未验证 提交 c67c8758 编写于 作者: Y Yiqun Liu 提交者: GitHub

Enhance fc_fuse_pass to enable fusing relu to fc_op (#19733)

* Refine the codes related to fc op.

* Add GPU implementation for fc functor.

* Apply fc_fuse_pass in GPU inference.
test=develop

* Change the cmake for fc op.

* Change PADDLE_ENFORCE to PADDLE_ENFORCE_EQ.

* Add an attribute to set the activation type in fc_op.

* Enhance the unittest of fc_op.
test=develop

* Remove the declaration of FCOpGrad back to the header file.
test=develop

* Set default value for newly added arguments in test_fc_op.
test=develop

* Enhance fc_fuse_pass to enable fusing relu.

* Allow print the shapes of var_desc in graph.
test=develop

* Enhance fc_fuse_pass_tester.

* Remove the use of PADDLE_ENFORCE.
test=develop

* Correct the number of ops after fusing.
test=develop

* Fix a typo.
test=develop

* Set activation_type to null when there is no relu in fc.
test=develop

* Refine fc_fuse_pass's codes.

* Enable the set of shape for tensor.

* Refine repeated_fc_relu_pass and add unittest.
test=develop
上级 b5a5d93b
......@@ -119,6 +119,7 @@ cc_test(test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS g
cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass framework_proto)
cc_test(test_seqpool_concat_fuse_pass SRCS seqpool_concat_fuse_pass_tester.cc DEPS seqpool_concat_fuse_pass framework_proto)
cc_test(test_seqpool_cvm_concat_fuse_pass SRCS seqpool_cvm_concat_fuse_pass_tester.cc DEPS seqpool_cvm_concat_fuse_pass framework_proto)
cc_test(test_repeated_fc_relu_fuse_pass SRCS repeated_fc_relu_fuse_pass_tester.cc DEPS repeated_fc_relu_fuse_pass framework_proto)
cc_test(test_is_test_pass SRCS is_test_pass_tester.cc DEPS is_test_pass)
cc_test(test_simplify_with_basic_ops_pass SRCS simplify_with_basic_ops_pass_tester.cc DEPS simplify_with_basic_ops_pass)
if(WITH_GPU)
......
......@@ -44,7 +44,8 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
patterns::FC fc_pattern(pattern, name_scope);
// fc_out is a tmp var, will be removed after fuse, so marked as intermediate.
auto* fc_out = fc_pattern(embedding_out, with_fc_bias)->AsIntermediate();
auto* fc_out = fc_pattern(embedding_out, with_fc_bias, /* with_relu */ false)
->AsIntermediate();
patterns::LSTM lstm_pattern(pattern, name_scope);
lstm_pattern(fc_out);
......@@ -194,7 +195,7 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
}
if (with_fc_bias) {
GET_IR_NODE_FROM_SUBGRAPH(fc_out, Out, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(fc_out, elementwise_add_out, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(fc_bias, bias, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add, elementwise_add, fc_pattern);
embedding_lstm_creator(lookup_table, W, lstm, subgraph.at(x), w, Weight,
......
......@@ -25,83 +25,110 @@ namespace framework {
namespace ir {
void FCFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE(graph);
PADDLE_ENFORCE_NOT_NULL(graph);
FusePassBase::Init("fc_fuse", graph);
std::unordered_set<Node*> nodes2delete;
int found_fc_count = 0;
for (bool with_relu : {true, false}) {
found_fc_count += ApplyFCPattern(graph, with_relu);
}
AddStatis(found_fc_count);
}
int FCFusePass::ApplyFCPattern(Graph* graph, bool with_relu) const {
GraphPatternDetector gpd;
auto* x = gpd.mutable_pattern()
->NewNode("fc_fuse/x")
->AsInput()
->assert_is_op_input("mul", "X");
patterns::FC fc_pattern(gpd.mutable_pattern(), "fc_fuse");
fc_pattern(x, true /*with bias*/);
fc_pattern(x, true /*with bias*/, with_relu);
int found_fc_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
if (subgraph.count(x) <= 0) {
LOG(WARNING) << "The subgraph is empty.";
return;
}
VLOG(4) << "handle FC fuse";
GET_IR_NODE_FROM_SUBGRAPH(w, w, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(fc_bias, bias, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(fc_out, Out, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(bias, bias, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out,
fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul, mul, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add, elementwise_add, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul_out, mul_out, fc_pattern);
Node* relu = nullptr;
Node* relu_out = nullptr;
if (with_relu) {
GET_IR_NODE_FROM_SUBGRAPH(tmp_relu, relu, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(tmp_relu_out, relu_out, fc_pattern);
relu = tmp_relu;
relu_out = tmp_relu_out;
}
auto base_op_desc = mul->Op();
// Create an FC Node.
// OpDesc desc(base_op_desc, nullptr);
OpDesc desc;
std::string fc_x_in = subgraph.at(x)->Name();
std::string fc_Y_in = w->Name();
std::string fc_bias_in = fc_bias->Name();
std::string fc_out_out = fc_out->Name();
desc.SetInput("Input", std::vector<std::string>({fc_x_in}));
desc.SetInput("W", std::vector<std::string>({fc_Y_in}));
desc.SetInput("Bias", std::vector<std::string>({fc_bias_in}));
desc.SetOutput("Out", std::vector<std::string>({fc_out_out}));
desc.SetType("fc");
// Set inputs of fc
desc.SetInput("Input", {subgraph.at(x)->Name()});
desc.SetInput("W", {w->Name()});
desc.SetInput("Bias", {bias->Name()});
// Set output of fc
std::string fc_out_name =
with_relu ? relu_out->Name() : elementwise_add_out->Name();
desc.SetOutput("Out", std::vector<std::string>({fc_out_name}));
// Set attrs of fc
desc.SetAttr("in_num_col_dims", mul->Op()->GetAttr("x_num_col_dims"));
std::string activation_type = with_relu ? "relu" : "";
desc.SetAttr("activation_type", activation_type);
// For anakin subgraph int8
// When in anakin subgraph int8 mode, the pattern like "fake_quant + mul +
// fake_dequant"
// can be detected by the quant_dequant_fuse_pass. This pass will add
// "input_scale",
// "weight_scale" which are extracted from fake_quant op and fake_dequant op
// to mul op,
// and then delete the fake_quant op and fake_dequant op in the graph. If
// the mul op
// has the scale info, we should add those to the fused fc.
if (base_op_desc->HasAttr("enable_int8")) {
desc.SetAttr("enable_int8", base_op_desc->GetAttr("enable_int8"));
desc.SetAttr("input_scale", base_op_desc->GetAttr("input_scale"));
desc.SetAttr("weight_scale", base_op_desc->GetAttr("weight_scale"));
if (base_op_desc->HasAttr("out_scale"))
desc.SetAttr("out_scale", base_op_desc->GetAttr("out_scale"));
// fake_dequant" can be detected by the quant_dequant_fuse_pass. This pass
// will add "input_scale", "weight_scale" which are extracted from
// fake_quant op and fake_dequant op to mul op, and then delete the
// fake_quant op and fake_dequant op in the graph. If the mul op has the
// scale info, we should add those to the fused fc.
auto* mul_op_desc = mul->Op();
if (mul_op_desc->HasAttr("enable_int8")) {
desc.SetAttr("enable_int8", mul_op_desc->GetAttr("enable_int8"));
desc.SetAttr("input_scale", mul_op_desc->GetAttr("input_scale"));
desc.SetAttr("weight_scale", mul_op_desc->GetAttr("weight_scale"));
if (mul_op_desc->HasAttr("out_scale"))
desc.SetAttr("out_scale", mul_op_desc->GetAttr("out_scale"));
auto elementwise_desc = elementwise_add->Op();
if (elementwise_desc->HasAttr("out_scale"))
desc.SetAttr("out_scale", elementwise_desc->GetAttr("out_scale"));
}
desc.SetType("fc");
auto fc_node = g->CreateOpNode(&desc); // OpDesc will be copied.
GraphSafeRemoveNodes(graph, {mul, elementwise_add, mul_out});
if (with_relu) {
GraphSafeRemoveNodes(
graph, {mul, elementwise_add, mul_out, elementwise_add_out, relu});
} else {
GraphSafeRemoveNodes(graph, {mul, elementwise_add, mul_out});
}
PADDLE_ENFORCE(subgraph.count(x));
IR_NODE_LINK_TO(subgraph.at(x), fc_node);
IR_NODE_LINK_TO(w, fc_node);
IR_NODE_LINK_TO(fc_bias, fc_node);
IR_NODE_LINK_TO(fc_node, fc_out);
IR_NODE_LINK_TO(bias, fc_node);
if (with_relu) {
IR_NODE_LINK_TO(fc_node, relu_out);
} else {
IR_NODE_LINK_TO(fc_node, elementwise_add_out);
}
found_fc_count++;
};
gpd(graph, handler);
AddStatis(found_fc_count);
return found_fc_count;
}
} // namespace ir
......
......@@ -31,7 +31,9 @@ class FCFusePass : public FusePassBase {
virtual ~FCFusePass() {}
protected:
void ApplyImpl(ir::Graph* graph) const override;
void ApplyImpl(Graph* graph) const override;
int ApplyFCPattern(Graph* graph, bool with_relu) const;
};
} // namespace ir
......
......@@ -15,81 +15,53 @@
#include "paddle/fluid/framework/ir/fc_fuse_pass.h"
#include <gtest/gtest.h>
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
namespace paddle {
namespace framework {
namespace ir {
void SetOp(ProgramDesc* prog, const std::string& type,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs) {
auto* op = prog->MutableBlock(0)->AppendOp();
op->SetType(type);
if (type == "mul") {
op->SetInput("X", {inputs[0]});
op->SetInput("Y", {inputs[1]});
op->SetAttr("x_num_col_dims", {1});
} else if (type == "elementwise_add") {
op->SetInput("X", inputs);
}
op->SetOutput("Out", outputs);
op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
static_cast<int>(OpRole::kForward));
}
// a->OP0->b
// a->OP1->c
// (b, c)->mul->d
// (d, e)->elementwise_add->f
ProgramDesc BuildProgramDesc() {
ProgramDesc prog;
for (auto& v : std::vector<std::string>({"a", "b", "c", "d", "e", "f"})) {
auto* var = prog.MutableBlock(0)->Var(v);
var->SetType(proto::VarType::SELECTED_ROWS);
if (v == "c") {
var->SetPersistable(true);
}
}
SetOp(&prog, "OP0", std::vector<std::string>({"a"}),
std::vector<std::string>({"b"}));
SetOp(&prog, "OP1", std::vector<std::string>({"a"}),
std::vector<std::string>({"c"}));
SetOp(&prog, "mul", std::vector<std::string>({"b", "c"}),
std::vector<std::string>({"d"}));
SetOp(&prog, "elementwise_add", std::vector<std::string>({"d", "e"}),
std::vector<std::string>({"f"}));
return prog;
}
TEST(FCFusePass, basic) {
auto prog = BuildProgramDesc();
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
// inputs operator output
// --------------------------------------------------------
// (a, filters_0 bias_0) conv2d -> conv2d_out
// conv2d_out relu -> relu_out_0
// (relu_out_0, weights_0) mul -> mul_out_0
// (mul_out_0, bias_1) elementwise_add -> add_out_0
// add_out_0 relu -> relu_out_1
// (relu_out_1, weights_1) mul -> mul_out_1
// (mul_out_1, bias_2) elementwise_add -> add_out_1
Layers layers;
auto* a = layers.data("a");
auto* filters_0 = layers.data("conv2d_filters_0", {}, true);
auto* bias_0 = layers.data("conv2d_bias_0", {}, true);
auto* conv2d_out = layers.conv2d(a, filters_0, bias_0, false);
auto* relu_out_0 = layers.relu(conv2d_out);
auto* weights_0 = layers.data("weights_0", {}, true);
auto* mul_out_0 = layers.mul(relu_out_0, weights_0);
auto* bias_1 = layers.data("bias_1", {}, true);
auto* add_out_0 = layers.elementwise_add(mul_out_0, bias_1);
auto* relu_out_1 = layers.relu(add_out_0);
auto* weights_1 = layers.data("weights_1", {}, true);
auto* mul_out_1 = layers.mul(relu_out_1, weights_1);
auto* bias_2 = layers.data("bias_2", {}, true);
auto* add_out_1 = layers.elementwise_add(mul_out_1, bias_2);
VLOG(4) << add_out_1;
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
auto pass = PassRegistry::Instance().Get("fc_fuse_pass");
int pre_nodes = graph->Nodes().size();
int num_nodes_before = graph->Nodes().size();
int num_mul_nodes_before = GetNumOpNodes(graph, "mul");
VLOG(3) << DebugString(graph);
graph.reset(pass->Apply(graph.release()));
int num_nodes_after = graph->Nodes().size();
int num_fc_nodes_after = GetNumOpNodes(graph, "fc");
VLOG(3) << DebugString(graph);
int after_nodes = graph->Nodes().size();
// Remove 3 Nodes: MUL,ELEMENTWISE_ADD, mul_out
// Add 1 Node: FC
EXPECT_EQ(pre_nodes - 2, after_nodes);
// Assert fc op in newly generated graph
int fc_count = 0;
for (auto* node : graph->Nodes()) {
if (node->IsOp() && node->Op()->Type() == "fc") {
++fc_count;
}
}
EXPECT_EQ(fc_count, 1);
PADDLE_ENFORCE_EQ(num_nodes_before, num_nodes_after + 6);
PADDLE_ENFORCE_EQ(num_fc_nodes_after, 2);
PADDLE_ENFORCE_EQ(num_mul_nodes_before, num_fc_nodes_after);
}
} // namespace ir
......
......@@ -33,7 +33,7 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
PDNode* x =
pattern->NewNode(patterns::UniqueKey("x"))->assert_var_not_persistable();
auto* fc_out = fc_pattern(x, with_fc_bias);
auto* fc_out = fc_pattern(x, with_fc_bias, /* with_relu */ false);
fc_out->AsIntermediate(); // fc_out is a tmp var, will be removed after fuse.
gru_pattern(fc_out);
......@@ -116,7 +116,7 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
auto* x_n = subgraph.at(x);
GET_IR_NODE_FROM_SUBGRAPH(w, w, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul, mul, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(fc_out, Out, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(fc_out, elementwise_add_out, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(Weight, Weight, gru_pattern);
GET_IR_NODE_FROM_SUBGRAPH(gru, gru, gru_pattern);
GET_IR_NODE_FROM_SUBGRAPH(Bias, Bias, gru_pattern);
......
......@@ -33,7 +33,8 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
patterns::FC fc_pattern(pattern, name_scope);
// fc_out is a tmp var, will be removed after fuse, so marked as intermediate.
auto* fc_out = fc_pattern(x, with_fc_bias)->AsIntermediate();
auto* fc_out =
fc_pattern(x, with_fc_bias, /* with_relu */ false)->AsIntermediate();
patterns::LSTM lstm_pattern(pattern, name_scope);
lstm_pattern(fc_out);
......@@ -132,7 +133,7 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
GET_IR_NODE_FROM_SUBGRAPH(w, w, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul, mul, fc_pattern);
if (with_fc_bias) {
GET_IR_NODE_FROM_SUBGRAPH(fc_out, Out, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(fc_out, elementwise_add_out, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(fc_bias, bias, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add, elementwise_add, fc_pattern);
lstm_creator(lstm, subgraph.at(x), w, Weight, Bias, Hidden, Cell, fc_out,
......
......@@ -846,7 +846,7 @@ PDNode *patterns::SeqConvEltAddRelu::operator()(
}
PDNode *patterns::FC::operator()(paddle::framework::ir::PDNode *x,
bool with_bias) {
bool with_bias, bool with_relu) {
// Create shared nodes.
x->assert_is_op_input("mul", "X");
auto *mul = pattern->NewNode(mul_repr())->assert_is_op("mul");
......@@ -859,11 +859,10 @@ PDNode *patterns::FC::operator()(paddle::framework::ir::PDNode *x,
auto *mul_out_var =
pattern->NewNode(mul_out_repr())->assert_is_op_output("mul");
// Add links.
mul->LinksFrom({x, mul_w_var}).LinksTo({mul_out_var});
if (!with_bias) { // not with bias
// Add links.
mul->LinksFrom({x, mul_w_var}).LinksTo({mul_out_var});
return mul_out_var;
} else { // with bias
mul_out_var->AsIntermediate()->assert_is_op_input("elementwise_add");
// Create operators.
......@@ -872,15 +871,29 @@ PDNode *patterns::FC::operator()(paddle::framework::ir::PDNode *x,
// Create variables.
auto *bias = pattern->NewNode(bias_repr())
->assert_is_op_input("elementwise_add")
->assert_is_persistable_var()
->AsInput();
auto *fc_out = pattern->NewNode(Out_repr())
->AsOutput()
->assert_is_op_output("elementwise_add");
auto *elementwise_add_out_var =
pattern->NewNode(elementwise_add_out_repr())
->AsOutput()
->assert_is_op_output("elementwise_add");
mul->LinksFrom({mul_w_var, x}).LinksTo({mul_out_var});
elementwise_add->LinksFrom({mul_out_var, bias}).LinksTo({fc_out});
return fc_out;
elementwise_add->LinksFrom({mul_out_var, bias})
.LinksTo({elementwise_add_out_var});
if (!with_relu) {
return elementwise_add_out_var;
} else {
elementwise_add_out_var->AsIntermediate()->assert_is_op_input("relu");
// Create operators.
auto *relu = pattern->NewNode(relu_repr())->assert_is_op("relu");
auto *relu_out_var = pattern->NewNode(relu_out_repr())
->AsOutput()
->assert_is_op_output("relu");
relu->LinksFrom({elementwise_add_out_var}).LinksTo({relu_out_var});
return relu_out_var;
}
}
}
......
......@@ -487,17 +487,19 @@ struct FC : public PatternBase {
FC(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "fc") {}
PDNode* operator()(PDNode* x, bool with_bias);
PDNode* operator()(PDNode* x, bool with_bias, bool with_relu);
// declare operator node's name
PATTERN_DECL_NODE(fc);
PATTERN_DECL_NODE(mul);
PATTERN_DECL_NODE(elementwise_add);
PATTERN_DECL_NODE(relu);
// declare variable node's name
PATTERN_DECL_NODE(w);
PATTERN_DECL_NODE(mul_out); // (x,w) -> mul_out
PATTERN_DECL_NODE(bias);
PATTERN_DECL_NODE(Out);
PATTERN_DECL_NODE(elementwise_add_out);
PATTERN_DECL_NODE(relu_out);
};
// MKL-DNN's FC with bias
......
......@@ -89,6 +89,17 @@ void GraphVizPass::ApplyImpl(ir::Graph* graph) const {
marked_nodes.count(n) ? marked_op_attrs : op_attrs;
dot.AddNode(node_id, attr, node_id);
} else if (n->IsVar()) {
if (n->Var() && n->Var()->GetType() == proto::VarType::LOD_TENSOR) {
bool is_first = true;
for (int64_t length : n->Var()->GetShape()) {
if (is_first) {
node_id += "\n" + std::to_string(length);
is_first = false;
} else {
node_id += "," + std::to_string(length);
}
}
}
decltype(op_attrs)* attr;
if (marked_nodes.count(n)) {
attr = &marked_var_attrs;
......
......@@ -28,10 +28,13 @@ struct Layers {
public:
const ProgramDesc& main_program() { return program_; }
VarDesc* data(std::string name) { return lod_tensor(name); }
VarDesc* data(std::string name, std::vector<int64_t> shape = {},
bool is_persistable = false) {
return lod_tensor(name, shape, is_persistable);
}
VarDesc* conv2d(VarDesc* input, VarDesc* filter, VarDesc* bias,
bool use_cudnn) {
bool use_cudnn = false) {
VarDesc* out = lod_tensor(unique_name());
OpDesc* op = program_.MutableBlock(0)->AppendOp();
op->SetType("conv2d");
......@@ -76,8 +79,27 @@ struct Layers {
return unary_op("relu", x, out);
}
VarDesc* mul(VarDesc* x, VarDesc* y, VarDesc* out = nullptr) {
return binary_op("mul", x, y, out);
VarDesc* fc(VarDesc* input, VarDesc* w, VarDesc* bias,
int in_num_col_dims = 1, std::string activation_type = "") {
VarDesc* out = lod_tensor(unique_name());
OpDesc* op = program_.MutableBlock(0)->AppendOp();
op->SetType("fc");
op->SetInput("Input", {input->Name()});
op->SetInput("W", {w->Name()});
op->SetInput("Bias", {bias->Name()});
op->SetOutput("Out", {out->Name()});
op->SetAttr("in_num_col_dims", in_num_col_dims);
op->SetAttr("activation_type", activation_type);
op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
static_cast<int>(OpRole::kForward));
return out;
}
VarDesc* mul(VarDesc* x, VarDesc* y, VarDesc* out = nullptr,
int x_num_col_dims = 1) {
AttributeMap attrs;
attrs["x_num_col_dims"] = 1;
return binary_op("mul", x, y, out, &attrs);
}
VarDesc* elementwise_add(VarDesc* x, VarDesc* y, VarDesc* out = nullptr) {
......@@ -116,9 +138,12 @@ struct Layers {
}
private:
VarDesc* lod_tensor(std::string name) {
VarDesc* lod_tensor(std::string name, std::vector<int64_t> shape = {},
bool is_persistable = false) {
auto* var = program_.MutableBlock(0)->Var(name);
var->SetType(proto::VarType::LOD_TENSOR);
var->SetShape(shape);
var->SetPersistable(is_persistable);
return var;
}
......@@ -136,7 +161,8 @@ struct Layers {
}
VarDesc* binary_op(std::string type, VarDesc* x, VarDesc* y,
VarDesc* out = nullptr) {
VarDesc* out = nullptr,
const AttributeMap* attrs = nullptr) {
if (!out) {
out = lod_tensor(unique_name());
}
......@@ -145,6 +171,11 @@ struct Layers {
op->SetInput("X", {x->Name()});
op->SetInput("Y", {y->Name()});
op->SetOutput("Out", {out->Name()});
if (attrs) {
for (auto& iter : *attrs) {
op->SetAttr(iter.first, iter.second);
}
}
op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
static_cast<int>(OpRole::kForward));
return out;
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.h"
#include <algorithm> // for max
......@@ -25,55 +25,84 @@ namespace paddle {
namespace framework {
namespace ir {
PDNode* BuildRepeatedFCReluPattern(PDPattern* pattern,
const std::string& name_scope, int num_fc) {
static bool IsInputOfFC(Node* n) {
if (n && n->IsVar() && VarLinksToOp(n, "fc")) {
return true;
}
return false;
}
static bool IsOutputOfFC(Node* n) {
if (n && n->IsVar() && VarLinksFromOp(n, "fc") && n->inputs.size() == 1U) {
return true;
}
return false;
}
static bool IsFCWithAct(Node* n, const std::string& act_type = "relu") {
if (n && n->IsOp() && n->Op() && n->Op()->Type() == "fc" &&
n->inputs.size() == 3U && n->outputs.size() == 1U) {
return boost::get<std::string>(n->Op()->GetAttr("activation_type")) ==
act_type;
}
return false;
}
static bool IsParamOfFC(Node* n, const std::string& param_name) {
if (IsInputOfFC(n) && n->inputs.empty() &&
(n->Name() == n->outputs[0]->Op()->Input(param_name)[0])) {
return true;
}
return false;
}
static int FindFCIdx(Node* x, const std::string& act_type = "relu") {
if (!IsInputOfFC(x)) {
return -1;
}
for (size_t k = 0; k < x->outputs.size(); ++k) {
auto* out_op = x->outputs[k];
if (IsFCWithAct(out_op, act_type) && out_op->outputs.size() == 1U) {
return k;
}
}
return -1;
}
static int FindInputIdx(Node* n, const std::string& name,
const std::string& act_type = "relu") {
if (!IsFCWithAct(n, act_type)) {
return -1;
}
for (size_t i = 0; i < n->inputs.size(); ++i) {
if (n->inputs[i]->Name() == n->Op()->Input(name)[0]) {
return i;
}
}
return -1;
}
void BuildRepeatedFCReluPattern(PDPattern* pattern,
const std::string& name_scope, int num_fc) {
auto var_next_is_fc_act = [=](Node* x, const std::string& act_type = "relu",
bool check_in_has_only_one_out = true,
int fc_idx = 0) -> bool {
bool next_is_fc = x && x->IsVar() && VarLinksToOp(x, "fc");
if (check_in_has_only_one_out) {
next_is_fc = next_is_fc && x->outputs.size() == 1;
}
if (!next_is_fc) {
if (!IsInputOfFC(x)) {
return false;
}
auto* fc_op = x->outputs[fc_idx];
bool next_is_act = fc_op && fc_op->IsOp() && fc_op->outputs.size() == 1 &&
fc_op->outputs[0] && fc_op->outputs[0]->IsVar() &&
VarLinksToOp(fc_op->outputs[0], act_type) &&
fc_op->outputs[0]->outputs.size() == 1;
if (!next_is_act) {
if (check_in_has_only_one_out && x->outputs.size() != 1U) {
return false;
}
auto* act_op = fc_op->outputs[0]->outputs[0];
return act_op && act_op->IsOp() && act_op->outputs.size() == 1;
};
auto find_fc_idx = [=](Node* x, const std::string& act_type = "relu") -> int {
bool next_is_fc = x && x->IsVar() && VarLinksToOp(x, "fc");
if (!next_is_fc) {
return 0;
}
for (size_t k = 0; k < x->outputs.size(); ++k) {
auto* fc_op = x->outputs[k];
bool next_is_act = fc_op && fc_op->IsOp() && fc_op->outputs.size() == 1 &&
fc_op->outputs[0] && fc_op->outputs[0]->IsVar() &&
VarLinksToOp(fc_op->outputs[0], act_type) &&
fc_op->outputs[0]->outputs.size() == 1;
if (!next_is_act) {
continue;
}
auto* act_op = fc_op->outputs[0]->outputs[0];
if (act_op && act_op->IsOp() && act_op->outputs.size() == 1) {
return k;
}
}
return 0;
auto* fc_op = x->outputs[fc_idx];
return IsFCWithAct(fc_op, act_type) && fc_op->outputs.size() == 1U;
};
// in -> fc -> out
// Current x is in, return fc's out which is next fc's input.
auto next_var_of_part = [=](Node* x, int fc_idx = 0) -> Node* {
return x->outputs[fc_idx]->outputs[0]->outputs[0]->outputs[0];
return x->outputs[fc_idx]->outputs[0];
};
auto var_next_is_fc_act_repeated_n_times = [=](
Node* x, int repeated_times, const std::string& act_type = "relu",
bool check_in_has_only_one_out = true) -> bool {
......@@ -87,25 +116,14 @@ PDNode* BuildRepeatedFCReluPattern(PDPattern* pattern,
return true;
};
// x is output of fc
auto var_before_is_fc_act = [=](Node* x, const std::string& act_type = "relu",
bool at_top = false) -> bool {
bool before_is_act =
x && x->IsVar() && x->inputs.size() == 1 && VarLinksFromOp(x, "relu");
if (!before_is_act) {
if (!IsOutputOfFC(x)) {
return false;
}
auto* relu_op = x->inputs[0];
bool before_is_fc = relu_op->IsOp() && relu_op->inputs.size() == 1 &&
relu_op->inputs[0]->IsVar() &&
VarLinksFromOp(relu_op->inputs[0], "fc") &&
relu_op->inputs[0]->inputs.size() == 1;
if (!before_is_fc) {
return false;
}
auto* fc_op = relu_op->inputs[0]->inputs[0];
bool is_fc = fc_op->IsOp() && fc_op->inputs.size() == 3;
if (!is_fc) {
auto* fc_op = x->inputs[0];
if (!IsFCWithAct(fc_op, act_type) || fc_op->inputs.size() != 3U) {
return false;
}
for (auto* fc_i : fc_op->inputs) {
......@@ -113,7 +131,7 @@ PDNode* BuildRepeatedFCReluPattern(PDPattern* pattern,
if (at_top) {
return true;
} else {
return VarLinksFromOp(fc_i, "relu");
return VarLinksFromOp(fc_i, "fc");
}
}
}
......@@ -121,10 +139,11 @@ PDNode* BuildRepeatedFCReluPattern(PDPattern* pattern,
};
auto before_var_of_part = [=](Node* x) -> Node* {
auto* fc_op = x->inputs[0]->inputs[0];
for (auto* fc_i : fc_op->inputs) {
if (!fc_i->inputs.empty()) {
return fc_i->inputs[0];
auto* fc_op = x->inputs[0];
for (auto* in : fc_op->inputs) {
if (!in->inputs.empty()) {
// w and bias has no input.
return in;
}
}
return nullptr;
......@@ -142,76 +161,76 @@ PDNode* BuildRepeatedFCReluPattern(PDPattern* pattern,
return true;
};
std::vector<PDNode*> fc_input_var(num_fc);
PDNode* fc_input_var_0 = nullptr;
std::vector<PDNode*> fc_output_var(num_fc);
std::vector<PDNode*> fc_weight_var(num_fc);
std::vector<PDNode*> fc_bias_var(num_fc);
std::vector<PDNode*> fc_ops(num_fc);
std::vector<PDNode*> relu_ops(num_fc);
for (int i = 0; i < num_fc; ++i) {
fc_input_var[i] = pattern->NewNode(
[=](Node* x) {
if (i == 0 && x->outputs.size() > 0) {
bool ok = x->inputs.size() > 0;
if (!ok) {
if (i == 0) {
fc_input_var_0 = pattern->NewNode(
[=](Node* x) {
if (x->outputs.size() <= 0 || x->inputs.size() <= 0U) {
return false;
}
int idx = find_fc_idx(x);
if (idx == 0) {
int fc_idx = FindFCIdx(x);
if (fc_idx < 0) {
return false;
} else if (fc_idx == 0) {
return var_next_is_fc_act_repeated_n_times(x, num_fc - i, "relu");
} else {
x = next_var_of_part(x, idx);
x = next_var_of_part(x, fc_idx);
return var_next_is_fc_act_repeated_n_times(
x, std::max(1, num_fc - i - 1), "relu");
}
} else {
return var_next_is_fc_act_repeated_n_times(x, num_fc - i, "relu") &&
x->inputs.size() > 0 &&
var_before_is_fc_act_repeated_n_times(x, i, "relu");
}
},
name_scope + "/fc_in_" + std::to_string(i));
},
name_scope + "/fc_in_0");
}
fc_weight_var[i] = pattern->NewNode(
[=](Node* x) {
if (!IsParamOfFC(x, "W")) {
return false;
}
auto* fc_op = x->outputs[0];
int input_idx = FindInputIdx(fc_op, "Input", "relu");
return var_next_is_fc_act_repeated_n_times(x, num_fc - i, "relu") &&
x->inputs.empty() &&
var_before_is_fc_act_repeated_n_times(x->outputs[0]->inputs[0],
i, "relu") &&
x->Name() == x->outputs[0]->Op()->Input("W")[0];
var_before_is_fc_act_repeated_n_times(fc_op->inputs[input_idx],
i, "relu");
},
name_scope + "/fc_weight_" + std::to_string(i));
fc_bias_var[i] = pattern->NewNode(
[=](Node* x) {
if (!IsParamOfFC(x, "Bias")) {
return false;
}
auto* fc_op = x->outputs[0];
int input_idx = FindInputIdx(fc_op, "Input", "relu");
return var_next_is_fc_act_repeated_n_times(x, num_fc - i, "relu") &&
x->inputs.empty() &&
var_before_is_fc_act_repeated_n_times(x->outputs[0]->inputs[0],
i, "relu") &&
x->Name() == x->outputs[0]->Op()->Input("Bias")[0];
var_before_is_fc_act_repeated_n_times(fc_op->inputs[input_idx],
i, "relu");
},
name_scope + "/fc_bias_" + std::to_string(i));
fc_output_var[i] = pattern->NewNode(
[=](Node* x) {
bool basic = x && x->IsVar() && VarLinksFromOp(x, "fc") &&
VarLinksToOp(x, "relu") && x->inputs.size() == 1 &&
x->inputs[0]->inputs.size() == 3;
if (!basic) {
if (!IsOutputOfFC(x)) {
return false;
}
x = x->inputs[0]->inputs[0];
if (i == 0 && x->outputs.size() > 0) {
bool ok = x->inputs.size() > 0;
if (!ok) {
x = before_var_of_part(x);
if (i == 0 && x->outputs.size() > 0U) {
if (x->inputs.size() <= 0U) {
return false;
}
int idx = find_fc_idx(x);
if (idx == 0) {
int fc_idx = FindFCIdx(x);
if (fc_idx < 0) {
return false;
} else if (fc_idx == 0) {
return var_next_is_fc_act_repeated_n_times(x, num_fc - i, "relu");
} else {
x = next_var_of_part(x, idx);
x = next_var_of_part(x, fc_idx);
return var_next_is_fc_act_repeated_n_times(
x, std::max(1, num_fc - i - 1), "relu");
}
......@@ -225,53 +244,29 @@ PDNode* BuildRepeatedFCReluPattern(PDPattern* pattern,
fc_ops[i] = pattern->NewNode(
[=](Node* x) {
bool basic = x && x->IsOp() && x->Op()->Type() == "fc" &&
x->inputs.size() == 3 && x->outputs.size() == 1;
if (!basic) {
if (!IsFCWithAct(x, "relu")) {
return false;
}
auto* fc_out_var = x->outputs[0];
return fc_out_var && fc_out_var->IsVar() &&
fc_out_var->outputs.size() == 1 &&
VarLinksToOp(fc_out_var, "relu") &&
fc_out_var->outputs[0]->outputs.size() == 1 &&
var_next_is_fc_act_repeated_n_times(
fc_out_var->outputs[0]->outputs[0], num_fc - i - 1,
"relu") &&
var_before_is_fc_act_repeated_n_times(
fc_out_var->outputs[0]->outputs[0], i + 1, "relu");
},
name_scope + "/fc_op_" + std::to_string(i));
relu_ops[i] = pattern->NewNode(
[=](Node* x) {
return x && x->IsOp() && x->Op()->Type() == "relu" &&
x->inputs.size() == 1 && x->outputs.size() == 1 &&
x->inputs[0]->IsVar() && VarLinksFromOp(x->inputs[0], "fc") &&
x->outputs[0]->IsVar() &&
var_next_is_fc_act_repeated_n_times(x->outputs[0],
num_fc - i - 1, "relu") &&
var_before_is_fc_act_repeated_n_times(x->outputs[0], i + 1,
var_next_is_fc_act_repeated_n_times(fc_out_var, num_fc - i - 1,
"relu") &&
var_before_is_fc_act_repeated_n_times(fc_out_var, i + 1,
"relu");
},
name_scope + "/act_op_" + std::to_string(i));
fc_ops[i]
->LinksFrom({fc_input_var[i], fc_weight_var[i], fc_bias_var[i]})
.LinksTo({fc_output_var[i]});
relu_ops[i]->LinksFrom({fc_output_var[i]});
}
name_scope + "/fc_op_" + std::to_string(i));
auto* last_out_var = pattern->NewNode(
[=](Node* x) {
return var_before_is_fc_act_repeated_n_times(x, num_fc, "relu");
},
name_scope + "/act_out");
for (int i = 0; i < num_fc - 1; ++i) {
relu_ops[i]->LinksTo({fc_input_var[i + 1]});
if (i == 0) {
fc_ops[i]
->LinksFrom({fc_input_var_0, fc_weight_var[i], fc_bias_var[i]})
.LinksTo({fc_output_var[i]});
} else {
fc_ops[i]
->LinksFrom({fc_output_var[i - 1], fc_weight_var[i], fc_bias_var[i]})
.LinksTo({fc_output_var[i]});
}
}
relu_ops[num_fc - 1]->LinksTo({last_out_var});
return last_out_var;
}
static int BuildFusion(Graph* graph, const std::string& name_scope,
......@@ -304,11 +299,11 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
auto& fused_pattern = gpd.pattern();
for (int i = 0; i < num_fc; ++i) {
if (i >= 1) {
relu_vars[i - 1] =
retrieve_node(name_scope + "/fc_in_" + std::to_string(i), subgraph,
if (i < num_fc - 1) {
relu_vars[i] =
retrieve_node(name_scope + "/fc_out_" + std::to_string(i), subgraph,
fused_pattern);
relu_names[i - 1] = relu_vars[i - 1]->Name();
relu_names[i] = relu_vars[i]->Name();
}
weights_vars[i] =
......@@ -324,7 +319,8 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
auto* input_var =
retrieve_node(name_scope + "/fc_in_0", subgraph, fused_pattern);
auto* last_out_var =
retrieve_node(name_scope + "/act_out", subgraph, fused_pattern);
retrieve_node(name_scope + "/fc_out_" + std::to_string(num_fc - 1),
subgraph, fused_pattern);
// Create New OpDesc
OpDesc op_desc;
......@@ -334,6 +330,7 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
op_desc.SetInput("Bias", bias_names);
op_desc.SetOutput("ReluOut", relu_names);
op_desc.SetOutput("Out", {last_out_var->Name()});
auto* op = graph->CreateOpNode(&op_desc);
IR_NODE_LINK_TO(input_var, op);
for (size_t i = 0; i < weights_vars.size(); ++i) {
......@@ -367,7 +364,9 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
}
void RepeatedFCReluFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(graph);
FusePassBase::Init(name_scope_, graph);
int fusion_count = 0;
for (int i = MAX_NUM_FC; i > 1; --i) {
fusion_count +=
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.h"
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
namespace paddle {
namespace framework {
namespace ir {
void TestMain(int num_fc) {
// inputs operator output
// -------------------------------------------------------------
// (x, filters, bias_0) conv2d -> conv2d_out
// (conv2d_out, fc_weights_0, fc_bias_0) fc -> fc_out_0
// (fc_out_0, fc_weights_1, fc_bias_1) fc -> fc_out_1
// ...
Layers layers;
VarDesc* x = layers.data("x");
VarDesc* filters = layers.data("filters", {}, true);
VarDesc* bias_0 = layers.data("bias_0", {}, true);
VarDesc* conv2d_out = layers.conv2d(x, filters, bias_0);
VarDesc* fc_in = conv2d_out;
for (int i = 0; i < num_fc; ++i) {
VarDesc* weights_i =
layers.data("fc_weights_" + std::to_string(i), {}, true);
VarDesc* bias_i = layers.data("fc_bias_" + std::to_string(i), {}, true);
std::string activation_type = i < (num_fc - 1) ? "relu" : "";
VarDesc* fc_out = layers.fc(fc_in, weights_i, bias_i, 1, activation_type);
fc_in = fc_out;
}
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
auto pass = PassRegistry::Instance().Get("repeated_fc_relu_fuse_pass");
int num_nodes_before = graph->Nodes().size();
int num_fc_nodes_before = GetNumOpNodes(graph, "fc");
VLOG(3) << DebugString(graph);
graph.reset(pass->Apply(graph.release()));
int num_nodes_after = graph->Nodes().size();
int num_fused_nodes_after = GetNumOpNodes(graph, "fusion_repeated_fc_relu");
VLOG(3) << DebugString(graph);
// Delete (num_fc_nodes_before - 1) fc ops
PADDLE_ENFORCE_EQ(num_nodes_before - (num_fc_nodes_before - 1) + 1,
num_nodes_after);
PADDLE_ENFORCE_EQ(num_fused_nodes_after, 1);
}
TEST(RepeatedFCReluFusePass, basic_3) { TestMain(3); }
TEST(RepeatedFCReluFusePass, basic_9) { TestMain(9); }
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(repeated_fc_relu_fuse_pass);
......@@ -146,7 +146,7 @@ TEST(Analyzer_seq_conv1, fuse_statis) {
ASSERT_TRUE(fuse_statis.count("seqconv_eltadd_relu_fuse"));
EXPECT_EQ(fuse_statis.at("fc_fuse"), 2);
EXPECT_EQ(fuse_statis.at("seqconv_eltadd_relu_fuse"), 6);
EXPECT_EQ(num_ops, 32);
EXPECT_EQ(num_ops, 31);
}
// Compare result of NativeConfig and AnalysisConfig
......
......@@ -132,7 +132,7 @@ class FCOpMaker : public framework::OpProtoAndCheckerMaker {
.SetDefault(1)
.EqualGreaterThan(1);
AddAttr<std::string>("activation_type",
"Avctivation type used in fully connected operator.")
"Activation type used in fully connected operator.")
.SetDefault("");
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册