R/training_run.R

tuning_run

Tune hyperparameters using training flags

Description

Run all combinations of the specifed training flags. The number of combinations can be reduced by specifying the sample parameter, which will result in a random sample of the flag combinations being run.

Usage

 
tuning_run( 
  file = "train.R", 
  context = "local", 
  config = Sys.getenv("R_CONFIG_ACTIVE", unset = "default"), 
  flags = NULL, 
  sample = NULL, 
  properties = NULL, 
  runs_dir = getOption("tfruns.runs_dir", "runs"), 
  artifacts_dir = getwd(), 
  echo = TRUE, 
  confirm = interactive(), 
  envir = parent.frame(), 
  encoding = getOption("encoding") 
) 

Arguments

Arguments Description
file Path to training script (defaults to “train.R”)
context Run context (defaults to “local”)
config The configuration to use. Defaults to the active configuration for the current environment (as specified by the R_CONFIG_ACTIVE environment variable), or default when unset.
flags Either a named list with flag values (multiple values can be provided for each flag) or a data frame that contains pre-generated combinations of flags (e.g. via base::expand.grid()). The latter can be useful for subsetting combinations. See ‘Examples’.
sample Sampling rate for flag combinations (defaults to running all combinations).
properties Named character vector with run properties. Properties are additional metadata about the run which will be subsequently available via ls_runs().
runs_dir Directory containing runs. Defaults to “runs” beneath the current working directory (or to the value of the tfruns.runs_dir R option if specified).
artifacts_dir Directory to capture created and modified files within. Pass NULL to not capture any artifcats.
echo Print expressions within training script
confirm Confirm before executing tuning run.
envir The environment in which the script should be evaluated
encoding The encoding of the training script; see file().

Value

Data frame with summary of all training runs performed during tuning.

Examples

library(tfruns) 
 
# using a list as input to the flags argument 
runs <- tuning_run( 
  system.file("examples/mnist_mlp/mnist_mlp.R", package = "tfruns"), 
  flags = list( 
    dropout1 = c(0.2, 0.3, 0.4), 
    dropout2 = c(0.2, 0.3, 0.4) 
  ) 
) 
runs[order(runs$eval_acc, decreasing = TRUE), ] 
 
# using a data frame as input to the flags argument 
# resulting in the same combinations above, but remove those 
# where the combined dropout rate exceeds 1 
grid <- expand.grid( 
  dropout1 = c(0.2, 0.3, 0.4), 
  dropout2 = c(0.2, 0.3, 0.4) 
) 
grid$combined_droput <- grid$dropout1 + grid$dropout2 
grid <- grid[grid$combined_droput <= 1, ] 
runs <- tuning_run( 
  system.file("examples/mnist_mlp/mnist_mlp.R", package = "tfruns"), 
  flags = grid[, c("dropout1", "dropout2")] 
)