diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/branch_culling.h b/mindspore/ccsrc/frontend/optimizer/irpass/branch_culling.h index 72a6a4df9fa9505afa6a8c12c19234b1df59e9d9..bcc6bf311f40b316e73bb94c7540e4894a50b04e 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/branch_culling.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/branch_culling.h @@ -98,12 +98,19 @@ class ConvertSwitchReplacement : public OptimizerCaller { return nullptr; } + auto cnode_ = node->cast(); + if (cnode_->size() < 1) { + return nullptr; + } + + auto node_ = cnode_->input(0); + PatternNode cond, true_br, false_br; - auto ConvertSwitchLambda = [&node, &cond, &true_br, &false_br]() -> AnfNodePtr { - auto g1_ = GetValueNode(true_br.GetNode(node)); - auto g2_ = GetValueNode(false_br.GetNode(node)); - auto x_ = cond.GetNode(node); + auto ConvertSwitchLambda = [&node_, &cond, &true_br, &false_br]() -> AnfNodePtr { + auto g1_ = GetValueNode(true_br.GetNode(node_)); + auto g2_ = GetValueNode(false_br.GetNode(node_)); + auto x_ = cond.GetNode(node_); // for switch replace method, only graphs without graph inside can be replaced for (auto &item : g1_->value_nodes()) { @@ -126,7 +133,7 @@ class ConvertSwitchReplacement : public OptimizerCaller { auto trans_g2 = internal::TransformGraphCondFalseBranchNodes(g2_, x_); std::vector params; - auto fg = node->func_graph(); + auto fg = node_->func_graph(); auto cloned_g1 = InlineClone(trans_g1, fg, params); auto cloned_g2 = InlineClone(trans_g2, fg, params); auto nnode = internal::TransformMergeBranches(cloned_g1, cloned_g2, true_output, false_output, x_, fg); @@ -135,8 +142,8 @@ class ConvertSwitchReplacement : public OptimizerCaller { }; MATCH_REPLACE_LAMBDA_IF( - node, PCNode(PPrimitive(prim::kPrimSwitch, cond, true_br, false_br)).MinExtraNodes(0), ConvertSwitchLambda, - true_br.CheckFunc(IsValueNode, node) && false_br.CheckFunc(IsValueNode, node)); + node_, PPrimitive(prim::kPrimSwitch, cond, true_br, false_br), ConvertSwitchLambda, + true_br.CheckFunc(IsValueNode, node_) && false_br.CheckFunc(IsValueNode, node_)); return nullptr; }