diff --git a/tensorflow/python/distribute/tpu_strategy.py b/tensorflow/python/distribute/tpu_strategy.py index 290986504b324d6082231e18f5d15de05ff216fb..0d90c395163589f2aba9e4ac8cf3b9f3b9b778d9 100644 --- a/tensorflow/python/distribute/tpu_strategy.py +++ b/tensorflow/python/distribute/tpu_strategy.py @@ -668,6 +668,15 @@ class TPUExtended(distribute_lib.StrategyExtendedV1): if kwargs is None: kwargs = {} + # Remove None at the end of args as they are not replicatable + # If there are None in the middle we can't do anything about it + # so let those cases fail. + # For example when Keras model predict is used they pass the targets as + # None. We want to handle it here so all client libraries don't have to + # do this as other strategies can handle None values better. + while args and args[-1] is None: + args = args[:-1] + # Used to re-structure flattened output tensors from `tpu.replicate()` # into a structured format. result = [[]] diff --git a/tensorflow/python/keras/distribute/BUILD b/tensorflow/python/keras/distribute/BUILD index bad86c1d1c4e708ec06cc5023dc141069bb5f115..8fe71268cb037f4904a3dc5d4ecbcaf31cae2664 100644 --- a/tensorflow/python/keras/distribute/BUILD +++ b/tensorflow/python/keras/distribute/BUILD @@ -84,7 +84,7 @@ distribute_py_test( srcs = ["distribute_strategy_test.py"], full_precision = True, main = "distribute_strategy_test.py", - shard_count = 4, + shard_count = 5, tags = [ "multi_and_single_gpu", "no_oss", # TODO(b/117919883): Fix python error. diff --git a/tensorflow/python/keras/distribute/distribute_strategy_test.py b/tensorflow/python/keras/distribute/distribute_strategy_test.py index 307ec3df9700088e7c6b6b53f462a31413e07718..10cd2e8ee3d2b6c3a24fda2e4c9810385b4211d8 100644 --- a/tensorflow/python/keras/distribute/distribute_strategy_test.py +++ b/tensorflow/python/keras/distribute/distribute_strategy_test.py @@ -297,6 +297,11 @@ def strategy_minus_tpu_combinations(): def tpu_strategy_combinations(): + return combinations.combine(distribution=tpu_strategies, + mode=['graph', 'eager']) + + +def tpu_strategy_combinations_graph_only(): return combinations.combine(distribution=tpu_strategies, mode=['graph']) @@ -313,8 +318,8 @@ def all_strategy_combinations_plus_cloning(): cloning=[True, False]) + combinations.combine( distribution=tpu_strategies, - mode=['graph'], - cloning=[True, False])) + mode=['graph', 'eager'], + cloning=[False])) def all_strategy_minus_default_and_tpu_combinations(): @@ -334,8 +339,8 @@ def all_strategy_combinations_minus_default(): def strategy_and_optimizer_combinations(): - return combinations.times( - all_strategy_combinations(), + non_tpu_strategies = combinations.times( + strategy_minus_tpu_combinations(), # TODO(b/130808953): Simplify when optimizers v1 work with cloning=False. combinations.combine( optimizer=[ @@ -353,6 +358,32 @@ def strategy_and_optimizer_combinations(): strategy_combinations.rmsprop_optimizer_keras_v2_fn ], cloning=[True, False])) + # TODO(b/130808953): Simplify when optimizers v1 work with cloning=False. + tpu_strategies_graph = combinations.combine( + distribution=tpu_strategies, + mode=['graph'], + cloning=[True], + optimizer=[ + strategy_combinations.adagrad_optimizer_v1_fn, + strategy_combinations.adam_optimizer_v1_fn, + strategy_combinations.gradient_descent_optimizer_v1_fn, + strategy_combinations.rmsprop_optimizer_v1_fn, + strategy_combinations.adagrad_optimizer_keras_v2_fn, + strategy_combinations.adam_optimizer_keras_v2_fn, + strategy_combinations.gradient_descent_optimizer_keras_v2_fn, + strategy_combinations.rmsprop_optimizer_keras_v2_fn + ]) + tpu_strategies_eager = combinations.combine( + distribution=tpu_strategies, + mode=['eager'], + cloning=[False], + optimizer=[ + strategy_combinations.adagrad_optimizer_keras_v2_fn, + strategy_combinations.adam_optimizer_keras_v2_fn, + strategy_combinations.gradient_descent_optimizer_keras_v2_fn, + strategy_combinations.rmsprop_optimizer_keras_v2_fn + ]) + return non_tpu_strategies + tpu_strategies_eager + tpu_strategies_graph class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase, @@ -769,7 +800,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, self.assertAllEqual([6, 7], outs[1].shape) @combinations.generate( - combinations.times(tpu_strategy_combinations(), + combinations.times(tpu_strategy_combinations_graph_only(), combinations.combine(batch_size=[4, 6]))) def test_evaluate_with_partial_batch(self, distribution, batch_size): with self.cached_session(): @@ -812,7 +843,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, rtol=1e-5) @combinations.generate( - combinations.times(tpu_strategy_combinations(), + combinations.times(tpu_strategy_combinations_graph_only(), combinations.combine(cloning=[True, False]))) def test_predict_with_partial_batch(self, distribution, cloning): with self.cached_session(): @@ -846,7 +877,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, atol=1e-5, rtol=1e-5) - @combinations.generate(tpu_strategy_combinations()) + @combinations.generate(tpu_strategy_combinations_graph_only()) def test_no_target_model(self, distribution): with self.cached_session(): optimizer = gradient_descent.GradientDescentOptimizer(0.001) @@ -872,7 +903,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, model.evaluate(inputs, steps=1) @combinations.generate( - combinations.times(tpu_strategy_combinations(), + combinations.times(tpu_strategy_combinations_graph_only(), combinations.combine(cloning=[True, False]))) def test_predict_multi_output_model_with_partial_batch( self, distribution, cloning): @@ -1192,6 +1223,7 @@ class TestDistributionStrategyWithDatasets(test.TestCase, def test_fit_eval_and_predict_with_optimizer(self, distribution, optimizer, cloning): with self.cached_session(): + with distribution.scope(): model = get_model() @@ -1341,7 +1373,7 @@ class TestDistributionStrategyWithDatasets(test.TestCase, self.assertAllClose(0.001, keras.backend.get_value(model.optimizer.lr)) @combinations.generate( - combinations.times(tpu_strategy_combinations(), + combinations.times(tpu_strategy_combinations_graph_only(), combinations.combine(batch_size=[4, 6]))) def test_evaluate_with_dataset_with_partial_batch(self, distribution, batch_size): @@ -1382,7 +1414,7 @@ class TestDistributionStrategyWithDatasets(test.TestCase, rtol=1e-5) @combinations.generate( - combinations.times(tpu_strategy_combinations(), + combinations.times(tpu_strategy_combinations_graph_only(), combinations.combine(cloning=[True, False]))) def test_predict_with_dataset_with_partial_batch(self, distribution, cloning): with self.cached_session(): @@ -1411,7 +1443,7 @@ class TestDistributionStrategyWithDatasets(test.TestCase, rtol=1e-5) @combinations.generate( - combinations.times(tpu_strategy_combinations(), + combinations.times(tpu_strategy_combinations_graph_only(), combinations.combine(cloning=[True, False]))) def test_predict_multi_output_model_with_dataset_with_partial_batch( self, distribution, cloning): diff --git a/tensorflow/python/keras/distribute/distributed_training_utils.py b/tensorflow/python/keras/distribute/distributed_training_utils.py index b544e72208be77e43e64154954c25d1869391129..98cb6c8857e8287527bc4d96d097f795a46658da 100644 --- a/tensorflow/python/keras/distribute/distributed_training_utils.py +++ b/tensorflow/python/keras/distribute/distributed_training_utils.py @@ -164,6 +164,15 @@ def unwrap_outputs(distribution_strategy, grouped_outputs, grouped_outputs[0], axis=None) all_outputs = flatten_per_replica_values(distribution_strategy, grouped_outputs[1:]) + if (is_tpu_strategy(distribution_strategy) and + ops.executing_eagerly_outside_functions()): + # Choose 1 value per replica in the TPU case since all replicas produce the + # same output. + # We only do this in eager mode for now since this function is used in + # both graph and eager mode and in the graph case we currently don't use + # experimental_run so would need to be removed when we converge the graph + # code path as well. + all_outputs = all_outputs[::distribution_strategy.num_replicas_in_sync] return [loss] + all_outputs @@ -578,6 +587,9 @@ def _prepare_feed_values(model, inputs, targets, sample_weights, mode): """ strategy = model._distribution_strategy inputs, targets, sample_weights = _get_input_from_iterator(inputs, model) + if is_tpu_strategy(strategy): + if sample_weights is not None: + raise ValueError('TPUStrategy does not support sample weights.') # When the inputs are dict, then we want to flatten it in the same order as # the input layers, such that the data are fed into the input layers in the @@ -611,8 +623,8 @@ def is_distributing_by_cloning(model): """Decide whether this model is going to be distributed via cloning. We are going to distribute the model by cloning if the user has signaled - that intent by not setting `cloning=False` in `Model.compile()` unless we - are in graph mode or running on TPU. + that intent by setting `cloning=True` in `Model.compile()` unless we are in + graph mode. Args: model: Keras model to distribute. @@ -621,9 +633,15 @@ def is_distributing_by_cloning(model): True if the `model` is going to be distributed using cloning and False otherwise. """ + if (is_tpu_strategy(model._distribution_strategy) and + context.executing_eagerly): + if model._cloning: + logging.warning( + 'Model cloning is not supported in TPU Strategy in Eager mode.' + 'cloning argument will be ignored.') + return False return (model._cloning or model._compile_distribution or - not ops.executing_eagerly_outside_functions() or - K.is_tpu_strategy(model._distribution_strategy)) + not ops.executing_eagerly_outside_functions()) def _custom_compile_for_predict(model): diff --git a/tensorflow/python/keras/distribute/keras_correctness_test_base.py b/tensorflow/python/keras/distribute/keras_correctness_test_base.py index 7c08abd96180f1362686896e642f51aa176a73f0..30c2fb96c94dea87f28bc95bfb1252a21b893a06 100644 --- a/tensorflow/python/keras/distribute/keras_correctness_test_base.py +++ b/tensorflow/python/keras/distribute/keras_correctness_test_base.py @@ -90,11 +90,20 @@ def strategies_for_embedding_models(): def test_combinations_for_embedding_model(): + # TODO(sourabhbajaj): Enable tests for eager mode + eager_mode_strategies = [s for s in strategies_for_embedding_models() + if not s.required_tpu] + return (combinations.times( combinations.combine( distribution=strategies_for_embedding_models(), cloning=[True, False]), - (graph_mode_test_configuration() + eager_mode_test_configuration()))) + (graph_mode_test_configuration())) + + combinations.times( + combinations.combine( + distribution=eager_mode_strategies, + cloning=[False]), + (eager_mode_test_configuration()))) def test_combinations_with_tpu_strategies(): @@ -322,7 +331,7 @@ def compare_results(results_with_ds, return default_tolerance - for key in results_with_ds: + for key in sorted(results_with_ds.keys()): if (key.startswith('training_history') and isinstance(distribution, (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1)) and @@ -420,9 +429,9 @@ class TestDistributionStrategyCorrectnessBase(test.TestCase, def get_model(self, distribution=None, cloning=None, input_shapes=None): raise NotImplementedError - def skip_unsupported_test_configuration(self, distribution): - if should_skip_tpu_with_eager(distribution): - self.skipTest('TPUStrategy does not support eager mode now.') + def skip_unsupported_test_configuration(self, distribution, cloning): + if should_skip_tpu_with_eager(distribution) and cloning: + self.skipTest('TPUStrategy does not support eager mode with cloning.') return def run_correctness_test(self, @@ -443,7 +452,7 @@ class TestDistributionStrategyCorrectnessBase(test.TestCase, self.skipTest('Test broken; see b/129793413 and b/117920141') with self.cached_session(): self.set_up_test_config(use_numpy, use_validation_data, with_batch_norm) - self.skip_unsupported_test_configuration(distribution) + self.skip_unsupported_test_configuration(distribution, cloning) if partial_last_batch == 'eval': x_train, y_train, x_eval, y_eval, x_predict = ( @@ -540,7 +549,7 @@ class TestDistributionStrategyCorrectnessBase(test.TestCase, def run_dynamic_lr_test(self, distribution, cloning=None): with self.cached_session(): self.set_up_test_config() - self.skip_unsupported_test_configuration(distribution) + self.skip_unsupported_test_configuration(distribution, cloning) x_train, y_train, _ = self.get_data() model = self.get_model(cloning=cloning, input_shapes=get_shapes(x_train)) diff --git a/tensorflow/python/keras/distribute/keras_dnn_correctness_test.py b/tensorflow/python/keras/distribute/keras_dnn_correctness_test.py index 12d966d9b30f2cfdcdb16fab0aab26d916f52a26..c8155262a82890132bea121b978bd4d0e272b051 100644 --- a/tensorflow/python/keras/distribute/keras_dnn_correctness_test.py +++ b/tensorflow/python/keras/distribute/keras_dnn_correctness_test.py @@ -153,7 +153,7 @@ class TestDistributionStrategyDnnMetricCorrectness( def run_metric_correctness_test(self, distribution, cloning): with self.cached_session(): self.set_up_test_config() - self.skip_unsupported_test_configuration(distribution) + self.skip_unsupported_test_configuration(distribution, cloning) x_train, y_train, _ = self.get_data() model = self.get_model(cloning, distribution=distribution) @@ -195,7 +195,7 @@ class TestDistributionStrategyDnnMetricEvalCorrectness( def run_eval_metrics_correctness_test(self, distribution, cloning): with self.cached_session(): self.set_up_test_config() - self.skip_unsupported_test_configuration(distribution) + self.skip_unsupported_test_configuration(distribution, cloning) model = self.get_model(cloning, distribution=distribution) @@ -266,11 +266,17 @@ class TestDistributionStrategyDnnCorrectnessWithSubclassedModel( keras_correctness_test_base.all_strategy_and_input_config_combinations()) def test_dnn_correctness(self, distribution, use_numpy, use_validation_data, cloning): - if ((not cloning and context.executing_eagerly() and - not K.is_tpu_strategy(distribution)) or + if ((not cloning and context.executing_eagerly()) or is_default_strategy(distribution)): self.run_correctness_test(distribution, use_numpy, use_validation_data, cloning) + elif K.is_tpu_strategy(distribution) and not context.executing_eagerly(): + with self.assertRaisesRegexp( + ValueError, + 'Expected `model` argument to be a functional `Model` instance, ' + 'but got a subclass model instead.'): + self.run_correctness_test(distribution, use_numpy, use_validation_data, + cloning) else: with self.assertRaisesRegexp( ValueError, @@ -286,6 +292,12 @@ class TestDistributionStrategyDnnCorrectnessWithSubclassedModel( not K.is_tpu_strategy(distribution)) or is_default_strategy(distribution)): self.run_dynamic_lr_test(distribution, cloning) + elif K.is_tpu_strategy(distribution): + with self.assertRaisesRegexp( + ValueError, + 'Expected `model` argument to be a functional `Model` instance, ' + 'but got a subclass model instead.'): + self.run_dynamic_lr_test(distribution, cloning) else: with self.assertRaisesRegexp( ValueError, @@ -301,9 +313,8 @@ class TestDistributionStrategyDnnCorrectnessWithSubclassedModel( use_validation_data): with self.assertRaisesRegexp( ValueError, - 'We currently do not support distribution strategy with a ' - '`Sequential` model that is created without `input_shape`/' - '`input_dim` set in its first layer or a subclassed model.'): + 'Expected `model` argument to be a functional `Model` instance, ' + 'but got a subclass model instead.'): self.run_correctness_test( distribution, use_numpy, diff --git a/tensorflow/python/keras/distribute/keras_utils_test.py b/tensorflow/python/keras/distribute/keras_utils_test.py index 3b324f1be2ac37cf0bd4f615c69e2bba92cf661d..85965da378952d581af7a5a25716fef7eba182c1 100644 --- a/tensorflow/python/keras/distribute/keras_utils_test.py +++ b/tensorflow/python/keras/distribute/keras_utils_test.py @@ -103,8 +103,8 @@ class TestDistributionStrategyWithCallbacks(test.TestCase, validation_steps=validation_steps, callbacks=[counter]) - if isinstance(distribution, (tpu_strategy.TPUStrategy, - tpu_strategy.TPUStrategyV1)): + if (isinstance(distribution, tpu_strategy.TPUStrategyV1) and + not context.executing_eagerly()): # TPU Strategy can have multi step training, from extended.steps_per_run # if steps_per_run = 1, then num_batch_call_per_epoch = steps_per_epoch steps_per_run = distribution.extended.steps_per_run diff --git a/tensorflow/python/keras/engine/training_distributed.py b/tensorflow/python/keras/engine/training_distributed.py index d26da1b9d17eee74ce6c711a82adf93b8c10a0b6..4c309284ab00b11013dedf6b010360a1cc86773f 100644 --- a/tensorflow/python/keras/engine/training_distributed.py +++ b/tensorflow/python/keras/engine/training_distributed.py @@ -26,6 +26,7 @@ from tensorflow.python.distribute import distribute_coordinator as dc from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.distribute import input_lib from tensorflow.python.distribute import reduce_util as ds_reduce_util +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import errors from tensorflow.python.framework import ops @@ -166,8 +167,6 @@ def experimental_tpu_fit_loop(model, # TODO(fchollet): add support for `steps_per_epoch=None` in TPU loops. current_strategy = model._distribution_strategy iterator = dist_utils.get_iterator(dataset, current_strategy) - steps_per_epoch = training_utils.infer_steps_for_dataset( - dataset, steps_per_epoch, epochs, steps_name='steps_per_epoch') scope = dist_utils.distributed_scope( strategy=current_strategy, learning_phase=1) @@ -185,12 +184,8 @@ def experimental_tpu_fit_loop(model, tensor = model._all_metrics_tensors[name] initial_loop_values[name] = array_ops.zeros(tensor.shape, tensor.dtype) - if steps_per_epoch is not None: - iteration_value = min(steps_per_epoch, - current_strategy.extended.steps_per_run) - else: - raise ValueError('Number of steps could not be infered from the data, ' - 'please pass the steps_per_epoch argument.') + iteration_value = min(steps_per_epoch, + current_strategy.extended.steps_per_run) steps_per_run = K.variable( value=iteration_value, @@ -320,8 +315,6 @@ def experimental_tpu_test_loop(model, mode = ModeKeys.TEST current_strategy = model._distribution_strategy iterator = dist_utils.get_iterator(dataset, current_strategy) - steps = training_utils.infer_steps_for_dataset(dataset, steps, - steps_name='steps') scope = dist_utils.distributed_scope( strategy=current_strategy, learning_phase=0) @@ -449,8 +442,6 @@ def experimental_tpu_predict_loop(model, (if the model has multiple outputs). """ mode = ModeKeys.PREDICT - steps = training_utils.infer_steps_for_dataset(dataset, steps, - steps_name='steps') dataset_fully_shaped = dist_utils.is_dataset_shape_fully_defined(dataset) padding_handler = None if not dataset_fully_shaped: @@ -653,32 +644,40 @@ class DistributionSingleWorkerTrainingLoop(training_utils.TrainingLoop): 'distribution strategies.') if dist_utils.is_tpu_strategy(model._distribution_strategy): - return experimental_tpu_fit_loop( - model, - dataset, - epochs=epochs, - verbose=verbose, - callbacks=callbacks, - val_dataset=val_dataset, - initial_epoch=initial_epoch, - steps_per_epoch=steps_per_epoch, - validation_steps=validation_steps, - validation_freq=validation_freq) - else: - return training_arrays.fit_loop( - model, - dataset, - batch_size=batch_size, - epochs=epochs, - verbose=verbose, - callbacks=callbacks, - val_inputs=val_dataset, - shuffle=shuffle, - initial_epoch=initial_epoch, - steps_per_epoch=steps_per_epoch, - validation_steps=validation_steps, - validation_freq=validation_freq, - steps_name='steps_per_epoch') + steps_per_epoch = training_utils.infer_steps_for_dataset( + dataset, steps_per_epoch, epochs, steps_name='steps_per_epoch') + if steps_per_epoch is None: + raise ValueError('Number of steps could not be infered from the data, ' + 'please pass the steps_per_epoch argument.') + + if not context.executing_eagerly(): + # Run TPU training in a custom loop in graph mode. + return experimental_tpu_fit_loop( + model, + dataset, + epochs=epochs, + verbose=verbose, + callbacks=callbacks, + val_dataset=val_dataset, + initial_epoch=initial_epoch, + steps_per_epoch=steps_per_epoch, + validation_steps=validation_steps, + validation_freq=validation_freq) + + return training_arrays.fit_loop( + model, + dataset, + batch_size=batch_size, + epochs=epochs, + verbose=verbose, + callbacks=callbacks, + val_inputs=val_dataset, + shuffle=shuffle, + initial_epoch=initial_epoch, + steps_per_epoch=steps_per_epoch, + validation_steps=validation_steps, + validation_freq=validation_freq, + steps_name='steps_per_epoch') def evaluate(self, model, @@ -702,16 +701,24 @@ class DistributionSingleWorkerTrainingLoop(training_utils.TrainingLoop): allow_partial_batch=True) if dist_utils.is_tpu_strategy(model._distribution_strategy): - return experimental_tpu_test_loop( - model, dataset, verbose=verbose, steps=steps, callbacks=callbacks) - else: - return training_arrays.test_loop( - model, - inputs=dataset, - batch_size=batch_size, - verbose=verbose, - steps=steps, - callbacks=callbacks) + steps = training_utils.infer_steps_for_dataset( + dataset, steps, steps_name='steps') + if steps is None: + raise ValueError('Number of steps could not be infered from the data, ' + 'please pass the steps argument.') + + if not context.executing_eagerly(): + # Run TPU evaluation in a custom loop in graph mode. + return experimental_tpu_test_loop( + model, dataset, verbose=verbose, steps=steps, callbacks=callbacks) + + return training_arrays.test_loop( + model, + inputs=dataset, + batch_size=batch_size, + verbose=verbose, + steps=steps, + callbacks=callbacks) def predict(self, model, @@ -731,16 +738,21 @@ class DistributionSingleWorkerTrainingLoop(training_utils.TrainingLoop): batch_size=batch_size, allow_partial_batch=True) if dist_utils.is_tpu_strategy(model._distribution_strategy): - return experimental_tpu_predict_loop( - model, dataset, verbose=verbose, steps=steps, callbacks=callbacks) - else: - return training_arrays.predict_loop( - model, - dataset, - batch_size=batch_size, - verbose=verbose, - steps=steps, - callbacks=callbacks) + steps = training_utils.infer_steps_for_dataset( + dataset, steps, steps_name='steps') + if steps is None: + raise ValueError('Number of steps could not be infered from the data, ' + 'please pass the steps argument.') + if not context.executing_eagerly(): + return experimental_tpu_predict_loop( + model, dataset, verbose=verbose, steps=steps, callbacks=callbacks) + return training_arrays.predict_loop( + model, + dataset, + batch_size=batch_size, + verbose=verbose, + steps=steps, + callbacks=callbacks) def _process_batch_and_step_size( self, model, inputs, batch_size, steps_per_epoch, mode): diff --git a/tensorflow/python/tpu/tpu_strategy_util.py b/tensorflow/python/tpu/tpu_strategy_util.py index efd26e314c40c3ef102bf4bfa35b04ee59653a09..068a323146590538b86fe0cd6abdcf380c9a90fc 100644 --- a/tensorflow/python/tpu/tpu_strategy_util.py +++ b/tensorflow/python/tpu/tpu_strategy_util.py @@ -90,6 +90,11 @@ def initialize_tpu_system(cluster_resolver=None): with ops.device(tpu_system_device): output = _tpu_init_fn() + + # Clear out the eager context caches since the memory is invalid now. + logging.info("Clearing out eager caches") + context.context()._clear_caches() # pylint: disable=protected-access + serialized_topology = output.numpy() else: master = cluster_resolver.master()