-
Notifications
You must be signed in to change notification settings - Fork 111
/
selections.R
489 lines (443 loc) · 14.5 KB
/
selections.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
#' @name selections
#' @aliases selections
#' @aliases selection
#'
#' @title Methods for selecting variables in step functions
#'
#' @description
#'
#' Tips for selecting columns in step functions.
#'
#' @details
#' When selecting variables or model terms in `step`
#' functions, `dplyr`-like tools are used. The *selector* functions
#' can choose variables based on their name, current role, data
#' type, or any combination of these. The selectors are passed as
#' any other argument to the step. If the variables are explicitly
#' named in the step function, this might look like:
#'
#' \preformatted{
#' recipe( ~ ., data = USArrests) \%>\%
#' step_pca(Murder, Assault, UrbanPop, Rape, num_comp = 3)
#' }
#'
#' The first four arguments indicate which variables should be
#' used in the PCA while the last argument is a specific argument
#' to [step_pca()] about the number of components.
#'
#' Note that:
#'
#' \enumerate{
#' \item These arguments are not evaluated until the `prep`
#' function for the step is executed.
#' \item The `dplyr`-like syntax allows for negative signs to
#' exclude variables (e.g. `-Murder`) and the set of selectors will
#' processed in order.
#' \item A leading exclusion in these arguments (e.g. `-Murder`)
#' has the effect of adding *all* variables to the list except the
#' excluded variable(s), ignoring role information.
#' }
#'
#' Select helpers from the `tidyselect` package can also be used:
#' [tidyselect::starts_with()], [tidyselect::ends_with()],
#' [tidyselect::contains()], [tidyselect::matches()],
#' [tidyselect::num_range()], [tidyselect::everything()],
#' [tidyselect::one_of()], [tidyselect::all_of()], and
#' [tidyselect::any_of()]
#'
#' Note that using [tidyselect::everything()] or any of the other `tidyselect`
#' functions aren't restricted to predictors. They will thus select outcomes,
#' ID, and predictor columns alike. This is why these functions should be used
#' with care, and why [tidyselect::everything()] likely isn't what you need.
#'
#' For example:
#'
#' \preformatted{
#' recipe(Species ~ ., data = iris) \%>\%
#' step_center(starts_with("Sepal"), -contains("Width"))
#' }
#'
#' would only select `Sepal.Length`
#'
#' Columns of the design matrix that may not exist when the step
#' is coded can also be selected. For example, when using
#' `step_pca()`, the number of columns created by feature extraction
#' may not be known when subsequent steps are defined. In this
#' case, using `matches("^PC")` will select all of the columns
#' whose names start with "PC" *once those columns are created*.
#'
#' There are sets of recipes-specific functions that can be used to select
#' variables based on their role or type: [has_role()] and
#' [has_type()]. For convenience, there are also functions that are
#' more specific. The functions [all_numeric()] and [all_nominal()] select
#' based on type, with nominal variables including both character and factor;
#' the functions [all_predictors()] and [all_outcomes()] select based on role.
#' The functions [all_numeric_predictors()] and [all_nominal_predictors()]
#' select intersections of role and type. Any can be used in conjunction with
#' the previous functions described for selecting variables using their names.
#'
#' A selection like this:
#'
#' \preformatted{
#' data(biomass)
#' recipe(HHV ~ ., data = biomass) \%>\%
#' step_center(all_numeric(), -all_outcomes())
#' }
#'
#' is equivalent to:
#'
#' \preformatted{
#' data(biomass)
#' recipe(HHV ~ ., data = biomass) \%>\%
#' step_center(all_numeric_predictors())
#' }
#'
#' Both result in all the numeric predictors: carbon, hydrogen,
#' oxygen, nitrogen, and sulfur.
#'
#' If a role for a variable has not been defined, it will never be
#' selected using role-specific selectors.
#'
#' ## Interactions
#'
#' Selectors can be used in [step_interact()] in similar ways but
#' must be embedded in a model formula (as opposed to a sequence
#' of selectors). For example, the interaction specification
#' could be `~ starts_with("Species"):Sepal.Width`. This can be
#' useful if `Species` was converted to dummy variables
#' previously using [step_dummy()]. The implementation of
#' `step_interact()` is special, and is more restricted than
#' the other step functions. Only the selector functions from
#' recipes and tidyselect are allowed. User defined selector functions
#' will not be recognized. Additionally, the tidyselect domain specific
#' language is not recognized here, meaning that `&`, `|`, `!`, and `-`
#' will not work.
#'
#' @includeRmd man/rmd/selections.Rmd details
NULL
# ------------------------------------------------------------------------------
#' Evaluate a selection with tidyselect semantics specific to recipes
#'
#' @description
#' `recipes_eval_select()` is a recipes specific variant of
#' [tidyselect::eval_select()] enhanced with the ability to recognize recipes
#' selectors, such as [all_numeric_predictors()]. See [selections]
#' for more information about the unique recipes selectors.
#'
#' This is a developer tool that is only useful for creating new recipes steps.
#'
#' @inheritParams rlang::args_dots_empty
#'
#' @param quos A list of quosures describing the selection. This is generally
#' the `...` argument of your step function, captured with [rlang::enquos()]
#' and stored in the step object as the `terms` element.
#'
#' @param data A data frame to use as the context to evaluate the selection in.
#' This is generally the `training` data passed to the [prep()] method
#' of your step.
#'
#' @param info A data frame of term information describing each column's type
#' and role for use with the recipes selectors. This is generally the `info`
#' data passed to the [prep()] method of your step.
#'
#' @param allow_rename Should the renaming syntax `c(foo = bar)` be allowed?
#' This is rarely required, and is currently only used by [step_select()].
#' It is unlikely that your step will need renaming capabilities.
#'
#' @param check_case_weights Should selecting case weights throw an error?
#' Defaults to `TRUE`. This is rarely changed and only needed in [juice()],
#' [bake.recipe()], [update_role()], and [add_role()].
#'
#' @param call The execution environment of a currently running function, e.g.
#' `caller_env()`. The function will be mentioned in error messages as the
#' source of the error. See the call argument of [rlang::abort()] for more
#' information.
#'
#' @return
#' A named character vector containing the evaluated selection. The names are
#' always the same as the values, except when `allow_rename = TRUE`, in which
#' case the names reflect the new names chosen by the user.
#'
#' @seealso [developer_functions]
#'
#' @export
#' @examplesIf rlang::is_installed("modeldata")
#' library(rlang)
#' data(scat, package = "modeldata")
#'
#' rec <- recipe(Species ~ ., data = scat)
#'
#' info <- summary(rec)
#' info
#'
#' quos <- quos(all_numeric_predictors(), where(is.factor))
#'
#' recipes_eval_select(quos, scat, info)
recipes_eval_select <- function(quos, data, info, ..., allow_rename = FALSE,
check_case_weights = TRUE, call = caller_env()) {
check_dots_empty()
if (rlang::is_missing(quos)) {
cli::cli_abort("Argument {.arg quos} is missing, with no default.")
}
# Maintain ordering between `data` column names and `info$variable` so
# `eval_select()` and recipes selectors return compatible positions
matches <- vctrs::vec_locate_matches(names(data), info$variable, no_match = "error")
data_info <- vec_slice(info, matches$haystack)
data_nest <- data_info[names(data_info) != "variable"]
data_nest <- tibble::new_tibble(data_nest, nrow = vctrs::vec_size(data_nest))
nested_info <- vctrs::vec_split(data_nest, by = data_info$variable)
nested_info <- list(variable = nested_info$key, data = nested_info$val)
nested_info <- tibble::new_tibble(nested_info, nrow = length(nested_info$variable))
local_current_info(nested_info)
expr <- expr(c(!!!quos))
if ((!allow_rename) && any(names(expr) != "")) {
offenders <- names(expr)
offenders <- offenders[offenders != ""]
cli::cli_abort(
"The following argument{?s} {?was/were} specified but do{?es/} not exist: \\
{.arg {offenders}}.",
call = call
)
}
sel <- tidyselect::eval_select(
expr = expr,
data = data,
allow_rename = allow_rename,
error_call = call
)
# Return names not positions, as these names are
# used for both the training and test set and their positions
# may have changed. If renaming is allowed, add the new names.
out <- names(data)[sel]
names <- names(sel)
if (check_case_weights &&
any(out %in% info$variable[info$role == "case_weights"])) {
cli::cli_abort("Cannot select case weights variable.", call = call)
}
names(out) <- names
out
}
#' Role Selection
#'
#' @description
#'
#' `has_role()`, `all_predictors()`, and `all_outcomes()` can be used to
#' select variables in a formula that have certain roles.
#'
#' **In most cases**, the right approach for users will be use to use the
#' predictor-specific selectors such as `all_numeric_predictors()` and
#' `all_nominal_predictors()`. In general you should be careful about using
#' `-all_outcomes()` if a `*_predictors()` selector would do what you want.
#'
#' Similarly, `has_type()`, `all_numeric()`, `all_integer()`, `all_double()`,
#' `all_nominal()`, `all_ordered()`, `all_unordered()`, `all_factor()`,
#' `all_string()`, `all_date()` and `all_datetime()` are used to select columns
#' based on their data type.
#'
#' `all_factor()` captures ordered and unordered factors, `all_string()`
#' captures characters, `all_unordered()` captures unordered factors and
#' characters, `all_ordered()` captures ordered factors, `all_nominal()`
#' captures characters, unordered and ordered factors.
#'
#' `all_integer()` captures integers, `all_double()` captures doubles,
#' `all_numeric()` captures all kinds of numeric.
#'
#' `all_date()` captures [Date()] variables, `all_datetime()` captures
#' [POSIXct()] variables.
#'
#' See [selections] for more details.
#'
#' `current_info()` is an internal function.
#'
#' All of these functions have have limited utility outside of column selection
#' in step functions.
#'
#' @param match A single character string for the query. Exact
#' matching is used (i.e. regular expressions won't work).
#'
#' @return
#'
#' Selector functions return an integer vector.
#'
#' `current_info()` returns an environment with objects `vars` and `data`.
#'
#' @examplesIf rlang::is_installed("modeldata")
#' data(biomass, package = "modeldata")
#'
#' rec <- recipe(biomass) %>%
#' update_role(
#' carbon, hydrogen, oxygen, nitrogen, sulfur,
#' new_role = "predictor"
#' ) %>%
#' update_role(HHV, new_role = "outcome") %>%
#' update_role(sample, new_role = "id variable") %>%
#' update_role(dataset, new_role = "splitting indicator")
#'
#' recipe_info <- summary(rec)
#' recipe_info
#'
#' # Centering on all predictors except carbon
#' rec %>%
#' step_center(all_predictors(), -carbon) %>%
#' prep(training = biomass) %>%
#' bake(new_data = NULL)
#' @export
has_role <- function(match = "predictor") {
roles <- peek_roles()
# roles is potentially a list columns so we unlist `.x` below.
lgl_matches <- purrr::map_lgl(roles, ~ any(unlist(.x) %in% match))
which(lgl_matches)
}
#' @export
#' @rdname has_role
has_type <- function(match = "numeric") {
types <- peek_types()
lgl_matches <- purrr::map_lgl(types, ~ any(.x %in% match))
which(lgl_matches)
}
peek_roles <- function() {
peek_info("role")
}
peek_types <- function() {
peek_info("type")
}
peek_info <- function(col) {
.data <- current_info()$data
purrr::map(.data, ~ unlist(.x[[col]]))
}
#' @export
#' @rdname has_role
all_outcomes <- function() {
has_role("outcome")
}
#' @export
#' @rdname has_role
all_predictors <- function() {
has_role("predictor")
}
#' @export
#' @rdname has_role
all_date <- function() {
has_type("date")
}
#' @export
#' @rdname has_role
all_date_predictors <- function() {
intersect(has_role("predictor"), has_type("date"))
}
#' @export
#' @rdname has_role
all_datetime <- function() {
has_type("datetime")
}
#' @export
#' @rdname has_role
all_datetime_predictors <- function() {
intersect(has_role("predictor"), has_type("datetime"))
}
#' @export
#' @rdname has_role
all_double <- function() {
has_type("double")
}
#' @export
#' @rdname has_role
all_double_predictors <- function() {
intersect(has_role("predictor"), has_type("double"))
}
#' @export
#' @rdname has_role
all_factor <- function() {
has_type("factor")
}
#' @export
#' @rdname has_role
all_factor_predictors <- function() {
intersect(has_role("predictor"), has_type("factor"))
}
#' @export
#' @rdname has_role
all_integer <- function() {
has_type("integer")
}
#' @export
#' @rdname has_role
all_integer_predictors <- function() {
intersect(has_role("predictor"), has_type("integer"))
}
#' @export
#' @rdname has_role
all_logical <- function() {
has_type("logical")
}
#' @export
#' @rdname has_role
all_logical_predictors <- function() {
intersect(has_role("predictor"), has_type("logical"))
}
#' @export
#' @rdname has_role
all_nominal <- function() {
has_type("nominal")
}
#' @export
#' @rdname has_role
all_nominal_predictors <- function() {
intersect(has_role("predictor"), has_type("nominal"))
}
#' @export
#' @rdname has_role
all_numeric <- function() {
has_type("numeric")
}
#' @export
#' @rdname has_role
all_numeric_predictors <- function() {
intersect(has_role("predictor"), has_type("numeric"))
}
#' @export
#' @rdname has_role
all_ordered <- function() {
has_type("ordered")
}
#' @export
#' @rdname has_role
all_ordered_predictors <- function() {
intersect(has_role("predictor"), has_type("ordered"))
}
#' @export
#' @rdname has_role
all_string <- function() {
has_type("string")
}
#' @export
#' @rdname has_role
all_string_predictors <- function() {
intersect(has_role("predictor"), has_type("string"))
}
#' @export
#' @rdname has_role
all_unordered <- function() {
has_type("unordered")
}
#' @export
#' @rdname has_role
all_unordered_predictors <- function() {
intersect(has_role("predictor"), has_type("unordered"))
}
## functions to get current variable info for selectors modeled after
## dplyr versions
cur_info_env <- env(empty_env())
local_current_info <- function(nested_info, frame = parent.frame()) {
local_bindings(
vars = nested_info$variable,
data = nested_info$data,
.env = cur_info_env,
.frame = frame
)
}
#' @export
#' @rdname has_role
current_info <- function() {
cur_info_env %||% cli::cli_abort("Variable context not set.")
}