paddle_pass_builder.cc 10.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/inference/api/paddle_pass_builder.h"
16 17 18
#ifdef PADDLE_WITH_CUDA
#include <cudnn.h>
#endif
19 20 21
#ifdef PADDLE_WITH_HIP
#include <miopen/miopen.h>
#endif
22
#include <glog/logging.h>
23
#include <sstream>
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70

namespace paddle {

void PaddlePassBuilder::AppendPass(const std::string &pass_type) {
  passes_.push_back(pass_type);
}

void PaddlePassBuilder::TurnOnDebug() {
  std::vector<std::string> passes;
  auto it = std::begin(passes_);
  while (it != std::end(passes_)) {
    if (*it != "graph_viz_pass") {
      it = passes_.insert(it + 1, "graph_viz_pass");
    } else {
      ++it;
    }
  }
}

std::string PaddlePassBuilder::DebugString() {
  std::stringstream ss;
  ss << "Passes to apply:\n";
  for (auto &pass : passes_) {
    ss << "  - " << pass << '\n';
  }
  return ss.str();
}

void PaddlePassBuilder::DeletePass(const std::string &pass_type) {
  auto it = std::begin(passes_);
  while (it != std::end(passes_)) {
    if (*it == pass_type) {
      it = passes_.erase(it);
    } else {
      ++it;
    }
  }
}

void PaddlePassBuilder::InsertPass(size_t idx, const std::string &pass_type) {
  passes_.insert(std::begin(passes_) + idx, pass_type);
}

void PaddlePassBuilder::DeletePass(size_t idx) {
  passes_.erase(std::begin(passes_) + idx);
}

W
Wojciech Uss 已提交
71 72
void PaddlePassBuilder::AppendAnalysisPass(const std::string &pass) {
  analysis_passes_.push_back(pass);
73 74
}

W
Wojciech Uss 已提交
75 76
void PaddlePassBuilder::ClearPasses() { passes_.clear(); }

77
const std::vector<std::string> kTRTSubgraphPasses({
78 79
  "conv_affine_channel_fuse_pass",  //
      "adaptive_pool2d_convert_global_pass",
80
      "conv_eltwiseadd_affine_channel_fuse_pass",  //
81
      "shuffle_channel_detect_pass",               //
82 83
      "quant_conv2d_dequant_fuse_pass",            //
      "delete_quant_dequant_op_pass",              //
84
      "delete_quant_dequant_filter_op_pass",       //
P
Pei Yang 已提交
85
      // "fc_fuse_pass",                                 //
86 87 88
      "simplify_with_basic_ops_pass",           //
      "embedding_eltwise_layernorm_fuse_pass",  //
      "multihead_matmul_fuse_pass_v2",          //
89
      "multihead_matmul_fuse_pass_v3",          //
90
      "skip_layernorm_fuse_pass",               //
91 92 93 94
      "conv_bn_fuse_pass",                      //
      "unsqueeze2_eltwise_fuse_pass",           //
      "squeeze2_matmul_fuse_pass",              //
      "reshape2_matmul_fuse_pass",              //
95
      "flatten2_matmul_fuse_pass",              //
96
      "map_matmul_v2_to_mul_pass",              //
97 98
      "map_matmul_v2_to_matmul_pass",           //
      "map_matmul_to_mul_pass",                 //
99
      "fc_fuse_pass",                           //
100
      "conv_elementwise_add_fuse_pass",         //
101 102
      "tensorrt_subgraph_pass",                 //
      "conv_bn_fuse_pass",                      //
103 104
#if CUDNN_VERSION >= 7100  // To run conv_fusion, the version of cudnn must be
                           // guaranteed at least v7
105 106 107
// cudnn8.0 has memory leak problem in conv + eltwise + act, so we
// disable the pass.
#if !(CUDNN_VERSION >= 8000 && CUDNN_VERSION < 8100)
108 109
      "conv_elementwise_add_act_fuse_pass",   //
      "conv_elementwise_add2_act_fuse_pass",  //
110 111
#endif
#endif
112 113 114
      "transpose_flatten_concat_fuse_pass",
});

D
denglin-github 已提交
115 116
const std::vector<std::string> kDlnneSubgraphPasses({
    "is_test_pass",                  //
D
denglin-github 已提交
117
    "delete_dropout_op_pass"         //
D
denglin-github 已提交
118 119 120 121 122 123 124
    "simplify_with_basic_ops_pass",  //
    "conv_bn_fuse_pass",             //
    "depthwise_conv_bn_fuse_pass",   //
    "shuffle_channel_detect_pass",   //
    "dlnne_subgraph_pass",           //
});

石晓伟 已提交
125 126 127 128 129 130
const std::vector<std::string> kLiteSubgraphPasses({
#ifdef PADDLE_WITH_LITE
    "lite_subgraph_pass",
#endif
});

131 132
GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
  passes_.assign({
133
    //   "identity_scale_op_clean_pass",             //
134 135 136
    "is_test_pass",                                  //
        "simplify_with_basic_ops_pass",              //
        "conv_affine_channel_fuse_pass",             //
137 138
        "conv_eltwiseadd_affine_channel_fuse_pass",  //
        "conv_bn_fuse_pass",                         //
139
        "conv_eltwiseadd_bn_fuse_pass",              //
140 141
        "embedding_eltwise_layernorm_fuse_pass",     //
        "multihead_matmul_fuse_pass_v2",             //
142 143
        "squeeze2_matmul_fuse_pass",                 //
        "reshape2_matmul_fuse_pass",                 //
144
        "flatten2_matmul_fuse_pass",                 //
145
        "map_matmul_v2_to_mul_pass",                 //
146 147
        "map_matmul_v2_to_matmul_pass",              //
        "map_matmul_to_mul_pass",                    //
148 149
        "fc_fuse_pass",                              //
        "fc_elementwise_layernorm_fuse_pass",        //
150 151
#if CUDNN_VERSION >= 7100  // To run conv_fusion, the version of cudnn must be
                           // guaranteed at least v7
152 153 154
// cudnn8.0 has memory leak problem in conv + eltwise + act, so we
// disable the pass.
#if !(CUDNN_VERSION >= 8000 && CUDNN_VERSION < 8100)
155 156
        "conv_elementwise_add_act_fuse_pass",   //
        "conv_elementwise_add2_act_fuse_pass",  //
157 158 159 160
#endif
        "conv_elementwise_add_fuse_pass",      //
#endif                                         //
        "transpose_flatten_concat_fuse_pass",  //
161
        // following pass should be located in the last, since it will
162 163
        // work on all fused ops.
        "runtime_context_cache_pass"
164 165 166 167 168
  });

  use_gpu_ = true;
}

169 170 171 172 173 174 175
void GpuPassStrategy::EnableCUDNN() {
  if (!use_cudnn_) {
    passes_.insert(passes_.begin(), "cudnn_placement_pass");
  }
  use_cudnn_ = true;
}

W
Wojciech Uss 已提交
176 177
void GpuPassStrategy::EnableMKLDNN() {
  LOG(ERROR) << "GPU not support MKLDNN yet";
178 179
}

W
Wojciech Uss 已提交
180 181
void GpuPassStrategy::EnableMkldnnQuantizer() {
  LOG(ERROR) << "GPU not support MKL-DNN quantization";
Y
Yan Chunwei 已提交
182 183
}

184 185 186 187
void GpuPassStrategy::EnableMkldnnBfloat16() {
  LOG(ERROR) << "GPU not support MKL-DNN bfloat16";
}

188 189 190
CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) {
  // NOTE the large fusions should be located in the front, so that they will
  // not be damaged by smaller ones.
191 192
  passes_.assign({"simplify_with_basic_ops_pass",  //
                  "layer_norm_fuse_pass",
193
                  "attention_lstm_fuse_pass",       //
194 195
                  "seqconv_eltadd_relu_fuse_pass",  //
                  // "seqpool_concat_fuse_pass",    //
196
                  "seqpool_cvm_concat_fuse_pass",  //
197
                  // "embedding_fc_lstm_fuse_pass", //
198
                  // TODO(wilber): fix correctness problem.
199
                  // "fc_lstm_fuse_pass",                    //
B
baoachun 已提交
200 201 202 203 204 205 206 207 208
                  "mul_lstm_fuse_pass",                      //
                  "fc_gru_fuse_pass",                        //
                  "mul_gru_fuse_pass",                       //
                  "seq_concat_fc_fuse_pass",                 //
                  "squeeze2_matmul_fuse_pass",               //
                  "reshape2_matmul_fuse_pass",               //
                  "flatten2_matmul_fuse_pass",               //
                  "map_matmul_v2_to_mul_pass",               //
                  "map_matmul_v2_to_matmul_pass",            //
209
                  "map_matmul_to_mul_pass",                  //
210 211 212 213 214 215 216 217
                  "fc_fuse_pass",                            //
                  "repeated_fc_relu_fuse_pass",              //
                  "squared_mat_sub_fuse_pass",               //
                  "conv_bn_fuse_pass",                       //
                  "conv_eltwiseadd_bn_fuse_pass",            //
                  "conv_transpose_bn_fuse_pass",             //
                  "conv_transpose_eltwiseadd_bn_fuse_pass",  //
                  "is_test_pass",                            //
218 219
                  // following pass should be located in the last, since
                  // it will work on all fused ops.
220
                  "runtime_context_cache_pass"});
Y
Yan Chunwei 已提交
221

222 223
  use_gpu_ = false;
}
W
Wojciech Uss 已提交
224

225 226
void CpuPassStrategy::EnableCUDNN() { LOG(ERROR) << "CPU not support cuDNN"; }

W
Wojciech Uss 已提交
227 228 229 230 231 232
void CpuPassStrategy::EnableMKLDNN() {
// TODO(Superjomn) Consider the way to mix CPU with GPU.
#ifdef PADDLE_WITH_MKLDNN
  if (!use_mkldnn_) {
    passes_.insert(passes_.begin(), "mkldnn_placement_pass");

233
    for (auto &pass : std::vector<std::string>({
234 235 236 237 238 239 240 241
             "depthwise_conv_mkldnn_pass",     //
             "conv_bn_fuse_pass",              // Execute BN passes again to
             "conv_eltwiseadd_bn_fuse_pass",   // preserve correct pass order
             "conv_affine_channel_fuse_pass",  //
             "conv_eltwiseadd_affine_channel_fuse_pass",  //
             "conv_transpose_bn_fuse_pass",               //
             "conv_transpose_eltwiseadd_bn_fuse_pass",    //
             "conv_bias_mkldnn_fuse_pass",                //
242
             "conv_transpose_bias_mkldnn_fuse_pass",
243 244 245
             "conv3d_bias_mkldnn_fuse_pass",  //
             "conv_elementwise_add_mkldnn_fuse_pass",
             "conv_concat_relu_mkldnn_fuse_pass",
246 247 248 249
             "conv_relu_mkldnn_fuse_pass",                 //
             "conv_leaky_relu_mkldnn_fuse_pass",           //
             "conv_relu6_mkldnn_fuse_pass",                //
             "conv_swish_mkldnn_fuse_pass",                //
J
jakpiase 已提交
250
             "conv_hard_swish_mkldnn_fuse_pass",           //
251 252 253
             "scale_matmul_fuse_pass",                     //
             "reshape_transpose_matmul_mkldnn_fuse_pass",  //
             "matmul_transpose_reshape_fuse_pass",         //
254
             // Disabled due to topology-dependent speed-up
255 256
             // "fc_mkldnn_pass",
             // "fc_act_mkldnn_fuse_pass",
257
             "batch_norm_act_fuse_pass",
258 259
             // TODO(intel): Please fix the bug on windows.
             // https://github.com/PaddlePaddle/Paddle/issues/29710
260
             // "mkldnn_inplace_pass",  // This pass should be activated after
261 262
             // fuses. Disabled by default due to
             // little gain and lots of problems
263
         })) {
W
Wojciech Uss 已提交
264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283
      passes_.push_back(pass);
    }
  }
  use_mkldnn_ = true;
#else
  use_mkldnn_ = false;
#endif
}

void CpuPassStrategy::EnableMkldnnQuantizer() {
#ifdef PADDLE_WITH_MKLDNN
  if (!use_mkldnn_quantizer_) {
    passes_.push_back("cpu_quantize_placement_pass");
  }
  use_mkldnn_quantizer_ = true;
#else
  use_mkldnn_quantizer_ = false;
#endif
}

284 285
void CpuPassStrategy::EnableMkldnnBfloat16() {
#ifdef PADDLE_WITH_MKLDNN
286 287 288
  if (!use_mkldnn_bfloat16_) {
    passes_.push_back("cpu_bfloat16_placement_pass");
    passes_.push_back("cpu_bfloat16_pass");
289
    passes_.push_back("cpu_quantize_squash_pass");
290
  }
291 292 293 294 295 296
  use_mkldnn_bfloat16_ = true;
#else
  use_mkldnn_bfloat16_ = false;
#endif
}

297
}  // namespace paddle