未验证 提交 73450636 编写于 作者: G GaoWei8 提交者: GitHub

[X86] Enhance fc_fuse_pass to enable fusing relu to fc_op (#2701)

* Enhance fc_fuse_pass to enable fusing relu to fc_op
test=develop

* restrict fusing relu in x86
test=develop
上级 8c0397c6
...@@ -23,8 +23,13 @@ namespace lite { ...@@ -23,8 +23,13 @@ namespace lite {
namespace mir { namespace mir {
void FcFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) { void FcFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
fusion::FcFuser fuser; #ifdef LITE_WITH_X86
fusion::FcFuser fuser(true);
fuser(graph.get()); fuser(graph.get());
#endif
fusion::FcFuser fuser2(false);
fuser2(graph.get());
} }
} // namespace mir } // namespace mir
......
...@@ -88,6 +88,7 @@ USE_LITE_OP(mul); ...@@ -88,6 +88,7 @@ USE_LITE_OP(mul);
USE_LITE_OP(elementwise_add); USE_LITE_OP(elementwise_add);
USE_LITE_OP(elementwise_sub); USE_LITE_OP(elementwise_sub);
USE_LITE_OP(fc); USE_LITE_OP(fc);
USE_LITE_OP(relu);
USE_LITE_OP(feed); USE_LITE_OP(feed);
USE_LITE_OP(fetch); USE_LITE_OP(fetch);
USE_LITE_OP(io_copy); USE_LITE_OP(io_copy);
......
...@@ -35,12 +35,23 @@ void FcFuser::BuildPattern() { ...@@ -35,12 +35,23 @@ void FcFuser::BuildPattern() {
std::vector<PMNode*> mul_inputs{W, x}; std::vector<PMNode*> mul_inputs{W, x};
std::vector<PMNode*> add_inputs{mul_out, b}; std::vector<PMNode*> add_inputs{mul_out, b};
mul_inputs >> *mul >> *mul_out; mul_inputs >> *mul >> *mul_out;
add_inputs >> *add >> *Out;
// Some op specialities. // Some op specialities.
mul_out->AsIntermediate(); mul_out->AsIntermediate();
mul->AsIntermediate(); mul->AsIntermediate();
add->AsIntermediate(); add->AsIntermediate();
if (with_relu_) {
auto* add_out = VarNode("add_out");
auto* relu = OpNode("relu", "relu");
std::vector<PMNode*> relu_inputs{add_out};
add_inputs >> *add >> *add_out;
relu_inputs >> *relu >> *Out;
add_out->AsIntermediate();
relu->AsIntermediate();
} else {
add_inputs >> *add >> *Out;
}
} }
void FcFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) { void FcFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
...@@ -71,6 +82,9 @@ cpp::OpDesc FcFuser::GenOpDesc(const key2nodes_t& matched) { ...@@ -71,6 +82,9 @@ cpp::OpDesc FcFuser::GenOpDesc(const key2nodes_t& matched) {
op_desc.SetAttr( op_desc.SetAttr(
"in_num_col_dims", "in_num_col_dims",
matched.at("mul")->stmt()->op_info()->GetAttr<int>("x_num_col_dims")); matched.at("mul")->stmt()->op_info()->GetAttr<int>("x_num_col_dims"));
if (with_relu_) {
op_desc.SetAttr("activation_type", std::string{"relu"});
}
return op_desc; return op_desc;
} }
......
...@@ -25,11 +25,13 @@ namespace fusion { ...@@ -25,11 +25,13 @@ namespace fusion {
class FcFuser : public FuseBase { class FcFuser : public FuseBase {
public: public:
explicit FcFuser(bool with_relu) : with_relu_(with_relu) {}
void BuildPattern() override; void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
private: private:
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override; cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override;
bool with_relu_;
}; };
} // namespace fusion } // namespace fusion
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册