提交 be0d1151 编写于 作者: M Marcello Maggioni 提交者: TensorFlower Gardener

[XLA] Rework dot() sharding propagation to lookahead instructions sharding to...

[XLA] Rework dot() sharding propagation to lookahead instructions sharding to choose a sharding for dot() that agrees with the users if possible.

PiperOrigin-RevId: 565086052
上级 d11423a4
......@@ -96,33 +96,52 @@ bool IsShardingStrictlyBetter(const HloSharding& lhs, const HloSharding& rhs) {
return false;
}
// Updates the sharding of the specified instruction with the specified sharding
// if it is better than the current one and returns true if a new sharding have
// been applied. If may_combine_partial_sharding is true, this may combine the
// new and existing sharding if they are both partial tiling partial
// replication.
bool MaybeImproveInstructionSharding(HloSharding sharding,
HloInstruction* instruction,
bool may_combine_partial_sharding,
bool allow_aggressive_resharding = false) {
std::optional<HloSharding> ReturnImprovedSharding(
HloSharding sharding, HloInstruction* instruction,
bool may_combine_partial_sharding,
bool allow_aggressive_resharding = false) {
// Always allow improve the sharding if it's straightly better.
if (instruction->has_sharding() &&
IsShardingStrictlyBetter(sharding, instruction->sharding())) {
instruction->set_sharding(sharding);
return true;
return sharding;
}
// Allows improve from tile maximal shardings to manual shardings.
if (instruction->has_sharding()) {
bool no_worse = true;
bool changed = false;
const std::vector<HloSharding>& flattened_instruction_shardings =
instruction->sharding().tuple_elements();
const std::vector<HloSharding>& flatten_shardings =
sharding.tuple_elements();
CHECK_EQ(flattened_instruction_shardings.size(), flatten_shardings.size());
for (int i = 0; i != flattened_instruction_shardings.size(); ++i) {
if (flattened_instruction_shardings[i] != flatten_shardings[i]) {
changed = true;
if (!flattened_instruction_shardings[i].IsTileMaximal() ||
!flatten_shardings[i].IsManual()) {
no_worse = false;
break;
}
}
}
// Replace sharding if we are know that it strictly improves(i.e. the
// sharding is changed and no worse than before) from tile maximal
// (sub)shardings to manual shardings. Otherwise pass through.
if (no_worse && changed) {
return sharding;
}
}
// We don't want to propagate tile maximal shardings.
if (!IsSpatiallyPartitioned(sharding)) {
return false;
return std::nullopt;
}
// Any sharding is better then no sharding.
if (!instruction->has_sharding()) {
instruction->set_sharding(std::move(sharding));
return true;
return sharding;
}
// We don't want to propagate manual shardings.
if (sharding.IsManual()) {
return false;
return std::nullopt;
}
int64_t sharding_tiles = sharding.NumTiles();
if (hlo_sharding_util::MergeSharding(instruction->sharding(), &sharding,
......@@ -138,10 +157,27 @@ bool MaybeImproveInstructionSharding(HloSharding sharding,
VLOG(10) << "Not merging because of different device distribution";
VLOG(10) << "Instr sharding: " << instruction->sharding().ToString();
VLOG(10) << "New sharding " << sharding.ToString();
return false;
return std::nullopt;
}
}
instruction->set_sharding(std::move(sharding));
return sharding;
}
return std::nullopt;
}
// Updates the sharding of the specified instruction with the specified sharding
// if it is better than the current one and returns true if a new sharding have
// been applied. If may_combine_partial_sharding is true, this may combine the
// new and existing sharding if they are both partial tiling partial
// replication.
bool MaybeImproveInstructionSharding(HloSharding sharding,
HloInstruction* instruction,
bool may_combine_partial_sharding,
bool allow_aggressive_resharding = false) {
if (auto new_sharding = ReturnImprovedSharding(sharding, instruction,
may_combine_partial_sharding,
allow_aggressive_resharding)) {
instruction->set_sharding(std::move(*new_sharding));
return true;
}
return false;
......@@ -415,10 +451,58 @@ bool SupportSpatialPartitioning(
}
}
// Helper to lookahead sharding of user of an instruction to be used as guidance
// for ambiguous cases.
std::optional<HloSharding> LookaheadUserSharding(HloInstruction* instr,
bool is_spmd,
const CallGraph& call_graph) {
if (instr->user_count() != 1) {
return std::nullopt;
}
HloInstruction* current_user = instr->users()[0];
std::optional<HloSharding> sharding;
std::vector<HloInstruction*> users_chain = {instr, current_user};
// Collect single user instructions along the way.
while (!current_user->has_sharding()) {
// Only consider single user chains.
if (current_user->users().size() != 1) {
users_chain.clear();
break;
}
current_user = current_user->users()[0];
users_chain.push_back(current_user);
}
// Early exit for unsupported cases.
if (users_chain.empty()) {
return std::nullopt;
}
for (int i = users_chain.size() - 1; i >= 1; --i) {
HloInstruction* user = users_chain[i];
HloInstruction* current = users_chain[i - 1];
CHECK(user->has_sharding());
sharding = ShardingPropagation::GetShardingFromUser(
*current, *user, INT64_MAX, is_spmd, call_graph);
// We need to set the sharding to the instruction, because
// GetShardingFromUser() interface uses sharding from the instruction
// itself. It will be cleared out later.
if (sharding.has_value() && i != 1) {
current->set_sharding(*sharding);
continue;
}
break;
}
// Clear the sharding of the middle instructions we set the sharding of
// because they were unsharded.
for (int i = 1; i < users_chain.size() - 1; ++i) {
users_chain[i]->clear_sharding();
}
return sharding;
}
bool InferDotShardingFromOperands(
HloInstruction* instruction,
HloInstruction* instruction, const CallGraph& call_graph,
const dot_as_convolution_util::DotConvolutionDimsInfo& dnums,
bool may_combine_partial_sharding) {
bool may_combine_partial_sharding, bool is_spmd) {
auto from_operand = [&](int64_t operand_index) {
auto operand = instruction->operand(operand_index);
const HloSharding& operand_sharding = operand->sharding();
......@@ -461,23 +545,73 @@ bool InferDotShardingFromOperands(
replicate_contracting_dims, op_dims_to_output_perm,
out_dims_to_op_perm);
};
bool changed = false;
int64_t larger_operand =
ShapeUtil::ByteSizeOf(instruction->operand(0)->shape()) >=
ShapeUtil::ByteSizeOf(instruction->operand(1)->shape())
? 0
: 1;
if (IsSpatiallyPartitioned(instruction->operand(larger_operand))) {
changed |= MaybeImproveInstructionSharding(from_operand(larger_operand),
instruction,
may_combine_partial_sharding);
std::optional<HloSharding> improved_operand_0;
std::optional<HloSharding> improved_operand_1;
if (IsSpatiallyPartitioned(instruction->operand(0))) {
improved_operand_0 = ReturnImprovedSharding(
from_operand(0), instruction, may_combine_partial_sharding,
/*allow_aggressive_resharding=*/false);
}
if (IsSpatiallyPartitioned(instruction->operand(1 - larger_operand))) {
changed |= MaybeImproveInstructionSharding(from_operand(1 - larger_operand),
instruction,
may_combine_partial_sharding);
if (IsSpatiallyPartitioned(instruction->operand(1))) {
improved_operand_1 = ReturnImprovedSharding(
from_operand(1), instruction, may_combine_partial_sharding,
/*allow_aggressive_resharding=*/false);
}
return changed;
// If not improved sharding found then do not set any sharding.
if (!improved_operand_0.has_value() && !improved_operand_1.has_value()) {
return false;
}
// Sharding found from operand 0 but not operand 1. Set sharding from operand
// 0
if (improved_operand_0.has_value() && !improved_operand_1.has_value()) {
instruction->set_sharding(*improved_operand_0);
return true;
}
// Sharding found from operand 1 but not operand 0. Set sharding from operand
// 1
if (!improved_operand_0.has_value() && improved_operand_1.has_value()) {
instruction->set_sharding(*improved_operand_1);
return true;
}
CHECK(improved_operand_0.has_value() && improved_operand_1.has_value());
std::optional<HloSharding> lookahead_sharding =
LookaheadUserSharding(instruction, is_spmd, call_graph);
std::array<HloSharding, 2> sharding_priority = {*improved_operand_0,
*improved_operand_1};
bool priority_defined_with_lookahead = false;
// Found sharding from lookahead.
if (lookahead_sharding.has_value()) {
const bool operand_0_is_lookahead_subtiling =
hlo_sharding_util::IsSubTilingOrEqualSharding(
instruction->shape(), *lookahead_sharding, *improved_operand_0);
const bool operand_1_is_lookahead_subtiling =
hlo_sharding_util::IsSubTilingOrEqualSharding(
instruction->shape(), *lookahead_sharding, *improved_operand_1);
// If the sharding from operand 0 is a subtiling of the user, but not the
// one from operand 1 prioritize that sharding.
if (operand_0_is_lookahead_subtiling && !operand_1_is_lookahead_subtiling) {
priority_defined_with_lookahead = true;
}
// If the sharding from operand 1 is a subtiling of the user, but not the
// one from operand 0 prioritize that sharding.
if (!operand_0_is_lookahead_subtiling && operand_1_is_lookahead_subtiling) {
instruction->set_sharding(*improved_operand_1);
std::swap(sharding_priority[0], sharding_priority[1]);
priority_defined_with_lookahead = true;
}
}
// If lookahead didn't define a priority then use size.
if (!priority_defined_with_lookahead &&
ShapeUtil::ByteSizeOf(instruction->operand(0)->shape()) <
ShapeUtil::ByteSizeOf(instruction->operand(1)->shape())) {
std::swap(sharding_priority[0], sharding_priority[1]);
}
// Set primary sharding to the instruction and then try to improve it with
// the secondary sharding.
instruction->set_sharding(sharding_priority[0]);
MaybeImproveInstructionSharding(sharding_priority[1], instruction,
may_combine_partial_sharding);
return true;
}
// Infer output sharding on index parallel dimensions for gather/scatter from
......@@ -610,8 +744,10 @@ bool InferScatterParallelShardingFromOperands(
// Convolution handling for InferShardingFromOperands().
bool InferConvolutionShardingFromOperands(HloInstruction* instruction,
const CallGraph& call_graph,
int64_t aggressiveness,
bool may_combine_partial_sharding) {
bool may_combine_partial_sharding,
bool is_spmd) {
auto get_partitions_for_dims =
[&](const HloInstruction* inst,
absl::Span<
......@@ -646,8 +782,8 @@ bool InferConvolutionShardingFromOperands(HloInstruction* instruction,
(lhs_conv_spatial_partitions == 1 && rhs_conv_spatial_partitions == 1 &&
instruction->batch_group_count() == 1 &&
instruction->feature_group_count() == 1)) {
return InferDotShardingFromOperands(instruction, dot_dims,
may_combine_partial_sharding);
return InferDotShardingFromOperands(instruction, call_graph, dot_dims,
may_combine_partial_sharding, is_spmd);
}
const auto& dnums = instruction->convolution_dimension_numbers();
const HloInstruction* lhs = instruction->operand(0);
......@@ -2186,8 +2322,9 @@ bool ShardingPropagation::InferShardingFromOperands(
1);
}
case HloOpcode::kConvolution:
return InferConvolutionShardingFromOperands(instruction, aggressiveness,
may_combine_partial_sharding);
return InferConvolutionShardingFromOperands(
instruction, call_graph, aggressiveness, may_combine_partial_sharding,
is_spmd_);
case HloOpcode::kTranspose: {
const HloInstruction* input = instruction->operand(0);
if (!IsSpatiallyPartitioned(input)) {
......@@ -2276,8 +2413,9 @@ bool ShardingPropagation::InferShardingFromOperands(
case HloOpcode::kDot: {
const auto& dnums =
dot_as_convolution_util::ParseDotGeneralFromDot(instruction);
return InferDotShardingFromOperands(instruction, dnums,
may_combine_partial_sharding);
return InferDotShardingFromOperands(instruction, call_graph, dnums,
may_combine_partial_sharding,
is_spmd_);
}
case HloOpcode::kParameter: {
auto parent_it = computation_map.find(instruction->parent());
......
......@@ -9549,5 +9549,40 @@ ENTRY %entry {
EXPECT_EQ(add_1->sharding(), output->sharding());
}
TEST_F(ShardingPropagationTest, LookaheadUsersOfDot) {
const char* const hlo_string = R"(
HloModule module
ENTRY %entry {
p0 = bf16[512,512,1024]{2,1,0} parameter(0), sharding={devices=[16,1,4]<=[64]}
p1 = bf16[512,512,16,128]{3,2,1,0} parameter(1), sharding={devices=[16,1,4,1]<=[64]}
p2 = bf16[16,1024,16,128]{3,2,1,0} parameter(2), sharding={devices=[1,4,4,1,4]<=[4,16]T(1,0) last_tile_dim_replicate}
p3 = s32[] parameter(3)
dot.1 = bf16[1024,16,128]{2,1,0} dot(p0, p1), lhs_contracting_dims={0,1}, rhs_contracting_dims={0,1}
reshape.1 = bf16[1,1024,16,128]{3,2,1,0} reshape(dot.1)
constant.1 = s32[] constant(0)
ROOT dynamic-update-slice.113 = bf16[16,1024,16,128]{3,2,1,0} dynamic-update-slice(p2, reshape.1, p3, constant.1, constant.1, /*index=5*/constant.1), sharding={devices=[1,4,4,1,4]<=[4,16]T(1,0) last_tile_dim_replicate}
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(
bool changed,
ShardingPropagation(
/*is_spmd=*/true, /*propagate_metadata=*/true,
/*allow_spmd_sharding_propagation_to_output=*/{true},
/*allow_spmd_sharding_propagation_to_parameters=*/{true})
.Run(module.get()));
EXPECT_TRUE(changed);
XLA_VLOG_LINES(1, module->ToString());
// Check dangling sharding custom-call can be removed by DCE after
// propagation.
auto* instruction = FindInstruction(module.get(), "dot.1");
// Check sharding is correctly propagated.
EXPECT_THAT(instruction,
op::Sharding(
"{devices=[4,4,1,4]<=[4,16]T(1,0) last_tile_dim_replicate}"));
}
} // namespace
} // namespace xla
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册