未验证 提交 73450636 编写于 作者: G GaoWei8 提交者: GitHub

[X86] Enhance fc_fuse_pass to enable fusing relu to fc_op (#2701)

* Enhance fc_fuse_pass to enable fusing relu to fc_op
test=develop

* restrict fusing relu in x86
test=develop
上级 8c0397c6
......@@ -23,8 +23,13 @@ namespace lite {
namespace mir {
void FcFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
fusion::FcFuser fuser;
#ifdef LITE_WITH_X86
fusion::FcFuser fuser(true);
fuser(graph.get());
#endif
fusion::FcFuser fuser2(false);
fuser2(graph.get());
}
} // namespace mir
......
......@@ -88,6 +88,7 @@ USE_LITE_OP(mul);
USE_LITE_OP(elementwise_add);
USE_LITE_OP(elementwise_sub);
USE_LITE_OP(fc);
USE_LITE_OP(relu);
USE_LITE_OP(feed);
USE_LITE_OP(fetch);
USE_LITE_OP(io_copy);
......
......@@ -35,12 +35,23 @@ void FcFuser::BuildPattern() {
std::vector<PMNode*> mul_inputs{W, x};
std::vector<PMNode*> add_inputs{mul_out, b};
mul_inputs >> *mul >> *mul_out;
add_inputs >> *add >> *Out;
// Some op specialities.
mul_out->AsIntermediate();
mul->AsIntermediate();
add->AsIntermediate();
if (with_relu_) {
auto* add_out = VarNode("add_out");
auto* relu = OpNode("relu", "relu");
std::vector<PMNode*> relu_inputs{add_out};
add_inputs >> *add >> *add_out;
relu_inputs >> *relu >> *Out;
add_out->AsIntermediate();
relu->AsIntermediate();
} else {
add_inputs >> *add >> *Out;
}
}
void FcFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
......@@ -71,6 +82,9 @@ cpp::OpDesc FcFuser::GenOpDesc(const key2nodes_t& matched) {
op_desc.SetAttr(
"in_num_col_dims",
matched.at("mul")->stmt()->op_info()->GetAttr<int>("x_num_col_dims"));
if (with_relu_) {
op_desc.SetAttr("activation_type", std::string{"relu"});
}
return op_desc;
}
......
......@@ -25,11 +25,13 @@ namespace fusion {
class FcFuser : public FuseBase {
public:
explicit FcFuser(bool with_relu) : with_relu_(with_relu) {}
void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
private:
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override;
bool with_relu_;
};
} // namespace fusion
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册