From c67c8758cbc3e4bac3544f9e981f1b02e4bd4d60 Mon Sep 17 00:00:00 2001 From: Yiqun Liu Date: Mon, 16 Sep 2019 13:14:15 +0800 Subject: [PATCH] Enhance fc_fuse_pass to enable fusing relu to fc_op (#19733) * Refine the codes related to fc op. * Add GPU implementation for fc functor. * Apply fc_fuse_pass in GPU inference. test=develop * Change the cmake for fc op. * Change PADDLE_ENFORCE to PADDLE_ENFORCE_EQ. * Add an attribute to set the activation type in fc_op. * Enhance the unittest of fc_op. test=develop * Remove the declaration of FCOpGrad back to the header file. test=develop * Set default value for newly added arguments in test_fc_op. test=develop * Enhance fc_fuse_pass to enable fusing relu. * Allow print the shapes of var_desc in graph. test=develop * Enhance fc_fuse_pass_tester. * Remove the use of PADDLE_ENFORCE. test=develop * Correct the number of ops after fusing. test=develop * Fix a typo. test=develop * Set activation_type to null when there is no relu in fc. test=develop * Refine fc_fuse_pass's codes. * Enable the set of shape for tensor. * Refine repeated_fc_relu_pass and add unittest. test=develop --- paddle/fluid/framework/ir/CMakeLists.txt | 1 + .../ir/embedding_fc_lstm_fuse_pass.cc | 5 +- paddle/fluid/framework/ir/fc_fuse_pass.cc | 105 ++++--- paddle/fluid/framework/ir/fc_fuse_pass.h | 4 +- .../fluid/framework/ir/fc_fuse_pass_tester.cc | 102 +++--- paddle/fluid/framework/ir/fc_gru_fuse_pass.cc | 4 +- .../fluid/framework/ir/fc_lstm_fuse_pass.cc | 5 +- .../framework/ir/graph_pattern_detector.cc | 33 +- .../framework/ir/graph_pattern_detector.h | 6 +- paddle/fluid/framework/ir/graph_viz_pass.cc | 11 + .../fluid/framework/ir/pass_tester_helper.h | 43 ++- .../ir/repeated_fc_relu_fuse_pass.cc | 297 +++++++++--------- .../ir/repeated_fc_relu_fuse_pass_tester.cc | 71 +++++ .../tests/api/analyzer_seq_conv1_tester.cc | 2 +- paddle/fluid/operators/fc_op.cc | 2 +- 15 files changed, 411 insertions(+), 280 deletions(-) create mode 100644 paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass_tester.cc diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 792df4b30d..cae7e90255 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -119,6 +119,7 @@ cc_test(test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS g cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass framework_proto) cc_test(test_seqpool_concat_fuse_pass SRCS seqpool_concat_fuse_pass_tester.cc DEPS seqpool_concat_fuse_pass framework_proto) cc_test(test_seqpool_cvm_concat_fuse_pass SRCS seqpool_cvm_concat_fuse_pass_tester.cc DEPS seqpool_cvm_concat_fuse_pass framework_proto) +cc_test(test_repeated_fc_relu_fuse_pass SRCS repeated_fc_relu_fuse_pass_tester.cc DEPS repeated_fc_relu_fuse_pass framework_proto) cc_test(test_is_test_pass SRCS is_test_pass_tester.cc DEPS is_test_pass) cc_test(test_simplify_with_basic_ops_pass SRCS simplify_with_basic_ops_pass_tester.cc DEPS simplify_with_basic_ops_pass) if(WITH_GPU) diff --git a/paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.cc b/paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.cc index b29b37992d..21ceec7927 100644 --- a/paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.cc @@ -44,7 +44,8 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, patterns::FC fc_pattern(pattern, name_scope); // fc_out is a tmp var, will be removed after fuse, so marked as intermediate. - auto* fc_out = fc_pattern(embedding_out, with_fc_bias)->AsIntermediate(); + auto* fc_out = fc_pattern(embedding_out, with_fc_bias, /* with_relu */ false) + ->AsIntermediate(); patterns::LSTM lstm_pattern(pattern, name_scope); lstm_pattern(fc_out); @@ -194,7 +195,7 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, } if (with_fc_bias) { - GET_IR_NODE_FROM_SUBGRAPH(fc_out, Out, fc_pattern); + GET_IR_NODE_FROM_SUBGRAPH(fc_out, elementwise_add_out, fc_pattern); GET_IR_NODE_FROM_SUBGRAPH(fc_bias, bias, fc_pattern); GET_IR_NODE_FROM_SUBGRAPH(elementwise_add, elementwise_add, fc_pattern); embedding_lstm_creator(lookup_table, W, lstm, subgraph.at(x), w, Weight, diff --git a/paddle/fluid/framework/ir/fc_fuse_pass.cc b/paddle/fluid/framework/ir/fc_fuse_pass.cc index 102fd38865..b53e6a250c 100644 --- a/paddle/fluid/framework/ir/fc_fuse_pass.cc +++ b/paddle/fluid/framework/ir/fc_fuse_pass.cc @@ -25,83 +25,110 @@ namespace framework { namespace ir { void FCFusePass::ApplyImpl(ir::Graph* graph) const { - PADDLE_ENFORCE(graph); + PADDLE_ENFORCE_NOT_NULL(graph); FusePassBase::Init("fc_fuse", graph); - std::unordered_set nodes2delete; + int found_fc_count = 0; + for (bool with_relu : {true, false}) { + found_fc_count += ApplyFCPattern(graph, with_relu); + } + AddStatis(found_fc_count); +} + +int FCFusePass::ApplyFCPattern(Graph* graph, bool with_relu) const { GraphPatternDetector gpd; auto* x = gpd.mutable_pattern() ->NewNode("fc_fuse/x") ->AsInput() ->assert_is_op_input("mul", "X"); patterns::FC fc_pattern(gpd.mutable_pattern(), "fc_fuse"); - fc_pattern(x, true /*with bias*/); + fc_pattern(x, true /*with bias*/, with_relu); int found_fc_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { + if (subgraph.count(x) <= 0) { + LOG(WARNING) << "The subgraph is empty."; + return; + } + VLOG(4) << "handle FC fuse"; GET_IR_NODE_FROM_SUBGRAPH(w, w, fc_pattern); - GET_IR_NODE_FROM_SUBGRAPH(fc_bias, bias, fc_pattern); - GET_IR_NODE_FROM_SUBGRAPH(fc_out, Out, fc_pattern); + GET_IR_NODE_FROM_SUBGRAPH(bias, bias, fc_pattern); + GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out, + fc_pattern); GET_IR_NODE_FROM_SUBGRAPH(mul, mul, fc_pattern); GET_IR_NODE_FROM_SUBGRAPH(elementwise_add, elementwise_add, fc_pattern); GET_IR_NODE_FROM_SUBGRAPH(mul_out, mul_out, fc_pattern); + Node* relu = nullptr; + Node* relu_out = nullptr; + if (with_relu) { + GET_IR_NODE_FROM_SUBGRAPH(tmp_relu, relu, fc_pattern); + GET_IR_NODE_FROM_SUBGRAPH(tmp_relu_out, relu_out, fc_pattern); + relu = tmp_relu; + relu_out = tmp_relu_out; + } - auto base_op_desc = mul->Op(); // Create an FC Node. - // OpDesc desc(base_op_desc, nullptr); OpDesc desc; - std::string fc_x_in = subgraph.at(x)->Name(); - std::string fc_Y_in = w->Name(); - std::string fc_bias_in = fc_bias->Name(); - std::string fc_out_out = fc_out->Name(); - - desc.SetInput("Input", std::vector({fc_x_in})); - desc.SetInput("W", std::vector({fc_Y_in})); - desc.SetInput("Bias", std::vector({fc_bias_in})); - desc.SetOutput("Out", std::vector({fc_out_out})); + desc.SetType("fc"); + + // Set inputs of fc + desc.SetInput("Input", {subgraph.at(x)->Name()}); + desc.SetInput("W", {w->Name()}); + desc.SetInput("Bias", {bias->Name()}); + + // Set output of fc + std::string fc_out_name = + with_relu ? relu_out->Name() : elementwise_add_out->Name(); + desc.SetOutput("Out", std::vector({fc_out_name})); + + // Set attrs of fc desc.SetAttr("in_num_col_dims", mul->Op()->GetAttr("x_num_col_dims")); + std::string activation_type = with_relu ? "relu" : ""; + desc.SetAttr("activation_type", activation_type); // For anakin subgraph int8 // When in anakin subgraph int8 mode, the pattern like "fake_quant + mul + - // fake_dequant" - // can be detected by the quant_dequant_fuse_pass. This pass will add - // "input_scale", - // "weight_scale" which are extracted from fake_quant op and fake_dequant op - // to mul op, - // and then delete the fake_quant op and fake_dequant op in the graph. If - // the mul op - // has the scale info, we should add those to the fused fc. - if (base_op_desc->HasAttr("enable_int8")) { - desc.SetAttr("enable_int8", base_op_desc->GetAttr("enable_int8")); - desc.SetAttr("input_scale", base_op_desc->GetAttr("input_scale")); - desc.SetAttr("weight_scale", base_op_desc->GetAttr("weight_scale")); - if (base_op_desc->HasAttr("out_scale")) - desc.SetAttr("out_scale", base_op_desc->GetAttr("out_scale")); + // fake_dequant" can be detected by the quant_dequant_fuse_pass. This pass + // will add "input_scale", "weight_scale" which are extracted from + // fake_quant op and fake_dequant op to mul op, and then delete the + // fake_quant op and fake_dequant op in the graph. If the mul op has the + // scale info, we should add those to the fused fc. + auto* mul_op_desc = mul->Op(); + if (mul_op_desc->HasAttr("enable_int8")) { + desc.SetAttr("enable_int8", mul_op_desc->GetAttr("enable_int8")); + desc.SetAttr("input_scale", mul_op_desc->GetAttr("input_scale")); + desc.SetAttr("weight_scale", mul_op_desc->GetAttr("weight_scale")); + if (mul_op_desc->HasAttr("out_scale")) + desc.SetAttr("out_scale", mul_op_desc->GetAttr("out_scale")); auto elementwise_desc = elementwise_add->Op(); if (elementwise_desc->HasAttr("out_scale")) desc.SetAttr("out_scale", elementwise_desc->GetAttr("out_scale")); } - desc.SetType("fc"); - auto fc_node = g->CreateOpNode(&desc); // OpDesc will be copied. - GraphSafeRemoveNodes(graph, {mul, elementwise_add, mul_out}); + if (with_relu) { + GraphSafeRemoveNodes( + graph, {mul, elementwise_add, mul_out, elementwise_add_out, relu}); + } else { + GraphSafeRemoveNodes(graph, {mul, elementwise_add, mul_out}); + } - PADDLE_ENFORCE(subgraph.count(x)); IR_NODE_LINK_TO(subgraph.at(x), fc_node); IR_NODE_LINK_TO(w, fc_node); - IR_NODE_LINK_TO(fc_bias, fc_node); - IR_NODE_LINK_TO(fc_node, fc_out); + IR_NODE_LINK_TO(bias, fc_node); + if (with_relu) { + IR_NODE_LINK_TO(fc_node, relu_out); + } else { + IR_NODE_LINK_TO(fc_node, elementwise_add_out); + } found_fc_count++; }; - gpd(graph, handler); - - AddStatis(found_fc_count); + return found_fc_count; } } // namespace ir diff --git a/paddle/fluid/framework/ir/fc_fuse_pass.h b/paddle/fluid/framework/ir/fc_fuse_pass.h index 0a0fcd2da8..ef6636d109 100644 --- a/paddle/fluid/framework/ir/fc_fuse_pass.h +++ b/paddle/fluid/framework/ir/fc_fuse_pass.h @@ -31,7 +31,9 @@ class FCFusePass : public FusePassBase { virtual ~FCFusePass() {} protected: - void ApplyImpl(ir::Graph* graph) const override; + void ApplyImpl(Graph* graph) const override; + + int ApplyFCPattern(Graph* graph, bool with_relu) const; }; } // namespace ir diff --git a/paddle/fluid/framework/ir/fc_fuse_pass_tester.cc b/paddle/fluid/framework/ir/fc_fuse_pass_tester.cc index affe506910..320d28f131 100644 --- a/paddle/fluid/framework/ir/fc_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/fc_fuse_pass_tester.cc @@ -15,81 +15,53 @@ #include "paddle/fluid/framework/ir/fc_fuse_pass.h" #include -#include "paddle/fluid/framework/op_proto_maker.h" +#include "paddle/fluid/framework/ir/pass_tester_helper.h" namespace paddle { namespace framework { namespace ir { -void SetOp(ProgramDesc* prog, const std::string& type, - const std::vector& inputs, - const std::vector& outputs) { - auto* op = prog->MutableBlock(0)->AppendOp(); - op->SetType(type); - if (type == "mul") { - op->SetInput("X", {inputs[0]}); - op->SetInput("Y", {inputs[1]}); - op->SetAttr("x_num_col_dims", {1}); - } else if (type == "elementwise_add") { - op->SetInput("X", inputs); - } - op->SetOutput("Out", outputs); - op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), - static_cast(OpRole::kForward)); -} - -// a->OP0->b -// a->OP1->c -// (b, c)->mul->d -// (d, e)->elementwise_add->f -ProgramDesc BuildProgramDesc() { - ProgramDesc prog; - for (auto& v : std::vector({"a", "b", "c", "d", "e", "f"})) { - auto* var = prog.MutableBlock(0)->Var(v); - var->SetType(proto::VarType::SELECTED_ROWS); - if (v == "c") { - var->SetPersistable(true); - } - } - - SetOp(&prog, "OP0", std::vector({"a"}), - std::vector({"b"})); - SetOp(&prog, "OP1", std::vector({"a"}), - std::vector({"c"})); - SetOp(&prog, "mul", std::vector({"b", "c"}), - std::vector({"d"})); - SetOp(&prog, "elementwise_add", std::vector({"d", "e"}), - std::vector({"f"})); - - return prog; -} - TEST(FCFusePass, basic) { - auto prog = BuildProgramDesc(); - - std::unique_ptr graph(new ir::Graph(prog)); - + // inputs operator output + // -------------------------------------------------------- + // (a, filters_0 bias_0) conv2d -> conv2d_out + // conv2d_out relu -> relu_out_0 + // (relu_out_0, weights_0) mul -> mul_out_0 + // (mul_out_0, bias_1) elementwise_add -> add_out_0 + // add_out_0 relu -> relu_out_1 + // (relu_out_1, weights_1) mul -> mul_out_1 + // (mul_out_1, bias_2) elementwise_add -> add_out_1 + Layers layers; + auto* a = layers.data("a"); + auto* filters_0 = layers.data("conv2d_filters_0", {}, true); + auto* bias_0 = layers.data("conv2d_bias_0", {}, true); + auto* conv2d_out = layers.conv2d(a, filters_0, bias_0, false); + auto* relu_out_0 = layers.relu(conv2d_out); + auto* weights_0 = layers.data("weights_0", {}, true); + auto* mul_out_0 = layers.mul(relu_out_0, weights_0); + auto* bias_1 = layers.data("bias_1", {}, true); + auto* add_out_0 = layers.elementwise_add(mul_out_0, bias_1); + auto* relu_out_1 = layers.relu(add_out_0); + auto* weights_1 = layers.data("weights_1", {}, true); + auto* mul_out_1 = layers.mul(relu_out_1, weights_1); + auto* bias_2 = layers.data("bias_2", {}, true); + auto* add_out_1 = layers.elementwise_add(mul_out_1, bias_2); + VLOG(4) << add_out_1; + + std::unique_ptr graph(new ir::Graph(layers.main_program())); auto pass = PassRegistry::Instance().Get("fc_fuse_pass"); - - int pre_nodes = graph->Nodes().size(); + int num_nodes_before = graph->Nodes().size(); + int num_mul_nodes_before = GetNumOpNodes(graph, "mul"); + VLOG(3) << DebugString(graph); graph.reset(pass->Apply(graph.release())); + int num_nodes_after = graph->Nodes().size(); + int num_fc_nodes_after = GetNumOpNodes(graph, "fc"); + VLOG(3) << DebugString(graph); - int after_nodes = graph->Nodes().size(); - - // Remove 3 Nodes: MUL,ELEMENTWISE_ADD, mul_out - // Add 1 Node: FC - EXPECT_EQ(pre_nodes - 2, after_nodes); - - // Assert fc op in newly generated graph - int fc_count = 0; - - for (auto* node : graph->Nodes()) { - if (node->IsOp() && node->Op()->Type() == "fc") { - ++fc_count; - } - } - EXPECT_EQ(fc_count, 1); + PADDLE_ENFORCE_EQ(num_nodes_before, num_nodes_after + 6); + PADDLE_ENFORCE_EQ(num_fc_nodes_after, 2); + PADDLE_ENFORCE_EQ(num_mul_nodes_before, num_fc_nodes_after); } } // namespace ir diff --git a/paddle/fluid/framework/ir/fc_gru_fuse_pass.cc b/paddle/fluid/framework/ir/fc_gru_fuse_pass.cc index 10cbe319ac..287c6dc407 100644 --- a/paddle/fluid/framework/ir/fc_gru_fuse_pass.cc +++ b/paddle/fluid/framework/ir/fc_gru_fuse_pass.cc @@ -33,7 +33,7 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, PDNode* x = pattern->NewNode(patterns::UniqueKey("x"))->assert_var_not_persistable(); - auto* fc_out = fc_pattern(x, with_fc_bias); + auto* fc_out = fc_pattern(x, with_fc_bias, /* with_relu */ false); fc_out->AsIntermediate(); // fc_out is a tmp var, will be removed after fuse. gru_pattern(fc_out); @@ -116,7 +116,7 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, auto* x_n = subgraph.at(x); GET_IR_NODE_FROM_SUBGRAPH(w, w, fc_pattern); GET_IR_NODE_FROM_SUBGRAPH(mul, mul, fc_pattern); - GET_IR_NODE_FROM_SUBGRAPH(fc_out, Out, fc_pattern); + GET_IR_NODE_FROM_SUBGRAPH(fc_out, elementwise_add_out, fc_pattern); GET_IR_NODE_FROM_SUBGRAPH(Weight, Weight, gru_pattern); GET_IR_NODE_FROM_SUBGRAPH(gru, gru, gru_pattern); GET_IR_NODE_FROM_SUBGRAPH(Bias, Bias, gru_pattern); diff --git a/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc b/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc index 6858a98be3..a5a72e875e 100644 --- a/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc @@ -33,7 +33,8 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope, patterns::FC fc_pattern(pattern, name_scope); // fc_out is a tmp var, will be removed after fuse, so marked as intermediate. - auto* fc_out = fc_pattern(x, with_fc_bias)->AsIntermediate(); + auto* fc_out = + fc_pattern(x, with_fc_bias, /* with_relu */ false)->AsIntermediate(); patterns::LSTM lstm_pattern(pattern, name_scope); lstm_pattern(fc_out); @@ -132,7 +133,7 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope, GET_IR_NODE_FROM_SUBGRAPH(w, w, fc_pattern); GET_IR_NODE_FROM_SUBGRAPH(mul, mul, fc_pattern); if (with_fc_bias) { - GET_IR_NODE_FROM_SUBGRAPH(fc_out, Out, fc_pattern); + GET_IR_NODE_FROM_SUBGRAPH(fc_out, elementwise_add_out, fc_pattern); GET_IR_NODE_FROM_SUBGRAPH(fc_bias, bias, fc_pattern); GET_IR_NODE_FROM_SUBGRAPH(elementwise_add, elementwise_add, fc_pattern); lstm_creator(lstm, subgraph.at(x), w, Weight, Bias, Hidden, Cell, fc_out, diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 8bec1f08b0..bbb2ee2f56 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -846,7 +846,7 @@ PDNode *patterns::SeqConvEltAddRelu::operator()( } PDNode *patterns::FC::operator()(paddle::framework::ir::PDNode *x, - bool with_bias) { + bool with_bias, bool with_relu) { // Create shared nodes. x->assert_is_op_input("mul", "X"); auto *mul = pattern->NewNode(mul_repr())->assert_is_op("mul"); @@ -859,11 +859,10 @@ PDNode *patterns::FC::operator()(paddle::framework::ir::PDNode *x, auto *mul_out_var = pattern->NewNode(mul_out_repr())->assert_is_op_output("mul"); + // Add links. + mul->LinksFrom({x, mul_w_var}).LinksTo({mul_out_var}); if (!with_bias) { // not with bias - // Add links. - mul->LinksFrom({x, mul_w_var}).LinksTo({mul_out_var}); return mul_out_var; - } else { // with bias mul_out_var->AsIntermediate()->assert_is_op_input("elementwise_add"); // Create operators. @@ -872,15 +871,29 @@ PDNode *patterns::FC::operator()(paddle::framework::ir::PDNode *x, // Create variables. auto *bias = pattern->NewNode(bias_repr()) ->assert_is_op_input("elementwise_add") + ->assert_is_persistable_var() ->AsInput(); - auto *fc_out = pattern->NewNode(Out_repr()) - ->AsOutput() - ->assert_is_op_output("elementwise_add"); + auto *elementwise_add_out_var = + pattern->NewNode(elementwise_add_out_repr()) + ->AsOutput() + ->assert_is_op_output("elementwise_add"); - mul->LinksFrom({mul_w_var, x}).LinksTo({mul_out_var}); - elementwise_add->LinksFrom({mul_out_var, bias}).LinksTo({fc_out}); - return fc_out; + elementwise_add->LinksFrom({mul_out_var, bias}) + .LinksTo({elementwise_add_out_var}); + if (!with_relu) { + return elementwise_add_out_var; + } else { + elementwise_add_out_var->AsIntermediate()->assert_is_op_input("relu"); + // Create operators. + auto *relu = pattern->NewNode(relu_repr())->assert_is_op("relu"); + auto *relu_out_var = pattern->NewNode(relu_out_repr()) + ->AsOutput() + ->assert_is_op_output("relu"); + + relu->LinksFrom({elementwise_add_out_var}).LinksTo({relu_out_var}); + return relu_out_var; + } } } diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index a99889f7cc..0d7d56cabf 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -487,17 +487,19 @@ struct FC : public PatternBase { FC(PDPattern* pattern, const std::string& name_scope) : PatternBase(pattern, name_scope, "fc") {} - PDNode* operator()(PDNode* x, bool with_bias); + PDNode* operator()(PDNode* x, bool with_bias, bool with_relu); // declare operator node's name PATTERN_DECL_NODE(fc); PATTERN_DECL_NODE(mul); PATTERN_DECL_NODE(elementwise_add); + PATTERN_DECL_NODE(relu); // declare variable node's name PATTERN_DECL_NODE(w); PATTERN_DECL_NODE(mul_out); // (x,w) -> mul_out PATTERN_DECL_NODE(bias); - PATTERN_DECL_NODE(Out); + PATTERN_DECL_NODE(elementwise_add_out); + PATTERN_DECL_NODE(relu_out); }; // MKL-DNN's FC with bias diff --git a/paddle/fluid/framework/ir/graph_viz_pass.cc b/paddle/fluid/framework/ir/graph_viz_pass.cc index 1da3c9fe69..fa7263b7e7 100644 --- a/paddle/fluid/framework/ir/graph_viz_pass.cc +++ b/paddle/fluid/framework/ir/graph_viz_pass.cc @@ -89,6 +89,17 @@ void GraphVizPass::ApplyImpl(ir::Graph* graph) const { marked_nodes.count(n) ? marked_op_attrs : op_attrs; dot.AddNode(node_id, attr, node_id); } else if (n->IsVar()) { + if (n->Var() && n->Var()->GetType() == proto::VarType::LOD_TENSOR) { + bool is_first = true; + for (int64_t length : n->Var()->GetShape()) { + if (is_first) { + node_id += "\n" + std::to_string(length); + is_first = false; + } else { + node_id += "," + std::to_string(length); + } + } + } decltype(op_attrs)* attr; if (marked_nodes.count(n)) { attr = &marked_var_attrs; diff --git a/paddle/fluid/framework/ir/pass_tester_helper.h b/paddle/fluid/framework/ir/pass_tester_helper.h index 26eeacab6e..38faf85cf0 100644 --- a/paddle/fluid/framework/ir/pass_tester_helper.h +++ b/paddle/fluid/framework/ir/pass_tester_helper.h @@ -28,10 +28,13 @@ struct Layers { public: const ProgramDesc& main_program() { return program_; } - VarDesc* data(std::string name) { return lod_tensor(name); } + VarDesc* data(std::string name, std::vector shape = {}, + bool is_persistable = false) { + return lod_tensor(name, shape, is_persistable); + } VarDesc* conv2d(VarDesc* input, VarDesc* filter, VarDesc* bias, - bool use_cudnn) { + bool use_cudnn = false) { VarDesc* out = lod_tensor(unique_name()); OpDesc* op = program_.MutableBlock(0)->AppendOp(); op->SetType("conv2d"); @@ -76,8 +79,27 @@ struct Layers { return unary_op("relu", x, out); } - VarDesc* mul(VarDesc* x, VarDesc* y, VarDesc* out = nullptr) { - return binary_op("mul", x, y, out); + VarDesc* fc(VarDesc* input, VarDesc* w, VarDesc* bias, + int in_num_col_dims = 1, std::string activation_type = "") { + VarDesc* out = lod_tensor(unique_name()); + OpDesc* op = program_.MutableBlock(0)->AppendOp(); + op->SetType("fc"); + op->SetInput("Input", {input->Name()}); + op->SetInput("W", {w->Name()}); + op->SetInput("Bias", {bias->Name()}); + op->SetOutput("Out", {out->Name()}); + op->SetAttr("in_num_col_dims", in_num_col_dims); + op->SetAttr("activation_type", activation_type); + op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), + static_cast(OpRole::kForward)); + return out; + } + + VarDesc* mul(VarDesc* x, VarDesc* y, VarDesc* out = nullptr, + int x_num_col_dims = 1) { + AttributeMap attrs; + attrs["x_num_col_dims"] = 1; + return binary_op("mul", x, y, out, &attrs); } VarDesc* elementwise_add(VarDesc* x, VarDesc* y, VarDesc* out = nullptr) { @@ -116,9 +138,12 @@ struct Layers { } private: - VarDesc* lod_tensor(std::string name) { + VarDesc* lod_tensor(std::string name, std::vector shape = {}, + bool is_persistable = false) { auto* var = program_.MutableBlock(0)->Var(name); var->SetType(proto::VarType::LOD_TENSOR); + var->SetShape(shape); + var->SetPersistable(is_persistable); return var; } @@ -136,7 +161,8 @@ struct Layers { } VarDesc* binary_op(std::string type, VarDesc* x, VarDesc* y, - VarDesc* out = nullptr) { + VarDesc* out = nullptr, + const AttributeMap* attrs = nullptr) { if (!out) { out = lod_tensor(unique_name()); } @@ -145,6 +171,11 @@ struct Layers { op->SetInput("X", {x->Name()}); op->SetInput("Y", {y->Name()}); op->SetOutput("Out", {out->Name()}); + if (attrs) { + for (auto& iter : *attrs) { + op->SetAttr(iter.first, iter.second); + } + } op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), static_cast(OpRole::kForward)); return out; diff --git a/paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.cc b/paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.cc index 00263b8a34..45157ca18b 100644 --- a/paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.cc +++ b/paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.cc @@ -1,16 +1,16 @@ /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. */ + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ #include "paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.h" #include // for max @@ -25,55 +25,84 @@ namespace paddle { namespace framework { namespace ir { -PDNode* BuildRepeatedFCReluPattern(PDPattern* pattern, - const std::string& name_scope, int num_fc) { +static bool IsInputOfFC(Node* n) { + if (n && n->IsVar() && VarLinksToOp(n, "fc")) { + return true; + } + return false; +} + +static bool IsOutputOfFC(Node* n) { + if (n && n->IsVar() && VarLinksFromOp(n, "fc") && n->inputs.size() == 1U) { + return true; + } + return false; +} + +static bool IsFCWithAct(Node* n, const std::string& act_type = "relu") { + if (n && n->IsOp() && n->Op() && n->Op()->Type() == "fc" && + n->inputs.size() == 3U && n->outputs.size() == 1U) { + return boost::get(n->Op()->GetAttr("activation_type")) == + act_type; + } + return false; +} + +static bool IsParamOfFC(Node* n, const std::string& param_name) { + if (IsInputOfFC(n) && n->inputs.empty() && + (n->Name() == n->outputs[0]->Op()->Input(param_name)[0])) { + return true; + } + return false; +} + +static int FindFCIdx(Node* x, const std::string& act_type = "relu") { + if (!IsInputOfFC(x)) { + return -1; + } + for (size_t k = 0; k < x->outputs.size(); ++k) { + auto* out_op = x->outputs[k]; + if (IsFCWithAct(out_op, act_type) && out_op->outputs.size() == 1U) { + return k; + } + } + return -1; +} + +static int FindInputIdx(Node* n, const std::string& name, + const std::string& act_type = "relu") { + if (!IsFCWithAct(n, act_type)) { + return -1; + } + for (size_t i = 0; i < n->inputs.size(); ++i) { + if (n->inputs[i]->Name() == n->Op()->Input(name)[0]) { + return i; + } + } + return -1; +} + +void BuildRepeatedFCReluPattern(PDPattern* pattern, + const std::string& name_scope, int num_fc) { auto var_next_is_fc_act = [=](Node* x, const std::string& act_type = "relu", bool check_in_has_only_one_out = true, int fc_idx = 0) -> bool { - bool next_is_fc = x && x->IsVar() && VarLinksToOp(x, "fc"); - if (check_in_has_only_one_out) { - next_is_fc = next_is_fc && x->outputs.size() == 1; - } - if (!next_is_fc) { + if (!IsInputOfFC(x)) { return false; } - auto* fc_op = x->outputs[fc_idx]; - bool next_is_act = fc_op && fc_op->IsOp() && fc_op->outputs.size() == 1 && - fc_op->outputs[0] && fc_op->outputs[0]->IsVar() && - VarLinksToOp(fc_op->outputs[0], act_type) && - fc_op->outputs[0]->outputs.size() == 1; - if (!next_is_act) { + if (check_in_has_only_one_out && x->outputs.size() != 1U) { return false; } - auto* act_op = fc_op->outputs[0]->outputs[0]; - return act_op && act_op->IsOp() && act_op->outputs.size() == 1; - }; - - auto find_fc_idx = [=](Node* x, const std::string& act_type = "relu") -> int { - bool next_is_fc = x && x->IsVar() && VarLinksToOp(x, "fc"); - if (!next_is_fc) { - return 0; - } - for (size_t k = 0; k < x->outputs.size(); ++k) { - auto* fc_op = x->outputs[k]; - bool next_is_act = fc_op && fc_op->IsOp() && fc_op->outputs.size() == 1 && - fc_op->outputs[0] && fc_op->outputs[0]->IsVar() && - VarLinksToOp(fc_op->outputs[0], act_type) && - fc_op->outputs[0]->outputs.size() == 1; - if (!next_is_act) { - continue; - } - auto* act_op = fc_op->outputs[0]->outputs[0]; - if (act_op && act_op->IsOp() && act_op->outputs.size() == 1) { - return k; - } - } - return 0; + auto* fc_op = x->outputs[fc_idx]; + return IsFCWithAct(fc_op, act_type) && fc_op->outputs.size() == 1U; }; + // in -> fc -> out + // Current x is in, return fc's out which is next fc's input. auto next_var_of_part = [=](Node* x, int fc_idx = 0) -> Node* { - return x->outputs[fc_idx]->outputs[0]->outputs[0]->outputs[0]; + return x->outputs[fc_idx]->outputs[0]; }; + auto var_next_is_fc_act_repeated_n_times = [=]( Node* x, int repeated_times, const std::string& act_type = "relu", bool check_in_has_only_one_out = true) -> bool { @@ -87,25 +116,14 @@ PDNode* BuildRepeatedFCReluPattern(PDPattern* pattern, return true; }; + // x is output of fc auto var_before_is_fc_act = [=](Node* x, const std::string& act_type = "relu", bool at_top = false) -> bool { - bool before_is_act = - x && x->IsVar() && x->inputs.size() == 1 && VarLinksFromOp(x, "relu"); - if (!before_is_act) { + if (!IsOutputOfFC(x)) { return false; } - auto* relu_op = x->inputs[0]; - bool before_is_fc = relu_op->IsOp() && relu_op->inputs.size() == 1 && - relu_op->inputs[0]->IsVar() && - VarLinksFromOp(relu_op->inputs[0], "fc") && - relu_op->inputs[0]->inputs.size() == 1; - - if (!before_is_fc) { - return false; - } - auto* fc_op = relu_op->inputs[0]->inputs[0]; - bool is_fc = fc_op->IsOp() && fc_op->inputs.size() == 3; - if (!is_fc) { + auto* fc_op = x->inputs[0]; + if (!IsFCWithAct(fc_op, act_type) || fc_op->inputs.size() != 3U) { return false; } for (auto* fc_i : fc_op->inputs) { @@ -113,7 +131,7 @@ PDNode* BuildRepeatedFCReluPattern(PDPattern* pattern, if (at_top) { return true; } else { - return VarLinksFromOp(fc_i, "relu"); + return VarLinksFromOp(fc_i, "fc"); } } } @@ -121,10 +139,11 @@ PDNode* BuildRepeatedFCReluPattern(PDPattern* pattern, }; auto before_var_of_part = [=](Node* x) -> Node* { - auto* fc_op = x->inputs[0]->inputs[0]; - for (auto* fc_i : fc_op->inputs) { - if (!fc_i->inputs.empty()) { - return fc_i->inputs[0]; + auto* fc_op = x->inputs[0]; + for (auto* in : fc_op->inputs) { + if (!in->inputs.empty()) { + // w and bias has no input. + return in; } } return nullptr; @@ -142,76 +161,76 @@ PDNode* BuildRepeatedFCReluPattern(PDPattern* pattern, return true; }; - std::vector fc_input_var(num_fc); + PDNode* fc_input_var_0 = nullptr; std::vector fc_output_var(num_fc); std::vector fc_weight_var(num_fc); std::vector fc_bias_var(num_fc); std::vector fc_ops(num_fc); - std::vector relu_ops(num_fc); for (int i = 0; i < num_fc; ++i) { - fc_input_var[i] = pattern->NewNode( - [=](Node* x) { - if (i == 0 && x->outputs.size() > 0) { - bool ok = x->inputs.size() > 0; - if (!ok) { + if (i == 0) { + fc_input_var_0 = pattern->NewNode( + [=](Node* x) { + if (x->outputs.size() <= 0 || x->inputs.size() <= 0U) { return false; } - int idx = find_fc_idx(x); - if (idx == 0) { + int fc_idx = FindFCIdx(x); + if (fc_idx < 0) { + return false; + } else if (fc_idx == 0) { return var_next_is_fc_act_repeated_n_times(x, num_fc - i, "relu"); } else { - x = next_var_of_part(x, idx); + x = next_var_of_part(x, fc_idx); return var_next_is_fc_act_repeated_n_times( x, std::max(1, num_fc - i - 1), "relu"); } - } else { - return var_next_is_fc_act_repeated_n_times(x, num_fc - i, "relu") && - x->inputs.size() > 0 && - var_before_is_fc_act_repeated_n_times(x, i, "relu"); - } - }, - name_scope + "/fc_in_" + std::to_string(i)); + }, + name_scope + "/fc_in_0"); + } fc_weight_var[i] = pattern->NewNode( [=](Node* x) { + if (!IsParamOfFC(x, "W")) { + return false; + } + auto* fc_op = x->outputs[0]; + int input_idx = FindInputIdx(fc_op, "Input", "relu"); return var_next_is_fc_act_repeated_n_times(x, num_fc - i, "relu") && - x->inputs.empty() && - var_before_is_fc_act_repeated_n_times(x->outputs[0]->inputs[0], - i, "relu") && - x->Name() == x->outputs[0]->Op()->Input("W")[0]; + var_before_is_fc_act_repeated_n_times(fc_op->inputs[input_idx], + i, "relu"); }, name_scope + "/fc_weight_" + std::to_string(i)); fc_bias_var[i] = pattern->NewNode( [=](Node* x) { + if (!IsParamOfFC(x, "Bias")) { + return false; + } + auto* fc_op = x->outputs[0]; + int input_idx = FindInputIdx(fc_op, "Input", "relu"); return var_next_is_fc_act_repeated_n_times(x, num_fc - i, "relu") && - x->inputs.empty() && - var_before_is_fc_act_repeated_n_times(x->outputs[0]->inputs[0], - i, "relu") && - x->Name() == x->outputs[0]->Op()->Input("Bias")[0]; + var_before_is_fc_act_repeated_n_times(fc_op->inputs[input_idx], + i, "relu"); }, name_scope + "/fc_bias_" + std::to_string(i)); fc_output_var[i] = pattern->NewNode( [=](Node* x) { - bool basic = x && x->IsVar() && VarLinksFromOp(x, "fc") && - VarLinksToOp(x, "relu") && x->inputs.size() == 1 && - x->inputs[0]->inputs.size() == 3; - if (!basic) { + if (!IsOutputOfFC(x)) { return false; } - x = x->inputs[0]->inputs[0]; - if (i == 0 && x->outputs.size() > 0) { - bool ok = x->inputs.size() > 0; - if (!ok) { + x = before_var_of_part(x); + if (i == 0 && x->outputs.size() > 0U) { + if (x->inputs.size() <= 0U) { return false; } - int idx = find_fc_idx(x); - if (idx == 0) { + int fc_idx = FindFCIdx(x); + if (fc_idx < 0) { + return false; + } else if (fc_idx == 0) { return var_next_is_fc_act_repeated_n_times(x, num_fc - i, "relu"); } else { - x = next_var_of_part(x, idx); + x = next_var_of_part(x, fc_idx); return var_next_is_fc_act_repeated_n_times( x, std::max(1, num_fc - i - 1), "relu"); } @@ -225,53 +244,29 @@ PDNode* BuildRepeatedFCReluPattern(PDPattern* pattern, fc_ops[i] = pattern->NewNode( [=](Node* x) { - bool basic = x && x->IsOp() && x->Op()->Type() == "fc" && - x->inputs.size() == 3 && x->outputs.size() == 1; - if (!basic) { + if (!IsFCWithAct(x, "relu")) { return false; } auto* fc_out_var = x->outputs[0]; return fc_out_var && fc_out_var->IsVar() && fc_out_var->outputs.size() == 1 && - VarLinksToOp(fc_out_var, "relu") && - fc_out_var->outputs[0]->outputs.size() == 1 && - var_next_is_fc_act_repeated_n_times( - fc_out_var->outputs[0]->outputs[0], num_fc - i - 1, - "relu") && - var_before_is_fc_act_repeated_n_times( - fc_out_var->outputs[0]->outputs[0], i + 1, "relu"); - }, - name_scope + "/fc_op_" + std::to_string(i)); - - relu_ops[i] = pattern->NewNode( - [=](Node* x) { - return x && x->IsOp() && x->Op()->Type() == "relu" && - x->inputs.size() == 1 && x->outputs.size() == 1 && - x->inputs[0]->IsVar() && VarLinksFromOp(x->inputs[0], "fc") && - x->outputs[0]->IsVar() && - var_next_is_fc_act_repeated_n_times(x->outputs[0], - num_fc - i - 1, "relu") && - var_before_is_fc_act_repeated_n_times(x->outputs[0], i + 1, + var_next_is_fc_act_repeated_n_times(fc_out_var, num_fc - i - 1, + "relu") && + var_before_is_fc_act_repeated_n_times(fc_out_var, i + 1, "relu"); }, - name_scope + "/act_op_" + std::to_string(i)); - - fc_ops[i] - ->LinksFrom({fc_input_var[i], fc_weight_var[i], fc_bias_var[i]}) - .LinksTo({fc_output_var[i]}); - relu_ops[i]->LinksFrom({fc_output_var[i]}); - } + name_scope + "/fc_op_" + std::to_string(i)); - auto* last_out_var = pattern->NewNode( - [=](Node* x) { - return var_before_is_fc_act_repeated_n_times(x, num_fc, "relu"); - }, - name_scope + "/act_out"); - for (int i = 0; i < num_fc - 1; ++i) { - relu_ops[i]->LinksTo({fc_input_var[i + 1]}); + if (i == 0) { + fc_ops[i] + ->LinksFrom({fc_input_var_0, fc_weight_var[i], fc_bias_var[i]}) + .LinksTo({fc_output_var[i]}); + } else { + fc_ops[i] + ->LinksFrom({fc_output_var[i - 1], fc_weight_var[i], fc_bias_var[i]}) + .LinksTo({fc_output_var[i]}); + } } - relu_ops[num_fc - 1]->LinksTo({last_out_var}); - return last_out_var; } static int BuildFusion(Graph* graph, const std::string& name_scope, @@ -304,11 +299,11 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, auto& fused_pattern = gpd.pattern(); for (int i = 0; i < num_fc; ++i) { - if (i >= 1) { - relu_vars[i - 1] = - retrieve_node(name_scope + "/fc_in_" + std::to_string(i), subgraph, + if (i < num_fc - 1) { + relu_vars[i] = + retrieve_node(name_scope + "/fc_out_" + std::to_string(i), subgraph, fused_pattern); - relu_names[i - 1] = relu_vars[i - 1]->Name(); + relu_names[i] = relu_vars[i]->Name(); } weights_vars[i] = @@ -324,7 +319,8 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, auto* input_var = retrieve_node(name_scope + "/fc_in_0", subgraph, fused_pattern); auto* last_out_var = - retrieve_node(name_scope + "/act_out", subgraph, fused_pattern); + retrieve_node(name_scope + "/fc_out_" + std::to_string(num_fc - 1), + subgraph, fused_pattern); // Create New OpDesc OpDesc op_desc; @@ -334,6 +330,7 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, op_desc.SetInput("Bias", bias_names); op_desc.SetOutput("ReluOut", relu_names); op_desc.SetOutput("Out", {last_out_var->Name()}); + auto* op = graph->CreateOpNode(&op_desc); IR_NODE_LINK_TO(input_var, op); for (size_t i = 0; i < weights_vars.size(); ++i) { @@ -367,7 +364,9 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, } void RepeatedFCReluFusePass::ApplyImpl(ir::Graph* graph) const { + PADDLE_ENFORCE_NOT_NULL(graph); FusePassBase::Init(name_scope_, graph); + int fusion_count = 0; for (int i = MAX_NUM_FC; i > 1; --i) { fusion_count += diff --git a/paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass_tester.cc b/paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass_tester.cc new file mode 100644 index 0000000000..81d9476d40 --- /dev/null +++ b/paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass_tester.cc @@ -0,0 +1,71 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.h" + +#include +#include "paddle/fluid/framework/ir/pass_tester_helper.h" + +namespace paddle { +namespace framework { +namespace ir { + +void TestMain(int num_fc) { + // inputs operator output + // ------------------------------------------------------------- + // (x, filters, bias_0) conv2d -> conv2d_out + // (conv2d_out, fc_weights_0, fc_bias_0) fc -> fc_out_0 + // (fc_out_0, fc_weights_1, fc_bias_1) fc -> fc_out_1 + // ... + Layers layers; + VarDesc* x = layers.data("x"); + VarDesc* filters = layers.data("filters", {}, true); + VarDesc* bias_0 = layers.data("bias_0", {}, true); + VarDesc* conv2d_out = layers.conv2d(x, filters, bias_0); + VarDesc* fc_in = conv2d_out; + for (int i = 0; i < num_fc; ++i) { + VarDesc* weights_i = + layers.data("fc_weights_" + std::to_string(i), {}, true); + VarDesc* bias_i = layers.data("fc_bias_" + std::to_string(i), {}, true); + std::string activation_type = i < (num_fc - 1) ? "relu" : ""; + VarDesc* fc_out = layers.fc(fc_in, weights_i, bias_i, 1, activation_type); + fc_in = fc_out; + } + + std::unique_ptr graph(new ir::Graph(layers.main_program())); + auto pass = PassRegistry::Instance().Get("repeated_fc_relu_fuse_pass"); + int num_nodes_before = graph->Nodes().size(); + int num_fc_nodes_before = GetNumOpNodes(graph, "fc"); + VLOG(3) << DebugString(graph); + + graph.reset(pass->Apply(graph.release())); + int num_nodes_after = graph->Nodes().size(); + int num_fused_nodes_after = GetNumOpNodes(graph, "fusion_repeated_fc_relu"); + VLOG(3) << DebugString(graph); + + // Delete (num_fc_nodes_before - 1) fc ops + PADDLE_ENFORCE_EQ(num_nodes_before - (num_fc_nodes_before - 1) + 1, + num_nodes_after); + PADDLE_ENFORCE_EQ(num_fused_nodes_after, 1); +} + +TEST(RepeatedFCReluFusePass, basic_3) { TestMain(3); } + +TEST(RepeatedFCReluFusePass, basic_9) { TestMain(9); } + +} // namespace ir +} // namespace framework +} // namespace paddle + +USE_PASS(repeated_fc_relu_fuse_pass); diff --git a/paddle/fluid/inference/tests/api/analyzer_seq_conv1_tester.cc b/paddle/fluid/inference/tests/api/analyzer_seq_conv1_tester.cc index 5ee848c3cf..e3f8b835f7 100644 --- a/paddle/fluid/inference/tests/api/analyzer_seq_conv1_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_seq_conv1_tester.cc @@ -146,7 +146,7 @@ TEST(Analyzer_seq_conv1, fuse_statis) { ASSERT_TRUE(fuse_statis.count("seqconv_eltadd_relu_fuse")); EXPECT_EQ(fuse_statis.at("fc_fuse"), 2); EXPECT_EQ(fuse_statis.at("seqconv_eltadd_relu_fuse"), 6); - EXPECT_EQ(num_ops, 32); + EXPECT_EQ(num_ops, 31); } // Compare result of NativeConfig and AnalysisConfig diff --git a/paddle/fluid/operators/fc_op.cc b/paddle/fluid/operators/fc_op.cc index bc0edd780c..da30fef555 100644 --- a/paddle/fluid/operators/fc_op.cc +++ b/paddle/fluid/operators/fc_op.cc @@ -132,7 +132,7 @@ class FCOpMaker : public framework::OpProtoAndCheckerMaker { .SetDefault(1) .EqualGreaterThan(1); AddAttr("activation_type", - "Avctivation type used in fully connected operator.") + "Activation type used in fully connected operator.") .SetDefault(""); AddAttr("use_mkldnn", "(bool, default false) Only used in mkldnn kernel") -- GitLab