R/learning_rate_schedules.R

learning_rate_schedule_exponential_decay

A LearningRateSchedule that uses an exponential decay schedule

Description

A LearningRateSchedule that uses an exponential decay schedule

Usage

 
learning_rate_schedule_exponential_decay( 
  initial_learning_rate, 
  decay_steps, 
  decay_rate, 
  staircase = FALSE, 
  ..., 
  name = NULL 
) 

Arguments

Arguments Description
initial_learning_rate A scalar float32 or float64 Tensor or a R number. The initial learning rate.
decay_steps A scalar int32 or int64 Tensor or an R number. Must be positive. See the decay computation above.
decay_rate A scalar float32 or float64 Tensor or an R number. The decay rate.
staircase Boolean. If TRUE decay the learning rate at discrete intervals.
For backwards and forwards compatibility
name String. Optional name of the operation. Defaults to ‘ExponentialDecay’.

Details

When training a model, it is often useful to lower the learning rate as the training progresses. This schedule applies an exponential decay function to an optimizer step, given a provided initial learning rate. The schedule is a 1-arg callable that produces a decayed learning rate when passed the current optimizer step. This can be useful for changing the learning rate value across different invocations of optimizer functions. It is computed as: ```

decayed_learning_rate <- function(step)

initial_learning_rate * decay_rate ^ (step / decay_steps)

If the argument `staircase` is `TRUE`, then `step / decay_steps` is an integer division (`%/%`) and the decayed learning rate follows a staircase function. You can pass this schedule directly into a optimizer as the learning rate (see example) Example: When fitting a Keras model, decay every 100000 steps with a base of 0.96:

initial_learning_rate <- 0.1

lr_schedule <- learning_rate_schedule_exponential_decay(

initial_learning_rate, 

decay_steps = 100000, 

decay_rate = 0.96, 

staircase = TRUE) 

model %>% compile(

optimizer= optimizer_sgd(learning_rate = lr_schedule),

loss = ‘sparse_categorical_crossentropy’,

metrics = ‘accuracy’)

model %>% fit(data, labels, epochs = 5)

```

See Also