提交 13ebf40d 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

[XLA:SPMD] Support partition DynamicUpdateSlice along slice dim at some special

cases.

PiperOrigin-RevId: 339499575
Change-Id: I39af2243447a5c7228929337c78ae59148d6739c
上级 63a74434
......@@ -2400,15 +2400,141 @@ Status SpmdPartitioningVisitor::HandleDynamicUpdateSlice(HloInstruction* hlo) {
if (hlo->sharding().IsTileMaximal()) {
return DefaultAction(hlo);
}
std::vector<int64> partitioned_slice_dims;
std::vector<int64> slice_dims;
std::vector<int64> partitioned_non_slice_dims;
std::vector<int64> partitioned_slice_offsets;
for (int64 i = 0; i < hlo->shape().rank(); ++i) {
if (hlo->sharding().tile_assignment().dim(i) != 1 &&
(hlo->operand(1)->shape().dimensions(i) != hlo->shape().dimensions(i) ||
!hlo->operand(i + 2)->IsConstant() ||
!hlo->operand(i + 2)->literal().IsZero({}))) {
// We currently do not partition the sliced dimensions.
return DefaultAction(hlo);
}
if (hlo->operand(1)->shape().dimensions(i) != hlo->shape().dimensions(i)) {
slice_dims.push_back(i);
if (hlo->sharding().tile_assignment().dim(i) != 1) {
if (!hlo->operand(i + 2)->IsConstant()) {
return DefaultAction(hlo);
}
partitioned_slice_dims.push_back(i);
partitioned_slice_offsets.push_back(
hlo->operand(i + 2)->literal().Get<int>({}));
}
} else if (hlo->sharding().tile_assignment().dim(i) != 1) {
if (!hlo->operand(i + 2)->IsConstant() ||
!hlo->operand(i + 2)->literal().IsZero({})) {
return DefaultAction(hlo);
}
partitioned_non_slice_dims.push_back(i);
}
}
// Handle when there is slice dim partitioned.
if (!partitioned_slice_dims.empty()) {
auto add_hlo = [&](std::unique_ptr<HloInstruction> to_add) {
return b_.AddInstruction(std::move(to_add));
};
std::vector<HloInstruction*> new_indices(hlo->shape().rank());
for (int64 i = 0; i < new_indices.size(); ++i) {
// Replicate the indices.
new_indices[i] = GetPartitionedHlo(hlo->operand(i + 2))
.Reshard(HloSharding::Replicate())
.hlo();
}
// Get partitioned input.
const auto& dus_sharding = hlo->sharding();
const auto& partitioned_input =
GetPartitionedHlo(hlo->operand(0)).Reshard(dus_sharding).hlo();
// Get replicate update.
auto update_sharding = HloSharding::Replicate();
if (!partitioned_non_slice_dims.empty()) {
// Do partial replicate for update if non slice dims are partitioned.
update_sharding =
hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(dus_sharding,
slice_dims);
}
HloInstruction* replicate_update =
GetPartitionedHlo(hlo->operand(1)).Reshard(update_sharding).hlo();
const auto& update_shape = replicate_update->shape();
const auto& partitioned_shape = partitioned_input->shape();
auto partition_ordinals =
MakeTiledPartitionOrdinals(hlo->sharding(), partition_id_, &b_);
HloInstruction* all_dims_within_partition = add_hlo(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
for (int i = 0; i < partitioned_slice_dims.size(); ++i) {
int dim = partitioned_slice_dims[i];
// Calculate per partition size.
const int64 per_partition_size = partitioned_shape.dimensions(dim);
// Only update within a single partition is supported.
if ((partitioned_slice_offsets[i] / per_partition_size) !=
((partitioned_slice_offsets[i] + update_shape.dimensions(dim) - 1) /
per_partition_size)) {
return DefaultAction(hlo);
}
// within_partition = (offset >= partition_id * per_partition_size) &&
// (offset < (partition_id + 1) * per_partition_size)
const Shape& compare_shape =
ShapeUtil::ChangeElementType(partition_id_->shape(), PRED);
auto per_partition_size_hlo = add_hlo(HloInstruction::CreateConstant(
LiteralUtil::CreateR0<int>(per_partition_size)));
const Shape& offset_shape = per_partition_size_hlo->shape();
auto partition_offset = add_hlo(HloInstruction::CreateBinary(
offset_shape, HloOpcode::kMultiply, partition_ordinals[dim],
per_partition_size_hlo));
// offset >= partition_id * per_partition_size
auto offset_ge = add_hlo(HloInstruction::CreateCompare(
compare_shape, new_indices[dim], partition_offset,
ComparisonDirection::kGe));
// offset < (partition_id + 1) * per_partition_size
auto offset_lt = add_hlo(HloInstruction::CreateCompare(
compare_shape, new_indices[dim],
add_hlo(HloInstruction::CreateBinary(
offset_shape, HloOpcode::kMultiply,
add_hlo(HloInstruction::CreateBinary(
offset_shape, HloOpcode::kAdd, partition_ordinals[dim],
add_hlo(HloInstruction::CreateConstant(
LiteralUtil::CreateR0<int>(1))))),
per_partition_size_hlo)),
ComparisonDirection::kLt));
auto update_within_partition = add_hlo(HloInstruction::CreateBinary(
compare_shape, HloOpcode::kAnd, offset_ge, offset_lt));
all_dims_within_partition = add_hlo(HloInstruction::CreateBinary(
compare_shape, HloOpcode::kAnd, all_dims_within_partition,
update_within_partition));
// Calculate offset.
// slice dim offset =
// within_partition ?
// offset - partition_id * per_partition_size : 0
new_indices[dim] = add_hlo(HloInstruction::CreateTernary(
new_indices[dim]->shape(), HloOpcode::kSelect,
update_within_partition,
add_hlo(HloInstruction::CreateBinary(
new_indices[dim]->shape(), HloOpcode::kSubtract, new_indices[dim],
partition_offset)),
add_hlo(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(0)))));
}
// Create dynamic update slice.
auto dus = add_hlo(HloInstruction::CreateDynamicUpdateSlice(
partitioned_shape, partitioned_input, replicate_update, new_indices));
SetPartitionedHlo(hlo, [&]() {
// Select if update is needed.
return add_hlo(HloInstruction::CreateTernary(
dus->shape(), HloOpcode::kSelect,
add_hlo(HloInstruction::CreateBroadcast(
ShapeUtil::ChangeElementType(dus->shape(), PRED),
all_dims_within_partition, {})),
dus, partitioned_input));
});
return Status::OK();
}
// Partition non slice dims only.
std::vector<HloInstruction*> new_indices(hlo->shape().rank());
auto new_input =
GetPartitionedHlo(hlo->operand(0)).Reshard(hlo->sharding()).hlo();
......
......@@ -4196,6 +4196,84 @@ ENTRY entry {
op::Shape("s32[64,64]")));
}
TEST_F(SpmdPartitioningTest, DynamicUpdateSliceAlongPartitionedDimension) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
%input = s32[128,64] parameter(0)
%input.copy = s32[128,64] copy(%input), sharding={devices=[1,2]0,1}
%index = s32[] parameter(1)
%constant = s32[] constant(60)
%update = s32[128,2] parameter(2)
%update.copy = s32[128,2] copy(%update), sharding={devices=[1,2]0,1}
ROOT %dynamic-update-slice = s32[128,64]
dynamic-update-slice(%input.copy, %update.copy, %index, %constant),
sharding={devices=[1,2]0,1}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/2));
VLOG(1) << module->ToString();
auto root = module->entry_computation()->root_instruction();
auto input = AllOf(op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(),
op::Reshape())),
op::Shape("s32[128,32]"));
auto update = AllOf(op::AllReduce(op::DynamicUpdateSlice(
op::Broadcast(),
op::Copy(op::DynamicSlice(
op::Parameter(2), op::Constant(), op::Reshape())),
op::Constant(), op::Reshape())),
op::Shape("s32[128,2]"));
EXPECT_THAT(
root, AllOf(op::Select(op::Broadcast(),
op::DynamicUpdateSlice(
input, update, op::Parameter(1), op::Select()),
input),
op::Shape("s32[128,32]")));
}
TEST_F(SpmdPartitioningTest, DynamicUpdateSlicePartitionSliceAndNonSliceDims) {
const char* const hlo_string = R"(
HloModule module
ENTRY entry {
%input = s32[128,64] parameter(0)
%input.copy = s32[128,64] copy(%input), sharding={devices=[2,2]0,1,2,3}
%constant.0 = s32[] constant(0)
%constant.1 = s32[] constant(60)
%update = s32[128,2] parameter(1)
%update.copy = s32[128,2] copy(%update), sharding={devices=[2,2]0,1,2,3}
ROOT %dynamic-update-slice = s32[128,64]
dynamic-update-slice(%input.copy, %update.copy, %constant.0, %constant.1),
sharding={devices=[2,2]0,1,2,3}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
PartitionComputation(hlo_string, /*num_devices=*/4));
VLOG(1) << module->ToString();
auto root = module->entry_computation()->root_instruction();
auto input = AllOf(op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(),
op::Reshape())),
op::Shape("s32[64,32]"));
auto update = AllOf(op::AllReduce(op::DynamicUpdateSlice(
op::Broadcast(),
op::Copy(op::DynamicSlice(
op::Parameter(1), op::Reshape(), op::Reshape())),
op::Constant(), op::Reshape())),
op::Shape("s32[64,2]"));
EXPECT_THAT(root,
AllOf(op::Select(op::Broadcast(),
op::DynamicUpdateSlice(
input, update, op::Constant(), op::Select()),
input),
op::Shape("s32[64,32]")));
}
TEST_F(SpmdPartitioningTest, PassthroughGather) {
const char* const hlo_string = R"(
HloModule module
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册