提交 f8d80a78 编写于 作者: Y Yunxing Dai 提交者: TensorFlower Gardener

Fix segment_reduction to support dynamic dims correctly.

Previously it just ignores dynamic dimension.

PiperOrigin-RevId: 327861140
Change-Id: Icfe9a6293cc28ca2b811b1810e790f4c62e1e4a3
上级 e98f54f4
......@@ -74,12 +74,44 @@ class UnsortedSegmentReduce : public XlaOpKernel {
" vs. ", indices_shape.dim_size(d)));
}
xla::XlaBuilder* builder = ctx->builder();
// data shape = [indices_shape, segment_shape]
// buffer shape = [num_segment, segment_shape]
// We now create the buffer shape by reverse enginerring data shape into
// indices shape and segment shape.
TensorShape buffer_shape = data_shape;
buffer_shape.RemoveDimRange(0, indices_shape.dims());
buffer_shape.InsertDim(0, num_segments);
auto buffer =
xla::Broadcast(InitialValue(builder), buffer_shape.dim_sizes());
// Build dynamic dim sizes for buffer, as well as whether each dimension
// size is dynamic or static. We build two parts: num_sgement part and
// segment_shape part.
std::vector<xla::XlaOp> buffer_dims;
std::vector<bool> buffer_dims_are_dynamic;
// Build the "num_segment" part.
bool num_segments_is_dynamic;
OP_REQUIRES_OK(
ctx, ctx->ResolveInputDynamismIntoPred(2, &num_segments_is_dynamic));
buffer_dims.insert(buffer_dims.begin(), ctx->Input(2));
buffer_dims_are_dynamic.insert(buffer_dims_are_dynamic.begin(),
num_segments_is_dynamic);
// Build the segment shape part.
for (int64 i = indices_shape.dims(); i < data_shape.dims(); ++i) {
buffer_dims.push_back(xla::GetDimensionSize(data, i));
buffer_dims_are_dynamic.push_back(
ctx->InputXlaShape(0)->is_dynamic_dimension(i));
}
for (int64 i = 0; i < buffer_dims.size(); ++i) {
if (buffer_dims_are_dynamic[i]) {
// For each dynamic dimension, call set-dimension-size on it.
buffer = xla::SetDimensionSize(buffer, buffer_dims[i], i);
}
}
auto combiner = [this](xla::XlaOp a, xla::XlaOp b,
xla::XlaBuilder* builder) { return Combine(a, b); };
......
......@@ -116,12 +116,44 @@ class UnsortedSegmentSum : public XlaOpKernel {
indices_shape.dim_size(d)));
}
xla::XlaBuilder* builder = ctx->builder();
// data shape = [indices_shape, segment_shape]
// buffer shape = [num_segment, segment_shape]
// We now create the buffer shape by reverse enginerring data shape into
// indices shape and segment shape.
TensorShape buffer_shape = data_shape;
buffer_shape.RemoveDimRange(0, indices_shape.dims());
buffer_shape.InsertDim(0, num_segments);
auto buffer = xla::Broadcast(XlaHelpers::Zero(builder, dtype_),
buffer_shape.dim_sizes());
// Build dynamic dim sizes for buffer, as well as whether each dimension
// size is dynamic or static. We build two parts: num_sgement part and
// segment_shape part.
std::vector<xla::XlaOp> buffer_dims;
std::vector<bool> buffer_dims_are_dynamic;
// Build the "num_segment" part.
bool num_segments_is_dynamic;
OP_REQUIRES_OK(
ctx, ctx->ResolveInputDynamismIntoPred(2, &num_segments_is_dynamic));
buffer_dims.insert(buffer_dims.begin(), ctx->Input(2));
buffer_dims_are_dynamic.insert(buffer_dims_are_dynamic.begin(),
num_segments_is_dynamic);
// Build the segment shape part.
for (int64 i = indices_shape.dims(); i < data_shape.dims(); ++i) {
buffer_dims.push_back(xla::GetDimensionSize(data, i));
buffer_dims_are_dynamic.push_back(
ctx->InputXlaShape(0)->is_dynamic_dimension(i));
}
for (int64 i = 0; i < buffer_dims.size(); ++i) {
if (buffer_dims_are_dynamic[i]) {
// For each dynamic dimension, call set-dimension-size on it.
buffer = xla::SetDimensionSize(buffer, buffer_dims[i], i);
}
}
auto combiner = [](xla::XlaOp a, xla::XlaOp b, xla::XlaBuilder* builder) {
return a + b;
};
......
......@@ -632,6 +632,34 @@ class InputIterationTest(test.TestCase, parameterized.TestCase,
# This assumes that there are exactly 2 replicas
self.assertAllEqual([2, 1], run(next(input_iterator)))
@combinations.generate(
combinations.combine(
distribution=strategy_combinations.multidevice_strategies,
mode=["eager"]))
def testSegmentSumWithDynamicNumberOfSegments(self, distribution):
def dataset_fn(_):
data = array_ops.zeros(5, dtype=dtypes.int32)
dataset = get_dataset_from_tensor_slices(data)
dataset = dataset.batch(3)
return dataset
input_iterator = iter(
distribution.experimental_distribute_datasets_from_function(dataset_fn))
@def_function.function
def step_fn(example):
segment_ids = array_ops.zeros_like_v2(example)
num_segment = array_ops.shape(example)[0]
# If number of segments is dynamic, output should be a dynamic shape.
return math_ops.unsorted_segment_sum(example, segment_ids, num_segment)
# This assumes that there are exactly 2 replicas
outputs = distribution.experimental_local_results(
distribution.run(step_fn, args=(next(input_iterator),)))
self.assertAllEqual((3,), outputs[0].shape)
self.assertAllEqual((2,), outputs[1].shape)
@combinations.generate(
combinations.combine(
distribution=strategy_combinations.multidevice_strategies,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册