From be0d1151a6a3352206a435c1c14afa2b1f4c5dc4 Mon Sep 17 00:00:00 2001 From: Marcello Maggioni Date: Wed, 13 Sep 2023 10:11:54 -0700 Subject: [PATCH] [XLA] Rework dot() sharding propagation to lookahead instructions sharding to choose a sharding for dot() that agrees with the users if possible. PiperOrigin-RevId: 565086052 --- .../xla/xla/service/sharding_propagation.cc | 220 ++++++++++++++---- .../xla/service/sharding_propagation_test.cc | 35 +++ 2 files changed, 214 insertions(+), 41 deletions(-) diff --git a/third_party/xla/xla/service/sharding_propagation.cc b/third_party/xla/xla/service/sharding_propagation.cc index 608421a434b..5737ed0c576 100644 --- a/third_party/xla/xla/service/sharding_propagation.cc +++ b/third_party/xla/xla/service/sharding_propagation.cc @@ -96,33 +96,52 @@ bool IsShardingStrictlyBetter(const HloSharding& lhs, const HloSharding& rhs) { return false; } -// Updates the sharding of the specified instruction with the specified sharding -// if it is better than the current one and returns true if a new sharding have -// been applied. If may_combine_partial_sharding is true, this may combine the -// new and existing sharding if they are both partial tiling partial -// replication. -bool MaybeImproveInstructionSharding(HloSharding sharding, - HloInstruction* instruction, - bool may_combine_partial_sharding, - bool allow_aggressive_resharding = false) { +std::optional ReturnImprovedSharding( + HloSharding sharding, HloInstruction* instruction, + bool may_combine_partial_sharding, + bool allow_aggressive_resharding = false) { // Always allow improve the sharding if it's straightly better. if (instruction->has_sharding() && IsShardingStrictlyBetter(sharding, instruction->sharding())) { - instruction->set_sharding(sharding); - return true; + return sharding; + } + // Allows improve from tile maximal shardings to manual shardings. + if (instruction->has_sharding()) { + bool no_worse = true; + bool changed = false; + const std::vector& flattened_instruction_shardings = + instruction->sharding().tuple_elements(); + const std::vector& flatten_shardings = + sharding.tuple_elements(); + CHECK_EQ(flattened_instruction_shardings.size(), flatten_shardings.size()); + for (int i = 0; i != flattened_instruction_shardings.size(); ++i) { + if (flattened_instruction_shardings[i] != flatten_shardings[i]) { + changed = true; + if (!flattened_instruction_shardings[i].IsTileMaximal() || + !flatten_shardings[i].IsManual()) { + no_worse = false; + break; + } + } + } + // Replace sharding if we are know that it strictly improves(i.e. the + // sharding is changed and no worse than before) from tile maximal + // (sub)shardings to manual shardings. Otherwise pass through. + if (no_worse && changed) { + return sharding; + } } // We don't want to propagate tile maximal shardings. if (!IsSpatiallyPartitioned(sharding)) { - return false; + return std::nullopt; } // Any sharding is better then no sharding. if (!instruction->has_sharding()) { - instruction->set_sharding(std::move(sharding)); - return true; + return sharding; } // We don't want to propagate manual shardings. if (sharding.IsManual()) { - return false; + return std::nullopt; } int64_t sharding_tiles = sharding.NumTiles(); if (hlo_sharding_util::MergeSharding(instruction->sharding(), &sharding, @@ -138,10 +157,27 @@ bool MaybeImproveInstructionSharding(HloSharding sharding, VLOG(10) << "Not merging because of different device distribution"; VLOG(10) << "Instr sharding: " << instruction->sharding().ToString(); VLOG(10) << "New sharding " << sharding.ToString(); - return false; + return std::nullopt; } } - instruction->set_sharding(std::move(sharding)); + return sharding; + } + return std::nullopt; +} + +// Updates the sharding of the specified instruction with the specified sharding +// if it is better than the current one and returns true if a new sharding have +// been applied. If may_combine_partial_sharding is true, this may combine the +// new and existing sharding if they are both partial tiling partial +// replication. +bool MaybeImproveInstructionSharding(HloSharding sharding, + HloInstruction* instruction, + bool may_combine_partial_sharding, + bool allow_aggressive_resharding = false) { + if (auto new_sharding = ReturnImprovedSharding(sharding, instruction, + may_combine_partial_sharding, + allow_aggressive_resharding)) { + instruction->set_sharding(std::move(*new_sharding)); return true; } return false; @@ -415,10 +451,58 @@ bool SupportSpatialPartitioning( } } +// Helper to lookahead sharding of user of an instruction to be used as guidance +// for ambiguous cases. +std::optional LookaheadUserSharding(HloInstruction* instr, + bool is_spmd, + const CallGraph& call_graph) { + if (instr->user_count() != 1) { + return std::nullopt; + } + HloInstruction* current_user = instr->users()[0]; + std::optional sharding; + std::vector users_chain = {instr, current_user}; + // Collect single user instructions along the way. + while (!current_user->has_sharding()) { + // Only consider single user chains. + if (current_user->users().size() != 1) { + users_chain.clear(); + break; + } + current_user = current_user->users()[0]; + users_chain.push_back(current_user); + } + // Early exit for unsupported cases. + if (users_chain.empty()) { + return std::nullopt; + } + for (int i = users_chain.size() - 1; i >= 1; --i) { + HloInstruction* user = users_chain[i]; + HloInstruction* current = users_chain[i - 1]; + CHECK(user->has_sharding()); + sharding = ShardingPropagation::GetShardingFromUser( + *current, *user, INT64_MAX, is_spmd, call_graph); + // We need to set the sharding to the instruction, because + // GetShardingFromUser() interface uses sharding from the instruction + // itself. It will be cleared out later. + if (sharding.has_value() && i != 1) { + current->set_sharding(*sharding); + continue; + } + break; + } + // Clear the sharding of the middle instructions we set the sharding of + // because they were unsharded. + for (int i = 1; i < users_chain.size() - 1; ++i) { + users_chain[i]->clear_sharding(); + } + return sharding; +} + bool InferDotShardingFromOperands( - HloInstruction* instruction, + HloInstruction* instruction, const CallGraph& call_graph, const dot_as_convolution_util::DotConvolutionDimsInfo& dnums, - bool may_combine_partial_sharding) { + bool may_combine_partial_sharding, bool is_spmd) { auto from_operand = [&](int64_t operand_index) { auto operand = instruction->operand(operand_index); const HloSharding& operand_sharding = operand->sharding(); @@ -461,23 +545,73 @@ bool InferDotShardingFromOperands( replicate_contracting_dims, op_dims_to_output_perm, out_dims_to_op_perm); }; - bool changed = false; - int64_t larger_operand = - ShapeUtil::ByteSizeOf(instruction->operand(0)->shape()) >= - ShapeUtil::ByteSizeOf(instruction->operand(1)->shape()) - ? 0 - : 1; - if (IsSpatiallyPartitioned(instruction->operand(larger_operand))) { - changed |= MaybeImproveInstructionSharding(from_operand(larger_operand), - instruction, - may_combine_partial_sharding); + std::optional improved_operand_0; + std::optional improved_operand_1; + if (IsSpatiallyPartitioned(instruction->operand(0))) { + improved_operand_0 = ReturnImprovedSharding( + from_operand(0), instruction, may_combine_partial_sharding, + /*allow_aggressive_resharding=*/false); } - if (IsSpatiallyPartitioned(instruction->operand(1 - larger_operand))) { - changed |= MaybeImproveInstructionSharding(from_operand(1 - larger_operand), - instruction, - may_combine_partial_sharding); + if (IsSpatiallyPartitioned(instruction->operand(1))) { + improved_operand_1 = ReturnImprovedSharding( + from_operand(1), instruction, may_combine_partial_sharding, + /*allow_aggressive_resharding=*/false); } - return changed; + // If not improved sharding found then do not set any sharding. + if (!improved_operand_0.has_value() && !improved_operand_1.has_value()) { + return false; + } + // Sharding found from operand 0 but not operand 1. Set sharding from operand + // 0 + if (improved_operand_0.has_value() && !improved_operand_1.has_value()) { + instruction->set_sharding(*improved_operand_0); + return true; + } + // Sharding found from operand 1 but not operand 0. Set sharding from operand + // 1 + if (!improved_operand_0.has_value() && improved_operand_1.has_value()) { + instruction->set_sharding(*improved_operand_1); + return true; + } + CHECK(improved_operand_0.has_value() && improved_operand_1.has_value()); + std::optional lookahead_sharding = + LookaheadUserSharding(instruction, is_spmd, call_graph); + std::array sharding_priority = {*improved_operand_0, + *improved_operand_1}; + bool priority_defined_with_lookahead = false; + // Found sharding from lookahead. + if (lookahead_sharding.has_value()) { + const bool operand_0_is_lookahead_subtiling = + hlo_sharding_util::IsSubTilingOrEqualSharding( + instruction->shape(), *lookahead_sharding, *improved_operand_0); + const bool operand_1_is_lookahead_subtiling = + hlo_sharding_util::IsSubTilingOrEqualSharding( + instruction->shape(), *lookahead_sharding, *improved_operand_1); + // If the sharding from operand 0 is a subtiling of the user, but not the + // one from operand 1 prioritize that sharding. + if (operand_0_is_lookahead_subtiling && !operand_1_is_lookahead_subtiling) { + priority_defined_with_lookahead = true; + } + // If the sharding from operand 1 is a subtiling of the user, but not the + // one from operand 0 prioritize that sharding. + if (!operand_0_is_lookahead_subtiling && operand_1_is_lookahead_subtiling) { + instruction->set_sharding(*improved_operand_1); + std::swap(sharding_priority[0], sharding_priority[1]); + priority_defined_with_lookahead = true; + } + } + // If lookahead didn't define a priority then use size. + if (!priority_defined_with_lookahead && + ShapeUtil::ByteSizeOf(instruction->operand(0)->shape()) < + ShapeUtil::ByteSizeOf(instruction->operand(1)->shape())) { + std::swap(sharding_priority[0], sharding_priority[1]); + } + // Set primary sharding to the instruction and then try to improve it with + // the secondary sharding. + instruction->set_sharding(sharding_priority[0]); + MaybeImproveInstructionSharding(sharding_priority[1], instruction, + may_combine_partial_sharding); + return true; } // Infer output sharding on index parallel dimensions for gather/scatter from @@ -610,8 +744,10 @@ bool InferScatterParallelShardingFromOperands( // Convolution handling for InferShardingFromOperands(). bool InferConvolutionShardingFromOperands(HloInstruction* instruction, + const CallGraph& call_graph, int64_t aggressiveness, - bool may_combine_partial_sharding) { + bool may_combine_partial_sharding, + bool is_spmd) { auto get_partitions_for_dims = [&](const HloInstruction* inst, absl::Span< @@ -646,8 +782,8 @@ bool InferConvolutionShardingFromOperands(HloInstruction* instruction, (lhs_conv_spatial_partitions == 1 && rhs_conv_spatial_partitions == 1 && instruction->batch_group_count() == 1 && instruction->feature_group_count() == 1)) { - return InferDotShardingFromOperands(instruction, dot_dims, - may_combine_partial_sharding); + return InferDotShardingFromOperands(instruction, call_graph, dot_dims, + may_combine_partial_sharding, is_spmd); } const auto& dnums = instruction->convolution_dimension_numbers(); const HloInstruction* lhs = instruction->operand(0); @@ -2186,8 +2322,9 @@ bool ShardingPropagation::InferShardingFromOperands( 1); } case HloOpcode::kConvolution: - return InferConvolutionShardingFromOperands(instruction, aggressiveness, - may_combine_partial_sharding); + return InferConvolutionShardingFromOperands( + instruction, call_graph, aggressiveness, may_combine_partial_sharding, + is_spmd_); case HloOpcode::kTranspose: { const HloInstruction* input = instruction->operand(0); if (!IsSpatiallyPartitioned(input)) { @@ -2276,8 +2413,9 @@ bool ShardingPropagation::InferShardingFromOperands( case HloOpcode::kDot: { const auto& dnums = dot_as_convolution_util::ParseDotGeneralFromDot(instruction); - return InferDotShardingFromOperands(instruction, dnums, - may_combine_partial_sharding); + return InferDotShardingFromOperands(instruction, call_graph, dnums, + may_combine_partial_sharding, + is_spmd_); } case HloOpcode::kParameter: { auto parent_it = computation_map.find(instruction->parent()); diff --git a/third_party/xla/xla/service/sharding_propagation_test.cc b/third_party/xla/xla/service/sharding_propagation_test.cc index e46e4207337..7bea7ac1e48 100644 --- a/third_party/xla/xla/service/sharding_propagation_test.cc +++ b/third_party/xla/xla/service/sharding_propagation_test.cc @@ -9549,5 +9549,40 @@ ENTRY %entry { EXPECT_EQ(add_1->sharding(), output->sharding()); } +TEST_F(ShardingPropagationTest, LookaheadUsersOfDot) { + const char* const hlo_string = R"( +HloModule module + +ENTRY %entry { + p0 = bf16[512,512,1024]{2,1,0} parameter(0), sharding={devices=[16,1,4]<=[64]} + p1 = bf16[512,512,16,128]{3,2,1,0} parameter(1), sharding={devices=[16,1,4,1]<=[64]} + p2 = bf16[16,1024,16,128]{3,2,1,0} parameter(2), sharding={devices=[1,4,4,1,4]<=[4,16]T(1,0) last_tile_dim_replicate} + p3 = s32[] parameter(3) + dot.1 = bf16[1024,16,128]{2,1,0} dot(p0, p1), lhs_contracting_dims={0,1}, rhs_contracting_dims={0,1} + reshape.1 = bf16[1,1024,16,128]{3,2,1,0} reshape(dot.1) + constant.1 = s32[] constant(0) + ROOT dynamic-update-slice.113 = bf16[16,1024,16,128]{3,2,1,0} dynamic-update-slice(p2, reshape.1, p3, constant.1, constant.1, /*index=5*/constant.1), sharding={devices=[1,4,4,1,4]<=[4,16]T(1,0) last_tile_dim_replicate} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + ShardingPropagation( + /*is_spmd=*/true, /*propagate_metadata=*/true, + /*allow_spmd_sharding_propagation_to_output=*/{true}, + /*allow_spmd_sharding_propagation_to_parameters=*/{true}) + .Run(module.get())); + EXPECT_TRUE(changed); + + XLA_VLOG_LINES(1, module->ToString()); + // Check dangling sharding custom-call can be removed by DCE after + // propagation. + auto* instruction = FindInstruction(module.get(), "dot.1"); + // Check sharding is correctly propagated. + EXPECT_THAT(instruction, + op::Sharding( + "{devices=[4,4,1,4]<=[4,16]T(1,0) last_tile_dim_replicate}")); +} + } // namespace } // namespace xla -- GitLab