提交 f753edca 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Fix truncated `steps_per_execution` which failed in distributed training.

The problem is that for a dataset with e.g. 14 elements and `steps_per_exectuion=5`, the `DataAdapter.steps` iterator does the following:

 1. Yield `0`,
 2. Yield `5`,
 3. Set `steps_per_execution` to `4`, yield `10`,
 4. Set `steps_per_execution` back to `5`.

The problem is that in distributed training, the steps are only enqueued, and not executed. So even if the value of `steps_per_execution` is adjusted to `4` for the  final step, and has a value of `4` when the task is enqueued, `steps_per_execution` is set back to `5` before the task is actually run.

As a result, 15 steps are computed instead of 14.

This change makes the number of steps a parameter of the internal `train_function`, `predict_function`, and `test_function` functions, and passes a copy of the value of `steps_per_execution` at the time the task is enqueued, e.g. between steps 3 and 4 above.

PiperOrigin-RevId: 395042946
上级 580f0446
......@@ -843,31 +843,27 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
write_scalar_summaries(outputs, step=model._train_counter) # pylint: disable=protected-access
return outputs
if (self._steps_per_execution is None or
self._steps_per_execution.numpy().item() == 1):
def train_function(iterator):
"""Runs a training execution with one step."""
return step_function(self, iterator)
else:
def train_function(iterator):
"""Runs a training execution with multiple steps."""
for _ in tf.range(self._steps_per_execution):
def train_function(iterator, steps_per_execution):
"""Runs a training execution with multiple steps."""
outputs = step_function(self, iterator)
if steps_per_execution > 1:
for _ in tf.range(steps_per_execution - 1):
outputs = step_function(self, iterator)
return outputs
return outputs
if not self.run_eagerly:
train_function = tf.function(
train_function, experimental_relax_shapes=True)
self.train_tf_function = train_function
self.train_function = train_function
if self._cluster_coordinator:
self.train_function = lambda iterator: self._cluster_coordinator.schedule( # pylint: disable=g-long-lambda
train_function, args=(iterator,))
self.train_function = lambda it: self._cluster_coordinator.schedule( # pylint: disable=g-long-lambda
train_function,
args=(it, self._steps_per_execution.numpy().item()))
else:
self.train_function = lambda it: train_function( # pylint: disable=g-long-lambda
it,
self._steps_per_execution.numpy().item())
return self.train_function
......@@ -1327,30 +1323,26 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
outputs, self.distribute_strategy, reduction='first')
return outputs
if (self._steps_per_execution is None or
self._steps_per_execution.numpy().item() == 1):
def test_function(iterator):
"""Runs an evaluation execution with one step."""
return step_function(self, iterator)
else:
def test_function(iterator):
"""Runs an evaluation execution with multiple steps."""
for _ in tf.range(self._steps_per_execution):
def test_function(iterator, steps_per_execution):
"""Runs an evaluation execution with multiple steps."""
outputs = step_function(self, iterator)
if steps_per_execution > 1:
for _ in tf.range(steps_per_execution - 1):
outputs = step_function(self, iterator)
return outputs
return outputs
if not self.run_eagerly:
test_function = tf.function(
test_function, experimental_relax_shapes=True)
self.test_function = test_function
if self._cluster_coordinator:
self.test_function = lambda iterator: self._cluster_coordinator.schedule( # pylint: disable=g-long-lambda
test_function, args=(iterator,))
self.test_function = lambda it: self._cluster_coordinator.schedule( # pylint: disable=g-long-lambda
test_function,
args=(it, self._steps_per_execution.numpy().item()))
else:
self.test_function = lambda it: test_function( # pylint: disable=g-long-lambda
it,
self._steps_per_execution.numpy().item())
return self.test_function
......@@ -1582,33 +1574,30 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
outputs, self.distribute_strategy, reduction='concat')
return outputs
if (self._steps_per_execution is None or
self._steps_per_execution.numpy().item() == 1):
def predict_function(iterator):
"""Runs an evaluation execution with one step."""
return step_function(self, iterator)
else:
def predict_function(iterator):
"""Runs an evaluation execution with multiple steps."""
outputs = step_function(self, iterator)
for _ in tf.range(self._steps_per_execution - 1):
def predict_function(iterator, steps_per_execution):
"""Runs an evaluation execution with multiple steps."""
outputs = step_function(self, iterator)
if steps_per_execution > 1:
for _ in tf.range(steps_per_execution - 1):
tf.autograph.experimental.set_loop_options(
shape_invariants=[(
t, tf_utils.get_tensor_spec(t, dynamic_batch=True).shape)
for t in tf.nest.flatten(outputs)])
step_outputs = step_function(self, iterator)
outputs = tf.nest.map_structure(lambda t1, t2: concat([t1, t2]), outputs,
step_outputs)
return outputs
outputs = tf.nest.map_structure(lambda t1, t2: concat([t1, t2]),
outputs, step_outputs)
return outputs
if not self.run_eagerly:
predict_function = tf.function(
predict_function, experimental_relax_shapes=True)
self.predict_function = predict_function
steps_per_execution = (
lambda: self._steps_per_execution.numpy().item() # pylint: disable=g-long-lambda
if self._steps_per_execution is not None else 1)
self.predict_function = lambda it: predict_function( # pylint: disable=g-long-lambda
it, steps_per_execution())
return self.predict_function
@traceback_utils.filter_traceback
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册