paddle_pass_builder.h 4.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <sstream>
#include <string>
#include <vector>

21 22 23
/*! \file */

/*! \namespace paddle */
24
namespace paddle {
25 26

/** This is a pass builder based on string. It is part of inference API.
27 28 29 30 31 32
 */
class PaddlePassBuilder {
 public:
  explicit PaddlePassBuilder(const std::vector<std::string> &passes)
      : passes_(passes) {}

33
  /** Append a pass to the end of the passes. */
34 35
  void AppendPass(const std::string &pass_type);

36 37 38 39
  /** Insert a pass to a specific position.
   * @param idx the position to insert.
   * @param pass_type the pass key.
   */
40 41
  void InsertPass(size_t idx, const std::string &pass_type);

42
  /** Delete the `idx`-th pass. */
43 44
  void DeletePass(size_t idx);

45
  /** Delete all the passes that has type `pass_type`. */
46 47
  void DeletePass(const std::string &pass_type);

48 49 50
  /** Visualize the computation graph after each pass by generating a DOT
   * language file, one can draw them with the Graphviz toolkit.
   */
51 52
  void TurnOnDebug();

53
  /** Human-readible information. */
54 55 56 57 58 59 60 61
  std::string DebugString();

  const std::vector<std::string> &AllPasses() const { return passes_; }

 protected:
  std::vector<std::string> passes_;
};

62
/**Pass strategy to help control the IR passes.
63 64 65 66 67 68
 */
class PassStrategy : public PaddlePassBuilder {
 public:
  explicit PassStrategy(const std::vector<std::string> &passes)
      : PaddlePassBuilder(passes) {}

69 70 71
  /** The MKLDNN control exists in both CPU and GPU mode, because there can be
   * still some CPU kernels running in CPU mode.
   */
72 73
  virtual void EnableMKLDNN() = 0;

74 75
  bool use_gpu() const { return use_gpu_; }

76
  virtual ~PassStrategy() = default;
77 78 79

 protected:
  bool use_gpu_{false};
80 81
};

82
/** The CPU passes controller, it is used in AnalysisPredictor with CPU mode.
83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101
 */
class CpuPassStrategy : public PassStrategy {
 public:
  CpuPassStrategy() : PassStrategy({}) {
    // NOTE the large fusions should be located in the front, so that they will
    // not be damaged by smaller ones.
    passes_.assign({
        "infer_clean_graph_pass",         //
        "attention_lstm_fuse_pass",       //
        "seqconv_eltadd_relu_fuse_pass",  //
        // "embedding_fc_lstm_fuse_pass", //
        "fc_lstm_fuse_pass",             //
        "mul_lstm_fuse_pass",            //
        "fc_gru_fuse_pass",              //
        "mul_gru_fuse_pass",             //
        "seq_concat_fc_fuse_pass",       //
        "fc_fuse_pass",                  //
        "conv_bn_fuse_pass",             //
        "conv_eltwiseadd_bn_fuse_pass",  //
102
        "is_test_pass",                  //
103
    });
104
    use_gpu_ = false;
105 106 107 108
  }

  virtual ~CpuPassStrategy() = default;

109
  void EnableMKLDNN() override {
110 111 112 113 114
// TODO(Superjomn) Consider the way to mix CPU with GPU.
#ifdef PADDLE_WITH_MKLDNN
    passes_.insert(passes_.begin(), "mkldnn_placement_pass");

    for (auto &pass :
115 116 117 118
         std::vector<std::string>({"depthwise_conv_mkldnn_pass",    //
                                   "conv_bias_mkldnn_fuse_pass",    //
                                   "conv3d_bias_mkldnn_fuse_pass",  //
                                   "conv_relu_mkldnn_fuse_pass",    //
119 120 121 122 123 124 125 126 127
                                   "conv_elementwise_add_mkldnn_fuse_pass"})) {
      passes_.push_back(pass);
    }
#endif
  }

  CpuPassStrategy(const CpuPassStrategy &other) : PassStrategy(other.passes_) {}
};

128
/** The GPU passes strategy, it is used in AnalysisPredictor with GPU mode.
129 130 131 132 133
 */
class GpuPassStrategy : public PassStrategy {
 public:
  GpuPassStrategy() : PassStrategy({}) {
    passes_.assign({
N
nhzlx 已提交
134 135 136 137 138 139 140
        "infer_clean_graph_pass",                    //
        "conv_affine_channel_fuse_pass",             //
        "conv_eltwiseadd_affine_channel_fuse_pass",  //
        "conv_bn_fuse_pass",                         //
        "conv_elementwise_add_act_fuse_pass",        //
        "conv_elementwise_add2_act_fuse_pass",       //
        "conv_elementwise_add_fuse_pass",            //
141
    });
142 143

    use_gpu_ = true;
144 145 146
  }

  GpuPassStrategy(const GpuPassStrategy &other)
147 148 149
      : PassStrategy(other.AllPasses()) {
    use_gpu_ = true;
  }
150

151
  void EnableMKLDNN() override;
152 153 154 155 156

  virtual ~GpuPassStrategy() = default;
};

}  // namespace paddle