From 910b76550b80b9611bbb7cd1acb1115994f30e77 Mon Sep 17 00:00:00 2001 From: Yuan Shuai Date: Thu, 31 Oct 2019 20:20:42 +0800 Subject: [PATCH] [BugFix] Fix conv bn bug, check bias in pass (#2313) (#2318) * Fix conv bn bug. test=develop * Fix bug in pattern_matcher. test=develop --- lite/core/mir/fusion/conv_bn_fuser.cc | 4 ++-- lite/core/mir/pattern_matcher.cc | 6 ++++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/lite/core/mir/fusion/conv_bn_fuser.cc b/lite/core/mir/fusion/conv_bn_fuser.cc index 099fc55583..ec07278eed 100644 --- a/lite/core/mir/fusion/conv_bn_fuser.cc +++ b/lite/core/mir/fusion/conv_bn_fuser.cc @@ -66,7 +66,6 @@ void ConvBNFuser::BuildPattern() { if (conv_has_bias_) { auto* conv_bias = VarNode("conv_bias") ->assert_is_op_input(conv_type_, "Bias") - ->AsInput() ->AsIntermediate(); conv->LinksFrom({conv_input, conv_weight, conv_bias}).LinksTo({conv_out}); } else { @@ -172,7 +171,8 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) { } // compute new conv_bias - if (conv_has_bias_) { + if (conv_has_bias_ && conv_op_desc->HasInput("Bias") && + conv_op_desc->Input("Bias").size() > 0) { auto conv_bias_t = scope->FindVar(matched.at("conv_bias")->arg()->name) ->GetMutable(); auto conv_bias_d = conv_bias_t->data(); diff --git a/lite/core/mir/pattern_matcher.cc b/lite/core/mir/pattern_matcher.cc index 8ec85a4ef1..8e0fc55be2 100644 --- a/lite/core/mir/pattern_matcher.cc +++ b/lite/core/mir/pattern_matcher.cc @@ -415,7 +415,8 @@ bool IsNthOutput(const Node *var, CHECK(var->IsArg()); CHECK(op->IsStmt()); auto op_info = op->stmt()->op_info(); - if (op_info->Output(argument).size() <= nth) return false; + if (!op_info->HasOutput(argument) || op_info->Output(argument).size() <= nth) + return false; return var->arg()->name == op_info->Output(argument)[nth]; } @@ -426,7 +427,8 @@ bool IsNthInput(const Node *var, CHECK(var->IsArg()); CHECK(op->IsStmt()); auto op_info = op->stmt()->op_info(); - if (op_info->Input(argument).size() <= nth) return false; + if (!op_info->HasInput(argument) || op_info->Input(argument).size() <= nth) + return false; return var->arg()->name == op_info->Input(argument)[nth]; } -- GitLab