提交 27573ece 编写于 作者: T Tomasz Patejko

MKLDNN conv + elementwise_add fusion: trailing spaces removed

上级 7f5c8a95
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h"
#include <functional>
#include "paddle/fluid/framework/ir/graph_traits.h"
namespace paddle {
......@@ -8,15 +24,14 @@ namespace patterns {
struct Pattern : public PatternBase {
Pattern(PDPattern* pattern, const std::string& name_scope)
: PatternBase{pattern, name_scope, ""}
{ }
private:
: PatternBase{pattern, name_scope, ""} {}
private:
std::string name_scope() { return name_scope_; }
std::string repr() { return repr_; }
std::string repr() { return repr_; }
size_t id() { return id_; }
PDPattern* node_pattern() { return pattern; }
public:
std::string node_name(std::string op_name) {
return PDNodeName(name_scope(), repr(), id(), op_name);
......@@ -37,22 +52,18 @@ struct Conv {
std::string filter_name() { return "Filter"; }
std::string output_name() { return "Output"; }
std::function<PDNode* ()> operator()(std::shared_ptr<Pattern> pattern) {
std::function<PDNode*()> operator()(std::shared_ptr<Pattern> pattern) {
return [&]() -> PDNode* {
auto conv_op = pattern->new_node(op_name())
->assert_is_op("conv2d");
auto conv_op = pattern->new_node(op_name())->assert_is_op("conv2d");
auto input_var = pattern->new_node(input_name())
->assert_is_op_input(op_name(),
input_name());
->assert_is_op_input(op_name(), input_name());
auto filter_var = pattern->new_node(filter_name())
->assert_is_op_input(op_name(),
filter_name());
->assert_is_op_input(op_name(), filter_name());
auto output_var = pattern->new_node(output_name())
->assert_is_op_output(op_name(),
output_name());
->assert_is_op_output(op_name(), output_name());
conv_op->LinksFrom({input_var, filter_var});
conv_op->LinksTo({output_var});
......@@ -68,22 +79,19 @@ struct ElementwiseAdd {
std::string y_name() { return "Y"; }
std::string out_name() { return "Out"; }
std::function<PDNode* (PDNode*)> operator()(std::shared_ptr<Pattern> pattern) {
std::function<PDNode*(PDNode*)> operator()(std::shared_ptr<Pattern> pattern) {
return [&](PDNode* conv_output) -> PDNode* {
auto elementwise_add_op = pattern->new_node(op_name())
->assert_is_op("elementwise_add");
auto elementwise_add_op =
pattern->new_node(op_name())->assert_is_op("elementwise_add");
auto x_var =
pattern->new_node(x_name())->assert_is_op_input(op_name(), x_name());
auto x_var = pattern->new_node(x_name())
->assert_is_op_input(op_name(),
x_name());
conv_output->assert_is_op_input(op_name(),
y_name());
conv_output->assert_is_op_input(op_name(), y_name());
auto out_var = pattern->new_node(out_name())
->AsOutput()
->assert_is_op_output(op_name(),
out_name());
->AsOutput()
->assert_is_op_output(op_name(), out_name());
elementwise_add_op->LinksFrom({x_var, conv_output});
elementwise_add_op->LinksTo({out_var});
......@@ -94,13 +102,13 @@ struct ElementwiseAdd {
};
Node* GetNodeFromSubgraph(const GraphPatternDetector::subgraph_t& subgraph,
std::shared_ptr<patterns::Pattern> pattern,
const std::string& op_name) {
std::shared_ptr<patterns::Pattern> pattern,
const std::string& op_name) {
PADDLE_ENFORCE(subgraph.count(pattern->retrieve_node(op_name)),
"Node not found for PDNode %s", pattern->node_name(op_name));
Node* var = subgraph.at(pattern->retrieve_node(op_name));
PADDLE_ENFORCE(var, "node %s not exists in the sub-graph");
return var;
}
......@@ -109,10 +117,9 @@ void LinkNodes(Node* from, Node* to) {
to->inputs.push_back(from);
}
template<typename IT, typename FindFunc, typename ReplaceFunc>
template <typename IT, typename FindFunc, typename ReplaceFunc>
void ReplaceAllOccurances(IT s, IT e, FindFunc f, ReplaceFunc r) {
if (s == e)
return;
if (s == e) return;
auto it = std::find_if(s, e, f);
......@@ -126,8 +133,7 @@ void ReplaceAllOccurances(IT s, IT e, FindFunc f, ReplaceFunc r) {
void CorrectGraphEdges(Graph* graph, Node* from, Node* to) {
for (auto& node : GraphTraits::DFS(*graph)) {
auto same = std::find_if(std::begin(node.inputs),
std::end(node.inputs),
auto same = std::find_if(std::begin(node.inputs), std::end(node.inputs),
[from](Node* n) { return n == from; });
if (same != std::end(node.inputs)) {
......@@ -137,17 +143,19 @@ void CorrectGraphEdges(Graph* graph, Node* from, Node* to) {
using input_type = VariableNameMap::value_type;
ReplaceAllOccurances(std::begin(inputs), std::end(inputs),
[from](const input_type& i) -> bool {
auto params = i.second;
auto pi = std::find_if(std::begin(params), std::end(params),
std::bind(std::equal_to<std::string>(),
from->Name(), std::placeholders::_1));
return pi != std::end(params);
},
[to, &node](const input_type& i) {
node.Op()->SetInput(i.first, {to->Name()});
});
ReplaceAllOccurances(
std::begin(inputs), std::end(inputs),
[from](const input_type& i) -> bool {
auto params = i.second;
auto pi =
std::find_if(std::begin(params), std::end(params),
std::bind(std::equal_to<std::string>(),
from->Name(), std::placeholders::_1));
return pi != std::end(params);
},
[to, &node](const input_type& i) {
node.Op()->SetInput(i.first, {to->Name()});
});
}
}
}
......@@ -169,7 +177,8 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
conv_output->AsIntermediate();
auto fuse_conv = [](Graph* g, Node* conv_input, Node* conv_filter, Node* conv_output, Node* elementwise_add_x) {
auto fuse_conv = [](Graph* g, Node* conv_input, Node* conv_filter,
Node* conv_output, Node* elementwise_add_x) {
OpDesc op_desc;
op_desc.SetType("conv2d");
......@@ -189,22 +198,23 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
patterns::LinkNodes(fused_conv_op, conv_output);
};
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) {
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
auto conv_op = patterns::GetNodeFromSubgraph(subgraph, pattern_ptr,
conv_pattern.op_name());
conv_pattern.op_name());
auto conv_input = patterns::GetNodeFromSubgraph(subgraph, pattern_ptr,
conv_pattern.input_name());
auto conv_filter = patterns::GetNodeFromSubgraph(subgraph, pattern_ptr,
conv_pattern.filter_name());
auto conv_output = patterns::GetNodeFromSubgraph(subgraph, pattern_ptr,
conv_pattern.output_name());
auto elementwise_add_op = patterns::GetNodeFromSubgraph(subgraph, pattern_ptr,
elementwise_add_pattern.op_name());
auto elementwise_add_x = patterns::GetNodeFromSubgraph(subgraph, pattern_ptr,
elementwise_add_pattern.x_name());
auto elementwise_add_out = patterns::GetNodeFromSubgraph(subgraph, pattern_ptr,
elementwise_add_pattern.out_name());
conv_pattern.input_name());
auto conv_filter = patterns::GetNodeFromSubgraph(
subgraph, pattern_ptr, conv_pattern.filter_name());
auto conv_output = patterns::GetNodeFromSubgraph(
subgraph, pattern_ptr, conv_pattern.output_name());
auto elementwise_add_op = patterns::GetNodeFromSubgraph(
subgraph, pattern_ptr, elementwise_add_pattern.op_name());
auto elementwise_add_x = patterns::GetNodeFromSubgraph(
subgraph, pattern_ptr, elementwise_add_pattern.x_name());
auto elementwise_add_out = patterns::GetNodeFromSubgraph(
subgraph, pattern_ptr, elementwise_add_pattern.out_name());
fuse_conv(g, conv_input, conv_filter, conv_output, elementwise_add_x);
patterns::CorrectGraphEdges(g, elementwise_add_out, conv_output);
......@@ -219,4 +229,5 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
} // namespace framework
} // namespace paddle
REGISTER_PASS(conv_elementwise_add_mkldnn_fuse_pass, paddle::framework::ir::ConvElementwiseAddMKLDNNFusePass);
REGISTER_PASS(conv_elementwise_add_mkldnn_fuse_pass,
paddle::framework::ir::ConvElementwiseAddMKLDNNFusePass);
#include "paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/ir/graph_traits.h"
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include <gtest/gtest.h>
#include <string>
#include "paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/ir/graph_traits.h"
namespace paddle {
namespace framework {
......@@ -33,10 +47,11 @@ void SetOp(ProgramDesc* prog, const std::string& type,
}
struct IsReachable {
using func = std::function<bool (const std::string&, const std::string&)>;
using func = std::function<bool(const std::string&, const std::string&)>;
auto operator()(const std::unique_ptr<ir::Graph>& graph) -> func {
auto find_node = [](const std::unique_ptr<ir::Graph>& graph, const std::string& name) -> Node* {
auto find_node = [](const std::unique_ptr<ir::Graph>& graph,
const std::string& name) -> Node* {
for (auto& node : GraphTraits::DFS(*graph)) {
if (name == node.Name()) {
return &node;
......@@ -47,8 +62,7 @@ struct IsReachable {
};
return [&](std::string from, const std::string to) -> bool {
if (from == to)
return true;
if (from == to) return true;
std::map<std::string, bool> visited;
......@@ -61,16 +75,14 @@ struct IsReachable {
std::list<std::string> queue;
queue.push_back(from);
while(!queue.empty()) {
while (!queue.empty()) {
auto cur = find_node(graph, queue.front());
queue.pop_front();
if (cur == nullptr)
return false;
if (cur == nullptr) return false;
for (auto n : cur->outputs) {
if (n->Name() == to)
return true;
if (n->Name() == to) return true;
if (!visited[n->Name()]) {
visited[n->Name()] = true;
......@@ -87,14 +99,14 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionWithElementwiseAddRelu) {
auto build_program_desc = [&]() -> ProgramDesc {
ProgramDesc prog;
for (auto& v :
std::vector<std::string>({"a", "b", "weights", "c", "d", "e"})) {
std::vector<std::string>({"a", "b", "weights", "c", "d", "e"})) {
auto* var = prog.MutableBlock(0)->Var(v);
var->SetType(proto::VarType::LOD_TENSOR);
if (v == "weights") {
var->SetPersistable(true);
}
}
SetOp(&prog, "conv2d", {"a", "weights"}, {"b"});
SetOp(&prog, "elementwise_add", {"c", "b"}, {"d"});
SetOp(&prog, "relu", {"d"}, {"e"});
......@@ -109,14 +121,16 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionWithElementwiseAddRelu) {
EXPECT_TRUE(is_reachable(graph)("a", "relu"));
auto pass = PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass");
auto pass =
PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass");
int original_nodes_num = graph->Nodes().size();
graph = pass->Apply(std::move(graph));
int current_nodes_num = graph->Nodes().size();
EXPECT_TRUE(is_reachable(graph)("a", "relu"));
EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added, current_nodes_num);
EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added,
current_nodes_num);
// Assert conv_relu op in newly generated graph
int conv_count = 0;
int elementwise_add_count = 0;
......@@ -136,15 +150,14 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionWithElementwiseAddRelu) {
TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionElementwiseAdd) {
auto build_program_desc = [&]() -> ProgramDesc {
ProgramDesc prog;
for (auto& v :
std::vector<std::string>({"a", "b", "weights"})) {
for (auto& v : std::vector<std::string>({"a", "b", "weights"})) {
auto* var = prog.MutableBlock(0)->Var(v);
var->SetType(proto::VarType::LOD_TENSOR);
if (v == "weights" || v == "bias") {
var->SetPersistable(true);
}
}
SetOp(&prog, "conv2d", {"a", "weights"}, {"b"});
SetOp(&prog, "elementwise_add", {"c", "b"}, {"d"});
......@@ -157,14 +170,16 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionElementwiseAdd) {
IsReachable is_reachable;
EXPECT_TRUE(is_reachable(graph)("a", "d"));
auto pass = PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass");
auto pass =
PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass");
int original_nodes_num = graph->Nodes().size();
graph = pass->Apply(std::move(graph));
int current_nodes_num = graph->Nodes().size();
EXPECT_FALSE(is_reachable(graph)("a", "d"));
EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added, current_nodes_num);
EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added,
current_nodes_num);
// Assert conv_relu op in newly generated graph
int conv_count = 0;
int elementwise_add_count = 0;
......@@ -185,14 +200,14 @@ TEST(ConvElementwiseAddMKLDNNFusePass, SigmoidConvolutionAddElementwiseRelu) {
auto build_program_desc = [&]() -> ProgramDesc {
ProgramDesc prog;
for (auto& v :
std::vector<std::string>({"a", "b", "weights", "c", "d", "e", "f"})) {
std::vector<std::string>({"a", "b", "weights", "c", "d", "e", "f"})) {
auto* var = prog.MutableBlock(0)->Var(v);
var->SetType(proto::VarType::LOD_TENSOR);
if (v.find("weights")) {
var->SetPersistable(true);
}
}
SetOp(&prog, "sigmoid", {"a"}, {"b"});
SetOp(&prog, "conv2d", {"b", "weights"}, {"c"});
SetOp(&prog, "elementwise_add", {"d", "c"}, {"e"});
......@@ -208,14 +223,16 @@ TEST(ConvElementwiseAddMKLDNNFusePass, SigmoidConvolutionAddElementwiseRelu) {
EXPECT_TRUE(is_reachable(graph)("a", "f"));
auto pass = PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass");
auto pass =
PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass");
int original_nodes_num = graph->Nodes().size();
graph = pass->Apply(std::move(graph));
int current_nodes_num = graph->Nodes().size();
EXPECT_TRUE(is_reachable(graph)("a", "f"));
EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added, current_nodes_num);
EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added,
current_nodes_num);
// Assert conv_relu op in newly generated graph
int conv_count = 0;
int elementwise_add_count = 0;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册