提交 42240893 编写于 作者: T Tomasz Patejko

MKLDNN residual connections fuse pass: Maybe removed and boost::optional used where it makes sense

上级 86fd3b32
...@@ -99,7 +99,7 @@ bool IsReachable(ir::Graph* graph, Node* from, Node* to) { ...@@ -99,7 +99,7 @@ bool IsReachable(ir::Graph* graph, Node* from, Node* to) {
return false; return false;
} }
std::pair<bool, Node*> HasBias(const Node& op, const std::string& bias_name) { boost::optional<Node*> HasBias(const Node& op, const std::string& bias_name) {
auto bias_input_names = op.Op()->Inputs(); auto bias_input_names = op.Op()->Inputs();
auto bias_it = bias_input_names.find(bias_name); auto bias_it = bias_input_names.find(bias_name);
...@@ -113,11 +113,11 @@ std::pair<bool, Node*> HasBias(const Node& op, const std::string& bias_name) { ...@@ -113,11 +113,11 @@ std::pair<bool, Node*> HasBias(const Node& op, const std::string& bias_name) {
[&bias_names](Node* n) -> bool { [&bias_names](Node* n) -> bool {
return n->Name() == bias_names[0]; return n->Name() == bias_names[0];
}); });
return std::make_pair(has_bias, *bias_names_it); return *bias_names_it;
} }
} }
return std::make_pair(false, nullptr); return boost::none;
} }
ResidualConnectionMKLDNNFusePass::FuseHandler::FuseHandler( ResidualConnectionMKLDNNFusePass::FuseHandler::FuseHandler(
...@@ -125,7 +125,8 @@ ResidualConnectionMKLDNNFusePass::FuseHandler::FuseHandler( ...@@ -125,7 +125,8 @@ ResidualConnectionMKLDNNFusePass::FuseHandler::FuseHandler(
const ResidualConnectionMKLDNNFusePass::ElementwiseAddFunc& const ResidualConnectionMKLDNNFusePass::ElementwiseAddFunc&
get_node_from_elementwise_add_op, get_node_from_elementwise_add_op,
const ResidualConnectionMKLDNNFusePass::CanFuseFunc& can_fuse_func) const ResidualConnectionMKLDNNFusePass::CanFuseFunc& can_fuse_func)
: get_node_from_conv_op{get_node_from_conv_op}, : fusion_stats{std::make_shared<int>(0)},
get_node_from_conv_op{get_node_from_conv_op},
get_node_from_elementwise_add_op{get_node_from_elementwise_add_op}, get_node_from_elementwise_add_op{get_node_from_elementwise_add_op},
can_fuse_func{can_fuse_func} {} can_fuse_func{can_fuse_func} {}
...@@ -157,13 +158,10 @@ void ResidualConnectionMKLDNNFusePass::FuseHandler::operator()( ...@@ -157,13 +158,10 @@ void ResidualConnectionMKLDNNFusePass::FuseHandler::operator()(
op_desc.SetInput("ResidualData", {elementwise_add_identity->Name()}); op_desc.SetInput("ResidualData", {elementwise_add_identity->Name()});
op_desc.SetOutput("Output", {conv_output->Name()}); op_desc.SetOutput("Output", {conv_output->Name()});
bool has_bias; auto conv_bias = HasBias(*conv_op, "Bias");
Node* conv_bias;
std::tie(has_bias, conv_bias) = HasBias(*conv_op, "Bias"); if (conv_bias) {
op_desc.SetInput("Bias", {(*conv_bias)->Name()});
if (has_bias) {
op_desc.SetInput("Bias", {conv_bias->Name()});
} }
for (const auto& attr : conv_op->Op()->GetAttrMap()) { for (const auto& attr : conv_op->Op()->GetAttrMap()) {
...@@ -179,40 +177,48 @@ void ResidualConnectionMKLDNNFusePass::FuseHandler::operator()( ...@@ -179,40 +177,48 @@ void ResidualConnectionMKLDNNFusePass::FuseHandler::operator()(
IR_NODE_LINK_TO(elementwise_add_identity, fused_conv_op); IR_NODE_LINK_TO(elementwise_add_identity, fused_conv_op);
IR_NODE_LINK_TO(fused_conv_op, conv_output); IR_NODE_LINK_TO(fused_conv_op, conv_output);
if (has_bias) { if (conv_bias) {
IR_NODE_LINK_TO(conv_bias, fused_conv_op); IR_NODE_LINK_TO((*conv_bias), fused_conv_op);
} }
CorrectGraphEdges(graph, elementwise_add_out, conv_output); CorrectGraphEdges(graph, elementwise_add_out, conv_output);
GraphSafeRemoveNodes(graph, GraphSafeRemoveNodes(graph,
{elementwise_add_out, conv_op, elementwise_add_op}); {elementwise_add_out, conv_op, elementwise_add_op});
(*fusion_stats)++;
}
std::tuple<Node*, Node*, Node*, Node*>
ResidualConnectionMKLDNNFusePass::GetNodesFromConv(
const patterns::Conv& conv_pattern,
const GraphPatternDetector::subgraph_t& subgraph) const {
GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_input, conv_input, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_filter, conv_filter, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_output, conv_output, conv_pattern);
return std::make_tuple(conv_op, conv_input, conv_filter, conv_output);
} }
graph_ptr ResidualConnectionMKLDNNFusePass::FuseConvAsX( GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsX(
const std::string& name_scope_, graph_ptr graph) const { const std::string& name_scope,
const GraphWithStats& graph_with_stats) const {
ir::Graph* graph;
int stats;
std::tie(graph, stats) = graph_with_stats;
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto pattern = gpd.mutable_pattern(); auto pattern = gpd.mutable_pattern();
patterns::Conv conv_pattern{pattern, name_scope_}; patterns::Conv conv_pattern{pattern, name_scope};
auto conv_output = conv_pattern(); auto conv_output = conv_pattern();
patterns::ElementwiseAdd elementwise_add_pattern{pattern, name_scope_}; patterns::ElementwiseAdd elementwise_add_pattern{pattern, name_scope};
elementwise_add_pattern( elementwise_add_pattern(
conv_output, conv_output,
pattern->NewNode(elementwise_add_pattern.elementwise_add_y_repr())); pattern->NewNode(elementwise_add_pattern.elementwise_add_y_repr()));
conv_output->AsIntermediate(); conv_output->AsIntermediate();
auto get_node_from_conv =
[&conv_pattern](const GraphPatternDetector::subgraph_t& subgraph)
-> std::tuple<Node*, Node*, Node*, Node*> {
GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_input, conv_input, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_filter, conv_filter, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_output, conv_output, conv_pattern);
return std::make_tuple(conv_op, conv_input, conv_filter, conv_output);
};
auto get_node_from_elementwise_add = [&elementwise_add_pattern]( auto get_node_from_elementwise_add = [&elementwise_add_pattern](
const GraphPatternDetector::subgraph_t& subgraph) const GraphPatternDetector::subgraph_t& subgraph)
-> std::tuple<Node*, Node*, Node*> { -> std::tuple<Node*, Node*, Node*> {
...@@ -227,43 +233,29 @@ graph_ptr ResidualConnectionMKLDNNFusePass::FuseConvAsX( ...@@ -227,43 +233,29 @@ graph_ptr ResidualConnectionMKLDNNFusePass::FuseConvAsX(
elementwise_add_out); elementwise_add_out);
}; };
auto can_fuse = [this](Node* op1, Node* op2) -> bool { return ExecuteHandlerOnGraph(
return this->FindFuseOption(*op1, *op2) == FUSE_MKLDNN; &gpd, graph_with_stats,
}; [this, &conv_pattern](const GraphPatternDetector::subgraph_t& subgraph) {
return GetNodesFromConv(conv_pattern, subgraph);
auto fuse_handler = },
FuseHandler{get_node_from_conv, get_node_from_elementwise_add, can_fuse}; get_node_from_elementwise_add);
gpd(graph.get(), fuse_handler);
return graph;
} }
graph_ptr ResidualConnectionMKLDNNFusePass::FuseConvAsY( GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsY(
const std::string& name_scope_, graph_ptr graph) const { const std::string& name_scope,
const GraphWithStats& graph_with_stats) const {
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto pattern = gpd.mutable_pattern(); auto pattern = gpd.mutable_pattern();
patterns::Conv conv_pattern{pattern, name_scope_}; patterns::Conv conv_pattern{pattern, name_scope};
auto conv_output = conv_pattern(); auto conv_output = conv_pattern();
patterns::ElementwiseAdd elementwise_add_pattern{pattern, name_scope_}; patterns::ElementwiseAdd elementwise_add_pattern{pattern, name_scope};
elementwise_add_pattern( elementwise_add_pattern(
pattern->NewNode(elementwise_add_pattern.elementwise_add_x_repr()), pattern->NewNode(elementwise_add_pattern.elementwise_add_x_repr()),
conv_output); conv_output);
conv_output->AsIntermediate(); conv_output->AsIntermediate();
auto get_node_from_conv =
[&conv_pattern](const GraphPatternDetector::subgraph_t& subgraph)
-> std::tuple<Node*, Node*, Node*, Node*> {
GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_input, conv_input, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_filter, conv_filter, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_output, conv_output, conv_pattern);
return std::make_tuple(conv_op, conv_input, conv_filter, conv_output);
};
auto get_node_from_elementwise_add = [&elementwise_add_pattern]( auto get_node_from_elementwise_add = [&elementwise_add_pattern](
const GraphPatternDetector::subgraph_t& subgraph) const GraphPatternDetector::subgraph_t& subgraph)
-> std::tuple<Node*, Node*, Node*> { -> std::tuple<Node*, Node*, Node*> {
...@@ -278,6 +270,24 @@ graph_ptr ResidualConnectionMKLDNNFusePass::FuseConvAsY( ...@@ -278,6 +270,24 @@ graph_ptr ResidualConnectionMKLDNNFusePass::FuseConvAsY(
elementwise_add_out); elementwise_add_out);
}; };
return ExecuteHandlerOnGraph(
&gpd, graph_with_stats,
[this, &conv_pattern](const GraphPatternDetector::subgraph_t& subgraph) {
return GetNodesFromConv(conv_pattern, subgraph);
},
get_node_from_elementwise_add);
}
GraphWithStats ResidualConnectionMKLDNNFusePass::ExecuteHandlerOnGraph(
GraphPatternDetector* gpd, const GraphWithStats& graph_with_stats,
const ResidualConnectionMKLDNNFusePass::ConvFunc& get_node_from_conv,
const ResidualConnectionMKLDNNFusePass::ElementwiseAddFunc&
get_node_from_elementwise_add) const {
ir::Graph* graph;
int stats;
std::tie(graph, stats) = graph_with_stats;
auto can_fuse = [this](Node* op1, Node* op2) -> bool { auto can_fuse = [this](Node* op1, Node* op2) -> bool {
return this->FindFuseOption(*op1, *op2) == FUSE_MKLDNN; return this->FindFuseOption(*op1, *op2) == FUSE_MKLDNN;
}; };
...@@ -285,15 +295,20 @@ graph_ptr ResidualConnectionMKLDNNFusePass::FuseConvAsY( ...@@ -285,15 +295,20 @@ graph_ptr ResidualConnectionMKLDNNFusePass::FuseConvAsY(
auto fuse_handler = auto fuse_handler =
FuseHandler{get_node_from_conv, get_node_from_elementwise_add, can_fuse}; FuseHandler{get_node_from_conv, get_node_from_elementwise_add, can_fuse};
gpd(graph.get(), fuse_handler); (*gpd)(graph, fuse_handler);
return graph; return std::make_pair(graph, stats + fuse_handler.get_stats());
} }
graph_ptr ResidualConnectionMKLDNNFusePass::ApplyImpl(graph_ptr graph) const { graph_ptr ResidualConnectionMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
FusePassBase::Init(name_scope_, graph.get()); FusePassBase::Init(name_scope_, graph.get());
return FuseConvAsY(name_scope_, FuseConvAsX(name_scope_, std::move(graph))); auto fused_graph_with_stats = FuseConvAsY(
name_scope_, FuseConvAsX(name_scope_, std::make_pair(graph.get(), 0)));
std::cout << "Fused graph " << fused_graph_with_stats.second << std::endl;
AddStatis(fused_graph_with_stats.second);
return graph;
} }
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
......
...@@ -27,43 +27,12 @@ namespace paddle { ...@@ -27,43 +27,12 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
// poor replacement for C++17 std::optional and Boost.Optional
struct InPlace {};
InPlace in_place;
template <typename T>
class Maybe {
private:
typename std::aligned_storage<sizeof(T), alignof(T)>::type data;
bool is_initialized{false};
public:
template <typename... Args>
explicit Maybe(InPlace, Args&&... args) {
new (&data) T(std::forward<Args>(args)...);
is_initialized = true;
}
Maybe() {}
operator bool() { return is_initialized; }
T& value() { return *reinterpret_cast<T*>(&data); }
~Maybe() { reinterpret_cast<T*>(&data)->~T(); }
};
template <typename T, typename... Args>
Maybe<T> MakeMaybe(Args&&... args) {
return Maybe<T>(in_place, std::forward<Args>(args)...);
}
using graph_ptr = std::unique_ptr<ir::Graph>; using graph_ptr = std::unique_ptr<ir::Graph>;
using GraphWithStats = std::pair<ir::Graph*, Maybe<int>>; using GraphWithStats = std::pair<ir::Graph*, int>;
void CorrectGraphEdges(Graph* graph, Node* from, Node* to); void CorrectGraphEdges(Graph* graph, Node* from, Node* to);
bool IsReachable(ir::Graph* graph, Node* from, Node* to); bool IsReachable(ir::Graph* graph, Node* from, Node* to);
std::pair<bool, Node*> HasBias(const Node& op, const std::string& bias_name); boost::optional<Node*> HasBias(const Node& op, const std::string& bias_name);
class ResidualConnectionMKLDNNFusePass : public FusePassBase { class ResidualConnectionMKLDNNFusePass : public FusePassBase {
private: private:
...@@ -79,6 +48,15 @@ class ResidualConnectionMKLDNNFusePass : public FusePassBase { ...@@ -79,6 +48,15 @@ class ResidualConnectionMKLDNNFusePass : public FusePassBase {
using ElementwiseAddFunc = GetNodeFunc<std::tuple<Node*, Node*, Node*>>; using ElementwiseAddFunc = GetNodeFunc<std::tuple<Node*, Node*, Node*>>;
using CanFuseFunc = std::function<bool(Node*, Node*)>; using CanFuseFunc = std::function<bool(Node*, Node*)>;
std::tuple<Node*, Node*, Node*, Node*> GetNodesFromConv(
const patterns::Conv& conv_pattern,
const GraphPatternDetector::subgraph_t& subgraph) const;
GraphWithStats ExecuteHandlerOnGraph(
GraphPatternDetector* gpd, const GraphWithStats& graph_with_stats,
const ConvFunc& get_node_from_conv,
const ElementwiseAddFunc& get_node_from_elementwise_add) const;
struct FuseHandler { struct FuseHandler {
FuseHandler(const ConvFunc& get_node_from_conv_op, FuseHandler(const ConvFunc& get_node_from_conv_op,
const ElementwiseAddFunc& get_node_from_elementwise_add_op, const ElementwiseAddFunc& get_node_from_elementwise_add_op,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册