paddle_pass_builder.h 7.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <sstream>
#include <string>
#include <vector>

21 22 23 24 25 26 27 28 29 30 31 32
///
/// \file paddle_pass_builder.h
///
/// \brief Class Paddle Passs Builder and its subclasses(pass strategies).
/// \section sec_intro Introduction
/// This class aims to build passes for paddle and define passes' strategies.
///
/// \author paddle-infer@baidu.com
/// \date 2020-3-23
/// \since 1.7

/// \namespace paddle
33
namespace paddle {
34

35 36 37 38 39 40 41 42 43 44 45
/// \class PaddlePassBuilder
/// \brief This class build passes based on vector<string> input. It is part of
/// inference API. Users can build passes, insert new passes, delete passes
/// using this class and its functions.
///
/// Example Usage:
///     Build a new pass.
/// \code{cpp}
/// const vector<string> passes(1, "conv_relu_mkldnn_fuse_pass");
/// PaddlePassBuilder builder(passes);
/// \endcode
46 47
class PaddlePassBuilder {
 public:
48 49
  /// \brief Constructor of the class. It stores the input passes.
  /// \param[in] passes passes' types.
50 51 52
  explicit PaddlePassBuilder(const std::vector<std::string> &passes)
      : passes_(passes) {}

53 54
  /// \brief Stores the input passes.
  /// \param[in] passes passes' types.
55 56 57 58
  void SetPasses(std::initializer_list<std::string> passes) {
    passes_ = passes;
  }

59 60
  /// \brief Append a pass to the end of the passes.
  /// \param[in] pass_type the type of the new pass.
61 62
  void AppendPass(const std::string &pass_type);

63 64 65
  /// \brief Insert a pass to a specific position.
  /// \param[in] idx the position to insert.
  /// \param[in] pass_type the type of insert pass.
66 67
  void InsertPass(size_t idx, const std::string &pass_type);

68 69
  /// \brief Delete the pass at certain position 'idx'.
  /// \param[in] idx the position to delete.
70 71
  void DeletePass(size_t idx);

72 73
  /// \brief Delete all passes that has a certain type 'pass_type'.
  /// \param[in] pass_type the certain pass type to be deleted.
74 75
  void DeletePass(const std::string &pass_type);

76
  /// \brief Delete all the passes.
77
  void ClearPasses();
78 79 80

  /// \brief Append an analysis pass.
  /// \param[in] pass the type of the new analysis pass.
Y
Yan Chunwei 已提交
81 82
  void AppendAnalysisPass(const std::string &pass);

83 84
  /// \brief Visualize the computation graph after each pass by generating a DOT
  /// language file, one can draw them with the Graphviz toolkit.
85
  void TurnOnDebug();
86
  /// \brief Human-readable information of the passes.
87 88
  std::string DebugString();

89 90
  /// \brief Get information of passes.
  /// \return Return list of the passes.
91
  const std::vector<std::string> &AllPasses() const { return passes_; }
92 93 94

  /// \brief Get information of analysis passes.
  /// \return Return list of analysis passes.
Y
Yan Chunwei 已提交
95 96 97 98 99 100 101
  std::vector<std::string> AnalysisPasses() const {
    auto passes = analysis_passes_;
    // To make sure the ir_graph_to_program should be the last pass so any
    // modication of IR will persist to the program.
    passes.push_back("ir_graph_to_program_pass");
    return passes;
  }
102 103

 protected:
104
  /// \cond Protected
Y
Yan Chunwei 已提交
105
  std::vector<std::string> analysis_passes_{
106
      {"ir_graph_build_pass", "ir_graph_clean_pass", "ir_analysis_pass",
107 108
       "ir_params_sync_among_devices_pass", "adjust_cudnn_workspace_size_pass",
       "inference_op_replace_pass"}};
109
  std::vector<std::string> passes_;
110
  /// \endcond
111 112
};

113 114 115
/// \class PassStrategy
/// \brief This class defines the pass strategies like whether to use gpu/cuDNN
/// kernel/MKLDNN.
116 117
class PassStrategy : public PaddlePassBuilder {
 public:
118 119
  /// \brief Constructor of PassStrategy class. It works the same as
  /// PaddlePassBuilder class. \param[in] passes passes' types.
120 121 122
  explicit PassStrategy(const std::vector<std::string> &passes)
      : PaddlePassBuilder(passes) {}

123
  /// \brief Enable the use of cuDNN kernel.
124 125
  virtual void EnableCUDNN() {}

126 127 128
  /// \brief Enable the use of MKLDNN.
  /// The MKLDNN control exists in both CPU and GPU mode, because there can
  /// still be some CPU kernels running in GPU mode.
Y
Yan Chunwei 已提交
129
  virtual void EnableMKLDNN() {}
130

131
  /// \brief Enable MKLDNN quantize optimization.
132
  virtual void EnableMkldnnQuantizer() {}
133

134 135
  /// \brief Check if we are using gpu.
  /// \return A bool variable implying whether we are in gpu mode.
136 137
  bool use_gpu() const { return use_gpu_; }

138
  /// \brief Default destructor.
139
  virtual ~PassStrategy() = default;
140 141

 protected:
142
  /// \cond Protected
143
  bool use_gpu_{false};
Y
Yan Chunwei 已提交
144
  bool use_mkldnn_{false};
145
  /// \endcond
146 147
};

148 149 150
/// \class CpuPassStrategy
/// \brief The CPU passes controller, it is used in AnalysisPredictor with CPU
/// mode.
151 152
class CpuPassStrategy : public PassStrategy {
 public:
153
  /// \brief Default constructor of CpuPassStrategy.
154
  CpuPassStrategy();
155

156 157
  /// \brief Construct by copying another CpuPassStrategy object.
  /// \param[in] other The CpuPassStrategy object we want to copy.
Y
Yan Chunwei 已提交
158
  explicit CpuPassStrategy(const CpuPassStrategy &other)
W
Wojciech Uss 已提交
159 160 161 162 163
      : PassStrategy(other.AllPasses()) {
    use_gpu_ = other.use_gpu_;
    use_mkldnn_ = other.use_mkldnn_;
    use_mkldnn_quantizer_ = other.use_mkldnn_quantizer_;
  }
164
  /// \brief Default destructor.
165 166
  virtual ~CpuPassStrategy() = default;

167
  /// \brief Enable the use of cuDNN kernel.
168
  void EnableCUDNN() override;
169 170

  /// \brief Enable the use of MKLDNN.
W
Wojciech Uss 已提交
171
  void EnableMKLDNN() override;
172 173

  /// \brief Enable MKLDNN quantize optimization.
W
Wojciech Uss 已提交
174
  void EnableMkldnnQuantizer() override;
175 176

 protected:
177
  /// \cond Protected
178
  bool use_mkldnn_quantizer_{false};
179
  /// \endcond
180 181
};

182 183 184
/// \class GpuPassStrategy
/// \brief The GPU passes controller, it is used in AnalysisPredictor with GPU
/// mode.
185 186
class GpuPassStrategy : public PassStrategy {
 public:
187
  /// \brief Default constructor of GpuPassStrategy.
188
  GpuPassStrategy();
189

190 191
  /// \brief Construct by copying another GpuPassStrategy object.
  /// \param[in] other The GpuPassStrategy object we want to copy.
Y
Yan Chunwei 已提交
192
  explicit GpuPassStrategy(const GpuPassStrategy &other)
193 194
      : PassStrategy(other.AllPasses()) {
    use_gpu_ = true;
195
    use_cudnn_ = other.use_cudnn_;
196
  }
197

198
  /// \brief Enable the use of cuDNN kernel.
199
  void EnableCUDNN() override;
200 201

  /// \brief Not supported in GPU mode yet.
202
  void EnableMKLDNN() override;
203 204

  /// \brief Not supported in GPU mode yet.
205
  void EnableMkldnnQuantizer() override;
206

207
  /// \brief Default destructor.
208
  virtual ~GpuPassStrategy() = default;
209 210

 protected:
211
  /// \cond Protected
212
  bool use_cudnn_{false};
213
  /// \endcond
214
};
215
/// \brief List of tensorRT subgraph passes.
216
extern const std::vector<std::string> kTRTSubgraphPasses;
217 218

/// \brief List of lite subgraph passes.
石晓伟 已提交
219
extern const std::vector<std::string> kLiteSubgraphPasses;
220

221
}  // namespace paddle