提交 5af29e07 编写于 作者: J Jiri Simsa 提交者: TensorFlower Gardener

[tf.data] Modify the rebatching logic to round up the result of dividing the...

[tf.data] Modify the rebatching logic to round up the result of dividing the original batch size with the number of workers.

PiperOrigin-RevId: 251267307
上级 88ed9779
......@@ -40,8 +40,10 @@ Status RebatchOptimizer::Init(
namespace {
constexpr char kAddOp[] = "Add";
constexpr char kConstOp[] = "Const";
constexpr char kIdentityOp[] = "Identity";
constexpr char kSubOp[] = "Sub";
constexpr char kTruncateDivOp[] = "TruncateDiv";
constexpr std::array<const char*, 5> kBatchDatasetOps = {
......@@ -124,8 +126,9 @@ Status UpdateOutputShapes(const string& node_name, int64 num_workers,
return Status::OK();
}
// Given a "batch" dataset node, modifies the batch_size input to divide the
// current batch size by num_workers.
// Given a "batch" dataset node, we replace the `batch_size` input with a new
// input that corresponds to the original input divided by `num_workers`. If
// `num_workers` does not divide `batch_size` evenly, the value is rounded up.
Status MutateBatchSize(const NodeDef& node, int64 num_workers,
MutableGraphView* graph) {
// For all the batching datasets the batch_size is input number 1 except for
......@@ -146,23 +149,20 @@ Status MutateBatchSize(const NodeDef& node, int64 num_workers,
return errors::Internal("Batch size node shape should be scalar");
}
int64 batch_size = batch_size_tensor.scalar<int64>()();
if (batch_size % num_workers != 0) {
return errors::InvalidArgument(
"Batch size: ", batch_size,
" is not divisible by num_workers: ", num_workers);
}
batch_size /= num_workers;
batch_size = (batch_size + num_workers - 1) / num_workers;
new_batch_size_node =
graph_utils::AddScalarConstNode<int64>(batch_size, graph);
} else {
// TODO(jsimsa): To provide parity with the case where the batch size is a
// constant, consider generating a subgraph that would fail when number of
// workers does not divide the original batch size evenly (instead of
// using truncated division).
NodeDef* one_node = graph_utils::AddScalarConstNode<int64>(1, graph);
NodeDef* num_workers_node =
graph_utils::AddScalarConstNode<int64>(num_workers, graph);
NodeDef* numerator_node =
AddBinaryNode(batch_size_node->name(), num_workers_node->name(), kAddOp,
DT_INT64, graph);
numerator_node = AddBinaryNode(numerator_node->name(), one_node->name(),
kSubOp, DT_INT64, graph);
new_batch_size_node =
AddBinaryNode(batch_size_node->name(), num_workers_node->name(),
AddBinaryNode(numerator_node->name(), num_workers_node->name(),
kTruncateDivOp, DT_INT64, graph);
}
// We don't call UpdateFanouts here because CSE elimination might lead to
......
......@@ -63,14 +63,14 @@ class RebatchDatasetTest(test_base.DatasetTestBase):
with self.assertRaisesRegexp(ValueError, "at least one dimension"):
distribute._RebatchDataset(dataset, num_workers=4)
def testNotDivisibleError(self, drop_remainder):
def testNotDivisible(self, drop_remainder):
dataset = dataset_ops.Dataset.range(1024).batch(
32, drop_remainder=drop_remainder)
with self.assertRaisesRegexp(errors.InvalidArgumentError,
"not divisible by"):
rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=5)
next_element = self.getNext(rebatched_dataset)
self.evaluate(next_element())
rebatched_dataset = distribute._RebatchDataset(dataset, num_workers=5)
expected_output = [[k for k in range(i, i + 7)] for i in range(0, 1022, 7)] # pylint: disable=g-complex-comprehension
if not drop_remainder:
expected_output.append([1022, 1023])
self.assertDatasetProduces(rebatched_dataset, expected_output)
def testTupleOutput(self, drop_remainder):
dataset = (
......@@ -371,10 +371,12 @@ class RebatchDatasetTest(test_base.DatasetTestBase):
self.assertEqual([[None]],
[ts.as_list() for ts in _flat_shapes(dataset)])
# pylint: disable=g-complex-comprehension
x = [(2, 0), (2, 0), (2, 0), (2, 0), (2, 0), (5, 1), (5, 1), (2, 0), (2, 0),
(2, 0), (2, 0), (2, 0), (5, 1), (5, 1)]
expected_output = [[value] * batch_size for batch_size, value in x]
pairs = [(3, 0), (3, 0), (3, 0)]
if not drop_remainder:
pairs.extend([(1, 0)])
pairs.extend([(5, 1), (5, 1)])
pairs = pairs * 2
expected_output = [[value] * batch_size for batch_size, value in pairs]
self.assertDatasetProduces(dataset, expected_output)
......
......@@ -20,7 +20,6 @@ from __future__ import print_function
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.data.util import structure
from tensorflow.python.framework import errors
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
......@@ -77,14 +76,8 @@ class _RebatchDataset(dataset_ops.UnaryDataset):
raise ValueError(
"Input shape should have at least one dimension. "
"Perhaps your input dataset is not batched?")
if (tensor_shape.dimension_value(output_shapes[0]) and
tensor_shape.dimension_value(output_shapes[0]) % num_workers != 0):
raise errors.InvalidArgumentError(
None, None,
"First dim of input shape: %d is not divisible by num_workers: %d" %
(output_shapes[0], num_workers))
output_dims = [d for d in output_shapes.dims]
output_dims[0] = output_dims[0] // num_workers
output_dims[0] = (output_dims[0] + num_workers - 1) // num_workers
return tensor_shape.TensorShape(output_dims)
input_types = dataset_ops.get_legacy_output_types(self._input_dataset)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册