提交 d2b35a79 编写于 作者: X Xinyi Wang 提交者: TensorFlower Gardener

Enable last partial batch for MWMS in TF2.x

PiperOrigin-RevId: 317760674
Change-Id: Ib7e0adbf4f8f013f21faef07ed4961c078806093
上级 7c384680
......@@ -178,6 +178,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
self._communication = communication
self._initialize_strategy(self._cluster_resolver)
self._cfer_fn_cache = weakref.WeakKeyDictionary()
self.experimental_enable_get_next_as_optional = True
assert isinstance(self._cross_device_ops,
cross_device_ops_lib.CollectiveAllReduce)
......
......@@ -370,7 +370,8 @@ class CollectiveAllReduceStrategyTestBase(
else:
self.assertEqual(list(expected_value), list(computed_value))
with self.assertRaises(errors.OutOfRangeError):
# error raised by calling optional_get_value on an Optional of None
with self.assertRaises(errors.InvalidArgumentError):
next_element = iterator.get_next()
sess.run([distribute_utils.select_replica(r, next_element)
for r in range(len(devices))])
......@@ -449,31 +450,35 @@ class DistributedCollectiveAllReduceStrategyTest(
combinations.combine(
mode=['graph'], required_gpus=[0, 1, 2], use_dataset=[True, False]))
def testMakeInputFnIterator(self, required_gpus, use_dataset):
if use_dataset:
fn = lambda: dataset_ops.Dataset.range(100)
else:
def fn():
dataset = dataset_ops.Dataset.range(100)
it = dataset_ops.make_one_shot_iterator(dataset)
return it.get_next
# We use CPU as the device when required_gpus = 0
devices_per_worker = max(1, required_gpus)
expected_values = [[i+j for j in range(devices_per_worker)]
for i in range(0, 100, devices_per_worker)]
input_fn = self._input_fn_to_test_input_context(
fn,
expected_num_replicas_in_sync=3*devices_per_worker,
expected_num_input_pipelines=3,
expected_input_pipeline_id=1) # because task_id = 1
self._test_input_fn_iterator(
'worker',
1,
required_gpus,
input_fn,
expected_values,
test_reinitialize=use_dataset,
ignore_order=not use_dataset)
def _worker_fn(task_type, task_id, required_gpus):
if use_dataset:
fn = lambda: dataset_ops.Dataset.range(100)
else:
def fn():
dataset = dataset_ops.Dataset.range(100)
it = dataset_ops.make_one_shot_iterator(dataset)
return it.get_next
# We use CPU as the device when required_gpus = 0
devices_per_worker = max(1, required_gpus)
expected_values = [[i+j for j in range(devices_per_worker)]
for i in range(0, 100, devices_per_worker)]
input_fn = self._input_fn_to_test_input_context(
fn,
expected_num_replicas_in_sync=3*devices_per_worker,
expected_num_input_pipelines=3,
expected_input_pipeline_id=task_id)
self._test_input_fn_iterator(
task_type,
task_id,
required_gpus,
input_fn,
expected_values,
test_reinitialize=use_dataset,
ignore_order=not use_dataset)
self._run_between_graph_clients(_worker_fn, self._cluster_spec,
required_gpus)
@combinations.generate(combinations.combine(mode=['graph']))
def testUpdateConfigProto(self):
......
......@@ -549,7 +549,7 @@ def _get_next_as_optional(iterator, strategy, name=None):
# Collective all-reduce requires explicit devices for inputs.
with ops.device("/cpu:0"):
# Converting to integers for all-reduce.
worker_has_value = math_ops.cast(worker_has_value, dtypes.int32)
worker_has_value = math_ops.cast(worker_has_value, dtypes.int64)
worker_devices.append(worker_has_value.device)
worker_has_values.append(worker_has_value)
# Make `replicas` a flat list of values across all replicas.
......@@ -624,16 +624,12 @@ class DistributedIteratorBase(DistributedIteratorInterface):
# get_next_as_optional(). And we only enable get_next_as_optional when the
# output shapes are not static.
#
# TODO(yuefengz): Currently `experimental_enable_get_next_as_optional` is
# always set to False in CollectiveAllReduceStrategy. We want to have a way
# to distinguish multi workers/single worker between graph, so we can enable
# the behavior in single worker case.
#
# TODO(rxsang): We want to always enable the get_next_as_optional behavior
# when user passed input_fn instead of dataset.
if getattr(
strategy.extended, "experimental_enable_get_next_as_optional", False):
self._enable_get_next_as_optional = not static_shape
self._enable_get_next_as_optional = (
not static_shape) or strategy.extended._in_multi_worker_mode()
else:
self._enable_get_next_as_optional = False
......@@ -906,7 +902,8 @@ class DistributedIterator(DistributedIteratorBase,
self._strategy = strategy
if getattr(
strategy.extended, "experimental_enable_get_next_as_optional", False):
self._enable_get_next_as_optional = not static_shape
self._enable_get_next_as_optional = (
not static_shape) or strategy.extended._in_multi_worker_mode()
else:
self._enable_get_next_as_optional = False
else:
......
......@@ -1144,7 +1144,6 @@ class DistributedIteratorMultiWorkerTest(
expected_values = [[[0, 1]], [[2, 3]], [[4]]]
input_context = None
strategy.extended.experimental_enable_get_next_as_optional = True
self._test_input_iteration(
input_type,
api_type,
......
......@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import numpy as np
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import combinations
......@@ -29,7 +30,6 @@ from tensorflow.python.distribute import strategy_test_lib
from tensorflow.python.eager import def_function
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
......@@ -81,7 +81,7 @@ class DistributedCollectiveAllReduceStrategyTest(
return d.shard(input_context.num_input_pipelines,
input_context.input_pipeline_id)
expected_sum_on_workers = [10, 35]
expected_data_on_worker = [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]
input_iterator = iter(
strategy.experimental_distribute_datasets_from_function(dataset_fn))
......@@ -90,10 +90,59 @@ class DistributedCollectiveAllReduceStrategyTest(
return strategy.experimental_local_results(iterator.get_next())
result = run(input_iterator)
sum_value = math_ops.reduce_sum(result)
self.assertEqual(
sum_value.numpy(),
expected_sum_on_workers[multi_worker_test_base.get_task_index()])
self.assertTrue(
np.array_equal(
result[0].numpy(),
expected_data_on_worker[multi_worker_test_base.get_task_index()]))
def testSimpleInputFromDatasetLastPartialBatch(self, strategy):
global_batch_size = 8
dataset = dataset_ops.DatasetV2.range(14).batch(
global_batch_size, drop_remainder=False)
input_iterator = iter(strategy.experimental_distribute_dataset(dataset))
@def_function.function
def run(input_iterator):
return strategy.run(lambda x: x, args=(next(input_iterator),))
# Let the complete batch go.
run(input_iterator)
# `result` is an incomplete batch
result = run(input_iterator)
expected_data_on_worker = [[8, 9, 10], [11, 12, 13]]
self.assertTrue(
np.array_equal(
result.numpy(),
expected_data_on_worker[multi_worker_test_base.get_task_index()]))
def testSimpleInputFromFnLastPartialBatch(self, strategy):
def dataset_fn(input_context):
global_batch_size = 8
batch_size = input_context.get_per_replica_batch_size(global_batch_size)
dataset = dataset_ops.DatasetV2.range(14).batch(
batch_size, drop_remainder=False)
return dataset.shard(input_context.num_input_pipelines,
input_context.input_pipeline_id)
input_iterator = iter(
strategy.experimental_distribute_datasets_from_function(dataset_fn))
@def_function.function
def run(input_iterator):
return strategy.run(lambda x: x, args=(next(input_iterator),))
# Let the complete batch go.
run(input_iterator)
# `result` is an incomplete batch
result = run(input_iterator)
expected_data_on_worker = [[8, 9, 10, 11], [12, 13]]
self.assertTrue(
np.array_equal(
result.numpy(), expected_data_on_worker[
multi_worker_test_base.get_task_index()]))
def testReduceHostTensor(self, strategy):
reduced = strategy.reduce(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册