提交 7423748e 编写于 作者: T Tomasz Patejko

MKLDNN residual connections fuse pass:

* implements reachability check between identity node and non-identity argument to elementwise_add
* implements handling identity node as x and as y argument to elementwise_add
上级 17226782
...@@ -14,14 +14,15 @@ ...@@ -14,14 +14,15 @@
#include "paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h" #include "paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h"
#include <functional> #include <functional>
#include <utility> #include <list>
#include <map>
#include <tuple>
#include "paddle/fluid/framework/ir/graph_traits.h" #include "paddle/fluid/framework/ir/graph_traits.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
namespace {
// The function keeps the graph consistent by replacing // The function keeps the graph consistent by replacing
// a node 'from' in the set of inputs nodes // a node 'from' in the set of inputs nodes
...@@ -51,104 +52,179 @@ void CorrectGraphEdges(Graph* graph, Node* from, Node* to) { ...@@ -51,104 +52,179 @@ void CorrectGraphEdges(Graph* graph, Node* from, Node* to) {
} }
} }
} }
} // namespace
using graph_ptr = std::unique_ptr<ir::Graph>;
graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const { bool IsReachable(ir::Graph* graph, Node* from, Node* to) {
FusePassBase::Init(name_scope_, graph.get()); auto find_node = [](ir::Graph* graph, const Node* node) -> Node* {
for (auto n : graph->Nodes()) {
if (n == node) {
return n;
}
}
GraphPatternDetector gpd; return nullptr;
auto pattern = gpd.mutable_pattern(); };
patterns::Conv conv_pattern{pattern, name_scope_}; if (from == to) {
auto conv_output = conv_pattern(); return true;
}
patterns::ElementwiseAdd elementwise_add_pattern{pattern, name_scope_}; std::map<Node*, bool> visited;
elementwise_add_pattern(conv_output);
conv_output->AsIntermediate(); for (auto& node : GraphTraits::DFS(*graph)) {
visited[&node] = false;
}
visited[from] = true;
std::list<Node*> queue;
queue.push_back(from);
while (!queue.empty()) {
auto cur = find_node(graph, queue.front());
queue.pop_front();
if (!cur) return false;
for (auto n : cur->outputs) {
if (n == to) {
return true;
}
auto conv_op_has_bias = [](const Node& conv_op) -> std::pair<bool, Node*> { if (!visited[n]) {
auto bias_input_names = conv_op.Op()->Inputs(); visited[n] = true;
queue.push_back(n);
}
}
}
return false;
}
std::pair<bool, Node*> ResidualConnectionMKLDNNFusePass::HasBias(
const Node& op) const {
auto bias_input_names = op.Op()->Inputs();
auto bias_it = bias_input_names.find("Bias"); auto bias_it = bias_input_names.find("Bias");
if (bias_it != std::end(bias_input_names)) { if (bias_it != std::end(bias_input_names)) {
bool has_bias = !bias_it->second.empty(); bool has_bias = !bias_it->second.empty();
if (has_bias) { if (has_bias) {
auto conv_bias_names = bias_it->second; auto bias_names = bias_it->second;
auto conv_bias_names_it = auto bias_names_it =
std::find_if(std::begin(conv_op.inputs), std::end(conv_op.inputs), std::find_if(std::begin(op.inputs), std::end(op.inputs),
[&conv_bias_names](Node* n) -> bool { [&bias_names](Node* n) -> bool {
return n->Name() == conv_bias_names[0]; return n->Name() == bias_names[0];
}); });
return std::make_pair(has_bias, *conv_bias_names_it); return std::make_pair(has_bias, *bias_names_it);
} }
} }
return std::make_pair(false, nullptr); return std::make_pair(false, nullptr);
}; }
graph_ptr ResidualConnectionMKLDNNFusePass::FuseConvAsX(
const std::string& name_scope_, graph_ptr graph) const {
GraphPatternDetector gpd;
auto pattern = gpd.mutable_pattern();
patterns::Conv conv_pattern{pattern, name_scope_};
auto conv_output = conv_pattern();
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, patterns::ElementwiseAdd elementwise_add_pattern{pattern, name_scope_};
Graph* g) { elementwise_add_pattern(
conv_output,
pattern->NewNode(elementwise_add_pattern.elementwise_add_y_repr()));
conv_output->AsIntermediate();
auto get_node_from_conv = [](const patterns::Conv& conv_pattern,
const GraphPatternDetector::subgraph_t& subgraph)
-> std::tuple<Node*, Node*, Node*, Node*> {
GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_pattern); GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_input, conv_input, conv_pattern); GET_IR_NODE_FROM_SUBGRAPH(conv_input, conv_input, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_filter, conv_filter, conv_pattern); GET_IR_NODE_FROM_SUBGRAPH(conv_filter, conv_filter, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_output, conv_output, conv_pattern); GET_IR_NODE_FROM_SUBGRAPH(conv_output, conv_output, conv_pattern);
return std::make_tuple(conv_op, conv_input, conv_filter, conv_output);
};
auto get_node_from_elementwise_add = [](
const patterns::ElementwiseAdd& elementwise_add_pattern,
const GraphPatternDetector::subgraph_t& subgraph)
-> std::tuple<Node*, Node*, Node*> {
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op, GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
elementwise_add_pattern); elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_x, elementwise_add_x, GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_y, elementwise_add_y,
elementwise_add_pattern); elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out, GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out,
elementwise_add_pattern); elementwise_add_pattern);
if (FindFuseOption(*conv_op, *elementwise_add_op) != FUSE_MKLDNN) return; return std::make_tuple(elementwise_add_op, elementwise_add_y,
elementwise_add_out);
OpDesc op_desc; };
op_desc.SetType("conv2d");
op_desc.SetInput("Input", {conv_input->Name()});
op_desc.SetInput("Filter", {conv_filter->Name()});
op_desc.SetInput("ResidualData", {elementwise_add_x->Name()});
op_desc.SetOutput("Output", {conv_output->Name()});
bool has_bias; auto handler =
Node* conv_bias; GenerateFuseHandler(conv_pattern, elementwise_add_pattern,
get_node_from_conv, get_node_from_elementwise_add);
gpd(graph.get(), handler);
std::tie(has_bias, conv_bias) = conv_op_has_bias(*conv_op); return graph;
}
if (has_bias) { graph_ptr ResidualConnectionMKLDNNFusePass::FuseConvAsY(
op_desc.SetInput("Bias", {conv_bias->Name()}); const std::string& name_scope_, graph_ptr graph) const {
} GraphPatternDetector gpd;
auto pattern = gpd.mutable_pattern();
for (const auto& attr : conv_op->Op()->GetAttrMap()) { patterns::Conv conv_pattern{pattern, name_scope_};
op_desc.SetAttr(attr.first, attr.second); auto conv_output = conv_pattern();
}
op_desc.SetAttr("fuse_residual_connection", true); patterns::ElementwiseAdd elementwise_add_pattern{pattern, name_scope_};
elementwise_add_pattern(
pattern->NewNode(elementwise_add_pattern.elementwise_add_x_repr()),
conv_output);
conv_output->AsIntermediate();
auto fused_conv_op = g->CreateOpNode(&op_desc); auto get_node_from_conv = [](const patterns::Conv& conv_pattern,
const GraphPatternDetector::subgraph_t& subgraph)
-> std::tuple<Node*, Node*, Node*, Node*> {
GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_input, conv_input, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_filter, conv_filter, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_output, conv_output, conv_pattern);
IR_NODE_LINK_TO(conv_input, fused_conv_op); return std::make_tuple(conv_op, conv_input, conv_filter, conv_output);
IR_NODE_LINK_TO(conv_filter, fused_conv_op); };
IR_NODE_LINK_TO(elementwise_add_x, fused_conv_op);
IR_NODE_LINK_TO(fused_conv_op, conv_output);
if (has_bias) { auto get_node_from_elementwise_add = [](
IR_NODE_LINK_TO(conv_bias, fused_conv_op); const patterns::ElementwiseAdd& elementwise_add_pattern,
} const GraphPatternDetector::subgraph_t& subgraph)
-> std::tuple<Node*, Node*, Node*> {
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_x, elementwise_add_x,
elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out,
elementwise_add_pattern);
CorrectGraphEdges(g, elementwise_add_out, conv_output); return std::make_tuple(elementwise_add_op, elementwise_add_x,
GraphSafeRemoveNodes(g, {elementwise_add_out, conv_op, elementwise_add_op}); elementwise_add_out);
}; };
auto handler =
GenerateFuseHandler(conv_pattern, elementwise_add_pattern,
get_node_from_conv, get_node_from_elementwise_add);
gpd(graph.get(), handler); gpd(graph.get(), handler);
return graph; return graph;
} }
graph_ptr ResidualConnectionMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
FusePassBase::Init(name_scope_, graph.get());
return FuseConvAsY(name_scope_, FuseConvAsX(name_scope_, std::move(graph)));
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
REGISTER_PASS(conv_elementwise_add_mkldnn_fuse_pass, REGISTER_PASS(conv_elementwise_add_mkldnn_fuse_pass,
paddle::framework::ir::ConvElementwiseAddMKLDNNFusePass); paddle::framework::ir::ResidualConnectionMKLDNNFusePass);
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#pragma once #pragma once
#include <string> #include <string>
#include <utility>
#include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h"
...@@ -23,16 +24,105 @@ namespace paddle { ...@@ -23,16 +24,105 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
class ConvElementwiseAddMKLDNNFusePass : public FusePassBase { using graph_ptr = std::unique_ptr<ir::Graph>;
void CorrectGraphEdges(Graph* graph, Node* from, Node* to);
bool IsReachable(ir::Graph* graph, Node* from, Node* to);
using handler_func = std::function<void(
const GraphPatternDetector::subgraph_t& subgraph, Graph* g)>;
class ResidualConnectionMKLDNNFusePass : public FusePassBase {
private:
graph_ptr FuseConvAsX(const std::string& name_scope_, graph_ptr graph) const;
graph_ptr FuseConvAsY(const std::string& name_scope_, graph_ptr graph) const;
std::pair<bool, Node*> HasBias(const Node& op) const;
template <typename CONV_FUNC, typename ELEMENTWISE_ADD_FUNC,
typename HANDLER_FUNC = handler_func>
HANDLER_FUNC GenerateFuseHandler(
const patterns::Conv& conv_pattern,
const patterns::ElementwiseAdd& elementwise_add_pattern,
CONV_FUNC get_node_from_conv_op,
ELEMENTWISE_ADD_FUNC get_node_from_elementwise_add_op) const;
public: public:
virtual ~ConvElementwiseAddMKLDNNFusePass() {} virtual ~ResidualConnectionMKLDNNFusePass() {}
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const; std::unique_ptr<ir::Graph> ApplyImpl(graph_ptr graph) const;
const std::string name_scope_{"residual_connections_fuse_pass"}; const std::string name_scope_{"residual_connection_fuse_pass"};
}; };
template <typename CONV_FUNC, typename ELEMENTWISE_ADD_FUNC,
typename HANDLER_FUNC>
HANDLER_FUNC ResidualConnectionMKLDNNFusePass::GenerateFuseHandler(
const patterns::Conv& conv_pattern,
const patterns::ElementwiseAdd& elementwise_add_pattern,
CONV_FUNC get_node_from_conv_op,
ELEMENTWISE_ADD_FUNC get_node_from_elementwise_add_op) const {
return [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) {
Node* conv_op;
Node* conv_input;
Node* conv_filter;
Node* conv_output;
Node* elementwise_add_op;
Node* elementwise_add_identity;
Node* elementwise_add_out;
std::tie(conv_op, conv_input, conv_filter, conv_output) =
get_node_from_conv_op(conv_pattern, subgraph);
std::tie(elementwise_add_op, elementwise_add_identity,
elementwise_add_out) =
get_node_from_elementwise_add_op(elementwise_add_pattern, subgraph);
if (this->FindFuseOption(*conv_op, *elementwise_add_op) != FUSE_MKLDNN)
return;
if (!IsReachable(graph, elementwise_add_identity, conv_output)) return;
OpDesc op_desc;
op_desc.SetType("conv2d");
op_desc.SetInput("Input", {conv_input->Name()});
op_desc.SetInput("Filter", {conv_filter->Name()});
op_desc.SetInput("ResidualData", {elementwise_add_identity->Name()});
op_desc.SetOutput("Output", {conv_output->Name()});
bool has_bias;
Node* conv_bias;
std::tie(has_bias, conv_bias) = this->HasBias(*conv_op);
if (has_bias) {
op_desc.SetInput("Bias", {conv_bias->Name()});
}
for (const auto& attr : conv_op->Op()->GetAttrMap()) {
op_desc.SetAttr(attr.first, attr.second);
}
op_desc.SetAttr("fuse_residual_connection", true);
auto fused_conv_op = graph->CreateOpNode(&op_desc);
IR_NODE_LINK_TO(conv_input, fused_conv_op);
IR_NODE_LINK_TO(conv_filter, fused_conv_op);
IR_NODE_LINK_TO(elementwise_add_identity, fused_conv_op);
IR_NODE_LINK_TO(fused_conv_op, conv_output);
if (has_bias) {
IR_NODE_LINK_TO(conv_bias, fused_conv_op);
}
CorrectGraphEdges(graph, elementwise_add_out, conv_output);
GraphSafeRemoveNodes(graph,
{elementwise_add_out, conv_op, elementwise_add_op});
};
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -1084,16 +1084,12 @@ PDNode *patterns::Conv::operator()() { ...@@ -1084,16 +1084,12 @@ PDNode *patterns::Conv::operator()() {
return output_var; return output_var;
} }
PDNode *patterns::ElementwiseAdd::operator()(PDNode *x_var) { PDNode *patterns::ElementwiseAdd::operator()(PDNode *x_var, PDNode *y_var) {
auto elementwise_add_op = pattern->NewNode(elementwise_add_op_repr()) auto elementwise_add_op = pattern->NewNode(elementwise_add_op_repr())
->assert_is_op("elementwise_add"); ->assert_is_op("elementwise_add");
x_var->assert_is_op_input("elementwise_add", "X"); x_var->AsInput()->assert_is_op_input("elementwise_add", "X");
y_var->AsInput()->assert_is_op_input("elementwise_add", "Y");
auto y_var = pattern->NewNode(elementwise_add_x_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto out_var = pattern->NewNode(elementwise_add_out_repr()) auto out_var = pattern->NewNode(elementwise_add_out_repr())
->AsOutput() ->AsOutput()
->assert_is_op_output("elementwise_add", "Out"); ->assert_is_op_output("elementwise_add", "Out");
......
...@@ -664,7 +664,7 @@ struct ElementwiseAdd : public PatternBase { ...@@ -664,7 +664,7 @@ struct ElementwiseAdd : public PatternBase {
ElementwiseAdd(PDPattern* pattern, const std::string& name_scope) ElementwiseAdd(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "elementwise_add") {} : PatternBase(pattern, name_scope, "elementwise_add") {}
PDNode* operator()(PDNode* x_var); PDNode* operator()(PDNode* x_var, PDNode* y_var);
PATTERN_DECL_NODE(elementwise_add_op); PATTERN_DECL_NODE(elementwise_add_op);
PATTERN_DECL_NODE(elementwise_add_x); PATTERN_DECL_NODE(elementwise_add_x);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册