提交 86fd3b32 编写于 作者: T Tomasz Patejko

MKLDNN residual connections fuse pass: counting statistics added to the pass

上级 ee6f778b
......@@ -21,11 +21,45 @@
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include <boost/optional.hpp>
namespace paddle {
namespace framework {
namespace ir {
// poor replacement for C++17 std::optional and Boost.Optional
struct InPlace {};
InPlace in_place;
template <typename T>
class Maybe {
private:
typename std::aligned_storage<sizeof(T), alignof(T)>::type data;
bool is_initialized{false};
public:
template <typename... Args>
explicit Maybe(InPlace, Args&&... args) {
new (&data) T(std::forward<Args>(args)...);
is_initialized = true;
}
Maybe() {}
operator bool() { return is_initialized; }
T& value() { return *reinterpret_cast<T*>(&data); }
~Maybe() { reinterpret_cast<T*>(&data)->~T(); }
};
template <typename T, typename... Args>
Maybe<T> MakeMaybe(Args&&... args) {
return Maybe<T>(in_place, std::forward<Args>(args)...);
}
using graph_ptr = std::unique_ptr<ir::Graph>;
using GraphWithStats = std::pair<ir::Graph*, Maybe<int>>;
void CorrectGraphEdges(Graph* graph, Node* from, Node* to);
bool IsReachable(ir::Graph* graph, Node* from, Node* to);
......@@ -33,8 +67,10 @@ std::pair<bool, Node*> HasBias(const Node& op, const std::string& bias_name);
class ResidualConnectionMKLDNNFusePass : public FusePassBase {
private:
graph_ptr FuseConvAsX(const std::string& name_scope_, graph_ptr graph) const;
graph_ptr FuseConvAsY(const std::string& name_scope_, graph_ptr graph) const;
GraphWithStats FuseConvAsX(const std::string& name_scope,
const GraphWithStats& graph_with_stats) const;
GraphWithStats FuseConvAsY(const std::string& name_scope,
const GraphWithStats& graph_with_stats) const;
template <typename RetType>
using GetNodeFunc =
......@@ -48,12 +84,15 @@ class ResidualConnectionMKLDNNFusePass : public FusePassBase {
const ElementwiseAddFunc& get_node_from_elementwise_add_op,
const CanFuseFunc& can_fuse_func);
void operator()(const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph);
int get_stats() const { return *fusion_stats; }
private:
std::shared_ptr<int> fusion_stats;
ConvFunc get_node_from_conv_op;
ElementwiseAddFunc get_node_from_elementwise_add_op;
CanFuseFunc can_fuse_func;
void operator()(const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph);
};
public:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册