From 86fd3b32bea089c519249a459414a15349ec57b0 Mon Sep 17 00:00:00 2001 From: Tomasz Patejko Date: Wed, 7 Nov 2018 16:36:06 +0100 Subject: [PATCH] MKLDNN residual connections fuse pass: counting statistics added to the pass --- .../conv_elementwise_add_mkldnn_fuse_pass.h | 49 +++++++++++++++++-- 1 file changed, 44 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h b/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h index b614b5c523..de4d1075e2 100644 --- a/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h +++ b/paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass.h @@ -21,11 +21,45 @@ #include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include + namespace paddle { namespace framework { namespace ir { +// poor replacement for C++17 std::optional and Boost.Optional +struct InPlace {}; +InPlace in_place; + +template +class Maybe { + private: + typename std::aligned_storage::type data; + bool is_initialized{false}; + + public: + template + explicit Maybe(InPlace, Args&&... args) { + new (&data) T(std::forward(args)...); + is_initialized = true; + } + + Maybe() {} + + operator bool() { return is_initialized; } + + T& value() { return *reinterpret_cast(&data); } + + ~Maybe() { reinterpret_cast(&data)->~T(); } +}; + +template +Maybe MakeMaybe(Args&&... args) { + return Maybe(in_place, std::forward(args)...); +} + using graph_ptr = std::unique_ptr; +using GraphWithStats = std::pair>; void CorrectGraphEdges(Graph* graph, Node* from, Node* to); bool IsReachable(ir::Graph* graph, Node* from, Node* to); @@ -33,8 +67,10 @@ std::pair HasBias(const Node& op, const std::string& bias_name); class ResidualConnectionMKLDNNFusePass : public FusePassBase { private: - graph_ptr FuseConvAsX(const std::string& name_scope_, graph_ptr graph) const; - graph_ptr FuseConvAsY(const std::string& name_scope_, graph_ptr graph) const; + GraphWithStats FuseConvAsX(const std::string& name_scope, + const GraphWithStats& graph_with_stats) const; + GraphWithStats FuseConvAsY(const std::string& name_scope, + const GraphWithStats& graph_with_stats) const; template using GetNodeFunc = @@ -48,12 +84,15 @@ class ResidualConnectionMKLDNNFusePass : public FusePassBase { const ElementwiseAddFunc& get_node_from_elementwise_add_op, const CanFuseFunc& can_fuse_func); + void operator()(const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph); + int get_stats() const { return *fusion_stats; } + + private: + std::shared_ptr fusion_stats; ConvFunc get_node_from_conv_op; ElementwiseAddFunc get_node_from_elementwise_add_op; CanFuseFunc can_fuse_func; - - void operator()(const GraphPatternDetector::subgraph_t& subgraph, - Graph* graph); }; public: -- GitLab