提交 a1699237 编写于 作者: S Sourabh Bajaj 提交者: TensorFlower Gardener

Support model.fit and evaluate in 2.0 with TPUStrategy using the...

Support model.fit and evaluate in 2.0 with TPUStrategy using the experimental_run + train_on_batch API.

PiperOrigin-RevId: 251570029
上级 0ba31190
......@@ -668,6 +668,15 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
if kwargs is None:
kwargs = {}
# Remove None at the end of args as they are not replicatable
# If there are None in the middle we can't do anything about it
# so let those cases fail.
# For example when Keras model predict is used they pass the targets as
# None. We want to handle it here so all client libraries don't have to
# do this as other strategies can handle None values better.
while args and args[-1] is None:
args = args[:-1]
# Used to re-structure flattened output tensors from `tpu.replicate()`
# into a structured format.
result = [[]]
......
......@@ -84,7 +84,7 @@ distribute_py_test(
srcs = ["distribute_strategy_test.py"],
full_precision = True,
main = "distribute_strategy_test.py",
shard_count = 4,
shard_count = 5,
tags = [
"multi_and_single_gpu",
"no_oss", # TODO(b/117919883): Fix python error.
......
......@@ -297,6 +297,11 @@ def strategy_minus_tpu_combinations():
def tpu_strategy_combinations():
return combinations.combine(distribution=tpu_strategies,
mode=['graph', 'eager'])
def tpu_strategy_combinations_graph_only():
return combinations.combine(distribution=tpu_strategies,
mode=['graph'])
......@@ -313,8 +318,8 @@ def all_strategy_combinations_plus_cloning():
cloning=[True, False]) +
combinations.combine(
distribution=tpu_strategies,
mode=['graph'],
cloning=[True, False]))
mode=['graph', 'eager'],
cloning=[False]))
def all_strategy_minus_default_and_tpu_combinations():
......@@ -334,8 +339,8 @@ def all_strategy_combinations_minus_default():
def strategy_and_optimizer_combinations():
return combinations.times(
all_strategy_combinations(),
non_tpu_strategies = combinations.times(
strategy_minus_tpu_combinations(),
# TODO(b/130808953): Simplify when optimizers v1 work with cloning=False.
combinations.combine(
optimizer=[
......@@ -353,6 +358,32 @@ def strategy_and_optimizer_combinations():
strategy_combinations.rmsprop_optimizer_keras_v2_fn
],
cloning=[True, False]))
# TODO(b/130808953): Simplify when optimizers v1 work with cloning=False.
tpu_strategies_graph = combinations.combine(
distribution=tpu_strategies,
mode=['graph'],
cloning=[True],
optimizer=[
strategy_combinations.adagrad_optimizer_v1_fn,
strategy_combinations.adam_optimizer_v1_fn,
strategy_combinations.gradient_descent_optimizer_v1_fn,
strategy_combinations.rmsprop_optimizer_v1_fn,
strategy_combinations.adagrad_optimizer_keras_v2_fn,
strategy_combinations.adam_optimizer_keras_v2_fn,
strategy_combinations.gradient_descent_optimizer_keras_v2_fn,
strategy_combinations.rmsprop_optimizer_keras_v2_fn
])
tpu_strategies_eager = combinations.combine(
distribution=tpu_strategies,
mode=['eager'],
cloning=[False],
optimizer=[
strategy_combinations.adagrad_optimizer_keras_v2_fn,
strategy_combinations.adam_optimizer_keras_v2_fn,
strategy_combinations.gradient_descent_optimizer_keras_v2_fn,
strategy_combinations.rmsprop_optimizer_keras_v2_fn
])
return non_tpu_strategies + tpu_strategies_eager + tpu_strategies_graph
class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase,
......@@ -769,7 +800,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase,
self.assertAllEqual([6, 7], outs[1].shape)
@combinations.generate(
combinations.times(tpu_strategy_combinations(),
combinations.times(tpu_strategy_combinations_graph_only(),
combinations.combine(batch_size=[4, 6])))
def test_evaluate_with_partial_batch(self, distribution, batch_size):
with self.cached_session():
......@@ -812,7 +843,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase,
rtol=1e-5)
@combinations.generate(
combinations.times(tpu_strategy_combinations(),
combinations.times(tpu_strategy_combinations_graph_only(),
combinations.combine(cloning=[True, False])))
def test_predict_with_partial_batch(self, distribution, cloning):
with self.cached_session():
......@@ -846,7 +877,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase,
atol=1e-5,
rtol=1e-5)
@combinations.generate(tpu_strategy_combinations())
@combinations.generate(tpu_strategy_combinations_graph_only())
def test_no_target_model(self, distribution):
with self.cached_session():
optimizer = gradient_descent.GradientDescentOptimizer(0.001)
......@@ -872,7 +903,7 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase,
model.evaluate(inputs, steps=1)
@combinations.generate(
combinations.times(tpu_strategy_combinations(),
combinations.times(tpu_strategy_combinations_graph_only(),
combinations.combine(cloning=[True, False])))
def test_predict_multi_output_model_with_partial_batch(
self, distribution, cloning):
......@@ -1192,6 +1223,7 @@ class TestDistributionStrategyWithDatasets(test.TestCase,
def test_fit_eval_and_predict_with_optimizer(self, distribution, optimizer,
cloning):
with self.cached_session():
with distribution.scope():
model = get_model()
......@@ -1341,7 +1373,7 @@ class TestDistributionStrategyWithDatasets(test.TestCase,
self.assertAllClose(0.001, keras.backend.get_value(model.optimizer.lr))
@combinations.generate(
combinations.times(tpu_strategy_combinations(),
combinations.times(tpu_strategy_combinations_graph_only(),
combinations.combine(batch_size=[4, 6])))
def test_evaluate_with_dataset_with_partial_batch(self, distribution,
batch_size):
......@@ -1382,7 +1414,7 @@ class TestDistributionStrategyWithDatasets(test.TestCase,
rtol=1e-5)
@combinations.generate(
combinations.times(tpu_strategy_combinations(),
combinations.times(tpu_strategy_combinations_graph_only(),
combinations.combine(cloning=[True, False])))
def test_predict_with_dataset_with_partial_batch(self, distribution, cloning):
with self.cached_session():
......@@ -1411,7 +1443,7 @@ class TestDistributionStrategyWithDatasets(test.TestCase,
rtol=1e-5)
@combinations.generate(
combinations.times(tpu_strategy_combinations(),
combinations.times(tpu_strategy_combinations_graph_only(),
combinations.combine(cloning=[True, False])))
def test_predict_multi_output_model_with_dataset_with_partial_batch(
self, distribution, cloning):
......
......@@ -164,6 +164,15 @@ def unwrap_outputs(distribution_strategy, grouped_outputs,
grouped_outputs[0], axis=None)
all_outputs = flatten_per_replica_values(distribution_strategy,
grouped_outputs[1:])
if (is_tpu_strategy(distribution_strategy) and
ops.executing_eagerly_outside_functions()):
# Choose 1 value per replica in the TPU case since all replicas produce the
# same output.
# We only do this in eager mode for now since this function is used in
# both graph and eager mode and in the graph case we currently don't use
# experimental_run so would need to be removed when we converge the graph
# code path as well.
all_outputs = all_outputs[::distribution_strategy.num_replicas_in_sync]
return [loss] + all_outputs
......@@ -578,6 +587,9 @@ def _prepare_feed_values(model, inputs, targets, sample_weights, mode):
"""
strategy = model._distribution_strategy
inputs, targets, sample_weights = _get_input_from_iterator(inputs, model)
if is_tpu_strategy(strategy):
if sample_weights is not None:
raise ValueError('TPUStrategy does not support sample weights.')
# When the inputs are dict, then we want to flatten it in the same order as
# the input layers, such that the data are fed into the input layers in the
......@@ -611,8 +623,8 @@ def is_distributing_by_cloning(model):
"""Decide whether this model is going to be distributed via cloning.
We are going to distribute the model by cloning if the user has signaled
that intent by not setting `cloning=False` in `Model.compile()` unless we
are in graph mode or running on TPU.
that intent by setting `cloning=True` in `Model.compile()` unless we are in
graph mode.
Args:
model: Keras model to distribute.
......@@ -621,9 +633,15 @@ def is_distributing_by_cloning(model):
True if the `model` is going to be distributed using cloning and False
otherwise.
"""
if (is_tpu_strategy(model._distribution_strategy) and
context.executing_eagerly):
if model._cloning:
logging.warning(
'Model cloning is not supported in TPU Strategy in Eager mode.'
'cloning argument will be ignored.')
return False
return (model._cloning or model._compile_distribution or
not ops.executing_eagerly_outside_functions() or
K.is_tpu_strategy(model._distribution_strategy))
not ops.executing_eagerly_outside_functions())
def _custom_compile_for_predict(model):
......
......@@ -90,11 +90,20 @@ def strategies_for_embedding_models():
def test_combinations_for_embedding_model():
# TODO(sourabhbajaj): Enable tests for eager mode
eager_mode_strategies = [s for s in strategies_for_embedding_models()
if not s.required_tpu]
return (combinations.times(
combinations.combine(
distribution=strategies_for_embedding_models(),
cloning=[True, False]),
(graph_mode_test_configuration() + eager_mode_test_configuration())))
(graph_mode_test_configuration())) +
combinations.times(
combinations.combine(
distribution=eager_mode_strategies,
cloning=[False]),
(eager_mode_test_configuration())))
def test_combinations_with_tpu_strategies():
......@@ -322,7 +331,7 @@ def compare_results(results_with_ds,
return default_tolerance
for key in results_with_ds:
for key in sorted(results_with_ds.keys()):
if (key.startswith('training_history') and
isinstance(distribution, (tpu_strategy.TPUStrategy,
tpu_strategy.TPUStrategyV1)) and
......@@ -420,9 +429,9 @@ class TestDistributionStrategyCorrectnessBase(test.TestCase,
def get_model(self, distribution=None, cloning=None, input_shapes=None):
raise NotImplementedError
def skip_unsupported_test_configuration(self, distribution):
if should_skip_tpu_with_eager(distribution):
self.skipTest('TPUStrategy does not support eager mode now.')
def skip_unsupported_test_configuration(self, distribution, cloning):
if should_skip_tpu_with_eager(distribution) and cloning:
self.skipTest('TPUStrategy does not support eager mode with cloning.')
return
def run_correctness_test(self,
......@@ -443,7 +452,7 @@ class TestDistributionStrategyCorrectnessBase(test.TestCase,
self.skipTest('Test broken; see b/129793413 and b/117920141')
with self.cached_session():
self.set_up_test_config(use_numpy, use_validation_data, with_batch_norm)
self.skip_unsupported_test_configuration(distribution)
self.skip_unsupported_test_configuration(distribution, cloning)
if partial_last_batch == 'eval':
x_train, y_train, x_eval, y_eval, x_predict = (
......@@ -540,7 +549,7 @@ class TestDistributionStrategyCorrectnessBase(test.TestCase,
def run_dynamic_lr_test(self, distribution, cloning=None):
with self.cached_session():
self.set_up_test_config()
self.skip_unsupported_test_configuration(distribution)
self.skip_unsupported_test_configuration(distribution, cloning)
x_train, y_train, _ = self.get_data()
model = self.get_model(cloning=cloning, input_shapes=get_shapes(x_train))
......
......@@ -153,7 +153,7 @@ class TestDistributionStrategyDnnMetricCorrectness(
def run_metric_correctness_test(self, distribution, cloning):
with self.cached_session():
self.set_up_test_config()
self.skip_unsupported_test_configuration(distribution)
self.skip_unsupported_test_configuration(distribution, cloning)
x_train, y_train, _ = self.get_data()
model = self.get_model(cloning, distribution=distribution)
......@@ -195,7 +195,7 @@ class TestDistributionStrategyDnnMetricEvalCorrectness(
def run_eval_metrics_correctness_test(self, distribution, cloning):
with self.cached_session():
self.set_up_test_config()
self.skip_unsupported_test_configuration(distribution)
self.skip_unsupported_test_configuration(distribution, cloning)
model = self.get_model(cloning, distribution=distribution)
......@@ -266,11 +266,17 @@ class TestDistributionStrategyDnnCorrectnessWithSubclassedModel(
keras_correctness_test_base.all_strategy_and_input_config_combinations())
def test_dnn_correctness(self, distribution, use_numpy, use_validation_data,
cloning):
if ((not cloning and context.executing_eagerly() and
not K.is_tpu_strategy(distribution)) or
if ((not cloning and context.executing_eagerly()) or
is_default_strategy(distribution)):
self.run_correctness_test(distribution, use_numpy, use_validation_data,
cloning)
elif K.is_tpu_strategy(distribution) and not context.executing_eagerly():
with self.assertRaisesRegexp(
ValueError,
'Expected `model` argument to be a functional `Model` instance, '
'but got a subclass model instead.'):
self.run_correctness_test(distribution, use_numpy, use_validation_data,
cloning)
else:
with self.assertRaisesRegexp(
ValueError,
......@@ -286,6 +292,12 @@ class TestDistributionStrategyDnnCorrectnessWithSubclassedModel(
not K.is_tpu_strategy(distribution)) or
is_default_strategy(distribution)):
self.run_dynamic_lr_test(distribution, cloning)
elif K.is_tpu_strategy(distribution):
with self.assertRaisesRegexp(
ValueError,
'Expected `model` argument to be a functional `Model` instance, '
'but got a subclass model instead.'):
self.run_dynamic_lr_test(distribution, cloning)
else:
with self.assertRaisesRegexp(
ValueError,
......@@ -301,9 +313,8 @@ class TestDistributionStrategyDnnCorrectnessWithSubclassedModel(
use_validation_data):
with self.assertRaisesRegexp(
ValueError,
'We currently do not support distribution strategy with a '
'`Sequential` model that is created without `input_shape`/'
'`input_dim` set in its first layer or a subclassed model.'):
'Expected `model` argument to be a functional `Model` instance, '
'but got a subclass model instead.'):
self.run_correctness_test(
distribution,
use_numpy,
......
......@@ -103,8 +103,8 @@ class TestDistributionStrategyWithCallbacks(test.TestCase,
validation_steps=validation_steps,
callbacks=[counter])
if isinstance(distribution, (tpu_strategy.TPUStrategy,
tpu_strategy.TPUStrategyV1)):
if (isinstance(distribution, tpu_strategy.TPUStrategyV1) and
not context.executing_eagerly()):
# TPU Strategy can have multi step training, from extended.steps_per_run
# if steps_per_run = 1, then num_batch_call_per_epoch = steps_per_epoch
steps_per_run = distribution.extended.steps_per_run
......
......@@ -26,6 +26,7 @@ from tensorflow.python.distribute import distribute_coordinator as dc
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.distribute import input_lib
from tensorflow.python.distribute import reduce_util as ds_reduce_util
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
......@@ -166,8 +167,6 @@ def experimental_tpu_fit_loop(model,
# TODO(fchollet): add support for `steps_per_epoch=None` in TPU loops.
current_strategy = model._distribution_strategy
iterator = dist_utils.get_iterator(dataset, current_strategy)
steps_per_epoch = training_utils.infer_steps_for_dataset(
dataset, steps_per_epoch, epochs, steps_name='steps_per_epoch')
scope = dist_utils.distributed_scope(
strategy=current_strategy, learning_phase=1)
......@@ -185,12 +184,8 @@ def experimental_tpu_fit_loop(model,
tensor = model._all_metrics_tensors[name]
initial_loop_values[name] = array_ops.zeros(tensor.shape, tensor.dtype)
if steps_per_epoch is not None:
iteration_value = min(steps_per_epoch,
current_strategy.extended.steps_per_run)
else:
raise ValueError('Number of steps could not be infered from the data, '
'please pass the steps_per_epoch argument.')
steps_per_run = K.variable(
value=iteration_value,
......@@ -320,8 +315,6 @@ def experimental_tpu_test_loop(model,
mode = ModeKeys.TEST
current_strategy = model._distribution_strategy
iterator = dist_utils.get_iterator(dataset, current_strategy)
steps = training_utils.infer_steps_for_dataset(dataset, steps,
steps_name='steps')
scope = dist_utils.distributed_scope(
strategy=current_strategy, learning_phase=0)
......@@ -449,8 +442,6 @@ def experimental_tpu_predict_loop(model,
(if the model has multiple outputs).
"""
mode = ModeKeys.PREDICT
steps = training_utils.infer_steps_for_dataset(dataset, steps,
steps_name='steps')
dataset_fully_shaped = dist_utils.is_dataset_shape_fully_defined(dataset)
padding_handler = None
if not dataset_fully_shaped:
......@@ -653,6 +644,14 @@ class DistributionSingleWorkerTrainingLoop(training_utils.TrainingLoop):
'distribution strategies.')
if dist_utils.is_tpu_strategy(model._distribution_strategy):
steps_per_epoch = training_utils.infer_steps_for_dataset(
dataset, steps_per_epoch, epochs, steps_name='steps_per_epoch')
if steps_per_epoch is None:
raise ValueError('Number of steps could not be infered from the data, '
'please pass the steps_per_epoch argument.')
if not context.executing_eagerly():
# Run TPU training in a custom loop in graph mode.
return experimental_tpu_fit_loop(
model,
dataset,
......@@ -664,7 +663,7 @@ class DistributionSingleWorkerTrainingLoop(training_utils.TrainingLoop):
steps_per_epoch=steps_per_epoch,
validation_steps=validation_steps,
validation_freq=validation_freq)
else:
return training_arrays.fit_loop(
model,
dataset,
......@@ -702,9 +701,17 @@ class DistributionSingleWorkerTrainingLoop(training_utils.TrainingLoop):
allow_partial_batch=True)
if dist_utils.is_tpu_strategy(model._distribution_strategy):
steps = training_utils.infer_steps_for_dataset(
dataset, steps, steps_name='steps')
if steps is None:
raise ValueError('Number of steps could not be infered from the data, '
'please pass the steps argument.')
if not context.executing_eagerly():
# Run TPU evaluation in a custom loop in graph mode.
return experimental_tpu_test_loop(
model, dataset, verbose=verbose, steps=steps, callbacks=callbacks)
else:
return training_arrays.test_loop(
model,
inputs=dataset,
......@@ -731,9 +738,14 @@ class DistributionSingleWorkerTrainingLoop(training_utils.TrainingLoop):
batch_size=batch_size,
allow_partial_batch=True)
if dist_utils.is_tpu_strategy(model._distribution_strategy):
steps = training_utils.infer_steps_for_dataset(
dataset, steps, steps_name='steps')
if steps is None:
raise ValueError('Number of steps could not be infered from the data, '
'please pass the steps argument.')
if not context.executing_eagerly():
return experimental_tpu_predict_loop(
model, dataset, verbose=verbose, steps=steps, callbacks=callbacks)
else:
return training_arrays.predict_loop(
model,
dataset,
......
......@@ -90,6 +90,11 @@ def initialize_tpu_system(cluster_resolver=None):
with ops.device(tpu_system_device):
output = _tpu_init_fn()
# Clear out the eager context caches since the memory is invalid now.
logging.info("Clearing out eager caches")
context.context()._clear_caches() # pylint: disable=protected-access
serialized_topology = output.numpy()
else:
master = cluster_resolver.master()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册