From 2dcff5caf80700c990445d1ab0f6fcdfc4caaf9a Mon Sep 17 00:00:00 2001 From: zhupengyang Date: Wed, 11 Sep 2019 13:56:04 +0800 Subject: [PATCH] fix conv-act-fuse-pass when there is no "bias" (#2003) test=develop --- .../mir/fusion/conv_activation_fuse_pass.cc | 13 +++-- lite/core/mir/fusion/conv_activation_fuser.cc | 52 ++++++------------- lite/core/mir/fusion/conv_activation_fuser.h | 5 +- .../generate_npu_program_pass_test.cc | 3 +- 4 files changed, 30 insertions(+), 43 deletions(-) diff --git a/lite/core/mir/fusion/conv_activation_fuse_pass.cc b/lite/core/mir/fusion/conv_activation_fuse_pass.cc index cad98cb26c..510009605b 100644 --- a/lite/core/mir/fusion/conv_activation_fuse_pass.cc +++ b/lite/core/mir/fusion/conv_activation_fuse_pass.cc @@ -23,11 +23,14 @@ namespace lite { namespace mir { void ConvActivationFusePass::Apply(const std::unique_ptr& graph) { - fusion::ConvActivationFuser fuser("conv2d", "relu"); - fuser(graph.get()); - - fusion::ConvActivationFuser depthwise_fuser("depthwise_conv2d", "relu"); - depthwise_fuser(graph.get()); + for (auto conv_type : {"conv2d", "depthwise_conv2d"}) { + for (auto act_type : {"relu"}) { + for (auto has_bias : {true, false}) { + fusion::ConvActivationFuser fuser(conv_type, act_type, has_bias); + fuser(graph.get()); + } + } + } } } // namespace mir diff --git a/lite/core/mir/fusion/conv_activation_fuser.cc b/lite/core/mir/fusion/conv_activation_fuser.cc index c49a9ad4f0..8e18b368f4 100644 --- a/lite/core/mir/fusion/conv_activation_fuser.cc +++ b/lite/core/mir/fusion/conv_activation_fuser.cc @@ -22,35 +22,33 @@ namespace mir { namespace fusion { void ConvActivationFuser::BuildPattern() { - // create input nodes. + // create nodes. auto* input = VarNode("input")->assert_is_op_input(conv_type_, "Input")->AsInput(); auto* filter = VarNode("filter")->assert_is_op_input(conv_type_, "Filter")->AsInput(); - auto* bias = - VarNode("bias")->assert_is_op_input(conv_type_, "Bias")->AsInput(); - - // create op nodes - auto* conv2d = - OpNode("conv2d", conv_type_)->assert_is_op(conv_type_)->AsIntermediate(); + PMNode* bias = nullptr; + if (has_bias_) { + bias = VarNode("bias")->assert_is_op_input(conv_type_, "Bias")->AsInput(); + } + auto* conv2d = OpNode("conv2d", conv_type_)->AsIntermediate(); - auto* act = - OpNode("act", act_type_)->assert_is_op(act_type_)->AsIntermediate(); + auto* act = OpNode("act", act_type_)->AsIntermediate(); - // create intermediate nodes auto* conv2d_out = VarNode("conv2d_out") ->assert_is_op_output(conv_type_, "Output") ->assert_is_op_input(act_type_, "X") ->AsIntermediate(); - // create output node auto* out = VarNode("output")->assert_is_op_output(act_type_, "Out")->AsOutput(); // create topology. - std::vector conv2d_inputs{filter, input, bias}; - conv2d_inputs >> *conv2d >> *conv2d_out; - *conv2d_out >> *act >> *out; + std::vector conv2d_inputs{filter, input}; + conv2d_inputs >> *conv2d >> *conv2d_out >> *act >> *out; + if (has_bias_) { + *bias >> *conv2d; + } } void ConvActivationFuser::InsertNewNode(SSAGraph* graph, @@ -66,33 +64,15 @@ void ConvActivationFuser::InsertNewNode(SSAGraph* graph, IR_NODE_LINK_TO(matched.at("input"), new_op_node); IR_NODE_LINK_TO(matched.at("filter"), new_op_node); - IR_NODE_LINK_TO(matched.at("bias"), new_op_node); + if (has_bias_) { + IR_NODE_LINK_TO(matched.at("bias"), new_op_node); + } IR_NODE_LINK_TO(new_op_node, matched.at("output")); } cpp::OpDesc ConvActivationFuser::GenOpDesc(const key2nodes_t& matched) { - auto* desc = matched.at("conv2d")->stmt()->op_info(); - - cpp::OpDesc op_desc = *desc; - op_desc.SetType(conv_type_); - op_desc.SetInput("Input", {matched.at("input")->arg()->name}); - op_desc.SetInput("Filter", {matched.at("filter")->arg()->name}); - op_desc.SetInput("Bias", {matched.at("bias")->arg()->name}); + cpp::OpDesc op_desc = *matched.at("conv2d")->stmt()->op_info(); op_desc.SetOutput("Output", {matched.at("output")->arg()->name}); - // Other inputs. See operators/conv_op.h - std::vector input_arg_names = desc->InputArgumentNames(); - - if (std::find(input_arg_names.begin(), - input_arg_names.end(), - "ResidualData") != input_arg_names.end()) { - op_desc.SetInput("ResidualData", desc->Input("ResidualData")); - } - // Only consider strides, padding, groups, dilations, fuse_relu for now - op_desc.SetAttr("strides", desc->GetAttr>("strides")); - op_desc.SetAttr("paddings", desc->GetAttr>("paddings")); - op_desc.SetAttr("groups", desc->GetAttr("groups")); - op_desc.SetAttr("dilations", desc->GetAttr>("dilations")); - // TODO(sangoly): support other activation types op_desc.SetAttr("fuse_relu", true); return op_desc; } diff --git a/lite/core/mir/fusion/conv_activation_fuser.h b/lite/core/mir/fusion/conv_activation_fuser.h index 3377e28e29..0d09c9dce2 100644 --- a/lite/core/mir/fusion/conv_activation_fuser.h +++ b/lite/core/mir/fusion/conv_activation_fuser.h @@ -26,10 +26,12 @@ namespace fusion { class ConvActivationFuser : public FuseBase { public: explicit ConvActivationFuser(const std::string& conv_type, - const std::string& act_type) { + const std::string& act_type, + bool has_bias) { CHECK(act_type == "relu") << "Only relu activation be supported now"; conv_type_ = conv_type; act_type_ = act_type; + has_bias_ = has_bias; } void BuildPattern() override; @@ -39,6 +41,7 @@ class ConvActivationFuser : public FuseBase { cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override; std::string conv_type_; std::string act_type_; + bool has_bias_; }; } // namespace fusion diff --git a/lite/core/mir/subgraph/generate_npu_program_pass_test.cc b/lite/core/mir/subgraph/generate_npu_program_pass_test.cc index 8bfdb7381b..a1f39441cb 100644 --- a/lite/core/mir/subgraph/generate_npu_program_pass_test.cc +++ b/lite/core/mir/subgraph/generate_npu_program_pass_test.cc @@ -70,7 +70,8 @@ void CompareOutData(const lite::Predictor& tgt, const lite::Predictor& ref) { const auto* ref_pdata = ref_otensor->data(); EXPECT_EQ(tgt_otensor->dims().production(), ref_otensor->dims().production()); for (size_t i = 0; i < tgt_otensor->dims().production(); ++i) { - auto diff = std::fabs((tgt_pdata[i] - ref_pdata[i]) / ref_pdata[i]); + auto diff = std::fabs(tgt_pdata[i] - ref_pdata[i]) / + (std::fabs(ref_pdata[i]) + 1e-6); VLOG(3) << diff; EXPECT_LT(diff, 0.1); } -- GitLab