提交 604bad08 编写于 作者: T Tomasz Patejko

MKLDNN conv + elementwise_add fusion: implementation of patterns refarctored,...

MKLDNN conv + elementwise_add fusion: implementation of patterns refarctored, applied to graph. UTs added
上级 9ce343f8
......@@ -44,6 +44,9 @@ if(WITH_MKLDNN)
endif()
cc_library(fuse_elewise_add_act_pass SRCS fuse_elewise_add_act_pass.cc DEPS pass graph_pattern_detector )
if(WITH_MKLDNN)
pass_library(conv_elementwise_add_mkldnn_fuse_pass inference)
endif()
set(GLOB_PASS_LIB ${PASS_LIBRARY} CACHE INTERNAL "Global PASS library")
......@@ -57,4 +60,5 @@ cc_test(test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS g
cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass framework_proto)
if (WITH_MKLDNN)
cc_test(test_conv_relu_mkldnn_fuse_pass SRCS conv_relu_mkldnn_fuse_pass_tester.cc DEPS conv_relu_mkldnn_fuse_pass)
cc_test(test_conv_elementwise_add_mkldnn_fuse_pass SRCS conv_elementwise_add_mkldnn_fuse_pass_tester.cc DEPS conv_elementwise_add_mkldnn_fuse_pass)
endif ()
#include "paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h"
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
struct Pattern : public PatternBase {
Pattern(PDPattern* pattern, const std::string& name_scope)
: PatternBase{pattern, name_scope, ""}
{ }
private:
std::string name_scope() { return name_scope_; }
std::string repr() { return repr_; }
size_t id() { return id_; }
PDPattern* node_pattern() { return pattern; }
public:
std::string node_name(std::string op_name)
{
return PDNodeName(name_scope(), repr(), id(), op_name);
}
PDNode* retrieve_node(std::string op_name)
{
return node_pattern()->RetrieveNode(node_name(op_name));
}
PDNode* new_node(std::string op_name)
{
return node_pattern()->NewNode(node_name(op_name));
}
};
struct Conv {
std::string conv_name() { return "conv2d"; }
std::string input_name() { return "Input"; }
std::string filter_name() { return "Filter"; }
std::string output_name() { return "Output"; }
std::function<PDNode* ()> operator()(std::shared_ptr<Pattern> pattern) {
return [&]() -> PDNode* {
auto conv_op = pattern->new_node(conv_name())
->assert_is_op("conv2d");
auto input_var = pattern->new_node(input_name())
->AsInput()
->assert_is_op_input(conv_name());
auto filter_var = pattern->new_node(filter_name())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input(conv_name());
auto output_var = pattern->new_node(output_name())
->AsOutput()
->assert_is_op_output(conv_name());
conv_op->LinksFrom({input_var, filter_var});
conv_op->LinksTo({output_var});
return output_var;
};
}
};
struct ElementwiseAdd {
std::string elementwise_add_name() { return "elementwise_add"; }
std::string x_name() { return "X"; }
std::string y_name() { return "Y"; }
std::string out_name() { return "Out"; }
std::function<PDNode* (PDNode*)> operator()(std::shared_ptr<Pattern> pattern) {
return [&](PDNode* conv_output) -> PDNode* {
auto elementwise_add_op = pattern->new_node(elementwise_add_name())
->assert_is_op("elementwise_add");
auto y_var = pattern->new_node(y_name())
->AsInput()
->assert_is_op_input(elementwise_add_name());
conv_output->assert_is_op_input(pattern->node_name(elementwise_add_name()),
pattern->node_name(x_name()));
// auto y_var = pattern->NewNode(y_name())
// ->AsInput()
// ->assert_is_op_input(elementwise_add_name());
auto out_var = pattern->new_node(out_name())
->AsOutput()
->assert_is_op_output(
pattern->node_name(elementwise_add_name()));
elementwise_add_op->LinksFrom({y_var, conv_output});
elementwise_add_op->LinksTo({out_var});
return out_var;
};
}
};
} // namespace patterns
Node* node_from_subgraph(const GraphPatternDetector::subgraph_t& subgraph,
std::shared_ptr<patterns::Pattern> pattern, const std::string& op_name)
{
PADDLE_ENFORCE(subgraph.count(pattern->retrieve_node(op_name)),
"Node not found for PDNode %s", pattern->node_name(op_name));
Node* var = subgraph.at(pattern->retrieve_node(op_name));
PADDLE_ENFORCE(var, "node %s not exists in the sub-graph");
return var;
}
using graph_ptr = std::unique_ptr<ir::Graph>;
graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
FusePassBase::Init("conv_elementwise_add_mkldnn_fuse_pass", graph.get());
GraphPatternDetector gpd;
auto pattern = gpd.mutable_pattern();
auto pattern_ptr = std::make_shared<patterns::Pattern>(pattern, name_scope_);
patterns::Conv conv_pattern;
auto conv_output = conv_pattern(pattern_ptr)();
conv_output->AsIntermediate();
patterns::ElementwiseAdd elementwise_add_pattern;
elementwise_add_pattern(pattern_ptr)(conv_output);
auto link_nodes_to = [](Node* a, Node* b) {
a->outputs.push_back(b);
b->inputs.push_back(a);
};
auto fuse_conv = [&](Graph* g, Node* conv_input, Node* conv_filter, Node* y) {
OpDesc op_desc;
op_desc.SetType("conv2d");
op_desc.SetInput("Input", {conv_input->Name()});
op_desc.SetInput("Filter", {conv_filter->Name()});
op_desc.SetOutput("Ouput", {y->Name()});
op_desc.SetAttr("fuse_sum", true);
auto fused_conv_op = g->CreateOpNode(&op_desc);
link_nodes_to(conv_input, fused_conv_op);
link_nodes_to(conv_filter, fused_conv_op);
link_nodes_to(fused_conv_op, y);
};
auto remove_unused_nodes = [](Graph* g, const std::unordered_set<const Node*>& removed_nodes) {
GraphSafeRemoveNodes(g, removed_nodes);
};
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) {
auto elementwise_add_x = node_from_subgraph(subgraph, pattern_ptr, elementwise_add_pattern.x_name());
auto elementwise_add_y = node_from_subgraph(subgraph, pattern_ptr, elementwise_add_pattern.y_name());
auto elementwise_add_out = node_from_subgraph(subgraph, pattern_ptr, elementwise_add_pattern.out_name());
auto conv_filter = node_from_subgraph(subgraph, pattern_ptr, conv_pattern.filter_name());
auto conv_input = node_from_subgraph(subgraph, pattern_ptr, conv_pattern.input_name());
auto conv_output = node_from_subgraph(subgraph, pattern_ptr, conv_pattern.output_name());
fuse_conv(g, conv_input, conv_filter, elementwise_add_y);
remove_unused_nodes(g, {elementwise_add_x, conv_output, elementwise_add_out});
};
gpd(graph.get(), handler);
return graph;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(conv_elementwise_add_mkldnn_fuse_pass, paddle::framework::ir::ConvElementwiseAddMKLDNNFusePass);
......@@ -9,14 +9,14 @@ namespace paddle {
namespace framework {
namespace ir {
class MKLDNNConvElementwiseAddFusePass : public FusePassBase {
class ConvElementwiseAddMKLDNNFusePass : public FusePassBase {
public:
virtual ~FCGRUFusePass() {}
virtual ~ConvElementwiseAddMKLDNNFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
const std::string name_scope_{"mkldnn_conv_elementwise_add_fuse"};
const std::string name_scope_{"conv_elementwise_add_mkldnn_fuse_pass"};
};
} // namespace ir
......
#include "paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h"
#include <gtest/gtest.h>
namespace paddle {
namespace framework {
namespace ir {
void SetOp(ProgramDesc* prog, const std::string& type,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs) {
auto op = prog->MutableBlock(0)->AppendOp();
op->SetType(type);
if (type == "conv2d") {
op->SetAttr("use_mkldnn", true);
op->SetInput("Input", {inputs[0]});
op->SetInput("Filter", {inputs[1]});
op->SetInput("Output", {outputs});
} else if (type == "elementwise_add") {
op->SetInput("X", {inputs[0]});
op->SetInput("Y", {inputs[1]});
op->SetOutput("Out", outputs);
}
}
ProgramDesc BuildProgramDesc() {
ProgramDesc prog;
for (auto& v :
std::vector<std::string>({"a", "b", "c", "d", "weights", "f", "g"})) {
auto* var = prog.MutableBlock(0)->Var(v);
var->SetType(proto::VarType::LOD_TENSOR);
if (v == "weights" || v == "bias") {
var->SetPersistable(true);
}
}
SetOp(&prog, "OP0", {"a"}, {"b"});
SetOp(&prog, "OP1", {"c"}, {"d"});
SetOp(&prog, "conv2d", {"d", "weights"}, {"f"});
SetOp(&prog, "elemenwise_add", {"d", "f"}, {"g"});
return prog;
}
TEST(ConvElementwiseAddMKLDNNFusePass, basic) {
auto prog = BuildProgramDesc();
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
auto pass = PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass");
int original_nodes_num = graph->Nodes().size();
graph = pass->Apply(std::move(graph));
int current_nodes_num = graph->Nodes().size();
EXPECT_EQ(original_nodes_num - 2, current_nodes_num);
// Assert conv_relu op in newly generated graph
int conv_elementwise_add_count = 0;
for (auto* node : graph->Nodes()) {
if (node->IsOp() && node->Op()->Type() == "conv2d") {
if (node->Op()->HasAttr("use_mkldnn")) {
bool use_mkldnn = boost::get<bool>(node->Op()->GetAttr("use_mkldnn"));
if (use_mkldnn) {
// TODO tpatejko: it is commented because convolution does not support this attribute
if (true/*node->Op()->HasAttr("fuse_sum")*/) {
// bool fuse_sum = boost::get<bool>(node->Op()->GetAttr("fuse_sum"));
if (true /*fuse_sum*/) {
++conv_elementwise_add_count;
}
}
}
}
}
}
EXPECT_EQ(conv_elementwise_add_count, 1);
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(conv_elementwise_add_mkldnn_fuse_pass);
#include "paddle/fluid/framework/ir/mkldnn_conv_elementwise_add_fuse_pass.h"
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
struct PatternNode {
PatternNode(PDPattern* pattern,
const std::string& name,
const std::string& name_scope,
const std::string& repr,
size_t id)
: nodeName{PDNodeName(name_scope, repr, id, name)}
, node{pattern->RetrieveNode(nodeName)
{ }
std::string name() { return nodeName };
PDNode* node() { return node };
private:
std::string nodeName;
PDNode* node;
};
/*
struct Conv : public PatternBase {
Conv(PDPattern* pattern, const std::string& name_scope)
: PatternBase{pattern, name_scope, "conv"}
, conv{pattern, "conv", name_scope_, repr_, id_}
, input{pattern, "Input", name_scope_, repr_, id_}
, filter{pattern, "Filter", name_scope_, repr_, id_}
, output{pattern, "Output", node_scope_, repr_ id_}
{ }
private:
PatternNode conv;
PatternNode input;
PatternNode filter;
PatternNode output;
public:
PDNode* operator()() {
auto conv_op = pattern->NewNode(conv.name())
->assert_is_op("conv2d");
auto input_var = pattern->NewNode(input.name())
->AsInput()
->assert_is_op_input(conv.name());
auto filter_var = pattern->NewNode(filter.name())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input(conv.name());
auto output_var = patterh->NewNode(output.name())
->AsOutput()
->assert_is_op_output(conv.name());
conv_op->LinksFrom({input_var, filter_var});
conv_op->LinksTo({output_var};
return output_var;
}
};
*/
struct Conv : public PatternBase {
Conv(PDPattern* pattern, const std::string& name_scope)
: PatternBase{pattern, name_scope, "conv"}
{ }
std::string conv_name() { return PDNodeName(name_scope_, repr_, id_, "conv2d"); }
PDNode* conv_node() { return pattern->RetrieveNode(conv_name()); }
std::string input_name() { return PDNodeName(name_scope, repr_, id_, "Input"); }
PDNode* input_node() { return pattern->RetrieveNode(input_name()); }
std::string filter_name() { return PDNodeName(name_scope_, repr_, id_, "Filter"); }
PDNode* filter_node() { return pattern->RetrieveNode(filter_name()); }
std::string output_name() { return PDNodeName(name_scope, repr_, id_, "Output"); }
PDNode* output_node() { return pattern->RetrieveNode(output_name()); }
PDNode* operator()() {
auto conv_op = pattern->NewNode(conv_name())
->assert_is_op("conv2d");
auto input_var = pattern->NewNode(input_name())
->AsInput()
->assert_is_op_input(conv_name());
auto filter_var = pattern->NewNode(filter_name())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input(conv_name());
auto output_var = patterh->NewNode(output_name())
->AsOutput()
->assert_is_op_output(conv_name());
conv_op->LinksFrom({input_var, filter_var});
conv_op->LinksTo({output_var};
return output_var;
}
};
struct ElementwiseAdd : public PatternBase {
Conv(PDPattern* pattern, const std::string& name_scope)
: PatternBase{pattern, name_scope, "elementwise_add"}
{ }
std::string elementwise_add_name() { return PDNodeName(name_scope_, repr_, id_, "elementwise_add"); }
PDNode* elementwise_add_node() { return pattern->RetrieveNode(elementwise_add_name()); }
std::string x_name() { return PDNodeName(name_scope, repr_, id_, "X"); }
PDNode* x_node() { return pattern->RetrieveNode(x_name()); }
std::string y_name() { return PDNodeName(name_scope_, repr_, id_, "Y"); }
PDNode* y_node() { return pattern->RetrieveNode(y_name()); }
std::string out_name() { return PDNodeName(name_scope, repr_, id_, "Out"); }
PDNode* out_node() { return pattern->RetrieveNode(out_name()); }
PDNode* operator()(PDNode* conv_output) {
auto elementwise_add_op = pattern->NewNode(conv_name())
->assert_is_op("elementwise_add");
auto x_var = pattern->NewNode(x_name())
->AsInput()
->assert_is_op_input(elementwise_add_name());
conv_output->assert_is_op_input(elementwise_add_name(), y_name());
// auto y_var = pattern->NewNode(y_name())
// ->AsInput()
// ->assert_is_op_input(elementwise_add_name());
auto out_var = pattern->NewNode(out_name())
->AsOutput()
->assert_is_op_output(elementwise_add_name());
conv_op->LinksFrom({x_var, conv_output});
conv_op->LinksTo({out_var};
return out_var;
}
};
} // namespace patterns
using graph_ptr = std::unique_ptr<ir::Graph>;
graph_ptr MKLDNNConvElementwiseAddFusePass::ApplyImpl(graph_ptr) const {
FusePassBase::Init("mkldnn_conv_elementwise_add_fuse", graph.get());
GraphPatternDetector gpd;
auto pattern = gpd.mutable_pattern();
patterns::Conv conv_pattern(pattern, name_scope_);
auto conv_output = conv_pattern();
conv_output->AsIntermediate();
patterns::ElementwiseAdd elementwise_add_pattern(pattern, name_scope_);
auto elementwis_add_output = elementwise_add_pattern(conv_output);
}
} // namespace ir
} // namespace framework
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册