提交 5af29e07 编写于 作者: J Jiri Simsa 提交者: TensorFlower Gardener

[tf.data] Modify the rebatching logic to round up the result of dividing the...

[tf.data] Modify the rebatching logic to round up the result of dividing the original batch size with the number of workers.

PiperOrigin-RevId: 251267307
上级 88ed9779
...@@ -40,8 +40,10 @@ Status RebatchOptimizer::Init( ...@@ -40,8 +40,10 @@ Status RebatchOptimizer::Init(
namespace { namespace {
constexpr char kAddOp[] = "Add";
constexpr char kConstOp[] = "Const"; constexpr char kConstOp[] = "Const";
constexpr char kIdentityOp[] = "Identity"; constexpr char kIdentityOp[] = "Identity";
constexpr char kSubOp[] = "Sub";
constexpr char kTruncateDivOp[] = "TruncateDiv"; constexpr char kTruncateDivOp[] = "TruncateDiv";
constexpr std::array<const char*, 5> kBatchDatasetOps = { constexpr std::array<const char*, 5> kBatchDatasetOps = {
...@@ -124,8 +126,9 @@ Status UpdateOutputShapes(const string& node_name, int64 num_workers, ...@@ -124,8 +126,9 @@ Status UpdateOutputShapes(const string& node_name, int64 num_workers,
return Status::OK(); return Status::OK();
} }
// Given a "batch" dataset node, modifies the batch_size input to divide the // Given a "batch" dataset node, we replace the `batch_size` input with a new
// current batch size by num_workers. // input that corresponds to the original input divided by `num_workers`. If
// `num_workers` does not divide `batch_size` evenly, the value is rounded up.
Status MutateBatchSize(const NodeDef& node, int64 num_workers, Status MutateBatchSize(const NodeDef& node, int64 num_workers,
MutableGraphView* graph) { MutableGraphView* graph) {
// For all the batching datasets the batch_size is input number 1 except for // For all the batching datasets the batch_size is input number 1 except for
...@@ -146,23 +149,20 @@ Status MutateBatchSize(const NodeDef& node, int64 num_workers, ...@@ -146,23 +149,20 @@ Status MutateBatchSize(const NodeDef& node, int64 num_workers,
return errors::Internal("Batch size node shape should be scalar"); return errors::Internal("Batch size node shape should be scalar");
} }
int64 batch_size = batch_size_tensor.scalar<int64>()(); int64 batch_size = batch_size_tensor.scalar<int64>()();
if (batch_size % num_workers != 0) { batch_size = (batch_size + num_workers - 1) / num_workers;
return errors::InvalidArgument(
"Batch size: ", batch_size,
" is not divisible by num_workers: ", num_workers);
}
batch_size /= num_workers;
new_batch_size_node = new_batch_size_node =
graph_utils::AddScalarConstNode<int64>(batch_size, graph); graph_utils::AddScalarConstNode<int64>(batch_size, graph);
} else { } else {
// TODO(jsimsa): To provide parity with the case where the batch size is a NodeDef* one_node = graph_utils::AddScalarConstNode<int64>(1, graph);
// constant, consider generating a subgraph that would fail when number of
// workers does not divide the original batch size evenly (instead of
// using truncated division).
NodeDef* num_workers_node = NodeDef* num_workers_node =
graph_utils::AddScalarConstNode<int64>(num_workers, graph); graph_utils::AddScalarConstNode<int64>(num_workers, graph);
NodeDef* numerator_node =
AddBinaryNode(batch_size_node->name(), num_workers_node->name(), kAddOp,
DT_INT64, graph);
numerator_node = AddBinaryNode(numerator_node->name(), one_node->name(),
kSubOp, DT_INT64, graph);
new_batch_size_node = new_batch_size_node =
AddBinaryNode(batch_size_node->name(), num_workers_node->name(), AddBinaryNode(numerator_node->name(), num_workers_node->name(),
kTruncateDivOp, DT_INT64, graph); kTruncateDivOp, DT_INT64, graph);
} }
// We don't call UpdateFanouts here because CSE elimination might lead to // We don't call UpdateFanouts here because CSE elimination might lead to
......
...@@ -63,14 +63,14 @@ class RebatchDatasetTest(test_base.DatasetTestBase): ...@@ -63,14 +63,14 @@ class RebatchDatasetTest(test_base.DatasetTestBase):
with self.assertRaisesRegexp(ValueError, "at least one dimension"): with self.assertRaisesRegexp(ValueError, "at least one dimension"):
distribute._RebatchDataset(dataset, num_workers=4) distribute._RebatchDataset(dataset, num_workers=4)
def testNotDivisibleError(self, drop_remainder): def testNotDivisible(self, drop_remainder):
dataset = dataset_ops.Dataset.range(1024).batch( dataset = dataset_ops.Dataset.range(1024).batch(
32, drop_remainder=drop_remainder) 32, drop_remainder=drop_remainder)
with self.assertRaisesRegexp(errors.InvalidArgumentError, rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=5)
"not divisible by"): expected_output = [[k for k in range(i, i + 7)] for i in range(0, 1022, 7)] # pylint: disable=g-complex-comprehension
rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=5) if not drop_remainder:
next_element = self.getNext(rebatched_dataset) expected_output.append([1022, 1023])
self.evaluate(next_element()) self.assertDatasetProduces(rebatched_dataset, expected_output)
def testTupleOutput(self, drop_remainder): def testTupleOutput(self, drop_remainder):
dataset = ( dataset = (
...@@ -371,10 +371,12 @@ class RebatchDatasetTest(test_base.DatasetTestBase): ...@@ -371,10 +371,12 @@ class RebatchDatasetTest(test_base.DatasetTestBase):
self.assertEqual([[None]], self.assertEqual([[None]],
[ts.as_list() for ts in _flat_shapes(dataset)]) [ts.as_list() for ts in _flat_shapes(dataset)])
# pylint: disable=g-complex-comprehension pairs = [(3, 0), (3, 0), (3, 0)]
x = [(2, 0), (2, 0), (2, 0), (2, 0), (2, 0), (5, 1), (5, 1), (2, 0), (2, 0), if not drop_remainder:
(2, 0), (2, 0), (2, 0), (5, 1), (5, 1)] pairs.extend([(1, 0)])
expected_output = [[value] * batch_size for batch_size, value in x] pairs.extend([(5, 1), (5, 1)])
pairs = pairs * 2
expected_output = [[value] * batch_size for batch_size, value in pairs]
self.assertDatasetProduces(dataset, expected_output) self.assertDatasetProduces(dataset, expected_output)
......
...@@ -20,7 +20,6 @@ from __future__ import print_function ...@@ -20,7 +20,6 @@ from __future__ import print_function
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest from tensorflow.python.data.util import nest
from tensorflow.python.data.util import structure from tensorflow.python.data.util import structure
from tensorflow.python.framework import errors
from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
...@@ -77,14 +76,8 @@ class _RebatchDataset(dataset_ops.UnaryDataset): ...@@ -77,14 +76,8 @@ class _RebatchDataset(dataset_ops.UnaryDataset):
raise ValueError( raise ValueError(
"Input shape should have at least one dimension. " "Input shape should have at least one dimension. "
"Perhaps your input dataset is not batched?") "Perhaps your input dataset is not batched?")
if (tensor_shape.dimension_value(output_shapes[0]) and
tensor_shape.dimension_value(output_shapes[0]) % num_workers != 0):
raise errors.InvalidArgumentError(
None, None,
"First dim of input shape: %d is not divisible by num_workers: %d" %
(output_shapes[0], num_workers))
output_dims = [d for d in output_shapes.dims] output_dims = [d for d in output_shapes.dims]
output_dims[0] = output_dims[0] // num_workers output_dims[0] = (output_dims[0] + num_workers - 1) // num_workers
return tensor_shape.TensorShape(output_dims) return tensor_shape.TensorShape(output_dims)
input_types = dataset_ops.get_legacy_output_types(self._input_dataset) input_types = dataset_ops.get_legacy_output_types(self._input_dataset)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册