提交 38b7b34b 编写于 作者: T Tomasz Patejko

MKLDNN conv + elementwise_add fusion: added reachability tests, inputs and...

MKLDNN conv + elementwise_add fusion: added reachability tests, inputs and outputs in graph nodes are transformed
上级 16eaaf3f
#include "paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/ir/graph_traits.h"
namespace paddle {
namespace framework {
......@@ -90,7 +91,7 @@ struct ElementwiseAdd {
};
} // namespace patterns
Node* node_from_subgraph(const GraphPatternDetector::subgraph_t& subgraph,
Node* GetNodeFromSubgraph(const GraphPatternDetector::subgraph_t& subgraph,
std::shared_ptr<patterns::Pattern> pattern, const std::string& op_name)
{
PADDLE_ENFORCE(subgraph.count(pattern->retrieve_node(op_name)),
......@@ -103,6 +104,20 @@ Node* node_from_subgraph(const GraphPatternDetector::subgraph_t& subgraph,
using graph_ptr = std::unique_ptr<ir::Graph>;
void CorrectGraphEdges(Graph* graph, Node* from, Node* to) {
for (auto& node : GraphTraits::DFS(*graph)) {
std::vector<Node*> to_remove;
auto same = std::find_if(std::begin(node.inputs),
std::end(node.inputs),
[from](Node* n) { return n == from; });
if (same != std::end(node.inputs)) {
node.inputs.push_back(to);
to->outputs.push_back(&node);
}
}
}
graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
FusePassBase::Init("conv_elementwise_add_mkldnn_fuse_pass", graph.get());
......@@ -145,16 +160,18 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
};
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) {
auto conv_op = node_from_subgraph(subgraph, pattern_ptr, conv_pattern.conv_name());
auto conv_input = node_from_subgraph(subgraph, pattern_ptr, conv_pattern.input_name());
auto conv_filter = node_from_subgraph(subgraph, pattern_ptr, conv_pattern.filter_name());
auto conv_output = node_from_subgraph(subgraph, pattern_ptr, conv_pattern.output_name());
auto conv_op = GetNodeFromSubgraph(subgraph, pattern_ptr, conv_pattern.conv_name());
auto conv_input = GetNodeFromSubgraph(subgraph, pattern_ptr, conv_pattern.input_name());
auto conv_filter = GetNodeFromSubgraph(subgraph, pattern_ptr, conv_pattern.filter_name());
auto conv_output = GetNodeFromSubgraph(subgraph, pattern_ptr, conv_pattern.output_name());
auto elementwise_add_op = node_from_subgraph(subgraph, pattern_ptr, elementwise_add_pattern.elementwise_add_name());
auto elementwise_add_y = node_from_subgraph(subgraph, pattern_ptr, elementwise_add_pattern.y_name());
auto elementwise_add_out = node_from_subgraph(subgraph, pattern_ptr, elementwise_add_pattern.out_name());
auto elementwise_add_op = GetNodeFromSubgraph(subgraph, pattern_ptr, elementwise_add_pattern.elementwise_add_name());
auto elementwise_add_y = GetNodeFromSubgraph(subgraph, pattern_ptr, elementwise_add_pattern.y_name());
auto elementwise_add_out = GetNodeFromSubgraph(subgraph, pattern_ptr, elementwise_add_pattern.out_name());
fuse_conv(g, conv_input, conv_filter, elementwise_add_y);
CorrectGraphEdges(g, elementwise_add_out, elementwise_add_y);
remove_unused_nodes(g, {conv_output, elementwise_add_out, conv_op, elementwise_add_op});
};
......
#include "paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/ir/graph_traits.h"
#include <string>
#include <gtest/gtest.h>
namespace paddle {
......@@ -21,37 +23,96 @@ void SetOp(ProgramDesc* prog, const std::string& type,
op->SetInput("X", {inputs[0]});
op->SetInput("Y", {inputs[1]});
op->SetOutput("Out", outputs);
} else if (type == "relu" || type == "sigmoid") {
op->SetInput("X", {inputs[0]});
op->SetOutput("Out", outputs);
}
}
TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionWithElementwiseAddWithOps) {
struct IsReachable {
using func = std::function<bool (const std::string&, const std::string&)>;
auto operator()(const std::unique_ptr<ir::Graph>& graph) -> func {
auto find_node = [](const std::unique_ptr<ir::Graph>& graph, const std::string& name) -> Node* {
for (auto& node : GraphTraits::DFS(*graph)) {
if (name == node.Name()) {
return &node;
}
}
return nullptr;
};
return [&](std::string from, const std::string to) -> bool {
if (from == to)
return true;
std::map<std::string, bool> visited;
for (auto& node : GraphTraits::DFS(*graph)) {
visited[node.Name()] = false;
}
visited[from] = true;
std::list<std::string> queue;
queue.push_back(from);
while(!queue.empty()) {
auto cur = find_node(graph, queue.front());
queue.pop_front();
if (cur == nullptr)
return false;
for (auto n : cur->outputs) {
if (n->Name() == to)
return true;
if (!visited[n->Name()]) {
visited[n->Name()] = true;
queue.push_back(n->Name());
}
}
}
return false;
};
}
};
TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionWithElementwiseAddRelu) {
auto build_program_desc = [&]() -> ProgramDesc {
ProgramDesc prog;
for (auto& v :
std::vector<std::string>({"a", "b", "weights", "c", "d", "e", "f", "g"})) {
std::vector<std::string>({"a", "b", "weights", "c", "d", "e"})) {
auto* var = prog.MutableBlock(0)->Var(v);
var->SetType(proto::VarType::LOD_TENSOR);
if (v == "weights" || v == "bias") {
if (v == "weights") {
var->SetPersistable(true);
}
}
SetOp(&prog, "OP0", {"a"}, {"b"});
SetOp(&prog, "OP1", {"c"}, {"d"});
SetOp(&prog, "conv2d", {"b", "weights"}, {"e"});
SetOp(&prog, "elementwise_add", {"e", "d"}, {"f"});
SetOp(&prog, "OP3", {"f"}, {"g"});
SetOp(&prog, "conv2d", {"a", "weights"}, {"b"});
SetOp(&prog, "elementwise_add", {"b", "c"}, {"d"});
SetOp(&prog, "relu", {"d"}, {"e"});
return prog;
};
auto prog = build_program_desc();
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
IsReachable is_reachable;
EXPECT_TRUE(is_reachable(graph)("a", "relu"));
auto pass = PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass");
int original_nodes_num = graph->Nodes().size();
graph = pass->Apply(std::move(graph));
int current_nodes_num = graph->Nodes().size();
EXPECT_TRUE(is_reachable(graph)("a", "relu"));
EXPECT_EQ(original_nodes_num - 4 + 1, current_nodes_num);
// Assert conv_relu op in newly generated graph
int conv_count = 0;
......@@ -64,26 +125,12 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionWithElementwiseAddWithOps) {
if (node->IsOp() && node->Op()->Type() == "elementwise_add") {
++elementwise_add_count;
}
/*
if (node->Op()->HasAttr("use_mkldnn")) {
bool use_mkldnn = boost::get<bool>(node->Op()->GetAttr("use_mkldnn"));
if (use_mkldnn) {
if (node->Op()->HasAttr("fuse_sum")) {
// bool fuse_sum = boost::get<bool>(node->Op()->GetAttr("fuse_sum"));
if (fuse_sum) {
++conv_elementwise_add_count;
}
}
}
}
}
*/
}
EXPECT_EQ(conv_count, 1);
EXPECT_EQ(elementwise_add_count, 0);
}
TEST(ConvElementwiseAddMKLDNNFusePass, OnlyConvolutionElementwiseAdd) {
TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionElementwiseAdd) {
auto build_program_desc = [&]() -> ProgramDesc {
ProgramDesc prog;
for (auto& v :
......@@ -103,10 +150,16 @@ TEST(ConvElementwiseAddMKLDNNFusePass, OnlyConvolutionElementwiseAdd) {
auto prog = build_program_desc();
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
IsReachable is_reachable;
EXPECT_TRUE(is_reachable(graph)("a", "d"));
auto pass = PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass");
int original_nodes_num = graph->Nodes().size();
graph = pass->Apply(std::move(graph));
int current_nodes_num = graph->Nodes().size();
EXPECT_FALSE(is_reachable(graph)("a", "d"));
EXPECT_EQ(original_nodes_num - 4 + 1, current_nodes_num);
// Assert conv_relu op in newly generated graph
......@@ -120,20 +173,57 @@ TEST(ConvElementwiseAddMKLDNNFusePass, OnlyConvolutionElementwiseAdd) {
if (node->IsOp() && node->Op()->Type() == "elementwise_add") {
++elementwise_add_count;
}
/*
if (node->Op()->HasAttr("use_mkldnn")) {
bool use_mkldnn = boost::get<bool>(node->Op()->GetAttr("use_mkldnn"));
if (use_mkldnn) {
if (node->Op()->HasAttr("fuse_sum")) {
// bool fuse_sum = boost::get<bool>(node->Op()->GetAttr("fuse_sum"));
if (fuse_sum) {
++conv_elementwise_add_count;
}
}
}
}
EXPECT_EQ(conv_count, 1);
EXPECT_EQ(elementwise_add_count, 0);
}
TEST(ConvElementwiseAddMKLDNNFusePass, SigmoidConvolutionAddElementwiseRelu) {
auto build_program_desc = [&]() -> ProgramDesc {
ProgramDesc prog;
for (auto& v :
std::vector<std::string>({"a", "b", "weights", "c", "d", "e", "f"})) {
auto* var = prog.MutableBlock(0)->Var(v);
var->SetType(proto::VarType::LOD_TENSOR);
if (v.find("weights")) {
var->SetPersistable(true);
}
}
*/
SetOp(&prog, "sigmoid", {"a"}, {"b"});
SetOp(&prog, "conv2d", {"b", "weights"}, {"c"});
SetOp(&prog, "elementwise_add", {"c", "d"}, {"e"});
SetOp(&prog, "relu", {"e"}, {"f"});
return prog;
};
auto prog = build_program_desc();
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
IsReachable is_reachable;
EXPECT_TRUE(is_reachable(graph)("a", "f"));
auto pass = PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass");
int original_nodes_num = graph->Nodes().size();
graph = pass->Apply(std::move(graph));
int current_nodes_num = graph->Nodes().size();
EXPECT_TRUE(is_reachable(graph)("a", "f"));
EXPECT_EQ(original_nodes_num - 4 + 1, current_nodes_num);
// Assert conv_relu op in newly generated graph
int conv_count = 0;
int elementwise_add_count = 0;
for (auto* node : graph->Nodes()) {
if (node->IsOp() && node->Op()->Type() == "conv2d") {
++conv_count;
}
if (node->IsOp() && node->Op()->Type() == "elementwise_add") {
++elementwise_add_count;
}
}
EXPECT_EQ(conv_count, 1);
EXPECT_EQ(elementwise_add_count, 0);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册