conv_elementwise_add_mkldnn_fuse_pass.h 4.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15 16 17
#pragma once

#include <string>
18
#include <tuple>
19
#include <utility>
20 21 22 23
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"

24 25
#include <boost/optional.hpp>

26 27 28 29
namespace paddle {
namespace framework {
namespace ir {

30
using graph_ptr = std::unique_ptr<ir::Graph>;
31
using GraphWithStats = std::pair<ir::Graph*, int>;
32 33 34

void CorrectGraphEdges(Graph* graph, Node* from, Node* to);
bool IsReachable(ir::Graph* graph, Node* from, Node* to);
35
boost::optional<Node*> HasBias(const Node& op, const std::string& bias_name);
36 37 38

class ResidualConnectionMKLDNNFusePass : public FusePassBase {
 private:
39 40 41 42
  GraphWithStats FuseConvAsX(const std::string& name_scope,
                             const GraphWithStats& graph_with_stats) const;
  GraphWithStats FuseConvAsY(const std::string& name_scope,
                             const GraphWithStats& graph_with_stats) const;
43 44 45
  GraphWithStats FuseProjectionConv(
      const std::string& name_scope,
      const GraphWithStats& graph_with_stats) const;
46

47 48 49
  template <typename RetType>
  using GetNodeFunc =
      std::function<RetType(const GraphPatternDetector::subgraph_t& subgraph)>;
50 51 52 53 54 55 56
  using IdentityConvFunc = GetNodeFunc<std::tuple<Node*, Node*, Node*, Node*>>;
  using IdentityElementwiseAddFunc =
      GetNodeFunc<std::tuple<Node*, Node*, Node*>>;

  using ProjectionConvFunc = IdentityConvFunc;
  using ProjectionElementwiseAddFunc = GetNodeFunc<std::tuple<Node*, Node*>>;

57 58
  using CanFuseFunc = std::function<bool(Node*, Node*)>;

59 60 61 62
  std::tuple<Node*, Node*, Node*, Node*> GetNodesFromConv(
      const patterns::Conv& conv_pattern,
      const GraphPatternDetector::subgraph_t& subgraph) const;

63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102
  std::tuple<Node*, Node*, Node*, Node*> GetNodesFromProjectionConv(
      const patterns::Conv& conv_pattern,
      const GraphPatternDetector::subgraph_t& subgraph) const;

  template <typename HandleType, typename... OpFuncs>
  GraphWithStats ExecuteHandleOnGraph(GraphPatternDetector* gpd,
                                      const GraphWithStats& graph_with_stats,
                                      OpFuncs&&... op_funcs) const {
    ir::Graph* graph;
    int stats;

    std::tie(graph, stats) = graph_with_stats;

    auto can_fuse = [this](Node* op1, Node* op2) -> bool {
      return this->FindFuseOption(*op1, *op2) == FUSE_MKLDNN;
    };

    auto fuse_handle = HandleType{can_fuse, std::forward<OpFuncs>(op_funcs)...};

    (*gpd)(graph, fuse_handle);

    return std::make_pair(graph, stats + fuse_handle.get_stats());
  }

  struct IdentityFuseHandle {
    IdentityFuseHandle(
        const CanFuseFunc& can_fuse_func,
        const IdentityConvFunc& get_node_from_conv_op,
        const IdentityElementwiseAddFunc& get_node_from_elementwise_add_op);

    void operator()(const GraphPatternDetector::subgraph_t& subgraph,
                    Graph* graph);
    int get_stats() const { return *fusion_stats; }

   private:
    std::shared_ptr<int> fusion_stats;
    CanFuseFunc can_fuse_func;
    IdentityConvFunc get_node_from_conv_op;
    IdentityElementwiseAddFunc get_node_from_elementwise_add_op;
  };
103

104 105 106 107 108 109
  struct ProjectionFuseHandle {
    ProjectionFuseHandle(
        const CanFuseFunc& can_fuse_func,
        const ProjectionConvFunc& get_node_from_conv_x_op,
        const ProjectionConvFunc& get_node_from_conv_y_op,
        const ProjectionElementwiseAddFunc& get_node_from_elementwise_add_op);
110

111 112 113 114 115 116
    void operator()(const GraphPatternDetector::subgraph_t& subgraph,
                    Graph* graph);
    int get_stats() const { return *fusion_stats; }

   private:
    std::shared_ptr<int> fusion_stats;
117
    CanFuseFunc can_fuse_func;
118 119 120
    ProjectionConvFunc get_node_from_conv_x_op;
    ProjectionConvFunc get_node_from_conv_y_op;
    ProjectionElementwiseAddFunc get_node_from_elementwise_add_op;
121
  };
122

123
 public:
124
  virtual ~ResidualConnectionMKLDNNFusePass() {}
125 126

 protected:
127
  std::unique_ptr<ir::Graph> ApplyImpl(graph_ptr graph) const;
128

129
  const std::string name_scope_{"residual_connection_fuse_pass"};
130 131 132 133
};
}  // namespace ir
}  // namespace framework
}  // namespace paddle