提交 ee6f778b 编写于 作者: T Tomasz Patejko

MKLDNN residual connections fuse pass: further refactoring

上级 7423748e
......@@ -99,10 +99,9 @@ bool IsReachable(ir::Graph* graph, Node* from, Node* to) {
return false;
}
std::pair<bool, Node*> ResidualConnectionMKLDNNFusePass::HasBias(
const Node& op) const {
std::pair<bool, Node*> HasBias(const Node& op, const std::string& bias_name) {
auto bias_input_names = op.Op()->Inputs();
auto bias_it = bias_input_names.find("Bias");
auto bias_it = bias_input_names.find(bias_name);
if (bias_it != std::end(bias_input_names)) {
bool has_bias = !bias_it->second.empty();
......@@ -121,6 +120,74 @@ std::pair<bool, Node*> ResidualConnectionMKLDNNFusePass::HasBias(
return std::make_pair(false, nullptr);
}
ResidualConnectionMKLDNNFusePass::FuseHandler::FuseHandler(
const ResidualConnectionMKLDNNFusePass::ConvFunc& get_node_from_conv_op,
const ResidualConnectionMKLDNNFusePass::ElementwiseAddFunc&
get_node_from_elementwise_add_op,
const ResidualConnectionMKLDNNFusePass::CanFuseFunc& can_fuse_func)
: get_node_from_conv_op{get_node_from_conv_op},
get_node_from_elementwise_add_op{get_node_from_elementwise_add_op},
can_fuse_func{can_fuse_func} {}
void ResidualConnectionMKLDNNFusePass::FuseHandler::operator()(
const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) {
Node* conv_op;
Node* conv_input;
Node* conv_filter;
Node* conv_output;
Node* elementwise_add_op;
Node* elementwise_add_identity;
Node* elementwise_add_out;
std::tie(conv_op, conv_input, conv_filter, conv_output) =
get_node_from_conv_op(subgraph);
std::tie(elementwise_add_op, elementwise_add_identity, elementwise_add_out) =
get_node_from_elementwise_add_op(subgraph);
if (!can_fuse_func(conv_op, elementwise_add_op)) return;
if (!IsReachable(graph, elementwise_add_identity, conv_output)) return;
OpDesc op_desc;
op_desc.SetType("conv2d");
op_desc.SetInput("Input", {conv_input->Name()});
op_desc.SetInput("Filter", {conv_filter->Name()});
op_desc.SetInput("ResidualData", {elementwise_add_identity->Name()});
op_desc.SetOutput("Output", {conv_output->Name()});
bool has_bias;
Node* conv_bias;
std::tie(has_bias, conv_bias) = HasBias(*conv_op, "Bias");
if (has_bias) {
op_desc.SetInput("Bias", {conv_bias->Name()});
}
for (const auto& attr : conv_op->Op()->GetAttrMap()) {
op_desc.SetAttr(attr.first, attr.second);
}
op_desc.SetAttr("fuse_residual_connection", true);
auto fused_conv_op = graph->CreateOpNode(&op_desc);
IR_NODE_LINK_TO(conv_input, fused_conv_op);
IR_NODE_LINK_TO(conv_filter, fused_conv_op);
IR_NODE_LINK_TO(elementwise_add_identity, fused_conv_op);
IR_NODE_LINK_TO(fused_conv_op, conv_output);
if (has_bias) {
IR_NODE_LINK_TO(conv_bias, fused_conv_op);
}
CorrectGraphEdges(graph, elementwise_add_out, conv_output);
GraphSafeRemoveNodes(graph,
{elementwise_add_out, conv_op, elementwise_add_op});
}
graph_ptr ResidualConnectionMKLDNNFusePass::FuseConvAsX(
const std::string& name_scope_, graph_ptr graph) const {
GraphPatternDetector gpd;
......@@ -135,8 +202,8 @@ graph_ptr ResidualConnectionMKLDNNFusePass::FuseConvAsX(
pattern->NewNode(elementwise_add_pattern.elementwise_add_y_repr()));
conv_output->AsIntermediate();
auto get_node_from_conv = [](const patterns::Conv& conv_pattern,
const GraphPatternDetector::subgraph_t& subgraph)
auto get_node_from_conv =
[&conv_pattern](const GraphPatternDetector::subgraph_t& subgraph)
-> std::tuple<Node*, Node*, Node*, Node*> {
GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_input, conv_input, conv_pattern);
......@@ -146,8 +213,7 @@ graph_ptr ResidualConnectionMKLDNNFusePass::FuseConvAsX(
return std::make_tuple(conv_op, conv_input, conv_filter, conv_output);
};
auto get_node_from_elementwise_add = [](
const patterns::ElementwiseAdd& elementwise_add_pattern,
auto get_node_from_elementwise_add = [&elementwise_add_pattern](
const GraphPatternDetector::subgraph_t& subgraph)
-> std::tuple<Node*, Node*, Node*> {
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
......@@ -161,10 +227,14 @@ graph_ptr ResidualConnectionMKLDNNFusePass::FuseConvAsX(
elementwise_add_out);
};
auto handler =
GenerateFuseHandler(conv_pattern, elementwise_add_pattern,
get_node_from_conv, get_node_from_elementwise_add);
gpd(graph.get(), handler);
auto can_fuse = [this](Node* op1, Node* op2) -> bool {
return this->FindFuseOption(*op1, *op2) == FUSE_MKLDNN;
};
auto fuse_handler =
FuseHandler{get_node_from_conv, get_node_from_elementwise_add, can_fuse};
gpd(graph.get(), fuse_handler);
return graph;
}
......@@ -183,8 +253,8 @@ graph_ptr ResidualConnectionMKLDNNFusePass::FuseConvAsY(
conv_output);
conv_output->AsIntermediate();
auto get_node_from_conv = [](const patterns::Conv& conv_pattern,
const GraphPatternDetector::subgraph_t& subgraph)
auto get_node_from_conv =
[&conv_pattern](const GraphPatternDetector::subgraph_t& subgraph)
-> std::tuple<Node*, Node*, Node*, Node*> {
GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_input, conv_input, conv_pattern);
......@@ -194,8 +264,7 @@ graph_ptr ResidualConnectionMKLDNNFusePass::FuseConvAsY(
return std::make_tuple(conv_op, conv_input, conv_filter, conv_output);
};
auto get_node_from_elementwise_add = [](
const patterns::ElementwiseAdd& elementwise_add_pattern,
auto get_node_from_elementwise_add = [&elementwise_add_pattern](
const GraphPatternDetector::subgraph_t& subgraph)
-> std::tuple<Node*, Node*, Node*> {
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
......@@ -209,10 +278,14 @@ graph_ptr ResidualConnectionMKLDNNFusePass::FuseConvAsY(
elementwise_add_out);
};
auto handler =
GenerateFuseHandler(conv_pattern, elementwise_add_pattern,
get_node_from_conv, get_node_from_elementwise_add);
gpd(graph.get(), handler);
auto can_fuse = [this](Node* op1, Node* op2) -> bool {
return this->FindFuseOption(*op1, *op2) == FUSE_MKLDNN;
};
auto fuse_handler =
FuseHandler{get_node_from_conv, get_node_from_elementwise_add, can_fuse};
gpd(graph.get(), fuse_handler);
return graph;
}
......
......@@ -15,6 +15,7 @@
#pragma once
#include <string>
#include <tuple>
#include <utility>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
......@@ -28,24 +29,32 @@ using graph_ptr = std::unique_ptr<ir::Graph>;
void CorrectGraphEdges(Graph* graph, Node* from, Node* to);
bool IsReachable(ir::Graph* graph, Node* from, Node* to);
using handler_func = std::function<void(
const GraphPatternDetector::subgraph_t& subgraph, Graph* g)>;
std::pair<bool, Node*> HasBias(const Node& op, const std::string& bias_name);
class ResidualConnectionMKLDNNFusePass : public FusePassBase {
private:
graph_ptr FuseConvAsX(const std::string& name_scope_, graph_ptr graph) const;
graph_ptr FuseConvAsY(const std::string& name_scope_, graph_ptr graph) const;
std::pair<bool, Node*> HasBias(const Node& op) const;
template <typename RetType>
using GetNodeFunc =
std::function<RetType(const GraphPatternDetector::subgraph_t& subgraph)>;
using ConvFunc = GetNodeFunc<std::tuple<Node*, Node*, Node*, Node*>>;
using ElementwiseAddFunc = GetNodeFunc<std::tuple<Node*, Node*, Node*>>;
using CanFuseFunc = std::function<bool(Node*, Node*)>;
struct FuseHandler {
FuseHandler(const ConvFunc& get_node_from_conv_op,
const ElementwiseAddFunc& get_node_from_elementwise_add_op,
const CanFuseFunc& can_fuse_func);
ConvFunc get_node_from_conv_op;
ElementwiseAddFunc get_node_from_elementwise_add_op;
CanFuseFunc can_fuse_func;
template <typename CONV_FUNC, typename ELEMENTWISE_ADD_FUNC,
typename HANDLER_FUNC = handler_func>
HANDLER_FUNC GenerateFuseHandler(
const patterns::Conv& conv_pattern,
const patterns::ElementwiseAdd& elementwise_add_pattern,
CONV_FUNC get_node_from_conv_op,
ELEMENTWISE_ADD_FUNC get_node_from_elementwise_add_op) const;
void operator()(const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph);
};
public:
virtual ~ResidualConnectionMKLDNNFusePass() {}
......@@ -55,74 +64,6 @@ class ResidualConnectionMKLDNNFusePass : public FusePassBase {
const std::string name_scope_{"residual_connection_fuse_pass"};
};
template <typename CONV_FUNC, typename ELEMENTWISE_ADD_FUNC,
typename HANDLER_FUNC>
HANDLER_FUNC ResidualConnectionMKLDNNFusePass::GenerateFuseHandler(
const patterns::Conv& conv_pattern,
const patterns::ElementwiseAdd& elementwise_add_pattern,
CONV_FUNC get_node_from_conv_op,
ELEMENTWISE_ADD_FUNC get_node_from_elementwise_add_op) const {
return [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) {
Node* conv_op;
Node* conv_input;
Node* conv_filter;
Node* conv_output;
Node* elementwise_add_op;
Node* elementwise_add_identity;
Node* elementwise_add_out;
std::tie(conv_op, conv_input, conv_filter, conv_output) =
get_node_from_conv_op(conv_pattern, subgraph);
std::tie(elementwise_add_op, elementwise_add_identity,
elementwise_add_out) =
get_node_from_elementwise_add_op(elementwise_add_pattern, subgraph);
if (this->FindFuseOption(*conv_op, *elementwise_add_op) != FUSE_MKLDNN)
return;
if (!IsReachable(graph, elementwise_add_identity, conv_output)) return;
OpDesc op_desc;
op_desc.SetType("conv2d");
op_desc.SetInput("Input", {conv_input->Name()});
op_desc.SetInput("Filter", {conv_filter->Name()});
op_desc.SetInput("ResidualData", {elementwise_add_identity->Name()});
op_desc.SetOutput("Output", {conv_output->Name()});
bool has_bias;
Node* conv_bias;
std::tie(has_bias, conv_bias) = this->HasBias(*conv_op);
if (has_bias) {
op_desc.SetInput("Bias", {conv_bias->Name()});
}
for (const auto& attr : conv_op->Op()->GetAttrMap()) {
op_desc.SetAttr(attr.first, attr.second);
}
op_desc.SetAttr("fuse_residual_connection", true);
auto fused_conv_op = graph->CreateOpNode(&op_desc);
IR_NODE_LINK_TO(conv_input, fused_conv_op);
IR_NODE_LINK_TO(conv_filter, fused_conv_op);
IR_NODE_LINK_TO(elementwise_add_identity, fused_conv_op);
IR_NODE_LINK_TO(fused_conv_op, conv_output);
if (has_bias) {
IR_NODE_LINK_TO(conv_bias, fused_conv_op);
}
CorrectGraphEdges(graph, elementwise_add_out, conv_output);
GraphSafeRemoveNodes(graph,
{elementwise_add_out, conv_op, elementwise_add_op});
};
}
} // namespace ir
} // namespace framework
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册