提交 efd76614 编写于 作者: T Tomasz Patejko

MKLDNN conv + elementwise_add fusion: implementation changed to conform with Paddle API

上级 347bf904
...@@ -22,6 +22,7 @@ namespace framework { ...@@ -22,6 +22,7 @@ namespace framework {
namespace ir { namespace ir {
namespace patterns { namespace patterns {
/*
struct Pattern : public PatternBase { struct Pattern : public PatternBase {
Pattern(PDPattern* pattern, const std::string& name_scope) Pattern(PDPattern* pattern, const std::string& name_scope)
: PatternBase{pattern, name_scope, ""} {} : PatternBase{pattern, name_scope, ""} {}
...@@ -45,7 +46,8 @@ struct Pattern : public PatternBase { ...@@ -45,7 +46,8 @@ struct Pattern : public PatternBase {
return node_pattern()->NewNode(node_name(op_name)); return node_pattern()->NewNode(node_name(op_name));
} }
}; };
*/
/*
struct Conv { struct Conv {
std::string op_name() const { return "conv2d"; } std::string op_name() const { return "conv2d"; }
std::string input_name() const { return "Input"; } std::string input_name() const { return "Input"; }
...@@ -105,7 +107,8 @@ struct ElementwiseAdd { ...@@ -105,7 +107,8 @@ struct ElementwiseAdd {
}; };
} }
}; };
*/
/*
Node* GetNodeFromSubgraph(const GraphPatternDetector::subgraph_t& subgraph, Node* GetNodeFromSubgraph(const GraphPatternDetector::subgraph_t& subgraph,
std::shared_ptr<patterns::Pattern> pattern, std::shared_ptr<patterns::Pattern> pattern,
const std::string& op_name) { const std::string& op_name) {
...@@ -116,6 +119,7 @@ Node* GetNodeFromSubgraph(const GraphPatternDetector::subgraph_t& subgraph, ...@@ -116,6 +119,7 @@ Node* GetNodeFromSubgraph(const GraphPatternDetector::subgraph_t& subgraph,
return var; return var;
} }
*/
void LinkNodes(Node* from, Node* to) { void LinkNodes(Node* from, Node* to) {
from->outputs.push_back(to); from->outputs.push_back(to);
...@@ -172,64 +176,50 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const { ...@@ -172,64 +176,50 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto pattern = gpd.mutable_pattern(); auto pattern = gpd.mutable_pattern();
auto pattern_ptr = std::make_shared<patterns::Pattern>(pattern, name_scope_);
patterns::Conv conv_pattern; patterns::Conv conv_pattern{pattern, "skip_connections_fusion"};
auto conv_output = conv_pattern(pattern_ptr)(); auto conv_output = conv_pattern();
patterns::ElementwiseAdd elementwise_add_pattern; patterns::ElementwiseAdd elementwise_add_pattern{pattern,
elementwise_add_pattern(pattern_ptr)(conv_output); "skip_connections_fusion"};
elementwise_add_pattern(conv_output);
conv_output->AsIntermediate(); conv_output->AsIntermediate();
auto fuse_conv = [&conv_pattern](Graph* g, Node* conv_input, Node* conv_bias, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Node* conv_filter, Node* conv_output, Graph* g) {
Node* elementwise_add_x) { GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_input, conv_input, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_bias, conv_bias, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_filter, conv_filter, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_output, conv_output, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_x, elementwise_add_x,
elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out,
elementwise_add_pattern);
OpDesc op_desc; OpDesc op_desc;
op_desc.SetType(conv_pattern.op_name()); op_desc.SetType("conv2d");
op_desc.SetInput(conv_pattern.input_name(), {conv_input->Name()}); op_desc.SetInput("Input", {conv_input->Name()});
op_desc.SetInput(conv_pattern.bias_name(), {conv_bias->Name()}); op_desc.SetInput("Bias", {conv_bias->Name()});
op_desc.SetInput(conv_pattern.filter_name(), {conv_filter->Name()}); op_desc.SetInput("Filter", {conv_filter->Name()});
op_desc.SetInput(conv_pattern.residual_data_name(), op_desc.SetInput("ResidualData", {elementwise_add_x->Name()});
{elementwise_add_x->Name()}); op_desc.SetOutput("Output", {conv_output->Name()});
op_desc.SetOutput(conv_pattern.output_name(), {conv_output->Name()});
op_desc.SetAttr("use_mkldnn", true); op_desc.SetAttr("use_mkldnn", true);
op_desc.SetAttr("fuse_eltwise", true); op_desc.SetAttr("fuse_eltwise", true);
auto fused_conv_op = g->CreateOpNode(&op_desc); auto fused_conv_op = g->CreateOpNode(&op_desc);
patterns::LinkNodes(conv_input, fused_conv_op); IR_NODE_LINK_TO(conv_input, fused_conv_op);
patterns::LinkNodes(conv_bias, fused_conv_op); IR_NODE_LINK_TO(conv_bias, fused_conv_op);
patterns::LinkNodes(conv_filter, fused_conv_op); IR_NODE_LINK_TO(conv_filter, fused_conv_op);
patterns::LinkNodes(elementwise_add_x, fused_conv_op); IR_NODE_LINK_TO(elementwise_add_x, fused_conv_op);
patterns::LinkNodes(fused_conv_op, conv_output); IR_NODE_LINK_TO(fused_conv_op, conv_output);
};
auto handler = [&conv_pattern, &elementwise_add_pattern, pattern_ptr,
fuse_conv](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
auto conv_op = patterns::GetNodeFromSubgraph(subgraph, pattern_ptr,
conv_pattern.op_name());
auto conv_input = patterns::GetNodeFromSubgraph(subgraph, pattern_ptr,
conv_pattern.input_name());
auto conv_bias = patterns::GetNodeFromSubgraph(subgraph, pattern_ptr,
conv_pattern.bias_name());
auto conv_filter = patterns::GetNodeFromSubgraph(
subgraph, pattern_ptr, conv_pattern.filter_name());
auto conv_output = patterns::GetNodeFromSubgraph(
subgraph, pattern_ptr, conv_pattern.output_name());
auto elementwise_add_op = patterns::GetNodeFromSubgraph(
subgraph, pattern_ptr, elementwise_add_pattern.op_name());
auto elementwise_add_x = patterns::GetNodeFromSubgraph(
subgraph, pattern_ptr, elementwise_add_pattern.x_name());
auto elementwise_add_out = patterns::GetNodeFromSubgraph(
subgraph, pattern_ptr, elementwise_add_pattern.out_name());
fuse_conv(g, conv_input, conv_bias, conv_filter, conv_output,
elementwise_add_x);
patterns::CorrectGraphEdges(g, elementwise_add_out, conv_output); patterns::CorrectGraphEdges(g, elementwise_add_out, conv_output);
GraphSafeRemoveNodes(g, {elementwise_add_out, conv_op, elementwise_add_op}); GraphSafeRemoveNodes(g, {elementwise_add_out, conv_op, elementwise_add_op});
}; };
......
...@@ -999,6 +999,45 @@ PDNode *patterns::ConvBias::operator()( ...@@ -999,6 +999,45 @@ PDNode *patterns::ConvBias::operator()(
return eltwise_out_var; return eltwise_out_var;
} }
PDNode *patterns::Conv::operator()() {
auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d");
auto input_var = pattern->NewNode(conv_input_repr())
->assert_is_op_input("conv2d", "Input");
auto bias_var =
pattern->NewNode(conv_bias_repr())->assert_is_op_input("conv2d", "Bias");
auto filter_var = pattern->NewNode(conv_filter_repr())
->assert_is_op_input("conv2d", "Filter");
auto output_var = pattern->NewNode(conv_output_repr())
->assert_is_op_output("conv2d", "Output");
conv_op->LinksFrom({input_var, bias_var, filter_var});
conv_op->LinksTo({output_var});
return output_var;
}
PDNode *patterns::ElementwiseAdd::operator()(PDNode *conv_output) {
auto elementwise_add_op = pattern->NewNode(elementwise_add_op_repr())
->assert_is_op("elementwise_add");
auto x_var = pattern->NewNode(elementwise_add_x_repr())
->assert_is_op_input("elementwise_add", "X");
conv_output->assert_is_op_input("elementwise_add", "Y");
auto out_var = pattern->NewNode(elementwise_add_out_repr())
->AsOutput()
->assert_is_op_output("elementwise_add", "Out");
elementwise_add_op->LinksFrom({x_var, conv_output});
elementwise_add_op->LinksTo({out_var});
return out_var;
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -599,6 +599,32 @@ struct ConvBias : public PatternBase { ...@@ -599,6 +599,32 @@ struct ConvBias : public PatternBase {
PATTERN_DECL_NODE(eltwise_bias); PATTERN_DECL_NODE(eltwise_bias);
PATTERN_DECL_NODE(eltwise_out); PATTERN_DECL_NODE(eltwise_out);
}; };
struct Conv : public PatternBase {
Conv(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "convolution") {}
PDNode* operator()();
PATTERN_DECL_NODE(conv_op);
PATTERN_DECL_NODE(conv_input);
PATTERN_DECL_NODE(conv_bias);
PATTERN_DECL_NODE(conv_filter);
PATTERN_DECL_NODE(conv_residual_data);
PATTERN_DECL_NODE(conv_output);
};
struct ElementwiseAdd : public PatternBase {
ElementwiseAdd(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "elementwise_add") {}
PDNode* operator()(PDNode* conv_output);
PATTERN_DECL_NODE(elementwise_add_op);
PATTERN_DECL_NODE(elementwise_add_x);
PATTERN_DECL_NODE(elementwise_add_y);
PATTERN_DECL_NODE(elementwise_add_out);
};
} // namespace patterns } // namespace patterns
// Link two ir::Nodes from each other. // Link two ir::Nodes from each other.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册