analysis_config.cc 31.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include <sstream>
16
#include <string>
17
#include <tuple>
18 19
#include "paddle/fluid/inference/api/paddle_analysis_config.h"
#include "paddle/fluid/inference/api/paddle_pass_builder.h"
20
#include "paddle/fluid/inference/utils/table_printer.h"
21
#include "paddle/fluid/platform/cpu_info.h"
22
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
23 24
#include "paddle/fluid/platform/enforce.h"

25 26 27 28
#ifdef PADDLE_WITH_TENSORRT
#include "paddle/fluid/inference/tensorrt/helper.h"
#endif

29
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
30 31 32
DECLARE_uint64(initial_gpu_memory_in_mb);
#endif

33
namespace paddle {
W
wanghuancoder 已提交
34 35
struct MkldnnQuantizerConfig;

36
extern const std::vector<std::string> kTRTSubgraphPasses;
D
denglin-github 已提交
37
extern const std::vector<std::string> kDlnneSubgraphPasses;
石晓伟 已提交
38
extern const std::vector<std::string> kLiteSubgraphPasses;
39

40
PassStrategy *AnalysisConfig::pass_builder() const {
41 42 43 44
  if (!pass_builder_.get()) {
    if (use_gpu_) {
      LOG(INFO) << "Create GPU IR passes";
      pass_builder_.reset(new GpuPassStrategy);
45 46
    } else if (use_xpu_) {
      pass_builder_.reset(new XpuPassStrategy);
W
Wilber 已提交
47 48
    } else if (use_npu_) {
      pass_builder_.reset(new NpuPassStrategy);
J
jianghaicheng 已提交
49 50 51
    } else if (use_ipu_) {
      LOG(INFO) << "Create IPU IR passes";
      pass_builder_.reset(new IpuPassStrategy);
52 53 54 55 56 57 58 59 60 61 62 63
    } else {
      LOG(INFO) << "Create CPU IR passes";
      pass_builder_.reset(new CpuPassStrategy);
    }
  } else if (pass_builder_->use_gpu() ^ use_gpu()) {
    LOG(WARNING) << "The use_gpu flag is not compatible between Config and "
                    "PassBuilder, the flags are "
                 << use_gpu() << " " << pass_builder_->use_gpu();
    LOG(WARNING) << "Please make them compatible, still use the existing "
                    "PassBuilder.";
  }

64 65 66
  return pass_builder_.get();
}

67
AnalysisConfig::AnalysisConfig(const std::string &model_dir) {
68
  model_dir_ = model_dir;
Y
Yan Chunwei 已提交
69 70

  Update();
71
}
72 73
AnalysisConfig::AnalysisConfig(const std::string &prog_file,
                               const std::string &params_file) {
74 75
  prog_file_ = prog_file;
  params_file_ = params_file;
Y
Yan Chunwei 已提交
76 77

  Update();
78
}
79 80
void AnalysisConfig::SetModel(const std::string &prog_file_path,
                              const std::string &params_file_path) {
81 82
  prog_file_ = prog_file_path;
  params_file_ = params_file_path;
Y
Yan Chunwei 已提交
83 84

  Update();
85
}
86 87
void AnalysisConfig::EnableUseGpu(uint64_t memory_pool_init_size_mb,
                                  int device_id) {
88
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
89 90
  use_gpu_ = true;
  memory_pool_init_size_mb_ = memory_pool_init_size_mb;
91
  FLAGS_initial_gpu_memory_in_mb = memory_pool_init_size_mb_;
92
  gpu_device_id_ = device_id;
93
#else
Y
Yan Chunwei 已提交
94
  LOG(ERROR) << "Please compile with gpu to EnableGpu()";
95 96
  use_gpu_ = false;
#endif
Y
Yan Chunwei 已提交
97 98 99

  Update();
}
100
void AnalysisConfig::DisableGpu() {
Y
Yan Chunwei 已提交
101 102 103
  use_gpu_ = false;

  Update();
104 105
}

106 107 108 109 110 111
void AnalysisConfig::DisableFCPadding() {
  use_fc_padding_ = false;

  Update();
}

W
Wilber 已提交
112 113 114 115
void AnalysisConfig::EnableXpu(int l3_workspace_size, bool locked,
                               bool autotune, const std::string &autotune_file,
                               const std::string &precision,
                               bool adaptive_seqlen) {
116 117
  use_xpu_ = true;
  xpu_l3_workspace_size_ = l3_workspace_size;
W
Wilber 已提交
118 119 120 121 122
  xpu_locked_ = locked;
  xpu_autotune_ = autotune;
  xpu_autotune_file_ = autotune_file;
  xpu_precision_ = precision;
  xpu_adaptive_seqlen_ = adaptive_seqlen;
123 124 125
  Update();
}

126 127 128 129 130 131 132 133
void AnalysisConfig::SetXpuDeviceId(int device_id) {
  PADDLE_ENFORCE_EQ(use_xpu_, true,
                    platform::errors::PreconditionNotMet(
                        "Should call EnableXpu before SetXpuDeviceId."));
  xpu_device_id_ = device_id;
  Update();
}

W
Wilber 已提交
134 135 136 137 138 139 140 141 142 143 144
void AnalysisConfig::EnableNpu(int device_id) {
#ifdef PADDLE_WITH_ASCEND_CL
  use_npu_ = true;
  npu_device_id_ = device_id;
#else
  LOG(ERROR) << "Please compile with npu to EnableNpu()";
  use_npu_ = false;
#endif

  Update();
}
145 146 147 148

void AnalysisConfig::EnableIpu(int ipu_device_num, int ipu_micro_batch_size,
                               bool ipu_enable_pipelining,
                               int ipu_batches_per_step) {
J
jianghaicheng 已提交
149 150 151
  enable_ir_optim_ = true;

  use_ipu_ = true;
152 153
  ipu_device_num_ = ipu_device_num;
  ipu_micro_batch_size_ = ipu_micro_batch_size;
J
jianghaicheng 已提交
154 155
  ipu_enable_pipelining_ = ipu_enable_pipelining;
  ipu_batches_per_step_ = ipu_batches_per_step;
156 157 158 159 160 161 162 163 164 165 166

  Update();
}

void AnalysisConfig::SetIpuConfig(bool ipu_enable_fp16, int ipu_replica_num,
                                  float ipu_available_memory_proportion,
                                  bool ipu_enable_half_partial) {
  ipu_enable_fp16_ = ipu_enable_fp16;
  ipu_replica_num_ = ipu_replica_num;
  ipu_available_memory_proportion_ = ipu_available_memory_proportion;
  ipu_enable_half_partial_ = ipu_enable_half_partial;
J
jianghaicheng 已提交
167 168 169

  Update();
}
W
Wilber 已提交
170

171
AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
172 173 174 175 176 177
#define CP_MEMBER(member__) member__ = other.member__;

  // Model related.
  CP_MEMBER(model_dir_);
  CP_MEMBER(model_from_memory_);  // the memory model reuses prog_file_ and
                                  // params_file_ fields.
178

179
  CP_MEMBER(opt_cache_dir_);
W
Wilber 已提交
180 181
  CP_MEMBER(prog_file_);
  CP_MEMBER(params_file_);
182

183
  CP_MEMBER(use_fc_padding_);
184
  // GPU related.
185
  CP_MEMBER(use_gpu_);
186
  CP_MEMBER(use_cudnn_);
187
  CP_MEMBER(gpu_device_id_);
188
  CP_MEMBER(memory_pool_init_size_mb_);
Y
Yan Chunwei 已提交
189 190

  CP_MEMBER(enable_memory_optim_);
S
Sylwester Fraczek 已提交
191
  // TensorRT related.
192 193 194 195
  CP_MEMBER(use_tensorrt_);
  CP_MEMBER(tensorrt_workspace_size_);
  CP_MEMBER(tensorrt_max_batchsize_);
  CP_MEMBER(tensorrt_min_subgraph_size_);
N
nhzlx 已提交
196
  CP_MEMBER(tensorrt_precision_mode_);
197
  CP_MEMBER(trt_disabled_ops_);
198 199
  CP_MEMBER(trt_use_dla_);
  CP_MEMBER(trt_dla_core_);
N
nhzlx 已提交
200
  CP_MEMBER(trt_use_static_engine_);
201
  CP_MEMBER(trt_use_calib_mode_);
202
  CP_MEMBER(trt_use_oss_);
203
  CP_MEMBER(trt_with_interleaved_);
204 205 206 207
  CP_MEMBER(trt_tuned_dynamic_shape_);
  CP_MEMBER(trt_allow_build_at_runtime_);
  CP_MEMBER(collect_shape_range_info_);
  CP_MEMBER(shape_range_info_path_);
208
  CP_MEMBER(trt_use_inspector_);
D
denglin-github 已提交
209 210 211
  // Dlnne related
  CP_MEMBER(use_dlnne_);
  CP_MEMBER(dlnne_min_subgraph_size_);
S
Sylwester Fraczek 已提交
212
  // MKLDNN related.
213 214
  CP_MEMBER(use_mkldnn_);
  CP_MEMBER(mkldnn_enabled_op_types_);
215
  CP_MEMBER(mkldnn_cache_capacity_);
216 217 218
  // Bfloat16 related.
  CP_MEMBER(use_mkldnn_bfloat16_);
  CP_MEMBER(bfloat16_enabled_op_types_);
219 220 221
  // Quantization related.
  CP_MEMBER(use_mkldnn_quantizer_);
  CP_MEMBER(mkldnn_quantizer_config_);
222 223 224
  CP_MEMBER(min_input_shape_);
  CP_MEMBER(max_input_shape_);
  CP_MEMBER(optim_input_shape_);
225
  CP_MEMBER(disable_trt_plugin_fp16_);
226

石晓伟 已提交
227 228 229 230
  CP_MEMBER(use_lite_);
  CP_MEMBER(lite_precision_mode_);
  CP_MEMBER(lite_passes_filter_);
  CP_MEMBER(lite_ops_filter_);
231 232
  CP_MEMBER(lite_zero_copy_);

W
Wilber 已提交
233
  // XPU related.
234
  CP_MEMBER(use_xpu_);
W
Wilber 已提交
235
  CP_MEMBER(xpu_device_id_);
236
  CP_MEMBER(xpu_l3_workspace_size_);
W
Wilber 已提交
237 238 239 240 241
  CP_MEMBER(xpu_locked_);
  CP_MEMBER(xpu_autotune_);
  CP_MEMBER(xpu_autotune_file_);
  CP_MEMBER(xpu_precision_);
  CP_MEMBER(xpu_adaptive_seqlen_);
石晓伟 已提交
242

W
Wilber 已提交
243 244 245
  // NPU related.
  CP_MEMBER(use_npu_);
  CP_MEMBER(npu_device_id_);
246
  CP_MEMBER(nnadapter_config_);
W
Wilber 已提交
247

248 249 250
  // profile related.
  CP_MEMBER(with_profile_);

251 252 253
  // glog related.
  CP_MEMBER(with_glog_info_);

254 255 256 257 258 259 260 261 262 263
  // Ir related.
  CP_MEMBER(enable_ir_optim_);
  CP_MEMBER(use_feed_fetch_ops_);
  CP_MEMBER(ir_debug_);
  CP_MEMBER(specify_input_name_);

  CP_MEMBER(cpu_math_library_num_threads_);

  CP_MEMBER(serialized_info_cache_);

264 265
  CP_MEMBER(thread_local_stream_);

J
jianghaicheng 已提交
266 267 268
  // ipu related
  CP_MEMBER(use_ipu_);
  CP_MEMBER(ipu_device_num_);
269
  CP_MEMBER(ipu_micro_batch_size_);
J
jianghaicheng 已提交
270 271
  CP_MEMBER(ipu_enable_pipelining_);
  CP_MEMBER(ipu_batches_per_step_);
272 273 274 275
  CP_MEMBER(ipu_enable_fp16_);
  CP_MEMBER(ipu_replica_num_);
  CP_MEMBER(ipu_available_memory_proportion_);
  CP_MEMBER(ipu_enable_half_partial_);
J
jianghaicheng 已提交
276

277
  if (use_gpu_) {
278 279 280
    PADDLE_ENFORCE_EQ(use_xpu_, false,
                      platform::errors::InvalidArgument(
                          "Only one choice can be made between CPU and XPU."));
281 282
    pass_builder_.reset(new GpuPassStrategy(
        *static_cast<GpuPassStrategy *>(other.pass_builder())));
J
jianghaicheng 已提交
283 284 285
  } else if (use_ipu_) {
    pass_builder_.reset(new IpuPassStrategy(
        *static_cast<IpuPassStrategy *>(other.pass_builder())));
286 287 288
  } else if (use_xpu_) {
    pass_builder_.reset(new XpuPassStrategy(
        *static_cast<XpuPassStrategy *>(other.pass_builder())));
W
Wilber 已提交
289 290 291
  } else if (use_npu_) {
    pass_builder_.reset(new NpuPassStrategy(
        *static_cast<NpuPassStrategy *>(other.pass_builder())));
292 293 294 295 296
  } else {
    pass_builder_.reset(new CpuPassStrategy(
        *static_cast<CpuPassStrategy *>(other.pass_builder())));
  }

297
#undef CP_MEMBER
Y
Yan Chunwei 已提交
298

W
Wilber 已提交
299 300 301 302 303
  Update();
  if (use_tensorrt_) {
    // Update() will reset all the passes, when some tensorRT pass is deleted in
    // other.pass_builder(), it will set again, so we just remove the
    // deleted_pass.
304
    pass_builder_->ClearPasses();
W
Wilber 已提交
305
    auto other_passes = other.pass_builder()->AllPasses();
306 307
    for (auto pass : other_passes) {
      pass_builder_->AppendPass(pass);
W
Wilber 已提交
308
    }
309
  }
D
denglin-github 已提交
310 311 312 313 314 315 316 317 318 319 320 321 322 323 324
  if (use_dlnne_) {
    auto all_passes = kDlnneSubgraphPasses;
    auto other_passes = other.pass_builder()->AllPasses();
    // We should sort them, because the user may call the SwitchIrDebug
    // interface, which will change the pass.
    std::sort(all_passes.begin(), all_passes.end());
    std::sort(other_passes.begin(), other_passes.end());
    std::vector<std::string> deleted_passes;
    std::set_difference(all_passes.begin(), all_passes.end(),
                        other_passes.begin(), other_passes.end(),
                        std::inserter(deleted_passes, deleted_passes.begin()));
    for (auto ps : deleted_passes) {
      pass_builder_->DeletePass(ps);
    }
  }
325 326
}

327
void AnalysisConfig::EnableCUDNN() {
328
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
329 330 331 332 333 334 335 336 337
  use_cudnn_ = use_gpu_;
#else
  LOG(ERROR) << "Please compile with CUDA first to use cuDNN";
  use_cudnn_ = false;
#endif

  Update();
}

338
void AnalysisConfig::EnableMKLDNN() {
339 340 341 342 343 344
#ifdef PADDLE_WITH_MKLDNN
  use_mkldnn_ = true;
#else
  LOG(ERROR) << "Please compile with MKLDNN first to use MKLDNN";
  use_mkldnn_ = false;
#endif
Y
Yan Chunwei 已提交
345 346

  Update();
347 348
}

349 350 351 352 353 354 355 356 357
void AnalysisConfig::SetMkldnnCacheCapacity(int capacity) {
#ifdef PADDLE_WITH_MKLDNN
  mkldnn_cache_capacity_ = capacity;
#else
  LOG(ERROR) << "Please compile with MKLDNN first to set MKLDNN Thread Id";
  mkldnn_cache_capacity_ = 0;
#endif
}

358 359 360 361 362 363 364 365 366 367 368 369 370
void AnalysisConfig::EnableMkldnnQuantizer() {
#ifdef PADDLE_WITH_MKLDNN
  if (!mkldnn_quantizer_config_)
    mkldnn_quantizer_config_.reset(new MkldnnQuantizerConfig());
  use_mkldnn_quantizer_ = true;
#else
  LOG(ERROR) << "Please compile with MKLDNN first to use MkldnnQuantizer";
  use_mkldnn_quantizer_ = false;
#endif

  Update();
}

371 372
void AnalysisConfig::EnableMkldnnBfloat16() {
#ifdef PADDLE_WITH_MKLDNN
373 374
  if (platform::MayIUse(platform::cpu_isa_t::avx512_core)) {
    use_mkldnn_bfloat16_ = true;
375 376 377 378
    LOG(INFO) << "Hardware support for BFLOAT16"
              << (platform::MayIUse(platform::cpu_isa_t::avx512_bf16)
                      ? " is enabled"
                      : " is disabled. Simulation will be used");
379 380 381 382
  } else {
    LOG(INFO) << "CPU does not support BFLOAT16 calculations";
    use_mkldnn_bfloat16_ = false;
  }
383 384 385 386 387 388 389 390
#else
  LOG(ERROR) << "Please compile with MKLDNN first to use MkldnnBfloat16";
  use_mkldnn_bfloat16_ = false;
#endif

  Update();
}

391
MkldnnQuantizerConfig *AnalysisConfig::mkldnn_quantizer_config() const {
392
  PADDLE_ENFORCE_NOT_NULL(mkldnn_quantizer_config_,
393 394
                          platform::errors::PreconditionNotMet(
                              "MkldnnQuantizer was not enabled yet."));
395
  return mkldnn_quantizer_config_.get();
396 397
}

398
void AnalysisConfig::EnableTensorRtEngine(
N
nhzlx 已提交
399
    int workspace_size, int max_batch_size, int min_subgraph_size,
400
    AnalysisConfig::Precision precision_mode, bool use_static,
401
    bool use_calib_mode) {
402
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
Y
Yan Chunwei 已提交
403 404 405 406 407
  if (!use_gpu()) {
    LOG(ERROR) << "To use TensorRT engine, please call EnableGpu() first";
    return;
  }

408 409 410
  use_tensorrt_ = true;
  tensorrt_workspace_size_ = workspace_size;
  tensorrt_max_batchsize_ = max_batch_size;
N
nhzlx 已提交
411
  tensorrt_min_subgraph_size_ = min_subgraph_size;
N
nhzlx 已提交
412
  tensorrt_precision_mode_ = precision_mode;
N
nhzlx 已提交
413
  trt_use_static_engine_ = use_static;
414
  trt_use_calib_mode_ = use_calib_mode;
Y
Yan Chunwei 已提交
415

416
  Update();
Y
Yan Chunwei 已提交
417 418 419 420
#else
  LOG(ERROR)
      << "To use TensorRT engine, please compile inference lib with GPU first.";
#endif
421 422
}

D
denglin-github 已提交
423 424 425 426 427 428
void AnalysisConfig::EnableDlnne(int min_subgraph_size) {
  use_dlnne_ = true;
  dlnne_min_subgraph_size_ = min_subgraph_size;
  Update();
}

429 430 431 432 433 434 435 436 437 438 439
void AnalysisConfig::SetTRTDynamicShapeInfo(
    std::map<std::string, std::vector<int>> min_input_shape,
    std::map<std::string, std::vector<int>> max_input_shape,
    std::map<std::string, std::vector<int>> optim_input_shape,
    bool disable_trt_plugin_fp16) {
  min_input_shape_ = min_input_shape;
  max_input_shape_ = max_input_shape;
  optim_input_shape_ = optim_input_shape;
  disable_trt_plugin_fp16_ = disable_trt_plugin_fp16;
}

440 441 442 443 444
void AnalysisConfig::EnableTensorRtDLA(int dla_core) {
  trt_use_dla_ = true;
  trt_dla_core_ = dla_core;
}

445 446
void AnalysisConfig::EnableTensorRtInspector() { trt_use_inspector_ = true; }

447 448 449 450 451
void AnalysisConfig::Exp_DisableTensorRtOPs(
    const std::vector<std::string> &ops) {
  trt_disabled_ops_.insert(trt_disabled_ops_.end(), ops.begin(), ops.end());
}

452
void AnalysisConfig::EnableTensorRtOSS() { trt_use_oss_ = true; }
453

Y
Yan Chunwei 已提交
454
// TODO(Superjomn) refactor this, buggy.
455
void AnalysisConfig::Update() {
456 457 458
  auto info = SerializeInfoCache();
  if (info == serialized_info_cache_) return;

Y
Yan Chunwei 已提交
459
  // Transfer pass_builder and copy the existing compatible passes.
W
Wilber 已提交
460 461
  if (!pass_builder_ || ((use_gpu() ^ pass_builder_->use_gpu())) ||
      ((use_xpu() ^ pass_builder_->use_xpu())) ||
J
jianghaicheng 已提交
462 463
      ((use_npu() ^ pass_builder_->use_npu())) ||
      ((use_ipu() ^ pass_builder_->use_ipu()))) {
Y
Yan Chunwei 已提交
464 465 466 467 468 469 470
    if (use_gpu()) {
      pass_builder_.reset(new GpuPassStrategy);

      if (use_tensorrt_) {
        // Append after the Affine_channel_conv_fuse pass.
        pass_builder()->InsertPass(3, "tensorrt_subgraph_pass");
      }
J
jianghaicheng 已提交
471 472 473
    } else if (use_ipu()) {
      VLOG(1) << "IpuPassStrategy has been used for new.";
      pass_builder_.reset(new IpuPassStrategy);
474 475 476 477 478 479
    } else if (use_xpu()) {
      PADDLE_ENFORCE_EQ(
          use_gpu(), false,
          platform::errors::InvalidArgument(
              "Only one choice can be made between CPU and XPU."));
      pass_builder_.reset(new XpuPassStrategy);
W
Wilber 已提交
480 481 482 483 484 485
    } else if (use_npu()) {
      PADDLE_ENFORCE_EQ(
          use_gpu(), false,
          platform::errors::InvalidArgument(
              "Only one choice can be made between GPU and NPU."));
      pass_builder_.reset(new NpuPassStrategy);
Y
Yan Chunwei 已提交
486 487 488
    } else {
      pass_builder_.reset(new CpuPassStrategy);
    }
489

490
  } else {
Y
Yan Chunwei 已提交
491 492 493
    if (use_gpu()) {
      pass_builder_.reset(new GpuPassStrategy(
          *static_cast<GpuPassStrategy *>(pass_builder_.get())));
J
jianghaicheng 已提交
494 495 496 497
    } else if (use_ipu()) {
      VLOG(1) << "IpuPassStrategy has been used.";
      pass_builder_.reset(new IpuPassStrategy(
          *static_cast<IpuPassStrategy *>(pass_builder_.get())));
498 499 500 501 502 503 504
    } else if (use_xpu()) {
      PADDLE_ENFORCE_EQ(
          use_gpu(), false,
          platform::errors::InvalidArgument(
              "Only one choice can be made between CPU and XPU."));
      pass_builder_.reset(new XpuPassStrategy(
          *static_cast<XpuPassStrategy *>(pass_builder_.get())));
W
Wilber 已提交
505 506 507 508 509 510 511
    } else if (use_npu()) {
      PADDLE_ENFORCE_EQ(
          use_gpu(), false,
          platform::errors::InvalidArgument(
              "Only one choice can be made between GPU and NPU."));
      pass_builder_.reset(new NpuPassStrategy(
          *static_cast<NpuPassStrategy *>(pass_builder_.get())));
Y
Yan Chunwei 已提交
512 513 514 515
    } else {
      pass_builder_.reset(new CpuPassStrategy(
          *static_cast<CpuPassStrategy *>(pass_builder_.get())));
    }
516 517 518
  }

  if (use_tensorrt_) {
519 520
    pass_builder()->ClearPasses();
    for (const auto &pass : kTRTSubgraphPasses) {
521
      if (tensorrt_precision_mode_ == AnalysisConfig::Precision::kInt8 &&
522
          (pass == "conv_bn_fuse_pass")) {
523 524
        continue;
      }
525
      pass_builder()->AppendPass(pass);
526 527
    }
  }
528

D
denglin-github 已提交
529 530 531 532 533 534 535
  if (use_dlnne_) {
    pass_builder()->ClearPasses();
    for (const auto &pass : kDlnneSubgraphPasses) {
      pass_builder()->AppendPass(pass);
    }
  }

536
  if (use_gpu() && use_cudnn_) {
537
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
538 539 540 541 542 543 544 545
    if (!enable_ir_optim_) {
      LOG(ERROR) << "EnableCUDNN() only works when IR optimization is enabled.";
    } else {
      pass_builder()->EnableCUDNN();
    }
#endif
  }

546
  if (use_mkldnn_) {
W
Wojciech Uss 已提交
547
#ifdef PADDLE_WITH_MKLDNN
548 549 550
    if (!enable_ir_optim_) {
      LOG(ERROR)
          << "EnableMKLDNN() only works when IR optimization is enabled.";
W
Wojciech Uss 已提交
551 552
    } else {
      pass_builder()->EnableMKLDNN();
553 554 555 556
    }
#endif
  }

557 558 559 560 561
  // Quantization passes must come after all other optimization passes
  if (use_mkldnn_quantizer_) {
    if (!enable_ir_optim_) {
      LOG(ERROR) << "EnableMkldnnQuantizer() only works when IR optimization "
                    "is enabled.";
562 563
    }
#ifdef PADDLE_WITH_MKLDNN
564
    pass_builder()->EnableMkldnnQuantizer();
565 566 567
#endif
  }

568 569 570 571 572 573
  if (use_mkldnn_bfloat16_) {
#ifdef PADDLE_WITH_MKLDNN
    pass_builder()->EnableMkldnnBfloat16();
#endif
  }

574
#ifdef PADDLE_WITH_MKLDNN
575 576
  // Do not optimize when mkldnn is on
  if (enable_memory_optim_ && !use_mkldnn_) {
577
#else
Y
Yan Chunwei 已提交
578
  if (enable_memory_optim_) {
579 580
#endif
    pass_builder()->AppendAnalysisPass("memory_optimize_pass");
Y
Yan Chunwei 已提交
581 582
  }

石晓伟 已提交
583 584 585 586 587 588 589 590 591 592 593 594 595 596
  if (use_lite_) {
#ifndef PADDLE_WITH_LITE
    LOG(WARNING) << "You tried to enable the lite subgraph "
                    "but did not have the option -DWITH_LITE compiled.";
#endif
    pass_builder()->ClearPasses();
    for (const auto &pass : kLiteSubgraphPasses) {
      if (std::find(lite_passes_filter_.begin(), lite_passes_filter_.end(),
                    pass) == lite_passes_filter_.end()) {
        pass_builder()->AppendPass(pass);
      }
    }
  }

597
  if (use_xpu_) {
598
#if (defined LITE_SUBGRAPH_WITH_XPU) || (defined PADDLE_WITH_XPU)
599 600 601 602
    PADDLE_ENFORCE_EQ(use_gpu_, false,
                      platform::errors::Unavailable(
                          "Currently, XPU and GPU cannot be enabled in the "
                          "same analysis configuration."));
603 604 605 606 607
#else
    PADDLE_THROW(platform::errors::Unavailable(
        "You tried to use an XPU device, but Paddle was not compiled "
        "with XPU-runtime."));
#endif
608 609
  }

W
Wilber 已提交
610
  if (use_npu_) {
611
#if defined(PADDLE_WITH_ASCEND_CL) || defined(LITE_SUBGRAPH_WITH_NPU)
W
Wilber 已提交
612 613 614 615 616 617 618 619 620 621
    PADDLE_ENFORCE_EQ(use_gpu_, false,
                      platform::errors::Unavailable(
                          "Currently, NPU and GPU cannot be enabled in the "
                          "same analysis configuration."));
#else
    PADDLE_THROW(platform::errors::Unavailable(
        "You tried to use an NPU device, but Paddle was not compiled "
        "with NPU-runtime."));
#endif
  }
J
jianghaicheng 已提交
622 623 624 625 626 627 628
  if (use_ipu_) {
#ifndef PADDLE_WITH_IPU
    PADDLE_THROW(platform::errors::Unavailable(
        "You tried to enable the ipu "
        "but did not have the option -DWITH_IPU compiled."));
#endif
  }
W
Wilber 已提交
629

630 631 632 633 634
  if (ir_debug_) {
    pass_builder()->TurnOnDebug();
  }
}

635
std::string AnalysisConfig::SerializeInfoCache() {
636
  std::stringstream ss;
Y
Yan Chunwei 已提交
637 638 639 640
  ss << model_dir_;
  ss << prog_file_;
  ss << params_file_;

641
  ss << use_gpu_;
642
  ss << use_fc_padding_;
643 644
  ss << gpu_device_id_;
  ss << xpu_device_id_;
645 646 647 648 649
  ss << memory_pool_init_size_mb_;

  ss << use_tensorrt_;
  ss << tensorrt_workspace_size_;
  ss << tensorrt_max_batchsize_;
Y
Yan Chunwei 已提交
650 651
  ss << tensorrt_min_subgraph_size_;

D
denglin-github 已提交
652 653 654
  ss << use_dlnne_;
  ss << dlnne_min_subgraph_size_;

655 656 657
  for (auto &op : trt_disabled_ops_) ss << op.c_str();
  ss << ";";

658 659 660
  ss << trt_use_dla_;
  ss << trt_dla_core_;

Y
Yan Chunwei 已提交
661
  ss << enable_memory_optim_;
662 663

  ss << use_mkldnn_;
664
  ss << mkldnn_cache_capacity_;
Y
Yan Chunwei 已提交
665 666 667
  for (auto &item : mkldnn_enabled_op_types_) ss << item;
  ss << ";";

668
  ss << use_mkldnn_quantizer_;
669
  ss << use_mkldnn_bfloat16_;
670 671
  for (auto &item : bfloat16_enabled_op_types_) ss << item;
  ss << ";";
Y
Yan Chunwei 已提交
672 673
  ss << model_from_memory_;

674 675
  ss << with_profile_;

676 677
  ss << with_glog_info_;

678 679 680 681
  ss << enable_ir_optim_;
  ss << use_feed_fetch_ops_;
  ss << ir_debug_;

Y
Yan Chunwei 已提交
682 683
  ss << specify_input_name_;
  ss << cpu_math_library_num_threads_;
石晓伟 已提交
684 685

  ss << use_lite_;
686 687
  ss << use_xpu_;
  ss << xpu_l3_workspace_size_;
W
Wilber 已提交
688 689 690 691 692
  ss << xpu_locked_;
  ss << xpu_autotune_;
  ss << xpu_autotune_file_;
  ss << xpu_precision_;
  ss << xpu_adaptive_seqlen_;
693

W
Wilber 已提交
694 695 696
  ss << use_npu_;
  ss << npu_device_id_;

697 698
  ss << thread_local_stream_;

J
jianghaicheng 已提交
699 700
  ss << use_ipu_;
  ss << ipu_device_num_;
701
  ss << ipu_micro_batch_size_;
J
jianghaicheng 已提交
702 703
  ss << ipu_enable_pipelining_;
  ss << ipu_batches_per_step_;
704 705 706 707
  ss << ipu_enable_fp16_;
  ss << ipu_replica_num_;
  ss << ipu_available_memory_proportion_;
  ss << ipu_enable_half_partial_;
J
jianghaicheng 已提交
708

709 710 711
  return ss.str();
}

712
void AnalysisConfig::SetCpuMathLibraryNumThreads(
713 714
    int cpu_math_library_num_threads) {
  cpu_math_library_num_threads_ = cpu_math_library_num_threads;
Y
Yan Chunwei 已提交
715 716

  Update();
717 718
}

719
float AnalysisConfig::fraction_of_gpu_memory_for_pool() const {
720
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
721 722
  // Get the GPU memory details and calculate the fraction of memory for the
  // GPU memory pool.
723
  size_t gpu_total, gpu_available;
724
  platform::SetDeviceId(gpu_device_id_);
725 726
  platform::GpuMemoryUsage(&gpu_available, &gpu_total);
  double total_gpu_memory = gpu_total / 1024. / 1024.;
727 728
  float fraction_of_gpu_memory =
      static_cast<double>(memory_pool_init_size_mb()) / total_gpu_memory;
729 730 731 732
  VLOG(3) << "total_gpu_memory is " << total_gpu_memory
          << "M, gpu_available is " << gpu_available / 1024. / 1024.
          << "M, memory_pool_init_size is " << memory_pool_init_size_mb()
          << "M.";
733 734 735 736
  return fraction_of_gpu_memory;
#else
  return 0.;
#endif
737 738
}

739 740
void AnalysisConfig::EnableMemoryOptim(bool x) {
  enable_memory_optim_ = x;
Y
Yan Chunwei 已提交
741 742 743
  Update();
}

744
bool AnalysisConfig::enable_memory_optim() const {
Y
Yan Chunwei 已提交
745 746 747
  return enable_memory_optim_;
}

748 749 750 751
void AnalysisConfig::SetModelBuffer(const char *prog_buffer,
                                    size_t prog_buffer_size,
                                    const char *param_buffer,
                                    size_t param_buffer_size) {
752 753
  prog_file_ = std::string(prog_buffer, prog_buffer + prog_buffer_size);
  params_file_ = std::string(param_buffer, param_buffer + param_buffer_size);
T
Tao Luo 已提交
754
  model_from_memory_ = true;
T
Tao Luo 已提交
755 756
}

757
NativeConfig AnalysisConfig::ToNativeConfig() const {
Y
Yan Chunwei 已提交
758 759 760 761 762
  NativeConfig config;
  config.model_dir = model_dir_;
  config.prog_file = prog_file_;
  config.param_file = params_file_;
  config.use_gpu = use_gpu_;
763
  config.device = gpu_device_id_;
Y
Yan Chunwei 已提交
764 765 766 767 768
  config.fraction_of_gpu_memory = fraction_of_gpu_memory_for_pool();
  config.specify_input_name = specify_input_name_;
  return config;
}

Y
Yan Chunwei 已提交
769 770 771 772
void AnalysisConfig::SwitchIrDebug(int x) {
  ir_debug_ = x;
  Update();
}
773 774 775 776 777 778

void AnalysisConfig::EnableProfile() {
  with_profile_ = true;
  Update();
}

779 780 781 782 783
void AnalysisConfig::DisableGlogInfo() {
  with_glog_info_ = false;
  Update();
}

石晓伟 已提交
784
void AnalysisConfig::EnableLiteEngine(
785
    AnalysisConfig::Precision precision_mode, bool zero_copy,
石晓伟 已提交
786 787 788 789 790 791
    const std::vector<std::string> &passes_filter,
    const std::vector<std::string> &ops_filter) {
  use_lite_ = true;
  lite_precision_mode_ = precision_mode;
  lite_passes_filter_ = passes_filter;
  lite_ops_filter_ = ops_filter;
792
  lite_zero_copy_ = zero_copy;
石晓伟 已提交
793 794 795
  Update();
}

796 797 798 799 800 801 802
void AnalysisConfig::PartiallyRelease() {
  prog_file_.clear();
  prog_file_.shrink_to_fit();
  params_file_.clear();
  params_file_.shrink_to_fit();
}

803 804
void AnalysisConfig::EnableGpuMultiStream() { thread_local_stream_ = true; }

805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823
std::string AnalysisConfig::Summary() {
  const std::vector<std::string> header{"Option", "Value"};
  paddle::inference::TablePrinter os(header);

  if (!model_dir_.empty()) {
    os.InsertRow({"model_dir", model_dir_});
  }
  if (!(prog_file_.empty() && params_file_.empty())) {
    os.InsertRow({"model_file", prog_file_});
    os.InsertRow({"params_file", params_file_});
  }
  if (model_from_memory_) {
    os.InsertRow({"model_from_memory", params_file_});
  }
  os.InsetDivider();

  // cpu info
  os.InsertRow(
      {"cpu_math_thread", std::to_string(cpu_math_library_num_threads_)});
824
  os.InsertRow({"enable_mkldnn", use_mkldnn_ ? "true" : "false"});
825 826 827 828 829 830 831 832 833 834 835 836 837 838 839
  os.InsertRow(
      {"mkldnn_cache_capacity", std::to_string(mkldnn_cache_capacity_)});
  os.InsetDivider();

  // gpu info
  os.InsertRow({"use_gpu", use_gpu_ ? "true" : "false"});
  if (use_gpu_) {
    os.InsertRow({"gpu_device_id", std::to_string(gpu_device_id_)});
    os.InsertRow({"memory_pool_init_size",
                  std::to_string(memory_pool_init_size_mb_) + "MB"});
    os.InsertRow(
        {"thread_local_stream", thread_local_stream_ ? "true" : "false"});

    os.InsertRow({"use_tensorrt", use_tensorrt_ ? "true" : "false"});
    if (use_tensorrt_) {
840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866
#ifdef PADDLE_WITH_TENSORRT
      auto Precision2String =
          [](paddle::AnalysisConfig::Precision prec) -> std::string {
        if (prec == Precision::kFloat32)
          return "fp32";
        else if (prec == Precision::kHalf)
          return "fp16";
        else if (prec == Precision::kInt8)
          return "int8";
        else
          return "None";
      };
      auto version2string =
          [](const std::tuple<int, int, int> &ver) -> std::string {
        std::ostringstream os;
        int major = std::get<0>(ver);
        int minor = std::get<1>(ver);
        int patch = std::get<2>(ver);
        os << major << "." << minor << "." << patch;
        return os.str();
      };
      os.InsertRow(
          {"trt_compile_version",
           version2string(inference::tensorrt::GetTrtCompileVersion())});
      os.InsertRow(
          {"trt_runtime_version",
           version2string(inference::tensorrt::GetTrtRuntimeVersion())});
867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882
      os.InsertRow({"tensorrt_precision_mode",
                    Precision2String(tensorrt_precision_mode_)});
      os.InsertRow({"tensorrt_workspace_size",
                    std::to_string(tensorrt_workspace_size_)});
      os.InsertRow(
          {"tensorrt_max_batch_size", std::to_string(tensorrt_max_batchsize_)});
      os.InsertRow({"tensorrt_min_subgraph_size",
                    std::to_string(tensorrt_min_subgraph_size_)});
      os.InsertRow({"tensorrt_use_static_engine",
                    trt_use_static_engine_ ? "true" : "false"});
      os.InsertRow(
          {"tensorrt_use_calib_mode", trt_use_calib_mode_ ? "true" : "false"});

      // dynamic_shape
      os.InsertRow({"tensorrt_enable_dynamic_shape",
                    min_input_shape_.empty() ? "false" : "true"});
883 884 885
      os.InsertRow({"tensorrt_tuned_dynamic_shape", trt_tuned_dynamic_shape_
                                                        ? shape_range_info_path_
                                                        : "false"});
886 887

      os.InsertRow({"tensorrt_use_oss", trt_use_oss_ ? "true" : "false"});
888 889
      os.InsertRow({"tensorrt_with_interleaved",
                    trt_with_interleaved_ ? "true" : "false"});
890 891 892 893
      os.InsertRow({"tensorrt_use_dla", trt_use_dla_ ? "true" : "false"});
      if (trt_use_dla_) {
        os.InsertRow({"tensorrt_dla_core", std::to_string(trt_dla_core_)});
      }
894
#endif
895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917
    }
  }
  os.InsetDivider();

  // xpu info
  os.InsertRow({"use_xpu", use_xpu_ ? "true" : "false"});
  if (use_xpu_) {
    os.InsertRow({"xpu_device_id", std::to_string(xpu_device_id_)});
    os.InsertRow(
        {"xpu_l3_workspace_size", std::to_string(xpu_l3_workspace_size_)});
  }
  os.InsetDivider();

  if (use_lite_) {
    os.InsertRow({"use_lite", use_lite_ ? "true" : "false"});
  }

  // ir info
  os.InsertRow({"ir_optim", enable_ir_optim_ ? "true" : "false"});
  os.InsertRow({"ir_debug", ir_debug_ ? "true" : "false"});
  os.InsertRow({"memory_optim", enable_memory_optim_ ? "true" : "false"});
  os.InsertRow({"enable_profile", with_profile_ ? "true" : "false"});
  os.InsertRow({"enable_log", with_glog_info_ ? "true" : "false"});
918 919
  os.InsertRow({"collect_shape_range_info",
                collect_shape_range_info_ ? shape_range_info_path_ : "false"});
920 921 922 923

  return os.PrintTable();
}

924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978
LiteNNAdapterConfig &LiteNNAdapterConfig::SetDeviceNames(
    const std::vector<std::string> &names) {
  nnadapter_device_names = names;
  return *this;
}

LiteNNAdapterConfig &LiteNNAdapterConfig::SetContextProperties(
    const std::string &properties) {
  nnadapter_context_properties = properties;
  return *this;
}

LiteNNAdapterConfig &LiteNNAdapterConfig::SetModelCacheDir(
    const std::string &dir) {
  nnadapter_model_cache_dir = dir;
  return *this;
}

LiteNNAdapterConfig &LiteNNAdapterConfig::SetModelCacheBuffers(
    const std::string &model_cache_token,
    const std::vector<char> &model_cache_buffer) {
  PADDLE_ENFORCE_EQ(model_cache_token.empty(), false,
                    platform::errors::InvalidArgument(
                        "model_cache_token should not be empty."));
  PADDLE_ENFORCE_EQ(model_cache_buffer.empty(), false,
                    platform::errors::InvalidArgument(
                        "model_cache_buffer should not be empty."));
  PADDLE_ENFORCE_EQ(nnadapter_model_cache_buffers.count(model_cache_token),
                    false, platform::errors::InvalidArgument(
                               "model_cache_token has already been set."));

  nnadapter_model_cache_buffers[model_cache_token] = model_cache_buffer;
  return *this;
}

LiteNNAdapterConfig &LiteNNAdapterConfig::SetSubgraphPartitionConfigPath(
    const std::string &path) {
  nnadapter_subgraph_partition_config_path = path;
  return *this;
}

LiteNNAdapterConfig &LiteNNAdapterConfig::SetSubgraphPartitionConfigBuffer(
    const std::string &buffer) {
  nnadapter_subgraph_partition_config_buffer = buffer;
  return *this;
}
LiteNNAdapterConfig &LiteNNAdapterConfig::Enable() {
  use_nnadapter = true;
  return *this;
}
LiteNNAdapterConfig &LiteNNAdapterConfig::Disable() {
  use_nnadapter = false;
  return *this;
}

979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014
void AnalysisConfig::CollectShapeRangeInfo(
    const std::string &shape_range_info_path) {
  LOG(INFO) << "In CollectShapeInfo mode, we will disable optimizations and "
               "collect the shape information of "
            << "all intermediate tensors in the compute graph and calculate "
               "the min_shape, max_shape and opt_shape.";
  collect_shape_range_info_ = true;
  PADDLE_ENFORCE_EQ(shape_range_info_path.empty(), false,
                    platform::errors::InvalidArgument(
                        "The shape_range_info_path should not be empty, please "
                        "re-check the argument."));
  shape_range_info_path_ = shape_range_info_path;
}

const std::string &AnalysisConfig::shape_range_info_path() {
  return shape_range_info_path_;
}

bool AnalysisConfig::shape_range_info_collected() {
  return collect_shape_range_info_;
}

void AnalysisConfig::EnableTunedTensorRtDynamicShape(
    const std::string &shape_range_info_path, bool allow_build_at_runtime) {
  shape_range_info_path_ = shape_range_info_path;
  trt_allow_build_at_runtime_ = allow_build_at_runtime;
  trt_tuned_dynamic_shape_ = true;
}

bool AnalysisConfig::tuned_tensorrt_dynamic_shape() {
  return trt_tuned_dynamic_shape_;
}

bool AnalysisConfig::trt_allow_build_at_runtime() {
  return trt_allow_build_at_runtime_;
}
1015
}  // namespace paddle