提交 8ca78e39 编写于 作者: Y Yeqing Li 提交者: A. Unique TensorFlower

Adds an option to control evaluation in training loop.

PiperOrigin-RevId: 313708781
上级 8ea058b9
......@@ -343,6 +343,7 @@ class DistributedExecutor(object):
SummaryWriter] = SummaryWriter,
init_checkpoint: Callable[[tf.keras.Model], Any] = None,
custom_callbacks: List[tf.keras.callbacks.Callback] = None,
continuous_eval: bool = False,
save_config: bool = True):
"""Runs distributed training.
......@@ -362,8 +363,10 @@ class DistributedExecutor(object):
custom_callbacks: A list of Keras Callbacks objects to run during
training. More specifically, `on_batch_begin()`, `on_batch_end()`,
methods are invoked during training.
continuous_eval: If `True`, will continously run evaluation on every
available checkpoints. If `False`, will do the evaluation once after the
final step.
save_config: bool. Whether to save params to model_dir.
Returns:
The training loss and eval metrics.
"""
......@@ -414,6 +417,7 @@ class DistributedExecutor(object):
# input pipeline ops in worker task.
train_iterator = self._get_input_iterator(train_input_fn, strategy)
train_loss = None
train_metric_result = None
eval_metric_result = None
tf.keras.backend.set_learning_phase(1)
with strategy.scope():
......@@ -530,7 +534,7 @@ class DistributedExecutor(object):
checkpoint_name.format(step=current_step))
last_save_checkpoint_step = current_step
if test_step:
if continuous_eval and current_step < total_steps and test_step:
eval_iterator = self._get_input_iterator(eval_input_fn, strategy)
eval_metric_result = self._run_evaluation(test_step, current_step,
eval_metric, eval_iterator)
......@@ -562,7 +566,7 @@ class DistributedExecutor(object):
self.train_summary_writer.close()
self.eval_summary_writer.close()
return train_loss, eval_metric_result
return train_metric_result, eval_metric_result
def _run_evaluation(self, test_step, current_training_step, metric,
test_iterator):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册