creates settings for the Estimator, which takes a model and trains it

setEstimator(
  learningRate = "auto",
  weightDecay = 0,
  batchSize = 512,
  epochs = 30,
  device = "cpu",
  optimizer = torch$optim$AdamW,
  scheduler = list(fun = torch$optim$lr_scheduler$ReduceLROnPlateau, params =
    list(patience = 1)),
  criterion = torch$nn$BCEWithLogitsLoss,
  earlyStopping = list(useEarlyStopping = TRUE, params = list(patience = 4)),
  metric = "auc",
  accumulationSteps = NULL,
  seed = NULL
)

Arguments

learningRate

what learning rate to use

weightDecay

what weight_decay to use

batchSize

batchSize to use

epochs

how many epochs to train for

device

what device to train on, can be a string or a function to that evaluates to the device during runtime

optimizer

which optimizer to use

scheduler

which learning rate scheduler to use

criterion

loss function to use

earlyStopping

If earlyStopping should be used which stops the training of your metric is not improving

metric

either `auc` or `loss` or a custom metric to use. This is the metric used for scheduler and earlyStopping. Needs to be a list with function `fun`, mode either `min` or `max` and a `name`, `fun` needs to be a function that takes in prediction and labels and outputs a score.

accumulationSteps

how many steps to accumulate gradients before updating weights, can also be a function that is evaluated during runtime

seed

seed to initialize weights of model with