提交 184223ec 编写于 作者: D Dimitris Vardoulakis 提交者: TensorFlower Gardener

[TF:XLA] Handle more patterns in ArCrsCombiner, and handle sequences of patterns.

Now, we optimize any sequence of the form:
AR [Bitcast|Transpose|Reshape|Convert|Multiply|Add|Subtract]* CRS

PiperOrigin-RevId: 225090998
上级 9748092a
......@@ -36,24 +36,40 @@ namespace {
namespace m = match;
// If the argument instruction is a CRS in the sequence
// AR -> Convert -> Add -> CRS
// then return the AR in the sequence.
// TODO(b/117554291): Rewrite this to recognize more general patterns,
// not just the specific one of AR -> Add -> Convert -> CRS.
absl::optional<HloInstruction*> MatchesArCrsPattern(
HloInstruction* instruction) {
HloInstruction *ar, *convert, *add, *crs;
if (Match(instruction,
m::CrossReplicaSum(
&crs, m::Add(&add, m::Op(),
m::Convert(&convert,
m::CrossReplicaSum(&ar, m::Op()))))) &&
ar->users().size() == 1 && ar->shape().element_type() == BF16 &&
convert->shape().element_type() == F32 && !crs->all_reduce_id()) {
return ar;
// Returns true iff the argument instruction is an AllReduce, followed by a
// certain sequence of instructions and then a CRS. It must be possible to move
// the AR past each instruction in the sequence.
bool MatchesArCrsPattern(HloInstruction* instruction) {
auto can_ar_move_past_instruction = [](HloInstruction* instruction) -> bool {
if (instruction->user_count() != 1) {
return false;
}
auto opcode = instruction->opcode();
return opcode == HloOpcode::kBitcast || opcode == HloOpcode::kTranspose ||
opcode == HloOpcode::kReshape || opcode == HloOpcode::kConvert ||
opcode == HloOpcode::kAdd || opcode == HloOpcode::kSubtract ||
opcode == HloOpcode::kMultiply;
};
auto computation_is_addition = [](HloComputation* c) {
return c->instruction_count() == 3 &&
Match(c->root_instruction(), m::Add(m::Parameter(), m::Parameter()));
};
if (!instruction->IsCrossModuleAllReduce() ||
!computation_is_addition(instruction->called_computations()[0]) ||
instruction->user_count() != 1) {
return false;
}
return absl::optional<HloInstruction*>();
auto next = instruction->users()[0];
while (!next->IsCrossReplicaAllReduce()) {
if (can_ar_move_past_instruction(next)) {
next = next->users()[0];
} else {
return false;
}
}
return computation_is_addition(next->called_computations()[0]);
}
} // namespace
......@@ -195,9 +211,8 @@ bool ArCrsCombiner::InstructionsComputeSameValue(
void ArCrsCombiner::GroupAllReducesById(HloModule* module) {
for (HloComputation* computation : module->MakeNonfusionComputations()) {
for (HloInstruction* instruction : computation->instructions()) {
auto ar = MatchesArCrsPattern(instruction);
if (ar) {
all_reduce_map_[*((*ar)->all_reduce_id())].push_back(*ar);
if (MatchesArCrsPattern(instruction)) {
all_reduce_map_[*(instruction->all_reduce_id())].push_back(instruction);
}
}
}
......@@ -205,21 +220,23 @@ void ArCrsCombiner::GroupAllReducesById(HloModule* module) {
void ArCrsCombiner::KeepProvablyEqualInstructionGroups() {
for (auto it : all_reduce_map_) {
auto all_reduce_id = it.first;
auto instruction_vec = it.second;
CHECK_EQ(instruction_vec.size(), num_spatial_partitions_);
auto instr_0 = instruction_vec[0];
auto add_0 = instr_0->users()[0]->users()[0];
CHECK_EQ(HloOpcode::kAdd, add_0->opcode());
for (int i = 1; i < instruction_vec.size(); ++i) {
auto instr_i = instruction_vec[i];
auto add_i = instr_i->users()[0]->users()[0];
CHECK_EQ(HloOpcode::kAdd, add_i->opcode());
auto next_0 = instr_0->users()[0];
auto next_i = instr_i->users()[0];
absl::flat_hash_map<int64, int64> visited_pairs;
if (!InstructionsComputeSameValue(add_0, add_i, &visited_pairs)) {
all_reduce_map_.erase(it.first);
}
do {
if (!InstructionsComputeSameValue(next_0, next_i, &visited_pairs)) {
all_reduce_map_.erase(all_reduce_id);
break;
}
next_0 = next_0->users()[0];
next_i = next_i->users()[0];
} while (!next_0->IsCrossReplicaAllReduce());
}
}
}
......@@ -228,47 +245,51 @@ StatusOr<bool> ArCrsCombiner::RewriteGraph() {
if (all_reduce_map_.empty()) {
return false;
}
auto computation_is_addition = [](HloComputation* c) {
return c->instruction_count() == 3 &&
Match(c->root_instruction(), m::Add(m::Parameter(), m::Parameter()));
};
for (auto it : all_reduce_map_) {
auto instruction_vec = it.second;
for (auto all_reduce : instruction_vec) {
auto parent_computation = all_reduce->parent();
auto convert = all_reduce->users()[0];
auto add = convert->users()[0];
auto crs = add->users()[0];
if (!computation_is_addition(all_reduce->called_computations()[0]) ||
!computation_is_addition(crs->called_computations()[0])) {
continue;
auto all_reduce_id = all_reduce->all_reduce_id();
auto prev = all_reduce->mutable_operand(0);
auto next = all_reduce->users()[0];
TF_CHECK_OK(all_reduce->ReplaceUseWith(next, prev));
TF_CHECK_OK(parent_computation->RemoveInstruction(all_reduce));
while (!next->IsCrossReplicaAllReduce()) {
switch (next->opcode()) {
case HloOpcode::kBitcast:
case HloOpcode::kTranspose:
case HloOpcode::kReshape:
case HloOpcode::kConvert:
case HloOpcode::kMultiply:
break;
case HloOpcode::kAdd:
case HloOpcode::kSubtract: {
auto other_operand = (next->operands()[0] == prev)
? next->operands()[1]
: next->operands()[0];
// To move the AR past the addition/subtraction, we need to divide
// other_operand by the number of spatial partitions.
auto shape = other_operand->shape();
Literal lit(shape);
lit.PopulateWithValue<float>(num_spatial_partitions_);
auto divisor = parent_computation->AddInstruction(
HloInstruction::CreateConstant(lit.Clone()));
auto division =
parent_computation->AddInstruction(HloInstruction::CreateBinary(
shape, HloOpcode::kDivide, other_operand, divisor));
TF_CHECK_OK(other_operand->ReplaceUseWith(next, division));
break;
}
default:
LOG(FATAL) << "Unexpected instruction: " << next->ToShortString();
}
prev = next;
next = next->users()[0];
}
HloInstruction* other_summand = (add->operands()[0] == convert)
? add->operands()[1]
: add->operands()[0];
// To move the AR past the addition, we need to divide other_summand by
// the number of spatial partitions.
CHECK_EQ(all_reduce->user_count(), 1);
TF_CHECK_OK(
all_reduce->ReplaceAllUsesWith(all_reduce->mutable_operand(0)));
auto shape = other_summand->shape();
Literal lit(shape);
lit.PopulateWithValue<float>(num_spatial_partitions_);
auto divisor = parent_computation->AddInstruction(
HloInstruction::CreateConstant(lit.Clone()));
auto division =
parent_computation->AddInstruction(HloInstruction::CreateBinary(
shape, HloOpcode::kDivide, other_summand, divisor));
TF_CHECK_OK(other_summand->ReplaceUseWith(add, division));
// The AllReduce and the CRS are combined to an all-core AllReduce.
crs->set_all_reduce_id(all_reduce->all_reduce_id());
TF_CHECK_OK(parent_computation->RemoveInstruction(all_reduce));
next->set_all_reduce_id(all_reduce_id);
}
}
return true;
}
......
......@@ -25,9 +25,12 @@ limitations under the License.
namespace xla {
// Combine an AllReduce and a CrossReplicaSum when they are close to each other
// in the graph, to use an efficient CrossReplicaSum implementation that
// fully utilizes the interconnect bandwidth.
// When the HLO graph contains an AllReduce, followed by some simple linear
// operations, followed by a CrossReplicaSum, we can combine the AR and the CRS,
// to use an efficient CrossReplicaSum implementation that fully utilizes the
// interconnect bandwidth.
// Such sequences appear in spatially partitioned models.
// This pass must run right after spatial partitioning.
class ArCrsCombiner : public HloModulePass {
public:
ArCrsCombiner(int num_spatial_partitions)
......
......@@ -326,11 +326,27 @@ ENTRY %WhileLoop () -> (f32[2,2], f32[2,2]) {
EXPECT_FALSE(ArCrsCombiner::TestInstructionsComputeSameValue(i1, i2));
}
TEST_F(ArCrsCombinerTest, RewritePatternArConvertAddCrs) {
void CompareReplicaGroups(const std::vector<ReplicaGroup>& groups_before,
const std::vector<ReplicaGroup>& groups_after) {
ASSERT_EQ(groups_before.size(), groups_after.size());
for (int i = 0; i < groups_before.size(); ++i) {
// Somewhat verbose way to compare the replica_ids, because EqualsProto
// is not available in the open-source build.
auto group_before = groups_before[i];
std::vector<int64> ids_before(group_before.replica_ids().begin(),
group_before.replica_ids().end());
auto group_after = groups_after[i];
std::vector<int64> ids_after(group_after.replica_ids().begin(),
group_after.replica_ids().end());
EXPECT_EQ(ids_before, ids_after);
}
}
TEST_F(ArCrsCombinerTest, RewriteArConvertCrs) {
const char* module_str = R"(
HloModule foobar
%binary_add (a: bf16[], b: bf16[]) -> bf16[] {
%sum.bf16 (a: bf16[], b: bf16[]) -> bf16[] {
%a = bf16[] parameter(0)
%b = bf16[] parameter(1)
ROOT %add = bf16[] add(%a, %b)
......@@ -342,48 +358,257 @@ HloModule foobar
ROOT %add = f32[] add(%x, %y)
}
ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) {
%p = f32[2,2] parameter(0)
%constant.bf16 = bf16[2,2] constant(bf16[2,2] {{1, 2}, {3, 4}})
%constant.f32 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}})
ENTRY %entrycomp (p: bf16[]) -> (f32[], f32[]) {
%p = bf16[] parameter(0)
%cross-replica-sum.ar.1 = bf16[]
cross-replica-sum(%p),
replica_groups={{0},{1}},
all_reduce_id=1,
to_apply=%sum.bf16,
sharding={maximal device=0}
%convert.1 = f32[]
convert(%cross-replica-sum.ar.1),
sharding={maximal device=0}
%cross-replica-sum.1 = f32[]
cross-replica-sum(%convert.1),
replica_groups={{0,1}},
to_apply=%sum.f32,
sharding={maximal device=0}
%cross-replica-sum.ar.2 = bf16[]
cross-replica-sum(%p),
replica_groups={{0},{1}},
all_reduce_id=1,
to_apply=%sum.bf16,
sharding={maximal device=1}
%convert.2 = f32[]
convert(%cross-replica-sum.ar.2),
sharding={maximal device=1}
%cross-replica-sum.2 = f32[]
cross-replica-sum(%convert.2),
replica_groups={{0,1}},
to_apply=%sum.f32,
sharding={maximal device=1}
ROOT %tuple = (f32[], f32[])
tuple(%cross-replica-sum.1, %cross-replica-sum.2),
sharding={{maximal device=0}, {maximal device=1}}
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_str));
auto crs_before =
module->entry_computation()->root_instruction()->operands()[0];
auto replica_groups_before = crs_before->replica_groups();
ArCrsCombiner combiner(2);
auto changed = combiner.Run(module.get()).ValueOrDie();
EXPECT_TRUE(changed);
EXPECT_THAT(module->entry_computation()->root_instruction(),
op::Tuple(op::CrossReplicaSum(op::Convert(op::Parameter())),
op::CrossReplicaSum(op::Convert(op::Parameter()))));
auto crs_after =
module->entry_computation()->root_instruction()->operands()[0];
auto replica_groups_after = crs_after->replica_groups();
CompareReplicaGroups(replica_groups_before, replica_groups_after);
}
TEST_F(ArCrsCombinerTest, RewriteArBitcastCrs) {
const char* module_str = R"(
HloModule foobar
%sum.1 (a: f32[2,1], b: f32[2,1]) -> f32[2,1] {
%a = f32[2,1] parameter(0)
%b = f32[2,1] parameter(1)
ROOT %add = f32[2,1] add(%a, %b)
}
%sum.2 (x: f32[2], y: f32[2]) -> f32[2] {
%x = f32[2] parameter(0)
%y = f32[2] parameter(1)
ROOT %add = f32[2] add(%x, %y)
}
ENTRY %entrycomp (p: f32[2,1]) -> (f32[2], f32[2]) {
%p = f32[2,1] parameter(0)
%cross-replica-sum.ar.1 = f32[2,1]
cross-replica-sum(%p),
replica_groups={{0},{1}},
all_reduce_id=1,
to_apply=%sum.1,
sharding={maximal device=0}
%bitcast.1 = f32[2]{0} bitcast(f32[2,1]{1,0} %cross-replica-sum.ar.1)
%cross-replica-sum.1 = f32[2]
cross-replica-sum(%bitcast.1),
replica_groups={{0,1}},
to_apply=%sum.2,
sharding={maximal device=0}
%cross-replica-sum.ar.2 = f32[2,1]
cross-replica-sum(%p),
replica_groups={{0},{1}},
all_reduce_id=1,
to_apply=%sum.1,
sharding={maximal device=1}
%bitcast.2 = f32[2]{0} bitcast(f32[2,1]{1,0} %cross-replica-sum.ar.2)
%cross-replica-sum.2 = f32[2]
cross-replica-sum(%bitcast.2),
replica_groups={{0,1}},
to_apply=%sum.2,
sharding={maximal device=1}
ROOT %tuple = (f32[], f32[])
tuple(%cross-replica-sum.1, %cross-replica-sum.2),
sharding={{maximal device=0}, {maximal device=1}}
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_str));
auto crs_before =
module->entry_computation()->root_instruction()->operands()[0];
auto replica_groups_before = crs_before->replica_groups();
ArCrsCombiner combiner(2);
auto changed = combiner.Run(module.get()).ValueOrDie();
EXPECT_TRUE(changed);
EXPECT_THAT(module->entry_computation()->root_instruction(),
op::Tuple(op::CrossReplicaSum(op::Bitcast(op::Parameter())),
op::CrossReplicaSum(op::Bitcast(op::Parameter()))));
auto crs_after =
module->entry_computation()->root_instruction()->operands()[0];
auto replica_groups_after = crs_after->replica_groups();
CompareReplicaGroups(replica_groups_before, replica_groups_after);
}
%cross-replica-sum.ar.1 = bf16[2,2]
TEST_F(ArCrsCombinerTest, RewriteArMultiplyCrs) {
const char* module_str = R"(
HloModule foobar
%sum.f32 (x: f32[], y: f32[]) -> f32[] {
%x = f32[] parameter(0)
%y = f32[] parameter(1)
ROOT %add = f32[] add(%x, %y)
}
ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
%p = f32[] parameter(0)
%constant.f32 = f32[] constant(123)
%cross-replica-sum.ar.1 = f32[]
cross-replica-sum(%p),
replica_groups={{0},{1}},
all_reduce_id=1,
to_apply=%sum.f32,
sharding={maximal device=0}
%multiply.1 = f32[]
multiply(%cross-replica-sum.ar.1, %constant.f32),
sharding={maximal device=0}
%cross-replica-sum.1 = f32[]
cross-replica-sum(%multiply.1),
replica_groups={{0,1}},
to_apply=%sum.f32,
sharding={maximal device=0}
%cross-replica-sum.ar.2 = f32[]
cross-replica-sum(%p),
replica_groups={{0},{1}},
all_reduce_id=1,
to_apply=%sum.f32,
sharding={maximal device=1}
%multiply.2 = f32[]
multiply(%cross-replica-sum.ar.2, %constant.f32),
sharding={maximal device=1}
%cross-replica-sum.2 = f32[]
cross-replica-sum(%multiply.2),
replica_groups={{0,1}},
to_apply=%sum.f32,
sharding={maximal device=1}
ROOT %tuple = (f32[], f32[])
tuple(%cross-replica-sum.1, %cross-replica-sum.2),
sharding={{maximal device=0}, {maximal device=1}}
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_str));
auto crs_before =
module->entry_computation()->root_instruction()->operands()[0];
auto replica_groups_before = crs_before->replica_groups();
ArCrsCombiner combiner(2);
auto changed = combiner.Run(module.get()).ValueOrDie();
EXPECT_TRUE(changed);
EXPECT_THAT(
module->entry_computation()->root_instruction(),
op::Tuple(
op::CrossReplicaSum(op::Multiply(op::Parameter(), op::Constant())),
op::CrossReplicaSum(op::Multiply(op::Parameter(), op::Constant()))));
auto crs_after =
module->entry_computation()->root_instruction()->operands()[0];
auto replica_groups_after = crs_after->replica_groups();
CompareReplicaGroups(replica_groups_before, replica_groups_after);
}
TEST_F(ArCrsCombinerTest, RewriteArConvertAddCrs) {
const char* module_str = R"(
HloModule foobar
%sum.bf16 (a: bf16[], b: bf16[]) -> bf16[] {
%a = bf16[] parameter(0)
%b = bf16[] parameter(1)
ROOT %add = bf16[] add(%a, %b)
}
%sum.f32 (x: f32[], y: f32[]) -> f32[] {
%x = f32[] parameter(0)
%y = f32[] parameter(1)
ROOT %add = f32[] add(%x, %y)
}
ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
%p = f32[] parameter(0)
%constant.bf16 = bf16[] constant(1)
%constant.f32 = f32[] constant(2)
%cross-replica-sum.ar.1 = bf16[]
cross-replica-sum(%constant.bf16),
replica_groups={{0},{1}},
all_reduce_id=1,
to_apply=%binary_add,
to_apply=%sum.bf16,
sharding={maximal device=0}
%convert.1 = f32[2,2]
%convert.1 = f32[]
convert(%cross-replica-sum.ar.1),
sharding={maximal device=0}
%add.1 = f32[2,2]
%add.1 = f32[]
add(%constant.f32, %convert.1),
sharding={maximal device=0}
%cross-replica-sum.1 = f32[2,2]
%cross-replica-sum.1 = f32[]
cross-replica-sum(%add.1),
replica_groups={{0,1}},
to_apply=%sum.f32,
sharding={maximal device=0}
%cross-replica-sum.ar.2 = bf16[2,2]
%cross-replica-sum.ar.2 = bf16[]
cross-replica-sum(%constant.bf16),
replica_groups={{0},{1}},
all_reduce_id=1,
to_apply=%binary_add,
to_apply=%sum.bf16,
sharding={maximal device=1}
%convert.2 = f32[2,2]
%convert.2 = f32[]
convert(%cross-replica-sum.ar.2),
sharding={maximal device=1}
%add.2 = f32[2,2]
%add.2 = f32[]
add(%constant.f32, %convert.2),
sharding={maximal device=1}
%cross-replica-sum.2 = f32[2,2]
%cross-replica-sum.2 = f32[]
cross-replica-sum(%add.2),
replica_groups={{0,1}},
to_apply=%sum.f32,
sharding={maximal device=1}
ROOT %tuple = (f32[2,2], f32[2,2])
ROOT %tuple = (f32[], f32[])
tuple(%cross-replica-sum.1, %cross-replica-sum.2),
sharding={{maximal device=0}, {maximal device=1}}
}
......@@ -407,25 +632,14 @@ ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) {
auto crs_after =
module->entry_computation()->root_instruction()->operands()[0];
auto replica_groups_after = crs_after->replica_groups();
ASSERT_EQ(replica_groups_before.size(), replica_groups_after.size());
for (int i = 0; i < replica_groups_before.size(); ++i) {
// Somewhat verbose way to compare the replica_ids, because EqualsProto
// is not available in the open-source build.
auto group_before = replica_groups_before[i];
std::vector<int64> ids_before(group_before.replica_ids().begin(),
group_before.replica_ids().end());
auto group_after = replica_groups_after[i];
std::vector<int64> ids_after(group_after.replica_ids().begin(),
group_after.replica_ids().end());
EXPECT_EQ(ids_before, ids_after);
}
CompareReplicaGroups(replica_groups_before, replica_groups_after);
}
TEST_F(ArCrsCombinerTest, OtherSummandNotTheSameDontRewrite) {
const char* module_str = R"(
HloModule foobar
%binary_add (a: bf16[], b: bf16[]) -> bf16[] {
%sum.bf16 (a: bf16[], b: bf16[]) -> bf16[] {
%a = bf16[] parameter(0)
%b = bf16[] parameter(1)
ROOT %add = bf16[] add(%a, %b)
......@@ -437,49 +651,49 @@ HloModule foobar
ROOT %add = f32[] add(%x, %y)
}
ENTRY %entrycomp (p: f32[2,2]) -> (f32[2,2], f32[2,2]) {
%p = f32[2,2] parameter(0)
%constant.bf16 = bf16[2,2] constant(bf16[2,2] {{1, 2}, {3, 4}})
%constant.f32.1 = f32[2,2] constant(f32[2,2] {{1, 2}, {3, 4}})
%constant.f32.2 = f32[2,2] constant(f32[2,2] {{3, 4}, {5, 6}})
ENTRY %entrycomp (p: f32[]) -> (f32[], f32[]) {
%p = f32[] parameter(0)
%constant.bf16 = bf16[] constant(1)
%constant.f32.1 = f32[] constant(2)
%constant.f32.2 = f32[] constant(3)
%cross-replica-sum.ar.1 = bf16[2,2]
%cross-replica-sum.ar.1 = bf16[]
cross-replica-sum(%constant.bf16),
replica_groups={{0},{1}},
all_reduce_id=1,
to_apply=%binary_add,
to_apply=%sum.bf16,
sharding={maximal device=0}
%convert.1 = f32[2,2]
%convert.1 = f32[]
convert(%cross-replica-sum.ar.1),
sharding={maximal device=0}
%add.1 = f32[2,2]
%add.1 = f32[]
add(%constant.f32.1, %convert.1),
sharding={maximal device=0}
%cross-replica-sum.1 = f32[2,2]
%cross-replica-sum.1 = f32[]
cross-replica-sum(%add.1),
replica_groups={{0,1}},
to_apply=%sum.f32,
sharding={maximal device=0}
%cross-replica-sum.ar.2 = bf16[2,2]
%cross-replica-sum.ar.2 = bf16[]
cross-replica-sum(%constant.bf16),
replica_groups={{0},{1}},
all_reduce_id=1,
to_apply=%binary_add,
to_apply=%sum.bf16,
sharding={maximal device=1}
%convert.2 = f32[2,2]
%convert.2 = f32[]
convert(%cross-replica-sum.ar.2),
sharding={maximal device=1}
%add.2 = f32[2,2]
%add.2 = f32[]
add(%constant.f32.2, %convert.2),
sharding={maximal device=1}
%cross-replica-sum.2 = f32[2,2]
%cross-replica-sum.2 = f32[]
cross-replica-sum(%add.2),
replica_groups={{0,1}},
to_apply=%sum.f32,
sharding={maximal device=1}
ROOT %tuple = (f32[2,2], f32[2,2])
ROOT %tuple = (f32[], f32[])
tuple(%cross-replica-sum.1, %cross-replica-sum.2),
sharding={{maximal device=0}, {maximal device=1}}
}
......
......@@ -2060,6 +2060,10 @@ bool HloInstruction::IsCrossModuleAllReduce() const {
return opcode() == HloOpcode::kCrossReplicaSum && all_reduce_id();
}
bool HloInstruction::IsCrossReplicaAllReduce() const {
return opcode() == HloOpcode::kCrossReplicaSum && !all_reduce_id();
}
string HloInstruction::ToStringWithCanonicalNameMap(
const HloPrintOptions& options,
CanonicalNameMap* canonical_name_map) const {
......
......@@ -1174,9 +1174,12 @@ class HloInstruction {
// Returns true if this instruction is elementwise on all its operands.
bool IsElementwise() const;
// Returns true if this is an cross module all-reduce instrucion.
// Returns true if this is a cross module all-reduce instruction.
bool IsCrossModuleAllReduce() const;
// Returns true if this is a cross-replica all-reduce instruction.
bool IsCrossReplicaAllReduce() const;
// Returns true if this elementwise instruction implicitly broadcasts operand
// `operand_idx`.
//
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册