“f203d5a31b9ab6813ff9a7a3b37c63945b25e646”上不存在“doc/api/v2/fluid/nets.html”
未验证 提交 c67c8758 编写于 作者: Y Yiqun Liu 提交者: GitHub

Enhance fc_fuse_pass to enable fusing relu to fc_op (#19733)

* Refine the codes related to fc op.

* Add GPU implementation for fc functor.

* Apply fc_fuse_pass in GPU inference.
test=develop

* Change the cmake for fc op.

* Change PADDLE_ENFORCE to PADDLE_ENFORCE_EQ.

* Add an attribute to set the activation type in fc_op.

* Enhance the unittest of fc_op.
test=develop

* Remove the declaration of FCOpGrad back to the header file.
test=develop

* Set default value for newly added arguments in test_fc_op.
test=develop

* Enhance fc_fuse_pass to enable fusing relu.

* Allow print the shapes of var_desc in graph.
test=develop

* Enhance fc_fuse_pass_tester.

* Remove the use of PADDLE_ENFORCE.
test=develop

* Correct the number of ops after fusing.
test=develop

* Fix a typo.
test=develop

* Set activation_type to null when there is no relu in fc.
test=develop

* Refine fc_fuse_pass's codes.

* Enable the set of shape for tensor.

* Refine repeated_fc_relu_pass and add unittest.
test=develop
上级 b5a5d93b
...@@ -119,6 +119,7 @@ cc_test(test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS g ...@@ -119,6 +119,7 @@ cc_test(test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS g
cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass framework_proto) cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass framework_proto)
cc_test(test_seqpool_concat_fuse_pass SRCS seqpool_concat_fuse_pass_tester.cc DEPS seqpool_concat_fuse_pass framework_proto) cc_test(test_seqpool_concat_fuse_pass SRCS seqpool_concat_fuse_pass_tester.cc DEPS seqpool_concat_fuse_pass framework_proto)
cc_test(test_seqpool_cvm_concat_fuse_pass SRCS seqpool_cvm_concat_fuse_pass_tester.cc DEPS seqpool_cvm_concat_fuse_pass framework_proto) cc_test(test_seqpool_cvm_concat_fuse_pass SRCS seqpool_cvm_concat_fuse_pass_tester.cc DEPS seqpool_cvm_concat_fuse_pass framework_proto)
cc_test(test_repeated_fc_relu_fuse_pass SRCS repeated_fc_relu_fuse_pass_tester.cc DEPS repeated_fc_relu_fuse_pass framework_proto)
cc_test(test_is_test_pass SRCS is_test_pass_tester.cc DEPS is_test_pass) cc_test(test_is_test_pass SRCS is_test_pass_tester.cc DEPS is_test_pass)
cc_test(test_simplify_with_basic_ops_pass SRCS simplify_with_basic_ops_pass_tester.cc DEPS simplify_with_basic_ops_pass) cc_test(test_simplify_with_basic_ops_pass SRCS simplify_with_basic_ops_pass_tester.cc DEPS simplify_with_basic_ops_pass)
if(WITH_GPU) if(WITH_GPU)
......
...@@ -44,7 +44,8 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, ...@@ -44,7 +44,8 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
patterns::FC fc_pattern(pattern, name_scope); patterns::FC fc_pattern(pattern, name_scope);
// fc_out is a tmp var, will be removed after fuse, so marked as intermediate. // fc_out is a tmp var, will be removed after fuse, so marked as intermediate.
auto* fc_out = fc_pattern(embedding_out, with_fc_bias)->AsIntermediate(); auto* fc_out = fc_pattern(embedding_out, with_fc_bias, /* with_relu */ false)
->AsIntermediate();
patterns::LSTM lstm_pattern(pattern, name_scope); patterns::LSTM lstm_pattern(pattern, name_scope);
lstm_pattern(fc_out); lstm_pattern(fc_out);
...@@ -194,7 +195,7 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, ...@@ -194,7 +195,7 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
} }
if (with_fc_bias) { if (with_fc_bias) {
GET_IR_NODE_FROM_SUBGRAPH(fc_out, Out, fc_pattern); GET_IR_NODE_FROM_SUBGRAPH(fc_out, elementwise_add_out, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(fc_bias, bias, fc_pattern); GET_IR_NODE_FROM_SUBGRAPH(fc_bias, bias, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add, elementwise_add, fc_pattern); GET_IR_NODE_FROM_SUBGRAPH(elementwise_add, elementwise_add, fc_pattern);
embedding_lstm_creator(lookup_table, W, lstm, subgraph.at(x), w, Weight, embedding_lstm_creator(lookup_table, W, lstm, subgraph.at(x), w, Weight,
......
...@@ -25,83 +25,110 @@ namespace framework { ...@@ -25,83 +25,110 @@ namespace framework {
namespace ir { namespace ir {
void FCFusePass::ApplyImpl(ir::Graph* graph) const { void FCFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE(graph); PADDLE_ENFORCE_NOT_NULL(graph);
FusePassBase::Init("fc_fuse", graph); FusePassBase::Init("fc_fuse", graph);
std::unordered_set<Node*> nodes2delete; int found_fc_count = 0;
for (bool with_relu : {true, false}) {
found_fc_count += ApplyFCPattern(graph, with_relu);
}
AddStatis(found_fc_count);
}
int FCFusePass::ApplyFCPattern(Graph* graph, bool with_relu) const {
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto* x = gpd.mutable_pattern() auto* x = gpd.mutable_pattern()
->NewNode("fc_fuse/x") ->NewNode("fc_fuse/x")
->AsInput() ->AsInput()
->assert_is_op_input("mul", "X"); ->assert_is_op_input("mul", "X");
patterns::FC fc_pattern(gpd.mutable_pattern(), "fc_fuse"); patterns::FC fc_pattern(gpd.mutable_pattern(), "fc_fuse");
fc_pattern(x, true /*with bias*/); fc_pattern(x, true /*with bias*/, with_relu);
int found_fc_count = 0; int found_fc_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
if (subgraph.count(x) <= 0) {
LOG(WARNING) << "The subgraph is empty.";
return;
}
VLOG(4) << "handle FC fuse"; VLOG(4) << "handle FC fuse";
GET_IR_NODE_FROM_SUBGRAPH(w, w, fc_pattern); GET_IR_NODE_FROM_SUBGRAPH(w, w, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(fc_bias, bias, fc_pattern); GET_IR_NODE_FROM_SUBGRAPH(bias, bias, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(fc_out, Out, fc_pattern); GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out,
fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul, mul, fc_pattern); GET_IR_NODE_FROM_SUBGRAPH(mul, mul, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add, elementwise_add, fc_pattern); GET_IR_NODE_FROM_SUBGRAPH(elementwise_add, elementwise_add, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul_out, mul_out, fc_pattern); GET_IR_NODE_FROM_SUBGRAPH(mul_out, mul_out, fc_pattern);
Node* relu = nullptr;
Node* relu_out = nullptr;
if (with_relu) {
GET_IR_NODE_FROM_SUBGRAPH(tmp_relu, relu, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(tmp_relu_out, relu_out, fc_pattern);
relu = tmp_relu;
relu_out = tmp_relu_out;
}
auto base_op_desc = mul->Op();
// Create an FC Node. // Create an FC Node.
// OpDesc desc(base_op_desc, nullptr);
OpDesc desc; OpDesc desc;
std::string fc_x_in = subgraph.at(x)->Name(); desc.SetType("fc");
std::string fc_Y_in = w->Name();
std::string fc_bias_in = fc_bias->Name(); // Set inputs of fc
std::string fc_out_out = fc_out->Name(); desc.SetInput("Input", {subgraph.at(x)->Name()});
desc.SetInput("W", {w->Name()});
desc.SetInput("Input", std::vector<std::string>({fc_x_in})); desc.SetInput("Bias", {bias->Name()});
desc.SetInput("W", std::vector<std::string>({fc_Y_in}));
desc.SetInput("Bias", std::vector<std::string>({fc_bias_in})); // Set output of fc
desc.SetOutput("Out", std::vector<std::string>({fc_out_out})); std::string fc_out_name =
with_relu ? relu_out->Name() : elementwise_add_out->Name();
desc.SetOutput("Out", std::vector<std::string>({fc_out_name}));
// Set attrs of fc
desc.SetAttr("in_num_col_dims", mul->Op()->GetAttr("x_num_col_dims")); desc.SetAttr("in_num_col_dims", mul->Op()->GetAttr("x_num_col_dims"));
std::string activation_type = with_relu ? "relu" : "";
desc.SetAttr("activation_type", activation_type);
// For anakin subgraph int8 // For anakin subgraph int8
// When in anakin subgraph int8 mode, the pattern like "fake_quant + mul + // When in anakin subgraph int8 mode, the pattern like "fake_quant + mul +
// fake_dequant" // fake_dequant" can be detected by the quant_dequant_fuse_pass. This pass
// can be detected by the quant_dequant_fuse_pass. This pass will add // will add "input_scale", "weight_scale" which are extracted from
// "input_scale", // fake_quant op and fake_dequant op to mul op, and then delete the
// "weight_scale" which are extracted from fake_quant op and fake_dequant op // fake_quant op and fake_dequant op in the graph. If the mul op has the
// to mul op, // scale info, we should add those to the fused fc.
// and then delete the fake_quant op and fake_dequant op in the graph. If auto* mul_op_desc = mul->Op();
// the mul op if (mul_op_desc->HasAttr("enable_int8")) {
// has the scale info, we should add those to the fused fc. desc.SetAttr("enable_int8", mul_op_desc->GetAttr("enable_int8"));
if (base_op_desc->HasAttr("enable_int8")) { desc.SetAttr("input_scale", mul_op_desc->GetAttr("input_scale"));
desc.SetAttr("enable_int8", base_op_desc->GetAttr("enable_int8")); desc.SetAttr("weight_scale", mul_op_desc->GetAttr("weight_scale"));
desc.SetAttr("input_scale", base_op_desc->GetAttr("input_scale")); if (mul_op_desc->HasAttr("out_scale"))
desc.SetAttr("weight_scale", base_op_desc->GetAttr("weight_scale")); desc.SetAttr("out_scale", mul_op_desc->GetAttr("out_scale"));
if (base_op_desc->HasAttr("out_scale"))
desc.SetAttr("out_scale", base_op_desc->GetAttr("out_scale"));
auto elementwise_desc = elementwise_add->Op(); auto elementwise_desc = elementwise_add->Op();
if (elementwise_desc->HasAttr("out_scale")) if (elementwise_desc->HasAttr("out_scale"))
desc.SetAttr("out_scale", elementwise_desc->GetAttr("out_scale")); desc.SetAttr("out_scale", elementwise_desc->GetAttr("out_scale"));
} }
desc.SetType("fc");
auto fc_node = g->CreateOpNode(&desc); // OpDesc will be copied. auto fc_node = g->CreateOpNode(&desc); // OpDesc will be copied.
GraphSafeRemoveNodes(graph, {mul, elementwise_add, mul_out}); if (with_relu) {
GraphSafeRemoveNodes(
graph, {mul, elementwise_add, mul_out, elementwise_add_out, relu});
} else {
GraphSafeRemoveNodes(graph, {mul, elementwise_add, mul_out});
}
PADDLE_ENFORCE(subgraph.count(x));
IR_NODE_LINK_TO(subgraph.at(x), fc_node); IR_NODE_LINK_TO(subgraph.at(x), fc_node);
IR_NODE_LINK_TO(w, fc_node); IR_NODE_LINK_TO(w, fc_node);
IR_NODE_LINK_TO(fc_bias, fc_node); IR_NODE_LINK_TO(bias, fc_node);
IR_NODE_LINK_TO(fc_node, fc_out); if (with_relu) {
IR_NODE_LINK_TO(fc_node, relu_out);
} else {
IR_NODE_LINK_TO(fc_node, elementwise_add_out);
}
found_fc_count++; found_fc_count++;
}; };
gpd(graph, handler); gpd(graph, handler);
return found_fc_count;
AddStatis(found_fc_count);
} }
} // namespace ir } // namespace ir
......
...@@ -31,7 +31,9 @@ class FCFusePass : public FusePassBase { ...@@ -31,7 +31,9 @@ class FCFusePass : public FusePassBase {
virtual ~FCFusePass() {} virtual ~FCFusePass() {}
protected: protected:
void ApplyImpl(ir::Graph* graph) const override; void ApplyImpl(Graph* graph) const override;
int ApplyFCPattern(Graph* graph, bool with_relu) const;
}; };
} // namespace ir } // namespace ir
......
...@@ -15,81 +15,53 @@ ...@@ -15,81 +15,53 @@
#include "paddle/fluid/framework/ir/fc_fuse_pass.h" #include "paddle/fluid/framework/ir/fc_fuse_pass.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/ir/pass_tester_helper.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
void SetOp(ProgramDesc* prog, const std::string& type,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs) {
auto* op = prog->MutableBlock(0)->AppendOp();
op->SetType(type);
if (type == "mul") {
op->SetInput("X", {inputs[0]});
op->SetInput("Y", {inputs[1]});
op->SetAttr("x_num_col_dims", {1});
} else if (type == "elementwise_add") {
op->SetInput("X", inputs);
}
op->SetOutput("Out", outputs);
op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
static_cast<int>(OpRole::kForward));
}
// a->OP0->b
// a->OP1->c
// (b, c)->mul->d
// (d, e)->elementwise_add->f
ProgramDesc BuildProgramDesc() {
ProgramDesc prog;
for (auto& v : std::vector<std::string>({"a", "b", "c", "d", "e", "f"})) {
auto* var = prog.MutableBlock(0)->Var(v);
var->SetType(proto::VarType::SELECTED_ROWS);
if (v == "c") {
var->SetPersistable(true);
}
}
SetOp(&prog, "OP0", std::vector<std::string>({"a"}),
std::vector<std::string>({"b"}));
SetOp(&prog, "OP1", std::vector<std::string>({"a"}),
std::vector<std::string>({"c"}));
SetOp(&prog, "mul", std::vector<std::string>({"b", "c"}),
std::vector<std::string>({"d"}));
SetOp(&prog, "elementwise_add", std::vector<std::string>({"d", "e"}),
std::vector<std::string>({"f"}));
return prog;
}
TEST(FCFusePass, basic) { TEST(FCFusePass, basic) {
auto prog = BuildProgramDesc(); // inputs operator output
// --------------------------------------------------------
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog)); // (a, filters_0 bias_0) conv2d -> conv2d_out
// conv2d_out relu -> relu_out_0
// (relu_out_0, weights_0) mul -> mul_out_0
// (mul_out_0, bias_1) elementwise_add -> add_out_0
// add_out_0 relu -> relu_out_1
// (relu_out_1, weights_1) mul -> mul_out_1
// (mul_out_1, bias_2) elementwise_add -> add_out_1
Layers layers;
auto* a = layers.data("a");
auto* filters_0 = layers.data("conv2d_filters_0", {}, true);
auto* bias_0 = layers.data("conv2d_bias_0", {}, true);
auto* conv2d_out = layers.conv2d(a, filters_0, bias_0, false);
auto* relu_out_0 = layers.relu(conv2d_out);
auto* weights_0 = layers.data("weights_0", {}, true);
auto* mul_out_0 = layers.mul(relu_out_0, weights_0);
auto* bias_1 = layers.data("bias_1", {}, true);
auto* add_out_0 = layers.elementwise_add(mul_out_0, bias_1);
auto* relu_out_1 = layers.relu(add_out_0);
auto* weights_1 = layers.data("weights_1", {}, true);
auto* mul_out_1 = layers.mul(relu_out_1, weights_1);
auto* bias_2 = layers.data("bias_2", {}, true);
auto* add_out_1 = layers.elementwise_add(mul_out_1, bias_2);
VLOG(4) << add_out_1;
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
auto pass = PassRegistry::Instance().Get("fc_fuse_pass"); auto pass = PassRegistry::Instance().Get("fc_fuse_pass");
int num_nodes_before = graph->Nodes().size();
int pre_nodes = graph->Nodes().size(); int num_mul_nodes_before = GetNumOpNodes(graph, "mul");
VLOG(3) << DebugString(graph);
graph.reset(pass->Apply(graph.release())); graph.reset(pass->Apply(graph.release()));
int num_nodes_after = graph->Nodes().size();
int num_fc_nodes_after = GetNumOpNodes(graph, "fc");
VLOG(3) << DebugString(graph);
int after_nodes = graph->Nodes().size(); PADDLE_ENFORCE_EQ(num_nodes_before, num_nodes_after + 6);
PADDLE_ENFORCE_EQ(num_fc_nodes_after, 2);
// Remove 3 Nodes: MUL,ELEMENTWISE_ADD, mul_out PADDLE_ENFORCE_EQ(num_mul_nodes_before, num_fc_nodes_after);
// Add 1 Node: FC
EXPECT_EQ(pre_nodes - 2, after_nodes);
// Assert fc op in newly generated graph
int fc_count = 0;
for (auto* node : graph->Nodes()) {
if (node->IsOp() && node->Op()->Type() == "fc") {
++fc_count;
}
}
EXPECT_EQ(fc_count, 1);
} }
} // namespace ir } // namespace ir
......
...@@ -33,7 +33,7 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, ...@@ -33,7 +33,7 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
PDNode* x = PDNode* x =
pattern->NewNode(patterns::UniqueKey("x"))->assert_var_not_persistable(); pattern->NewNode(patterns::UniqueKey("x"))->assert_var_not_persistable();
auto* fc_out = fc_pattern(x, with_fc_bias); auto* fc_out = fc_pattern(x, with_fc_bias, /* with_relu */ false);
fc_out->AsIntermediate(); // fc_out is a tmp var, will be removed after fuse. fc_out->AsIntermediate(); // fc_out is a tmp var, will be removed after fuse.
gru_pattern(fc_out); gru_pattern(fc_out);
...@@ -116,7 +116,7 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, ...@@ -116,7 +116,7 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
auto* x_n = subgraph.at(x); auto* x_n = subgraph.at(x);
GET_IR_NODE_FROM_SUBGRAPH(w, w, fc_pattern); GET_IR_NODE_FROM_SUBGRAPH(w, w, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul, mul, fc_pattern); GET_IR_NODE_FROM_SUBGRAPH(mul, mul, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(fc_out, Out, fc_pattern); GET_IR_NODE_FROM_SUBGRAPH(fc_out, elementwise_add_out, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(Weight, Weight, gru_pattern); GET_IR_NODE_FROM_SUBGRAPH(Weight, Weight, gru_pattern);
GET_IR_NODE_FROM_SUBGRAPH(gru, gru, gru_pattern); GET_IR_NODE_FROM_SUBGRAPH(gru, gru, gru_pattern);
GET_IR_NODE_FROM_SUBGRAPH(Bias, Bias, gru_pattern); GET_IR_NODE_FROM_SUBGRAPH(Bias, Bias, gru_pattern);
......
...@@ -33,7 +33,8 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope, ...@@ -33,7 +33,8 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
patterns::FC fc_pattern(pattern, name_scope); patterns::FC fc_pattern(pattern, name_scope);
// fc_out is a tmp var, will be removed after fuse, so marked as intermediate. // fc_out is a tmp var, will be removed after fuse, so marked as intermediate.
auto* fc_out = fc_pattern(x, with_fc_bias)->AsIntermediate(); auto* fc_out =
fc_pattern(x, with_fc_bias, /* with_relu */ false)->AsIntermediate();
patterns::LSTM lstm_pattern(pattern, name_scope); patterns::LSTM lstm_pattern(pattern, name_scope);
lstm_pattern(fc_out); lstm_pattern(fc_out);
...@@ -132,7 +133,7 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope, ...@@ -132,7 +133,7 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
GET_IR_NODE_FROM_SUBGRAPH(w, w, fc_pattern); GET_IR_NODE_FROM_SUBGRAPH(w, w, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul, mul, fc_pattern); GET_IR_NODE_FROM_SUBGRAPH(mul, mul, fc_pattern);
if (with_fc_bias) { if (with_fc_bias) {
GET_IR_NODE_FROM_SUBGRAPH(fc_out, Out, fc_pattern); GET_IR_NODE_FROM_SUBGRAPH(fc_out, elementwise_add_out, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(fc_bias, bias, fc_pattern); GET_IR_NODE_FROM_SUBGRAPH(fc_bias, bias, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add, elementwise_add, fc_pattern); GET_IR_NODE_FROM_SUBGRAPH(elementwise_add, elementwise_add, fc_pattern);
lstm_creator(lstm, subgraph.at(x), w, Weight, Bias, Hidden, Cell, fc_out, lstm_creator(lstm, subgraph.at(x), w, Weight, Bias, Hidden, Cell, fc_out,
......
...@@ -846,7 +846,7 @@ PDNode *patterns::SeqConvEltAddRelu::operator()( ...@@ -846,7 +846,7 @@ PDNode *patterns::SeqConvEltAddRelu::operator()(
} }
PDNode *patterns::FC::operator()(paddle::framework::ir::PDNode *x, PDNode *patterns::FC::operator()(paddle::framework::ir::PDNode *x,
bool with_bias) { bool with_bias, bool with_relu) {
// Create shared nodes. // Create shared nodes.
x->assert_is_op_input("mul", "X"); x->assert_is_op_input("mul", "X");
auto *mul = pattern->NewNode(mul_repr())->assert_is_op("mul"); auto *mul = pattern->NewNode(mul_repr())->assert_is_op("mul");
...@@ -859,11 +859,10 @@ PDNode *patterns::FC::operator()(paddle::framework::ir::PDNode *x, ...@@ -859,11 +859,10 @@ PDNode *patterns::FC::operator()(paddle::framework::ir::PDNode *x,
auto *mul_out_var = auto *mul_out_var =
pattern->NewNode(mul_out_repr())->assert_is_op_output("mul"); pattern->NewNode(mul_out_repr())->assert_is_op_output("mul");
// Add links.
mul->LinksFrom({x, mul_w_var}).LinksTo({mul_out_var});
if (!with_bias) { // not with bias if (!with_bias) { // not with bias
// Add links.
mul->LinksFrom({x, mul_w_var}).LinksTo({mul_out_var});
return mul_out_var; return mul_out_var;
} else { // with bias } else { // with bias
mul_out_var->AsIntermediate()->assert_is_op_input("elementwise_add"); mul_out_var->AsIntermediate()->assert_is_op_input("elementwise_add");
// Create operators. // Create operators.
...@@ -872,15 +871,29 @@ PDNode *patterns::FC::operator()(paddle::framework::ir::PDNode *x, ...@@ -872,15 +871,29 @@ PDNode *patterns::FC::operator()(paddle::framework::ir::PDNode *x,
// Create variables. // Create variables.
auto *bias = pattern->NewNode(bias_repr()) auto *bias = pattern->NewNode(bias_repr())
->assert_is_op_input("elementwise_add") ->assert_is_op_input("elementwise_add")
->assert_is_persistable_var()
->AsInput(); ->AsInput();
auto *fc_out = pattern->NewNode(Out_repr()) auto *elementwise_add_out_var =
->AsOutput() pattern->NewNode(elementwise_add_out_repr())
->assert_is_op_output("elementwise_add"); ->AsOutput()
->assert_is_op_output("elementwise_add");
mul->LinksFrom({mul_w_var, x}).LinksTo({mul_out_var}); elementwise_add->LinksFrom({mul_out_var, bias})
elementwise_add->LinksFrom({mul_out_var, bias}).LinksTo({fc_out}); .LinksTo({elementwise_add_out_var});
return fc_out; if (!with_relu) {
return elementwise_add_out_var;
} else {
elementwise_add_out_var->AsIntermediate()->assert_is_op_input("relu");
// Create operators.
auto *relu = pattern->NewNode(relu_repr())->assert_is_op("relu");
auto *relu_out_var = pattern->NewNode(relu_out_repr())
->AsOutput()
->assert_is_op_output("relu");
relu->LinksFrom({elementwise_add_out_var}).LinksTo({relu_out_var});
return relu_out_var;
}
} }
} }
......
...@@ -487,17 +487,19 @@ struct FC : public PatternBase { ...@@ -487,17 +487,19 @@ struct FC : public PatternBase {
FC(PDPattern* pattern, const std::string& name_scope) FC(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "fc") {} : PatternBase(pattern, name_scope, "fc") {}
PDNode* operator()(PDNode* x, bool with_bias); PDNode* operator()(PDNode* x, bool with_bias, bool with_relu);
// declare operator node's name // declare operator node's name
PATTERN_DECL_NODE(fc); PATTERN_DECL_NODE(fc);
PATTERN_DECL_NODE(mul); PATTERN_DECL_NODE(mul);
PATTERN_DECL_NODE(elementwise_add); PATTERN_DECL_NODE(elementwise_add);
PATTERN_DECL_NODE(relu);
// declare variable node's name // declare variable node's name
PATTERN_DECL_NODE(w); PATTERN_DECL_NODE(w);
PATTERN_DECL_NODE(mul_out); // (x,w) -> mul_out PATTERN_DECL_NODE(mul_out); // (x,w) -> mul_out
PATTERN_DECL_NODE(bias); PATTERN_DECL_NODE(bias);
PATTERN_DECL_NODE(Out); PATTERN_DECL_NODE(elementwise_add_out);
PATTERN_DECL_NODE(relu_out);
}; };
// MKL-DNN's FC with bias // MKL-DNN's FC with bias
......
...@@ -89,6 +89,17 @@ void GraphVizPass::ApplyImpl(ir::Graph* graph) const { ...@@ -89,6 +89,17 @@ void GraphVizPass::ApplyImpl(ir::Graph* graph) const {
marked_nodes.count(n) ? marked_op_attrs : op_attrs; marked_nodes.count(n) ? marked_op_attrs : op_attrs;
dot.AddNode(node_id, attr, node_id); dot.AddNode(node_id, attr, node_id);
} else if (n->IsVar()) { } else if (n->IsVar()) {
if (n->Var() && n->Var()->GetType() == proto::VarType::LOD_TENSOR) {
bool is_first = true;
for (int64_t length : n->Var()->GetShape()) {
if (is_first) {
node_id += "\n" + std::to_string(length);
is_first = false;
} else {
node_id += "," + std::to_string(length);
}
}
}
decltype(op_attrs)* attr; decltype(op_attrs)* attr;
if (marked_nodes.count(n)) { if (marked_nodes.count(n)) {
attr = &marked_var_attrs; attr = &marked_var_attrs;
......
...@@ -28,10 +28,13 @@ struct Layers { ...@@ -28,10 +28,13 @@ struct Layers {
public: public:
const ProgramDesc& main_program() { return program_; } const ProgramDesc& main_program() { return program_; }
VarDesc* data(std::string name) { return lod_tensor(name); } VarDesc* data(std::string name, std::vector<int64_t> shape = {},
bool is_persistable = false) {
return lod_tensor(name, shape, is_persistable);
}
VarDesc* conv2d(VarDesc* input, VarDesc* filter, VarDesc* bias, VarDesc* conv2d(VarDesc* input, VarDesc* filter, VarDesc* bias,
bool use_cudnn) { bool use_cudnn = false) {
VarDesc* out = lod_tensor(unique_name()); VarDesc* out = lod_tensor(unique_name());
OpDesc* op = program_.MutableBlock(0)->AppendOp(); OpDesc* op = program_.MutableBlock(0)->AppendOp();
op->SetType("conv2d"); op->SetType("conv2d");
...@@ -76,8 +79,27 @@ struct Layers { ...@@ -76,8 +79,27 @@ struct Layers {
return unary_op("relu", x, out); return unary_op("relu", x, out);
} }
VarDesc* mul(VarDesc* x, VarDesc* y, VarDesc* out = nullptr) { VarDesc* fc(VarDesc* input, VarDesc* w, VarDesc* bias,
return binary_op("mul", x, y, out); int in_num_col_dims = 1, std::string activation_type = "") {
VarDesc* out = lod_tensor(unique_name());
OpDesc* op = program_.MutableBlock(0)->AppendOp();
op->SetType("fc");
op->SetInput("Input", {input->Name()});
op->SetInput("W", {w->Name()});
op->SetInput("Bias", {bias->Name()});
op->SetOutput("Out", {out->Name()});
op->SetAttr("in_num_col_dims", in_num_col_dims);
op->SetAttr("activation_type", activation_type);
op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
static_cast<int>(OpRole::kForward));
return out;
}
VarDesc* mul(VarDesc* x, VarDesc* y, VarDesc* out = nullptr,
int x_num_col_dims = 1) {
AttributeMap attrs;
attrs["x_num_col_dims"] = 1;
return binary_op("mul", x, y, out, &attrs);
} }
VarDesc* elementwise_add(VarDesc* x, VarDesc* y, VarDesc* out = nullptr) { VarDesc* elementwise_add(VarDesc* x, VarDesc* y, VarDesc* out = nullptr) {
...@@ -116,9 +138,12 @@ struct Layers { ...@@ -116,9 +138,12 @@ struct Layers {
} }
private: private:
VarDesc* lod_tensor(std::string name) { VarDesc* lod_tensor(std::string name, std::vector<int64_t> shape = {},
bool is_persistable = false) {
auto* var = program_.MutableBlock(0)->Var(name); auto* var = program_.MutableBlock(0)->Var(name);
var->SetType(proto::VarType::LOD_TENSOR); var->SetType(proto::VarType::LOD_TENSOR);
var->SetShape(shape);
var->SetPersistable(is_persistable);
return var; return var;
} }
...@@ -136,7 +161,8 @@ struct Layers { ...@@ -136,7 +161,8 @@ struct Layers {
} }
VarDesc* binary_op(std::string type, VarDesc* x, VarDesc* y, VarDesc* binary_op(std::string type, VarDesc* x, VarDesc* y,
VarDesc* out = nullptr) { VarDesc* out = nullptr,
const AttributeMap* attrs = nullptr) {
if (!out) { if (!out) {
out = lod_tensor(unique_name()); out = lod_tensor(unique_name());
} }
...@@ -145,6 +171,11 @@ struct Layers { ...@@ -145,6 +171,11 @@ struct Layers {
op->SetInput("X", {x->Name()}); op->SetInput("X", {x->Name()});
op->SetInput("Y", {y->Name()}); op->SetInput("Y", {y->Name()});
op->SetOutput("Out", {out->Name()}); op->SetOutput("Out", {out->Name()});
if (attrs) {
for (auto& iter : *attrs) {
op->SetAttr(iter.first, iter.second);
}
}
op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
static_cast<int>(OpRole::kForward)); static_cast<int>(OpRole::kForward));
return out; return out;
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
* You may obtain a copy of the License at You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and See the License for the specific language governing permissions and
* limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.h" #include "paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.h"
#include <algorithm> // for max #include <algorithm> // for max
...@@ -25,55 +25,84 @@ namespace paddle { ...@@ -25,55 +25,84 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
PDNode* BuildRepeatedFCReluPattern(PDPattern* pattern, static bool IsInputOfFC(Node* n) {
const std::string& name_scope, int num_fc) { if (n && n->IsVar() && VarLinksToOp(n, "fc")) {
return true;
}
return false;
}
static bool IsOutputOfFC(Node* n) {
if (n && n->IsVar() && VarLinksFromOp(n, "fc") && n->inputs.size() == 1U) {
return true;
}
return false;
}
static bool IsFCWithAct(Node* n, const std::string& act_type = "relu") {
if (n && n->IsOp() && n->Op() && n->Op()->Type() == "fc" &&
n->inputs.size() == 3U && n->outputs.size() == 1U) {
return boost::get<std::string>(n->Op()->GetAttr("activation_type")) ==
act_type;
}
return false;
}
static bool IsParamOfFC(Node* n, const std::string& param_name) {
if (IsInputOfFC(n) && n->inputs.empty() &&
(n->Name() == n->outputs[0]->Op()->Input(param_name)[0])) {
return true;
}
return false;
}
static int FindFCIdx(Node* x, const std::string& act_type = "relu") {
if (!IsInputOfFC(x)) {
return -1;
}
for (size_t k = 0; k < x->outputs.size(); ++k) {
auto* out_op = x->outputs[k];
if (IsFCWithAct(out_op, act_type) && out_op->outputs.size() == 1U) {
return k;
}
}
return -1;
}
static int FindInputIdx(Node* n, const std::string& name,
const std::string& act_type = "relu") {
if (!IsFCWithAct(n, act_type)) {
return -1;
}
for (size_t i = 0; i < n->inputs.size(); ++i) {
if (n->inputs[i]->Name() == n->Op()->Input(name)[0]) {
return i;
}
}
return -1;
}
void BuildRepeatedFCReluPattern(PDPattern* pattern,
const std::string& name_scope, int num_fc) {
auto var_next_is_fc_act = [=](Node* x, const std::string& act_type = "relu", auto var_next_is_fc_act = [=](Node* x, const std::string& act_type = "relu",
bool check_in_has_only_one_out = true, bool check_in_has_only_one_out = true,
int fc_idx = 0) -> bool { int fc_idx = 0) -> bool {
bool next_is_fc = x && x->IsVar() && VarLinksToOp(x, "fc"); if (!IsInputOfFC(x)) {
if (check_in_has_only_one_out) {
next_is_fc = next_is_fc && x->outputs.size() == 1;
}
if (!next_is_fc) {
return false; return false;
} }
auto* fc_op = x->outputs[fc_idx]; if (check_in_has_only_one_out && x->outputs.size() != 1U) {
bool next_is_act = fc_op && fc_op->IsOp() && fc_op->outputs.size() == 1 &&
fc_op->outputs[0] && fc_op->outputs[0]->IsVar() &&
VarLinksToOp(fc_op->outputs[0], act_type) &&
fc_op->outputs[0]->outputs.size() == 1;
if (!next_is_act) {
return false; return false;
} }
auto* act_op = fc_op->outputs[0]->outputs[0]; auto* fc_op = x->outputs[fc_idx];
return act_op && act_op->IsOp() && act_op->outputs.size() == 1; return IsFCWithAct(fc_op, act_type) && fc_op->outputs.size() == 1U;
};
auto find_fc_idx = [=](Node* x, const std::string& act_type = "relu") -> int {
bool next_is_fc = x && x->IsVar() && VarLinksToOp(x, "fc");
if (!next_is_fc) {
return 0;
}
for (size_t k = 0; k < x->outputs.size(); ++k) {
auto* fc_op = x->outputs[k];
bool next_is_act = fc_op && fc_op->IsOp() && fc_op->outputs.size() == 1 &&
fc_op->outputs[0] && fc_op->outputs[0]->IsVar() &&
VarLinksToOp(fc_op->outputs[0], act_type) &&
fc_op->outputs[0]->outputs.size() == 1;
if (!next_is_act) {
continue;
}
auto* act_op = fc_op->outputs[0]->outputs[0];
if (act_op && act_op->IsOp() && act_op->outputs.size() == 1) {
return k;
}
}
return 0;
}; };
// in -> fc -> out
// Current x is in, return fc's out which is next fc's input.
auto next_var_of_part = [=](Node* x, int fc_idx = 0) -> Node* { auto next_var_of_part = [=](Node* x, int fc_idx = 0) -> Node* {
return x->outputs[fc_idx]->outputs[0]->outputs[0]->outputs[0]; return x->outputs[fc_idx]->outputs[0];
}; };
auto var_next_is_fc_act_repeated_n_times = [=]( auto var_next_is_fc_act_repeated_n_times = [=](
Node* x, int repeated_times, const std::string& act_type = "relu", Node* x, int repeated_times, const std::string& act_type = "relu",
bool check_in_has_only_one_out = true) -> bool { bool check_in_has_only_one_out = true) -> bool {
...@@ -87,25 +116,14 @@ PDNode* BuildRepeatedFCReluPattern(PDPattern* pattern, ...@@ -87,25 +116,14 @@ PDNode* BuildRepeatedFCReluPattern(PDPattern* pattern,
return true; return true;
}; };
// x is output of fc
auto var_before_is_fc_act = [=](Node* x, const std::string& act_type = "relu", auto var_before_is_fc_act = [=](Node* x, const std::string& act_type = "relu",
bool at_top = false) -> bool { bool at_top = false) -> bool {
bool before_is_act = if (!IsOutputOfFC(x)) {
x && x->IsVar() && x->inputs.size() == 1 && VarLinksFromOp(x, "relu");
if (!before_is_act) {
return false; return false;
} }
auto* relu_op = x->inputs[0]; auto* fc_op = x->inputs[0];
bool before_is_fc = relu_op->IsOp() && relu_op->inputs.size() == 1 && if (!IsFCWithAct(fc_op, act_type) || fc_op->inputs.size() != 3U) {
relu_op->inputs[0]->IsVar() &&
VarLinksFromOp(relu_op->inputs[0], "fc") &&
relu_op->inputs[0]->inputs.size() == 1;
if (!before_is_fc) {
return false;
}
auto* fc_op = relu_op->inputs[0]->inputs[0];
bool is_fc = fc_op->IsOp() && fc_op->inputs.size() == 3;
if (!is_fc) {
return false; return false;
} }
for (auto* fc_i : fc_op->inputs) { for (auto* fc_i : fc_op->inputs) {
...@@ -113,7 +131,7 @@ PDNode* BuildRepeatedFCReluPattern(PDPattern* pattern, ...@@ -113,7 +131,7 @@ PDNode* BuildRepeatedFCReluPattern(PDPattern* pattern,
if (at_top) { if (at_top) {
return true; return true;
} else { } else {
return VarLinksFromOp(fc_i, "relu"); return VarLinksFromOp(fc_i, "fc");
} }
} }
} }
...@@ -121,10 +139,11 @@ PDNode* BuildRepeatedFCReluPattern(PDPattern* pattern, ...@@ -121,10 +139,11 @@ PDNode* BuildRepeatedFCReluPattern(PDPattern* pattern,
}; };
auto before_var_of_part = [=](Node* x) -> Node* { auto before_var_of_part = [=](Node* x) -> Node* {
auto* fc_op = x->inputs[0]->inputs[0]; auto* fc_op = x->inputs[0];
for (auto* fc_i : fc_op->inputs) { for (auto* in : fc_op->inputs) {
if (!fc_i->inputs.empty()) { if (!in->inputs.empty()) {
return fc_i->inputs[0]; // w and bias has no input.
return in;
} }
} }
return nullptr; return nullptr;
...@@ -142,76 +161,76 @@ PDNode* BuildRepeatedFCReluPattern(PDPattern* pattern, ...@@ -142,76 +161,76 @@ PDNode* BuildRepeatedFCReluPattern(PDPattern* pattern,
return true; return true;
}; };
std::vector<PDNode*> fc_input_var(num_fc); PDNode* fc_input_var_0 = nullptr;
std::vector<PDNode*> fc_output_var(num_fc); std::vector<PDNode*> fc_output_var(num_fc);
std::vector<PDNode*> fc_weight_var(num_fc); std::vector<PDNode*> fc_weight_var(num_fc);
std::vector<PDNode*> fc_bias_var(num_fc); std::vector<PDNode*> fc_bias_var(num_fc);
std::vector<PDNode*> fc_ops(num_fc); std::vector<PDNode*> fc_ops(num_fc);
std::vector<PDNode*> relu_ops(num_fc);
for (int i = 0; i < num_fc; ++i) { for (int i = 0; i < num_fc; ++i) {
fc_input_var[i] = pattern->NewNode( if (i == 0) {
[=](Node* x) { fc_input_var_0 = pattern->NewNode(
if (i == 0 && x->outputs.size() > 0) { [=](Node* x) {
bool ok = x->inputs.size() > 0; if (x->outputs.size() <= 0 || x->inputs.size() <= 0U) {
if (!ok) {
return false; return false;
} }
int idx = find_fc_idx(x); int fc_idx = FindFCIdx(x);
if (idx == 0) { if (fc_idx < 0) {
return false;
} else if (fc_idx == 0) {
return var_next_is_fc_act_repeated_n_times(x, num_fc - i, "relu"); return var_next_is_fc_act_repeated_n_times(x, num_fc - i, "relu");
} else { } else {
x = next_var_of_part(x, idx); x = next_var_of_part(x, fc_idx);
return var_next_is_fc_act_repeated_n_times( return var_next_is_fc_act_repeated_n_times(
x, std::max(1, num_fc - i - 1), "relu"); x, std::max(1, num_fc - i - 1), "relu");
} }
} else { },
return var_next_is_fc_act_repeated_n_times(x, num_fc - i, "relu") && name_scope + "/fc_in_0");
x->inputs.size() > 0 && }
var_before_is_fc_act_repeated_n_times(x, i, "relu");
}
},
name_scope + "/fc_in_" + std::to_string(i));
fc_weight_var[i] = pattern->NewNode( fc_weight_var[i] = pattern->NewNode(
[=](Node* x) { [=](Node* x) {
if (!IsParamOfFC(x, "W")) {
return false;
}
auto* fc_op = x->outputs[0];
int input_idx = FindInputIdx(fc_op, "Input", "relu");
return var_next_is_fc_act_repeated_n_times(x, num_fc - i, "relu") && return var_next_is_fc_act_repeated_n_times(x, num_fc - i, "relu") &&
x->inputs.empty() && var_before_is_fc_act_repeated_n_times(fc_op->inputs[input_idx],
var_before_is_fc_act_repeated_n_times(x->outputs[0]->inputs[0], i, "relu");
i, "relu") &&
x->Name() == x->outputs[0]->Op()->Input("W")[0];
}, },
name_scope + "/fc_weight_" + std::to_string(i)); name_scope + "/fc_weight_" + std::to_string(i));
fc_bias_var[i] = pattern->NewNode( fc_bias_var[i] = pattern->NewNode(
[=](Node* x) { [=](Node* x) {
if (!IsParamOfFC(x, "Bias")) {
return false;
}
auto* fc_op = x->outputs[0];
int input_idx = FindInputIdx(fc_op, "Input", "relu");
return var_next_is_fc_act_repeated_n_times(x, num_fc - i, "relu") && return var_next_is_fc_act_repeated_n_times(x, num_fc - i, "relu") &&
x->inputs.empty() && var_before_is_fc_act_repeated_n_times(fc_op->inputs[input_idx],
var_before_is_fc_act_repeated_n_times(x->outputs[0]->inputs[0], i, "relu");
i, "relu") &&
x->Name() == x->outputs[0]->Op()->Input("Bias")[0];
}, },
name_scope + "/fc_bias_" + std::to_string(i)); name_scope + "/fc_bias_" + std::to_string(i));
fc_output_var[i] = pattern->NewNode( fc_output_var[i] = pattern->NewNode(
[=](Node* x) { [=](Node* x) {
bool basic = x && x->IsVar() && VarLinksFromOp(x, "fc") && if (!IsOutputOfFC(x)) {
VarLinksToOp(x, "relu") && x->inputs.size() == 1 &&
x->inputs[0]->inputs.size() == 3;
if (!basic) {
return false; return false;
} }
x = x->inputs[0]->inputs[0]; x = before_var_of_part(x);
if (i == 0 && x->outputs.size() > 0) { if (i == 0 && x->outputs.size() > 0U) {
bool ok = x->inputs.size() > 0; if (x->inputs.size() <= 0U) {
if (!ok) {
return false; return false;
} }
int idx = find_fc_idx(x); int fc_idx = FindFCIdx(x);
if (idx == 0) { if (fc_idx < 0) {
return false;
} else if (fc_idx == 0) {
return var_next_is_fc_act_repeated_n_times(x, num_fc - i, "relu"); return var_next_is_fc_act_repeated_n_times(x, num_fc - i, "relu");
} else { } else {
x = next_var_of_part(x, idx); x = next_var_of_part(x, fc_idx);
return var_next_is_fc_act_repeated_n_times( return var_next_is_fc_act_repeated_n_times(
x, std::max(1, num_fc - i - 1), "relu"); x, std::max(1, num_fc - i - 1), "relu");
} }
...@@ -225,53 +244,29 @@ PDNode* BuildRepeatedFCReluPattern(PDPattern* pattern, ...@@ -225,53 +244,29 @@ PDNode* BuildRepeatedFCReluPattern(PDPattern* pattern,
fc_ops[i] = pattern->NewNode( fc_ops[i] = pattern->NewNode(
[=](Node* x) { [=](Node* x) {
bool basic = x && x->IsOp() && x->Op()->Type() == "fc" && if (!IsFCWithAct(x, "relu")) {
x->inputs.size() == 3 && x->outputs.size() == 1;
if (!basic) {
return false; return false;
} }
auto* fc_out_var = x->outputs[0]; auto* fc_out_var = x->outputs[0];
return fc_out_var && fc_out_var->IsVar() && return fc_out_var && fc_out_var->IsVar() &&
fc_out_var->outputs.size() == 1 && fc_out_var->outputs.size() == 1 &&
VarLinksToOp(fc_out_var, "relu") && var_next_is_fc_act_repeated_n_times(fc_out_var, num_fc - i - 1,
fc_out_var->outputs[0]->outputs.size() == 1 && "relu") &&
var_next_is_fc_act_repeated_n_times( var_before_is_fc_act_repeated_n_times(fc_out_var, i + 1,
fc_out_var->outputs[0]->outputs[0], num_fc - i - 1,
"relu") &&
var_before_is_fc_act_repeated_n_times(
fc_out_var->outputs[0]->outputs[0], i + 1, "relu");
},
name_scope + "/fc_op_" + std::to_string(i));
relu_ops[i] = pattern->NewNode(
[=](Node* x) {
return x && x->IsOp() && x->Op()->Type() == "relu" &&
x->inputs.size() == 1 && x->outputs.size() == 1 &&
x->inputs[0]->IsVar() && VarLinksFromOp(x->inputs[0], "fc") &&
x->outputs[0]->IsVar() &&
var_next_is_fc_act_repeated_n_times(x->outputs[0],
num_fc - i - 1, "relu") &&
var_before_is_fc_act_repeated_n_times(x->outputs[0], i + 1,
"relu"); "relu");
}, },
name_scope + "/act_op_" + std::to_string(i)); name_scope + "/fc_op_" + std::to_string(i));
fc_ops[i]
->LinksFrom({fc_input_var[i], fc_weight_var[i], fc_bias_var[i]})
.LinksTo({fc_output_var[i]});
relu_ops[i]->LinksFrom({fc_output_var[i]});
}
auto* last_out_var = pattern->NewNode( if (i == 0) {
[=](Node* x) { fc_ops[i]
return var_before_is_fc_act_repeated_n_times(x, num_fc, "relu"); ->LinksFrom({fc_input_var_0, fc_weight_var[i], fc_bias_var[i]})
}, .LinksTo({fc_output_var[i]});
name_scope + "/act_out"); } else {
for (int i = 0; i < num_fc - 1; ++i) { fc_ops[i]
relu_ops[i]->LinksTo({fc_input_var[i + 1]}); ->LinksFrom({fc_output_var[i - 1], fc_weight_var[i], fc_bias_var[i]})
.LinksTo({fc_output_var[i]});
}
} }
relu_ops[num_fc - 1]->LinksTo({last_out_var});
return last_out_var;
} }
static int BuildFusion(Graph* graph, const std::string& name_scope, static int BuildFusion(Graph* graph, const std::string& name_scope,
...@@ -304,11 +299,11 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, ...@@ -304,11 +299,11 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
auto& fused_pattern = gpd.pattern(); auto& fused_pattern = gpd.pattern();
for (int i = 0; i < num_fc; ++i) { for (int i = 0; i < num_fc; ++i) {
if (i >= 1) { if (i < num_fc - 1) {
relu_vars[i - 1] = relu_vars[i] =
retrieve_node(name_scope + "/fc_in_" + std::to_string(i), subgraph, retrieve_node(name_scope + "/fc_out_" + std::to_string(i), subgraph,
fused_pattern); fused_pattern);
relu_names[i - 1] = relu_vars[i - 1]->Name(); relu_names[i] = relu_vars[i]->Name();
} }
weights_vars[i] = weights_vars[i] =
...@@ -324,7 +319,8 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, ...@@ -324,7 +319,8 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
auto* input_var = auto* input_var =
retrieve_node(name_scope + "/fc_in_0", subgraph, fused_pattern); retrieve_node(name_scope + "/fc_in_0", subgraph, fused_pattern);
auto* last_out_var = auto* last_out_var =
retrieve_node(name_scope + "/act_out", subgraph, fused_pattern); retrieve_node(name_scope + "/fc_out_" + std::to_string(num_fc - 1),
subgraph, fused_pattern);
// Create New OpDesc // Create New OpDesc
OpDesc op_desc; OpDesc op_desc;
...@@ -334,6 +330,7 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, ...@@ -334,6 +330,7 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
op_desc.SetInput("Bias", bias_names); op_desc.SetInput("Bias", bias_names);
op_desc.SetOutput("ReluOut", relu_names); op_desc.SetOutput("ReluOut", relu_names);
op_desc.SetOutput("Out", {last_out_var->Name()}); op_desc.SetOutput("Out", {last_out_var->Name()});
auto* op = graph->CreateOpNode(&op_desc); auto* op = graph->CreateOpNode(&op_desc);
IR_NODE_LINK_TO(input_var, op); IR_NODE_LINK_TO(input_var, op);
for (size_t i = 0; i < weights_vars.size(); ++i) { for (size_t i = 0; i < weights_vars.size(); ++i) {
...@@ -367,7 +364,9 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, ...@@ -367,7 +364,9 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
} }
void RepeatedFCReluFusePass::ApplyImpl(ir::Graph* graph) const { void RepeatedFCReluFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(graph);
FusePassBase::Init(name_scope_, graph); FusePassBase::Init(name_scope_, graph);
int fusion_count = 0; int fusion_count = 0;
for (int i = MAX_NUM_FC; i > 1; --i) { for (int i = MAX_NUM_FC; i > 1; --i) {
fusion_count += fusion_count +=
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.h"
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
namespace paddle {
namespace framework {
namespace ir {
void TestMain(int num_fc) {
// inputs operator output
// -------------------------------------------------------------
// (x, filters, bias_0) conv2d -> conv2d_out
// (conv2d_out, fc_weights_0, fc_bias_0) fc -> fc_out_0
// (fc_out_0, fc_weights_1, fc_bias_1) fc -> fc_out_1
// ...
Layers layers;
VarDesc* x = layers.data("x");
VarDesc* filters = layers.data("filters", {}, true);
VarDesc* bias_0 = layers.data("bias_0", {}, true);
VarDesc* conv2d_out = layers.conv2d(x, filters, bias_0);
VarDesc* fc_in = conv2d_out;
for (int i = 0; i < num_fc; ++i) {
VarDesc* weights_i =
layers.data("fc_weights_" + std::to_string(i), {}, true);
VarDesc* bias_i = layers.data("fc_bias_" + std::to_string(i), {}, true);
std::string activation_type = i < (num_fc - 1) ? "relu" : "";
VarDesc* fc_out = layers.fc(fc_in, weights_i, bias_i, 1, activation_type);
fc_in = fc_out;
}
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
auto pass = PassRegistry::Instance().Get("repeated_fc_relu_fuse_pass");
int num_nodes_before = graph->Nodes().size();
int num_fc_nodes_before = GetNumOpNodes(graph, "fc");
VLOG(3) << DebugString(graph);
graph.reset(pass->Apply(graph.release()));
int num_nodes_after = graph->Nodes().size();
int num_fused_nodes_after = GetNumOpNodes(graph, "fusion_repeated_fc_relu");
VLOG(3) << DebugString(graph);
// Delete (num_fc_nodes_before - 1) fc ops
PADDLE_ENFORCE_EQ(num_nodes_before - (num_fc_nodes_before - 1) + 1,
num_nodes_after);
PADDLE_ENFORCE_EQ(num_fused_nodes_after, 1);
}
TEST(RepeatedFCReluFusePass, basic_3) { TestMain(3); }
TEST(RepeatedFCReluFusePass, basic_9) { TestMain(9); }
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(repeated_fc_relu_fuse_pass);
...@@ -146,7 +146,7 @@ TEST(Analyzer_seq_conv1, fuse_statis) { ...@@ -146,7 +146,7 @@ TEST(Analyzer_seq_conv1, fuse_statis) {
ASSERT_TRUE(fuse_statis.count("seqconv_eltadd_relu_fuse")); ASSERT_TRUE(fuse_statis.count("seqconv_eltadd_relu_fuse"));
EXPECT_EQ(fuse_statis.at("fc_fuse"), 2); EXPECT_EQ(fuse_statis.at("fc_fuse"), 2);
EXPECT_EQ(fuse_statis.at("seqconv_eltadd_relu_fuse"), 6); EXPECT_EQ(fuse_statis.at("seqconv_eltadd_relu_fuse"), 6);
EXPECT_EQ(num_ops, 32); EXPECT_EQ(num_ops, 31);
} }
// Compare result of NativeConfig and AnalysisConfig // Compare result of NativeConfig and AnalysisConfig
......
...@@ -132,7 +132,7 @@ class FCOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -132,7 +132,7 @@ class FCOpMaker : public framework::OpProtoAndCheckerMaker {
.SetDefault(1) .SetDefault(1)
.EqualGreaterThan(1); .EqualGreaterThan(1);
AddAttr<std::string>("activation_type", AddAttr<std::string>("activation_type",
"Avctivation type used in fully connected operator.") "Activation type used in fully connected operator.")
.SetDefault(""); .SetDefault("");
AddAttr<bool>("use_mkldnn", AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel") "(bool, default false) Only used in mkldnn kernel")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册