未验证 提交 910b7655 编写于 作者: Y Yuan Shuai 提交者: GitHub

[BugFix] Fix conv bn bug, check bias in pass (#2313) (#2318)

* Fix conv bn bug. test=develop

* Fix bug in pattern_matcher. test=develop
上级 f3733908
...@@ -66,7 +66,6 @@ void ConvBNFuser::BuildPattern() { ...@@ -66,7 +66,6 @@ void ConvBNFuser::BuildPattern() {
if (conv_has_bias_) { if (conv_has_bias_) {
auto* conv_bias = VarNode("conv_bias") auto* conv_bias = VarNode("conv_bias")
->assert_is_op_input(conv_type_, "Bias") ->assert_is_op_input(conv_type_, "Bias")
->AsInput()
->AsIntermediate(); ->AsIntermediate();
conv->LinksFrom({conv_input, conv_weight, conv_bias}).LinksTo({conv_out}); conv->LinksFrom({conv_input, conv_weight, conv_bias}).LinksTo({conv_out});
} else { } else {
...@@ -172,7 +171,8 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) { ...@@ -172,7 +171,8 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
} }
// compute new conv_bias // compute new conv_bias
if (conv_has_bias_) { if (conv_has_bias_ && conv_op_desc->HasInput("Bias") &&
conv_op_desc->Input("Bias").size() > 0) {
auto conv_bias_t = scope->FindVar(matched.at("conv_bias")->arg()->name) auto conv_bias_t = scope->FindVar(matched.at("conv_bias")->arg()->name)
->GetMutable<lite::Tensor>(); ->GetMutable<lite::Tensor>();
auto conv_bias_d = conv_bias_t->data<float>(); auto conv_bias_d = conv_bias_t->data<float>();
......
...@@ -415,7 +415,8 @@ bool IsNthOutput(const Node *var, ...@@ -415,7 +415,8 @@ bool IsNthOutput(const Node *var,
CHECK(var->IsArg()); CHECK(var->IsArg());
CHECK(op->IsStmt()); CHECK(op->IsStmt());
auto op_info = op->stmt()->op_info(); auto op_info = op->stmt()->op_info();
if (op_info->Output(argument).size() <= nth) return false; if (!op_info->HasOutput(argument) || op_info->Output(argument).size() <= nth)
return false;
return var->arg()->name == op_info->Output(argument)[nth]; return var->arg()->name == op_info->Output(argument)[nth];
} }
...@@ -426,7 +427,8 @@ bool IsNthInput(const Node *var, ...@@ -426,7 +427,8 @@ bool IsNthInput(const Node *var,
CHECK(var->IsArg()); CHECK(var->IsArg());
CHECK(op->IsStmt()); CHECK(op->IsStmt());
auto op_info = op->stmt()->op_info(); auto op_info = op->stmt()->op_info();
if (op_info->Input(argument).size() <= nth) return false; if (!op_info->HasInput(argument) || op_info->Input(argument).size() <= nth)
return false;
return var->arg()->name == op_info->Input(argument)[nth]; return var->arg()->name == op_info->Input(argument)[nth];
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册