提交 c7734283 编写于 作者: A Abdullah Rashwan 提交者: A. Unique TensorFlower

Internal change

PiperOrigin-RevId: 446234160
上级 6d458bcc
......@@ -15,7 +15,7 @@
"""TFM common training driver library."""
# pytype: disable=attribute-error
import os
from typing import Any, Mapping, Optional, Tuple
from typing import Any, Mapping, Optional, Tuple, List
# Import libraries
......@@ -40,6 +40,8 @@ def run_experiment(
model_dir: str,
run_post_eval: bool = False,
save_summary: bool = True,
train_actions: Optional[List[orbit.Action]] = None,
eval_actions: Optional[List[orbit.Action]] = None,
trainer: Optional[base_trainer.Trainer] = None,
controller_cls=orbit.Controller
) -> Tuple[tf.keras.Model, Mapping[str, Any]]:
......@@ -55,6 +57,8 @@ def run_experiment(
run_post_eval: Whether to run post eval once after training, metrics logs
are returned.
save_summary: Whether to save train and validation summary.
train_actions: Optional list of Orbit train actions.
eval_actions: Optional list of Orbit eval actions.
trainer: the base_trainer.Trainer instance. It should be created within the
strategy.scope().
controller_cls: The controller class to manage the train and eval process.
......@@ -90,6 +94,13 @@ def run_experiment(
else:
checkpoint_manager = None
train_actions = [] if not train_actions else train_actions
train_actions += actions.get_train_actions(
params, trainer, model_dir, checkpoint_manager=checkpoint_manager)
eval_actions = [] if not eval_actions else eval_actions
eval_actions += actions.get_eval_actions(params, trainer, model_dir)
controller = controller_cls(
strategy=distribution_strategy,
trainer=trainer if 'train' in mode else None,
......@@ -103,9 +114,8 @@ def run_experiment(
(save_summary) else None,
summary_interval=params.trainer.summary_interval if
(save_summary) else None,
train_actions=actions.get_train_actions(
params, trainer, model_dir, checkpoint_manager=checkpoint_manager),
eval_actions=actions.get_eval_actions(params, trainer, model_dir))
train_actions=train_actions,
eval_actions=eval_actions)
logging.info('Starts to execute mode: %s', mode)
with distribution_strategy.scope():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册