未验证 提交 b5f3be83 编写于 作者: Y Yiqun Liu 提交者: GitHub

Implement a pass detect fusion group of elementwise op (#19884)

* Add fusion_group_pass and elementwise pattern.

* Rewrite the detector of elementwise group.
test=develop

* Add a comment in codegen.

* Add more unittest cases.
test=develop

* Move code_generator related code to fusion_group directory.

* Correct the including path.

* Add the definition of SubGraph and finish the insert of fusion_group op in pass.

* Insert graph_vis_pass in tester to visualize the graph for debug.
上级 da9e9dd0
...@@ -6,6 +6,7 @@ file(APPEND ${pass_file} "\#include \"paddle/fluid/framework/ir/pass.h\"\n") ...@@ -6,6 +6,7 @@ file(APPEND ${pass_file} "\#include \"paddle/fluid/framework/ir/pass.h\"\n")
add_subdirectory(fuse_optimizer_ops_pass) add_subdirectory(fuse_optimizer_ops_pass)
add_subdirectory(memory_optimize_pass) add_subdirectory(memory_optimize_pass)
add_subdirectory(multi_devices_graph_pass) add_subdirectory(multi_devices_graph_pass)
add_subdirectory(fusion_group)
# Usage: pass_library(target inference) will append to paddle_inference_pass.h # Usage: pass_library(target inference) will append to paddle_inference_pass.h
unset(INFER_IR_PASSES CACHE) # clear the global variable unset(INFER_IR_PASSES CACHE) # clear the global variable
...@@ -30,8 +31,6 @@ function(pass_library TARGET DEST) ...@@ -30,8 +31,6 @@ function(pass_library TARGET DEST)
endif() endif()
endfunction() endfunction()
cc_library(codegen SRCS codegen.cc DEPS codegen_helper)
cc_library(codegen_helper SRCS codegen_helper.cc DEPS graph node graph_helper)
cc_library(node SRCS node.cc DEPS proto_desc) cc_library(node SRCS node.cc DEPS proto_desc)
cc_library(graph SRCS graph.cc DEPS node pretty_log) cc_library(graph SRCS graph.cc DEPS node pretty_log)
cc_library(graph_helper SRCS graph_helper.cc DEPS graph) cc_library(graph_helper SRCS graph_helper.cc DEPS graph)
...@@ -111,11 +110,6 @@ set(GLOB_PASS_LIB ${PASS_LIBRARY} CACHE INTERNAL "Global PASS library") ...@@ -111,11 +110,6 @@ set(GLOB_PASS_LIB ${PASS_LIBRARY} CACHE INTERNAL "Global PASS library")
cc_library(pass_builder SRCS pass_builder.cc DEPS pass) cc_library(pass_builder SRCS pass_builder.cc DEPS pass)
if(NOT APPLE AND NOT WIN32)
if(WITH_GPU)
cc_test(codegen_test SRCS codegen_test.cc DEPS codegen_helper codegen device_code lod_tensor)
endif()
endif()
cc_test(node_test SRCS node_test.cc DEPS node) cc_test(node_test SRCS node_test.cc DEPS node)
cc_test(pass_test SRCS pass_test.cc DEPS graph pass graph_helper) cc_test(pass_test SRCS pass_test.cc DEPS graph pass graph_helper)
cc_test(graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry) cc_test(graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry)
......
cc_library(code_generator SRCS code_generator.cc code_generator_helper.cc DEPS graph)
if(NOT APPLE AND NOT WIN32)
if(WITH_GPU)
cc_test(test_code_generator SRCS code_generator_tester.cc DEPS code_generator device_code lod_tensor)
endif()
endif()
cc_library(fusion_group_pass
SRCS fusion_group_pass.cc elementwise_group_detector.cc
DEPS graph_pattern_detector pass)
cc_test(test_fusion_group_pass SRCS fusion_group_pass_tester.cc DEPS fusion_group_pass graph_viz_pass)
...@@ -11,10 +11,12 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,10 +11,12 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/ir/codegen.h"
#include "paddle/fluid/framework/ir/fusion_group/code_generator.h"
#include <set> #include <set>
#include <sstream> #include <sstream>
#include "paddle/fluid/framework/ir/codegen_helper.h" #include "paddle/fluid/framework/ir/fusion_group/code_generator_helper.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
...@@ -23,9 +25,8 @@ CodeGenerator::CodeGenerator(CodeTemplate code_template) { ...@@ -23,9 +25,8 @@ CodeGenerator::CodeGenerator(CodeTemplate code_template) {
code_template_ = code_template; code_template_ = code_template;
} }
// in order to get the right result of expression, we need to calculate, we // In order to get the right result of expression, we need to calculate and
// store the expression as // store the expression as suffix Expressions using vector.
// suffix Expressions using vector
std::string CodeGenerator::GenerateCode(TemplateVariable template_var) { std::string CodeGenerator::GenerateCode(TemplateVariable template_var) {
auto cuda_kernel = kernel_function + code_template_.Format(template_var); auto cuda_kernel = kernel_function + code_template_.Format(template_var);
return cuda_kernel; return cuda_kernel;
......
...@@ -11,10 +11,11 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,10 +11,11 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/ir/codegen_helper.h" #include "paddle/fluid/framework/ir/fusion_group/code_generator_helper.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -23,8 +24,11 @@ namespace ir { ...@@ -23,8 +24,11 @@ namespace ir {
class CodeGenerator { class CodeGenerator {
public: public:
explicit CodeGenerator(CodeTemplate code_template); explicit CodeGenerator(CodeTemplate code_template);
std::string GenerateCode(TemplateVariable template_var); std::string GenerateCode(TemplateVariable template_var);
// TODO(wangchao66) std::string GenerateCode(const Graph& graph)
// TODO(wangchao): add a more general interface
// std::string Generate(const std::string name, const SubGraph& subgraph);
private: private:
CodeTemplate code_template_; CodeTemplate code_template_;
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
* You may obtain a copy of the License at You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and See the License for the specific language governing permissions and
* limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/ir/codegen_helper.h"
#include "paddle/fluid/framework/ir/fusion_group/code_generator_helper.h"
#include <algorithm> #include <algorithm>
#include <sstream> #include <sstream>
#include <string> #include <string>
#include <vector> #include <vector>
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
...@@ -50,6 +52,7 @@ std::string OperationExpression::GetLHSTemplate() { ...@@ -50,6 +52,7 @@ std::string OperationExpression::GetLHSTemplate() {
bool OperationExpression::SupportState() { bool OperationExpression::SupportState() {
return (support_table.find(op_) == support_table.end()); return (support_table.find(op_) == support_table.end());
} }
// we Traverse the graph and get the group , all input id and output id is // we Traverse the graph and get the group , all input id and output id is
// unique for the node which belong the group // unique for the node which belong the group
std::string OperationExpression::GetExpression() { std::string OperationExpression::GetExpression() {
......
...@@ -11,6 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,6 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <iostream> #include <iostream>
...@@ -81,6 +82,7 @@ class TemplateVariable { ...@@ -81,6 +82,7 @@ class TemplateVariable {
private: private:
std::unordered_map<std::string, std::string> strings_; std::unordered_map<std::string, std::string> strings_;
}; };
class CodeTemplate { class CodeTemplate {
public: public:
CodeTemplate() = default; CodeTemplate() = default;
...@@ -110,6 +112,7 @@ class CodeTemplate { ...@@ -110,6 +112,7 @@ class CodeTemplate {
return EmitIndents(ret); return EmitIndents(ret);
} }
std::string EmitIndents(std::string str) { std::string EmitIndents(std::string str) {
std::string ret = str; std::string ret = str;
int space_num = 0; int space_num = 0;
...@@ -147,6 +150,7 @@ static std::string EmitUniqueName(std::vector<OperationExpression> expression) { ...@@ -147,6 +150,7 @@ static std::string EmitUniqueName(std::vector<OperationExpression> expression) {
} }
return ret.str(); return ret.str();
} }
// we get the parameter list code for the expression information // we get the parameter list code for the expression information
static std::string EmitDeclarationCode( static std::string EmitDeclarationCode(
std::vector<OperationExpression> expression, std::string type) { std::vector<OperationExpression> expression, std::string type) {
......
...@@ -11,19 +11,20 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,19 +11,20 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/ir/codegen.h"
#include "paddle/fluid/framework/ir/fusion_group/code_generator.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <cmath> #include <cmath>
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/ir/codegen_helper.h" #include "paddle/fluid/framework/ir/fusion_group/code_generator_helper.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/operators/math.h" #include "paddle/fluid/operators/math.h"
#include "paddle/fluid/platform/device_code.h" #include "paddle/fluid/platform/device_code.h"
#include "paddle/fluid/platform/init.h" #include "paddle/fluid/platform/init.h"
#ifdef PADDLE_WITH_CUDA
TEST(codegen, cuda) { #ifdef PADDLE_WITH_CUDA
TEST(code_generator, cuda) {
std::vector<int> mul_input{1, 2}; std::vector<int> mul_input{1, 2};
std::vector<int> add_input{3, 4}; std::vector<int> add_input{3, 4};
std::vector<int> sub_input{5, 6}; std::vector<int> sub_input{5, 6};
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle {
namespace framework {
namespace ir {
namespace fusion_group {
static std::unordered_set<std::string> binary_op_types = {
"elementwise_add", "elementwise_sub", "elementwise_mul",
"elementwise_div", "elementwise_min", "elementwise_max"};
static std::unordered_set<std::string> unary_op_types = {"relu", "sigmoid",
"tanh"};
static bool IsSpecifiedOp(const std::unordered_set<std::string>& op_types,
Node* n) {
if (n && n->IsOp() && n->Op() && n->outputs.size() > 0U) {
auto iter = op_types.find(n->Op()->Type());
if (iter != op_types.end()) {
return true;
}
}
return false;
}
static bool IsBinaryOp(Node* n) {
if (IsSpecifiedOp(binary_op_types, n) && n->inputs.size() == 2U) {
auto* x = n->inputs[0];
auto* y = n->inputs[1];
std::vector<int64_t> x_shape;
std::vector<int64_t> y_shape;
if (x && x->IsVar() && x->Var()) {
x_shape = x->Var()->GetShape();
}
if (y && y->IsVar() && y->Var()) {
y_shape = y->Var()->GetShape();
}
if (x_shape.size() == 0U || x_shape.size() != y_shape.size()) {
return false;
}
for (size_t i = 0; i < x_shape.size(); ++i) {
if (x_shape[i] != y_shape[i]) {
return false;
}
}
return true;
}
return false;
}
static bool IsUnaryOp(Node* n) { return IsSpecifiedOp(unary_op_types, n); }
bool ElementwiseGroupDetector::IsElementwiseOp(Node* n) {
return IsBinaryOp(n) || IsUnaryOp(n);
}
bool ElementwiseGroupDetector::IsInputOfElementwiseOp(Node* n,
std::string name) {
if (n && n->IsVar() && n->Var()) {
for (auto* op : n->outputs) {
if (IsElementwiseOp(op)) {
if (name.empty()) {
return true;
} else if (IsNthInput(n, op, name, 0)) {
return true;
}
}
}
}
return false;
}
bool ElementwiseGroupDetector::IsOutputOfElementwiseOp(Node* n) {
if (n && n->IsVar() && n->Var()) {
for (auto* op : n->inputs) {
if (IsElementwiseOp(op)) {
return true;
}
}
}
return false;
}
void ElementwiseGroupDetector::Insert(Node* n) {
if (subgraph_.nodes_set.find(n) == subgraph_.nodes_set.end()) {
VLOG(5) << "Insert " << n->Name() << " to subgraph " << name_;
subgraph_.nodes_set.insert(n);
}
}
int ElementwiseGroupDetector::Search(Node* n, std::vector<Node*> except_nodes) {
std::unordered_set<Node*> except_nodes_set;
for (size_t i = 0; i < except_nodes.size(); ++i) {
except_nodes_set.insert(except_nodes[i]);
}
int num_operations = 0;
if (IsElementwiseOp(n)) {
Insert(n);
num_operations += 1;
for (auto* var : n->inputs) {
Insert(var);
if (except_nodes_set.find(var) == except_nodes_set.end()) {
num_operations += Search(var, {n});
}
}
for (auto* var : n->outputs) {
Insert(var);
if (except_nodes_set.find(var) == except_nodes_set.end()) {
num_operations += Search(var, {n});
}
}
} else if (n && n->IsVar() && n->Var()) {
for (auto* op : n->inputs) {
if (IsElementwiseOp(op) &&
except_nodes_set.find(op) == except_nodes_set.end()) {
num_operations += Search(op, {n});
}
}
for (auto* op : n->outputs) {
if (IsElementwiseOp(op) &&
except_nodes_set.find(op) == except_nodes_set.end()) {
num_operations += Search(op, {n});
}
}
}
return num_operations;
}
int ElementwiseGroupDetector::operator()(Node* n) {
if (!IsOutputOfElementwiseOp(n) && IsInputOfElementwiseOp(n, "X")) {
name_ = n->Name();
Insert(n);
num_operations_ = Search(n, n->inputs);
VLOG(4) << "Detect elementwise subgraph begin with " << name_ << ", "
<< num_operations_ << " operations, " << GetSubgraph().GetNumNodes()
<< " nodes";
}
return num_operations_;
}
} // namespace fusion_group
} // namespace ir
} // namespace framework
} // namespace paddle
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/ir/fusion_group/subgraph.h"
#include "paddle/fluid/framework/ir/node.h"
namespace paddle {
namespace framework {
namespace ir {
namespace fusion_group {
struct ElementwiseGroupDetector {
public:
int operator()(Node* n);
SubGraph GetSubgraph() const { return subgraph_; }
private:
bool IsElementwiseOp(Node* n);
bool IsInputOfElementwiseOp(Node* n, std::string name = "");
bool IsOutputOfElementwiseOp(Node* n);
void Insert(Node* n);
int Search(Node* n, std::vector<Node*> except_nodes = {});
private:
std::string name_;
int num_operations_{0};
SubGraph subgraph_;
};
} // namespace fusion_group
} // namespace ir
} // namespace framework
} // namespace paddle
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/fusion_group/fusion_group_pass.h"
#include <vector>
#include "paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
namespace paddle {
namespace framework {
namespace ir {
void FusionGroupPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(graph);
int num_elementwise_groups = DetectFusionGroup(graph, 0);
LOG(INFO) << "Detect " << num_elementwise_groups
<< " elementwise fusion groups.";
}
int FusionGroupPass::DetectFusionGroup(Graph* graph, int type) const {
std::vector<fusion_group::SubGraph> subgraphs;
std::unordered_set<Node*> all_nodes = graph->Nodes();
for (Node* n : all_nodes) {
bool is_found = false;
for (auto& subgraph : subgraphs) {
if (subgraph.nodes_set.find(n) != subgraph.nodes_set.end()) {
is_found = true;
break;
}
}
if (is_found) {
continue;
}
fusion_group::SubGraph subgraph;
if (type == 0) {
fusion_group::ElementwiseGroupDetector detector;
int num_operations = detector(n);
if (num_operations >= 2) {
subgraph = detector.GetSubgraph();
}
}
if (!subgraph.IsEmpty()) {
subgraphs.push_back(subgraph);
}
}
// TODO(liuyiqun): check whether there are intersection between subgraphs
for (size_t i = 0; i < subgraphs.size(); ++i) {
InsertFusionGroupOp(graph, subgraphs[i]);
}
return subgraphs.size();
}
void FusionGroupPass::InsertFusionGroupOp(
Graph* graph, const fusion_group::SubGraph& subgraph) const {
std::vector<Node*> input_vars_of_subgraph = subgraph.GetInputVarNodes();
std::vector<Node*> output_vars_of_subgraph = subgraph.GetOutputVarNodes();
std::unordered_set<Node*> external_nodes;
OpDesc op_desc;
op_desc.SetType("fusion_group");
std::vector<std::string> input_names;
for (auto* n : input_vars_of_subgraph) {
input_names.push_back(n->Name());
external_nodes.insert(n);
}
op_desc.SetInput("Xs", input_names);
std::vector<std::string> output_names;
for (auto* n : output_vars_of_subgraph) {
output_names.push_back(n->Name());
external_nodes.insert(n);
}
op_desc.SetOutput("Outs", output_names);
op_desc.SetAttr("type", subgraph.type);
op_desc.SetAttr("func_name", subgraph.func_name);
auto fusion_group_node = graph->CreateOpNode(&op_desc);
for (auto* in : input_vars_of_subgraph) {
IR_NODE_LINK_TO(in, fusion_group_node);
}
for (auto* out : output_vars_of_subgraph) {
IR_NODE_LINK_TO(fusion_group_node, out);
}
std::unordered_set<const Node*> internal_nodes;
for (auto* n : subgraph.nodes_set) {
if (external_nodes.find(n) == external_nodes.end()) {
internal_nodes.insert(n);
}
}
GraphSafeRemoveNodes(graph, internal_nodes);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(fusion_group_pass, paddle::framework::ir::FusionGroupPass);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <unordered_set>
#include "paddle/fluid/framework/ir/fusion_group/subgraph.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
class FusionGroupPass : public Pass {
protected:
void ApplyImpl(Graph* graph) const override;
private:
int DetectFusionGroup(Graph* graph, int type = 0) const;
void InsertFusionGroupOp(Graph* graph,
const fusion_group::SubGraph& subgraph) const;
const std::string name_scope_{"fusion_group"};
};
} // namespace ir
} // namespace framework
} // namespace paddle
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/fusion_group/fusion_group_pass.h"
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
namespace paddle {
namespace framework {
namespace ir {
TEST(FusionGroupPass, elementwise_list) {
// inputs operator output
// --------------------------------------------------------
// (x, y) mul -> tmp_0
// (tmp_0, z) elementwise_add -> tmp_1
// tmp_1 relu -> tmp_2
// (tmp_2, w) elementwise_add -> tmp_3
//
// Expression: tmp_3 = relu(mul(x, y) + z) + w
Layers layers;
auto* x = layers.data("x", {16, 16});
auto* y = layers.data("y", {16, 32});
auto* tmp_0 = layers.mul(x, y);
tmp_0->SetShape({16, 32});
auto* z = layers.data("z", {16, 32});
auto* tmp_1 = layers.elementwise_add(tmp_0, z);
auto* tmp_2 = layers.relu(tmp_1);
tmp_2->SetShape({16, 32});
auto* w = layers.data("w", {16, 32});
layers.elementwise_add(tmp_2, w);
std::unique_ptr<Graph> graph(new Graph(layers.main_program()));
// The following codes is to insert a graph_viz_pass to transform the graph to
// a .dot file. It is used for debug.
// auto graph_viz_pass = PassRegistry::Instance().Get("graph_viz_pass");
// graph_viz_pass->Set("graph_viz_path", new
// std::string("00_elementwise_list.dot"));
// graph.reset(graph_viz_pass->Apply(graph.release()));
auto fusion_group_pass = PassRegistry::Instance().Get("fusion_group_pass");
VLOG(3) << DebugString(graph);
graph.reset(fusion_group_pass->Apply(graph.release()));
int num_fusion_group_ops = GetNumOpNodes(graph, "fusion_group");
VLOG(3) << DebugString(graph);
PADDLE_ENFORCE_EQ(num_fusion_group_ops, 1);
// The following codes is to insert a graph_viz_pass to transform the graph to
// a .dot file. It is used for debug.
// auto graph_viz_pass = PassRegistry::Instance().Get("graph_viz_pass");
// graph_viz_pass->Set("graph_viz_path", new
// std::string("01_elementwise_list.fusion_group.dot"));
// graph.reset(graph_viz_pass->Apply(graph.release()));
}
TEST(FusionGroupPass, elementwise_tree) {
// inputs operator output
// --------------------------------------------------------
// (x0, y0) mul -> tmp_0
// x1 sigmoid -> tmp_1
// (tmp_0, tmp_1) elementwise_mul -> tmp_2
// x2 sigmoid -> tmp_3
// x3 tanh -> tmp_4
// (tmp_3, tmp_4) elementwise_mul -> tmp_5
// (tmp_2, tmp_5) elementwise_add -> tmp_6
// x4 tanh -> tmp_7
// x5 sigmoid -> tmp_8
// (tmp_7, tmp_8) elementwise_mul -> tmp_9
// (tmp_6, tmp_9) mul -> tmp_10
//
// Expression: tmp_6 = mul(x0, y0) * sigmoid(x1) + sigmoid(x2) * tanh(x3)
// tmp_9 = tanh(x4) * sigmoid(x5)
// tmp_10 = mul(tmp_6, tmp_9)
Layers layers;
auto* x0 = layers.data("x0", {16, 16});
auto* y0 = layers.data("y0", {16, 32});
auto* tmp_0 = layers.mul(x0, y0);
tmp_0->SetShape({16, 32});
auto* x1 = layers.data("x1", {16, 32});
auto* tmp_1 = layers.sigmoid(x1);
tmp_1->SetShape({16, 32});
auto* tmp_2 = layers.elementwise_mul(tmp_0, tmp_1);
tmp_2->SetShape({16, 32});
auto* x2 = layers.data("x2", {16, 32});
auto* tmp_3 = layers.sigmoid(x2);
tmp_3->SetShape({16, 32});
auto* x3 = layers.data("x3", {16, 32});
auto* tmp_4 = layers.tanh(x3);
tmp_4->SetShape({16, 32});
auto* tmp_5 = layers.elementwise_mul(tmp_3, tmp_4);
tmp_5->SetShape({16, 32});
auto* tmp_6 = layers.elementwise_add(tmp_2, tmp_5);
tmp_6->SetShape({16, 32});
auto* x4 = layers.data("x4", {16, 32});
auto* tmp_7 = layers.tanh(x4);
tmp_7->SetShape({16, 32});
auto* x5 = layers.data("x5", {16, 32});
auto* tmp_8 = layers.sigmoid(x5);
tmp_8->SetShape({16, 32});
auto* tmp_9 = layers.elementwise_mul(tmp_7, tmp_8);
tmp_9->SetShape({16, 32});
layers.mul(tmp_6, tmp_9);
std::unique_ptr<Graph> graph(new Graph(layers.main_program()));
// The following codes is to insert a graph_viz_pass to transform the graph to
// a .dot file. It is used for debug.
// auto graph_viz_pass = PassRegistry::Instance().Get("graph_viz_pass");
// graph_viz_pass->Set("graph_viz_path", new
// std::string("00_elementwise_tree.dot"));
// graph.reset(graph_viz_pass->Apply(graph.release()));
auto fusion_group_pass = PassRegistry::Instance().Get("fusion_group_pass");
LOG(INFO) << DebugString(graph);
graph.reset(fusion_group_pass->Apply(graph.release()));
int num_fusion_group_ops = GetNumOpNodes(graph, "fusion_group");
LOG(INFO) << DebugString(graph);
PADDLE_ENFORCE_EQ(num_fusion_group_ops, 2);
// The following codes is to insert a graph_viz_pass to transform the graph to
// a .dot file. It is used for debug.
// auto graph_viz_pass = PassRegistry::Instance().Get("graph_viz_pass");
// graph_viz_pass->Set("graph_viz_path", new
// std::string("01_elementwise_tree.fusion_group.dot"));
// graph.reset(graph_viz_pass->Apply(graph.release()));
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(fusion_group_pass);
USE_PASS(graph_viz_pass);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/ir/node.h"
namespace paddle {
namespace framework {
namespace ir {
namespace fusion_group {
struct SubGraph {
int type{-1};
std::string func_name;
std::unordered_set<Node*> nodes_set;
bool IsEmpty() { return nodes_set.empty(); }
size_t GetNumNodes() { return nodes_set.size(); }
int GetNumOperations() {
int num_operations = 0;
for (auto* n : nodes_set) {
if (n && n->IsOp() && n->Op()) {
num_operations++;
}
}
return num_operations;
}
std::vector<Node*> GetInputVarNodes() const {
// The order of input nodes should be consistent with that of the generated
// code.
std::vector<Node*> input_vars;
for (auto* n : nodes_set) {
if (n && n->IsVar() && n->Var()) {
bool is_found = true;
// When the inputs size is 0, it is also considered the input var of
// subgraph.
if (n->inputs.size() == 0U) {
is_found = false;
}
// Normally a var node has only one input op node.
for (auto* in : n->inputs) {
if (nodes_set.find(in) == nodes_set.end()) {
is_found = false;
}
}
if (!is_found) {
input_vars.push_back(n);
}
}
}
return input_vars;
}
std::vector<Node*> GetOutputVarNodes() const {
// The order of output nodes should be consistant with that of the generated
// code.
std::vector<Node*> output_vars;
for (auto* n : nodes_set) {
if (n && n->IsVar() && n->Var()) {
bool is_found = true;
if (n->outputs.size() == 0U) {
is_found = false;
}
for (auto* out : n->outputs) {
if (nodes_set.find(out) == nodes_set.end()) {
is_found = false;
}
}
if (!is_found) {
output_vars.push_back(n);
}
}
}
return output_vars;
}
};
} // namespace fusion_group
} // namespace ir
} // namespace framework
} // namespace paddle
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include <memory> #include <memory>
#include <sstream> #include <sstream>
#include <string> #include <string>
#include <unordered_set>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/op_proto_maker.h"
...@@ -92,6 +93,14 @@ struct Layers { ...@@ -92,6 +93,14 @@ struct Layers {
return unary_op("relu", x, out); return unary_op("relu", x, out);
} }
VarDesc* sigmoid(VarDesc* x, VarDesc* out = nullptr) {
return unary_op("sigmoid", x, out);
}
VarDesc* tanh(VarDesc* x, VarDesc* out = nullptr) {
return unary_op("tanh", x, out);
}
VarDesc* fc(VarDesc* input, VarDesc* w, VarDesc* bias, VarDesc* fc(VarDesc* input, VarDesc* w, VarDesc* bias,
int in_num_col_dims = 1, std::string activation_type = "") { int in_num_col_dims = 1, std::string activation_type = "") {
VarDesc* out = lod_tensor(unique_name()); VarDesc* out = lod_tensor(unique_name());
...@@ -119,6 +128,10 @@ struct Layers { ...@@ -119,6 +128,10 @@ struct Layers {
return binary_op("elementwise_add", x, y, out); return binary_op("elementwise_add", x, y, out);
} }
VarDesc* elementwise_mul(VarDesc* x, VarDesc* y, VarDesc* out = nullptr) {
return binary_op("elementwise_mul", x, y, out);
}
VarDesc* dropout(VarDesc* x, float dropout_prob, VarDesc* dropout(VarDesc* x, float dropout_prob,
std::string dropout_implementation) { std::string dropout_implementation) {
VarDesc* out = lod_tensor(unique_name()); VarDesc* out = lod_tensor(unique_name());
...@@ -399,10 +412,9 @@ static std::string DebugString(Node* node) { ...@@ -399,10 +412,9 @@ static std::string DebugString(Node* node) {
return os.str(); return os.str();
} }
static std::string DebugString(const std::unique_ptr<Graph>& graph) { static std::string DebugString(const std::unordered_set<Node*>& nodes) {
std::ostringstream os; std::ostringstream os;
os << "Graph: {\n"; for (auto* node : nodes) {
for (auto* node : graph->Nodes()) {
if (node->IsOp() && node->Op()) { if (node->IsOp() && node->Op()) {
os << " "; os << " ";
} else if (node->IsVar() && node->Var()) { } else if (node->IsVar() && node->Var()) {
...@@ -410,7 +422,12 @@ static std::string DebugString(const std::unique_ptr<Graph>& graph) { ...@@ -410,7 +422,12 @@ static std::string DebugString(const std::unique_ptr<Graph>& graph) {
} }
os << DebugString(node) << "\n"; os << DebugString(node) << "\n";
} }
os << "}\n"; return os.str();
}
static std::string DebugString(const std::unique_ptr<Graph>& graph) {
std::ostringstream os;
os << "Graph: {\n" << DebugString(graph->Nodes()) << "}\n";
return os.str(); return os.str();
} }
......
...@@ -59,12 +59,12 @@ TEST(SimplifyWithBasicOpsPass, dropout) { ...@@ -59,12 +59,12 @@ TEST(SimplifyWithBasicOpsPass, dropout) {
int num_scale_nodes_after = GetNumOpNodes(graph, "scale"); int num_scale_nodes_after = GetNumOpNodes(graph, "scale");
VLOG(3) << DebugString(graph); VLOG(3) << DebugString(graph);
PADDLE_ENFORCE_EQ(num_dropout_nodes_after, 0UL); PADDLE_ENFORCE_EQ(num_dropout_nodes_after, 0);
if (dropout_implementation == "downgrade_in_infer") { if (dropout_implementation == "downgrade_in_infer") {
PADDLE_ENFORCE_EQ(num_dropout_nodes_before, PADDLE_ENFORCE_EQ(num_dropout_nodes_before,
num_scale_nodes_after - num_scale_nodes_before); num_scale_nodes_after - num_scale_nodes_before);
} else { } else {
PADDLE_ENFORCE_EQ(num_scale_nodes_after - num_scale_nodes_before, 0UL); PADDLE_ENFORCE_EQ(num_scale_nodes_after - num_scale_nodes_before, 0);
} }
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册