diff --git a/lite/core/mir/fusion/conv_bn_fuser.cc b/lite/core/mir/fusion/conv_bn_fuser.cc index 099fc55583925988c37d85f965900f7e4dfa1e98..ec07278eed1f259c45e225497f94d682b544c57c 100644 --- a/lite/core/mir/fusion/conv_bn_fuser.cc +++ b/lite/core/mir/fusion/conv_bn_fuser.cc @@ -66,7 +66,6 @@ void ConvBNFuser::BuildPattern() { if (conv_has_bias_) { auto* conv_bias = VarNode("conv_bias") ->assert_is_op_input(conv_type_, "Bias") - ->AsInput() ->AsIntermediate(); conv->LinksFrom({conv_input, conv_weight, conv_bias}).LinksTo({conv_out}); } else { @@ -172,7 +171,8 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) { } // compute new conv_bias - if (conv_has_bias_) { + if (conv_has_bias_ && conv_op_desc->HasInput("Bias") && + conv_op_desc->Input("Bias").size() > 0) { auto conv_bias_t = scope->FindVar(matched.at("conv_bias")->arg()->name) ->GetMutable(); auto conv_bias_d = conv_bias_t->data(); diff --git a/lite/core/mir/pattern_matcher.cc b/lite/core/mir/pattern_matcher.cc index 8ec85a4ef124ae461cb1e16cf56717a0227b06e6..8e0fc55be2389244ae065b4c2809bbdd74be370c 100644 --- a/lite/core/mir/pattern_matcher.cc +++ b/lite/core/mir/pattern_matcher.cc @@ -415,7 +415,8 @@ bool IsNthOutput(const Node *var, CHECK(var->IsArg()); CHECK(op->IsStmt()); auto op_info = op->stmt()->op_info(); - if (op_info->Output(argument).size() <= nth) return false; + if (!op_info->HasOutput(argument) || op_info->Output(argument).size() <= nth) + return false; return var->arg()->name == op_info->Output(argument)[nth]; } @@ -426,7 +427,8 @@ bool IsNthInput(const Node *var, CHECK(var->IsArg()); CHECK(op->IsStmt()); auto op_info = op->stmt()->op_info(); - if (op_info->Input(argument).size() <= nth) return false; + if (!op_info->HasInput(argument) || op_info->Input(argument).size() <= nth) + return false; return var->arg()->name == op_info->Input(argument)[nth]; }