Generate Predictions with an Estimator

Generate predicted labels / values for input data provided by input_fn().

# S3 method for tf_estimator
predict(object, input_fn, checkpoint_path = NULL,
  predict_keys = c("predictions", "classes", "class_ids", "logistic",
  "logits", "probabilities"), hooks = NULL, as_iterable = FALSE,
  simplify = TRUE, yield_single_examples = TRUE, ...)



A TensorFlow estimator.


An input function, typically generated by the input_fn() helper function.


The path to a specific model checkpoint to be used for prediction. If NULL (the default), the latest checkpoint in model_dir is used.


The types of predictions that should be produced, as an R list. When this argument is not specified (the default), all possible predicted values will be returned.


A list of R functions, to be used as callbacks inside the training loop. By default, hook_history_saver(every_n_step = 10) and hook_progress_bar() will be attached if not provided to save the metrics history and create the progress bar.


Boolean; should a raw Python generator be returned? When FALSE (the default), the predicted values will be consumed from the generator and returned as an R object.


Whether to simplify prediction results into a tibble, as opposed to a list. Defaults to TRUE.


(Available since TensorFlow v1.7) If FALSE, yields the whole batch as returned by the model_fn instead of decomposing the batch into individual elements. This is useful if model_fn returns some tensors with first dimension not equal to the batch size.


Optional arguments passed on to the estimator's predict() method.


Evaluated values of predictions tensors.


ValueError: Could not find a trained model in model_dir. ValueError: if batch length of predictions are not same. ValueError: If there is a conflict between predict_keys and predictions. For example if predict_keys is not NULL but EstimatorSpec.predictions is not a dict.

See also