提交 c9bd2d50 编写于 作者: T tensor-tang

refine fc and gru pattern

上级 7eebb905
......@@ -519,50 +519,41 @@ bool VarLinksFromOp(Node* node, const std::string& op_type) {
PDNode* patterns::FC(PDPattern* pattern, const std::string& name_scope,
PDNode* x, bool with_bias) {
// Create Operators
PDNode* elementwise_add_op{nullptr};
// mul op
auto* mul_op = pattern->NewNode(name_scope, "mul")->assert_is_op("mul");
if (with_bias) {
elementwise_add_op = pattern->NewNode(name_scope, "elementwise_add")
->assert_is_op("elementwise_add");
}
// Create variables
// w
auto* mul_weight_var = pattern->NewNode(name_scope, "w")
->AsInput()
->assert_is_persistable_var()
->assert_is_op_nth_input("mul", "Y", 0);
PDNode* mul_out_var{nullptr};
->assert_is_op_input("mul", "Y");
PDNode* fc_out{nullptr};
if (with_bias) {
PDNode* elementwise_add_op{nullptr};
PDNode *mul_out_var{nullptr}, *bias{nullptr};
elementwise_add_op = pattern->NewNode(name_scope, "elementwise_add")
->assert_is_op("elementwise_add");
// intermediate variable, will be removed in the IR after fuse.
mul_out_var = pattern->NewNode(name_scope, "mul_out")
->AsIntermediate()
->assert_is_only_output_of_op("mul")
->assert_is_op_input("elementwise_add");
}
PDNode *bias{nullptr}, *fc_out{nullptr};
if (with_bias) {
->assert_is_op_input("elementwise_add", "X");
// bias
bias = pattern->NewNode(name_scope, "fc_bias")
->assert_is_op_input("elementwise_add")
->AsInput();
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("elementwise_add", "Y");
// output
fc_out = pattern->NewNode(name_scope, "fc_out")
->AsOutput()
->assert_is_op_output("elementwise_add");
->assert_is_op_output("elementwise_add", "Out");
mul_op->LinksFrom({x, mul_weight_var}).LinksTo({mul_out_var});
elementwise_add_op->LinksFrom({mul_out_var, bias}).LinksTo({fc_out});
} else {
fc_out = pattern->NewNode(name_scope, "fc_out")
->AsOutput()
->assert_is_op_output("mul");
}
if (with_bias) {
mul_op->LinksFrom({mul_weight_var, x}).LinksTo({mul_out_var});
elementwise_add_op->LinksFrom({mul_out_var, bias}).LinksTo({fc_out});
} else {
->assert_is_op_output("mul", "Out");
mul_op->LinksFrom({mul_weight_var, x}).LinksTo({fc_out});
}
return fc_out;
}
......@@ -609,6 +600,10 @@ PDNode* patterns::GRU(PDPattern* pattern, const std::string& name_scope,
NEW_NODE(gru, BatchResetHiddenPrev, output);
NEW_NODE(gru, BatchHidden, output);
BatchGate->AsIntermediate();
BatchResetHiddenPrev->AsIntermediate();
BatchHidden->AsIntermediate();
gru_op->LinksFrom({x, Weight, Bias});
gru_op->LinksTo({Hidden, BatchGate, BatchResetHiddenPrev, BatchHidden});
return Hidden;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册