查看原文
其他

mlr3实战:决策树和xgboost预测房价

阿越就是我 医学和生信笔记 2023-02-25
关注公众号,发送R语言,获取学习资料!
点击上方 关注我们


前面用10篇推文详细介绍了mlr3包的基础使用及进阶方法。

今天学习用一个简单的例子说明mlr3的实战用法。

预测King Country地区的房价,将学习使用mlr3及其生态进行数据预处理、建模、重抽样、超参数调优等内容。用到了决策树以及xgboost

加载数据和R包

library(mlr3verse)
## 载入需要的程辑包:mlr3
set.seed(123# 设置种子数,数据可重复
lgr::get_logger("mlr3")$set_threshold("warn"# 减少屏幕日志
lgr::get_logger("bbotk")$set_threshold("warn")

data("kc_housing", package = "mlr3data"# 加载数据

数据探索

str(kc_housing)

## 'data.frame': 21613 obs. of  20 variables:
##  $ date         : POSIXct, format: "2014-10-13" "2014-12-09" ...
##  $ price        : num  221900 538000 180000 604000 510000 ...
##  $ bedrooms     : int  3 3 2 4 3 4 3 3 3 3 ...
##  $ bathrooms    : num  1 2.25 1 3 2 4.5 2.25 1.5 1 2.5 ...
##  $ sqft_living  : int  1180 2570 770 1960 1680 5420 1715 1060 1780 1890 ...
##  $ sqft_lot     : int  5650 7242 10000 5000 8080 101930 6819 9711 7470 6560 ...
##  $ floors       : num  1 2 1 1 1 1 2 1 1 2 ...
##  $ waterfront   : logi  FALSE FALSE FALSE FALSE FALSE FALSE ...
##  $ view         : int  0 0 0 0 0 0 0 0 0 0 ...
##  $ condition    : int  3 3 3 5 3 3 3 3 3 3 ...
##  $ grade        : int  7 7 6 7 8 11 7 7 7 7 ...
##  $ sqft_above   : int  1180 2170 770 1050 1680 3890 1715 1060 1050 1890 ...
##  $ sqft_basement: int  NA 400 NA 910 NA 1530 NA NA 730 NA ...
##  $ yr_built     : int  1955 1951 1933 1965 1987 2001 1995 1963 1960 2003 ...
##  $ yr_renovated : int  NA 1991 NA NA NA NA NA NA NA NA ...
##  $ zipcode      : int  98178 98125 98028 98136 98074 98053 98003 98198 98146 98038 ...
##  $ lat          : num  47.5 47.7 47.7 47.5 47.6 ...
##  $ long         : num  -122 -122 -122 -122 -122 ...
##  $ sqft_living15: int  1340 1690 2720 1360 1800 4760 2238 1650 1780 2390 ...
##  $ sqft_lot15   : int  5650 7639 8062 5000 7503 101930 6819 9711 8113 7570 ...
##  - attr(*, "index")= int(0)
dim(kc_housing) # 21613,20

## [1] 21613    20
summary(kc_housing)

##       date                         price            bedrooms     
##  Min.   :2014-05-02 00:00:00   Min.   :  75000   Min.   : 0.000  
##  1st Qu.:2014-07-22 00:00:00   1st Qu.: 321950   1st Qu.: 3.000  
##  Median :2014-10-16 00:00:00   Median : 450000   Median : 3.000  
##  Mean   :2014-10-29 03:58:09   Mean   : 540088   Mean   : 3.371  
##  3rd Qu.:2015-02-17 00:00:00   3rd Qu.: 645000   3rd Qu.: 4.000  
##  Max.   :2015-05-27 00:00:00   Max.   :7700000   Max.   :33.000  
##                                                                  
##    bathrooms      sqft_living       sqft_lot           floors     
##  Min.   :0.000   Min.   :  290   Min.   :    520   Min.   :1.000  
##  1st Qu.:1.750   1st Qu.: 1427   1st Qu.:   5040   1st Qu.:1.000  
##  Median :2.250   Median : 1910   Median :   7618   Median :1.500  
##  Mean   :2.115   Mean   : 2080   Mean   :  15107   Mean   :1.494  
##  3rd Qu.:2.500   3rd Qu.: 2550   3rd Qu.:  10688   3rd Qu.:2.000  
##  Max.   :8.000   Max.   :13540   Max.   :1651359   Max.   :3.500  
##                                                                   
##  waterfront           view          condition         grade       
##  Mode :logical   Min.   :0.0000   Min.   :1.000   Min.   : 1.000  
##  FALSE:21450     1st Qu.:0.0000   1st Qu.:3.000   1st Qu.: 7.000  
##  TRUE :163       Median :0.0000   Median :3.000   Median : 7.000  
##                  Mean   :0.2343   Mean   :3.409   Mean   : 7.657  
##                  3rd Qu.:0.0000   3rd Qu.:4.000   3rd Qu.: 8.000  
##                  Max.   :4.0000   Max.   :5.000   Max.   :13.000  
##                                                                   
##    sqft_above   sqft_basement       yr_built     yr_renovated      zipcode     
##  Min.   : 290   Min.   :  10.0   Min.   :1900   Min.   :1934    Min.   :98001  
##  1st Qu.:1190   1st Qu.: 450.0   1st Qu.:1951   1st Qu.:1987    1st Qu.:98033  
##  Median :1560   Median : 700.0   Median :1975   Median :2000    Median :98065  
##  Mean   :1788   Mean   : 742.4   Mean   :1971   Mean   :1996    Mean   :98078  
##  3rd Qu.:2210   3rd Qu.: 980.0   3rd Qu.:1997   3rd Qu.:2007    3rd Qu.:98118  
##  Max.   :9410   Max.   :4820.0   Max.   :2015   Max.   :2015    Max.   :98199  
##                 NA's   :13126                   NA's   :20699                  
##       lat             long        sqft_living15    sqft_lot15    
##  Min.   :47.16   Min.   :-122.5   Min.   : 399   Min.   :   651  
##  1st Qu.:47.47   1st Qu.:-122.3   1st Qu.:1490   1st Qu.:  5100  
##  Median :47.57   Median :-122.2   Median :1840   Median :  7620  
##  Mean   :47.56   Mean   :-122.2   Mean   :1987   Mean   : 12768  
##  3rd Qu.:47.68   3rd Qu.:-122.1   3rd Qu.:2360   3rd Qu.: 10083  
##  Max.   :47.78   Max.   :-121.3   Max.   :6210   Max.   :871200  
## 

数据预处理

price是结果变量(target),其余是预测变量(feature)。

首先要把日期型变量date变为数值型,然后以最早的日期为标准变成数值,以天为单位。

把邮政编码变为因子型。

增加新列renovates,记录房子是否翻修过。

增加新列has_basement,记录有无地下室情况。

把房子价格单位从1

删除有缺失值的行。

library(anytime)
dates <- anytime(kc_housing$date)
kc_housing$date <- as.numeric(difftime(dates, min(dates), units = "days"))

kc_housing$renovated <- as.numeric(!is.na(kc_housing$yr_renovated))
kc_housing$has_basement <- as.numeric(!is.na(kc_housing$sqft_basement))
kc_housing$yr_renovated <- NULL
kc_housing$sqft_basement <- NULL

kc_housing$price <- kc_housing$price / 1000

简单画图看一下:

library(ggplot2)

ggplot(kc_housing, aes(x = price)) + geom_density() + theme_minimal()
plot of chunk unnamed-chunk-4

创建任务

task <- as_task_regr(kc_housing, target = "price")
task
## <TaskRegr:kc_housing> (21613 x 20)
## * Target: price
## * Properties: -
## * Features (19):
##   - int (11): bedrooms, condition, grade, sqft_above, sqft_living,
##     sqft_living15, sqft_lot, sqft_lot15, view, yr_built, zipcode
##   - dbl (7): bathrooms, date, floors, has_basement, lat, long,
##     renovated
##   - lgl (1): waterfront
autoplot(task)+facet_wrap(~ condition)
plot of chunk unnamed-chunk-6
# 变量间关系
autoplot(task$clone()$select(task$feature_names[c(3,17)]),type = "pairs")
## Registered S3 method overwritten by 'GGally':
##   method from   
##   +.gg   ggplot2
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
plot of chunk unnamed-chunk-7

划分数据开始建模

split <- partition(task, ratio = 0.7)
train_idx <- split$train
test_idx <- split$test

task_train <- task$clone()$filter(train_idx)
task_test <- task$clone()$filter(test_idx)

决策树

#  先不用zipcode这一列
task_nozip <- task_train$clone()$select(setdiff(task$feature_names, "zipcode"))

# 建模
lrn <- lrn("regr.rpart")
lrn$train(task_nozip, row_ids = train_idx)

# 可视化决策树
library(rpart.plot)
## 载入需要的程辑包:rpart
rpart.plot(lrn$model)
plot of chunk unnamed-chunk-9

可以看到决策树在grade/sqft_living/lat等水平上进行了分支,下面画一个地图,看看经纬度对价格的影响。

library(ggmap)
## Google's Terms of Service: https://cloud.google.com/maps-platform/terms/.

## Please cite ggmap if you use it! See citation("ggmap") for details.
qmplot(long, lat, maptype = "watercolor", color = log(price),
  data = kc_housing[train_idx[1:3000],]) +
  scale_colour_viridis_c()
000012

很明显还是靠近水边的房子价格更贵!经纬度对房价影响也是有一点的。

下面看看不同邮政区域对价格的影响。

qmplot(long, lat, maptype = "watercolor", color = zipcode,
  data = kc_housing[train_idx[1:3000],]) + guides(color = "none")
000010

看起来不同邮政区域对价格有影响的。

下面用加上邮政区域的数据进行建模,使用3折交叉验证提高模型稳定性:

lrn_rpart <- lrn("regr.rpart")
cv3 <- rsmp("cv", folds = 3)

res <- resample(task_train, lrn_rpart, cv3, store_models = T)
res$aggregate(msr("regr.rmse"))
## regr.rmse 
##  221.0799

xgboost

lrn_xgboost <- lrn("regr.xgboost")

lrn_xgboost$param_set # 查看可以设置的超参数
## <ParamSet>
##                              id    class lower upper nlevels          default
##  1:                       alpha ParamDbl     0   Inf     Inf                0
##  2:               approxcontrib ParamLgl    NA    NA       2            FALSE
##  3:                  base_score ParamDbl  -Inf   Inf     Inf              0.5
##  4:                     booster ParamFct    NA    NA       3           gbtree
##  5:                   callbacks ParamUty    NA    NA     Inf        <list[0]>
##  6:           colsample_bylevel ParamDbl     0     1     Inf                1
##  7:            colsample_bynode ParamDbl     0     1     Inf                1
##  8:            colsample_bytree ParamDbl     0     1     Inf                1
##  9: disable_default_eval_metric ParamLgl    NA    NA       2            FALSE
## 10:       early_stopping_rounds ParamInt     1   Inf     Inf                 
## 11:                         eta ParamDbl     0     1     Inf              0.3
## 12:                 eval_metric ParamUty    NA    NA     Inf             rmse
## 13:            feature_selector ParamFct    NA    NA       5           cyclic
## 14:                       feval ParamUty    NA    NA     Inf                 
## 15:                       gamma ParamDbl     0   Inf     Inf                0
## 16:                 grow_policy ParamFct    NA    NA       2        depthwise
## 17:     interaction_constraints ParamUty    NA    NA     Inf   <NoDefault[3]>
## 18:              iterationrange ParamUty    NA    NA     Inf   <NoDefault[3]>
## 19:                      lambda ParamDbl     0   Inf     Inf                1
## 20:                 lambda_bias ParamDbl     0   Inf     Inf                0
## 21:                     max_bin ParamInt     2   Inf     Inf              256
## 22:              max_delta_step ParamDbl     0   Inf     Inf                0
## 23:                   max_depth ParamInt     0   Inf     Inf                6
## 24:                  max_leaves ParamInt     0   Inf     Inf                0
## 25:                    maximize ParamLgl    NA    NA       2                 
## 26:            min_child_weight ParamDbl     0   Inf     Inf                1
## 27:                     missing ParamDbl  -Inf   Inf     Inf               NA
## 28:        monotone_constraints ParamUty    NA    NA     Inf                0
## 29:              normalize_type ParamFct    NA    NA       2             tree
## 30:                     nrounds ParamInt     1   Inf     Inf   <NoDefault[3]>
## 31:                     nthread ParamInt     1   Inf     Inf                1
## 32:                  ntreelimit ParamInt     1   Inf     Inf                 
## 33:           num_parallel_tree ParamInt     1   Inf     Inf                1
## 34:                   objective ParamUty    NA    NA     Inf reg:squarederror
## 35:                    one_drop ParamLgl    NA    NA       2            FALSE
## 36:                outputmargin ParamLgl    NA    NA       2            FALSE
## 37:                 predcontrib ParamLgl    NA    NA       2            FALSE
## 38:                   predictor ParamFct    NA    NA       2    cpu_predictor
## 39:             predinteraction ParamLgl    NA    NA       2            FALSE
## 40:                    predleaf ParamLgl    NA    NA       2            FALSE
## 41:               print_every_n ParamInt     1   Inf     Inf                1
## 42:                process_type ParamFct    NA    NA       2          default
## 43:                   rate_drop ParamDbl     0     1     Inf                0
## 44:                refresh_leaf ParamLgl    NA    NA       2             TRUE
## 45:                     reshape ParamLgl    NA    NA       2            FALSE
## 46:                 sample_type ParamFct    NA    NA       2          uniform
## 47:             sampling_method ParamFct    NA    NA       2          uniform
## 48:                   save_name ParamUty    NA    NA     Inf                 
## 49:                 save_period ParamInt     0   Inf     Inf                 
## 50:            scale_pos_weight ParamDbl  -Inf   Inf     Inf                1
## 51:          seed_per_iteration ParamLgl    NA    NA       2            FALSE
## 52:  single_precision_histogram ParamLgl    NA    NA       2            FALSE
## 53:                  sketch_eps ParamDbl     0     1     Inf             0.03
## 54:                   skip_drop ParamDbl     0     1     Inf                0
## 55:                strict_shape ParamLgl    NA    NA       2            FALSE
## 56:                   subsample ParamDbl     0     1     Inf                1
## 57:                       top_k ParamInt     0   Inf     Inf                0
## 58:                    training ParamLgl    NA    NA       2            FALSE
## 59:                 tree_method ParamFct    NA    NA       5             auto
## 60:      tweedie_variance_power ParamDbl     1     2     Inf              1.5
## 61:                     updater ParamUty    NA    NA     Inf   <NoDefault[3]>
## 62:                     verbose ParamInt     0     2       3                1
## 63:                   watchlist ParamUty    NA    NA     Inf                 
## 64:                   xgb_model ParamUty    NA    NA     Inf                 
##                              id    class lower upper nlevels          default
##                      parents value
##  1:                               
##  2:                               
##  3:                               
##  4:                               
##  5:                               
##  6:                               
##  7:                               
##  8:                               
##  9:                               
## 10:                               
## 11:                               
## 12:                               
## 13:                  booster      
## 14:                               
## 15:                               
## 16:              tree_method      
## 17:                               
## 18:                               
## 19:                               
## 20:                  booster      
## 21:              tree_method      
## 22:                               
## 23:                               
## 24:              grow_policy      
## 25:                               
## 26:                               
## 27:                               
## 28:                               
## 29:                  booster      
## 30:                              1
## 31:                              1
## 32:                               
## 33:                               
## 34:                               
## 35:                  booster      
## 36:                               
## 37:                               
## 38:                               
## 39:                               
## 40:                               
## 41:                  verbose      
## 42:                               
## 43:                  booster      
## 44:                               
## 45:                               
## 46:                  booster      
## 47:                  booster      
## 48:                               
## 49:                               
## 50:                               
## 51:                               
## 52:              tree_method      
## 53:              tree_method      
## 54:                  booster      
## 55:                               
## 56:                               
## 57: booster,feature_selector      
## 58:                               
## 59:                  booster      
## 60:                objective      
## 61:                               
## 62:                              0
## 63:                               
## 64:                               
##                      parents value
search_space <- ps(
   eta = p_dbl(lower = 0.2, upper = .4),
  min_child_weight = p_dbl(lower = 1, upper = 20),
  subsample = p_dbl(lower = .7, upper = .8),
  colsample_bytree = p_dbl( lower = .9, upper = 1),
  colsample_bylevel = p_dbl(lower = .5, upper = .7),
  nrounds = p_int(lower = 1L, upper = 25)
)

at <- auto_tuner(
  method = "random_search",
  learner = lrn_xgboost,
  resampling = rsmp("holdout"),
  measure = msr("regr.rmse"),
  search_space = search_space,
  term_evals = 10,
  batch_size = 8
)

res <- resample(task_nozip, at, cv3, store_models = T)
res$aggregate() 
## regr.mse 
## 19122.95

效果比决策树好很多!


以上就是今天的内容,希望对你有帮助哦!欢迎点赞、在看、关注、转发

欢迎在评论区留言或直接添加我的微信!


欢迎关注公众号:医学和生信笔记

医学和生信笔记 公众号主要分享:1.医学小知识、肛肠科小知识;2.R语言和Python相关的数据分析、可视化、机器学习等;3.生物信息学学习资料和自己的学习笔记!


往期回顾




宽数据变为长数据的5种情况!

2022-03-13

长数据变为宽数据的7种情况!

2022-03-14

长宽数据转换的特殊情况

2022-03-15

ggplot2添加公式标签

2022-03-16

TCGA批量差异分析并可视化

2022-03-12

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存