From c9bd2d50f1d9c0db255ebc132b7c74438f3b3bba Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Fri, 7 Sep 2018 12:51:36 +0800 Subject: [PATCH] refine fc and gru pattern --- .../framework/ir/graph_pattern_detector.cc | 45 +++++++++---------- 1 file changed, 20 insertions(+), 25 deletions(-) diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 37566b7621..69a323a8bd 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -519,50 +519,41 @@ bool VarLinksFromOp(Node* node, const std::string& op_type) { PDNode* patterns::FC(PDPattern* pattern, const std::string& name_scope, PDNode* x, bool with_bias) { - // Create Operators - PDNode* elementwise_add_op{nullptr}; + // mul op auto* mul_op = pattern->NewNode(name_scope, "mul")->assert_is_op("mul"); - if (with_bias) { - elementwise_add_op = pattern->NewNode(name_scope, "elementwise_add") - ->assert_is_op("elementwise_add"); - } - // Create variables - // w auto* mul_weight_var = pattern->NewNode(name_scope, "w") ->AsInput() ->assert_is_persistable_var() - ->assert_is_op_nth_input("mul", "Y", 0); - PDNode* mul_out_var{nullptr}; + ->assert_is_op_input("mul", "Y"); + + PDNode* fc_out{nullptr}; if (with_bias) { + PDNode* elementwise_add_op{nullptr}; + PDNode *mul_out_var{nullptr}, *bias{nullptr}; + elementwise_add_op = pattern->NewNode(name_scope, "elementwise_add") + ->assert_is_op("elementwise_add"); // intermediate variable, will be removed in the IR after fuse. mul_out_var = pattern->NewNode(name_scope, "mul_out") ->AsIntermediate() ->assert_is_only_output_of_op("mul") - ->assert_is_op_input("elementwise_add"); - } - PDNode *bias{nullptr}, *fc_out{nullptr}; - if (with_bias) { + ->assert_is_op_input("elementwise_add", "X"); // bias bias = pattern->NewNode(name_scope, "fc_bias") - ->assert_is_op_input("elementwise_add") - ->AsInput(); + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("elementwise_add", "Y"); // output fc_out = pattern->NewNode(name_scope, "fc_out") ->AsOutput() - ->assert_is_op_output("elementwise_add"); + ->assert_is_op_output("elementwise_add", "Out"); + mul_op->LinksFrom({x, mul_weight_var}).LinksTo({mul_out_var}); + elementwise_add_op->LinksFrom({mul_out_var, bias}).LinksTo({fc_out}); } else { fc_out = pattern->NewNode(name_scope, "fc_out") ->AsOutput() - ->assert_is_op_output("mul"); - } - - if (with_bias) { - mul_op->LinksFrom({mul_weight_var, x}).LinksTo({mul_out_var}); - elementwise_add_op->LinksFrom({mul_out_var, bias}).LinksTo({fc_out}); - } else { + ->assert_is_op_output("mul", "Out"); mul_op->LinksFrom({mul_weight_var, x}).LinksTo({fc_out}); } - return fc_out; } @@ -609,6 +600,10 @@ PDNode* patterns::GRU(PDPattern* pattern, const std::string& name_scope, NEW_NODE(gru, BatchResetHiddenPrev, output); NEW_NODE(gru, BatchHidden, output); + BatchGate->AsIntermediate(); + BatchResetHiddenPrev->AsIntermediate(); + BatchHidden->AsIntermediate(); + gru_op->LinksFrom({x, Weight, Bias}); gru_op->LinksTo({Hidden, BatchGate, BatchResetHiddenPrev, BatchHidden}); return Hidden; -- GitLab