diff --git a/official/core/train_lib.py b/official/core/train_lib.py index 9f703678e02b75bf4ab49dd2d52cc36a5f0b4271..450637043bd4b11fa0bcaa4ba83fb602d0d66bb9 100644 --- a/official/core/train_lib.py +++ b/official/core/train_lib.py @@ -40,7 +40,8 @@ def run_experiment( model_dir: str, run_post_eval: bool = False, save_summary: bool = True, - trainer: Optional[base_trainer.Trainer] = None + trainer: Optional[base_trainer.Trainer] = None, + controller_cls=orbit.Controller ) -> Tuple[tf.keras.Model, Mapping[str, Any]]: """Runs train/eval configured by the experiment params. @@ -56,6 +57,8 @@ def run_experiment( save_summary: Whether to save train and validation summary. trainer: the base_trainer.Trainer instance. It should be created within the strategy.scope(). + controller_cls: The controller class to manage the train and eval process. + Must be a orbit.Controller subclass. Returns: A 2-tuple of (model, eval_logs). @@ -87,7 +90,7 @@ def run_experiment( else: checkpoint_manager = None - controller = orbit.Controller( + controller = controller_cls( strategy=distribution_strategy, trainer=trainer if 'train' in mode else None, evaluator=trainer,