提交 54fb237b 编写于 作者: Y Yeqing Li 提交者: A. Unique TensorFlower

Internal change

PiperOrigin-RevId: 380858205
上级 02ff7788
...@@ -40,7 +40,8 @@ def run_experiment( ...@@ -40,7 +40,8 @@ def run_experiment(
model_dir: str, model_dir: str,
run_post_eval: bool = False, run_post_eval: bool = False,
save_summary: bool = True, save_summary: bool = True,
trainer: Optional[base_trainer.Trainer] = None trainer: Optional[base_trainer.Trainer] = None,
controller_cls=orbit.Controller
) -> Tuple[tf.keras.Model, Mapping[str, Any]]: ) -> Tuple[tf.keras.Model, Mapping[str, Any]]:
"""Runs train/eval configured by the experiment params. """Runs train/eval configured by the experiment params.
...@@ -56,6 +57,8 @@ def run_experiment( ...@@ -56,6 +57,8 @@ def run_experiment(
save_summary: Whether to save train and validation summary. save_summary: Whether to save train and validation summary.
trainer: the base_trainer.Trainer instance. It should be created within the trainer: the base_trainer.Trainer instance. It should be created within the
strategy.scope(). strategy.scope().
controller_cls: The controller class to manage the train and eval process.
Must be a orbit.Controller subclass.
Returns: Returns:
A 2-tuple of (model, eval_logs). A 2-tuple of (model, eval_logs).
...@@ -87,7 +90,7 @@ def run_experiment( ...@@ -87,7 +90,7 @@ def run_experiment(
else: else:
checkpoint_manager = None checkpoint_manager = None
controller = orbit.Controller( controller = controller_cls(
strategy=distribution_strategy, strategy=distribution_strategy,
trainer=trainer if 'train' in mode else None, trainer=trainer if 'train' in mode else None,
evaluator=trainer, evaluator=trainer,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册