提交 e702e0bc 编写于 作者: Y yujianfeng

Add tuple_getitem check for outputs of bn

上级 817b0e4a
......@@ -81,6 +81,7 @@
#include "pre_activate/ascend/enhancer/getnext_memcpy_elimination.h"
#include "pre_activate/ascend/ir_fission/addn_fission.h"
#include "pre_activate/ascend/enhancer/insert_memcpy_async_for_getnext.h"
#include "pre_activate/ascend/ir_fission/batch_norm_grad_infer_fission.h"
#include "utils/context/ms_context.h"
#include "utils/config_manager.h"
#include "debug/anf_ir_dump.h"
......@@ -116,6 +117,7 @@ void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) {
ir_fusion_pm->AddPass(std::make_shared<GetitemTuple>());
ir_fusion_pm->AddPass(std::make_shared<BatchNorm2BNInfer>());
ir_fusion_pm->AddPass(std::make_shared<BatchNormGrad2BNInferGrad>());
ir_fusion_pm->AddPass(std::make_shared<BatchNormGradInferFission>());
}
} // namespace
......
......@@ -34,6 +34,9 @@ bool CheckOutputsIndex(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
for (const auto &node_index : manager->node_users()[node]) {
AnfNodePtr output = node_index.first;
MS_EXCEPTION_IF_NULL(output);
if (!IsPrimitiveCNode(output, prim::kPrimTupleGetItem)) {
continue;
}
auto tuple_getiterm_cnode = output->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(tuple_getiterm_cnode);
auto index_node = tuple_getiterm_cnode->input(kInputNodeOutputIndexInTupleGetItem);
......
......@@ -274,6 +274,9 @@ const AnfNodePtr FusedBatchNormFusion::Process(const FuncGraphPtr &func_graph, c
MS_EXCEPTION_IF_NULL(manager);
for (const auto &output : bn_outputs) {
MS_EXCEPTION_IF_NULL(output);
if (!IsPrimitiveCNode(output, prim::kPrimTupleGetItem)) {
continue;
}
auto tuple_getitem_cnode = output->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(tuple_getitem_cnode);
AnfNodePtr index_node = tuple_getitem_cnode->input(kInputNodeOutputIndexInTupleGetItem);
......
......@@ -32,7 +32,21 @@ bool CheckValueNodeInputOfMul(const AnfNodePtr &node) {
std::vector<size_t> mul_input_shape = AnfAlgo::GetOutputInferShape(node, 0);
return mul_input_shape.empty() || (mul_input_shape.size() == 1 && mul_input_shape[0] == 1);
}
void AddInputToOutput(const FuncGraphPtr &func_graph, const CNodePtr &old_cnode, const AnfNodePtr &new_node,
std::vector<AnfNodePtr> *new_outputs) {
MS_EXCEPTION_IF_NULL(old_cnode);
MS_EXCEPTION_IF_NULL(new_node);
MS_EXCEPTION_IF_NULL(new_outputs);
auto node_to_output = old_cnode->input(kAccumIndex + 1);
MS_EXCEPTION_IF_NULL(node_to_output);
AbstractBasePtrList abstract_list{old_cnode->abstract(), node_to_output->abstract()};
auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
new_node->set_abstract(abstract_tuple);
// Create Output
CreateMultipleOutputsOfAnfNode(func_graph, new_node, kFusedMulApplyMomentumOutputNum, new_outputs);
}
} // namespace
const BaseRef MomentumLossscaleFusion::DefinePattern() const {
VarPtr Xs = std::make_shared<SeqVar>();
VarPtr X0 = std::make_shared<Var>();
......@@ -80,15 +94,10 @@ const AnfNodePtr MomentumLossscaleFusion::Process(const FuncGraphPtr &func_graph
input_names_value[3] = "x1";
input_names_value.emplace_back("x2");
AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(input_names_value), new_node);
auto node_to_output = cnode->input(kAccumIndex + 1);
MS_EXCEPTION_IF_NULL(node_to_output);
AbstractBasePtrList abstract_list{node->abstract(), node_to_output->abstract()};
auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
new_node->set_abstract(abstract_tuple);
new_node->set_scope(node->scope());
// Create Output
// Create Outputs
std::vector<AnfNodePtr> new_outputs;
CreateMultipleOutputsOfAnfNode(func_graph, new_node, kFusedMulApplyMomentumOutputNum, &new_outputs);
AddInputToOutput(func_graph, cnode, new_node, &new_outputs);
if (new_outputs.size() != kFusedMulApplyMomentumOutputNum) {
MS_LOG(EXCEPTION) << "Failed to create outputs of " << new_node->DebugString();
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册