提交 5741cef6 编写于 作者: Y Yeqing Li 提交者: A. Unique TensorFlower

Support multiple metrics in CTL.

PiperOrigin-RevId: 306751755
上级 3c227a73
......@@ -30,9 +30,9 @@ import tensorflow as tf
# pylint: disable=unused-import,g-import-not-at-top,redefined-outer-name,reimported
from typing import Optional, Dict, List, Text, Callable, Union, Iterator, Any
from official.modeling.hyperparams import params_dict
from official.utils import hyperparams_flags
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils
from official.utils import hyperparams_flags
FLAGS = flags.FLAGS
......@@ -59,6 +59,45 @@ def _no_metric():
return None
def metrics_as_dict(metric):
"""Puts input metric(s) into a list.
Args:
metric: metric(s) to be put into the list. `metric` could be a object, a
list or a dict of tf.keras.metrics.Metric or has the `required_method`.
Returns:
A dictionary of valid metrics.
"""
if isinstance(metric, tf.keras.metrics.Metric):
metrics = {metric.name: metric}
elif isinstance(metric, list):
metrics = {m.name: m for m in metric}
elif isinstance(metric, dict):
metrics = metric
elif not metric:
return {}
else:
metrics = {'metric': metric}
return metrics
def metric_results(metric):
"""Collects results from the given metric(s)."""
metrics = metrics_as_dict(metric)
metric_result = {
name: m.result().numpy().astype(float) for name, m in metrics.items()
}
return metric_result
def reset_states(metric):
"""Resets states of the given metric(s)."""
metrics = metrics_as_dict(metric)
for m in metrics.values():
m.reset_states()
class SummaryWriter(object):
"""Simple SummaryWriter for writing dictionary of metrics.
......@@ -185,6 +224,7 @@ class DistributedExecutor(object):
loss_fn,
optimizer,
metric=None):
metrics = metrics_as_dict(metric)
def _replicated_step(inputs):
"""Replicated training step."""
......@@ -195,11 +235,8 @@ class DistributedExecutor(object):
prediction_loss = loss_fn(labels, outputs)
loss = tf.reduce_mean(prediction_loss)
loss = loss / strategy.num_replicas_in_sync
if isinstance(metric, tf.keras.metrics.Metric):
metric.update_state(labels, outputs)
else:
logging.error('train metric is not an instance of '
'tf.keras.metrics.Metric.')
for m in metrics.values():
m.update_state(labels, outputs)
grads = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
......@@ -235,6 +272,7 @@ class DistributedExecutor(object):
Args:
iterator: an iterator that yields input tensors.
num_steps: the number of steps in the loop.
Returns:
The loss tensor.
......@@ -259,6 +297,7 @@ class DistributedExecutor(object):
def _create_test_step(self, strategy, model, metric):
"""Creates a distributed test step."""
metrics = metrics_as_dict(metric)
@tf.function
def test_step(iterator):
......@@ -266,22 +305,20 @@ class DistributedExecutor(object):
if not metric:
logging.info('Skip test_step because metric is None (%s)', metric)
return None, None
if not isinstance(metric, tf.keras.metrics.Metric):
raise ValueError(
'Metric must be an instance of tf.keras.metrics.Metric '
'for running in test_step. Actual {}'.format(metric))
def _test_step_fn(inputs):
"""Replicated accuracy calculation."""
inputs, labels = inputs
model_outputs = model(inputs, training=False)
metric.update_state(labels, model_outputs)
for m in metrics.values():
m.update_state(labels, model_outputs)
return labels, model_outputs
return strategy.run(_test_step_fn, args=(next(iterator),))
return test_step
def train(self,
train_input_fn: Callable[[params_dict.ParamsDict], tf.data.Dataset],
eval_input_fn: Callable[[params_dict.ParamsDict],
......@@ -422,8 +459,9 @@ class DistributedExecutor(object):
test_step = self._create_test_step(strategy, model, metric=eval_metric)
# Step-0 operations
_save_checkpoint(
checkpoint, model_dir, checkpoint_name.format(step=current_step))
if current_step == 0 and not latest_checkpoint_file:
_save_checkpoint(
checkpoint, model_dir, checkpoint_name.format(step=current_step))
if test_step:
eval_iterator = self._get_input_iterator(eval_input_fn, strategy)
eval_metric_result = self._run_evaluation(
......@@ -432,7 +470,7 @@ class DistributedExecutor(object):
'Step: %s evalation metric = %s.', current_step, eval_metric_result)
test_summary_writer(
metrics=eval_metric_result, step=optimizer.iterations)
eval_metric.reset_states()
reset_states(eval_metric)
logging.info('Training started')
last_save_checkpoint_step = current_step
......@@ -454,12 +492,7 @@ class DistributedExecutor(object):
raise ValueError('total loss is NaN.')
if train_metric:
train_metric_result = train_metric.result()
if isinstance(train_metric, tf.keras.metrics.Metric):
train_metric_result = tf.nest.map_structure(
lambda x: x.numpy().astype(float), train_metric_result)
if not isinstance(train_metric_result, dict):
train_metric_result = {'metric': train_metric_result}
train_metric_result = metric_results(train_metric)
train_metric_result.update(train_loss)
else:
train_metric_result = train_loss
......@@ -496,9 +529,9 @@ class DistributedExecutor(object):
# Re-initialize evaluation metric, except the last step.
if eval_metric and current_step < total_steps:
eval_metric.reset_states()
reset_states(eval_metric)
if train_metric and current_step < total_steps:
train_metric.reset_states()
reset_states(train_metric)
# Reaches the end of training and saves the last checkpoint.
if last_save_checkpoint_step < total_steps:
......@@ -534,9 +567,7 @@ class DistributedExecutor(object):
except (StopIteration, tf.errors.OutOfRangeError):
break
metric_result = metric.result()
if isinstance(metric, tf.keras.metrics.Metric):
metric_result = metric_result.numpy().astype(float)
metric_result = metric_results(metric)
logging.info('Step: [%d] Validation metric = %f', current_training_step,
metric_result)
return metric_result
......@@ -653,7 +684,7 @@ class DistributedExecutor(object):
logging.info('Step: %s evalation metric = %s.', current_step,
eval_metric_result)
summary_writer(metrics=eval_metric_result, step=current_step)
eval_metric.reset_states()
reset_states(eval_metric)
return eval_metric_result, current_step
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册