提交 f28e7353 编写于 作者: M Marcello Maggioni 提交者: TensorFlower Gardener

[XLA] Add support to CollectivePipeliner to sink collectives.

Small collectives might be better off when sinked and there are other potnential use cases
Also fix a bug, where we were accepting reuse of the data that we were storing and changing the tests using that pattern to match the fix.

PiperOrigin-RevId: 565080772
上级 65bd6912
......@@ -29,6 +29,7 @@ limitations under the License.
#include "absl/container/inlined_vector.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "xla/comparison_util.h"
......@@ -57,6 +58,8 @@ namespace xla {
const char* const CollectivePipeliner::kInsertedByPreviousStep =
"InsertedByPreviousStep";
const char* const CollectivePipeliner::kSunkByPreviousStep =
"SunkByPreviousStep";
namespace {
......@@ -128,7 +131,7 @@ bool CheckParameterUsageIsCompatible(const HloInstruction* gte,
int64_t sliced_index) {
for (auto* user : gte->users()) {
// Expected all users are dynamic-slices
if (user->opcode() != HloOpcode::kDynamicSlice && dus != user) {
if (dus != user) {
VLOG(5) << "CheckParameterUsageIsCompatible(): User not a dynamic slice "
"or the dynamic-update-slice for the output."
<< user->ToString();
......@@ -771,7 +774,8 @@ void WhileLoopAnalysis::CollectCollectivesToMove(
if (!should_process(instr)) {
continue;
}
if (direction == CollectivePipeliner::PipeliningDirection::kForward) {
if (direction == CollectivePipeliner::PipeliningDirection::kForward ||
direction == CollectivePipeliner::PipeliningDirection::kForwardSink) {
auto [dyn_update, formatting_ops] = CheckStoreIntoSliceIsCompatible(
instr, while_body, level_to_operate_on, pipeline_use_tree_);
if (dyn_update == nullptr) {
......@@ -781,6 +785,21 @@ void WhileLoopAnalysis::CollectCollectivesToMove(
"computation";
continue;
}
std::optional<int64_t> sliced_dim = GetSlicedDimension(dyn_update);
if (!sliced_dim.has_value()) {
VLOG(5) << "Skipping " << instr->name()
<< " because couldn't find sliced dimension";
continue;
}
if (direction == CollectivePipeliner::PipeliningDirection::kForwardSink &&
(*sliced_dim != 0 || dyn_update->shape().dimensions(0) !=
loop_iteration_count_->GetUnsignedValue())) {
VLOG(5) << "Skipping " << instr->name()
<< " because number of iteration of the loop doesn't match "
"slices being inserted or slice dim is not 0. slice_dim = "
<< *sliced_dim << " loop count = "
<< loop_iteration_count_->GetUnsignedValue();
}
if (!process_different_sized_options_) {
if (!formatting_ops.empty()) {
if (instr->operand(0)->shape() != formatting_ops.back()->shape()) {
......@@ -804,12 +823,6 @@ void WhileLoopAnalysis::CollectCollectivesToMove(
continue;
}
}
std::optional<int64_t> sliced_dim = GetSlicedDimension(dyn_update);
if (!sliced_dim.has_value()) {
VLOG(5) << "Skipping " << instr->name()
<< " because couldn't find sliced dimension";
continue;
}
const HloInstruction* to_insert_into = dyn_update->operand(0);
if (level_to_operate_on == 0 &&
(to_insert_into->opcode() != HloOpcode::kGetTupleElement ||
......@@ -917,6 +930,88 @@ const std::vector<WhileMoveInfo>& WhileLoopAnalysis::GetMoveInfos() const {
return move_infos_;
}
// Simple loop invariant check. If the data doesn't depend in any way from the
// input tuple consider it loop invariant.
// TODO: Extract something more complete in a separate file. This is current
// quite custom to the transformation here.
bool IsLoopInvariant(
const HloInstruction* instr,
absl::flat_hash_map<const HloInstruction*, bool>& invariant_cache) {
auto it = invariant_cache.find(instr);
if (it != invariant_cache.end()) {
return it->second;
}
// This performs a post order iteration of the graph. First element is the
// current HLO in the stack and the second parameter is the number of operands
// to still visit before visiting the HLO itself.
std::vector<std::pair<const HloInstruction*, int>> stack(
1, std::make_pair(instr, 0));
absl::flat_hash_set<const HloInstruction*> visited;
while (!stack.empty()) {
auto& current = stack.back();
if (std::get<0>(current)->HasSideEffect() ||
std::get<0>(current)->opcode() == HloOpcode::kParameter) {
invariant_cache[std::get<0>(current)] = false;
}
if (std::get<0>(current)->operands().empty()) {
invariant_cache[std::get<0>(current)] = true;
}
if (std::get<1>(current) > 0) {
auto* current_operand =
std::get<0>(current)->operand(std::get<1>(current) - 1);
auto cop_it = invariant_cache.find(current_operand);
CHECK(cop_it != invariant_cache.end())
<< "Entry expected to be populated";
if (!cop_it->second) {
invariant_cache[std::get<0>(current)] = false;
stack.pop_back();
continue;
}
}
if (std::get<0>(current)->operand_count() == std::get<1>(current)) {
stack.pop_back();
continue;
}
auto* next_operand = std::get<0>(current)->operand(std::get<1>(current)++);
auto op_it = invariant_cache.find(next_operand);
if (op_it == invariant_cache.end()) {
stack.push_back(std::make_pair(next_operand, 0));
} else if (!op_it->second) {
invariant_cache[next_operand] &= op_it->second;
}
}
it = invariant_cache.find(instr);
CHECK(it != invariant_cache.end())
<< "We should have computed \"instr\" value";
return it->second;
}
// Compute a shape that can hold a concatenation of tensors of shape base_shape.
Shape ComputeFullOutputShape(const WhileMoveInfo& move_info,
const Shape& base_shape) {
return ShapeUtil::PrependMajorDimension(
move_info.dynamic_update_slice->operand(0)
->shape()
.dimensions()[move_info.sliced_idx],
base_shape);
}
// Create zero of base type ptype and broadcast it to shape.
HloInstruction* CreateZero(HloComputation* comp, const Shape& shape,
PrimitiveType ptype) {
if (shape.dimensions_size() == 0) {
return comp->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::Zero(ptype)));
}
HloInstruction* zero_constant =
comp->AddInstruction(HloInstruction::CreateBroadcast(
shape,
comp->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::Zero(ptype))),
{}));
return zero_constant;
}
} // namespace
// Function that does the work of pushing forward instructions that have been
......@@ -957,7 +1052,6 @@ Status TransformLoopForward(const WhileLoopAnalysis& loop_analysis,
absl::flat_hash_set<HloInstruction*> to_skip_set;
absl::flat_hash_map<HloInstruction*, HloInstruction*> formatting_map;
absl::flat_hash_map<HloInstruction*, int64_t> is_output_instruction;
absl::flat_hash_map<HloInstruction*, int64_t> pipelined_instruction_index;
std::vector<int64_t> moves_requiring_special_output;
int64_t count = 0;
// Add all-reduces to duplicate into a set.
......@@ -1318,6 +1412,421 @@ Status TransformLoopForward(const WhileLoopAnalysis& loop_analysis,
return OkStatus();
}
// Function that does the work of sinking all-reduces the output of which are
// concatenated after the loop. Rough transformation: while (i < LAYERS) {
// p0 = param(0)
// p1 = param(1)
// x = computation(p0)
// xg = all-reduce(x)
// y = computation(p1)
// yg = all-reduce(y)
// }
//
// to
//
// x_prev = computation(p0)
// y_prev = computation(p1)
// i = i + 1
// while (i < LAYERS, x_all, y_all) {
// p0 = param(0)
// p1 = param(1)
// x = computation(p0)
// y = computation(p1)
// x_all = append(x)
// y_all = append(y)
// }
// xg_all = all-reduce(x_all)
// yg_all = all-reduce(y_all)
Status TransformLoopForwardSink(const WhileLoopAnalysis& loop_analysis,
bool insert_non_alias_custom_call,
int64_t level_to_operate_on,
bool pipeline_use_tree,
bool process_different_sized_ops,
HloPredicate should_process,
int64_t& next_channel_id) {
// Defining some maps/sets to keep track of instructions duplicated.
absl::flat_hash_map<HloInstruction*, int64_t> is_output_instruction;
absl::flat_hash_map<const HloInstruction*, bool> invariant_cache;
// Map get-tuple-elements() inside of the loop with elements passed to the
// tuple that is the "init" of the loop.
HloInstruction* while_loop = loop_analysis.while_loop_instruction();
HloComputation* while_body = while_loop->while_body();
CHECK_EQ(while_body->parameter_instructions().size(), 1)
<< "Expected only one parameter";
HloInstruction* loop_parameter = while_body->parameter_instructions()[0];
HloInstruction* loop_init = while_loop->mutable_operand(0);
CHECK_EQ(while_body->root_instruction()->opcode(), HloOpcode::kTuple);
for (int i = 0; i < while_body->root_instruction()->operand_count(); ++i) {
is_output_instruction[while_body->root_instruction()->mutable_operand(i)] =
i;
}
// Collect the new parameter shapes with the additional state for the indices
// and construct new operand vectors for the init of the new loop and its root
// instruction.
HloComputation* loop_computation = while_loop->parent();
HloComputation* body_computation = while_loop->while_body();
std::vector<HloInstruction*> new_init_operands;
std::vector<Shape> new_parameter_shapes;
std::vector<HloInstruction*> new_root_operands;
absl::flat_hash_set<int64_t> indices_to_insert;
const int64_t operands_indices_count = loop_init->operand_count();
const int64_t new_loop_tuple_operand_count = operands_indices_count;
absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
replacements;
new_parameter_shapes.resize(new_loop_tuple_operand_count);
new_root_operands.resize(new_loop_tuple_operand_count);
new_init_operands.resize(new_loop_tuple_operand_count);
absl::flat_hash_set<int64_t> original_to_move_indices;
// Initialize data structures with information about the outputs that need to
// be sunk.
for (auto& to_move : loop_analysis.GetMoveInfos()) {
HloInstruction* collective = to_move.collective_to_move;
Shape shape =
ComputeFullOutputShape(to_move, collective->operand(0)->shape());
new_init_operands[to_move.output_idx] =
CreateZero(loop_computation, shape, shape.element_type());
new_parameter_shapes[to_move.output_idx] = shape;
original_to_move_indices.insert(to_move.output_idx);
indices_to_insert.insert(to_move.output_idx);
new_root_operands[to_move.output_idx] = collective->mutable_operand(0);
}
// Initialize the data structures for output indices that aren't modified.
for (int i = 0; i < loop_parameter->shape().tuple_shapes().size(); ++i) {
if (original_to_move_indices.contains(i)) {
continue;
}
new_parameter_shapes[i] = loop_parameter->shape().tuple_shapes(i);
new_init_operands[i] = loop_init->mutable_operand(i);
new_root_operands[i] = while_body->root_instruction()->mutable_operand(i);
}
// Collect instructions that are necessary for the execution of the sunk
// instructions. If they are loop invariant they are stored as is, otherwise
// the version for each iteration is accumulated in a buffer.
for (auto& move_info : loop_analysis.GetMoveInfos()) {
auto pipelined_instrs = CollectDependenciesToPipeline(
move_info.collective_to_move, absl::MakeSpan(move_info.formatting_ops));
for (auto* pipelined : pipelined_instrs) {
const bool is_loop_invariant =
IsLoopInvariant(pipelined, invariant_cache);
is_output_instruction[pipelined] = new_init_operands.size();
if (is_loop_invariant) {
new_parameter_shapes.push_back(pipelined->shape());
new_init_operands.push_back(
CreateZero(loop_computation, pipelined->shape(),
pipelined->shape().element_type()));
new_root_operands.push_back(pipelined);
continue;
}
Shape expanded_shape =
ComputeFullOutputShape(move_info, pipelined->shape());
new_parameter_shapes.push_back(expanded_shape);
new_init_operands.push_back(CreateZero(loop_computation, expanded_shape,
expanded_shape.element_type()));
indices_to_insert.insert(new_root_operands.size());
HloInstruction* reshaped = body_computation->AddInstruction(
HloInstruction::CreateReshape(expanded_shape, pipelined));
new_root_operands.push_back(reshaped);
}
}
std::unique_ptr<HloInstruction> new_parameter =
HloInstruction::CreateParameter(
0, ShapeUtil::MakeTupleShape(new_parameter_shapes),
absl::StrCat("sink_", loop_parameter->name()));
// Insert inputs to the collective we are sinking in slices for the loop.
for (auto& to_move : loop_analysis.GetMoveInfos()) {
if (!indices_to_insert.contains(to_move.output_idx)) {
continue;
}
HloInstruction* to_insert =
body_computation->AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::PrependMajorDimension(
1, new_root_operands[to_move.output_idx]->shape()),
new_root_operands[to_move.output_idx]));
Shape expanded_shape = ComputeFullOutputShape(
to_move, new_root_operands[to_move.output_idx]->shape());
HloInstruction* input =
body_computation->AddInstruction(HloInstruction::CreateCustomCall(
expanded_shape,
{body_computation->AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR0((int32_t)to_move.output_idx)))},
"PlaceHolder"));
std::vector<HloInstruction*> indices(
expanded_shape.dimensions_size(),
CreateZero(
body_computation, to_move.dynamic_update_slice->index_shapes()[0],
to_move.dynamic_update_slice->index_shapes()[0].element_type()));
indices[0] = to_move.dynamic_update_slice->index_operands()[0];
to_insert = body_computation->AddInstruction(
HloInstruction::CreateDynamicUpdateSlice(expanded_shape, input,
to_insert, indices));
new_root_operands[to_move.output_idx] = to_insert;
}
std::unique_ptr<HloInstruction> new_root_instr =
HloInstruction::CreateTuple(new_root_operands);
// Mark for removal (by setting replacement entry to nullptr) the users of the
// old parameters we are replacing for the loops. All the computation tree
// for those should be not used in the new loop.
for (auto* p_user : body_computation->parameter_instructions()[0]->users()) {
CHECK_EQ(p_user->opcode(), HloOpcode::kGetTupleElement);
const int64_t tuple_idx = p_user->tuple_index();
if (!indices_to_insert.contains(tuple_idx)) {
continue;
}
replacements[p_user] =
HloInstruction::CreateGetTupleElement(new_parameter.get(), tuple_idx);
std::vector<HloInstruction*> stack(p_user->users().begin(),
p_user->users().end());
while (!stack.empty()) {
auto* u = stack.back();
stack.pop_back();
replacements[u] = nullptr;
for (auto* user : u->users()) {
if (user == body_computation->root_instruction()) {
continue;
}
stack.push_back(user);
}
}
}
replacements[body_computation->parameter_instruction(0)] =
std::move(new_parameter);
replacements[body_computation->root_instruction()] =
std::move(new_root_instr);
replacements[while_loop->while_condition()->parameter_instruction(0)] =
HloInstruction::CreateParameter(
0, ShapeUtil::MakeTupleShape(new_parameter_shapes),
absl::StrCat(
"sink_",
while_loop->while_condition()->parameter_instruction(0)->name()));
// Clone and create new loop.
HloInstruction* new_init = loop_computation->AddInstruction(
HloInstruction::CreateTuple(new_init_operands));
HloComputation* cloned_body =
body_computation->parent()->AddEmbeddedComputation(
body_computation->CloneWithReplacements(&replacements));
HloComputation* cloned_cond =
body_computation->parent()->AddEmbeddedComputation(
while_loop->while_condition()->CloneWithReplacements(&replacements));
for (int64_t i = 0; i < cloned_body->root_instruction()->operand_count();
++i) {
HloInstruction* output =
cloned_body->root_instruction()->mutable_operand(i);
if (output->opcode() != HloOpcode::kDynamicUpdateSlice) {
continue;
}
if (!output->operand(0)->IsCustomCall("PlaceHolder")) {
continue;
}
auto idx = Cast<HloConstantInstruction>(output->operand(0)->operand(0))
->literal()
.GetFirstInteger();
auto* new_param =
cloned_body->AddInstruction(HloInstruction::CreateGetTupleElement(
output->shape(), cloned_body->parameter_instruction(0), *idx));
HloInstruction* old_operand_param = output->mutable_operand(0);
TF_RETURN_IF_ERROR(output->ReplaceOperandWith(0, new_param));
TF_RETURN_IF_ERROR(
old_operand_param->parent()->RemoveInstruction(old_operand_param));
if (insert_non_alias_custom_call) {
auto* old_operand = output->mutable_operand(1);
auto* custom_call =
cloned_body->AddInstruction(HloInstruction::CreateCustomCall(
old_operand->shape(), {old_operand},
/*custom_call_target=*/CollectivePipeliner::kSunkByPreviousStep));
TF_RETURN_IF_ERROR(output->ReplaceOperandWith(1, custom_call));
}
}
HloInstruction* new_while =
loop_computation->AddInstruction(HloInstruction::CreateWhile(
new_init->shape(), cloned_cond, cloned_body, new_init));
std::vector<HloInstruction*> new_output_tuple;
new_output_tuple.resize(new_root_operands.size(), nullptr);
// Reproduce computation to the output after the loop on the full shape.
for (auto& to_move : loop_analysis.GetMoveInfos()) {
absl::flat_hash_map<HloInstruction*, HloInstruction*> pipelined_map;
HloInstruction* to_sink = loop_computation->AddInstruction(
HloInstruction::CreateGetTupleElement(new_while, to_move.output_idx));
const int64_t new_dim_limit =
to_move.dynamic_update_slice->shape().dimensions(0);
pipelined_map[to_move.collective_to_move->mutable_operand(0)] = to_sink;
auto pipelined_instrs = CollectDependenciesToPipeline(
to_move.collective_to_move, absl::MakeSpan(to_move.formatting_ops));
for (auto* original_pipelined : pipelined_instrs) {
const bool is_loop_invariant =
IsLoopInvariant(original_pipelined, invariant_cache);
CHECK(is_output_instruction.contains(original_pipelined));
int64_t pipelined_idx = is_output_instruction[original_pipelined];
HloInstruction* pipelined = loop_computation->AddInstruction(
HloInstruction::CreateGetTupleElement(new_while, pipelined_idx));
// Broadcast loop invariant instructions.
if (is_loop_invariant) {
Shape full_shape = ComputeFullOutputShape(to_move, pipelined->shape());
absl::InlinedVector<int64_t, 4> operand_dims;
operand_dims.resize(pipelined->shape().dimensions_size());
absl::c_iota(operand_dims, 1);
HloInstruction* broadcasted =
loop_computation->AddInstruction(HloInstruction::CreateBroadcast(
full_shape, pipelined, operand_dims));
pipelined_map[original_pipelined] = broadcasted;
} else {
pipelined_map[original_pipelined] = pipelined;
}
}
// Cloning the main instruction
HloInstruction* pipelined_instr_cloned = loop_computation->AddInstruction(
to_move.collective_to_move->CloneWithNewOperands(
ComputeFullOutputShape(to_move,
to_move.collective_to_move->shape()),
{to_sink}));
UpdateInstructionChannelId(pipelined_instr_cloned, next_channel_id);
pipelined_map[to_move.collective_to_move] = pipelined_instr_cloned;
auto collect_operands = [&pipelined_map](HloInstruction* instr) {
std::vector<HloInstruction*> operands;
for (auto* operand : instr->mutable_operands()) {
auto it = pipelined_map.find(operand);
CHECK(it != pipelined_map.end());
operands.push_back(it->second);
}
return operands;
};
// We are adding a batch dimension to the formatting ops, so we need to
// specially rewrite each instruction potentially if adding dimensions has
// an effect on the instruction itself (like say broadcast, slices ... etc).
for (HloInstruction* formatting_op : to_move.formatting_ops) {
if (formatting_op->IsElementwise() ||
formatting_op->opcode() == HloOpcode::kReshape ||
formatting_op->opcode() == HloOpcode::kConvert ||
formatting_op->opcode() == HloOpcode::kCollectivePermute) {
HloInstruction* cloned_elementwise = loop_computation->AddInstruction(
formatting_op->CloneWithNewOperands(
ComputeFullOutputShape(to_move, formatting_op->shape()),
collect_operands(formatting_op)));
pipelined_map[formatting_op] = cloned_elementwise;
continue;
}
if (formatting_op->opcode() == HloOpcode::kBroadcast) {
CHECK(formatting_op->dimensions().empty());
auto operands = collect_operands(formatting_op);
std::vector<int64_t> dimensions(1, 0);
// Constant scalars don't get expanded ahead of time and are kept
// scalar.
if (operands[0]->shape().dimensions_size() == 0) {
dimensions.clear();
}
HloInstruction* expanded_broadcast =
loop_computation->AddInstruction(HloInstruction::CreateBroadcast(
ComputeFullOutputShape(to_move, formatting_op->shape()),
operands[0], dimensions));
pipelined_map[formatting_op] = expanded_broadcast;
continue;
}
if (formatting_op->opcode() == HloOpcode::kSlice) {
std::vector<int64_t> slice_start = formatting_op->slice_starts();
std::vector<int64_t> slice_limits = formatting_op->slice_limits();
std::vector<int64_t> slice_strides = formatting_op->slice_strides();
slice_start.insert(slice_start.begin(), 0);
slice_limits.insert(slice_limits.begin(), new_dim_limit);
slice_strides.insert(slice_strides.begin(), 1);
HloInstruction* expanded_slice =
loop_computation->AddInstruction(HloInstruction::CreateSlice(
ComputeFullOutputShape(to_move, formatting_op->shape()),
collect_operands(formatting_op)[0], slice_start, slice_limits,
slice_strides));
pipelined_map[formatting_op] = expanded_slice;
continue;
}
if (formatting_op->opcode() == HloOpcode::kDynamicSlice) {
std::vector<int64_t> dynamic_slice_sizes =
formatting_op->dynamic_slice_sizes();
dynamic_slice_sizes.insert(dynamic_slice_sizes.begin(), new_dim_limit);
HloDynamicSliceInstruction* dynslice =
Cast<HloDynamicSliceInstruction>(formatting_op);
HloInstruction* zero = loop_computation->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::Zero(
formatting_op->operand(dynslice->first_index_operand_number())
->shape()
.element_type())));
std::vector<HloInstruction*> indices(1, zero);
indices.insert(indices.end(), dynslice->index_operands().begin(),
dynslice->index_operands().end());
HloInstruction* expanded_dynslice =
loop_computation->AddInstruction(HloInstruction::CreateDynamicSlice(
ComputeFullOutputShape(to_move, formatting_op->shape()),
collect_operands(formatting_op)[0], indices,
dynamic_slice_sizes));
pipelined_map[formatting_op] = expanded_dynslice;
continue;
}
if (formatting_op->opcode() == HloOpcode::kPad) {
HloPadInstruction* pad_instruction =
Cast<HloPadInstruction>(formatting_op);
PaddingConfig p_config = pad_instruction->padding_config();
PaddingConfig new_p_config;
new_p_config.add_dimensions();
for (auto& dim : p_config.dimensions()) {
auto* new_dim = new_p_config.add_dimensions();
*new_dim = dim;
}
auto new_operands = collect_operands(formatting_op);
HloInstruction* expanded_pad =
loop_computation->AddInstruction(HloInstruction::CreatePad(
ComputeFullOutputShape(to_move, formatting_op->shape()),
new_operands[0], new_operands[1], new_p_config));
pipelined_map[formatting_op] = expanded_pad;
continue;
}
if (formatting_op->opcode() == HloOpcode::kTranspose) {
HloTransposeInstruction* transpose_instruction =
Cast<HloTransposeInstruction>(formatting_op);
std::vector<int64_t> new_dims(
transpose_instruction->dimensions().begin(),
transpose_instruction->dimensions().end());
new_dims.insert(new_dims.begin(), 0);
for (int64_t& dim : new_dims) {
++dim;
}
HloInstruction* expanded_transpose =
loop_computation->AddInstruction(HloInstruction::CreateTranspose(
ComputeFullOutputShape(to_move, formatting_op->shape()),
collect_operands(formatting_op)[0], new_dims));
pipelined_map[formatting_op] = expanded_transpose;
continue;
}
CHECK(false) << "Unsupported instruction";
}
HloInstruction* inserted_operand =
to_move.dynamic_update_slice->mutable_operand(1);
CHECK(pipelined_map.contains(inserted_operand))
<< "Expected to be processed";
HloInstruction* expanded_inserted = pipelined_map[inserted_operand];
if (!ShapeUtil::Compatible(expanded_inserted->shape(),
to_move.dynamic_update_slice->shape())) {
expanded_inserted =
loop_computation->AddInstruction(HloInstruction::CreateReshape(
to_move.dynamic_update_slice->shape(), expanded_inserted));
}
new_output_tuple[to_move.output_idx] = expanded_inserted;
}
// Create new loop tuple replacement.
for (int i = 0; i < new_while->shape().tuple_shapes_size(); ++i) {
if (new_output_tuple[i] != nullptr) {
continue;
}
new_output_tuple[i] = loop_computation->AddInstruction(
HloInstruction::CreateGetTupleElement(new_while, i));
}
HloInstruction* new_tuple = loop_computation->AddInstruction(
HloInstruction::CreateTuple(new_output_tuple));
TF_RETURN_IF_ERROR(while_loop->ReplaceAllUsesWithDifferentShape(new_tuple));
TF_RETURN_IF_ERROR(
loop_computation->RemoveInstructionAndUnusedOperands(while_loop));
TF_RETURN_IF_ERROR(loop_computation->parent()->RemoveUnusedComputations());
return OkStatus();
}
// Function that does the work of pushing backward instructions that have been
// determined that can be pipelined. Rough transformation:
// while (i < LAYERS) {
......@@ -1601,8 +2110,13 @@ StatusOr<bool> CollectivePipeliner::Run(
}
}
}
int64_t transformed_loops = 0;
int64_t transformed_instructions = 0;
int64_t next_channel_id = hlo_query::NextChannelId(*module);
VLOG(1) << "Pipelining on direction: "
<< GetPipelineDirectionString(config_.pipelining_direction);
for (HloInstruction* instruction : while_loop_instructions) {
VLOG(1) << "While: " << instruction->ToString();
WhileLoopAnalysis loop_analysis(
instruction, config_.max_pipelining_per_loop, config_.pipeline_use_tree,
config_.process_different_sized_ops);
......@@ -1611,7 +2125,6 @@ StatusOr<bool> CollectivePipeliner::Run(
loop_analysis.GetLoopIterationCount()->GetUnsignedValue() == 0) {
continue;
}
VLOG(1) << "While: " << instruction->ToString();
VLOG(1) << "While iterations: "
<< loop_analysis.GetLoopIterationCount()->ToString();
loop_analysis.CollectCollectivesToMove(config_.level_to_operate_on,
......@@ -1620,6 +2133,7 @@ StatusOr<bool> CollectivePipeliner::Run(
if (loop_analysis.GetMoveInfos().empty()) {
continue;
}
transformed_instructions += loop_analysis.GetMoveInfos().size();
VLOG(1) << "Found Collectives to optimize";
if (VLOG_IS_ON(1)) {
for (auto& to_move : loop_analysis.GetMoveInfos()) {
......@@ -1635,6 +2149,12 @@ StatusOr<bool> CollectivePipeliner::Run(
loop_analysis, !config_.last_run, config_.level_to_operate_on,
config_.pipeline_use_tree, config_.process_different_sized_ops,
config_.should_process, next_channel_id));
} else if (config_.pipelining_direction ==
PipeliningDirection::kForwardSink) {
TF_RETURN_IF_ERROR(TransformLoopForwardSink(
loop_analysis, !config_.last_run, config_.level_to_operate_on,
config_.pipeline_use_tree, config_.process_different_sized_ops,
config_.should_process, next_channel_id));
} else {
CHECK_EQ(config_.pipelining_direction, PipeliningDirection::kBackward);
TF_RETURN_IF_ERROR(TransformLoopBackward(
......@@ -1642,6 +2162,7 @@ StatusOr<bool> CollectivePipeliner::Run(
config_.process_different_sized_ops, config_.should_process,
next_channel_id));
}
++transformed_loops;
changed = true;
}
// If this is the last expected run then remove all the custom-calls that we
......@@ -1665,6 +2186,10 @@ StatusOr<bool> CollectivePipeliner::Run(
instruction));
}
}
VLOG(1) << "Transformed loops: " << transformed_loops
<< " and transformed instructions: " << transformed_instructions
<< " for pipelining direction: "
<< GetPipelineDirectionString(config_.pipelining_direction);
return changed;
}
......
......@@ -59,6 +59,7 @@ class CollectivePipeliner : public HloModulePass {
enum PipeliningDirection {
kBackward,
kForward,
kForwardSink,
};
struct Config {
int64_t level_to_operate_on = 0;
......@@ -76,15 +77,31 @@ class CollectivePipeliner : public HloModulePass {
HloPredicate should_process;
};
static const char* const kInsertedByPreviousStep;
static const char* const kSunkByPreviousStep;
explicit CollectivePipeliner(const Config& config) : config_(config) {}
CollectivePipeliner(CollectivePipeliner&& other) = default;
CollectivePipeliner& operator=(CollectivePipeliner&& other) = default;
absl::string_view GetPipelineDirectionString(PipeliningDirection direction) {
switch (direction) {
case PipeliningDirection::kForward: {
return "forward";
}
case PipeliningDirection::kBackward: {
return "backward";
}
case PipeliningDirection::kForwardSink: {
return "forwardsink";
}
}
}
absl::string_view name() const override {
if (config_.pipelining_direction == kForward) {
return "collective-pipeliner-forward";
} else {
} else if (config_.pipelining_direction == kBackward) {
return "collective-pipeliner-backward";
} else {
return "collective-pipeliner-forwardsink";
}
}
......
......@@ -87,16 +87,17 @@ add {
}
while_cond {
param = (s32[], bf16[3,8,128]) parameter(0)
param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0)
gte = s32[] get-tuple-element(param), index=0
constant.1 = s32[] constant(3)
ROOT cmp = pred[] compare(gte, constant.1), direction=LT
}
while_body {
param = (s32[], bf16[3,8,128]) parameter(0)
param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0)
get-tuple-element.394 = s32[] get-tuple-element(param), index=0
get-tuple-element.395 = bf16[3,8,128] get-tuple-element(param), index=1
get-tuple-element.5 = bf16[3,8,128] get-tuple-element(param), index=2
constant.2557 = s32[] constant(1)
add.230 = s32[] add(get-tuple-element.394, constant.2557)
constant.2559 = s32[] constant(3)
......@@ -108,18 +109,18 @@ while_body {
constant.2562 = s32[] constant(2)
add.232 = s32[] add(subtract.139, constant.2562)
select.1348 = s32[] select(compare.747, add.232, add.231)
dynamic-slice.99 = bf16[1,8,128] dynamic-slice(get-tuple-element.395, select.1348, constant.2561, constant.2561), dynamic_slice_sizes={1,8,128}
dynamic-slice.99 = bf16[1,8,128] dynamic-slice(get-tuple-element.5, select.1348, constant.2561, constant.2561), dynamic_slice_sizes={1,8,128}
mul = bf16[1,8,128] multiply(dynamic-slice.99, dynamic-slice.99)
ar.1 = bf16[1,8,128] all-reduce(mul), replica_groups={}, to_apply=add, channel_id=1
dynamic-update-slice.35 = bf16[3,8,128] dynamic-update-slice(get-tuple-element.395, ar.1, select.1348, constant.2561, constant.2561)
ROOT tuple = (s32[], bf16[3,8,128]) tuple(add.230, dynamic-update-slice.35)
ROOT tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(add.230, dynamic-update-slice.35, get-tuple-element.5)
}
ENTRY entry {
c0 = s32[] constant(0)
p0 = bf16[3,8,128] parameter(0)
tuple = (s32[], bf16[3,8,128]) tuple(c0, p0)
while = (s32[], bf16[3,8,128]) while(tuple), condition=while_cond, body=while_body
tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(c0, p0, p0)
while = (s32[], bf16[3,8,128], bf16[3,8,128]) while(tuple), condition=while_cond, body=while_body
ROOT gte1 = bf16[3,8,128] get-tuple-element(while), index=1
}
)";
......@@ -132,7 +133,7 @@ ENTRY entry {
EXPECT_EQ(sliced->opcode(), HloOpcode::kDynamicSlice);
const HloInstruction* index = sliced->operand(1);
EXPECT_EQ(index->opcode(), HloOpcode::kGetTupleElement);
EXPECT_EQ(index->tuple_index(), 2);
EXPECT_EQ(index->tuple_index(), 3);
const HloInstruction* while_inst = index->operand(0);
EXPECT_EQ(while_inst->opcode(), HloOpcode::kWhile);
const HloInstruction* while_root =
......@@ -151,7 +152,7 @@ ENTRY entry {
EXPECT_EQ(get_tuple_value->opcode(), HloOpcode::kGetTupleElement);
EXPECT_EQ(get_tuple_index->opcode(), HloOpcode::kGetTupleElement);
EXPECT_EQ(get_tuple_value->tuple_index(), 1);
EXPECT_EQ(get_tuple_index->tuple_index(), 2);
EXPECT_EQ(get_tuple_index->tuple_index(), 3);
}
TEST_F(CollectivePipelinerTest, TransformIncrementIndexByOneNotFirstIdx) {
......@@ -165,16 +166,17 @@ add {
}
while_cond {
param = (s32[], bf16[8,3,128]) parameter(0)
param = (s32[], bf16[8,3,128], bf16[8,3,128]) parameter(0)
gte = s32[] get-tuple-element(param), index=0
constant.1 = s32[] constant(3)
ROOT cmp = pred[] compare(gte, constant.1), direction=LT
}
while_body {
param = (s32[], bf16[8,3,128]) parameter(0)
param = (s32[], bf16[8,3,128], bf16[8,3,128]) parameter(0)
get-tuple-element.394 = s32[] get-tuple-element(param), index=0
get-tuple-element.395 = bf16[8,3,128] get-tuple-element(param), index=1
get-tuple-element.5 = bf16[8,3,128] get-tuple-element(param), index=2
constant.2557 = s32[] constant(1)
add.230 = s32[] add(get-tuple-element.394, constant.2557)
constant.2559 = s32[] constant(3)
......@@ -186,18 +188,18 @@ while_body {
constant.2562 = s32[] constant(2)
add.232 = s32[] add(subtract.139, constant.2562)
select.1348 = s32[] select(compare.747, add.232, add.231)
dynamic-slice.99 = bf16[8,1,128] dynamic-slice(get-tuple-element.395, constant.2561, select.1348, constant.2561), dynamic_slice_sizes={8,1,128}
dynamic-slice.99 = bf16[8,1,128] dynamic-slice(get-tuple-element.5, constant.2561, select.1348, constant.2561), dynamic_slice_sizes={8,1,128}
mul = bf16[8,1,128] multiply(dynamic-slice.99, dynamic-slice.99)
ar.1 = bf16[8,1,128] all-reduce(mul), replica_groups={}, to_apply=add, channel_id=1
dynamic-update-slice.35 = bf16[8,3,128] dynamic-update-slice(get-tuple-element.395, ar.1, constant.2561, select.1348, constant.2561)
ROOT tuple = (s32[], bf16[8,3,128]) tuple(add.230, dynamic-update-slice.35)
ROOT tuple = (s32[], bf16[8,3,128], bf16[8,3,128]) tuple(add.230, dynamic-update-slice.35, get-tuple-element.5)
}
ENTRY entry {
c0 = s32[] constant(0)
p0 = bf16[8,3,128] parameter(0)
tuple = (s32[], bf16[8,3,128]) tuple(c0, p0)
while = (s32[], bf16[8,3,128]) while(tuple), condition=while_cond, body=while_body
tuple = (s32[], bf16[8,3,128], bf16[8,3,128]) tuple(c0, p0, p0)
while = (s32[], bf16[8,3,128], bf16[8,3,128]) while(tuple), condition=while_cond, body=while_body
ROOT gte1 = bf16[8,3,128] get-tuple-element(while), index=1
}
)";
......@@ -210,7 +212,7 @@ ENTRY entry {
EXPECT_EQ(sliced->opcode(), HloOpcode::kDynamicSlice);
const HloInstruction* index = sliced->operand(2);
EXPECT_EQ(index->opcode(), HloOpcode::kGetTupleElement);
EXPECT_EQ(index->tuple_index(), 2);
EXPECT_EQ(index->tuple_index(), 3);
const HloInstruction* while_inst = index->operand(0);
EXPECT_EQ(while_inst->opcode(), HloOpcode::kWhile);
const HloInstruction* while_root =
......@@ -229,7 +231,7 @@ ENTRY entry {
EXPECT_EQ(get_tuple_value->opcode(), HloOpcode::kGetTupleElement);
EXPECT_EQ(get_tuple_index->opcode(), HloOpcode::kGetTupleElement);
EXPECT_EQ(get_tuple_value->tuple_index(), 1);
EXPECT_EQ(get_tuple_index->tuple_index(), 2);
EXPECT_EQ(get_tuple_index->tuple_index(), 3);
}
TEST_F(CollectivePipelinerTest, TransformIncrementByTwo) {
......@@ -243,16 +245,17 @@ add {
}
while_cond {
param = (s32[], bf16[3,8,128]) parameter(0)
param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0)
gte = s32[] get-tuple-element(param), index=0
constant.1 = s32[] constant(3)
ROOT cmp = pred[] compare(gte, constant.1), direction=LT
}
while_body {
param = (s32[], bf16[3,8,128]) parameter(0)
param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0)
get-tuple-element.394 = s32[] get-tuple-element(param), index=0
get-tuple-element.395 = bf16[3,8,128] get-tuple-element(param), index=1
get-tuple-element.5 = bf16[3,8,128] get-tuple-element(param), index=2
constant.2557 = s32[] constant(2)
add.230 = s32[] add(get-tuple-element.394, constant.2557)
constant.2559 = s32[] constant(3)
......@@ -264,18 +267,19 @@ while_body {
constant.2562 = s32[] constant(2)
add.232 = s32[] add(subtract.139, constant.2562)
select.1348 = s32[] select(compare.747, add.232, add.231)
dynamic-slice.99 = bf16[1,8,128] dynamic-slice(get-tuple-element.395, select.1348, constant.2561, constant.2561), dynamic_slice_sizes={1,8,128}
dynamic-slice.99 = bf16[1,8,128] dynamic-slice(get-tuple-element.5, select.1348, constant.2561, constant.2561), dynamic_slice_sizes={1,8,128}
mul = bf16[1,8,128] multiply(dynamic-slice.99, dynamic-slice.99)
ar.1 = bf16[1,8,128] all-reduce(mul), replica_groups={}, to_apply=add, channel_id=1
dynamic-update-slice.35 = bf16[3,8,128] dynamic-update-slice(get-tuple-element.395, ar.1, select.1348, constant.2561, constant.2561)
ROOT tuple = (s32[], bf16[3,8,128]) tuple(add.230, dynamic-update-slice.35)
ROOT tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(add.230, dynamic-update-slice.35, get-tuple-element.5)
}
ENTRY entry {
c0 = s32[] constant(0)
p0 = bf16[3,8,128] parameter(0)
tuple = (s32[], bf16[3,8,128]) tuple(c0, p0)
while = (s32[], bf16[3,8,128]) while(tuple), condition=while_cond, body=while_body
tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(c0, p0, p0)
while = (s32[], bf16[3,8,128], bf16[3,8,128]) while(tuple), condition=while_cond, body=while_body
ROOT gte1 = bf16[3,8,128] get-tuple-element(while), index=1
}
)";
......@@ -288,7 +292,7 @@ ENTRY entry {
EXPECT_EQ(sliced->opcode(), HloOpcode::kDynamicSlice);
const HloInstruction* index = sliced->operand(1);
EXPECT_EQ(index->opcode(), HloOpcode::kGetTupleElement);
EXPECT_EQ(index->tuple_index(), 2);
EXPECT_EQ(index->tuple_index(), 3);
const HloInstruction* while_inst = index->operand(0);
EXPECT_EQ(while_inst->opcode(), HloOpcode::kWhile);
const HloInstruction* while_root =
......@@ -307,7 +311,7 @@ ENTRY entry {
EXPECT_EQ(get_tuple_value->opcode(), HloOpcode::kGetTupleElement);
EXPECT_EQ(get_tuple_index->opcode(), HloOpcode::kGetTupleElement);
EXPECT_EQ(get_tuple_value->tuple_index(), 1);
EXPECT_EQ(get_tuple_index->tuple_index(), 2);
EXPECT_EQ(get_tuple_index->tuple_index(), 3);
}
TEST_F(CollectivePipelinerTest, NoTransformCantProveIndexDoesntWrap) {
......@@ -321,16 +325,17 @@ add {
}
while_cond {
param = (s32[], bf16[3,8,128]) parameter(0)
param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0)
gte = s32[] get-tuple-element(param), index=0
constant.1 = s32[] constant(4)
ROOT cmp = pred[] compare(gte, constant.1), direction=LT
}
while_body {
param = (s32[], bf16[3,8,128]) parameter(0)
param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0)
get-tuple-element.394 = s32[] get-tuple-element(param), index=0
get-tuple-element.395 = bf16[3,8,128] get-tuple-element(param), index=1
get-tuple-element.5 = bf16[3,8,128] get-tuple-element(param), index=2
constant.2557 = s32[] constant(1)
add.230 = s32[] add(get-tuple-element.394, constant.2557)
constant.2559 = s32[] constant(3)
......@@ -342,18 +347,18 @@ while_body {
constant.2562 = s32[] constant(2)
add.232 = s32[] add(subtract.139, constant.2562)
select.1348 = s32[] select(compare.747, add.232, add.231)
dynamic-slice.99 = bf16[1,8,128] dynamic-slice(get-tuple-element.395, select.1348, constant.2561, constant.2561), dynamic_slice_sizes={1,8,128}
dynamic-slice.99 = bf16[1,8,128] dynamic-slice(get-tuple-element.5, select.1348, constant.2561, constant.2561), dynamic_slice_sizes={1,8,128}
mul = bf16[1,8,128] multiply(dynamic-slice.99, dynamic-slice.99)
ar.1 = bf16[1,8,128] all-reduce(mul), replica_groups={}, to_apply=add, channel_id=1
dynamic-update-slice.35 = bf16[3,8,128] dynamic-update-slice(get-tuple-element.395, ar.1, select.1348, constant.2561, constant.2561)
ROOT tuple = (s32[], bf16[3,8,128]) tuple(add.230, dynamic-update-slice.35)
ROOT tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(add.230, dynamic-update-slice.35, get-tuple-element.5)
}
ENTRY entry {
c0 = s32[] constant(-1)
p0 = bf16[3,8,128] parameter(0)
tuple = (s32[], bf16[3,8,128]) tuple(c0, p0)
while = (s32[], bf16[3,8,128]) while(tuple), condition=while_cond, body=while_body
tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(c0, p0, p0)
while = (s32[], bf16[3,8,128], bf16[3,8,128]) while(tuple), condition=while_cond, body=while_body
ROOT gte1 = bf16[3,8,128] get-tuple-element(while), index=1
}
)";
......@@ -373,16 +378,17 @@ add {
}
while_cond {
param = (s32[], bf16[3,8,128]) parameter(0)
param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0)
gte = s32[] get-tuple-element(param), index=0
constant.1 = s32[] constant(0)
ROOT cmp = pred[] compare(gte, constant.1), direction=LT
}
while_body {
param = (s32[], bf16[3,8,128]) parameter(0)
param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0)
get-tuple-element.394 = s32[] get-tuple-element(param), index=0
get-tuple-element.395 = bf16[3,8,128] get-tuple-element(param), index=1
get-tuple-element.5 = bf16[3,8,128] get-tuple-element(param), index=2
constant.2557 = s32[] constant(1)
add.230 = s32[] add(get-tuple-element.394, constant.2557)
constant.2559 = s32[] constant(3)
......@@ -394,18 +400,18 @@ while_body {
constant.2562 = s32[] constant(2)
add.232 = s32[] add(subtract.139, constant.2562)
select.1348 = s32[] select(compare.747, add.232, add.231)
dynamic-slice.99 = bf16[1,8,128] dynamic-slice(get-tuple-element.395, select.1348, constant.2561, constant.2561), dynamic_slice_sizes={1,8,128}
dynamic-slice.99 = bf16[1,8,128] dynamic-slice(get-tuple-element.5, select.1348, constant.2561, constant.2561), dynamic_slice_sizes={1,8,128}
mul = bf16[1,8,128] multiply(dynamic-slice.99, dynamic-slice.99)
ar.1 = bf16[1,8,128] all-reduce(mul), replica_groups={}, to_apply=add, channel_id=1
dynamic-update-slice.35 = bf16[3,8,128] dynamic-update-slice(get-tuple-element.395, ar.1, select.1348, constant.2561, constant.2561)
ROOT tuple = (s32[], bf16[3,8,128]) tuple(add.230, dynamic-update-slice.35)
ROOT tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(add.230, dynamic-update-slice.35, get-tuple-element.5)
}
ENTRY entry {
c0 = s32[] constant(-3)
p0 = bf16[3,8,128] parameter(0)
tuple = (s32[], bf16[3,8,128]) tuple(c0, p0)
while = (s32[], bf16[3,8,128]) while(tuple), condition=while_cond, body=while_body
tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(c0, p0, p0)
while = (s32[], bf16[3,8,128], bf16[3,8,128]) while(tuple), condition=while_cond, body=while_body
ROOT gte1 = bf16[3,8,128] get-tuple-element(while), index=1
}
)";
......@@ -434,16 +440,17 @@ add {
}
while_cond {
param = (s32[], bf16[3,8,128], bf16[1,8,128]) parameter(0)
param = (s32[], bf16[3,8,128], bf16[1,8,128], bf16[3,8,128]) parameter(0)
gte = s32[] get-tuple-element(param), index=0
constant.1 = s32[] constant(0)
ROOT cmp = pred[] compare(gte, constant.1), direction=LT
}
while_body {
param = (s32[], bf16[3,8,128], bf16[1,8,128]) parameter(0)
param = (s32[], bf16[3,8,128], bf16[1,8,128], bf16[3,8,128]) parameter(0)
get-tuple-element.394 = s32[] get-tuple-element(param), index=0
get-tuple-element.395 = bf16[3,8,128] get-tuple-element(param), index=1
get-tuple-element.5 = bf16[3,8,128] get-tuple-element(param), index=3
constant.2557 = s32[] constant(1)
add.230 = s32[] add(get-tuple-element.394, constant.2557)
constant.2559 = s32[] constant(3)
......@@ -460,7 +467,7 @@ while_body {
mul = bf16[1,8,128] multiply(dynamic-slice.99, dynamic-slice.99)
ar.1 = bf16[1,8,128] all-reduce(mul), replica_groups={}, to_apply=add, channel_id=1
dynamic-update-slice.35 = bf16[3,8,128] dynamic-update-slice(get-tuple-element.395, ar.1, select.1348, constant.2561, constant.2561)
ROOT tuple = (s32[], bf16[3,8,128], bf16[1,8,128]) tuple(add.230, dynamic-update-slice.35, dynamic-slice.911)
ROOT tuple = (s32[], bf16[3,8,128], bf16[1,8,128], bf16[3,8,128]) tuple(add.230, dynamic-update-slice.35, dynamic-slice.911, get-tuple-element.5)
}
ENTRY entry {
......@@ -468,8 +475,9 @@ ENTRY entry {
p0 = bf16[3,8,128] parameter(0)
cc = bf16[] constant(0)
c1 = bf16[1,8,128] broadcast(cc), dimensions={}
tuple = (s32[], bf16[3,8,128], bf16[1,8,128]) tuple(c0, p0, c1)
while = (s32[], bf16[3,8,128], bf16[1,8,128]) while(tuple), condition=while_cond, body=while_body
c2 = bf16[3,8,128] broadcast(cc), dimensions={}
tuple = (s32[], bf16[3,8,128], bf16[1,8,128], bf16[3,8,128]) tuple(c0, p0, c1, c2)
while = (s32[], bf16[3,8,128], bf16[1,8,128], bf16[3,8,128]) while(tuple), condition=while_cond, body=while_body
ROOT gte1 = bf16[3,8,128] get-tuple-element(while), index=1
}
)";
......@@ -489,16 +497,17 @@ add {
}
while_cond {
param = (s32[], bf16[3,8,128]) parameter(0)
param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0)
gte = s32[] get-tuple-element(param), index=0
constant.1 = s32[] constant(0)
ROOT cmp = pred[] compare(gte, constant.1), direction=LT
}
while_body {
param = (s32[], bf16[3,8,128]) parameter(0)
param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0)
get-tuple-element.394 = s32[] get-tuple-element(param), index=0
get-tuple-element.395 = bf16[3,8,128] get-tuple-element(param), index=1
get-tuple-element.5 = bf16[3,8,128] get-tuple-element(param), index=2
constant.2557 = s32[] constant(1)
add.230 = s32[] add(get-tuple-element.394, constant.2557)
constant.2559 = s32[] constant(3)
......@@ -510,20 +519,20 @@ while_body {
constant.2562 = s32[] constant(2)
add.232 = s32[] add(subtract.139, constant.2562)
select.1348 = s32[] select(compare.747, add.232, add.231)
dynamic-slice.99 = bf16[1,8,128] dynamic-slice(get-tuple-element.395, select.1348, constant.2561, constant.2561), dynamic_slice_sizes={1,8,128}
dynamic-slice.99 = bf16[1,8,128] dynamic-slice(get-tuple-element.5, select.1348, constant.2561, constant.2561), dynamic_slice_sizes={1,8,128}
mul = bf16[1,8,128] multiply(dynamic-slice.99, dynamic-slice.99)
rs.1 = bf16[1,1,128] reduce-scatter(mul), replica_groups={}, to_apply=add, channel_id=1, dimensions={1}
ag.1 = bf16[1,8,128] all-gather(rs.1), replica_groups={}, channel_id=2, dimensions={1}
dynamic-update-slice.35 = bf16[3,8,128] dynamic-update-slice(get-tuple-element.395, ag.1, select.1348, constant.2561, constant.2561)
ROOT tuple = (s32[], bf16[3,8,128]) tuple(add.230, dynamic-update-slice.35)
ROOT tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(add.230, dynamic-update-slice.35, get-tuple-element.5)
}
ENTRY entry {
c0 = s32[] constant(-3)
p0 = bf16[3,8,128] parameter(0)
cc = bf16[] constant(0)
tuple = (s32[], bf16[3,8,128]) tuple(c0, p0)
while = (s32[], bf16[3,8,128]) while(tuple), condition=while_cond, body=while_body
tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(c0, p0, p0)
while = (s32[], bf16[3,8,128], bf16[3,8,128]) while(tuple), condition=while_cond, body=while_body
ROOT gte1 = bf16[3,8,128] get-tuple-element(while), index=1
}
)";
......@@ -552,16 +561,17 @@ add {
}
while_cond {
param = (s32[], bf16[3,9,128]) parameter(0)
param = (s32[], bf16[3,9,128], bf16[3,9,128]) parameter(0)
gte = s32[] get-tuple-element(param), index=0
constant.1 = s32[] constant(0)
ROOT cmp = pred[] compare(gte, constant.1), direction=LT
}
while_body {
param = (s32[], bf16[3,9,128]) parameter(0)
param = (s32[], bf16[3,9,128], bf16[3,9,128]) parameter(0)
get-tuple-element.394 = s32[] get-tuple-element(param), index=0
get-tuple-element.395 = bf16[3,9,128] get-tuple-element(param), index=1
get-tuple-element.5 = bf16[3,9,128] get-tuple-element(param), index=2
constant.2557 = s32[] constant(1)
add.230 = s32[] add(get-tuple-element.394, constant.2557)
constant.2559 = s32[] constant(3)
......@@ -573,7 +583,7 @@ while_body {
constant.2562 = s32[] constant(2)
add.232 = s32[] add(subtract.139, constant.2562)
select.1348 = s32[] select(compare.747, add.232, add.231)
dynamic-slice.99 = bf16[1,9,128] dynamic-slice(get-tuple-element.395, select.1348, constant.2561, constant.2561), dynamic_slice_sizes={1,9,128}
dynamic-slice.99 = bf16[1,9,128] dynamic-slice(get-tuple-element.5, select.1348, constant.2561, constant.2561), dynamic_slice_sizes={1,9,128}
mul = bf16[1,9,128] multiply(dynamic-slice.99, dynamic-slice.99)
cpd = bf16[] constant(0)
%pd = bf16[1,16,128] pad(mul, cpd), padding=0_0x0_7x0_0
......@@ -581,15 +591,15 @@ while_body {
ag.1 = bf16[1,16,128] all-gather(rs.1), replica_groups={}, channel_id=2, dimensions={1}
slc = bf16[1,9,128] slice(ag.1), slice={[0:1], [0:9], [0:128]}
dynamic-update-slice.35 = bf16[3,9,128] dynamic-update-slice(get-tuple-element.395, slc, select.1348, constant.2561, constant.2561)
ROOT tuple = (s32[], bf16[3,9,128]) tuple(add.230, dynamic-update-slice.35)
ROOT tuple = (s32[], bf16[3,9,128], bf16[3,9,128]) tuple(add.230, dynamic-update-slice.35, get-tuple-element.5)
}
ENTRY entry {
c0 = s32[] constant(-3)
p0 = bf16[3,9,128] parameter(0)
cc = bf16[] constant(0)
tuple = (s32[], bf16[3,9,128]) tuple(c0, p0)
while = (s32[], bf16[3,9,128]) while(tuple), condition=while_cond, body=while_body
tuple = (s32[], bf16[3,9,128], bf16[3,9,128]) tuple(c0, p0, p0)
while = (s32[], bf16[3,9,128], bf16[3,9,128]) while(tuple), condition=while_cond, body=while_body
ROOT gte1 = bf16[3,9,128] get-tuple-element(while), index=1
}
)";
......@@ -619,33 +629,34 @@ add {
}
while_cond {
param = (s32[], bf16[3,8,128]) parameter(0)
param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0)
gte = s32[] get-tuple-element(param), index=0
constant.1 = s32[] constant(0)
ROOT cmp = pred[] compare(gte, constant.1), direction=LT
}
while_body {
param = (s32[], bf16[3,8,128]) parameter(0)
param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0)
get-tuple-element.394 = s32[] get-tuple-element(param), index=0
get-tuple-element.395 = bf16[3,8,128] get-tuple-element(param), index=1
get-tuple-element.5 = bf16[3,8,128] get-tuple-element(param), index=2
constant.2557 = s32[] constant(1)
constant.2561 = s32[] constant(0)
add.230 = s32[] add(get-tuple-element.394, constant.2557)
dynamic-slice.99 = bf16[1,8,128] dynamic-slice(get-tuple-element.395, get-tuple-element.394, constant.2561, constant.2561), dynamic_slice_sizes={1,8,128}
dynamic-slice.99 = bf16[1,8,128] dynamic-slice(get-tuple-element.5, get-tuple-element.394, constant.2561, constant.2561), dynamic_slice_sizes={1,8,128}
mul = bf16[1,8,128] multiply(dynamic-slice.99, dynamic-slice.99)
rs.1 = bf16[1,1,128] reduce-scatter(mul), replica_groups={}, to_apply=add, channel_id=1, dimensions={1}
ag.1 = bf16[1,8,128] all-gather(rs.1), replica_groups={}, channel_id=2, dimensions={1}
dynamic-update-slice.35 = bf16[3,8,128] dynamic-update-slice(get-tuple-element.395, ag.1, get-tuple-element.394, constant.2561, constant.2561)
ROOT tuple = (s32[], bf16[3,8,128]) tuple(add.230, dynamic-update-slice.35)
ROOT tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(add.230, dynamic-update-slice.35, get-tuple-element.5)
}
ENTRY entry {
c0 = s32[] constant(-8)
p0 = bf16[3,8,128] parameter(0)
cc = bf16[] constant(0)
tuple = (s32[], bf16[3,8,128]) tuple(c0, p0)
while = (s32[], bf16[3,8,128]) while(tuple), condition=while_cond, body=while_body
tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(c0, p0, p0)
while = (s32[], bf16[3,8,128], bf16[3,8,128]) while(tuple), condition=while_cond, body=while_body
ROOT gte1 = bf16[3,8,128] get-tuple-element(while), index=1
}
)";
......@@ -1093,17 +1104,18 @@ add {
}
while_cond {
param = (s32[], bf16[3,8,128], bf16[3,1,2,128]) parameter(0)
param = (s32[], bf16[3,8,128], bf16[3,1,2,128], bf16[3,8,128]) parameter(0)
gte = s32[] get-tuple-element(param), index=0
constant.1 = s32[] constant(3)
ROOT cmp = pred[] compare(gte, constant.1), direction=LT
}
while_body {
param = (s32[], bf16[3,8,128], bf16[3,1,2,128]) parameter(0)
param = (s32[], bf16[3,8,128], bf16[3,1,2,128], bf16[3,8,128]) parameter(0)
get-tuple-element.394 = s32[] get-tuple-element(param), index=0
get-tuple-element.395 = bf16[3,8,128] get-tuple-element(param), index=1
get-tuple-element.k = bf16[3,1,2,128] get-tuple-element(param), index=2
get-tuple-element.5 = bf16[3,8,128] get-tuple-element(param), index=3
constant.2561 = s32[] constant(0)
constant.2557 = s32[] constant(1)
add.230 = s32[] add(get-tuple-element.394, constant.2557)
......@@ -1119,19 +1131,19 @@ while_body {
r = bf16[1,2,128] reshape(dynamic-slice.k)
a = bf16[1,2,128] add(r, r)
ag = bf16[1,8,128] all-gather(a), dimensions={1}, replica_groups={}
dynamic-slice.99 = bf16[1,8,128] dynamic-slice(get-tuple-element.395, select.1348, constant.2561, constant.2561), dynamic_slice_sizes={1,8,128}
dynamic-slice.99 = bf16[1,8,128] dynamic-slice(get-tuple-element.5, select.1348, constant.2561, constant.2561), dynamic_slice_sizes={1,8,128}
mul = bf16[1,8,128] multiply(dynamic-slice.99, ag)
ar.1 = bf16[1,8,128] all-reduce(mul), replica_groups={}, to_apply=add, channel_id=1
dynamic-update-slice.35 = bf16[3,8,128] dynamic-update-slice(get-tuple-element.395, ar.1, select.1348, constant.2561, constant.2561)
ROOT tuple = (s32[], bf16[3,8,128], bf16[3,1,2,128]) tuple(add.230, dynamic-update-slice.35, get-tuple-element.k)
ROOT tuple = (s32[], bf16[3,8,128], bf16[3,1,2,128], bf16[3,8,128]) tuple(add.230, dynamic-update-slice.35, get-tuple-element.k, get-tuple-element.5)
}
ENTRY entry {
c0 = s32[] constant(0)
p0 = bf16[3,8,128] parameter(0)
p1 = bf16[3,1,2,128] parameter(1)
tuple = (s32[], bf16[3,8,128], bf16[3,1,2,128]) tuple(c0, p0, p1)
while = (s32[], bf16[3,8,128], bf16[3,1,2,128]) while(tuple), condition=while_cond, body=while_body
tuple = (s32[], bf16[3,8,128], bf16[3,1,2,128], bf16[3,8,128]) tuple(c0, p0, p1, p0)
while = (s32[], bf16[3,8,128], bf16[3,1,2,128], bf16[3,8,128]) while(tuple), condition=while_cond, body=while_body
ROOT gte1 = bf16[3,8,128] get-tuple-element(while), index=1
}
)";
......@@ -1163,17 +1175,18 @@ add {
}
while_cond {
param = (s32[], f32[3,8,128], bf16[3,1,2,128]) parameter(0)
param = (s32[], f32[3,8,128], bf16[3,1,2,128], f32[3,8,128]) parameter(0)
gte = s32[] get-tuple-element(param), index=0
constant.1 = s32[] constant(3)
ROOT cmp = pred[] compare(gte, constant.1), direction=LT
}
while_body {
param = (s32[], f32[3,8,128], bf16[3,1,2,128]) parameter(0)
param = (s32[], f32[3,8,128], bf16[3,1,2,128], f32[3,8,128]) parameter(0)
get-tuple-element.394 = s32[] get-tuple-element(param), index=0
get-tuple-element.395 = f32[3,8,128] get-tuple-element(param), index=1
get-tuple-element.k = bf16[3,1,2,128] get-tuple-element(param), index=2
get-tuple-element.5 = f32[3,8,128] get-tuple-element(param), index=3
constant.2561 = s32[] constant(0)
constant.2557 = s32[] constant(1)
add.230 = s32[] add(get-tuple-element.394, constant.2557)
......@@ -1189,21 +1202,21 @@ while_body {
r = bf16[1,2,128] reshape(dynamic-slice.k)
a = bf16[1,2,128] add(r, r)
ag = bf16[1,8,128] all-gather(a), dimensions={1}, replica_groups={}
dynamic-slice.99 = f32[1,8,128] dynamic-slice(get-tuple-element.395, select.1348, constant.2561, constant.2561), dynamic_slice_sizes={1,8,128}
dynamic-slice.99 = f32[1,8,128] dynamic-slice(get-tuple-element.5, select.1348, constant.2561, constant.2561), dynamic_slice_sizes={1,8,128}
cvt.0 = bf16[1,8,128] convert(dynamic-slice.99)
mul = bf16[1,8,128] multiply(cvt.0, ag)
ar.1 = bf16[1,8,128] all-reduce(mul), replica_groups={}, to_apply=add, channel_id=1
cvt.1 = f32[1,8,128] convert(ar.1)
dynamic-update-slice.35 = f32[3,8,128] dynamic-update-slice(get-tuple-element.395, cvt.1, select.1348, constant.2561, constant.2561)
ROOT tuple = (s32[], f32[3,8,128], bf16[3,1,2,128]) tuple(add.230, dynamic-update-slice.35, get-tuple-element.k)
ROOT tuple = (s32[], f32[3,8,128], bf16[3,1,2,128], f32[3,8,128]) tuple(add.230, dynamic-update-slice.35, get-tuple-element.k, get-tuple-element.5)
}
ENTRY entry {
c0 = s32[] constant(0)
p0 = f32[3,8,128] parameter(0)
p1 = bf16[3,1,2,128] parameter(1)
tuple = (s32[], f32[3,8,128], bf16[3,1,2,128]) tuple(c0, p0, p1)
while = (s32[], f32[3,8,128], bf16[3,1,2,128]) while(tuple), condition=while_cond, body=while_body
tuple = (s32[], f32[3,8,128], bf16[3,1,2,128], f32[3,8,128]) tuple(c0, p0, p1, p0)
while = (s32[], f32[3,8,128], bf16[3,1,2,128], f32[3,8,128]) while(tuple), condition=while_cond, body=while_body
ROOT gte1 = f32[3,8,128] get-tuple-element(while), index=1
}
)";
......@@ -1234,16 +1247,17 @@ add {
}
while_cond {
param = (s32[], bf16[3,8,128]) parameter(0)
param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0)
gte = s32[] get-tuple-element(param), index=0
constant.1 = s32[] constant(3)
ROOT cmp = pred[] compare(gte, constant.1), direction=LT
}
while_body {
param = (s32[], bf16[3,8,128]) parameter(0)
param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0)
get-tuple-element.394 = s32[] get-tuple-element(param), index=0
get-tuple-element.395 = bf16[3,8,128] get-tuple-element(param), index=1
get-tuple-element.5 = bf16[3,8,128] get-tuple-element(param), index=2
constant.2557 = s32[] constant(1)
add.230 = s32[] add(get-tuple-element.394, constant.2557)
constant.2559 = s32[] constant(3)
......@@ -1255,7 +1269,7 @@ while_body {
constant.2562 = s32[] constant(2)
add.232 = s32[] add(subtract.139, constant.2562)
select.1348 = s32[] select(compare.747, add.232, add.231)
dynamic-slice.99 = bf16[1,8,128] dynamic-slice(get-tuple-element.395, select.1348, constant.2561, constant.2561), dynamic_slice_sizes={1,8,128}
dynamic-slice.99 = bf16[1,8,128] dynamic-slice(get-tuple-element.5, select.1348, constant.2561, constant.2561), dynamic_slice_sizes={1,8,128}
mul = bf16[1,8,128] multiply(dynamic-slice.99, dynamic-slice.99)
c2 = bf16[] constant(2.0)
bc = bf16[1,8,128] broadcast(c2)
......@@ -1264,14 +1278,14 @@ while_body {
mul3 = bf16[1,8,128] multiply(mul2, ar.1)
mul4 = bf16[1,8,128] multiply(mul3, mul)
dynamic-update-slice.35 = bf16[3,8,128] dynamic-update-slice(get-tuple-element.395, mul4, select.1348, constant.2561, constant.2561)
ROOT tuple = (s32[], bf16[3,8,128]) tuple(add.230, dynamic-update-slice.35)
ROOT tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(add.230, dynamic-update-slice.35, get-tuple-element.5)
}
ENTRY entry {
c0 = s32[] constant(0)
p0 = bf16[3,8,128] parameter(0)
tuple = (s32[], bf16[3,8,128]) tuple(c0, p0)
while = (s32[], bf16[3,8,128]) while(tuple), condition=while_cond, body=while_body
tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(c0, p0, p0)
while = (s32[], bf16[3,8,128], bf16[3,8,128]) while(tuple), condition=while_cond, body=while_body
ROOT gte1 = bf16[3,8,128] get-tuple-element(while), index=1
}
)";
......@@ -1295,16 +1309,17 @@ add {
}
while_cond {
param = (s32[], bf16[3,8,128]) parameter(0)
param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0)
gte = s32[] get-tuple-element(param), index=0
constant.1 = s32[] constant(3)
ROOT cmp = pred[] compare(gte, constant.1), direction=LT
}
while_body {
param = (s32[], bf16[3,8,128]) parameter(0)
param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0)
get-tuple-element.394 = s32[] get-tuple-element(param), index=0
get-tuple-element.395 = bf16[3,8,128] get-tuple-element(param), index=1
get-tuple-element.5 = bf16[3,8,128] get-tuple-element(param), index=2
constant.2557 = s32[] constant(1)
add.230 = s32[] add(get-tuple-element.394, constant.2557)
constant.2559 = s32[] constant(3)
......@@ -1316,19 +1331,19 @@ while_body {
constant.2562 = s32[] constant(2)
add.232 = s32[] add(subtract.139, constant.2562)
select.1348 = s32[] select(compare.747, add.232, add.231)
dynamic-slice.99 = bf16[1,8,128] dynamic-slice(get-tuple-element.395, select.1348, constant.2561, constant.2561), dynamic_slice_sizes={1,8,128}
dynamic-slice.99 = bf16[1,8,128] dynamic-slice(get-tuple-element.5, select.1348, constant.2561, constant.2561), dynamic_slice_sizes={1,8,128}
mul = bf16[1,8,128] multiply(dynamic-slice.99, dynamic-slice.99)
ar.1 = bf16[1,8,128] all-reduce(mul), replica_groups={}, to_apply=add, channel_id=1
mul2 = bf16[1,8,128] multiply(ar.1, mul)
dynamic-update-slice.35 = bf16[3,8,128] dynamic-update-slice(get-tuple-element.395, mul2, select.1348, constant.2561, constant.2561)
ROOT tuple = (s32[], bf16[3,8,128]) tuple(add.230, dynamic-update-slice.35)
ROOT tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(add.230, dynamic-update-slice.35, get-tuple-element.5)
}
ENTRY entry {
c0 = s32[] constant(0)
p0 = bf16[3,8,128] parameter(0)
tuple = (s32[], bf16[3,8,128]) tuple(c0, p0)
while = (s32[], bf16[3,8,128]) while(tuple), condition=while_cond, body=while_body
tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(c0, p0, p0)
while = (s32[], bf16[3,8,128], bf16[3,8,128]) while(tuple), condition=while_cond, body=while_body
ROOT gte1 = bf16[3,8,128] get-tuple-element(while), index=1
}
)";
......@@ -1341,5 +1356,136 @@ ENTRY entry {
XLA_VLOG_LINES(1, module->ToString());
}
TEST_F(CollectivePipelinerTest, TransformIncrementIndexByOneNotFirstIdxSink) {
constexpr absl::string_view hlo_string = R"(
HloModule module
add {
lhs = bf16[] parameter(0)
rhs = bf16[] parameter(1)
ROOT add = bf16[] add(lhs, rhs)
}
while_cond {
param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0)
gte = s32[] get-tuple-element(param), index=0
constant.1 = s32[] constant(3)
ROOT cmp = pred[] compare(gte, constant.1), direction=LT
}
while_body {
param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0)
get-tuple-element.394 = s32[] get-tuple-element(param), index=0
get-tuple-element.395 = bf16[3,8,128] get-tuple-element(param), index=1
get-tuple-element.35 = bf16[3,8,128] get-tuple-element(param), index=2
constant.2557 = s32[] constant(1)
add.230 = s32[] add(get-tuple-element.394, constant.2557)
constant.2559 = s32[] constant(3)
subtract.139 = s32[] subtract(constant.2559, get-tuple-element.394)
constant.2560 = s32[] constant(-1)
add.231 = s32[] add(subtract.139, constant.2560)
constant.2561 = s32[] constant(0)
compare.747 = pred[] compare(add.231, constant.2561), direction=LT
constant.2562 = s32[] constant(2)
add.232 = s32[] add(subtract.139, constant.2562)
select.1348 = s32[] select(compare.747, add.232, add.231)
dynamic-slice.99 = bf16[1,8,128] dynamic-slice(get-tuple-element.35, select.1348, constant.2561, constant.2561), dynamic_slice_sizes={1,8,128}
mul = bf16[1,8,128] multiply(dynamic-slice.99, dynamic-slice.99)
ar.1 = bf16[1,8,128] all-reduce(mul), replica_groups={}, to_apply=add, channel_id=1
%c = bf16[] custom-call(), custom_call_target="Boh"
%b = bf16[1,8,128] broadcast(c), dimensions={}
%a = bf16[1,8,128] add(ar.1, b)
dynamic-update-slice.35 = bf16[3,8,128] dynamic-update-slice(get-tuple-element.395, a, select.1348, constant.2561, constant.2561)
ROOT tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(add.230, dynamic-update-slice.35, get-tuple-element.35)
}
ENTRY entry {
c0 = s32[] constant(0)
p0 = bf16[3,8,128] parameter(0)
tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(c0, p0, p0)
while = (s32[], bf16[3,8,128], bf16[3,8,128]) while(tuple), condition=while_cond, body=while_body
ROOT gte1 = bf16[3,8,128] get-tuple-element(while), index=1
}
)";
auto module = ParseAndReturnUnverifiedModule(hlo_string, config_).value();
EXPECT_TRUE(RunOptimizer(module.get(), /*last_run=*/true,
/*level_to_operate_on=*/0,
/*pipeline_use_tree=*/true,
/*process_different_sized_ops=*/true,
CollectivePipeliner::kForwardSink)
.value());
XLA_VLOG_LINES(0, module->ToString());
}
TEST_F(CollectivePipelinerTest,
TransformIncrementIndexByOneNotFirstIdxSinkCustomCall) {
constexpr absl::string_view hlo_string = R"(
HloModule module
add {
lhs = bf16[] parameter(0)
rhs = bf16[] parameter(1)
ROOT add = bf16[] add(lhs, rhs)
}
while_cond {
param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0)
gte = s32[] get-tuple-element(param), index=0
constant.1 = s32[] constant(3)
ROOT cmp = pred[] compare(gte, constant.1), direction=LT
}
while_body {
param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0)
get-tuple-element.394 = s32[] get-tuple-element(param), index=0
get-tuple-element.395 = bf16[3,8,128] get-tuple-element(param), index=1
get-tuple-element.35 = bf16[3,8,128] get-tuple-element(param), index=2
constant.2557 = s32[] constant(1)
add.230 = s32[] add(get-tuple-element.394, constant.2557)
constant.2559 = s32[] constant(3)
subtract.139 = s32[] subtract(constant.2559, get-tuple-element.394)
constant.2560 = s32[] constant(-1)
add.231 = s32[] add(subtract.139, constant.2560)
constant.2561 = s32[] constant(0)
compare.747 = pred[] compare(add.231, constant.2561), direction=LT
constant.2562 = s32[] constant(2)
add.232 = s32[] add(subtract.139, constant.2562)
select.1348 = s32[] select(compare.747, add.232, add.231)
dynamic-slice.99 = bf16[1,8,128] dynamic-slice(get-tuple-element.35, select.1348, constant.2561, constant.2561), dynamic_slice_sizes={1,8,128}
mul = bf16[1,8,128] multiply(dynamic-slice.99, dynamic-slice.99)
ar.1 = bf16[1,8,128] all-reduce(mul), replica_groups={}, to_apply=add, channel_id=1
%c = bf16[] custom-call(), custom_call_target="Boh"
%b = bf16[1,8,128] broadcast(c), dimensions={}
%a = bf16[1,8,128] add(ar.1, b)
dynamic-update-slice.35 = bf16[3,8,128] dynamic-update-slice(get-tuple-element.395, a, select.1348, constant.2561, constant.2561)
ROOT tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(add.230, dynamic-update-slice.35, get-tuple-element.35)
}
ENTRY entry {
c0 = s32[] constant(0)
p0 = bf16[3,8,128] parameter(0)
tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(c0, p0, p0)
while = (s32[], bf16[3,8,128], bf16[3,8,128]) while(tuple), condition=while_cond, body=while_body
ROOT gte1 = bf16[3,8,128] get-tuple-element(while), index=1
}
)";
auto module = ParseAndReturnUnverifiedModule(hlo_string, config_).value();
EXPECT_TRUE(RunOptimizer(module.get(), /*last_run=*/false,
/*level_to_operate_on=*/0,
/*pipeline_use_tree=*/true,
/*process_different_sized_ops=*/true,
CollectivePipeliner::kForwardSink)
.value());
XLA_VLOG_LINES(1, module->ToString());
const HloInstruction* all_reduce = module->entry_computation()
->root_instruction()
->operand(0)
->operand(1)
->operand(0)
->operand(0);
EXPECT_EQ(all_reduce->opcode(), HloOpcode::kAllReduce);
EXPECT_EQ(all_reduce->shape().dimensions(0), 3);
}
} // namespace
} // namespace xla
......@@ -57,8 +57,10 @@ StatusOr<bool> RunOptimizer(
pass.AddPass<CollectivePipeliner>(config);
pass.AddPass<HloVerifier>(/*layout_sensitive=*/false,
/*allow_mixed_precision=*/false);
pass.AddPass<HloDCE>(/*remove_cross_partition_collective_ops=*/true);
return pass.Run(module);
TF_ASSIGN_OR_RETURN(const bool modified, pass.Run(module));
HloPassPipeline pass_dce("dce");
pass_dce.AddPass<HloDCE>(/*remove_cross_partition_collective_ops=*/true);
return modified;
}
TEST_F(CollectivePipelinerExecutionTest, TransformIncrementIndexByOne) {
......@@ -72,16 +74,17 @@ add {
}
while_cond {
param = (s32[], bf16[3,8,128]) parameter(0)
param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0)
gte = s32[] get-tuple-element(param), index=0
constant.1 = s32[] constant(3)
ROOT cmp = pred[] compare(gte, constant.1), direction=LT
}
while_body {
param = (s32[], bf16[3,8,128]) parameter(0)
param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0)
get-tuple-element.394 = s32[] get-tuple-element(param), index=0
get-tuple-element.395 = bf16[3,8,128] get-tuple-element(param), index=1
get-tuple-element.5 = bf16[3,8,128] get-tuple-element(param), index=2
constant.2557 = s32[] constant(1)
add.230 = s32[] add(get-tuple-element.394, constant.2557)
constant.2559 = s32[] constant(3)
......@@ -93,25 +96,25 @@ while_body {
constant.2562 = s32[] constant(2)
add.232 = s32[] add(subtract.139, constant.2562)
select.1348 = s32[] select(compare.747, add.232, add.231)
dynamic-slice.99 = bf16[1,8,128] dynamic-slice(get-tuple-element.395, select.1348, constant.2561, constant.2561), dynamic_slice_sizes={1,8,128}
dynamic-slice.99 = bf16[1,8,128] dynamic-slice(get-tuple-element.5, select.1348, constant.2561, constant.2561), dynamic_slice_sizes={1,8,128}
mul = bf16[1,8,128] multiply(dynamic-slice.99, dynamic-slice.99)
ar.1 = bf16[1,8,128] negate(mul)
dynamic-update-slice.35 = bf16[3,8,128] dynamic-update-slice(get-tuple-element.395, ar.1, select.1348, constant.2561, constant.2561)
ROOT tuple = (s32[], bf16[3,8,128]) tuple(add.230, dynamic-update-slice.35)
ROOT tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(add.230, dynamic-update-slice.35, get-tuple-element.5)
}
ENTRY entry {
c0 = s32[] constant(0)
p0 = bf16[3,8,128] parameter(0)
tuple = (s32[], bf16[3,8,128]) tuple(c0, p0)
while = (s32[], bf16[3,8,128]) while(tuple), condition=while_cond, body=while_body
tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(c0, p0, p0)
while = (s32[], bf16[3,8,128], bf16[3,8,128]) while(tuple), condition=while_cond, body=while_body
ROOT gte1 = bf16[3,8,128] get-tuple-element(while), index=1
}
)";
auto module = ParseAndReturnUnverifiedModule(hlo_string).value();
auto module2 = ParseAndReturnUnverifiedModule(hlo_string).value();
EXPECT_TRUE(RunOptimizer(module.get(), /*last_run=*/true, 0).value());
EXPECT_TRUE(RunOptimizer(module2.get(), /*last_run=*/true, 200).value());
EXPECT_FALSE(RunOptimizer(module2.get(), /*last_run=*/true, 200).value());
XLA_VLOG_LINES(1, module->ToString());
XLA_VLOG_LINES(1, module2->ToString());
EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(module2),
......@@ -192,59 +195,56 @@ ENTRY %entry (p0: bf16[3,8,128]) -> bf16[3,8,128] {
TEST_F(CollectivePipelinerExecutionTest,
TransformIncrementIndexByOneNotFirstIdx) {
constexpr absl::string_view hlo_string = R"(
HloModule module
HloModule module
add {
lhs = bf16[] parameter(0)
rhs = bf16[] parameter(1)
ROOT add = bf16[] add(lhs, rhs)
}
add {
lhs = bf16[] parameter(0)
rhs = bf16[] parameter(1)
ROOT add = bf16[] add(lhs, rhs)
}
while_cond {
param = (s32[], bf16[8,3,128]) parameter(0)
gte = s32[] get-tuple-element(param), index=0
constant.1 = s32[] constant(3)
ROOT cmp = pred[] compare(gte, constant.1), direction=LT
}
while_cond {
param = (s32[], bf16[8,3,128], bf16[8,3,128]) parameter(0)
gte = s32[] get-tuple-element(param), index=0
constant.1 = s32[] constant(3)
ROOT cmp = pred[] compare(gte, constant.1), direction=LT
}
while_body {
param = (s32[], bf16[8,3,128]) parameter(0)
get-tuple-element.394 = s32[] get-tuple-element(param), index=0
get-tuple-element.395 = bf16[8,3,128] get-tuple-element(param), index=1
constant.2557 = s32[] constant(1)
add.230 = s32[] add(get-tuple-element.394, constant.2557)
constant.2559 = s32[] constant(3)
subtract.139 = s32[] subtract(constant.2559, get-tuple-element.394)
constant.2560 = s32[] constant(-1)
add.231 = s32[] add(subtract.139, constant.2560)
constant.2561 = s32[] constant(0)
compare.747 = pred[] compare(add.231, constant.2561), direction=LT
constant.2562 = s32[] constant(2)
add.232 = s32[] add(subtract.139, constant.2562)
select.1348 = s32[] select(compare.747, add.232, add.231)
dynamic-slice.99 = bf16[8,1,128] dynamic-slice(get-tuple-element.395,
constant.2561, select.1348, constant.2561), dynamic_slice_sizes={8,1,128}
mul = bf16[8,1,128] multiply(dynamic-slice.99, dynamic-slice.99)
ar.1 = bf16[8,1,128] negate(mul)
dynamic-update-slice.35 = bf16[8,3,128]
dynamic-update-slice(get-tuple-element.395, ar.1, constant.2561,
select.1348, constant.2561) ROOT tuple = (s32[], bf16[8,3,128])
tuple(add.230, dynamic-update-slice.35)
}
while_body {
param = (s32[], bf16[8,3,128], bf16[8,3,128]) parameter(0)
get-tuple-element.394 = s32[] get-tuple-element(param), index=0
get-tuple-element.395 = bf16[8,3,128] get-tuple-element(param), index=1
get-tuple-element.5 = bf16[8,3,128] get-tuple-element(param), index=2
constant.2557 = s32[] constant(1)
add.230 = s32[] add(get-tuple-element.394, constant.2557)
constant.2559 = s32[] constant(3)
subtract.139 = s32[] subtract(constant.2559, get-tuple-element.394)
constant.2560 = s32[] constant(-1)
add.231 = s32[] add(subtract.139, constant.2560)
constant.2561 = s32[] constant(0)
compare.747 = pred[] compare(add.231, constant.2561), direction=LT
constant.2562 = s32[] constant(2)
add.232 = s32[] add(subtract.139, constant.2562)
select.1348 = s32[] select(compare.747, add.232, add.231)
dynamic-slice.99 = bf16[8,1,128] dynamic-slice(get-tuple-element.5, constant.2561, select.1348, constant.2561), dynamic_slice_sizes={8,1,128}
mul = bf16[8,1,128] multiply(dynamic-slice.99, dynamic-slice.99)
ar.1 = bf16[8,1,128] negate(mul)
dynamic-update-slice.35 = bf16[8,3,128] dynamic-update-slice(get-tuple-element.395, ar.1, constant.2561, select.1348, constant.2561)
ROOT tuple = (s32[], bf16[8,3,128], bf16[8,3,128]) tuple(add.230, dynamic-update-slice.35, get-tuple-element.5)
}
ENTRY entry {
c0 = s32[] constant(0)
p0 = bf16[8,3,128] parameter(0)
tuple = (s32[], bf16[8,3,128]) tuple(c0, p0)
while = (s32[], bf16[8,3,128]) while(tuple), condition=while_cond,
body=while_body ROOT gte1 = bf16[8,3,128] get-tuple-element(while),
index=1
}
ENTRY entry {
c0 = s32[] constant(0)
p0 = bf16[8,3,128] parameter(0)
tuple = (s32[], bf16[8,3,128], bf16[8,3,128]) tuple(c0, p0, p0)
while = (s32[], bf16[8,3,128], bf16[8,3,128]) while(tuple), condition=while_cond, body=while_body
ROOT gte1 = bf16[8,3,128] get-tuple-element(while), index=1
}
)";
auto module = ParseAndReturnUnverifiedModule(hlo_string).value();
auto module2 = ParseAndReturnUnverifiedModule(hlo_string).value();
EXPECT_TRUE(RunOptimizer(module.get(), /*last_run=*/true, 0).value());
EXPECT_TRUE(RunOptimizer(module2.get(), /*last_run=*/true, 200).value());
EXPECT_FALSE(RunOptimizer(module2.get(), /*last_run=*/true, 200).value());
XLA_VLOG_LINES(1, module->ToString());
XLA_VLOG_LINES(1, module2->ToString());
EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(module2),
......@@ -253,59 +253,56 @@ TEST_F(CollectivePipelinerExecutionTest,
TEST_F(CollectivePipelinerExecutionTest, TransformIncrementByTwo) {
constexpr absl::string_view hlo_string = R"(
HloModule module
HloModule module
add {
lhs = bf16[] parameter(0)
rhs = bf16[] parameter(1)
ROOT add = bf16[] add(lhs, rhs)
}
add {
lhs = bf16[] parameter(0)
rhs = bf16[] parameter(1)
ROOT add = bf16[] add(lhs, rhs)
}
while_cond {
param = (s32[], bf16[3,8,128]) parameter(0)
gte = s32[] get-tuple-element(param), index=0
constant.1 = s32[] constant(3)
ROOT cmp = pred[] compare(gte, constant.1), direction=LT
}
while_cond {
param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0)
gte = s32[] get-tuple-element(param), index=0
constant.1 = s32[] constant(3)
ROOT cmp = pred[] compare(gte, constant.1), direction=LT
}
while_body {
param = (s32[], bf16[3,8,128]) parameter(0)
get-tuple-element.394 = s32[] get-tuple-element(param), index=0
get-tuple-element.395 = bf16[3,8,128] get-tuple-element(param), index=1
constant.2557 = s32[] constant(2)
add.230 = s32[] add(get-tuple-element.394, constant.2557)
constant.2559 = s32[] constant(3)
subtract.139 = s32[] subtract(constant.2559, get-tuple-element.394)
constant.2560 = s32[] constant(-1)
add.231 = s32[] add(subtract.139, constant.2560)
constant.2561 = s32[] constant(0)
compare.747 = pred[] compare(add.231, constant.2561), direction=LT
constant.2562 = s32[] constant(2)
add.232 = s32[] add(subtract.139, constant.2562)
select.1348 = s32[] select(compare.747, add.232, add.231)
dynamic-slice.99 = bf16[1,8,128] dynamic-slice(get-tuple-element.395,
select.1348, constant.2561, constant.2561), dynamic_slice_sizes={1,8,128}
mul = bf16[1,8,128] multiply(dynamic-slice.99, dynamic-slice.99)
ar.1 = bf16[1,8,128] negate(mul)
dynamic-update-slice.35 = bf16[3,8,128]
dynamic-update-slice(get-tuple-element.395, ar.1, select.1348,
constant.2561, constant.2561) ROOT tuple = (s32[], bf16[3,8,128])
tuple(add.230, dynamic-update-slice.35)
}
while_body {
param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0)
get-tuple-element.394 = s32[] get-tuple-element(param), index=0
get-tuple-element.395 = bf16[3,8,128] get-tuple-element(param), index=1
get-tuple-element.5 = bf16[3,8,128] get-tuple-element(param), index=2
constant.2557 = s32[] constant(2)
add.230 = s32[] add(get-tuple-element.394, constant.2557)
constant.2559 = s32[] constant(3)
subtract.139 = s32[] subtract(constant.2559, get-tuple-element.394)
constant.2560 = s32[] constant(-1)
add.231 = s32[] add(subtract.139, constant.2560)
constant.2561 = s32[] constant(0)
compare.747 = pred[] compare(add.231, constant.2561), direction=LT
constant.2562 = s32[] constant(2)
add.232 = s32[] add(subtract.139, constant.2562)
select.1348 = s32[] select(compare.747, add.232, add.231)
dynamic-slice.99 = bf16[1,8,128] dynamic-slice(get-tuple-element.5, select.1348, constant.2561, constant.2561), dynamic_slice_sizes={1,8,128}
mul = bf16[1,8,128] multiply(dynamic-slice.99, dynamic-slice.99)
ar.1 = bf16[1,8,128] negate(mul)
dynamic-update-slice.35 = bf16[3,8,128] dynamic-update-slice(get-tuple-element.395, ar.1, select.1348, constant.2561, constant.2561)
ROOT tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(add.230, dynamic-update-slice.35, get-tuple-element.5)
}
ENTRY entry {
c0 = s32[] constant(0)
p0 = bf16[3,8,128] parameter(0)
tuple = (s32[], bf16[3,8,128]) tuple(c0, p0)
while = (s32[], bf16[3,8,128]) while(tuple), condition=while_cond,
body=while_body ROOT gte1 = bf16[3,8,128] get-tuple-element(while),
index=1
}
ENTRY entry {
c0 = s32[] constant(0)
p0 = bf16[3,8,128] parameter(0)
tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(c0, p0, p0)
while = (s32[], bf16[3,8,128], bf16[3,8,128]) while(tuple), condition=while_cond, body=while_body
ROOT gte1 = bf16[3,8,128] get-tuple-element(while), index=1
}
)";
auto module = ParseAndReturnUnverifiedModule(hlo_string).value();
auto module2 = ParseAndReturnUnverifiedModule(hlo_string).value();
EXPECT_TRUE(RunOptimizer(module.get(), /*last_run=*/true, 0).value());
EXPECT_TRUE(RunOptimizer(module2.get(), /*last_run=*/true, 200).value());
EXPECT_FALSE(RunOptimizer(module2.get(), /*last_run=*/true, 200).value());
XLA_VLOG_LINES(1, module->ToString());
XLA_VLOG_LINES(1, module2->ToString());
EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(module2),
......@@ -314,59 +311,56 @@ TEST_F(CollectivePipelinerExecutionTest, TransformIncrementByTwo) {
TEST_F(CollectivePipelinerExecutionTest, NoTransformCantProveIndexDoesntWrap) {
constexpr absl::string_view hlo_string = R"(
HloModule module
HloModule module
add {
lhs = bf16[] parameter(0)
rhs = bf16[] parameter(1)
ROOT add = bf16[] add(lhs, rhs)
}
add {
lhs = bf16[] parameter(0)
rhs = bf16[] parameter(1)
ROOT add = bf16[] add(lhs, rhs)
}
while_cond {
param = (s32[], bf16[3,8,128]) parameter(0)
gte = s32[] get-tuple-element(param), index=0
constant.1 = s32[] constant(4)
ROOT cmp = pred[] compare(gte, constant.1), direction=LT
}
while_cond {
param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0)
gte = s32[] get-tuple-element(param), index=0
constant.1 = s32[] constant(4)
ROOT cmp = pred[] compare(gte, constant.1), direction=LT
}
while_body {
param = (s32[], bf16[3,8,128]) parameter(0)
get-tuple-element.394 = s32[] get-tuple-element(param), index=0
get-tuple-element.395 = bf16[3,8,128] get-tuple-element(param), index=1
constant.2557 = s32[] constant(1)
add.230 = s32[] add(get-tuple-element.394, constant.2557)
constant.2559 = s32[] constant(3)
subtract.139 = s32[] subtract(constant.2559, get-tuple-element.394)
constant.2560 = s32[] constant(-1)
add.231 = s32[] add(subtract.139, constant.2560)
constant.2561 = s32[] constant(0)
compare.747 = pred[] compare(add.231, constant.2561), direction=LT
constant.2562 = s32[] constant(2)
add.232 = s32[] add(subtract.139, constant.2562)
select.1348 = s32[] select(compare.747, add.232, add.231)
dynamic-slice.99 = bf16[1,8,128] dynamic-slice(get-tuple-element.395,
select.1348, constant.2561, constant.2561), dynamic_slice_sizes={1,8,128}
mul = bf16[1,8,128] multiply(dynamic-slice.99, dynamic-slice.99)
ar.1 = bf16[1,8,128] negate(mul)
dynamic-update-slice.35 = bf16[3,8,128]
dynamic-update-slice(get-tuple-element.395, ar.1, select.1348,
constant.2561, constant.2561) ROOT tuple = (s32[], bf16[3,8,128])
tuple(add.230, dynamic-update-slice.35)
}
while_body {
param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0)
get-tuple-element.394 = s32[] get-tuple-element(param), index=0
get-tuple-element.395 = bf16[3,8,128] get-tuple-element(param), index=1
get-tuple-element.5 = bf16[3,8,128] get-tuple-element(param), index=2
constant.2557 = s32[] constant(1)
add.230 = s32[] add(get-tuple-element.394, constant.2557)
constant.2559 = s32[] constant(3)
subtract.139 = s32[] subtract(constant.2559, get-tuple-element.394)
constant.2560 = s32[] constant(-1)
add.231 = s32[] add(subtract.139, constant.2560)
constant.2561 = s32[] constant(0)
compare.747 = pred[] compare(add.231, constant.2561), direction=LT
constant.2562 = s32[] constant(2)
add.232 = s32[] add(subtract.139, constant.2562)
select.1348 = s32[] select(compare.747, add.232, add.231)
dynamic-slice.99 = bf16[1,8,128] dynamic-slice(get-tuple-element.5, select.1348, constant.2561, constant.2561), dynamic_slice_sizes={1,8,128}
mul = bf16[1,8,128] multiply(dynamic-slice.99, dynamic-slice.99)
ar.1 = bf16[1,8,128] negate(mul)
dynamic-update-slice.35 = bf16[3,8,128] dynamic-update-slice(get-tuple-element.395, ar.1, select.1348, constant.2561, constant.2561)
ROOT tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(add.230, dynamic-update-slice.35, get-tuple-element.5)
}
ENTRY entry {
c0 = s32[] constant(-1)
p0 = bf16[3,8,128] parameter(0)
tuple = (s32[], bf16[3,8,128]) tuple(c0, p0)
while = (s32[], bf16[3,8,128]) while(tuple), condition=while_cond,
body=while_body ROOT gte1 = bf16[3,8,128] get-tuple-element(while),
index=1
}
ENTRY entry {
c0 = s32[] constant(-1)
p0 = bf16[3,8,128] parameter(0)
tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(c0, p0, p0)
while = (s32[], bf16[3,8,128], bf16[3,8,128]) while(tuple), condition=while_cond, body=while_body
ROOT gte1 = bf16[3,8,128] get-tuple-element(while), index=1
}
)";
auto module = ParseAndReturnUnverifiedModule(hlo_string).value();
auto module2 = ParseAndReturnUnverifiedModule(hlo_string).value();
EXPECT_TRUE(RunOptimizer(module.get(), /*last_run=*/true, 0).value());
EXPECT_TRUE(RunOptimizer(module2.get(), /*last_run=*/true, 200).value());
EXPECT_FALSE(RunOptimizer(module.get(), /*last_run=*/true, 0).value());
EXPECT_FALSE(RunOptimizer(module2.get(), /*last_run=*/true, 200).value());
XLA_VLOG_LINES(1, module->ToString());
XLA_VLOG_LINES(1, module2->ToString());
EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(module2),
......@@ -385,16 +379,17 @@ TEST_F(CollectivePipelinerExecutionTest,
}
while_cond {
param = (s32[], bf16[3,8,128]) parameter(0)
param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0)
gte = s32[] get-tuple-element(param), index=0
constant.1 = s32[] constant(0)
ROOT cmp = pred[] compare(gte, constant.1), direction=LT
}
while_body {
param = (s32[], bf16[3,8,128]) parameter(0)
param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0)
get-tuple-element.394 = s32[] get-tuple-element(param), index=0
get-tuple-element.395 = bf16[3,8,128] get-tuple-element(param), index=1
get-tuple-element.5 = bf16[3,8,128] get-tuple-element(param), index=2
constant.2557 = s32[] constant(1)
add.230 = s32[] add(get-tuple-element.394, constant.2557)
constant.2559 = s32[] constant(3)
......@@ -406,21 +401,21 @@ TEST_F(CollectivePipelinerExecutionTest,
constant.2562 = s32[] constant(2)
add.232 = s32[] add(subtract.139, constant.2562)
select.1348 = s32[] select(compare.747, add.232, add.231)
dynamic-slice.99 = bf16[1,8,128] dynamic-slice(get-tuple-element.395,
dynamic-slice.99 = bf16[1,8,128] dynamic-slice(get-tuple-element.5,
select.1348, constant.2561, constant.2561), dynamic_slice_sizes={1,8,128}
mul = bf16[1,8,128] multiply(dynamic-slice.99, dynamic-slice.99)
ar.1 = bf16[1,8,128] negate(mul)
dynamic-update-slice.35 = bf16[3,8,128]
dynamic-update-slice(get-tuple-element.395, ar.1, select.1348,
constant.2561, constant.2561) ROOT tuple = (s32[], bf16[3,8,128])
tuple(add.230, dynamic-update-slice.35)
constant.2561, constant.2561) ROOT tuple = (s32[], bf16[3,8,128], bf16[3,8,128])
tuple(add.230, dynamic-update-slice.35, get-tuple-element.5)
}
ENTRY entry {
c0 = s32[] constant(-3)
p0 = bf16[3,8,128] parameter(0)
tuple = (s32[], bf16[3,8,128]) tuple(c0, p0)
while = (s32[], bf16[3,8,128]) while(tuple), condition=while_cond,
tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(c0, p0, p0)
while = (s32[], bf16[3,8,128], bf16[3,8,128]) while(tuple), condition=while_cond,
body=while_body ROOT gte1 = bf16[3,8,128] get-tuple-element(while),
index=1
}
......@@ -428,7 +423,7 @@ TEST_F(CollectivePipelinerExecutionTest,
auto module = ParseAndReturnUnverifiedModule(hlo_string).value();
auto module2 = ParseAndReturnUnverifiedModule(hlo_string).value();
EXPECT_TRUE(RunOptimizer(module.get(), /*last_run=*/true, 0).value());
EXPECT_TRUE(RunOptimizer(module2.get(), /*last_run=*/true, 200).value());
EXPECT_FALSE(RunOptimizer(module2.get(), /*last_run=*/true, 200).value());
XLA_VLOG_LINES(1, module->ToString());
XLA_VLOG_LINES(1, module2->ToString());
EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(module2),
......@@ -493,8 +488,8 @@ TEST_F(CollectivePipelinerExecutionTest, EscapedInputNoTransform) {
)";
auto module = ParseAndReturnUnverifiedModule(hlo_string).value();
auto module2 = ParseAndReturnUnverifiedModule(hlo_string).value();
EXPECT_TRUE(RunOptimizer(module.get(), /*last_run=*/true, 0).value());
EXPECT_TRUE(RunOptimizer(module2.get(), /*last_run=*/true, 200).value());
EXPECT_FALSE(RunOptimizer(module.get(), /*last_run=*/true, 0).value());
EXPECT_FALSE(RunOptimizer(module2.get(), /*last_run=*/true, 200).value());
XLA_VLOG_LINES(1, module->ToString());
XLA_VLOG_LINES(1, module2->ToString());
EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(module2),
......@@ -512,16 +507,17 @@ TEST_F(CollectivePipelinerExecutionTest, TransformWithAg) {
}
while_cond {
param = (s32[], bf16[3,8,128]) parameter(0)
param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0)
gte = s32[] get-tuple-element(param), index=0
constant.1 = s32[] constant(0)
ROOT cmp = pred[] compare(gte, constant.1), direction=LT
}
while_body {
param = (s32[], bf16[3,8,128]) parameter(0)
param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0)
get-tuple-element.394 = s32[] get-tuple-element(param), index=0
get-tuple-element.395 = bf16[3,8,128] get-tuple-element(param), index=1
get-tuple-element.5 = bf16[3,8,128] get-tuple-element(param), index=2
constant.2557 = s32[] constant(1)
add.230 = s32[] add(get-tuple-element.394, constant.2557)
constant.2559 = s32[] constant(3)
......@@ -533,7 +529,7 @@ TEST_F(CollectivePipelinerExecutionTest, TransformWithAg) {
constant.2562 = s32[] constant(2)
add.232 = s32[] add(subtract.139, constant.2562)
select.1348 = s32[] select(compare.747, add.232, add.231)
dynamic-slice.99 = bf16[1,8,128] dynamic-slice(get-tuple-element.395,
dynamic-slice.99 = bf16[1,8,128] dynamic-slice(get-tuple-element.5,
select.1348, constant.2561, constant.2561), dynamic_slice_sizes={1,8,128}
mul = bf16[1,8,128] multiply(dynamic-slice.99, dynamic-slice.99)
rs.1 = bf16[1,8,128] negate(mul)
......@@ -541,23 +537,24 @@ TEST_F(CollectivePipelinerExecutionTest, TransformWithAg) {
dynamic-update-slice.35 =
bf16[3,8,128] dynamic-update-slice(get-tuple-element.395, ag.1,
select.1348, constant.2561, constant.2561) ROOT tuple = (s32[],
bf16[3,8,128]) tuple(add.230, dynamic-update-slice.35)
bf16[3,8,128], bf16[3,8,128]) tuple(add.230, dynamic-update-slice.35, get-tuple-element.5)
}
ENTRY entry {
c0 = s32[] constant(-3)
p0 = bf16[3,8,128] parameter(0)
cc = bf16[] constant(0)
tuple = (s32[], bf16[3,8,128]) tuple(c0, p0)
while = (s32[], bf16[3,8,128]) while(tuple), condition=while_cond,
body=while_body ROOT gte1 = bf16[3,8,128] get-tuple-element(while),
tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(c0, p0, p0)
while = (s32[], bf16[3,8,128], bf16[3,8,128]) while(tuple), condition=while_cond,
body=while_body
ROOT gte1 = bf16[3,8,128] get-tuple-element(while),
index=1
}
)";
auto module = ParseAndReturnUnverifiedModule(hlo_string).value();
auto module2 = ParseAndReturnUnverifiedModule(hlo_string).value();
EXPECT_TRUE(RunOptimizer(module.get(), /*last_run=*/true, 0).value());
EXPECT_TRUE(RunOptimizer(module2.get(), /*last_run=*/true, 200).value());
EXPECT_FALSE(RunOptimizer(module2.get(), /*last_run=*/true, 200).value());
XLA_VLOG_LINES(1, module->ToString());
XLA_VLOG_LINES(1, module2->ToString());
EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(module2),
......@@ -575,16 +572,17 @@ add {
}
while_cond {
param = (s32[], bf16[3,9,128]) parameter(0)
param = (s32[], bf16[3,9,128], bf16[3,9,128]) parameter(0)
gte = s32[] get-tuple-element(param), index=0
constant.1 = s32[] constant(0)
ROOT cmp = pred[] compare(gte, constant.1), direction=LT
}
while_body {
param = (s32[], bf16[3,9,128]) parameter(0)
param = (s32[], bf16[3,9,128], bf16[3,9,128]) parameter(0)
get-tuple-element.394 = s32[] get-tuple-element(param), index=0
get-tuple-element.395 = bf16[3,9,128] get-tuple-element(param), index=1
get-tuple-element.5 = bf16[3,9,128] get-tuple-element(param), index=2
constant.2557 = s32[] constant(1)
add.230 = s32[] add(get-tuple-element.394, constant.2557)
constant.2559 = s32[] constant(3)
......@@ -596,7 +594,7 @@ while_body {
constant.2562 = s32[] constant(2)
add.232 = s32[] add(subtract.139, constant.2562)
select.1348 = s32[] select(compare.747, add.232, add.231)
dynamic-slice.99 = bf16[1,9,128] dynamic-slice(get-tuple-element.395, select.1348, constant.2561, constant.2561), dynamic_slice_sizes={1,9,128}
dynamic-slice.99 = bf16[1,9,128] dynamic-slice(get-tuple-element.5, select.1348, constant.2561, constant.2561), dynamic_slice_sizes={1,9,128}
mul = bf16[1,9,128] multiply(dynamic-slice.99, dynamic-slice.99)
cpd = bf16[] constant(0)
%pd = bf16[1,16,128] pad(mul, cpd), padding=0_0x0_7x0_0
......@@ -604,22 +602,22 @@ while_body {
ag.1 = bf16[1,16,128] negate(rs.1)
slc = bf16[1,9,128] slice(ag.1), slice={[0:1], [0:9], [0:128]}
dynamic-update-slice.35 = bf16[3,9,128] dynamic-update-slice(get-tuple-element.395, slc, select.1348, constant.2561, constant.2561)
ROOT tuple = (s32[], bf16[3,9,128]) tuple(add.230, dynamic-update-slice.35)
ROOT tuple = (s32[], bf16[3,9,128], bf16[3,9,128]) tuple(add.230, dynamic-update-slice.35, get-tuple-element.5)
}
ENTRY entry {
c0 = s32[] constant(-3)
p0 = bf16[3,9,128] parameter(0)
cc = bf16[] constant(0)
tuple = (s32[], bf16[3,9,128]) tuple(c0, p0)
while = (s32[], bf16[3,9,128]) while(tuple), condition=while_cond, body=while_body
tuple = (s32[], bf16[3,9,128], bf16[3,9,128]) tuple(c0, p0, p0)
while = (s32[], bf16[3,9,128], bf16[3,9,128]) while(tuple), condition=while_cond, body=while_body
ROOT gte1 = bf16[3,9,128] get-tuple-element(while), index=1
}
)";
auto module = ParseAndReturnUnverifiedModule(hlo_string).value();
auto module2 = ParseAndReturnUnverifiedModule(hlo_string).value();
EXPECT_TRUE(RunOptimizer(module.get(), /*last_run=*/true, 0).value());
EXPECT_TRUE(RunOptimizer(module2.get(), /*last_run=*/true, 200).value());
EXPECT_FALSE(RunOptimizer(module2.get(), /*last_run=*/true, 200).value());
XLA_VLOG_LINES(1, module->ToString());
XLA_VLOG_LINES(1, module2->ToString());
EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(module2),
......@@ -637,20 +635,21 @@ TEST_F(CollectivePipelinerExecutionTest, TransformWithAgInsertCustomCall) {
}
while_cond {
param = (s32[], bf16[3,8,128]) parameter(0)
param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0)
gte = s32[] get-tuple-element(param), index=0
constant.1 = s32[] constant(0)
ROOT cmp = pred[] compare(gte, constant.1), direction=LT
}
while_body {
param = (s32[], bf16[3,8,128]) parameter(0)
param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0)
get-tuple-element.394 = s32[] get-tuple-element(param), index=0
get-tuple-element.395 = bf16[3,8,128] get-tuple-element(param), index=1
get-tuple-element.5 = bf16[3,8,128] get-tuple-element(param), index=2
constant.2557 = s32[] constant(1)
constant.2561 = s32[] constant(0)
add.230 = s32[] add(get-tuple-element.394, constant.2557)
dynamic-slice.99 = bf16[1,8,128] dynamic-slice(get-tuple-element.395,
dynamic-slice.99 = bf16[1,8,128] dynamic-slice(get-tuple-element.5,
get-tuple-element.394, constant.2561, constant.2561),
dynamic_slice_sizes={1,8,128} mul = bf16[1,8,128]
multiply(dynamic-slice.99, dynamic-slice.99) rs.1 = bf16[1,8,128]
......@@ -658,16 +657,16 @@ TEST_F(CollectivePipelinerExecutionTest, TransformWithAgInsertCustomCall) {
ag.1 = bf16[1,8,128] negate(rs.1)
dynamic-update-slice.35 = bf16[3,8,128]
dynamic-update-slice(get-tuple-element.395, ag.1, get-tuple-element.394,
constant.2561, constant.2561) ROOT tuple = (s32[], bf16[3,8,128])
tuple(add.230, dynamic-update-slice.35)
constant.2561, constant.2561) ROOT tuple = (s32[], bf16[3,8,128], bf16[3,8,128])
tuple(add.230, dynamic-update-slice.35, get-tuple-element.5)
}
ENTRY entry {
c0 = s32[] constant(-8)
p0 = bf16[3,8,128] parameter(0)
cc = bf16[] constant(0)
tuple = (s32[], bf16[3,8,128]) tuple(c0, p0)
while = (s32[], bf16[3,8,128]) while(tuple), condition=while_cond,
tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(c0, p0, p0)
while = (s32[], bf16[3,8,128], bf16[3,8,128]) while(tuple), condition=while_cond,
body=while_body ROOT gte1 = bf16[3,8,128] get-tuple-element(while),
index=1
}
......@@ -675,7 +674,7 @@ TEST_F(CollectivePipelinerExecutionTest, TransformWithAgInsertCustomCall) {
auto module = ParseAndReturnUnverifiedModule(hlo_string).value();
auto module2 = ParseAndReturnUnverifiedModule(hlo_string).value();
EXPECT_TRUE(RunOptimizer(module.get(), /*last_run=*/true, 0).value());
EXPECT_TRUE(RunOptimizer(module2.get(), /*last_run=*/true, 200).value());
EXPECT_FALSE(RunOptimizer(module2.get(), /*last_run=*/true, 200).value());
XLA_VLOG_LINES(1, module->ToString());
XLA_VLOG_LINES(1, module2->ToString());
EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(module2),
......@@ -744,7 +743,7 @@ ENTRY entry {
HloPredicateIsOp<HloOpcode::kConcatenate>,
CollectivePipeliner::PipeliningDirection::kBackward)
.value());
EXPECT_TRUE(RunOptimizer(module2.get(), /*last_run=*/true, 0).value());
EXPECT_FALSE(RunOptimizer(module2.get(), /*last_run=*/true, 0).value());
XLA_VLOG_LINES(1, module->ToString());
XLA_VLOG_LINES(1, module2->ToString());
EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(module2),
......@@ -755,17 +754,24 @@ TEST_F(CollectivePipelinerExecutionTest, MultiUsesElementwise) {
constexpr absl::string_view hlo_string = R"(
HloModule module
add {
lhs = bf16[] parameter(0)
rhs = bf16[] parameter(1)
ROOT add = bf16[] add(lhs, rhs)
}
while_cond {
param = (s32[], bf16[3,8,128]) parameter(0)
param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0)
gte = s32[] get-tuple-element(param), index=0
constant.1 = s32[] constant(3)
ROOT cmp = pred[] compare(gte, constant.1), direction=LT
}
while_body {
param = (s32[], bf16[3,8,128]) parameter(0)
param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0)
get-tuple-element.394 = s32[] get-tuple-element(param), index=0
get-tuple-element.395 = bf16[3,8,128] get-tuple-element(param), index=1
get-tuple-element.5 = bf16[3,8,128] get-tuple-element(param), index=2
constant.2557 = s32[] constant(1)
add.230 = s32[] add(get-tuple-element.394, constant.2557)
constant.2559 = s32[] constant(3)
......@@ -777,7 +783,7 @@ while_body {
constant.2562 = s32[] constant(2)
add.232 = s32[] add(subtract.139, constant.2562)
select.1348 = s32[] select(compare.747, add.232, add.231)
dynamic-slice.99 = bf16[1,8,128] dynamic-slice(get-tuple-element.395, select.1348, constant.2561, constant.2561), dynamic_slice_sizes={1,8,128}
dynamic-slice.99 = bf16[1,8,128] dynamic-slice(get-tuple-element.5, select.1348, constant.2561, constant.2561), dynamic_slice_sizes={1,8,128}
mul = bf16[1,8,128] multiply(dynamic-slice.99, dynamic-slice.99)
c2 = bf16[] constant(2.0)
bc = bf16[1,8,128] broadcast(c2)
......@@ -786,14 +792,14 @@ while_body {
mul3 = bf16[1,8,128] multiply(mul2, ar.1)
mul4 = bf16[1,8,128] multiply(mul3, mul)
dynamic-update-slice.35 = bf16[3,8,128] dynamic-update-slice(get-tuple-element.395, mul4, select.1348, constant.2561, constant.2561)
ROOT tuple = (s32[], bf16[3,8,128]) tuple(add.230, dynamic-update-slice.35)
ROOT tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(add.230, dynamic-update-slice.35, get-tuple-element.5)
}
ENTRY entry {
c0 = s32[] constant(0)
p0 = bf16[3,8,128] parameter(0)
tuple = (s32[], bf16[3,8,128]) tuple(c0, p0)
while = (s32[], bf16[3,8,128]) while(tuple), condition=while_cond, body=while_body
tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(c0, p0, p0)
while = (s32[], bf16[3,8,128], bf16[3,8,128]) while(tuple), condition=while_cond, body=while_body
ROOT gte1 = bf16[3,8,128] get-tuple-element(while), index=1
}
)";
......@@ -816,17 +822,24 @@ TEST_F(CollectivePipelinerExecutionTest, ElementWiseUser) {
constexpr absl::string_view hlo_string = R"(
HloModule module
add {
lhs = bf16[] parameter(0)
rhs = bf16[] parameter(1)
ROOT add = bf16[] add(lhs, rhs)
}
while_cond {
param = (s32[], bf16[3,8,128]) parameter(0)
param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0)
gte = s32[] get-tuple-element(param), index=0
constant.1 = s32[] constant(3)
ROOT cmp = pred[] compare(gte, constant.1), direction=LT
}
while_body {
param = (s32[], bf16[3,8,128]) parameter(0)
param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0)
get-tuple-element.394 = s32[] get-tuple-element(param), index=0
get-tuple-element.395 = bf16[3,8,128] get-tuple-element(param), index=1
get-tuple-element.5 = bf16[3,8,128] get-tuple-element(param), index=2
constant.2557 = s32[] constant(1)
add.230 = s32[] add(get-tuple-element.394, constant.2557)
constant.2559 = s32[] constant(3)
......@@ -838,19 +851,19 @@ while_body {
constant.2562 = s32[] constant(2)
add.232 = s32[] add(subtract.139, constant.2562)
select.1348 = s32[] select(compare.747, add.232, add.231)
dynamic-slice.99 = bf16[1,8,128] dynamic-slice(get-tuple-element.395, select.1348, constant.2561, constant.2561), dynamic_slice_sizes={1,8,128}
dynamic-slice.99 = bf16[1,8,128] dynamic-slice(get-tuple-element.5, select.1348, constant.2561, constant.2561), dynamic_slice_sizes={1,8,128}
mul = bf16[1,8,128] multiply(dynamic-slice.99, dynamic-slice.99)
ar.1 = bf16[1,8,128] negate(mul)
mul2 = bf16[1,8,128] multiply(ar.1, mul)
dynamic-update-slice.35 = bf16[3,8,128] dynamic-update-slice(get-tuple-element.395, mul2, select.1348, constant.2561, constant.2561)
ROOT tuple = (s32[], bf16[3,8,128]) tuple(add.230, dynamic-update-slice.35)
ROOT tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(add.230, dynamic-update-slice.35, get-tuple-element.5)
}
ENTRY entry {
c0 = s32[] constant(0)
p0 = bf16[3,8,128] parameter(0)
tuple = (s32[], bf16[3,8,128]) tuple(c0, p0)
while = (s32[], bf16[3,8,128]) while(tuple), condition=while_cond, body=while_body
tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(c0, p0, p0)
while = (s32[], bf16[3,8,128], bf16[3,8,128]) while(tuple), condition=while_cond, body=while_body
ROOT gte1 = bf16[3,8,128] get-tuple-element(while), index=1
}
)";
......@@ -869,5 +882,72 @@ ENTRY entry {
ErrorSpec{0.1, 0.1}));
}
TEST_F(CollectivePipelinerExecutionTest,
TransformIncrementIndexByOneNotFirstIdxSink) {
constexpr absl::string_view hlo_string = R"(
HloModule module
add {
lhs = bf16[] parameter(0)
rhs = bf16[] parameter(1)
ROOT add = bf16[] add(lhs, rhs)
}
while_cond {
param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0)
gte = s32[] get-tuple-element(param), index=0
constant.1 = s32[] constant(3)
ROOT cmp = pred[] compare(gte, constant.1), direction=LT
}
while_body {
param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0)
get-tuple-element.394 = s32[] get-tuple-element(param), index=0
get-tuple-element.395 = bf16[3,8,128] get-tuple-element(param), index=1
get-tuple-element.35 = bf16[3,8,128] get-tuple-element(param), index=2
constant.2557 = s32[] constant(1)
add.230 = s32[] add(get-tuple-element.394, constant.2557)
constant.2559 = s32[] constant(3)
subtract.139 = s32[] subtract(constant.2559, get-tuple-element.394)
constant.2560 = s32[] constant(-1)
add.231 = s32[] add(subtract.139, constant.2560)
constant.2561 = s32[] constant(0)
compare.747 = pred[] compare(add.231, constant.2561), direction=LT
constant.2562 = s32[] constant(2)
add.232 = s32[] add(subtract.139, constant.2562)
select.1348 = s32[] select(compare.747, add.232, add.231)
dynamic-slice.99 = bf16[1,8,128] dynamic-slice(get-tuple-element.35, select.1348, constant.2561, constant.2561), dynamic_slice_sizes={1,8,128}
mul = bf16[1,8,128] multiply(dynamic-slice.99, dynamic-slice.99)
ar.1 = bf16[1,8,128] negate(mul)
%c = bf16[] constant(5.0)
%b = bf16[1,8,128] broadcast(c), dimensions={}
%a = bf16[1,8,128] add(ar.1, b)
dynamic-update-slice.35 = bf16[3,8,128] dynamic-update-slice(get-tuple-element.395, a, select.1348, constant.2561, constant.2561)
ROOT tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(add.230, dynamic-update-slice.35, get-tuple-element.35)
}
ENTRY entry {
c0 = s32[] constant(0)
p0 = bf16[3,8,128] parameter(0)
tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(c0, p0, p0)
while = (s32[], bf16[3,8,128], bf16[3,8,128]) while(tuple), condition=while_cond, body=while_body
ROOT gte1 = bf16[3,8,128] get-tuple-element(while), index=1
}
)";
auto module = ParseAndReturnUnverifiedModule(hlo_string).value();
auto module2 = ParseAndReturnUnverifiedModule(hlo_string).value();
EXPECT_TRUE(
RunOptimizer(module.get(), /*last_run=*/true, 0,
/*should_process=*/HloPredicateIsOp<HloOpcode::kNegate>,
CollectivePipeliner::PipeliningDirection::kForwardSink,
/*pipeline_use_tree=*/true)
.value());
XLA_VLOG_LINES(1, module->ToString());
XLA_VLOG_LINES(1, module2->ToString());
EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(module2),
ErrorSpec{0.1, 0.1}));
}
} // namespace
} // namespace xla
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册