diff --git a/lite/core/mir/fusion/fc_fuse_pass.cc b/lite/core/mir/fusion/fc_fuse_pass.cc index 7fc449219251bbd7e639e8092099f43fe8eca626..120b312aebb972caa7c58609f97b01db95c9e862 100644 --- a/lite/core/mir/fusion/fc_fuse_pass.cc +++ b/lite/core/mir/fusion/fc_fuse_pass.cc @@ -23,8 +23,13 @@ namespace lite { namespace mir { void FcFusePass::Apply(const std::unique_ptr& graph) { - fusion::FcFuser fuser; +#ifdef LITE_WITH_X86 + fusion::FcFuser fuser(true); fuser(graph.get()); +#endif + + fusion::FcFuser fuser2(false); + fuser2(graph.get()); } } // namespace mir diff --git a/lite/core/mir/fusion/fc_fuse_pass_test.cc b/lite/core/mir/fusion/fc_fuse_pass_test.cc index f7aa4bb5adcb848531ecc3a8f63bace1c2e3e0ff..54260732c5efe788f0d3740197253fa2321a7d02 100644 --- a/lite/core/mir/fusion/fc_fuse_pass_test.cc +++ b/lite/core/mir/fusion/fc_fuse_pass_test.cc @@ -88,6 +88,7 @@ USE_LITE_OP(mul); USE_LITE_OP(elementwise_add); USE_LITE_OP(elementwise_sub); USE_LITE_OP(fc); +USE_LITE_OP(relu); USE_LITE_OP(feed); USE_LITE_OP(fetch); USE_LITE_OP(io_copy); diff --git a/lite/core/mir/fusion/fc_fuser.cc b/lite/core/mir/fusion/fc_fuser.cc index 460c0fdf7a4309638b9852a315ca0efda02801ab..3c99131083d37ea2c8511ed136bff17c891529af 100644 --- a/lite/core/mir/fusion/fc_fuser.cc +++ b/lite/core/mir/fusion/fc_fuser.cc @@ -35,12 +35,23 @@ void FcFuser::BuildPattern() { std::vector mul_inputs{W, x}; std::vector add_inputs{mul_out, b}; mul_inputs >> *mul >> *mul_out; - add_inputs >> *add >> *Out; // Some op specialities. mul_out->AsIntermediate(); mul->AsIntermediate(); add->AsIntermediate(); + + if (with_relu_) { + auto* add_out = VarNode("add_out"); + auto* relu = OpNode("relu", "relu"); + std::vector relu_inputs{add_out}; + add_inputs >> *add >> *add_out; + relu_inputs >> *relu >> *Out; + add_out->AsIntermediate(); + relu->AsIntermediate(); + } else { + add_inputs >> *add >> *Out; + } } void FcFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) { @@ -71,6 +82,9 @@ cpp::OpDesc FcFuser::GenOpDesc(const key2nodes_t& matched) { op_desc.SetAttr( "in_num_col_dims", matched.at("mul")->stmt()->op_info()->GetAttr("x_num_col_dims")); + if (with_relu_) { + op_desc.SetAttr("activation_type", std::string{"relu"}); + } return op_desc; } diff --git a/lite/core/mir/fusion/fc_fuser.h b/lite/core/mir/fusion/fc_fuser.h index 7ba07527898c7e648c5f7f9151642ab0928fa496..6cb08f41574b67df1c78fa296d2d395771a66ee1 100644 --- a/lite/core/mir/fusion/fc_fuser.h +++ b/lite/core/mir/fusion/fc_fuser.h @@ -25,11 +25,13 @@ namespace fusion { class FcFuser : public FuseBase { public: + explicit FcFuser(bool with_relu) : with_relu_(with_relu) {} void BuildPattern() override; void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; private: cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override; + bool with_relu_; }; } // namespace fusion