未验证 提交 47459e98 编写于 作者: S Sylwester Fraczek 提交者: GitHub

refactor conv+relementwise_add (residual) (#40005)

上级 c0e29233
......@@ -22,6 +22,7 @@
#include "paddle/fluid/framework/ir/graph_traits.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/string/pretty_log.h"
namespace paddle {
namespace framework {
......@@ -135,157 +136,9 @@ ResidualConnectionMKLDNNFusePass::ResidualConnectionMKLDNNFusePass() {
.End();
}
ResidualConnectionMKLDNNFusePass::IdentityFuseHandle::IdentityFuseHandle(
const ResidualConnectionMKLDNNFusePass::CanFuseFunc& can_fuse_func,
const ResidualConnectionMKLDNNFusePass::IdentityConvFunc&
get_node_from_conv_op,
const ResidualConnectionMKLDNNFusePass::IdentityElementwiseAddFunc&
get_node_from_elementwise_add_op,
const ResidualConnectionMKLDNNFusePass* pass)
: fusion_stats{std::make_shared<int>(0)},
can_fuse_func{can_fuse_func},
get_node_from_conv_op{get_node_from_conv_op},
get_node_from_elementwise_add_op{get_node_from_elementwise_add_op},
pass_{pass} {}
void ResidualConnectionMKLDNNFusePass::IdentityFuseHandle::operator()(
const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) {
Node* conv_op;
Node* conv_input;
Node* conv_filter;
Node* conv_output;
Node* elementwise_add_op;
Node* elementwise_add_identity;
Node* elementwise_add_out;
std::tie(conv_op, conv_input, conv_filter, conv_output) =
get_node_from_conv_op(subgraph);
std::tie(elementwise_add_op, elementwise_add_identity, elementwise_add_out) =
get_node_from_elementwise_add_op(subgraph);
if (!can_fuse_func(conv_op, elementwise_add_op)) return;
if (!IsReachable(graph, elementwise_add_identity, conv_output)) return;
if (HasFusedActivation(conv_op)) return;
if (!pass_->IsCompat(subgraph, graph)) {
LOG(WARNING)
<< "conv_elementwise_add_mkldnn_fuse_pass in op compat failed.";
return;
}
conv_op->Op()->SetInput("ResidualData", {elementwise_add_identity->Name()});
conv_op->Op()->SetOutput("Output", {elementwise_add_out->Name()});
conv_op->Op()->SetAttr("fuse_residual_connection", true);
GraphSafeRemoveNodes(graph, {conv_output, elementwise_add_op});
IR_NODE_LINK_TO(elementwise_add_identity, conv_op);
IR_NODE_LINK_TO(conv_op, elementwise_add_out);
(*fusion_stats)++;
}
ResidualConnectionMKLDNNFusePass::ProjectionFuseHandle::ProjectionFuseHandle(
const ResidualConnectionMKLDNNFusePass::CanFuseFunc& can_fuse_func,
const ResidualConnectionMKLDNNFusePass::ProjectionConvFunc&
get_node_from_conv_x_op,
const ResidualConnectionMKLDNNFusePass::ProjectionConvFunc&
get_node_from_conv_y_op,
const ResidualConnectionMKLDNNFusePass::ProjectionElementwiseAddFunc&
get_node_from_elementwise_add_op,
const ResidualConnectionMKLDNNFusePass* pass)
: fusion_stats{std::make_shared<int>(0)},
can_fuse_func{can_fuse_func},
get_node_from_conv_x_op{get_node_from_conv_x_op},
get_node_from_conv_y_op{get_node_from_conv_y_op},
get_node_from_elementwise_add_op{get_node_from_elementwise_add_op},
pass_{pass} {}
void ResidualConnectionMKLDNNFusePass::ProjectionFuseHandle::operator()(
const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) {
Node* conv_x_op;
Node* conv_x_input;
Node* conv_x_filter;
Node* conv_x_output;
Node* conv_y_op;
Node* conv_y_input;
Node* conv_y_filter;
Node* conv_y_output;
Node* elementwise_add_op;
Node* elementwise_add_out;
if (!pass_->IsCompat(subgraph, graph)) {
LOG(WARNING)
<< "conv_elementwise_add_mkldnn_fuse_pass in op compat failed.";
return;
}
std::tie(conv_x_op, conv_x_input, conv_x_filter, conv_x_output) =
get_node_from_conv_x_op(subgraph);
std::tie(conv_y_op, conv_y_input, conv_y_filter, conv_y_output) =
get_node_from_conv_y_op(subgraph);
std::tie(elementwise_add_op, elementwise_add_out) =
get_node_from_elementwise_add_op(subgraph);
if (!can_fuse_func(conv_x_op, elementwise_add_op)) return;
if (!can_fuse_func(conv_y_op, elementwise_add_op)) return;
Node* projection_node;
Node* residual_conv_op;
Node* residual_conv_output;
if (IsReachable(graph, conv_x_input, conv_y_output)) {
projection_node = conv_x_output;
residual_conv_op = conv_y_op;
residual_conv_output = conv_y_output;
} else if (IsReachable(graph, conv_y_input, conv_x_output)) {
projection_node = conv_y_output;
residual_conv_op = conv_x_op;
residual_conv_output = conv_x_output;
} else {
return;
}
if (HasFusedActivation(residual_conv_op)) return;
residual_conv_op->Op()->SetInput("ResidualData", {projection_node->Name()});
residual_conv_op->Op()->SetOutput("Output", {elementwise_add_out->Name()});
residual_conv_op->Op()->SetAttr("fuse_residual_connection", true);
GraphSafeRemoveNodes(graph, {residual_conv_output, elementwise_add_op});
IR_NODE_LINK_TO(projection_node, residual_conv_op);
IR_NODE_LINK_TO(residual_conv_op, elementwise_add_out);
(*fusion_stats)++;
}
std::tuple<Node*, Node*, Node*, Node*>
ResidualConnectionMKLDNNFusePass::GetNodesFromConv(
const patterns::Conv& conv_pattern,
const GraphPatternDetector::subgraph_t& subgraph) const {
GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_input, conv_input, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_filter, conv_filter, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_output, conv_output, conv_pattern);
return std::make_tuple(conv_op, conv_input, conv_filter, conv_output);
}
GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsX(
const std::string& name_scope,
const GraphWithStats& graph_with_stats) const {
ir::Graph* graph;
int stats;
std::tie(graph, stats) = graph_with_stats;
GraphPatternDetector gpd;
auto pattern = gpd.mutable_pattern();
......@@ -298,26 +151,56 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsX(
pattern->NewNode(elementwise_add_pattern.elementwise_add_y_repr()));
conv_output->AsIntermediate();
auto get_node_from_elementwise_add = [&elementwise_add_pattern](
const GraphPatternDetector::subgraph_t& subgraph)
-> std::tuple<Node*, Node*, Node*> {
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_y, elementwise_add_y,
elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out,
elementwise_add_pattern);
return std::make_tuple(elementwise_add_op, elementwise_add_y,
elementwise_add_out);
};
return ExecuteHandleOnGraph<IdentityFuseHandle>(
&gpd, graph_with_stats,
[this, &conv_pattern](const GraphPatternDetector::subgraph_t& subgraph) {
return GetNodesFromConv(conv_pattern, subgraph);
},
get_node_from_elementwise_add, this);
int found_conv_as_x_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_input, conv_input, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_filter, conv_filter, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_output, conv_output, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_identity, elementwise_add_y,
elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out,
elementwise_add_pattern);
if (FindFuseOption(*conv_op, *elementwise_add_op) != FUSE_MKLDNN) return;
if (!IsReachable(g, elementwise_add_identity, conv_output)) return;
if (HasFusedActivation(conv_op)) return;
if (!IsCompat(subgraph, g)) {
LOG(WARNING)
<< "conv_elementwise_add_mkldnn_fuse_pass in op compat failed.";
return;
}
conv_op->Op()->SetInput("ResidualData", {elementwise_add_identity->Name()});
conv_op->Op()->SetOutput("Output", {elementwise_add_out->Name()});
conv_op->Op()->SetAttr("fuse_residual_connection", true);
GraphSafeRemoveNodes(g, {conv_output, elementwise_add_op});
IR_NODE_LINK_TO(elementwise_add_identity, conv_op);
IR_NODE_LINK_TO(conv_op, elementwise_add_out);
found_conv_as_x_count++;
};
gpd(graph_with_stats.first, handler);
if (!Has("disable_logs") || !Get<bool>("disable_logs")) {
std::stringstream msg_ss;
msg_ss << "--- Fused " << found_conv_as_x_count
<< " conv (as x) + elementwise_add patterns";
paddle::string::PrettyLogDetail(msg_ss.str().c_str());
}
return std::make_pair(graph_with_stats.first,
found_conv_as_x_count + graph_with_stats.second);
}
GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsY(
......@@ -335,26 +218,56 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsY(
conv_output);
conv_output->AsIntermediate();
auto get_node_from_elementwise_add = [&elementwise_add_pattern](
const GraphPatternDetector::subgraph_t& subgraph)
-> std::tuple<Node*, Node*, Node*> {
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_x, elementwise_add_x,
elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out,
elementwise_add_pattern);
return std::make_tuple(elementwise_add_op, elementwise_add_x,
elementwise_add_out);
};
return ExecuteHandleOnGraph<IdentityFuseHandle>(
&gpd, graph_with_stats,
[this, &conv_pattern](const GraphPatternDetector::subgraph_t& subgraph) {
return GetNodesFromConv(conv_pattern, subgraph);
},
get_node_from_elementwise_add, this);
int found_conv_as_y_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_input, conv_input, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_filter, conv_filter, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_output, conv_output, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_x, elementwise_add_x,
elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out,
elementwise_add_pattern);
if (FindFuseOption(*conv_op, *elementwise_add_op) != FUSE_MKLDNN) return;
if (!IsReachable(g, elementwise_add_x, conv_output)) return;
if (HasFusedActivation(conv_op)) return;
if (!IsCompat(subgraph, g)) {
LOG(WARNING)
<< "conv_elementwise_add_mkldnn_fuse_pass in op compat failed.";
return;
}
conv_op->Op()->SetInput("ResidualData", {elementwise_add_x->Name()});
conv_op->Op()->SetOutput("Output", {elementwise_add_out->Name()});
conv_op->Op()->SetAttr("fuse_residual_connection", true);
GraphSafeRemoveNodes(g, {conv_output, elementwise_add_op});
IR_NODE_LINK_TO(elementwise_add_x, conv_op);
IR_NODE_LINK_TO(conv_op, elementwise_add_out);
found_conv_as_y_count++;
};
gpd(graph_with_stats.first, handler);
if (!Has("disable_logs") || !Get<bool>("disable_logs")) {
std::stringstream msg_ss;
msg_ss << "--- Fused " << found_conv_as_y_count
<< " conv (as y) + elementwise_add patterns";
paddle::string::PrettyLogDetail(msg_ss.str().c_str());
}
return std::make_pair(graph_with_stats.first,
found_conv_as_y_count + graph_with_stats.second);
}
GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv(
......@@ -374,39 +287,84 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv(
conv_x_output->AsIntermediate();
conv_y_output->AsIntermediate();
auto get_node_from_elementwise_add = [&elementwise_add_pattern](
const GraphPatternDetector::subgraph_t& subgraph)
-> std::tuple<Node*, Node*> {
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out,
elementwise_add_pattern);
return std::make_tuple(elementwise_add_op, elementwise_add_out);
};
return ExecuteHandleOnGraph<ProjectionFuseHandle>(
&gpd, graph_with_stats,
[this,
&conv_x_pattern](const GraphPatternDetector::subgraph_t& subgraph) {
return GetNodesFromConv(conv_x_pattern, subgraph);
},
[this,
&conv_y_pattern](const GraphPatternDetector::subgraph_t& subgraph) {
return GetNodesFromConv(conv_y_pattern, subgraph);
},
get_node_from_elementwise_add, this);
int found_projection_conv_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_IR_NODE_FROM_SUBGRAPH(conv_x_op, conv_op, conv_x_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_x_input, conv_input, conv_x_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_x_filter, conv_filter, conv_x_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_x_output, conv_output, conv_x_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_y_op, conv_op, conv_y_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_y_input, conv_input, conv_y_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_y_filter, conv_filter, conv_y_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_y_output, conv_output, conv_y_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out,
elementwise_add_pattern);
if (!IsCompat(subgraph, g)) {
LOG(WARNING)
<< "conv_elementwise_add_mkldnn_fuse_pass in op compat failed.";
return;
}
if (FindFuseOption(*conv_x_op, *elementwise_add_op) != FUSE_MKLDNN) return;
if (FindFuseOption(*conv_y_op, *elementwise_add_op) != FUSE_MKLDNN) return;
Node* projection_node;
Node* residual_conv_op;
Node* residual_conv_output;
if (IsReachable(g, conv_x_input, conv_y_output)) {
projection_node = conv_x_output;
residual_conv_op = conv_y_op;
residual_conv_output = conv_y_output;
} else if (IsReachable(g, conv_y_input, conv_x_output)) {
projection_node = conv_y_output;
residual_conv_op = conv_x_op;
residual_conv_output = conv_x_output;
} else {
return;
}
if (HasFusedActivation(residual_conv_op)) return;
residual_conv_op->Op()->SetInput("ResidualData", {projection_node->Name()});
residual_conv_op->Op()->SetOutput("Output", {elementwise_add_out->Name()});
residual_conv_op->Op()->SetAttr("fuse_residual_connection", true);
GraphSafeRemoveNodes(g, {residual_conv_output, elementwise_add_op});
IR_NODE_LINK_TO(projection_node, residual_conv_op);
IR_NODE_LINK_TO(residual_conv_op, elementwise_add_out);
found_projection_conv_count++;
};
gpd(graph_with_stats.first, handler);
if (!Has("disable_logs") || !Get<bool>("disable_logs")) {
std::stringstream msg_ss;
msg_ss << "--- Fused " << found_projection_conv_count
<< " projection conv (as y) + elementwise_add patterns";
paddle::string::PrettyLogDetail(msg_ss.str().c_str());
}
return std::make_pair(graph_with_stats.first,
found_projection_conv_count + graph_with_stats.second);
}
void ResidualConnectionMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
void ResidualConnectionMKLDNNFusePass::ApplyImpl(ir::Graph* graph) const {
FusePassBase::Init(name_scope_, graph);
auto fused_graph_with_stats = FuseConvAsY(
name_scope_,
FuseConvAsX(name_scope_,
FuseProjectionConv(name_scope_, std::make_pair(graph, 0))));
auto graph_with_stats =
FuseProjectionConv(name_scope_, std::make_pair(graph, 0));
graph_with_stats = FuseConvAsX(name_scope_, graph_with_stats);
graph_with_stats = FuseConvAsY(name_scope_, graph_with_stats);
LOG(INFO) << "Fused graph " << fused_graph_with_stats.second << "\n";
AddStatis(fused_graph_with_stats.second);
AddStatis(graph_with_stats.second);
}
} // namespace ir
} // namespace framework
......
......@@ -28,19 +28,9 @@ namespace paddle {
namespace framework {
namespace ir {
class Graph;
class GraphPatternDetector;
class Node;
namespace patterns {
struct Conv;
} // namespace patterns
using graph_ptr = ir::Graph*;
using GraphWithStats = std::pair<ir::Graph*, int>;
void CorrectGraphEdges(Graph* graph, Node* from, Node* to);
bool IsReachable(ir::Graph* graph, Node* from, Node* to);
paddle::optional<Node*> HasBias(const Node& op, const std::string& bias_name);
class ResidualConnectionMKLDNNFusePass : public FusePassBase {
private:
......@@ -52,91 +42,13 @@ class ResidualConnectionMKLDNNFusePass : public FusePassBase {
const std::string& name_scope,
const GraphWithStats& graph_with_stats) const;
template <typename RetType>
using GetNodeFunc =
std::function<RetType(const GraphPatternDetector::subgraph_t& subgraph)>;
using IdentityConvFunc = GetNodeFunc<std::tuple<Node*, Node*, Node*, Node*>>;
using IdentityElementwiseAddFunc =
GetNodeFunc<std::tuple<Node*, Node*, Node*>>;
using ProjectionConvFunc = IdentityConvFunc;
using ProjectionElementwiseAddFunc = GetNodeFunc<std::tuple<Node*, Node*>>;
using CanFuseFunc = std::function<bool(Node*, Node*)>;
std::tuple<Node*, Node*, Node*, Node*> GetNodesFromConv(
const patterns::Conv& conv_pattern,
const GraphPatternDetector::subgraph_t& subgraph) const;
std::tuple<Node*, Node*, Node*, Node*> GetNodesFromProjectionConv(
const patterns::Conv& conv_pattern,
const GraphPatternDetector::subgraph_t& subgraph) const;
template <typename HandleType, typename... OpFuncs>
GraphWithStats ExecuteHandleOnGraph(GraphPatternDetector* gpd,
const GraphWithStats& graph_with_stats,
OpFuncs&&... op_funcs) const {
ir::Graph* graph;
int stats;
std::tie(graph, stats) = graph_with_stats;
auto can_fuse = [this](Node* op1, Node* op2) -> bool {
return this->FindFuseOption(*op1, *op2) == FUSE_MKLDNN;
};
auto fuse_handle = HandleType{can_fuse, std::forward<OpFuncs>(op_funcs)...};
(*gpd)(graph, fuse_handle);
return std::make_pair(graph, stats + fuse_handle.get_stats());
}
struct IdentityFuseHandle {
IdentityFuseHandle(
const CanFuseFunc& can_fuse_func,
const IdentityConvFunc& get_node_from_conv_op,
const IdentityElementwiseAddFunc& get_node_from_elementwise_add_op,
const ResidualConnectionMKLDNNFusePass* pass);
void operator()(const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph);
int get_stats() const { return *fusion_stats; }
private:
std::shared_ptr<int> fusion_stats;
CanFuseFunc can_fuse_func;
IdentityConvFunc get_node_from_conv_op;
IdentityElementwiseAddFunc get_node_from_elementwise_add_op;
const ResidualConnectionMKLDNNFusePass* pass_;
};
struct ProjectionFuseHandle {
ProjectionFuseHandle(
const CanFuseFunc& can_fuse_func,
const ProjectionConvFunc& get_node_from_conv_x_op,
const ProjectionConvFunc& get_node_from_conv_y_op,
const ProjectionElementwiseAddFunc& get_node_from_elementwise_add_op,
const ResidualConnectionMKLDNNFusePass* pass);
void operator()(const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph);
int get_stats() const { return *fusion_stats; }
private:
std::shared_ptr<int> fusion_stats;
CanFuseFunc can_fuse_func;
ProjectionConvFunc get_node_from_conv_x_op;
ProjectionConvFunc get_node_from_conv_y_op;
ProjectionElementwiseAddFunc get_node_from_elementwise_add_op;
const ResidualConnectionMKLDNNFusePass* pass_;
};
public:
ResidualConnectionMKLDNNFusePass();
virtual ~ResidualConnectionMKLDNNFusePass() {}
protected:
void ApplyImpl(graph_ptr graph) const;
void ApplyImpl(ir::Graph* graph) const;
static bool HasFusedActivation(Node* conv_node) {
return !(conv_node->Op()
->GetAttrIfExists<std::string>("fuse_activation")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册