提交 acea25b9 编写于 作者: A A. Unique TensorFlower

Internal change

PiperOrigin-RevId: 295849975
上级 f3600cd1
......@@ -64,7 +64,7 @@ class SummaryWriter(object):
"""Simple SummaryWriter for writing dictionary of metrics.
Attributes:
_writer: The tf.SummaryWriter.
writer: The tf.SummaryWriter.
"""
def __init__(self, model_dir: Text, name: Text):
......@@ -74,7 +74,7 @@ class SummaryWriter(object):
model_dir: the model folder path.
name: the summary subfolder name.
"""
self._writer = tf.summary.create_file_writer(os.path.join(model_dir, name))
self.writer = tf.summary.create_file_writer(os.path.join(model_dir, name))
def __call__(self, metrics: Union[Dict[Text, float], float], step: int):
"""Write metrics to summary with the given writer.
......@@ -88,10 +88,10 @@ class SummaryWriter(object):
logging.warning('Warning: summary writer prefer metrics as dictionary.')
metrics = {'metric': metrics}
with self._writer.as_default():
with self.writer.as_default():
for k, v in metrics.items():
tf.summary.scalar(k, v, step=step)
self._writer.flush()
self.writer.flush()
class DistributedExecutor(object):
......@@ -122,6 +122,9 @@ class DistributedExecutor(object):
self._strategy = strategy
self._checkpoint_name = 'ctl_step_{step}.ckpt'
self._is_multi_host = is_multi_host
self.train_summary_writer = None
self.eval_summary_writer = None
self.global_train_step = None
@property
def checkpoint_name(self):
......@@ -395,7 +398,10 @@ class DistributedExecutor(object):
eval_metric = eval_metric_fn()
train_metric = train_metric_fn()
train_summary_writer = summary_writer_fn(model_dir, 'eval_train')
self.train_summary_writer = train_summary_writer.writer
test_summary_writer = summary_writer_fn(model_dir, 'eval_test')
self.eval_summary_writer = test_summary_writer.writer
# Continue training loop.
train_step = self._create_train_step(
......@@ -406,6 +412,7 @@ class DistributedExecutor(object):
metric=train_metric)
test_step = None
if eval_input_fn and eval_metric:
self.global_train_step = model.optimizer.iterations
test_step = self._create_test_step(strategy, model, metric=eval_metric)
logging.info('Training started')
......@@ -549,6 +556,7 @@ class DistributedExecutor(object):
return True
summary_writer = summary_writer_fn(model_dir, 'eval')
self.eval_summary_writer = summary_writer.writer
# Read checkpoints from the given model directory
# until `eval_timeout` seconds elapses.
......@@ -615,6 +623,7 @@ class DistributedExecutor(object):
'checkpoint', checkpoint_path)
checkpoint.restore(checkpoint_path)
self.global_train_step = model.optimizer.iterations
eval_iterator = self._get_input_iterator(eval_input_fn, strategy)
eval_metric_result = self._run_evaluation(test_step, current_step,
eval_metric, eval_iterator)
......
......@@ -70,6 +70,9 @@ RETINANET_CFG = {
'val_json_file': '',
'eval_file_pattern': '',
'input_sharding': True,
# When visualizing images, set evaluation batch size to 40 to avoid
# potential OOM.
'num_images_to_visualize': 0,
},
'predict': {
'predict_batch_size': 8,
......
......@@ -25,6 +25,7 @@ import os
import json
import tensorflow.compat.v2 as tf
from official.modeling.training import distributed_executor as executor
from official.vision.detection.utils import box_utils
class DetectionDistributedExecutor(executor.DistributedExecutor):
......@@ -38,13 +39,19 @@ class DetectionDistributedExecutor(executor.DistributedExecutor):
trainable_variables_filter=None,
**kwargs):
super(DetectionDistributedExecutor, self).__init__(**kwargs)
params = kwargs['params']
if predict_post_process_fn:
assert callable(predict_post_process_fn)
if trainable_variables_filter:
assert callable(trainable_variables_filter)
self._predict_post_process_fn = predict_post_process_fn
self._trainable_variables_filter = trainable_variables_filter
self.eval_steps = tf.Variable(
0,
trainable=False,
dtype=tf.int32,
synchronization=tf.VariableSynchronization.ON_READ,
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
shape=[])
def _create_replicated_step(self,
strategy,
......@@ -90,24 +97,41 @@ class DetectionDistributedExecutor(executor.DistributedExecutor):
"""Creates a distributed test step."""
@tf.function
def test_step(iterator):
def test_step(iterator, eval_steps):
"""Calculates evaluation metrics on distributed devices."""
def _test_step_fn(inputs):
def _test_step_fn(inputs, eval_steps):
"""Replicated accuracy calculation."""
inputs, labels = inputs
model_outputs = model(inputs, training=False)
if self._predict_post_process_fn:
labels, prediction_outputs = self._predict_post_process_fn(
labels, model_outputs)
num_remaining_visualizations = (
self._params.eval.num_images_to_visualize - eval_steps)
# If there are remaining number of visualizations that needs to be
# done, add next batch outputs for visualization.
#
# TODO(hongjunchoi): Once dynamic slicing is supported on TPU, only
# write correct slice of outputs to summary file.
if num_remaining_visualizations > 0:
box_utils.visualize_bounding_boxes(
inputs, prediction_outputs['detection_boxes'],
self.global_train_step, self.eval_summary_writer)
return labels, prediction_outputs
labels, outputs = strategy.experimental_run_v2(
_test_step_fn, args=(next(iterator),))
_test_step_fn, args=(
next(iterator),
eval_steps,
))
outputs = tf.nest.map_structure(strategy.experimental_local_results,
outputs)
labels = tf.nest.map_structure(strategy.experimental_local_results,
labels)
eval_steps.assign_add(self._params.eval.batch_size)
return labels, outputs
return test_step
......@@ -115,6 +139,7 @@ class DetectionDistributedExecutor(executor.DistributedExecutor):
def _run_evaluation(self, test_step, current_training_step, metric,
test_iterator):
"""Runs validation steps and aggregate metrics."""
self.eval_steps.assign(0)
if not test_iterator or not metric:
logging.warning(
'Both test_iterator (%s) and metrics (%s) must not be None.',
......@@ -123,7 +148,7 @@ class DetectionDistributedExecutor(executor.DistributedExecutor):
logging.info('Running evaluation after step: %s.', current_training_step)
while True:
try:
labels, outputs = test_step(test_iterator)
labels, outputs = test_step(test_iterator, self.eval_steps)
if metric:
metric.update_state(labels, outputs)
except (StopIteration, tf.errors.OutOfRangeError):
......
......@@ -239,4 +239,5 @@ def main(argv):
if __name__ == '__main__':
assert tf.version.VERSION.startswith('2.')
tf.config.set_soft_device_placement(True)
app.run(main)
......@@ -26,6 +26,22 @@ EPSILON = 1e-8
BBOX_XFORM_CLIP = np.log(1000. / 16.)
def visualize_images_with_bounding_boxes(images, box_outputs, step,
summary_writer):
"""Records subset of evaluation images with bounding boxes."""
image_shape = tf.shape(images[0])
image_height = tf.cast(image_shape[0], tf.float32)
image_width = tf.cast(image_shape[1], tf.float32)
normalized_boxes = normalize_boxes(box_outputs, [image_height, image_width])
bounding_box_color = tf.constant([[1.0, 1.0, 0.0, 1.0]])
image_summary = tf.image.draw_bounding_boxes(images, normalized_boxes,
bounding_box_color)
with summary_writer.as_default():
tf.summary.image('bounding_box_summary', image_summary, step=step)
summary_writer.flush()
def yxyx_to_xywh(boxes):
"""Converts boxes from ymin, xmin, ymax, xmax to xmin, ymin, width, height.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册