未验证 提交 47459e98 编写于 作者: S Sylwester Fraczek 提交者: GitHub

refactor conv+relementwise_add (residual) (#40005)

上级 c0e29233
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include "paddle/fluid/framework/ir/graph_traits.h" #include "paddle/fluid/framework/ir/graph_traits.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/string/pretty_log.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -135,157 +136,9 @@ ResidualConnectionMKLDNNFusePass::ResidualConnectionMKLDNNFusePass() { ...@@ -135,157 +136,9 @@ ResidualConnectionMKLDNNFusePass::ResidualConnectionMKLDNNFusePass() {
.End(); .End();
} }
ResidualConnectionMKLDNNFusePass::IdentityFuseHandle::IdentityFuseHandle(
const ResidualConnectionMKLDNNFusePass::CanFuseFunc& can_fuse_func,
const ResidualConnectionMKLDNNFusePass::IdentityConvFunc&
get_node_from_conv_op,
const ResidualConnectionMKLDNNFusePass::IdentityElementwiseAddFunc&
get_node_from_elementwise_add_op,
const ResidualConnectionMKLDNNFusePass* pass)
: fusion_stats{std::make_shared<int>(0)},
can_fuse_func{can_fuse_func},
get_node_from_conv_op{get_node_from_conv_op},
get_node_from_elementwise_add_op{get_node_from_elementwise_add_op},
pass_{pass} {}
void ResidualConnectionMKLDNNFusePass::IdentityFuseHandle::operator()(
const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) {
Node* conv_op;
Node* conv_input;
Node* conv_filter;
Node* conv_output;
Node* elementwise_add_op;
Node* elementwise_add_identity;
Node* elementwise_add_out;
std::tie(conv_op, conv_input, conv_filter, conv_output) =
get_node_from_conv_op(subgraph);
std::tie(elementwise_add_op, elementwise_add_identity, elementwise_add_out) =
get_node_from_elementwise_add_op(subgraph);
if (!can_fuse_func(conv_op, elementwise_add_op)) return;
if (!IsReachable(graph, elementwise_add_identity, conv_output)) return;
if (HasFusedActivation(conv_op)) return;
if (!pass_->IsCompat(subgraph, graph)) {
LOG(WARNING)
<< "conv_elementwise_add_mkldnn_fuse_pass in op compat failed.";
return;
}
conv_op->Op()->SetInput("ResidualData", {elementwise_add_identity->Name()});
conv_op->Op()->SetOutput("Output", {elementwise_add_out->Name()});
conv_op->Op()->SetAttr("fuse_residual_connection", true);
GraphSafeRemoveNodes(graph, {conv_output, elementwise_add_op});
IR_NODE_LINK_TO(elementwise_add_identity, conv_op);
IR_NODE_LINK_TO(conv_op, elementwise_add_out);
(*fusion_stats)++;
}
ResidualConnectionMKLDNNFusePass::ProjectionFuseHandle::ProjectionFuseHandle(
const ResidualConnectionMKLDNNFusePass::CanFuseFunc& can_fuse_func,
const ResidualConnectionMKLDNNFusePass::ProjectionConvFunc&
get_node_from_conv_x_op,
const ResidualConnectionMKLDNNFusePass::ProjectionConvFunc&
get_node_from_conv_y_op,
const ResidualConnectionMKLDNNFusePass::ProjectionElementwiseAddFunc&
get_node_from_elementwise_add_op,
const ResidualConnectionMKLDNNFusePass* pass)
: fusion_stats{std::make_shared<int>(0)},
can_fuse_func{can_fuse_func},
get_node_from_conv_x_op{get_node_from_conv_x_op},
get_node_from_conv_y_op{get_node_from_conv_y_op},
get_node_from_elementwise_add_op{get_node_from_elementwise_add_op},
pass_{pass} {}
void ResidualConnectionMKLDNNFusePass::ProjectionFuseHandle::operator()(
const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) {
Node* conv_x_op;
Node* conv_x_input;
Node* conv_x_filter;
Node* conv_x_output;
Node* conv_y_op;
Node* conv_y_input;
Node* conv_y_filter;
Node* conv_y_output;
Node* elementwise_add_op;
Node* elementwise_add_out;
if (!pass_->IsCompat(subgraph, graph)) {
LOG(WARNING)
<< "conv_elementwise_add_mkldnn_fuse_pass in op compat failed.";
return;
}
std::tie(conv_x_op, conv_x_input, conv_x_filter, conv_x_output) =
get_node_from_conv_x_op(subgraph);
std::tie(conv_y_op, conv_y_input, conv_y_filter, conv_y_output) =
get_node_from_conv_y_op(subgraph);
std::tie(elementwise_add_op, elementwise_add_out) =
get_node_from_elementwise_add_op(subgraph);
if (!can_fuse_func(conv_x_op, elementwise_add_op)) return;
if (!can_fuse_func(conv_y_op, elementwise_add_op)) return;
Node* projection_node;
Node* residual_conv_op;
Node* residual_conv_output;
if (IsReachable(graph, conv_x_input, conv_y_output)) {
projection_node = conv_x_output;
residual_conv_op = conv_y_op;
residual_conv_output = conv_y_output;
} else if (IsReachable(graph, conv_y_input, conv_x_output)) {
projection_node = conv_y_output;
residual_conv_op = conv_x_op;
residual_conv_output = conv_x_output;
} else {
return;
}
if (HasFusedActivation(residual_conv_op)) return;
residual_conv_op->Op()->SetInput("ResidualData", {projection_node->Name()});
residual_conv_op->Op()->SetOutput("Output", {elementwise_add_out->Name()});
residual_conv_op->Op()->SetAttr("fuse_residual_connection", true);
GraphSafeRemoveNodes(graph, {residual_conv_output, elementwise_add_op});
IR_NODE_LINK_TO(projection_node, residual_conv_op);
IR_NODE_LINK_TO(residual_conv_op, elementwise_add_out);
(*fusion_stats)++;
}
std::tuple<Node*, Node*, Node*, Node*>
ResidualConnectionMKLDNNFusePass::GetNodesFromConv(
const patterns::Conv& conv_pattern,
const GraphPatternDetector::subgraph_t& subgraph) const {
GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_input, conv_input, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_filter, conv_filter, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_output, conv_output, conv_pattern);
return std::make_tuple(conv_op, conv_input, conv_filter, conv_output);
}
GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsX( GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsX(
const std::string& name_scope, const std::string& name_scope,
const GraphWithStats& graph_with_stats) const { const GraphWithStats& graph_with_stats) const {
ir::Graph* graph;
int stats;
std::tie(graph, stats) = graph_with_stats;
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto pattern = gpd.mutable_pattern(); auto pattern = gpd.mutable_pattern();
...@@ -298,26 +151,56 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsX( ...@@ -298,26 +151,56 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsX(
pattern->NewNode(elementwise_add_pattern.elementwise_add_y_repr())); pattern->NewNode(elementwise_add_pattern.elementwise_add_y_repr()));
conv_output->AsIntermediate(); conv_output->AsIntermediate();
auto get_node_from_elementwise_add = [&elementwise_add_pattern]( int found_conv_as_x_count = 0;
const GraphPatternDetector::subgraph_t& subgraph)
-> std::tuple<Node*, Node*, Node*> { auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op, Graph* g) {
elementwise_add_pattern); GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_y, elementwise_add_y, GET_IR_NODE_FROM_SUBGRAPH(conv_input, conv_input, conv_pattern);
elementwise_add_pattern); GET_IR_NODE_FROM_SUBGRAPH(conv_filter, conv_filter, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out, GET_IR_NODE_FROM_SUBGRAPH(conv_output, conv_output, conv_pattern);
elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
return std::make_tuple(elementwise_add_op, elementwise_add_y, elementwise_add_pattern);
elementwise_add_out); GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_identity, elementwise_add_y,
}; elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out,
return ExecuteHandleOnGraph<IdentityFuseHandle>( elementwise_add_pattern);
&gpd, graph_with_stats,
[this, &conv_pattern](const GraphPatternDetector::subgraph_t& subgraph) { if (FindFuseOption(*conv_op, *elementwise_add_op) != FUSE_MKLDNN) return;
return GetNodesFromConv(conv_pattern, subgraph);
}, if (!IsReachable(g, elementwise_add_identity, conv_output)) return;
get_node_from_elementwise_add, this);
if (HasFusedActivation(conv_op)) return;
if (!IsCompat(subgraph, g)) {
LOG(WARNING)
<< "conv_elementwise_add_mkldnn_fuse_pass in op compat failed.";
return;
}
conv_op->Op()->SetInput("ResidualData", {elementwise_add_identity->Name()});
conv_op->Op()->SetOutput("Output", {elementwise_add_out->Name()});
conv_op->Op()->SetAttr("fuse_residual_connection", true);
GraphSafeRemoveNodes(g, {conv_output, elementwise_add_op});
IR_NODE_LINK_TO(elementwise_add_identity, conv_op);
IR_NODE_LINK_TO(conv_op, elementwise_add_out);
found_conv_as_x_count++;
};
gpd(graph_with_stats.first, handler);
if (!Has("disable_logs") || !Get<bool>("disable_logs")) {
std::stringstream msg_ss;
msg_ss << "--- Fused " << found_conv_as_x_count
<< " conv (as x) + elementwise_add patterns";
paddle::string::PrettyLogDetail(msg_ss.str().c_str());
}
return std::make_pair(graph_with_stats.first,
found_conv_as_x_count + graph_with_stats.second);
} }
GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsY( GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsY(
...@@ -335,26 +218,56 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsY( ...@@ -335,26 +218,56 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsY(
conv_output); conv_output);
conv_output->AsIntermediate(); conv_output->AsIntermediate();
auto get_node_from_elementwise_add = [&elementwise_add_pattern]( int found_conv_as_y_count = 0;
const GraphPatternDetector::subgraph_t& subgraph)
-> std::tuple<Node*, Node*, Node*> { auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op, Graph* g) {
elementwise_add_pattern); GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_x, elementwise_add_x, GET_IR_NODE_FROM_SUBGRAPH(conv_input, conv_input, conv_pattern);
elementwise_add_pattern); GET_IR_NODE_FROM_SUBGRAPH(conv_filter, conv_filter, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out, GET_IR_NODE_FROM_SUBGRAPH(conv_output, conv_output, conv_pattern);
elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
return std::make_tuple(elementwise_add_op, elementwise_add_x, elementwise_add_pattern);
elementwise_add_out); GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_x, elementwise_add_x,
}; elementwise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out,
return ExecuteHandleOnGraph<IdentityFuseHandle>( elementwise_add_pattern);
&gpd, graph_with_stats,
[this, &conv_pattern](const GraphPatternDetector::subgraph_t& subgraph) { if (FindFuseOption(*conv_op, *elementwise_add_op) != FUSE_MKLDNN) return;
return GetNodesFromConv(conv_pattern, subgraph);
}, if (!IsReachable(g, elementwise_add_x, conv_output)) return;
get_node_from_elementwise_add, this);
if (HasFusedActivation(conv_op)) return;
if (!IsCompat(subgraph, g)) {
LOG(WARNING)
<< "conv_elementwise_add_mkldnn_fuse_pass in op compat failed.";
return;
}
conv_op->Op()->SetInput("ResidualData", {elementwise_add_x->Name()});
conv_op->Op()->SetOutput("Output", {elementwise_add_out->Name()});
conv_op->Op()->SetAttr("fuse_residual_connection", true);
GraphSafeRemoveNodes(g, {conv_output, elementwise_add_op});
IR_NODE_LINK_TO(elementwise_add_x, conv_op);
IR_NODE_LINK_TO(conv_op, elementwise_add_out);
found_conv_as_y_count++;
};
gpd(graph_with_stats.first, handler);
if (!Has("disable_logs") || !Get<bool>("disable_logs")) {
std::stringstream msg_ss;
msg_ss << "--- Fused " << found_conv_as_y_count
<< " conv (as y) + elementwise_add patterns";
paddle::string::PrettyLogDetail(msg_ss.str().c_str());
}
return std::make_pair(graph_with_stats.first,
found_conv_as_y_count + graph_with_stats.second);
} }
GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv( GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv(
...@@ -374,39 +287,84 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv( ...@@ -374,39 +287,84 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv(
conv_x_output->AsIntermediate(); conv_x_output->AsIntermediate();
conv_y_output->AsIntermediate(); conv_y_output->AsIntermediate();
auto get_node_from_elementwise_add = [&elementwise_add_pattern]( int found_projection_conv_count = 0;
const GraphPatternDetector::subgraph_t& subgraph)
-> std::tuple<Node*, Node*> { auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op, Graph* g) {
elementwise_add_pattern); GET_IR_NODE_FROM_SUBGRAPH(conv_x_op, conv_op, conv_x_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out, GET_IR_NODE_FROM_SUBGRAPH(conv_x_input, conv_input, conv_x_pattern);
elementwise_add_pattern); GET_IR_NODE_FROM_SUBGRAPH(conv_x_filter, conv_filter, conv_x_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_x_output, conv_output, conv_x_pattern);
return std::make_tuple(elementwise_add_op, elementwise_add_out);
}; GET_IR_NODE_FROM_SUBGRAPH(conv_y_op, conv_op, conv_y_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_y_input, conv_input, conv_y_pattern);
return ExecuteHandleOnGraph<ProjectionFuseHandle>( GET_IR_NODE_FROM_SUBGRAPH(conv_y_filter, conv_filter, conv_y_pattern);
&gpd, graph_with_stats, GET_IR_NODE_FROM_SUBGRAPH(conv_y_output, conv_output, conv_y_pattern);
[this,
&conv_x_pattern](const GraphPatternDetector::subgraph_t& subgraph) { GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op,
return GetNodesFromConv(conv_x_pattern, subgraph); elementwise_add_pattern);
}, GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out,
[this, elementwise_add_pattern);
&conv_y_pattern](const GraphPatternDetector::subgraph_t& subgraph) {
return GetNodesFromConv(conv_y_pattern, subgraph); if (!IsCompat(subgraph, g)) {
}, LOG(WARNING)
get_node_from_elementwise_add, this); << "conv_elementwise_add_mkldnn_fuse_pass in op compat failed.";
return;
}
if (FindFuseOption(*conv_x_op, *elementwise_add_op) != FUSE_MKLDNN) return;
if (FindFuseOption(*conv_y_op, *elementwise_add_op) != FUSE_MKLDNN) return;
Node* projection_node;
Node* residual_conv_op;
Node* residual_conv_output;
if (IsReachable(g, conv_x_input, conv_y_output)) {
projection_node = conv_x_output;
residual_conv_op = conv_y_op;
residual_conv_output = conv_y_output;
} else if (IsReachable(g, conv_y_input, conv_x_output)) {
projection_node = conv_y_output;
residual_conv_op = conv_x_op;
residual_conv_output = conv_x_output;
} else {
return;
}
if (HasFusedActivation(residual_conv_op)) return;
residual_conv_op->Op()->SetInput("ResidualData", {projection_node->Name()});
residual_conv_op->Op()->SetOutput("Output", {elementwise_add_out->Name()});
residual_conv_op->Op()->SetAttr("fuse_residual_connection", true);
GraphSafeRemoveNodes(g, {residual_conv_output, elementwise_add_op});
IR_NODE_LINK_TO(projection_node, residual_conv_op);
IR_NODE_LINK_TO(residual_conv_op, elementwise_add_out);
found_projection_conv_count++;
};
gpd(graph_with_stats.first, handler);
if (!Has("disable_logs") || !Get<bool>("disable_logs")) {
std::stringstream msg_ss;
msg_ss << "--- Fused " << found_projection_conv_count
<< " projection conv (as y) + elementwise_add patterns";
paddle::string::PrettyLogDetail(msg_ss.str().c_str());
}
return std::make_pair(graph_with_stats.first,
found_projection_conv_count + graph_with_stats.second);
} }
void ResidualConnectionMKLDNNFusePass::ApplyImpl(graph_ptr graph) const { void ResidualConnectionMKLDNNFusePass::ApplyImpl(ir::Graph* graph) const {
FusePassBase::Init(name_scope_, graph); FusePassBase::Init(name_scope_, graph);
auto fused_graph_with_stats = FuseConvAsY( auto graph_with_stats =
name_scope_, FuseProjectionConv(name_scope_, std::make_pair(graph, 0));
FuseConvAsX(name_scope_, graph_with_stats = FuseConvAsX(name_scope_, graph_with_stats);
FuseProjectionConv(name_scope_, std::make_pair(graph, 0)))); graph_with_stats = FuseConvAsY(name_scope_, graph_with_stats);
LOG(INFO) << "Fused graph " << fused_graph_with_stats.second << "\n"; AddStatis(graph_with_stats.second);
AddStatis(fused_graph_with_stats.second);
} }
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
......
...@@ -28,19 +28,9 @@ namespace paddle { ...@@ -28,19 +28,9 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
class Graph;
class GraphPatternDetector;
class Node;
namespace patterns {
struct Conv;
} // namespace patterns
using graph_ptr = ir::Graph*;
using GraphWithStats = std::pair<ir::Graph*, int>; using GraphWithStats = std::pair<ir::Graph*, int>;
void CorrectGraphEdges(Graph* graph, Node* from, Node* to);
bool IsReachable(ir::Graph* graph, Node* from, Node* to); bool IsReachable(ir::Graph* graph, Node* from, Node* to);
paddle::optional<Node*> HasBias(const Node& op, const std::string& bias_name);
class ResidualConnectionMKLDNNFusePass : public FusePassBase { class ResidualConnectionMKLDNNFusePass : public FusePassBase {
private: private:
...@@ -52,91 +42,13 @@ class ResidualConnectionMKLDNNFusePass : public FusePassBase { ...@@ -52,91 +42,13 @@ class ResidualConnectionMKLDNNFusePass : public FusePassBase {
const std::string& name_scope, const std::string& name_scope,
const GraphWithStats& graph_with_stats) const; const GraphWithStats& graph_with_stats) const;
template <typename RetType>
using GetNodeFunc =
std::function<RetType(const GraphPatternDetector::subgraph_t& subgraph)>;
using IdentityConvFunc = GetNodeFunc<std::tuple<Node*, Node*, Node*, Node*>>;
using IdentityElementwiseAddFunc =
GetNodeFunc<std::tuple<Node*, Node*, Node*>>;
using ProjectionConvFunc = IdentityConvFunc;
using ProjectionElementwiseAddFunc = GetNodeFunc<std::tuple<Node*, Node*>>;
using CanFuseFunc = std::function<bool(Node*, Node*)>;
std::tuple<Node*, Node*, Node*, Node*> GetNodesFromConv(
const patterns::Conv& conv_pattern,
const GraphPatternDetector::subgraph_t& subgraph) const;
std::tuple<Node*, Node*, Node*, Node*> GetNodesFromProjectionConv(
const patterns::Conv& conv_pattern,
const GraphPatternDetector::subgraph_t& subgraph) const;
template <typename HandleType, typename... OpFuncs>
GraphWithStats ExecuteHandleOnGraph(GraphPatternDetector* gpd,
const GraphWithStats& graph_with_stats,
OpFuncs&&... op_funcs) const {
ir::Graph* graph;
int stats;
std::tie(graph, stats) = graph_with_stats;
auto can_fuse = [this](Node* op1, Node* op2) -> bool {
return this->FindFuseOption(*op1, *op2) == FUSE_MKLDNN;
};
auto fuse_handle = HandleType{can_fuse, std::forward<OpFuncs>(op_funcs)...};
(*gpd)(graph, fuse_handle);
return std::make_pair(graph, stats + fuse_handle.get_stats());
}
struct IdentityFuseHandle {
IdentityFuseHandle(
const CanFuseFunc& can_fuse_func,
const IdentityConvFunc& get_node_from_conv_op,
const IdentityElementwiseAddFunc& get_node_from_elementwise_add_op,
const ResidualConnectionMKLDNNFusePass* pass);
void operator()(const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph);
int get_stats() const { return *fusion_stats; }
private:
std::shared_ptr<int> fusion_stats;
CanFuseFunc can_fuse_func;
IdentityConvFunc get_node_from_conv_op;
IdentityElementwiseAddFunc get_node_from_elementwise_add_op;
const ResidualConnectionMKLDNNFusePass* pass_;
};
struct ProjectionFuseHandle {
ProjectionFuseHandle(
const CanFuseFunc& can_fuse_func,
const ProjectionConvFunc& get_node_from_conv_x_op,
const ProjectionConvFunc& get_node_from_conv_y_op,
const ProjectionElementwiseAddFunc& get_node_from_elementwise_add_op,
const ResidualConnectionMKLDNNFusePass* pass);
void operator()(const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph);
int get_stats() const { return *fusion_stats; }
private:
std::shared_ptr<int> fusion_stats;
CanFuseFunc can_fuse_func;
ProjectionConvFunc get_node_from_conv_x_op;
ProjectionConvFunc get_node_from_conv_y_op;
ProjectionElementwiseAddFunc get_node_from_elementwise_add_op;
const ResidualConnectionMKLDNNFusePass* pass_;
};
public: public:
ResidualConnectionMKLDNNFusePass(); ResidualConnectionMKLDNNFusePass();
virtual ~ResidualConnectionMKLDNNFusePass() {} virtual ~ResidualConnectionMKLDNNFusePass() {}
protected: protected:
void ApplyImpl(graph_ptr graph) const; void ApplyImpl(ir::Graph* graph) const;
static bool HasFusedActivation(Node* conv_node) { static bool HasFusedActivation(Node* conv_node) {
return !(conv_node->Op() return !(conv_node->Op()
->GetAttrIfExists<std::string>("fuse_activation") ->GetAttrIfExists<std::string>("fuse_activation")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册