提交 a1699237 编写于 作者: S Sourabh Bajaj 提交者: TensorFlower Gardener

Support model.fit and evaluate in 2.0 with TPUStrategy using the...

Support model.fit and evaluate in 2.0 with TPUStrategy using the experimental_run + train_on_batch API.

PiperOrigin-RevId: 251570029
上级 0ba31190
...@@ -668,6 +668,15 @@ class TPUExtended(distribute_lib.StrategyExtendedV1): ...@@ -668,6 +668,15 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
if kwargs is None: if kwargs is None:
kwargs = {} kwargs = {}
# Remove None at the end of args as they are not replicatable
# If there are None in the middle we can't do anything about it
# so let those cases fail.
# For example when Keras model predict is used they pass the targets as
# None. We want to handle it here so all client libraries don't have to
# do this as other strategies can handle None values better.
while args and args[-1] is None:
args = args[:-1]
# Used to re-structure flattened output tensors from `tpu.replicate()` # Used to re-structure flattened output tensors from `tpu.replicate()`
# into a structured format. # into a structured format.
result = [[]] result = [[]]
......
...@@ -84,7 +84,7 @@ distribute_py_test( ...@@ -84,7 +84,7 @@ distribute_py_test(
srcs = ["distribute_strategy_test.py"], srcs = ["distribute_strategy_test.py"],
full_precision = True, full_precision = True,
main = "distribute_strategy_test.py", main = "distribute_strategy_test.py",
shard_count = 4, shard_count = 5,
tags = [ tags = [
"multi_and_single_gpu", "multi_and_single_gpu",
"no_oss", # TODO(b/117919883): Fix python error. "no_oss", # TODO(b/117919883): Fix python error.
......
...@@ -297,6 +297,11 @@ def strategy_minus_tpu_combinations(): ...@@ -297,6 +297,11 @@ def strategy_minus_tpu_combinations():
def tpu_strategy_combinations(): def tpu_strategy_combinations():
return combinations.combine(distribution=tpu_strategies,
mode=['graph', 'eager'])
def tpu_strategy_combinations_graph_only():
return combinations.combine(distribution=tpu_strategies, return combinations.combine(distribution=tpu_strategies,
mode=['graph']) mode=['graph'])
...@@ -313,8 +318,8 @@ def all_strategy_combinations_plus_cloning(): ...@@ -313,8 +318,8 @@ def all_strategy_combinations_plus_cloning():
cloning=[True, False]) + cloning=[True, False]) +
combinations.combine( combinations.combine(
distribution=tpu_strategies, distribution=tpu_strategies,
mode=['graph'], mode=['graph', 'eager'],
cloning=[True, False])) cloning=[False]))
def all_strategy_minus_default_and_tpu_combinations(): def all_strategy_minus_default_and_tpu_combinations():
...@@ -334,8 +339,8 @@ def all_strategy_combinations_minus_default(): ...@@ -334,8 +339,8 @@ def all_strategy_combinations_minus_default():
def strategy_and_optimizer_combinations(): def strategy_and_optimizer_combinations():
return combinations.times( non_tpu_strategies = combinations.times(
all_strategy_combinations(), strategy_minus_tpu_combinations(),
# TODO(b/130808953): Simplify when optimizers v1 work with cloning=False. # TODO(b/130808953): Simplify when optimizers v1 work with cloning=False.
combinations.combine( combinations.combine(
optimizer=[ optimizer=[
...@@ -353,6 +358,32 @@ def strategy_and_optimizer_combinations(): ...@@ -353,6 +358,32 @@ def strategy_and_optimizer_combinations():
strategy_combinations.rmsprop_optimizer_keras_v2_fn strategy_combinations.rmsprop_optimizer_keras_v2_fn
], ],
cloning=[True, False])) cloning=[True, False]))
# TODO(b/130808953): Simplify when optimizers v1 work with cloning=False.
tpu_strategies_graph = combinations.combine(
distribution=tpu_strategies,
mode=['graph'],
cloning=[True],
optimizer=[
strategy_combinations.adagrad_optimizer_v1_fn,
strategy_combinations.adam_optimizer_v1_fn,
strategy_combinations.gradient_descent_optimizer_v1_fn,
strategy_combinations.rmsprop_optimizer_v1_fn,
strategy_combinations.adagrad_optimizer_keras_v2_fn,
strategy_combinations.adam_optimizer_keras_v2_fn,
strategy_combinations.gradient_descent_optimizer_keras_v2_fn,
strategy_combinations.rmsprop_optimizer_keras_v2_fn
])
tpu_strategies_eager = combinations.combine(
distribution=tpu_strategies,
mode=['eager'],
cloning=[False],
optimizer=[
strategy_combinations.adagrad_optimizer_keras_v2_fn,
strategy_combinations.adam_optimizer_keras_v2_fn,
strategy_combinations.gradient_descent_optimizer_keras_v2_fn,
strategy_combinations.rmsprop_optimizer_keras_v2_fn
])
return non_tpu_strategies + tpu_strategies_eager + tpu_strategies_graph
class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase, class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase,
...@@ -769,7 +800,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, ...@@ -769,7 +800,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase,
self.assertAllEqual([6, 7], outs[1].shape) self.assertAllEqual([6, 7], outs[1].shape)
@combinations.generate( @combinations.generate(
combinations.times(tpu_strategy_combinations(), combinations.times(tpu_strategy_combinations_graph_only(),
combinations.combine(batch_size=[4, 6]))) combinations.combine(batch_size=[4, 6])))
def test_evaluate_with_partial_batch(self, distribution, batch_size): def test_evaluate_with_partial_batch(self, distribution, batch_size):
with self.cached_session(): with self.cached_session():
...@@ -812,7 +843,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, ...@@ -812,7 +843,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase,
rtol=1e-5) rtol=1e-5)
@combinations.generate( @combinations.generate(
combinations.times(tpu_strategy_combinations(), combinations.times(tpu_strategy_combinations_graph_only(),
combinations.combine(cloning=[True, False]))) combinations.combine(cloning=[True, False])))
def test_predict_with_partial_batch(self, distribution, cloning): def test_predict_with_partial_batch(self, distribution, cloning):
with self.cached_session(): with self.cached_session():
...@@ -846,7 +877,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, ...@@ -846,7 +877,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase,
atol=1e-5, atol=1e-5,
rtol=1e-5) rtol=1e-5)
@combinations.generate(tpu_strategy_combinations()) @combinations.generate(tpu_strategy_combinations_graph_only())
def test_no_target_model(self, distribution): def test_no_target_model(self, distribution):
with self.cached_session(): with self.cached_session():
optimizer = gradient_descent.GradientDescentOptimizer(0.001) optimizer = gradient_descent.GradientDescentOptimizer(0.001)
...@@ -872,7 +903,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, ...@@ -872,7 +903,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase,
model.evaluate(inputs, steps=1) model.evaluate(inputs, steps=1)
@combinations.generate( @combinations.generate(
combinations.times(tpu_strategy_combinations(), combinations.times(tpu_strategy_combinations_graph_only(),
combinations.combine(cloning=[True, False]))) combinations.combine(cloning=[True, False])))
def test_predict_multi_output_model_with_partial_batch( def test_predict_multi_output_model_with_partial_batch(
self, distribution, cloning): self, distribution, cloning):
...@@ -1192,6 +1223,7 @@ class TestDistributionStrategyWithDatasets(test.TestCase, ...@@ -1192,6 +1223,7 @@ class TestDistributionStrategyWithDatasets(test.TestCase,
def test_fit_eval_and_predict_with_optimizer(self, distribution, optimizer, def test_fit_eval_and_predict_with_optimizer(self, distribution, optimizer,
cloning): cloning):
with self.cached_session(): with self.cached_session():
with distribution.scope(): with distribution.scope():
model = get_model() model = get_model()
...@@ -1341,7 +1373,7 @@ class TestDistributionStrategyWithDatasets(test.TestCase, ...@@ -1341,7 +1373,7 @@ class TestDistributionStrategyWithDatasets(test.TestCase,
self.assertAllClose(0.001, keras.backend.get_value(model.optimizer.lr)) self.assertAllClose(0.001, keras.backend.get_value(model.optimizer.lr))
@combinations.generate( @combinations.generate(
combinations.times(tpu_strategy_combinations(), combinations.times(tpu_strategy_combinations_graph_only(),
combinations.combine(batch_size=[4, 6]))) combinations.combine(batch_size=[4, 6])))
def test_evaluate_with_dataset_with_partial_batch(self, distribution, def test_evaluate_with_dataset_with_partial_batch(self, distribution,
batch_size): batch_size):
...@@ -1382,7 +1414,7 @@ class TestDistributionStrategyWithDatasets(test.TestCase, ...@@ -1382,7 +1414,7 @@ class TestDistributionStrategyWithDatasets(test.TestCase,
rtol=1e-5) rtol=1e-5)
@combinations.generate( @combinations.generate(
combinations.times(tpu_strategy_combinations(), combinations.times(tpu_strategy_combinations_graph_only(),
combinations.combine(cloning=[True, False]))) combinations.combine(cloning=[True, False])))
def test_predict_with_dataset_with_partial_batch(self, distribution, cloning): def test_predict_with_dataset_with_partial_batch(self, distribution, cloning):
with self.cached_session(): with self.cached_session():
...@@ -1411,7 +1443,7 @@ class TestDistributionStrategyWithDatasets(test.TestCase, ...@@ -1411,7 +1443,7 @@ class TestDistributionStrategyWithDatasets(test.TestCase,
rtol=1e-5) rtol=1e-5)
@combinations.generate( @combinations.generate(
combinations.times(tpu_strategy_combinations(), combinations.times(tpu_strategy_combinations_graph_only(),
combinations.combine(cloning=[True, False]))) combinations.combine(cloning=[True, False])))
def test_predict_multi_output_model_with_dataset_with_partial_batch( def test_predict_multi_output_model_with_dataset_with_partial_batch(
self, distribution, cloning): self, distribution, cloning):
......
...@@ -164,6 +164,15 @@ def unwrap_outputs(distribution_strategy, grouped_outputs, ...@@ -164,6 +164,15 @@ def unwrap_outputs(distribution_strategy, grouped_outputs,
grouped_outputs[0], axis=None) grouped_outputs[0], axis=None)
all_outputs = flatten_per_replica_values(distribution_strategy, all_outputs = flatten_per_replica_values(distribution_strategy,
grouped_outputs[1:]) grouped_outputs[1:])
if (is_tpu_strategy(distribution_strategy) and
ops.executing_eagerly_outside_functions()):
# Choose 1 value per replica in the TPU case since all replicas produce the
# same output.
# We only do this in eager mode for now since this function is used in
# both graph and eager mode and in the graph case we currently don't use
# experimental_run so would need to be removed when we converge the graph
# code path as well.
all_outputs = all_outputs[::distribution_strategy.num_replicas_in_sync]
return [loss] + all_outputs return [loss] + all_outputs
...@@ -578,6 +587,9 @@ def _prepare_feed_values(model, inputs, targets, sample_weights, mode): ...@@ -578,6 +587,9 @@ def _prepare_feed_values(model, inputs, targets, sample_weights, mode):
""" """
strategy = model._distribution_strategy strategy = model._distribution_strategy
inputs, targets, sample_weights = _get_input_from_iterator(inputs, model) inputs, targets, sample_weights = _get_input_from_iterator(inputs, model)
if is_tpu_strategy(strategy):
if sample_weights is not None:
raise ValueError('TPUStrategy does not support sample weights.')
# When the inputs are dict, then we want to flatten it in the same order as # When the inputs are dict, then we want to flatten it in the same order as
# the input layers, such that the data are fed into the input layers in the # the input layers, such that the data are fed into the input layers in the
...@@ -611,8 +623,8 @@ def is_distributing_by_cloning(model): ...@@ -611,8 +623,8 @@ def is_distributing_by_cloning(model):
"""Decide whether this model is going to be distributed via cloning. """Decide whether this model is going to be distributed via cloning.
We are going to distribute the model by cloning if the user has signaled We are going to distribute the model by cloning if the user has signaled
that intent by not setting `cloning=False` in `Model.compile()` unless we that intent by setting `cloning=True` in `Model.compile()` unless we are in
are in graph mode or running on TPU. graph mode.
Args: Args:
model: Keras model to distribute. model: Keras model to distribute.
...@@ -621,9 +633,15 @@ def is_distributing_by_cloning(model): ...@@ -621,9 +633,15 @@ def is_distributing_by_cloning(model):
True if the `model` is going to be distributed using cloning and False True if the `model` is going to be distributed using cloning and False
otherwise. otherwise.
""" """
if (is_tpu_strategy(model._distribution_strategy) and
context.executing_eagerly):
if model._cloning:
logging.warning(
'Model cloning is not supported in TPU Strategy in Eager mode.'
'cloning argument will be ignored.')
return False
return (model._cloning or model._compile_distribution or return (model._cloning or model._compile_distribution or
not ops.executing_eagerly_outside_functions() or not ops.executing_eagerly_outside_functions())
K.is_tpu_strategy(model._distribution_strategy))
def _custom_compile_for_predict(model): def _custom_compile_for_predict(model):
......
...@@ -90,11 +90,20 @@ def strategies_for_embedding_models(): ...@@ -90,11 +90,20 @@ def strategies_for_embedding_models():
def test_combinations_for_embedding_model(): def test_combinations_for_embedding_model():
# TODO(sourabhbajaj): Enable tests for eager mode
eager_mode_strategies = [s for s in strategies_for_embedding_models()
if not s.required_tpu]
return (combinations.times( return (combinations.times(
combinations.combine( combinations.combine(
distribution=strategies_for_embedding_models(), distribution=strategies_for_embedding_models(),
cloning=[True, False]), cloning=[True, False]),
(graph_mode_test_configuration() + eager_mode_test_configuration()))) (graph_mode_test_configuration())) +
combinations.times(
combinations.combine(
distribution=eager_mode_strategies,
cloning=[False]),
(eager_mode_test_configuration())))
def test_combinations_with_tpu_strategies(): def test_combinations_with_tpu_strategies():
...@@ -322,7 +331,7 @@ def compare_results(results_with_ds, ...@@ -322,7 +331,7 @@ def compare_results(results_with_ds,
return default_tolerance return default_tolerance
for key in results_with_ds: for key in sorted(results_with_ds.keys()):
if (key.startswith('training_history') and if (key.startswith('training_history') and
isinstance(distribution, (tpu_strategy.TPUStrategy, isinstance(distribution, (tpu_strategy.TPUStrategy,
tpu_strategy.TPUStrategyV1)) and tpu_strategy.TPUStrategyV1)) and
...@@ -420,9 +429,9 @@ class TestDistributionStrategyCorrectnessBase(test.TestCase, ...@@ -420,9 +429,9 @@ class TestDistributionStrategyCorrectnessBase(test.TestCase,
def get_model(self, distribution=None, cloning=None, input_shapes=None): def get_model(self, distribution=None, cloning=None, input_shapes=None):
raise NotImplementedError raise NotImplementedError
def skip_unsupported_test_configuration(self, distribution): def skip_unsupported_test_configuration(self, distribution, cloning):
if should_skip_tpu_with_eager(distribution): if should_skip_tpu_with_eager(distribution) and cloning:
self.skipTest('TPUStrategy does not support eager mode now.') self.skipTest('TPUStrategy does not support eager mode with cloning.')
return return
def run_correctness_test(self, def run_correctness_test(self,
...@@ -443,7 +452,7 @@ class TestDistributionStrategyCorrectnessBase(test.TestCase, ...@@ -443,7 +452,7 @@ class TestDistributionStrategyCorrectnessBase(test.TestCase,
self.skipTest('Test broken; see b/129793413 and b/117920141') self.skipTest('Test broken; see b/129793413 and b/117920141')
with self.cached_session(): with self.cached_session():
self.set_up_test_config(use_numpy, use_validation_data, with_batch_norm) self.set_up_test_config(use_numpy, use_validation_data, with_batch_norm)
self.skip_unsupported_test_configuration(distribution) self.skip_unsupported_test_configuration(distribution, cloning)
if partial_last_batch == 'eval': if partial_last_batch == 'eval':
x_train, y_train, x_eval, y_eval, x_predict = ( x_train, y_train, x_eval, y_eval, x_predict = (
...@@ -540,7 +549,7 @@ class TestDistributionStrategyCorrectnessBase(test.TestCase, ...@@ -540,7 +549,7 @@ class TestDistributionStrategyCorrectnessBase(test.TestCase,
def run_dynamic_lr_test(self, distribution, cloning=None): def run_dynamic_lr_test(self, distribution, cloning=None):
with self.cached_session(): with self.cached_session():
self.set_up_test_config() self.set_up_test_config()
self.skip_unsupported_test_configuration(distribution) self.skip_unsupported_test_configuration(distribution, cloning)
x_train, y_train, _ = self.get_data() x_train, y_train, _ = self.get_data()
model = self.get_model(cloning=cloning, input_shapes=get_shapes(x_train)) model = self.get_model(cloning=cloning, input_shapes=get_shapes(x_train))
......
...@@ -153,7 +153,7 @@ class TestDistributionStrategyDnnMetricCorrectness( ...@@ -153,7 +153,7 @@ class TestDistributionStrategyDnnMetricCorrectness(
def run_metric_correctness_test(self, distribution, cloning): def run_metric_correctness_test(self, distribution, cloning):
with self.cached_session(): with self.cached_session():
self.set_up_test_config() self.set_up_test_config()
self.skip_unsupported_test_configuration(distribution) self.skip_unsupported_test_configuration(distribution, cloning)
x_train, y_train, _ = self.get_data() x_train, y_train, _ = self.get_data()
model = self.get_model(cloning, distribution=distribution) model = self.get_model(cloning, distribution=distribution)
...@@ -195,7 +195,7 @@ class TestDistributionStrategyDnnMetricEvalCorrectness( ...@@ -195,7 +195,7 @@ class TestDistributionStrategyDnnMetricEvalCorrectness(
def run_eval_metrics_correctness_test(self, distribution, cloning): def run_eval_metrics_correctness_test(self, distribution, cloning):
with self.cached_session(): with self.cached_session():
self.set_up_test_config() self.set_up_test_config()
self.skip_unsupported_test_configuration(distribution) self.skip_unsupported_test_configuration(distribution, cloning)
model = self.get_model(cloning, distribution=distribution) model = self.get_model(cloning, distribution=distribution)
...@@ -266,11 +266,17 @@ class TestDistributionStrategyDnnCorrectnessWithSubclassedModel( ...@@ -266,11 +266,17 @@ class TestDistributionStrategyDnnCorrectnessWithSubclassedModel(
keras_correctness_test_base.all_strategy_and_input_config_combinations()) keras_correctness_test_base.all_strategy_and_input_config_combinations())
def test_dnn_correctness(self, distribution, use_numpy, use_validation_data, def test_dnn_correctness(self, distribution, use_numpy, use_validation_data,
cloning): cloning):
if ((not cloning and context.executing_eagerly() and if ((not cloning and context.executing_eagerly()) or
not K.is_tpu_strategy(distribution)) or
is_default_strategy(distribution)): is_default_strategy(distribution)):
self.run_correctness_test(distribution, use_numpy, use_validation_data, self.run_correctness_test(distribution, use_numpy, use_validation_data,
cloning) cloning)
elif K.is_tpu_strategy(distribution) and not context.executing_eagerly():
with self.assertRaisesRegexp(
ValueError,
'Expected `model` argument to be a functional `Model` instance, '
'but got a subclass model instead.'):
self.run_correctness_test(distribution, use_numpy, use_validation_data,
cloning)
else: else:
with self.assertRaisesRegexp( with self.assertRaisesRegexp(
ValueError, ValueError,
...@@ -286,6 +292,12 @@ class TestDistributionStrategyDnnCorrectnessWithSubclassedModel( ...@@ -286,6 +292,12 @@ class TestDistributionStrategyDnnCorrectnessWithSubclassedModel(
not K.is_tpu_strategy(distribution)) or not K.is_tpu_strategy(distribution)) or
is_default_strategy(distribution)): is_default_strategy(distribution)):
self.run_dynamic_lr_test(distribution, cloning) self.run_dynamic_lr_test(distribution, cloning)
elif K.is_tpu_strategy(distribution):
with self.assertRaisesRegexp(
ValueError,
'Expected `model` argument to be a functional `Model` instance, '
'but got a subclass model instead.'):
self.run_dynamic_lr_test(distribution, cloning)
else: else:
with self.assertRaisesRegexp( with self.assertRaisesRegexp(
ValueError, ValueError,
...@@ -301,9 +313,8 @@ class TestDistributionStrategyDnnCorrectnessWithSubclassedModel( ...@@ -301,9 +313,8 @@ class TestDistributionStrategyDnnCorrectnessWithSubclassedModel(
use_validation_data): use_validation_data):
with self.assertRaisesRegexp( with self.assertRaisesRegexp(
ValueError, ValueError,
'We currently do not support distribution strategy with a ' 'Expected `model` argument to be a functional `Model` instance, '
'`Sequential` model that is created without `input_shape`/' 'but got a subclass model instead.'):
'`input_dim` set in its first layer or a subclassed model.'):
self.run_correctness_test( self.run_correctness_test(
distribution, distribution,
use_numpy, use_numpy,
......
...@@ -103,8 +103,8 @@ class TestDistributionStrategyWithCallbacks(test.TestCase, ...@@ -103,8 +103,8 @@ class TestDistributionStrategyWithCallbacks(test.TestCase,
validation_steps=validation_steps, validation_steps=validation_steps,
callbacks=[counter]) callbacks=[counter])
if isinstance(distribution, (tpu_strategy.TPUStrategy, if (isinstance(distribution, tpu_strategy.TPUStrategyV1) and
tpu_strategy.TPUStrategyV1)): not context.executing_eagerly()):
# TPU Strategy can have multi step training, from extended.steps_per_run # TPU Strategy can have multi step training, from extended.steps_per_run
# if steps_per_run = 1, then num_batch_call_per_epoch = steps_per_epoch # if steps_per_run = 1, then num_batch_call_per_epoch = steps_per_epoch
steps_per_run = distribution.extended.steps_per_run steps_per_run = distribution.extended.steps_per_run
......
...@@ -26,6 +26,7 @@ from tensorflow.python.distribute import distribute_coordinator as dc ...@@ -26,6 +26,7 @@ from tensorflow.python.distribute import distribute_coordinator as dc
from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.distribute import input_lib from tensorflow.python.distribute import input_lib
from tensorflow.python.distribute import reduce_util as ds_reduce_util from tensorflow.python.distribute import reduce_util as ds_reduce_util
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
...@@ -166,8 +167,6 @@ def experimental_tpu_fit_loop(model, ...@@ -166,8 +167,6 @@ def experimental_tpu_fit_loop(model,
# TODO(fchollet): add support for `steps_per_epoch=None` in TPU loops. # TODO(fchollet): add support for `steps_per_epoch=None` in TPU loops.
current_strategy = model._distribution_strategy current_strategy = model._distribution_strategy
iterator = dist_utils.get_iterator(dataset, current_strategy) iterator = dist_utils.get_iterator(dataset, current_strategy)
steps_per_epoch = training_utils.infer_steps_for_dataset(
dataset, steps_per_epoch, epochs, steps_name='steps_per_epoch')
scope = dist_utils.distributed_scope( scope = dist_utils.distributed_scope(
strategy=current_strategy, learning_phase=1) strategy=current_strategy, learning_phase=1)
...@@ -185,12 +184,8 @@ def experimental_tpu_fit_loop(model, ...@@ -185,12 +184,8 @@ def experimental_tpu_fit_loop(model,
tensor = model._all_metrics_tensors[name] tensor = model._all_metrics_tensors[name]
initial_loop_values[name] = array_ops.zeros(tensor.shape, tensor.dtype) initial_loop_values[name] = array_ops.zeros(tensor.shape, tensor.dtype)
if steps_per_epoch is not None:
iteration_value = min(steps_per_epoch, iteration_value = min(steps_per_epoch,
current_strategy.extended.steps_per_run) current_strategy.extended.steps_per_run)
else:
raise ValueError('Number of steps could not be infered from the data, '
'please pass the steps_per_epoch argument.')
steps_per_run = K.variable( steps_per_run = K.variable(
value=iteration_value, value=iteration_value,
...@@ -320,8 +315,6 @@ def experimental_tpu_test_loop(model, ...@@ -320,8 +315,6 @@ def experimental_tpu_test_loop(model,
mode = ModeKeys.TEST mode = ModeKeys.TEST
current_strategy = model._distribution_strategy current_strategy = model._distribution_strategy
iterator = dist_utils.get_iterator(dataset, current_strategy) iterator = dist_utils.get_iterator(dataset, current_strategy)
steps = training_utils.infer_steps_for_dataset(dataset, steps,
steps_name='steps')
scope = dist_utils.distributed_scope( scope = dist_utils.distributed_scope(
strategy=current_strategy, learning_phase=0) strategy=current_strategy, learning_phase=0)
...@@ -449,8 +442,6 @@ def experimental_tpu_predict_loop(model, ...@@ -449,8 +442,6 @@ def experimental_tpu_predict_loop(model,
(if the model has multiple outputs). (if the model has multiple outputs).
""" """
mode = ModeKeys.PREDICT mode = ModeKeys.PREDICT
steps = training_utils.infer_steps_for_dataset(dataset, steps,
steps_name='steps')
dataset_fully_shaped = dist_utils.is_dataset_shape_fully_defined(dataset) dataset_fully_shaped = dist_utils.is_dataset_shape_fully_defined(dataset)
padding_handler = None padding_handler = None
if not dataset_fully_shaped: if not dataset_fully_shaped:
...@@ -653,6 +644,14 @@ class DistributionSingleWorkerTrainingLoop(training_utils.TrainingLoop): ...@@ -653,6 +644,14 @@ class DistributionSingleWorkerTrainingLoop(training_utils.TrainingLoop):
'distribution strategies.') 'distribution strategies.')
if dist_utils.is_tpu_strategy(model._distribution_strategy): if dist_utils.is_tpu_strategy(model._distribution_strategy):
steps_per_epoch = training_utils.infer_steps_for_dataset(
dataset, steps_per_epoch, epochs, steps_name='steps_per_epoch')
if steps_per_epoch is None:
raise ValueError('Number of steps could not be infered from the data, '
'please pass the steps_per_epoch argument.')
if not context.executing_eagerly():
# Run TPU training in a custom loop in graph mode.
return experimental_tpu_fit_loop( return experimental_tpu_fit_loop(
model, model,
dataset, dataset,
...@@ -664,7 +663,7 @@ class DistributionSingleWorkerTrainingLoop(training_utils.TrainingLoop): ...@@ -664,7 +663,7 @@ class DistributionSingleWorkerTrainingLoop(training_utils.TrainingLoop):
steps_per_epoch=steps_per_epoch, steps_per_epoch=steps_per_epoch,
validation_steps=validation_steps, validation_steps=validation_steps,
validation_freq=validation_freq) validation_freq=validation_freq)
else:
return training_arrays.fit_loop( return training_arrays.fit_loop(
model, model,
dataset, dataset,
...@@ -702,9 +701,17 @@ class DistributionSingleWorkerTrainingLoop(training_utils.TrainingLoop): ...@@ -702,9 +701,17 @@ class DistributionSingleWorkerTrainingLoop(training_utils.TrainingLoop):
allow_partial_batch=True) allow_partial_batch=True)
if dist_utils.is_tpu_strategy(model._distribution_strategy): if dist_utils.is_tpu_strategy(model._distribution_strategy):
steps = training_utils.infer_steps_for_dataset(
dataset, steps, steps_name='steps')
if steps is None:
raise ValueError('Number of steps could not be infered from the data, '
'please pass the steps argument.')
if not context.executing_eagerly():
# Run TPU evaluation in a custom loop in graph mode.
return experimental_tpu_test_loop( return experimental_tpu_test_loop(
model, dataset, verbose=verbose, steps=steps, callbacks=callbacks) model, dataset, verbose=verbose, steps=steps, callbacks=callbacks)
else:
return training_arrays.test_loop( return training_arrays.test_loop(
model, model,
inputs=dataset, inputs=dataset,
...@@ -731,9 +738,14 @@ class DistributionSingleWorkerTrainingLoop(training_utils.TrainingLoop): ...@@ -731,9 +738,14 @@ class DistributionSingleWorkerTrainingLoop(training_utils.TrainingLoop):
batch_size=batch_size, batch_size=batch_size,
allow_partial_batch=True) allow_partial_batch=True)
if dist_utils.is_tpu_strategy(model._distribution_strategy): if dist_utils.is_tpu_strategy(model._distribution_strategy):
steps = training_utils.infer_steps_for_dataset(
dataset, steps, steps_name='steps')
if steps is None:
raise ValueError('Number of steps could not be infered from the data, '
'please pass the steps argument.')
if not context.executing_eagerly():
return experimental_tpu_predict_loop( return experimental_tpu_predict_loop(
model, dataset, verbose=verbose, steps=steps, callbacks=callbacks) model, dataset, verbose=verbose, steps=steps, callbacks=callbacks)
else:
return training_arrays.predict_loop( return training_arrays.predict_loop(
model, model,
dataset, dataset,
......
...@@ -90,6 +90,11 @@ def initialize_tpu_system(cluster_resolver=None): ...@@ -90,6 +90,11 @@ def initialize_tpu_system(cluster_resolver=None):
with ops.device(tpu_system_device): with ops.device(tpu_system_device):
output = _tpu_init_fn() output = _tpu_init_fn()
# Clear out the eager context caches since the memory is invalid now.
logging.info("Clearing out eager caches")
context.context()._clear_caches() # pylint: disable=protected-access
serialized_topology = output.numpy() serialized_topology = output.numpy()
else: else:
master = cluster_resolver.master() master = cluster_resolver.master()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册