From a89296ac1fa9fd91eccde23955ac07590988c62b Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Sat, 12 Jan 2019 17:27:26 +0000 Subject: [PATCH] add repeated fc relu pass --- paddle/fluid/framework/ir/CMakeLists.txt | 1 + .../ir/repeated_fc_relu_fuse_pass.cc | 409 ++++++++++++++++++ .../framework/ir/repeated_fc_relu_fuse_pass.h | 41 ++ .../framework/ir/seqpool_concat_fuse_pass.cc | 3 +- .../fluid/inference/api/paddle_pass_builder.h | 1 + 5 files changed, 454 insertions(+), 1 deletion(-) create mode 100644 paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.cc create mode 100644 paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.h diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 42fb6a1aa..c888f96d9 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -43,6 +43,7 @@ pass_library(multi_batch_merge_pass base) pass_library(conv_bn_fuse_pass inference) pass_library(seqconv_eltadd_relu_fuse_pass inference) pass_library(seqpool_concat_fuse_pass inference) +pass_library(repeated_fc_relu_fuse_pass inference) pass_library(is_test_pass base) pass_library(conv_elementwise_add_act_fuse_pass inference) pass_library(conv_elementwise_add2_act_fuse_pass inference) diff --git a/paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.cc b/paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.cc new file mode 100644 index 000000000..6f619181f --- /dev/null +++ b/paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.cc @@ -0,0 +1,409 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#include "paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.h" +#include // for max +#include +#include +#include "paddle/fluid/framework/lod_tensor.h" + +#define MAX_NUM_FC 10 + +namespace paddle { +namespace framework { +namespace ir { + +PDNode* BuildRepeatedFCReluPattern(PDPattern* pattern, + const std::string& name_scope, int num_fc) { + auto var_next_is_fc_act = [=](Node* x, const std::string& act_type = "relu", + bool check_in_has_only_one_out = true, + int fc_idx = 0) -> bool { + bool next_is_fc = x && x->IsVar() && VarLinksToOp(x, "fc"); + if (check_in_has_only_one_out) { + next_is_fc = next_is_fc && x->outputs.size() == 1; + } + if (!next_is_fc) { + return false; + } + auto* fc_op = x->outputs[fc_idx]; + bool next_is_act = fc_op && fc_op->IsOp() && fc_op->outputs.size() == 1 && + fc_op->outputs[0] && fc_op->outputs[0]->IsVar() && + VarLinksToOp(fc_op->outputs[0], act_type) && + fc_op->outputs[0]->outputs.size() == 1; + if (!next_is_act) { + return false; + } + auto* act_op = fc_op->outputs[0]->outputs[0]; + return act_op && act_op->IsOp() && act_op->outputs.size() == 1; + }; + + auto find_fc_idx = [=](Node* x, const std::string& act_type = "relu") -> int { + bool next_is_fc = x && x->IsVar() && VarLinksToOp(x, "fc"); + if (!next_is_fc) { + return 0; + } + for (size_t k = 0; k < x->outputs.size(); ++k) { + auto* fc_op = x->outputs[k]; + bool next_is_act = fc_op && fc_op->IsOp() && fc_op->outputs.size() == 1 && + fc_op->outputs[0] && fc_op->outputs[0]->IsVar() && + VarLinksToOp(fc_op->outputs[0], act_type) && + fc_op->outputs[0]->outputs.size() == 1; + if (!next_is_act) { + continue; + } + auto* act_op = fc_op->outputs[0]->outputs[0]; + if (act_op && act_op->IsOp() && act_op->outputs.size() == 1) { + return k; + } + } + return 0; + }; + + auto next_var_of_part = [=](Node* x, int fc_idx = 0) -> Node* { + return x->outputs[fc_idx]->outputs[0]->outputs[0]->outputs[0]; + }; + auto var_next_is_fc_act_repeated_n_times = [=]( + Node* x, int repeated_times, const std::string& act_type = "relu", + bool check_in_has_only_one_out = true) -> bool { + for (int i = 0; i < repeated_times; ++i) { + if (!var_next_is_fc_act(x, act_type, + i == 0 && check_in_has_only_one_out)) { + return false; + } + x = next_var_of_part(x); + } + return true; + }; + + auto var_before_is_fc_act = [=](Node* x, const std::string& act_type = "relu", + bool at_top = false) -> bool { + bool before_is_act = + x && x->IsVar() && x->inputs.size() == 1 && VarLinksFromOp(x, "relu"); + if (!before_is_act) { + return false; + } + auto* relu_op = x->inputs[0]; + // std::cout << "xxxx" << std::endl; + bool before_is_fc = relu_op->IsOp() && relu_op->inputs.size() == 1 && + relu_op->inputs[0]->IsVar() && + VarLinksFromOp(relu_op->inputs[0], "fc") && + relu_op->inputs[0]->inputs.size() == 1; + + if (!before_is_fc) { + return false; + } + auto* fc_op = relu_op->inputs[0]->inputs[0]; + bool is_fc = fc_op->IsOp() && fc_op->inputs.size() == 3; + // std::cout << "*****" << fc_op->inputs.size() << std::endl; + if (!is_fc) { + return false; + } + for (size_t kkk = 0; kkk < 3; ++kkk) { + // std::cout << "++++++" << kkk << std::endl; + if (!fc_op->inputs[kkk]->inputs.empty()) { + if (at_top) { + return true; + } else { + bool res = VarLinksFromOp(fc_op->inputs[kkk], "relu"); + // std::cout << fc_op->inputs[kkk]->Name() << "++++++-----" << kkk << + // ":" + // << res << std::endl; + return res; + } + } + } + // for (auto* fc_i : fc_op->inputs) { + // if (!fc_i->inputs.empty()) { + // std::cout << "++++++" << fc_op->inputs.size()< Node* { + auto* fc_op = x->inputs[0]->inputs[0]; + for (auto* fc_i : fc_op->inputs) { + if (!fc_i->inputs.empty()) { + return fc_i->inputs[0]; + } + } + return nullptr; + }; + + auto var_before_is_fc_act_repeated_n_times = [=]( + Node* x, int repeated_times, + const std::string& act_type = "relu") -> bool { + for (int i = 0; i < repeated_times; ++i) { + // std::cout << "----" << i << std::endl; + if (!var_before_is_fc_act(x, act_type, i == repeated_times - 1)) { + return false; + } + x = before_var_of_part(x); + } + return true; + }; + + std::vector fc_input_var(num_fc); + std::vector fc_output_var(num_fc); + std::vector fc_weight_var(num_fc); + std::vector fc_bias_var(num_fc); + std::vector fc_ops(num_fc); + std::vector relu_ops(num_fc); + + for (int i = 0; i < num_fc; ++i) { + fc_input_var[i] = pattern->NewNode( + [=](Node* x) { + if (i == 0 && x->outputs.size() > 0) { + bool ok = x->inputs.size() > 0; + if (!ok) { + return false; + } + int idx = find_fc_idx(x); + if (idx == 0) { + return var_next_is_fc_act_repeated_n_times(x, num_fc - i, "relu"); + } else { + x = next_var_of_part(x, idx); + return var_next_is_fc_act_repeated_n_times( + x, std::max(1, num_fc - i - 1), "relu"); + } + } else { + bool part1 = + var_next_is_fc_act_repeated_n_times(x, num_fc - i, "relu") && + x->inputs.size() > 0; + if (x->Name() == "fc_0.tmp_1" && x->IsVar() && part1) { + // std::cout << "testes" << std::endl; + } + bool part2 = var_before_is_fc_act_repeated_n_times(x, i, "relu"); + if (x->Name() == "fc_0.tmp_1") { + // std::cout << "========" << part1 << "," << part2 << std::endl; + } + return part1 && part2; + } + }, + name_scope + "/fc_in_" + std::to_string(i)); + + fc_weight_var[i] = pattern->NewNode( + [=](Node* x) { + return var_next_is_fc_act_repeated_n_times(x, num_fc - i, "relu") && + x->inputs.empty() && + var_before_is_fc_act_repeated_n_times(x->outputs[0]->inputs[0], + i, "relu") && + x->Name() == x->outputs[0]->Op()->Input("W")[0]; + }, + name_scope + "/fc_weight_" + std::to_string(i)); + + fc_bias_var[i] = pattern->NewNode( + [=](Node* x) { + return var_next_is_fc_act_repeated_n_times(x, num_fc - i, "relu") && + x->inputs.empty() && + var_before_is_fc_act_repeated_n_times(x->outputs[0]->inputs[0], + i, "relu") && + x->Name() == x->outputs[0]->Op()->Input("Bias")[0]; + }, + name_scope + "/fc_bias_" + std::to_string(i)); + + fc_output_var[i] = pattern->NewNode( + [=](Node* x) { + bool basic = x && x->IsVar() && VarLinksFromOp(x, "fc") && + VarLinksToOp(x, "relu") && x->inputs.size() == 1 && + x->inputs[0]->inputs.size() == 3; + if (!basic) { + return false; + } + x = x->inputs[0]->inputs[0]; + if (i == 0 && x->outputs.size() > 0) { + bool ok = x->inputs.size() > 0; + if (!ok) { + return false; + } + int idx = find_fc_idx(x); + if (idx == 0) { + return var_next_is_fc_act_repeated_n_times(x, num_fc - i, "relu"); + } else { + x = next_var_of_part(x, idx); + return var_next_is_fc_act_repeated_n_times( + x, std::max(1, num_fc - i - 1), "relu"); + } + } else { + return var_next_is_fc_act_repeated_n_times(x, num_fc - i, "relu") && + x->inputs.size() > 0 && + var_before_is_fc_act_repeated_n_times(x, i, "relu"); + } + }, + name_scope + "/fc_out_" + std::to_string(i)); + + fc_ops[i] = pattern->NewNode( + [=](Node* x) { + bool basic = x && x->IsOp() && x->Op()->Type() == "fc" && + x->inputs.size() == 3 && x->outputs.size() == 1; + if (!basic) { + return false; + } + auto* fc_out_var = x->outputs[0]; + return fc_out_var && fc_out_var->IsVar() && + fc_out_var->outputs.size() == 1 && + VarLinksToOp(fc_out_var, "relu") && + fc_out_var->outputs[0]->outputs.size() == 1 && + var_next_is_fc_act_repeated_n_times( + fc_out_var->outputs[0]->outputs[0], num_fc - i - 1, + "relu") && + var_before_is_fc_act_repeated_n_times( + fc_out_var->outputs[0]->outputs[0], i + 1, "relu"); + }, + name_scope + "/fc_op_" + std::to_string(i)); + + relu_ops[i] = pattern->NewNode( + [=](Node* x) { + return x && x->IsOp() && x->Op()->Type() == "relu" && + x->inputs.size() == 1 && x->outputs.size() == 1 && + x->inputs[0]->IsVar() && VarLinksFromOp(x->inputs[0], "fc") && + x->outputs[0]->IsVar() && + var_next_is_fc_act_repeated_n_times(x->outputs[0], + num_fc - i - 1, "relu") && + var_before_is_fc_act_repeated_n_times(x->outputs[0], i + 1, + "relu"); + }, + name_scope + "/act_op_" + std::to_string(i)); + + fc_ops[i] + ->LinksFrom({fc_input_var[i], fc_weight_var[i], fc_bias_var[i]}) + .LinksTo({fc_output_var[i]}); + relu_ops[i]->LinksFrom({fc_output_var[i]}); + } + + auto* last_out_var = pattern->NewNode( + [=](Node* x) { + return var_before_is_fc_act_repeated_n_times(x, num_fc, "relu"); + }, + name_scope + "/act_out"); + for (int i = 0; i < num_fc - 1; ++i) { + relu_ops[i]->LinksTo({fc_input_var[i + 1]}); + } + relu_ops[num_fc - 1]->LinksTo({last_out_var}); + return last_out_var; +} + +static int BuildFusion(Graph* graph, const std::string& name_scope, + int num_fc) { + GraphPatternDetector gpd; + auto* pattern = gpd.mutable_pattern(); + BuildRepeatedFCReluPattern(pattern, name_scope, num_fc); + + auto retrieve_node = [](const std::string& name, + const GraphPatternDetector::subgraph_t& subgraph, + const PDPattern& pat) -> Node* { + PADDLE_ENFORCE(subgraph.count(pat.RetrieveNode(name)), + "pattern has no Node called %s", name.c_str()); + Node* p = subgraph.at(pat.RetrieveNode(name)); + PADDLE_ENFORCE_NOT_NULL(p, "subgraph has no node %s", name.c_str()); + return p; + }; + + int fusion_count{0}; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + LOG(INFO) << "handle Repeated FC Act fuse"; + std::vector weights_vars(num_fc); + std::vector bias_vars(num_fc); + std::vector relu_vars(num_fc - 1); + + std::vector weight_names(num_fc); + std::vector bias_names(num_fc); + std::vector relu_names(num_fc - 1); + + auto& fused_pattern = gpd.pattern(); + for (int i = 0; i < num_fc; ++i) { + if (i >= 1) { + relu_vars[i - 1] = + retrieve_node(name_scope + "/fc_in_" + std::to_string(i), subgraph, + fused_pattern); + relu_names[i - 1] = relu_vars[i - 1]->Name(); + } + + weights_vars[i] = + retrieve_node(name_scope + "/fc_weight_" + std::to_string(i), + subgraph, fused_pattern); + weight_names[i] = weights_vars[i]->Name(); + + bias_vars[i] = retrieve_node(name_scope + "/fc_bias_" + std::to_string(i), + subgraph, fused_pattern); + bias_names[i] = bias_vars[i]->Name(); + } + + auto* input_var = + retrieve_node(name_scope + "/fc_in_0", subgraph, fused_pattern); + auto* last_out_var = + retrieve_node(name_scope + "/act_out", subgraph, fused_pattern); + + // Create New OpDesc + OpDesc op_desc; + op_desc.SetType("fusion_repeated_fc_relu"); + op_desc.SetInput("X", {input_var->Name()}); + op_desc.SetInput("W", weight_names); + op_desc.SetInput("Bias", bias_names); + op_desc.SetOutput("ReluOut", relu_names); + op_desc.SetOutput("Out", {last_out_var->Name()}); + auto* op = graph->CreateOpNode(&op_desc); + IR_NODE_LINK_TO(input_var, op); + for (size_t i = 0; i < weights_vars.size(); ++i) { + IR_NODE_LINK_TO(weights_vars[i], op); + IR_NODE_LINK_TO(bias_vars[i], op); + } + for (size_t i = 0; i < relu_vars.size(); ++i) { + IR_NODE_LINK_TO(op, relu_vars[i]); + } + IR_NODE_LINK_TO(op, last_out_var); + + std::unordered_set marked_nodes; + for (auto& item : subgraph) { + marked_nodes.insert(item.second); + } + for (size_t i = 0; i < weights_vars.size(); ++i) { + marked_nodes.erase(weights_vars[i]); + marked_nodes.erase(bias_vars[i]); + } + for (size_t i = 0; i < relu_vars.size(); ++i) { + marked_nodes.erase(relu_vars[i]); + } + marked_nodes.erase(input_var); + marked_nodes.erase(last_out_var); + GraphSafeRemoveNodes(graph, marked_nodes); + ++fusion_count; + }; + + gpd(graph, handler); + return fusion_count; +} + +std::unique_ptr RepeatedFCReluFusePass::ApplyImpl( + std::unique_ptr graph) const { + FusePassBase::Init(name_scope_, graph.get()); + int fusion_count = 0; + for (int i = MAX_NUM_FC; i > 1; --i) { + fusion_count += + BuildFusion(graph.get(), name_scope_ + "/" + std::to_string(3), 3); + } + AddStatis(fusion_count); + + return graph; +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(repeated_fc_relu_fuse_pass, + paddle::framework::ir::RepeatedFCReluFusePass); diff --git a/paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.h b/paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.h new file mode 100644 index 000000000..9e66d891f --- /dev/null +++ b/paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.h @@ -0,0 +1,41 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#pragma once + +#include +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" + +namespace paddle { +namespace framework { +namespace ir { + +/** + * Fuse Repeated FC Relu + */ +class RepeatedFCReluFusePass : public FusePassBase { + public: + virtual ~RepeatedFCReluFusePass() {} + + protected: + std::unique_ptr ApplyImpl(std::unique_ptr graph) const; + + const std::string name_scope_{"repeated_fc_relu"}; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/seqpool_concat_fuse_pass.cc b/paddle/fluid/framework/ir/seqpool_concat_fuse_pass.cc index fa75e3b4a..63a0c24f2 100644 --- a/paddle/fluid/framework/ir/seqpool_concat_fuse_pass.cc +++ b/paddle/fluid/framework/ir/seqpool_concat_fuse_pass.cc @@ -129,7 +129,8 @@ PDNode* BuildSeqPoolConcatPattern(PDPattern* pattern, return concat_out_var; } -int BuildFusion(Graph* graph, const std::string& name_scope, int num_inputs) { +static int BuildFusion(Graph* graph, const std::string& name_scope, + int num_inputs) { GraphPatternDetector gpd; auto* pattern = gpd.mutable_pattern(); BuildSeqPoolConcatPattern(pattern, name_scope, num_inputs); diff --git a/paddle/fluid/inference/api/paddle_pass_builder.h b/paddle/fluid/inference/api/paddle_pass_builder.h index de9650735..aea0a6914 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.h +++ b/paddle/fluid/inference/api/paddle_pass_builder.h @@ -98,6 +98,7 @@ class CpuPassStrategy : public PassStrategy { "mul_gru_fuse_pass", // "seq_concat_fc_fuse_pass", // "fc_fuse_pass", // + "repeated_fc_relu_fuse_pass", // "conv_bn_fuse_pass", // "conv_eltwiseadd_bn_fuse_pass", // "is_test_pass", // -- GitLab