提交 54fb237b 编写于 作者: Y Yeqing Li 提交者: A. Unique TensorFlower

Internal change

PiperOrigin-RevId: 380858205
上级 02ff7788
......@@ -40,7 +40,8 @@ def run_experiment(
model_dir: str,
run_post_eval: bool = False,
save_summary: bool = True,
trainer: Optional[base_trainer.Trainer] = None
trainer: Optional[base_trainer.Trainer] = None,
controller_cls=orbit.Controller
) -> Tuple[tf.keras.Model, Mapping[str, Any]]:
"""Runs train/eval configured by the experiment params.
......@@ -56,6 +57,8 @@ def run_experiment(
save_summary: Whether to save train and validation summary.
trainer: the base_trainer.Trainer instance. It should be created within the
strategy.scope().
controller_cls: The controller class to manage the train and eval process.
Must be a orbit.Controller subclass.
Returns:
A 2-tuple of (model, eval_logs).
......@@ -87,7 +90,7 @@ def run_experiment(
else:
checkpoint_manager = None
controller = orbit.Controller(
controller = controller_cls(
strategy=distribution_strategy,
trainer=trainer if 'train' in mode else None,
evaluator=trainer,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册