From 54fb237bbd5a3167c78a206986496f652de98f23 Mon Sep 17 00:00:00 2001 From: Yeqing Li Date: Tue, 22 Jun 2021 11:47:59 -0700 Subject: [PATCH] Internal change PiperOrigin-RevId: 380858205 --- official/core/train_lib.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/official/core/train_lib.py b/official/core/train_lib.py index 9f703678e..450637043 100644 --- a/official/core/train_lib.py +++ b/official/core/train_lib.py @@ -40,7 +40,8 @@ def run_experiment( model_dir: str, run_post_eval: bool = False, save_summary: bool = True, - trainer: Optional[base_trainer.Trainer] = None + trainer: Optional[base_trainer.Trainer] = None, + controller_cls=orbit.Controller ) -> Tuple[tf.keras.Model, Mapping[str, Any]]: """Runs train/eval configured by the experiment params. @@ -56,6 +57,8 @@ def run_experiment( save_summary: Whether to save train and validation summary. trainer: the base_trainer.Trainer instance. It should be created within the strategy.scope(). + controller_cls: The controller class to manage the train and eval process. + Must be a orbit.Controller subclass. Returns: A 2-tuple of (model, eval_logs). @@ -87,7 +90,7 @@ def run_experiment( else: checkpoint_manager = None - controller = orbit.Controller( + controller = controller_cls( strategy=distribution_strategy, trainer=trainer if 'train' in mode else None, evaluator=trainer, -- GitLab