| Function | Works | 
|---|---|
tidypredict_fit(), tidypredict_sql(),
parse_model() | 
✔ | 
tidypredict_to_column() | 
✔ | 
tidypredict_test() | 
✔ | 
tidypredict_interval(),
tidypredict_sql_interval() | 
✗ | 
parsnip | 
✔ | 
tidypredict_ functionslibrary(xgboost)
logregobj <- function(preds, dtrain) {
  labels <- xgboost::getinfo(dtrain, "label")
  preds <- 1 / (1 + exp(-preds))
  grad <- preds - labels
  hess <- preds * (1 - preds)
  return(list(grad = grad, hess = hess))
}
xgb_bin_data <- xgboost::xgb.DMatrix(
  as.matrix(mtcars[, -9]),
  label = mtcars$am
)
model <- xgboost::xgb.train(
  params = list(max_depth = 2, objective = "binary:logistic", base_score = 0.5),
  data = xgb_bin_data, nrounds = 50
)Create the R formula
tidypredict_fit(model)
#> 1 - 1/(1 + exp(0 + case_when(wt >= 3.18000007 ~ -0.436363667, 
#>     (qsec < 19.1849995 | is.na(qsec)) & (wt < 3.18000007 | is.na(wt)) ~ 
#>         0.428571463, qsec >= 19.1849995 & (wt < 3.18000007 | 
#>         is.na(wt)) ~ 0) + case_when((wt < 3.01250005 | is.na(wt)) ~ 
#>     0.311573088, (hp < 222.5 | is.na(hp)) & wt >= 3.01250005 ~ 
#>     -0.392053694, hp >= 222.5 & wt >= 3.01250005 ~ -0.0240745768) + 
#>     case_when((gear < 3.5 | is.na(gear)) ~ -0.355945677, (wt < 
#>         3.01250005 | is.na(wt)) & gear >= 3.5 ~ 0.325712085, 
#>         wt >= 3.01250005 & gear >= 3.5 ~ -0.0384863913) + case_when((gear < 
#>     3.5 | is.na(gear)) ~ -0.309683114, (wt < 3.01250005 | is.na(wt)) & 
#>     gear >= 3.5 ~ 0.283893973, wt >= 3.01250005 & gear >= 3.5 ~ 
#>     -0.032039877) + case_when((gear < 3.5 | is.na(gear)) ~ -0.275577009, 
#>     (wt < 3.01250005 | is.na(wt)) & gear >= 3.5 ~ 0.252453178, 
#>     wt >= 3.01250005 & gear >= 3.5 ~ -0.0266750772) + case_when((gear < 
#>     3.5 | is.na(gear)) ~ -0.248323873, (qsec < 17.6599998 | is.na(qsec)) & 
#>     gear >= 3.5 ~ 0.261978835, qsec >= 17.6599998 & gear >= 3.5 ~ 
#>     -0.00959526002) + case_when((gear < 3.5 | is.na(gear)) ~ 
#>     -0.225384533, (wt < 3.01250005 | is.na(wt)) & gear >= 3.5 ~ 
#>     0.218285918, wt >= 3.01250005 & gear >= 3.5 ~ -0.0373593047) + 
#>     case_when((gear < 3.5 | is.na(gear)) ~ -0.205454513, (qsec < 
#>         18.7550011 | is.na(qsec)) & gear >= 3.5 ~ 0.196076646, 
#>         qsec >= 18.7550011 & gear >= 3.5 ~ -0.0544253439) + case_when((wt < 
#>     3.01250005 | is.na(wt)) ~ 0.149246693, (qsec < 17.4099998 | 
#>     is.na(qsec)) & wt >= 3.01250005 ~ 0.0354709327, qsec >= 17.4099998 & 
#>     wt >= 3.01250005 ~ -0.226075932) + case_when((gear < 3.5 | 
#>     is.na(gear)) ~ -0.184417158, (wt < 3.01250005 | is.na(wt)) & 
#>     gear >= 3.5 ~ 0.176768288, wt >= 3.01250005 & gear >= 3.5 ~ 
#>     -0.0237750355) + case_when((gear < 3.5 | is.na(gear)) ~ -0.168993726, 
#>     (qsec < 18.6049995 | is.na(qsec)) & gear >= 3.5 ~ 0.155569643, 
#>     qsec >= 18.6049995 & gear >= 3.5 ~ -0.0325752236) + case_when((wt < 
#>     3.01250005 | is.na(wt)) ~ 0.119126029, wt >= 3.01250005 ~ 
#>     -0.105012275) + case_when((qsec < 17.1749992 | is.na(qsec)) ~ 
#>     0.117254697, qsec >= 17.1749992 ~ -0.0994235724) + case_when((wt < 
#>     3.18000007 | is.na(wt)) ~ 0.097100094, wt >= 3.18000007 ~ 
#>     -0.10567718) + case_when((wt < 3.18000007 | is.na(wt)) ~ 
#>     0.0824323222, wt >= 3.18000007 ~ -0.091120176) + case_when((qsec < 
#>     17.5100002 | is.na(qsec)) ~ 0.0854752287, qsec >= 17.5100002 ~ 
#>     -0.0764453933) + case_when((wt < 3.18000007 | is.na(wt)) ~ 
#>     0.0749477893, wt >= 3.18000007 ~ -0.0799863264) + case_when((qsec < 
#>     17.7099991 | is.na(qsec)) ~ 0.0728750378, qsec >= 17.7099991 ~ 
#>     -0.0646049976) + case_when((wt < 3.18000007 | is.na(wt)) ~ 
#>     0.0682478622, wt >= 3.18000007 ~ -0.0711427554) + case_when((wt < 
#>     3.18000007 | is.na(wt)) ~ 0.0579533465, wt >= 3.18000007 ~ 
#>     -0.0613371208) + case_when((qsec < 18.1499996 | is.na(qsec)) ~ 
#>     0.0595484748, qsec >= 18.1499996 ~ -0.0546668135) + case_when((wt < 
#>     3.18000007 | is.na(wt)) ~ 0.0535288528, wt >= 3.18000007 ~ 
#>     -0.0558333211) + case_when((wt < 3.18000007 | is.na(wt)) ~ 
#>     0.0454574414, wt >= 3.18000007 ~ -0.048143398) + case_when((qsec < 
#>     18.5600014 | is.na(qsec)) ~ 0.0422042683, qsec >= 18.5600014 ~ 
#>     -0.0454404354) + case_when((wt < 3.18000007 | is.na(wt)) ~ 
#>     0.0420555957, wt >= 3.18000007 ~ -0.0449385941) + case_when((qsec < 
#>     18.5600014 | is.na(qsec)) ~ 0.0393446013, qsec >= 18.5600014 ~ 
#>     -0.0425945036) + case_when((wt < 3.18000007 | is.na(wt)) ~ 
#>     0.0391179025, wt >= 3.18000007 ~ -0.0420661867) + case_when((qsec < 
#>     18.4099998 | is.na(qsec)) ~ 0.0304145869, qsec >= 18.4099998 ~ 
#>     -0.031833414) + case_when((wt < 3.18000007 | is.na(wt)) ~ 
#>     0.0362136625, wt >= 3.18000007 ~ -0.038949281) + case_when((qsec < 
#>     18.4099998 | is.na(qsec)) ~ 0.0295153651, qsec >= 18.4099998 ~ 
#>     -0.0307046026) + case_when((drat < 3.80999994 | is.na(drat)) ~ 
#>     -0.0306891855, drat >= 3.80999994 ~ 0.0288283136) + case_when((qsec < 
#>     18.4099998 | is.na(qsec)) ~ 0.0271221269, qsec >= 18.4099998 ~ 
#>     -0.0281750448) + case_when((qsec < 18.4099998 | is.na(qsec)) ~ 
#>     0.0228891298, qsec >= 18.4099998 ~ -0.0238814205) + case_when((drat < 
#>     3.80999994 | is.na(drat)) ~ -0.0296511576, drat >= 3.80999994 ~ 
#>     0.0280048084) + case_when((qsec < 18.4099998 | is.na(qsec)) ~ 
#>     0.0214707125, qsec >= 18.4099998 ~ -0.0224219449) + case_when((qsec < 
#>     18.4099998 | is.na(qsec)) ~ 0.0181306079, qsec >= 18.4099998 ~ 
#>     -0.0190209728) + case_when((wt < 3.18000007 | is.na(wt)) ~ 
#>     0.0379650332, wt >= 3.18000007 ~ -0.0395050682) + case_when((qsec < 
#>     18.4099998 | is.na(qsec)) ~ 0.0194106717, qsec >= 18.4099998 ~ 
#>     -0.0202215631) + case_when((qsec < 18.4099998 | is.na(qsec)) ~ 
#>     0.0164139606, qsec >= 18.4099998 ~ -0.0171694476) + case_when((qsec < 
#>     18.4099998 | is.na(qsec)) ~ 0.013879573, qsec >= 18.4099998 ~ 
#>     -0.0145772658) + case_when((qsec < 18.4099998 | is.na(qsec)) ~ 
#>     0.0117362784, qsec >= 18.4099998 ~ -0.0123759825) + case_when((wt < 
#>     3.18000007 | is.na(wt)) ~ 0.0388614088, wt >= 3.18000007 ~ 
#>     -0.0400568396) + log(0.5/(1 - 0.5))))Add the prediction to the original table
library(dplyr)
mtcars %>%
  tidypredict_to_column(model) %>%
  glimpse()
#> Rows: 32
#> Columns: 12
#> $ mpg  <dbl> 21.0, 21.0, 22.8, 21.4, 18.7, 18.1, 14.3, 24.4, 22.8, 19.2, 17.8,…
#> $ cyl  <dbl> 6, 6, 4, 6, 8, 6, 8, 4, 4, 6, 6, 8, 8, 8, 8, 8, 8, 4, 4, 4, 4, 8,…
#> $ disp <dbl> 160.0, 160.0, 108.0, 258.0, 360.0, 225.0, 360.0, 146.7, 140.8, 16…
#> $ hp   <dbl> 110, 110, 93, 110, 175, 105, 245, 62, 95, 123, 123, 180, 180, 180…
#> $ drat <dbl> 3.90, 3.90, 3.85, 3.08, 3.15, 2.76, 3.21, 3.69, 3.92, 3.92, 3.92,…
#> $ wt   <dbl> 2.620, 2.875, 2.320, 3.215, 3.440, 3.460, 3.570, 3.190, 3.150, 3.…
#> $ qsec <dbl> 16.46, 17.02, 18.61, 19.44, 17.02, 20.22, 15.84, 20.00, 22.90, 18…
#> $ vs   <dbl> 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0,…
#> $ am   <dbl> 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0,…
#> $ gear <dbl> 4, 4, 4, 3, 3, 3, 3, 4, 4, 4, 4, 3, 3, 3, 3, 3, 3, 4, 4, 4, 3, 3,…
#> $ carb <dbl> 4, 4, 1, 1, 2, 1, 4, 2, 2, 4, 4, 3, 3, 3, 4, 4, 4, 1, 2, 1, 1, 2,…
#> $ fit  <dbl> 0.98576418, 0.98576418, 0.92735110, 0.01081509, 0.04639094, 0.010…Confirm that tidypredict results match to the
model’s predict() results. The xg_df argument
expects the xgb.DMatrix data set.
parsnip fitted models are also supported by
tidypredict:
Here is an example of the model spec:
pm <- parse_model(model)
str(pm, 2)
#> List of 2
#>  $ general:List of 7
#>   ..$ model        : chr "xgb.Booster"
#>   ..$ type         : chr "xgb"
#>   ..$ niter        : num 50
#>   ..$ params       :List of 4
#>   ..$ feature_names: chr [1:10] "mpg" "cyl" "disp" "hp" ...
#>   ..$ nfeatures    : int 10
#>   ..$ version      : num 1
#>  $ trees  :List of 42
#>   ..$ 0 :List of 3
#>   ..$ 1 :List of 3
#>   ..$ 2 :List of 3
#>   ..$ 3 :List of 3
#>   ..$ 4 :List of 3
#>   ..$ 5 :List of 3
#>   ..$ 6 :List of 3
#>   ..$ 7 :List of 3
#>   ..$ 8 :List of 3
#>   ..$ 9 :List of 3
#>   ..$ 10:List of 3
#>   ..$ 11:List of 2
#>   ..$ 12:List of 2
#>   ..$ 13:List of 2
#>   ..$ 14:List of 2
#>   ..$ 15:List of 2
#>   ..$ 16:List of 2
#>   ..$ 17:List of 2
#>   ..$ 18:List of 2
#>   ..$ 19:List of 2
#>   ..$ 20:List of 2
#>   ..$ 21:List of 2
#>   ..$ 22:List of 2
#>   ..$ 23:List of 2
#>   ..$ 24:List of 2
#>   ..$ 25:List of 2
#>   ..$ 26:List of 2
#>   ..$ 27:List of 2
#>   ..$ 28:List of 2
#>   ..$ 29:List of 2
#>   ..$ 30:List of 2
#>   ..$ 31:List of 2
#>   ..$ 32:List of 2
#>   ..$ 33:List of 2
#>   ..$ 34:List of 2
#>   ..$ 35:List of 2
#>   ..$ 36:List of 2
#>   ..$ 37:List of 2
#>   ..$ 38:List of 2
#>   ..$ 39:List of 2
#>   ..$ 40:List of 2
#>   ..$ 41:List of 2
#>  - attr(*, "class")= chr [1:3] "parsed_model" "pm_xgb" "list"str(pm$trees[1])
#> List of 1
#>  $ 0:List of 3
#>   ..$ :List of 2
#>   .. ..$ prediction: num -0.436
#>   .. ..$ path      :List of 1
#>   .. .. ..$ :List of 5
#>   .. .. .. ..$ type   : chr "conditional"
#>   .. .. .. ..$ col    : chr "wt"
#>   .. .. .. ..$ val    : num 3.18
#>   .. .. .. ..$ op     : chr "less"
#>   .. .. .. ..$ missing: logi FALSE
#>   ..$ :List of 2
#>   .. ..$ prediction: num 0.429
#>   .. ..$ path      :List of 2
#>   .. .. ..$ :List of 5
#>   .. .. .. ..$ type   : chr "conditional"
#>   .. .. .. ..$ col    : chr "qsec"
#>   .. .. .. ..$ val    : num 19.2
#>   .. .. .. ..$ op     : chr "more-equal"
#>   .. .. .. ..$ missing: logi TRUE
#>   .. .. ..$ :List of 5
#>   .. .. .. ..$ type   : chr "conditional"
#>   .. .. .. ..$ col    : chr "wt"
#>   .. .. .. ..$ val    : num 3.18
#>   .. .. .. ..$ op     : chr "more-equal"
#>   .. .. .. ..$ missing: logi TRUE
#>   ..$ :List of 2
#>   .. ..$ prediction: num 0
#>   .. ..$ path      :List of 2
#>   .. .. ..$ :List of 5
#>   .. .. .. ..$ type   : chr "conditional"
#>   .. .. .. ..$ col    : chr "qsec"
#>   .. .. .. ..$ val    : num 19.2
#>   .. .. .. ..$ op     : chr "less"
#>   .. .. .. ..$ missing: logi FALSE
#>   .. .. ..$ :List of 5
#>   .. .. .. ..$ type   : chr "conditional"
#>   .. .. .. ..$ col    : chr "wt"
#>   .. .. .. ..$ val    : num 3.18
#>   .. .. .. ..$ op     : chr "more-equal"
#>   .. .. .. ..$ missing: logi TRUE