提交 377c326f 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2203 Fix the back outputs of BNTrainingUpdateV3 may cover the front whitch is empty

Merge pull request !2203 from huanghui/single-batchnorm-fission-pass
......@@ -23,42 +23,8 @@
namespace mindspore {
namespace opt {
namespace {
const std::vector<int> kOutputIndex{0, 1, 2, 3, 4};
constexpr size_t kBatchNormLeastOutputNum = 1;
constexpr size_t kBatchNormRealInputNum = 3;
bool GetBatchNormOutputs(const FuncGraphPtr &func_graph, const AnfNodePtr &bn, std::vector<AnfNodePtr> *bn_outputs) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(bn_outputs);
auto manager = func_graph->manager();
MS_EXCEPTION_IF_NULL(manager);
auto iter = manager->node_users().find(bn);
if (iter == manager->node_users().end()) {
return false;
}
size_t output_num = 0;
for (const auto &node_index : iter->second) {
AnfNodePtr output = node_index.first;
MS_EXCEPTION_IF_NULL(output);
if (!IsPrimitiveCNode(output, prim::kPrimTupleGetItem)) {
continue;
}
auto tuple_getiterm_cnode = output->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(tuple_getiterm_cnode);
auto index_node = tuple_getiterm_cnode->input(kInputNodeOutputIndexInTupleGetItem);
MS_EXCEPTION_IF_NULL(index_node);
auto value_node = index_node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
int index = GetValue<int>(value_node->value());
if (std::find(kOutputIndex.begin(), kOutputIndex.end(), index) == kOutputIndex.end()) {
return false;
}
bn_outputs->push_back(output);
output_num++;
}
return output_num >= kBatchNormLeastOutputNum;
}
AnfNodePtr CreateBNTrainingReduce(const FuncGraphPtr &func_graph, const AnfNodePtr &bn) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(bn);
......@@ -140,34 +106,12 @@ const AnfNodePtr SingleBatchNormFission::Process(const FuncGraphPtr &func_graph,
MS_LOG(INFO) << "is training should be true if do fusion";
return nullptr;
}
std::vector<AnfNodePtr> bn_outputs;
if (!GetBatchNormOutputs(func_graph, node, &bn_outputs)) {
MS_LOG(INFO) << "The BatchNorm node should only have output 0, 3 and 4. The node should not be changed";
return nullptr;
}
AnfNodePtr bn_training_reduce = CreateBNTrainingReduce(func_graph, node);
std::vector<AnfNodePtr> bn_training_reduce_outputs;
CreateMultipleOutputsOfAnfNode(func_graph, bn_training_reduce, kBNTrainingReduceOutputNum,
&bn_training_reduce_outputs);
AnfNodePtr bn_training_update_v3 = CreateBNTrainingUpdateV3(func_graph, node, bn_training_reduce_outputs);
std::vector<AnfNodePtr> bn_training_update_v3_outputs;
CreateMultipleOutputsOfAnfNode(func_graph, bn_training_update_v3, kBNTrainingUpdateV3OutputNum,
&bn_training_update_v3_outputs);
if (bn_training_update_v3_outputs.size() != kBNTrainingUpdateV3OutputNum) {
MS_LOG(EXCEPTION) << "The output size of node bn_training_reduce must be " << kBNTrainingUpdateV2OutputNum
<< ", but it is " << bn_training_update_v3_outputs.size();
}
auto manager = func_graph->manager();
MS_EXCEPTION_IF_NULL(manager);
sort(bn_outputs.begin(), bn_outputs.end(), CompareTupleGetitem);
size_t output_index = 0;
for (const auto &output : bn_outputs) {
(void)manager->Replace(output, bn_training_update_v3_outputs[output_index]);
output_index++;
}
// Return the new node for control depends.
return bn_training_update_v3;
return CreateBNTrainingUpdateV3(func_graph, node, bn_training_reduce_outputs);
}
} // namespace opt
} // namespace mindspore
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册