提交 b21347fb 编写于 作者: S Scott Zhu 提交者: TensorFlower Gardener

Allow steps_per_epoch=-1 for PSS to run indefinitely.

PiperOrigin-RevId: 380058341
上级 216da119
......@@ -44,6 +44,31 @@ class DatasetCreatorModelFitTest(test_base.DatasetCreatorModelFitTestBase):
model = self._model_fit(strategy)
self.assertEqual(model.optimizer.iterations, 100)
def testModelFitwithStepsPerEpochNegativeOne(self, strategy):
def dataset_fn(input_context):
del input_context
x = tf.random.uniform((10, 10))
y = tf.random.uniform((10,))
return tf.data.Dataset.from_tensor_slices(
(x, y)).shuffle(10).batch(2)
if strategy._should_use_with_coordinator:
with self.assertRaises((tf.errors.OutOfRangeError,
tf.errors.CancelledError)):
self._model_fit(
strategy,
steps_per_epoch=-1,
x=dataset_creator.DatasetCreator(dataset_fn),
validation_data=dataset_creator.DatasetCreator(dataset_fn),
)
else:
self._model_fit(
strategy,
steps_per_epoch=-1,
x=dataset_creator.DatasetCreator(dataset_fn),
validation_data=dataset_creator.DatasetCreator(dataset_fn),
)
def testModelFitWithNumpyData(self, strategy):
x = np.random.rand(100, 10)
y = np.random.rand(100, 1)
......
......@@ -89,11 +89,6 @@ class DatasetCreatorModelFitTestBase(tf.test.TestCase, parameterized.TestCase):
if not is_loss_float:
raise RuntimeError("loss is supposed to be in the logs and float.")
def on_train_end(self, logs=None):
if self._prev_epoch != 9:
raise RuntimeError("Unexpected last epoch: {}".format(
self._prev_epoch))
with strategy.scope():
model = sequential.Sequential([core_layers.Dense(10)])
if with_normalization_layer:
......
......@@ -1281,8 +1281,18 @@ class DataHandler(object):
# TODO(b/150292341): Allow multiple async steps here.
return self._inferred_steps is None
def _log_indefinite_training_warning(self):
logging.warning("The training loop will run indefinitely since you have "
"set `steps_per_epoch=-1`. Please use batch-level "
"callbacks to save checkpoints or log training progress, "
"etc")
def _infer_steps(self, steps, dataset):
"""Infers steps_per_epoch needed to loop through a dataset."""
if steps == -1:
self._log_indefinite_training_warning()
return None
if steps is not None:
return steps
......@@ -1356,10 +1366,12 @@ class _ClusterCoordinatorDataHandler(DataHandler):
self._dataset = self._model._cluster_coordinator.create_per_worker_dataset( # pylint: disable=protected-access
per_worker_dataset_fn)
if steps_per_epoch is None:
raise ValueError(
"`steps_per_epoch` must be specified with `ParameterServerStrategy`.")
self._inferred_steps = steps_per_epoch
if steps_per_epoch == -1:
self._inferred_steps = None
self._log_indefinite_training_warning()
else:
self._inferred_steps = steps_per_epoch
def sync(self):
self._model._cluster_coordinator.join() # pylint: disable=protected-access
......
......@@ -1019,9 +1019,11 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
`tf.data` dataset, and 'steps_per_epoch'
is None, the epoch will run until the input dataset is exhausted.
When passing an infinitely repeating dataset, you must specify the
`steps_per_epoch` argument. This argument is not supported with
array inputs. `steps_per_epoch=None` is not supported when using
`tf.distribute.experimental.ParameterServerStrategy`.
`steps_per_epoch` argument. If `steps_per_epoch=-1` the training
will run indefinitely with an infinitely repeating dataset.
This argument is not supported with array inputs.
When using `tf.distribute.experimental.ParameterServerStrategy`:
* `steps_per_epoch=None` is not supported.
validation_steps: Only relevant if `validation_data` is provided and
is a `tf.data` dataset. Total number of steps (batches of
samples) to draw before stopping when performing validation
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册