提交 fcbb00ff 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Clean up fused_instructions method in HloInstruction

PiperOrigin-RevId: 164879220
上级 471c506b
......@@ -521,13 +521,18 @@ void HloInstruction::MergeFusionInstruction(
// fusion.
// Add all non-parameter fused instructions to 'unfused_instructions' to be
// merged into 'this'.
// This is done in reverse post order.
std::vector<HloInstruction*> unfused_instructions;
for (auto& fused_instruction : clone->fused_instructions()) {
auto fused_instructions =
clone->fused_instructions_computation()->MakeInstructionPostOrder();
for (auto fused_it = fused_instructions.rbegin();
fused_it != fused_instructions.rend(); ++fused_it) {
auto fused_instruction = *fused_it;
if (fused_instruction->opcode() == HloOpcode::kParameter) {
TF_CHECK_OK(fused_instruction->ReplaceAllUsesWith(
clone->mutable_operand(fused_instruction->parameter_number())));
} else {
unfused_instructions.push_back(fused_instruction.get());
unfused_instructions.push_back(fused_instruction);
}
}
CHECK(unfused_instructions.front() == clone->fused_expression_root());
......@@ -1007,8 +1012,6 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneFusionWithNewOperands(
std::vector<HloInstruction*> new_fused_parameters;
const std::vector<HloInstruction*>& fused_parameters_ =
fused_instructions_computation()->parameter_instructions();
const std::list<std::unique_ptr<HloInstruction>>& fused_instructions_ =
fused_instructions_computation()->instructions();
for (HloInstruction* old_fused_parameter : fused_parameters_) {
new_fused_instructions.push_back(old_fused_parameter->Clone());
......@@ -1017,10 +1020,8 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneFusionWithNewOperands(
new_fused_parameters.push_back(new_fusion_parameter);
InsertOrDie(&old_to_new, old_fused_parameter, new_fusion_parameter);
}
for (auto old_fused_instruction_iter = fused_instructions_.rbegin();
old_fused_instruction_iter != fused_instructions_.rend();
++old_fused_instruction_iter) {
HloInstruction* old_fused_instruction = old_fused_instruction_iter->get();
for (auto old_fused_instruction :
fused_instructions_computation()->MakeInstructionPostOrder()) {
if (old_fused_instruction->opcode() == HloOpcode::kParameter) {
FindOrDie(old_to_new, old_fused_instruction);
continue;
......@@ -1667,11 +1668,11 @@ HloInstructionProto HloInstruction::ToProto() const {
proto.mutable_fused_instructions_computation();
proto_fused_computation->set_name(name());
// Fill in fused instructions. Note that fused_instructions() returns in
// reverse post-order (i.e. root first), so we reverse to get post-order.
for (auto fused_it = fused_instructions().rbegin();
fused_it != fused_instructions().rend(); ++fused_it) {
HloInstructionProto fused_proto = (*fused_it)->ToProto();
// Fill in fused instructions in post order.
auto fused_instructions =
fused_instructions_computation()->MakeInstructionPostOrder();
for (auto fused_instruction : fused_instructions) {
HloInstructionProto fused_proto = fused_instruction->ToProto();
proto_fused_computation->add_instructions()->Swap(&fused_proto);
}
break;
......
......@@ -599,9 +599,7 @@ class HloInstruction {
// Precondition: opcode() == HloOpcode::kFusion
HloComputation* fused_instructions_computation() const;
// Returns the vector of fused instructions inside this fusion
// instruction. The order is a reverse postorder of the fused expression (root
// is first in the order).
// Returns the list of fused instructions inside this fusioninstruction.
//
// Note: although the list itself is const, the instructions contained in the
// list returned here are mutable.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册