From 7345063611a031b545b56b3ca6b13ab48b0d08aa Mon Sep 17 00:00:00 2001 From: GaoWei8 <53294385+GaoWei8@users.noreply.github.com> Date: Thu, 2 Jan 2020 17:37:48 +0800 Subject: [PATCH] [X86] Enhance fc_fuse_pass to enable fusing relu to fc_op (#2701) * Enhance fc_fuse_pass to enable fusing relu to fc_op test=develop * restrict fusing relu in x86 test=develop --- lite/core/mir/fusion/fc_fuse_pass.cc | 7 ++++++- lite/core/mir/fusion/fc_fuse_pass_test.cc | 1 + lite/core/mir/fusion/fc_fuser.cc | 16 +++++++++++++++- lite/core/mir/fusion/fc_fuser.h | 2 ++ 4 files changed, 24 insertions(+), 2 deletions(-) diff --git a/lite/core/mir/fusion/fc_fuse_pass.cc b/lite/core/mir/fusion/fc_fuse_pass.cc index 7fc4492192..120b312aeb 100644 --- a/lite/core/mir/fusion/fc_fuse_pass.cc +++ b/lite/core/mir/fusion/fc_fuse_pass.cc @@ -23,8 +23,13 @@ namespace lite { namespace mir { void FcFusePass::Apply(const std::unique_ptr& graph) { - fusion::FcFuser fuser; +#ifdef LITE_WITH_X86 + fusion::FcFuser fuser(true); fuser(graph.get()); +#endif + + fusion::FcFuser fuser2(false); + fuser2(graph.get()); } } // namespace mir diff --git a/lite/core/mir/fusion/fc_fuse_pass_test.cc b/lite/core/mir/fusion/fc_fuse_pass_test.cc index f7aa4bb5ad..54260732c5 100644 --- a/lite/core/mir/fusion/fc_fuse_pass_test.cc +++ b/lite/core/mir/fusion/fc_fuse_pass_test.cc @@ -88,6 +88,7 @@ USE_LITE_OP(mul); USE_LITE_OP(elementwise_add); USE_LITE_OP(elementwise_sub); USE_LITE_OP(fc); +USE_LITE_OP(relu); USE_LITE_OP(feed); USE_LITE_OP(fetch); USE_LITE_OP(io_copy); diff --git a/lite/core/mir/fusion/fc_fuser.cc b/lite/core/mir/fusion/fc_fuser.cc index 460c0fdf7a..3c99131083 100644 --- a/lite/core/mir/fusion/fc_fuser.cc +++ b/lite/core/mir/fusion/fc_fuser.cc @@ -35,12 +35,23 @@ void FcFuser::BuildPattern() { std::vector mul_inputs{W, x}; std::vector add_inputs{mul_out, b}; mul_inputs >> *mul >> *mul_out; - add_inputs >> *add >> *Out; // Some op specialities. mul_out->AsIntermediate(); mul->AsIntermediate(); add->AsIntermediate(); + + if (with_relu_) { + auto* add_out = VarNode("add_out"); + auto* relu = OpNode("relu", "relu"); + std::vector relu_inputs{add_out}; + add_inputs >> *add >> *add_out; + relu_inputs >> *relu >> *Out; + add_out->AsIntermediate(); + relu->AsIntermediate(); + } else { + add_inputs >> *add >> *Out; + } } void FcFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) { @@ -71,6 +82,9 @@ cpp::OpDesc FcFuser::GenOpDesc(const key2nodes_t& matched) { op_desc.SetAttr( "in_num_col_dims", matched.at("mul")->stmt()->op_info()->GetAttr("x_num_col_dims")); + if (with_relu_) { + op_desc.SetAttr("activation_type", std::string{"relu"}); + } return op_desc; } diff --git a/lite/core/mir/fusion/fc_fuser.h b/lite/core/mir/fusion/fc_fuser.h index 7ba0752789..6cb08f4157 100644 --- a/lite/core/mir/fusion/fc_fuser.h +++ b/lite/core/mir/fusion/fc_fuser.h @@ -25,11 +25,13 @@ namespace fusion { class FcFuser : public FuseBase { public: + explicit FcFuser(bool with_relu) : with_relu_(with_relu) {} void BuildPattern() override; void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; private: cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override; + bool with_relu_; }; } // namespace fusion -- GitLab