analysis_config.cc 36.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include <sstream>
16
#include <string>
17
#include <tuple>
18

19 20
#include "paddle/fluid/inference/api/paddle_analysis_config.h"
#include "paddle/fluid/inference/api/paddle_pass_builder.h"
21
#include "paddle/fluid/inference/utils/table_printer.h"
22
#include "paddle/fluid/platform/cpu_info.h"
23
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
24 25
#include "paddle/fluid/platform/enforce.h"

26 27 28 29
#ifdef PADDLE_WITH_TENSORRT
#include "paddle/fluid/inference/tensorrt/helper.h"
#endif

30
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
31 32 33
DECLARE_uint64(initial_gpu_memory_in_mb);
#endif

34
namespace paddle {
W
wanghuancoder 已提交
35 36
struct MkldnnQuantizerConfig;

37
extern const std::vector<std::string> kTRTSubgraphPasses;
D
denglin-github 已提交
38
extern const std::vector<std::string> kDlnneSubgraphPasses;
石晓伟 已提交
39
extern const std::vector<std::string> kLiteSubgraphPasses;
40

41
PassStrategy *AnalysisConfig::pass_builder() const {
42 43 44 45
  if (!pass_builder_.get()) {
    if (use_gpu_) {
      LOG(INFO) << "Create GPU IR passes";
      pass_builder_.reset(new GpuPassStrategy);
46 47
    } else if (use_xpu_) {
      pass_builder_.reset(new XpuPassStrategy);
W
Wilber 已提交
48 49
    } else if (use_npu_) {
      pass_builder_.reset(new NpuPassStrategy);
J
jianghaicheng 已提交
50 51 52
    } else if (use_ipu_) {
      LOG(INFO) << "Create IPU IR passes";
      pass_builder_.reset(new IpuPassStrategy);
53 54 55 56 57 58 59 60 61 62 63 64
    } else {
      LOG(INFO) << "Create CPU IR passes";
      pass_builder_.reset(new CpuPassStrategy);
    }
  } else if (pass_builder_->use_gpu() ^ use_gpu()) {
    LOG(WARNING) << "The use_gpu flag is not compatible between Config and "
                    "PassBuilder, the flags are "
                 << use_gpu() << " " << pass_builder_->use_gpu();
    LOG(WARNING) << "Please make them compatible, still use the existing "
                    "PassBuilder.";
  }

65 66 67
  return pass_builder_.get();
}

68
AnalysisConfig::AnalysisConfig(const std::string &model_dir) {
69
  model_dir_ = model_dir;
Y
Yan Chunwei 已提交
70 71

  Update();
72
}
73 74
AnalysisConfig::AnalysisConfig(const std::string &prog_file,
                               const std::string &params_file) {
75 76
  prog_file_ = prog_file;
  params_file_ = params_file;
Y
Yan Chunwei 已提交
77 78

  Update();
79
}
80 81
void AnalysisConfig::SetModel(const std::string &prog_file_path,
                              const std::string &params_file_path) {
82 83
  prog_file_ = prog_file_path;
  params_file_ = params_file_path;
Y
Yan Chunwei 已提交
84 85

  Update();
86
}
87

88 89
void AnalysisConfig::EnableUseGpu(uint64_t memory_pool_init_size_mb,
                                  int device_id) {
90
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
91 92
  use_gpu_ = true;
  memory_pool_init_size_mb_ = memory_pool_init_size_mb;
93
  FLAGS_initial_gpu_memory_in_mb = memory_pool_init_size_mb_;
94
  gpu_device_id_ = device_id;
95
#else
Y
Yan Chunwei 已提交
96
  LOG(ERROR) << "Please compile with gpu to EnableGpu()";
97 98
  use_gpu_ = false;
#endif
Y
Yan Chunwei 已提交
99 100 101

  Update();
}
102

103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120
void AnalysisConfig::SetExecStream(void *stream) {
  PADDLE_ENFORCE_NOT_NULL(stream, platform::errors::InvalidArgument(
                                      "`stream` should not be nullptr"));
  exec_stream_ = stream;
  use_external_stream_ = true;
  Update();
}

void *AnalysisConfig::GetExecStream() const {
  PADDLE_ENFORCE_NOT_NULL(exec_stream_, platform::errors::InvalidArgument(
                                            "`stream` should not be nullptr"));
  return exec_stream_;
}

bool AnalysisConfig::external_stream_enabled() const {
  return use_external_stream_;
}

121
void AnalysisConfig::DisableGpu() {
Y
Yan Chunwei 已提交
122 123 124
  use_gpu_ = false;

  Update();
125 126
}

127 128 129 130 131 132 133 134 135 136 137 138 139
void AnalysisConfig::Exp_EnableUseGpuFp16(
    std::unordered_set<std::string> op_list) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
  use_gpu_fp16_ = true;
  gpu_fp16_disabled_op_types_.insert(op_list.begin(), op_list.end());
#else
  LOG(ERROR) << "Please compile with gpu to Exp_EnableUseGpuFp16()";
  use_gpu_fp16_ = false;
#endif

  Update();
}

140 141 142 143 144 145
void AnalysisConfig::DisableFCPadding() {
  use_fc_padding_ = false;

  Update();
}

W
Wilber 已提交
146 147 148 149
void AnalysisConfig::EnableXpu(int l3_workspace_size, bool locked,
                               bool autotune, const std::string &autotune_file,
                               const std::string &precision,
                               bool adaptive_seqlen) {
150 151
  use_xpu_ = true;
  xpu_l3_workspace_size_ = l3_workspace_size;
W
Wilber 已提交
152 153 154 155 156
  xpu_locked_ = locked;
  xpu_autotune_ = autotune;
  xpu_autotune_file_ = autotune_file;
  xpu_precision_ = precision;
  xpu_adaptive_seqlen_ = adaptive_seqlen;
157 158 159
  Update();
}

160 161 162 163 164 165 166 167
void AnalysisConfig::SetXpuDeviceId(int device_id) {
  PADDLE_ENFORCE_EQ(use_xpu_, true,
                    platform::errors::PreconditionNotMet(
                        "Should call EnableXpu before SetXpuDeviceId."));
  xpu_device_id_ = device_id;
  Update();
}

W
Wilber 已提交
168 169 170 171 172 173 174 175 176 177 178
void AnalysisConfig::EnableNpu(int device_id) {
#ifdef PADDLE_WITH_ASCEND_CL
  use_npu_ = true;
  npu_device_id_ = device_id;
#else
  LOG(ERROR) << "Please compile with npu to EnableNpu()";
  use_npu_ = false;
#endif

  Update();
}
179

180 181 182 183 184 185 186 187 188 189 190 191 192
void AnalysisConfig::EnableCustomDevice(const std::string &device_type,
                                        int device_id) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
  use_custom_device_ = true;
  custom_device_id_ = device_id;
  custom_device_type_ = device_type;
#else
  LOG(ERROR) << "Please compile with CustomDevice to EnableCustomDevice()";
  use_custom_device_ = false;
#endif
  Update();
}

193 194 195
void AnalysisConfig::EnableIpu(int ipu_device_num, int ipu_micro_batch_size,
                               bool ipu_enable_pipelining,
                               int ipu_batches_per_step) {
J
jianghaicheng 已提交
196 197 198
  enable_ir_optim_ = true;

  use_ipu_ = true;
199 200
  ipu_device_num_ = ipu_device_num;
  ipu_micro_batch_size_ = ipu_micro_batch_size;
J
jianghaicheng 已提交
201 202
  ipu_enable_pipelining_ = ipu_enable_pipelining;
  ipu_batches_per_step_ = ipu_batches_per_step;
203 204 205 206 207 208 209 210 211 212 213

  Update();
}

void AnalysisConfig::SetIpuConfig(bool ipu_enable_fp16, int ipu_replica_num,
                                  float ipu_available_memory_proportion,
                                  bool ipu_enable_half_partial) {
  ipu_enable_fp16_ = ipu_enable_fp16;
  ipu_replica_num_ = ipu_replica_num;
  ipu_available_memory_proportion_ = ipu_available_memory_proportion;
  ipu_enable_half_partial_ = ipu_enable_half_partial;
J
jianghaicheng 已提交
214 215 216

  Update();
}
W
Wilber 已提交
217

218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244
void AnalysisConfig::EnableONNXRuntime() {
#ifdef PADDLE_WITH_ONNXRUNTIME
  use_onnxruntime_ = true;
#else
  LOG(ERROR) << "Please compile with onnxruntime to EnableONNXRuntime()";
  use_onnxruntime_ = false;
#endif

  Update();
}

void AnalysisConfig::DisableONNXRuntime() {
  use_onnxruntime_ = false;
  Update();
}

void AnalysisConfig::EnableORTOptimization() {
#ifdef PADDLE_WITH_ONNXRUNTIME
  enable_ort_optimization_ = true;
#else
  LOG(ERROR) << "Please compile with onnxruntime to EnableORTOptimization()";
  enable_ort_optimization_ = false;
#endif

  Update();
}

245
AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
246 247 248 249 250 251
#define CP_MEMBER(member__) member__ = other.member__;

  // Model related.
  CP_MEMBER(model_dir_);
  CP_MEMBER(model_from_memory_);  // the memory model reuses prog_file_ and
                                  // params_file_ fields.
252

253
  CP_MEMBER(opt_cache_dir_);
W
Wilber 已提交
254 255
  CP_MEMBER(prog_file_);
  CP_MEMBER(params_file_);
256

257
  CP_MEMBER(use_fc_padding_);
258
  // GPU related.
259
  CP_MEMBER(use_gpu_);
260 261
  CP_MEMBER(use_external_stream_);
  CP_MEMBER(exec_stream_);
262
  CP_MEMBER(use_cudnn_);
263
  CP_MEMBER(gpu_device_id_);
264
  CP_MEMBER(memory_pool_init_size_mb_);
265 266
  CP_MEMBER(use_gpu_fp16_);
  CP_MEMBER(gpu_fp16_disabled_op_types_);
Y
Yan Chunwei 已提交
267 268

  CP_MEMBER(enable_memory_optim_);
S
Sylwester Fraczek 已提交
269
  // TensorRT related.
270 271 272 273
  CP_MEMBER(use_tensorrt_);
  CP_MEMBER(tensorrt_workspace_size_);
  CP_MEMBER(tensorrt_max_batchsize_);
  CP_MEMBER(tensorrt_min_subgraph_size_);
N
nhzlx 已提交
274
  CP_MEMBER(tensorrt_precision_mode_);
275
  CP_MEMBER(trt_disabled_ops_);
276 277
  CP_MEMBER(trt_use_dla_);
  CP_MEMBER(trt_dla_core_);
N
nhzlx 已提交
278
  CP_MEMBER(trt_use_static_engine_);
279
  CP_MEMBER(trt_use_calib_mode_);
280
  CP_MEMBER(trt_use_varseqlen_);
281
  CP_MEMBER(trt_with_interleaved_);
282 283
  CP_MEMBER(tensorrt_transformer_posid_);
  CP_MEMBER(tensorrt_transformer_maskid_);
284 285 286 287
  CP_MEMBER(trt_tuned_dynamic_shape_);
  CP_MEMBER(trt_allow_build_at_runtime_);
  CP_MEMBER(collect_shape_range_info_);
  CP_MEMBER(shape_range_info_path_);
288
  CP_MEMBER(trt_use_inspector_);
D
denglin-github 已提交
289 290 291
  // Dlnne related
  CP_MEMBER(use_dlnne_);
  CP_MEMBER(dlnne_min_subgraph_size_);
S
Sylwester Fraczek 已提交
292
  // MKLDNN related.
293 294
  CP_MEMBER(use_mkldnn_);
  CP_MEMBER(mkldnn_enabled_op_types_);
295
  CP_MEMBER(mkldnn_cache_capacity_);
296 297 298
  // Bfloat16 related.
  CP_MEMBER(use_mkldnn_bfloat16_);
  CP_MEMBER(bfloat16_enabled_op_types_);
299
  // Quantization related.
B
baoachun 已提交
300 301 302
  CP_MEMBER(use_mkldnn_int8_);
  CP_MEMBER(quantize_enabled_op_types_);
  CP_MEMBER(quantize_excluded_op_ids_);
303 304
  CP_MEMBER(use_mkldnn_quantizer_);
  CP_MEMBER(mkldnn_quantizer_config_);
305 306 307
  CP_MEMBER(min_input_shape_);
  CP_MEMBER(max_input_shape_);
  CP_MEMBER(optim_input_shape_);
308
  CP_MEMBER(disable_trt_plugin_fp16_);
309

石晓伟 已提交
310 311 312 313
  CP_MEMBER(use_lite_);
  CP_MEMBER(lite_precision_mode_);
  CP_MEMBER(lite_passes_filter_);
  CP_MEMBER(lite_ops_filter_);
314 315
  CP_MEMBER(lite_zero_copy_);

W
Wilber 已提交
316
  // XPU related.
317
  CP_MEMBER(use_xpu_);
W
Wilber 已提交
318
  CP_MEMBER(xpu_device_id_);
319
  CP_MEMBER(xpu_l3_workspace_size_);
W
Wilber 已提交
320 321 322 323 324
  CP_MEMBER(xpu_locked_);
  CP_MEMBER(xpu_autotune_);
  CP_MEMBER(xpu_autotune_file_);
  CP_MEMBER(xpu_precision_);
  CP_MEMBER(xpu_adaptive_seqlen_);
石晓伟 已提交
325

W
Wilber 已提交
326 327 328
  // NPU related.
  CP_MEMBER(use_npu_);
  CP_MEMBER(npu_device_id_);
329
  CP_MEMBER(nnadapter_config_);
W
Wilber 已提交
330

331 332 333
  // profile related.
  CP_MEMBER(with_profile_);

334 335 336
  // glog related.
  CP_MEMBER(with_glog_info_);

337 338 339 340 341 342 343 344 345 346
  // Ir related.
  CP_MEMBER(enable_ir_optim_);
  CP_MEMBER(use_feed_fetch_ops_);
  CP_MEMBER(ir_debug_);
  CP_MEMBER(specify_input_name_);

  CP_MEMBER(cpu_math_library_num_threads_);

  CP_MEMBER(serialized_info_cache_);

347 348
  CP_MEMBER(thread_local_stream_);

J
jianghaicheng 已提交
349 350 351
  // ipu related
  CP_MEMBER(use_ipu_);
  CP_MEMBER(ipu_device_num_);
352
  CP_MEMBER(ipu_micro_batch_size_);
J
jianghaicheng 已提交
353 354
  CP_MEMBER(ipu_enable_pipelining_);
  CP_MEMBER(ipu_batches_per_step_);
355 356 357 358
  CP_MEMBER(ipu_enable_fp16_);
  CP_MEMBER(ipu_replica_num_);
  CP_MEMBER(ipu_available_memory_proportion_);
  CP_MEMBER(ipu_enable_half_partial_);
J
jianghaicheng 已提交
359

360 361 362
  // fleet exe related
  CP_MEMBER(dist_config_);

363 364 365 366 367
  // custom device related.
  CP_MEMBER(use_custom_device_);
  CP_MEMBER(custom_device_type_);
  CP_MEMBER(custom_device_id_);

368
  if (use_gpu_) {
369 370 371
    PADDLE_ENFORCE_EQ(use_xpu_, false,
                      platform::errors::InvalidArgument(
                          "Only one choice can be made between CPU and XPU."));
372 373
    pass_builder_.reset(new GpuPassStrategy(
        *static_cast<GpuPassStrategy *>(other.pass_builder())));
J
jianghaicheng 已提交
374 375 376
  } else if (use_ipu_) {
    pass_builder_.reset(new IpuPassStrategy(
        *static_cast<IpuPassStrategy *>(other.pass_builder())));
377 378 379
  } else if (use_xpu_) {
    pass_builder_.reset(new XpuPassStrategy(
        *static_cast<XpuPassStrategy *>(other.pass_builder())));
W
Wilber 已提交
380 381 382
  } else if (use_npu_) {
    pass_builder_.reset(new NpuPassStrategy(
        *static_cast<NpuPassStrategy *>(other.pass_builder())));
383 384 385 386 387
  } else {
    pass_builder_.reset(new CpuPassStrategy(
        *static_cast<CpuPassStrategy *>(other.pass_builder())));
  }

388
#undef CP_MEMBER
Y
Yan Chunwei 已提交
389

W
Wilber 已提交
390 391 392 393 394
  Update();
  if (use_tensorrt_) {
    // Update() will reset all the passes, when some tensorRT pass is deleted in
    // other.pass_builder(), it will set again, so we just remove the
    // deleted_pass.
395
    pass_builder_->ClearPasses();
W
Wilber 已提交
396
    auto other_passes = other.pass_builder()->AllPasses();
397 398
    for (auto pass : other_passes) {
      pass_builder_->AppendPass(pass);
W
Wilber 已提交
399
    }
400
  }
D
denglin-github 已提交
401 402 403 404 405 406 407 408 409 410 411 412 413 414 415
  if (use_dlnne_) {
    auto all_passes = kDlnneSubgraphPasses;
    auto other_passes = other.pass_builder()->AllPasses();
    // We should sort them, because the user may call the SwitchIrDebug
    // interface, which will change the pass.
    std::sort(all_passes.begin(), all_passes.end());
    std::sort(other_passes.begin(), other_passes.end());
    std::vector<std::string> deleted_passes;
    std::set_difference(all_passes.begin(), all_passes.end(),
                        other_passes.begin(), other_passes.end(),
                        std::inserter(deleted_passes, deleted_passes.begin()));
    for (auto ps : deleted_passes) {
      pass_builder_->DeletePass(ps);
    }
  }
416 417
}

418
void AnalysisConfig::EnableCUDNN() {
419
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
420 421 422 423 424 425 426 427 428
  use_cudnn_ = use_gpu_;
#else
  LOG(ERROR) << "Please compile with CUDA first to use cuDNN";
  use_cudnn_ = false;
#endif

  Update();
}

429
void AnalysisConfig::EnableMKLDNN() {
430 431 432 433 434 435
#ifdef PADDLE_WITH_MKLDNN
  use_mkldnn_ = true;
#else
  LOG(ERROR) << "Please compile with MKLDNN first to use MKLDNN";
  use_mkldnn_ = false;
#endif
Y
Yan Chunwei 已提交
436 437

  Update();
438 439
}

440 441 442 443 444 445 446 447 448
void AnalysisConfig::SetMkldnnCacheCapacity(int capacity) {
#ifdef PADDLE_WITH_MKLDNN
  mkldnn_cache_capacity_ = capacity;
#else
  LOG(ERROR) << "Please compile with MKLDNN first to set MKLDNN Thread Id";
  mkldnn_cache_capacity_ = 0;
#endif
}

449 450 451 452 453 454 455 456 457 458 459 460 461
void AnalysisConfig::EnableMkldnnQuantizer() {
#ifdef PADDLE_WITH_MKLDNN
  if (!mkldnn_quantizer_config_)
    mkldnn_quantizer_config_.reset(new MkldnnQuantizerConfig());
  use_mkldnn_quantizer_ = true;
#else
  LOG(ERROR) << "Please compile with MKLDNN first to use MkldnnQuantizer";
  use_mkldnn_quantizer_ = false;
#endif

  Update();
}

462 463
void AnalysisConfig::EnableMkldnnBfloat16() {
#ifdef PADDLE_WITH_MKLDNN
464 465
  if (platform::MayIUse(platform::cpu_isa_t::avx512_core)) {
    use_mkldnn_bfloat16_ = true;
466 467 468 469
    LOG(INFO) << "Hardware support for BFLOAT16"
              << (platform::MayIUse(platform::cpu_isa_t::avx512_bf16)
                      ? " is enabled"
                      : " is disabled. Simulation will be used");
470 471 472 473
  } else {
    LOG(INFO) << "CPU does not support BFLOAT16 calculations";
    use_mkldnn_bfloat16_ = false;
  }
474 475 476 477 478 479 480 481
#else
  LOG(ERROR) << "Please compile with MKLDNN first to use MkldnnBfloat16";
  use_mkldnn_bfloat16_ = false;
#endif

  Update();
}

B
baoachun 已提交
482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510
void AnalysisConfig::EnableMkldnnInt8(
    const std::unordered_set<std::string> &op_list) {
#ifdef PADDLE_WITH_MKLDNN
  use_mkldnn_int8_ = true;
  use_fc_padding_ = false;
  if (!op_list.empty()) {
    for (auto &type : op_list) {
      if (!quantize_enabled_op_types_.count(type)) {
        LOG(ERROR) << "There are unsupported operators in the configured "
                      "quantization operator list. The unsupported operator "
                      "is: "
                   << type;
        use_mkldnn_int8_ = false;
        break;
      }
    }
    if (use_mkldnn_int8_) {
      quantize_enabled_op_types_.clear();
      quantize_enabled_op_types_.insert(op_list.begin(), op_list.end());
    }
  }
#else
  LOG(ERROR) << "Please compile with MKLDNN first to use MkldnnInt8";
  use_mkldnn_int8_ = false;
#endif

  Update();
}

511
MkldnnQuantizerConfig *AnalysisConfig::mkldnn_quantizer_config() const {
512
  PADDLE_ENFORCE_NOT_NULL(mkldnn_quantizer_config_,
513 514
                          platform::errors::PreconditionNotMet(
                              "MkldnnQuantizer was not enabled yet."));
515
  return mkldnn_quantizer_config_.get();
516 517
}

518
void AnalysisConfig::EnableTensorRtEngine(
N
nhzlx 已提交
519
    int workspace_size, int max_batch_size, int min_subgraph_size,
520
    AnalysisConfig::Precision precision_mode, bool use_static,
521
    bool use_calib_mode) {
522
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
Y
Yan Chunwei 已提交
523 524 525 526 527
  if (!use_gpu()) {
    LOG(ERROR) << "To use TensorRT engine, please call EnableGpu() first";
    return;
  }

528 529 530
  use_tensorrt_ = true;
  tensorrt_workspace_size_ = workspace_size;
  tensorrt_max_batchsize_ = max_batch_size;
N
nhzlx 已提交
531
  tensorrt_min_subgraph_size_ = min_subgraph_size;
N
nhzlx 已提交
532
  tensorrt_precision_mode_ = precision_mode;
N
nhzlx 已提交
533
  trt_use_static_engine_ = use_static;
534
  trt_use_calib_mode_ = use_calib_mode;
Y
Yan Chunwei 已提交
535

536
  Update();
Y
Yan Chunwei 已提交
537 538 539 540
#else
  LOG(ERROR)
      << "To use TensorRT engine, please compile inference lib with GPU first.";
#endif
541 542
}

D
denglin-github 已提交
543 544 545 546 547 548
void AnalysisConfig::EnableDlnne(int min_subgraph_size) {
  use_dlnne_ = true;
  dlnne_min_subgraph_size_ = min_subgraph_size;
  Update();
}

549 550 551 552 553 554 555 556 557 558 559
void AnalysisConfig::SetTRTDynamicShapeInfo(
    std::map<std::string, std::vector<int>> min_input_shape,
    std::map<std::string, std::vector<int>> max_input_shape,
    std::map<std::string, std::vector<int>> optim_input_shape,
    bool disable_trt_plugin_fp16) {
  min_input_shape_ = min_input_shape;
  max_input_shape_ = max_input_shape;
  optim_input_shape_ = optim_input_shape;
  disable_trt_plugin_fp16_ = disable_trt_plugin_fp16;
}

560 561 562 563 564
void AnalysisConfig::EnableTensorRtDLA(int dla_core) {
  trt_use_dla_ = true;
  trt_dla_core_ = dla_core;
}

565 566
void AnalysisConfig::EnableTensorRtInspector() { trt_use_inspector_ = true; }

567 568 569 570 571
void AnalysisConfig::Exp_DisableTensorRtOPs(
    const std::vector<std::string> &ops) {
  trt_disabled_ops_.insert(trt_disabled_ops_.end(), ops.begin(), ops.end());
}

572
void AnalysisConfig::EnableVarseqlen() { trt_use_varseqlen_ = true; }
573

Y
Yan Chunwei 已提交
574
// TODO(Superjomn) refactor this, buggy.
575
void AnalysisConfig::Update() {
576 577 578
  auto info = SerializeInfoCache();
  if (info == serialized_info_cache_) return;

Y
Yan Chunwei 已提交
579
  // Transfer pass_builder and copy the existing compatible passes.
W
Wilber 已提交
580 581
  if (!pass_builder_ || ((use_gpu() ^ pass_builder_->use_gpu())) ||
      ((use_xpu() ^ pass_builder_->use_xpu())) ||
J
jianghaicheng 已提交
582
      ((use_npu() ^ pass_builder_->use_npu())) ||
583 584
      ((use_ipu() ^ pass_builder_->use_ipu())) ||
      ((use_custom_device() ^ pass_builder_->use_custom_device()))) {
Y
Yan Chunwei 已提交
585 586 587 588 589 590 591
    if (use_gpu()) {
      pass_builder_.reset(new GpuPassStrategy);

      if (use_tensorrt_) {
        // Append after the Affine_channel_conv_fuse pass.
        pass_builder()->InsertPass(3, "tensorrt_subgraph_pass");
      }
J
jianghaicheng 已提交
592 593 594
    } else if (use_ipu()) {
      VLOG(1) << "IpuPassStrategy has been used for new.";
      pass_builder_.reset(new IpuPassStrategy);
595 596 597 598 599 600
    } else if (use_xpu()) {
      PADDLE_ENFORCE_EQ(
          use_gpu(), false,
          platform::errors::InvalidArgument(
              "Only one choice can be made between CPU and XPU."));
      pass_builder_.reset(new XpuPassStrategy);
W
Wilber 已提交
601 602 603 604 605 606
    } else if (use_npu()) {
      PADDLE_ENFORCE_EQ(
          use_gpu(), false,
          platform::errors::InvalidArgument(
              "Only one choice can be made between GPU and NPU."));
      pass_builder_.reset(new NpuPassStrategy);
607 608 609 610 611 612
    } else if (use_custom_device()) {
      PADDLE_ENFORCE_EQ(
          use_gpu(), false,
          platform::errors::InvalidArgument(
              "Only one choice can be made between GPU and CustomDevice."));
      pass_builder_.reset(new CustomDevicePassStrategy);
Y
Yan Chunwei 已提交
613 614 615
    } else {
      pass_builder_.reset(new CpuPassStrategy);
    }
616

617
  } else {
Y
Yan Chunwei 已提交
618 619 620
    if (use_gpu()) {
      pass_builder_.reset(new GpuPassStrategy(
          *static_cast<GpuPassStrategy *>(pass_builder_.get())));
J
jianghaicheng 已提交
621 622 623 624
    } else if (use_ipu()) {
      VLOG(1) << "IpuPassStrategy has been used.";
      pass_builder_.reset(new IpuPassStrategy(
          *static_cast<IpuPassStrategy *>(pass_builder_.get())));
625 626 627 628 629 630 631
    } else if (use_xpu()) {
      PADDLE_ENFORCE_EQ(
          use_gpu(), false,
          platform::errors::InvalidArgument(
              "Only one choice can be made between CPU and XPU."));
      pass_builder_.reset(new XpuPassStrategy(
          *static_cast<XpuPassStrategy *>(pass_builder_.get())));
W
Wilber 已提交
632 633 634 635 636 637 638
    } else if (use_npu()) {
      PADDLE_ENFORCE_EQ(
          use_gpu(), false,
          platform::errors::InvalidArgument(
              "Only one choice can be made between GPU and NPU."));
      pass_builder_.reset(new NpuPassStrategy(
          *static_cast<NpuPassStrategy *>(pass_builder_.get())));
639 640 641 642 643 644 645
    } else if (use_custom_device()) {
      PADDLE_ENFORCE_EQ(
          use_gpu(), false,
          platform::errors::InvalidArgument(
              "Only one choice can be made between GPU and CustomDevice."));
      pass_builder_.reset(new CustomDevicePassStrategy(
          *static_cast<CustomDevicePassStrategy *>(pass_builder_.get())));
Y
Yan Chunwei 已提交
646 647 648 649
    } else {
      pass_builder_.reset(new CpuPassStrategy(
          *static_cast<CpuPassStrategy *>(pass_builder_.get())));
    }
650 651 652
  }

  if (use_tensorrt_) {
653 654
    pass_builder()->ClearPasses();
    for (const auto &pass : kTRTSubgraphPasses) {
655
      if (tensorrt_precision_mode_ == AnalysisConfig::Precision::kInt8 &&
656
          (pass == "conv_bn_fuse_pass")) {
657 658
        continue;
      }
659
      pass_builder()->AppendPass(pass);
660 661
    }
  }
662

D
denglin-github 已提交
663 664 665 666 667 668 669
  if (use_dlnne_) {
    pass_builder()->ClearPasses();
    for (const auto &pass : kDlnneSubgraphPasses) {
      pass_builder()->AppendPass(pass);
    }
  }

670
  if (use_gpu() && use_cudnn_) {
671
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
672 673 674 675 676 677 678 679
    if (!enable_ir_optim_) {
      LOG(ERROR) << "EnableCUDNN() only works when IR optimization is enabled.";
    } else {
      pass_builder()->EnableCUDNN();
    }
#endif
  }

680 681 682 683 684 685 686 687 688 689 690 691 692 693
  if (use_gpu_fp16_) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
    if (!enable_ir_optim_) {
      LOG(ERROR) << "Exp_EnableUseGpuFp16() only works when IR optimization is "
                    "enabled.";
    } else if (!use_gpu()) {
      LOG(ERROR)
          << "Exp_EnableUseGpuFp16() only works when use_gpu is enabled.";
    } else {
      pass_builder()->Exp_EnableUseGpuFp16();
    }
#endif
  }

694
  if (use_mkldnn_) {
W
Wojciech Uss 已提交
695
#ifdef PADDLE_WITH_MKLDNN
696 697 698
    if (!enable_ir_optim_) {
      LOG(ERROR)
          << "EnableMKLDNN() only works when IR optimization is enabled.";
W
Wojciech Uss 已提交
699 700
    } else {
      pass_builder()->EnableMKLDNN();
701 702 703 704
    }
#endif
  }

705 706 707 708 709
  // Quantization passes must come after all other optimization passes
  if (use_mkldnn_quantizer_) {
    if (!enable_ir_optim_) {
      LOG(ERROR) << "EnableMkldnnQuantizer() only works when IR optimization "
                    "is enabled.";
710 711
    }
#ifdef PADDLE_WITH_MKLDNN
712
    pass_builder()->EnableMkldnnQuantizer();
713 714 715
#endif
  }

716 717 718 719 720 721
  if (use_mkldnn_bfloat16_) {
#ifdef PADDLE_WITH_MKLDNN
    pass_builder()->EnableMkldnnBfloat16();
#endif
  }

B
baoachun 已提交
722 723 724 725 726 727 728 729 730 731 732 733 734 735
  if (use_mkldnn_int8_) {
#ifdef PADDLE_WITH_MKLDNN
    if (!enable_ir_optim_) {
      LOG(ERROR) << "EnableMkldnnInt8() only works when IR optimization "
                    "is enabled.";
    } else if (!use_mkldnn_) {
      LOG(ERROR) << "EnableMkldnnInt8() only works when MKLDNN "
                    "is enabled.";
    } else {
      pass_builder()->EnableMkldnnInt8();
    }
#endif
  }

736
#ifdef PADDLE_WITH_MKLDNN
737 738
  // Do not optimize when mkldnn is on
  if (enable_memory_optim_ && !use_mkldnn_) {
739
#else
Y
Yan Chunwei 已提交
740
  if (enable_memory_optim_) {
741 742
#endif
    pass_builder()->AppendAnalysisPass("memory_optimize_pass");
Y
Yan Chunwei 已提交
743 744
  }

石晓伟 已提交
745 746 747 748 749 750 751 752 753 754 755 756 757 758
  if (use_lite_) {
#ifndef PADDLE_WITH_LITE
    LOG(WARNING) << "You tried to enable the lite subgraph "
                    "but did not have the option -DWITH_LITE compiled.";
#endif
    pass_builder()->ClearPasses();
    for (const auto &pass : kLiteSubgraphPasses) {
      if (std::find(lite_passes_filter_.begin(), lite_passes_filter_.end(),
                    pass) == lite_passes_filter_.end()) {
        pass_builder()->AppendPass(pass);
      }
    }
  }

759
  if (use_xpu_) {
760
#if (defined LITE_SUBGRAPH_WITH_XPU) || (defined PADDLE_WITH_XPU)
761 762 763 764
    PADDLE_ENFORCE_EQ(use_gpu_, false,
                      platform::errors::Unavailable(
                          "Currently, XPU and GPU cannot be enabled in the "
                          "same analysis configuration."));
765 766 767 768 769
#else
    PADDLE_THROW(platform::errors::Unavailable(
        "You tried to use an XPU device, but Paddle was not compiled "
        "with XPU-runtime."));
#endif
770 771
  }

W
Wilber 已提交
772
  if (use_npu_) {
773
#if defined(PADDLE_WITH_ASCEND_CL) || defined(LITE_SUBGRAPH_WITH_NPU)
W
Wilber 已提交
774 775 776 777 778 779 780 781 782 783
    PADDLE_ENFORCE_EQ(use_gpu_, false,
                      platform::errors::Unavailable(
                          "Currently, NPU and GPU cannot be enabled in the "
                          "same analysis configuration."));
#else
    PADDLE_THROW(platform::errors::Unavailable(
        "You tried to use an NPU device, but Paddle was not compiled "
        "with NPU-runtime."));
#endif
  }
J
jianghaicheng 已提交
784 785 786 787 788 789 790
  if (use_ipu_) {
#ifndef PADDLE_WITH_IPU
    PADDLE_THROW(platform::errors::Unavailable(
        "You tried to enable the ipu "
        "but did not have the option -DWITH_IPU compiled."));
#endif
  }
791 792 793 794 795 796 797
  if (use_custom_device_) {
#ifndef PADDLE_WITH_CUSTOM_DEVICE
    PADDLE_THROW(platform::errors::Unavailable(
        "You tried to enable the custom device "
        "but did not have the option -DWITH_CUSTOM_DEVICE compiled."));
#endif
  }
798 799 800 801 802
  if (ir_debug_) {
    pass_builder()->TurnOnDebug();
  }
}

803
std::string AnalysisConfig::SerializeInfoCache() {
804
  std::stringstream ss;
Y
Yan Chunwei 已提交
805 806 807 808
  ss << model_dir_;
  ss << prog_file_;
  ss << params_file_;

809
  ss << use_gpu_;
810 811
  ss << use_external_stream_;
  ss << exec_stream_;
812 813
  ss << use_gpu_fp16_;
  for (auto &item : gpu_fp16_disabled_op_types_) ss << item;
814
  ss << use_fc_padding_;
815 816
  ss << gpu_device_id_;
  ss << xpu_device_id_;
817 818 819 820 821
  ss << memory_pool_init_size_mb_;

  ss << use_tensorrt_;
  ss << tensorrt_workspace_size_;
  ss << tensorrt_max_batchsize_;
Y
Yan Chunwei 已提交
822 823
  ss << tensorrt_min_subgraph_size_;

D
denglin-github 已提交
824 825 826
  ss << use_dlnne_;
  ss << dlnne_min_subgraph_size_;

827 828 829
  for (auto &op : trt_disabled_ops_) ss << op.c_str();
  ss << ";";

830 831 832
  ss << trt_use_dla_;
  ss << trt_dla_core_;

Y
Yan Chunwei 已提交
833
  ss << enable_memory_optim_;
834 835

  ss << use_mkldnn_;
836
  ss << mkldnn_cache_capacity_;
Y
Yan Chunwei 已提交
837 838 839
  for (auto &item : mkldnn_enabled_op_types_) ss << item;
  ss << ";";

840
  ss << use_mkldnn_quantizer_;
841
  ss << use_mkldnn_bfloat16_;
842
  for (auto &item : bfloat16_enabled_op_types_) ss << item;
B
baoachun 已提交
843 844 845
  ss << use_mkldnn_int8_;
  for (auto &item : quantize_enabled_op_types_) ss << item;
  for (auto &item : quantize_excluded_op_ids_) ss << item;
846
  ss << ";";
Y
Yan Chunwei 已提交
847 848
  ss << model_from_memory_;

849 850
  ss << with_profile_;

851 852
  ss << with_glog_info_;

853 854 855 856
  ss << enable_ir_optim_;
  ss << use_feed_fetch_ops_;
  ss << ir_debug_;

Y
Yan Chunwei 已提交
857 858
  ss << specify_input_name_;
  ss << cpu_math_library_num_threads_;
石晓伟 已提交
859 860

  ss << use_lite_;
861 862
  ss << use_xpu_;
  ss << xpu_l3_workspace_size_;
W
Wilber 已提交
863 864 865 866 867
  ss << xpu_locked_;
  ss << xpu_autotune_;
  ss << xpu_autotune_file_;
  ss << xpu_precision_;
  ss << xpu_adaptive_seqlen_;
868

W
Wilber 已提交
869 870 871
  ss << use_npu_;
  ss << npu_device_id_;

872 873
  ss << thread_local_stream_;

J
jianghaicheng 已提交
874 875
  ss << use_ipu_;
  ss << ipu_device_num_;
876
  ss << ipu_micro_batch_size_;
J
jianghaicheng 已提交
877 878
  ss << ipu_enable_pipelining_;
  ss << ipu_batches_per_step_;
879 880 881 882
  ss << ipu_enable_fp16_;
  ss << ipu_replica_num_;
  ss << ipu_available_memory_proportion_;
  ss << ipu_enable_half_partial_;
J
jianghaicheng 已提交
883

884 885 886
  return ss.str();
}

887
void AnalysisConfig::SetCpuMathLibraryNumThreads(
888 889
    int cpu_math_library_num_threads) {
  cpu_math_library_num_threads_ = cpu_math_library_num_threads;
Y
Yan Chunwei 已提交
890 891

  Update();
892 893
}

894
float AnalysisConfig::fraction_of_gpu_memory_for_pool() const {
895
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
896 897
  // Get the GPU memory details and calculate the fraction of memory for the
  // GPU memory pool.
898
  size_t gpu_total, gpu_available;
899
  platform::SetDeviceId(gpu_device_id_);
900 901
  platform::GpuMemoryUsage(&gpu_available, &gpu_total);
  double total_gpu_memory = gpu_total / 1024. / 1024.;
902 903
  float fraction_of_gpu_memory =
      static_cast<double>(memory_pool_init_size_mb()) / total_gpu_memory;
904 905 906 907
  VLOG(3) << "total_gpu_memory is " << total_gpu_memory
          << "M, gpu_available is " << gpu_available / 1024. / 1024.
          << "M, memory_pool_init_size is " << memory_pool_init_size_mb()
          << "M.";
908 909 910 911
  return fraction_of_gpu_memory;
#else
  return 0.;
#endif
912 913
}

914 915
void AnalysisConfig::EnableMemoryOptim(bool x) {
  enable_memory_optim_ = x;
Y
Yan Chunwei 已提交
916 917 918
  Update();
}

919
bool AnalysisConfig::enable_memory_optim() const {
Y
Yan Chunwei 已提交
920 921 922
  return enable_memory_optim_;
}

923 924 925 926
void AnalysisConfig::SetModelBuffer(const char *prog_buffer,
                                    size_t prog_buffer_size,
                                    const char *param_buffer,
                                    size_t param_buffer_size) {
927 928
  prog_file_ = std::string(prog_buffer, prog_buffer + prog_buffer_size);
  params_file_ = std::string(param_buffer, param_buffer + param_buffer_size);
T
Tao Luo 已提交
929
  model_from_memory_ = true;
T
Tao Luo 已提交
930 931
}

932
NativeConfig AnalysisConfig::ToNativeConfig() const {
Y
Yan Chunwei 已提交
933 934 935 936 937
  NativeConfig config;
  config.model_dir = model_dir_;
  config.prog_file = prog_file_;
  config.param_file = params_file_;
  config.use_gpu = use_gpu_;
938
  config.device = gpu_device_id_;
Y
Yan Chunwei 已提交
939 940 941 942 943
  config.fraction_of_gpu_memory = fraction_of_gpu_memory_for_pool();
  config.specify_input_name = specify_input_name_;
  return config;
}

Y
Yan Chunwei 已提交
944 945 946 947
void AnalysisConfig::SwitchIrDebug(int x) {
  ir_debug_ = x;
  Update();
}
948 949 950 951 952 953

void AnalysisConfig::EnableProfile() {
  with_profile_ = true;
  Update();
}

954 955 956 957 958
void AnalysisConfig::DisableGlogInfo() {
  with_glog_info_ = false;
  Update();
}

石晓伟 已提交
959
void AnalysisConfig::EnableLiteEngine(
960
    AnalysisConfig::Precision precision_mode, bool zero_copy,
石晓伟 已提交
961 962 963 964 965 966
    const std::vector<std::string> &passes_filter,
    const std::vector<std::string> &ops_filter) {
  use_lite_ = true;
  lite_precision_mode_ = precision_mode;
  lite_passes_filter_ = passes_filter;
  lite_ops_filter_ = ops_filter;
967
  lite_zero_copy_ = zero_copy;
石晓伟 已提交
968 969 970
  Update();
}

971 972 973 974 975 976 977
void AnalysisConfig::PartiallyRelease() {
  prog_file_.clear();
  prog_file_.shrink_to_fit();
  params_file_.clear();
  params_file_.shrink_to_fit();
}

978 979
void AnalysisConfig::EnableGpuMultiStream() { thread_local_stream_ = true; }

980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998
std::string AnalysisConfig::Summary() {
  const std::vector<std::string> header{"Option", "Value"};
  paddle::inference::TablePrinter os(header);

  if (!model_dir_.empty()) {
    os.InsertRow({"model_dir", model_dir_});
  }
  if (!(prog_file_.empty() && params_file_.empty())) {
    os.InsertRow({"model_file", prog_file_});
    os.InsertRow({"params_file", params_file_});
  }
  if (model_from_memory_) {
    os.InsertRow({"model_from_memory", params_file_});
  }
  os.InsetDivider();

  // cpu info
  os.InsertRow(
      {"cpu_math_thread", std::to_string(cpu_math_library_num_threads_)});
999
  os.InsertRow({"enable_mkldnn", use_mkldnn_ ? "true" : "false"});
1000 1001 1002 1003 1004 1005 1006 1007 1008 1009
  os.InsertRow(
      {"mkldnn_cache_capacity", std::to_string(mkldnn_cache_capacity_)});
  os.InsetDivider();

  // gpu info
  os.InsertRow({"use_gpu", use_gpu_ ? "true" : "false"});
  if (use_gpu_) {
    os.InsertRow({"gpu_device_id", std::to_string(gpu_device_id_)});
    os.InsertRow({"memory_pool_init_size",
                  std::to_string(memory_pool_init_size_mb_) + "MB"});
1010 1011
    os.InsertRow(
        {"use_external_stream", use_external_stream_ ? "true" : "false"});
1012 1013 1014 1015 1016
    os.InsertRow(
        {"thread_local_stream", thread_local_stream_ ? "true" : "false"});

    os.InsertRow({"use_tensorrt", use_tensorrt_ ? "true" : "false"});
    if (use_tensorrt_) {
1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043
#ifdef PADDLE_WITH_TENSORRT
      auto Precision2String =
          [](paddle::AnalysisConfig::Precision prec) -> std::string {
        if (prec == Precision::kFloat32)
          return "fp32";
        else if (prec == Precision::kHalf)
          return "fp16";
        else if (prec == Precision::kInt8)
          return "int8";
        else
          return "None";
      };
      auto version2string =
          [](const std::tuple<int, int, int> &ver) -> std::string {
        std::ostringstream os;
        int major = std::get<0>(ver);
        int minor = std::get<1>(ver);
        int patch = std::get<2>(ver);
        os << major << "." << minor << "." << patch;
        return os.str();
      };
      os.InsertRow(
          {"trt_compile_version",
           version2string(inference::tensorrt::GetTrtCompileVersion())});
      os.InsertRow(
          {"trt_runtime_version",
           version2string(inference::tensorrt::GetTrtRuntimeVersion())});
1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059
      os.InsertRow({"tensorrt_precision_mode",
                    Precision2String(tensorrt_precision_mode_)});
      os.InsertRow({"tensorrt_workspace_size",
                    std::to_string(tensorrt_workspace_size_)});
      os.InsertRow(
          {"tensorrt_max_batch_size", std::to_string(tensorrt_max_batchsize_)});
      os.InsertRow({"tensorrt_min_subgraph_size",
                    std::to_string(tensorrt_min_subgraph_size_)});
      os.InsertRow({"tensorrt_use_static_engine",
                    trt_use_static_engine_ ? "true" : "false"});
      os.InsertRow(
          {"tensorrt_use_calib_mode", trt_use_calib_mode_ ? "true" : "false"});

      // dynamic_shape
      os.InsertRow({"tensorrt_enable_dynamic_shape",
                    min_input_shape_.empty() ? "false" : "true"});
1060 1061 1062
      os.InsertRow({"tensorrt_tuned_dynamic_shape", trt_tuned_dynamic_shape_
                                                        ? shape_range_info_path_
                                                        : "false"});
1063

1064 1065
      os.InsertRow(
          {"tensorrt_use_varseqlen", trt_use_varseqlen_ ? "true" : "false"});
1066 1067
      os.InsertRow({"tensorrt_with_interleaved",
                    trt_with_interleaved_ ? "true" : "false"});
1068 1069 1070
      os.InsertRow({"tensorrt_transformer_posid", tensorrt_transformer_posid_});
      os.InsertRow(
          {"tensorrt_transformer_maskid", tensorrt_transformer_maskid_});
1071 1072 1073 1074
      os.InsertRow({"tensorrt_use_dla", trt_use_dla_ ? "true" : "false"});
      if (trt_use_dla_) {
        os.InsertRow({"tensorrt_dla_core", std::to_string(trt_dla_core_)});
      }
1075
#endif
1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098
    }
  }
  os.InsetDivider();

  // xpu info
  os.InsertRow({"use_xpu", use_xpu_ ? "true" : "false"});
  if (use_xpu_) {
    os.InsertRow({"xpu_device_id", std::to_string(xpu_device_id_)});
    os.InsertRow(
        {"xpu_l3_workspace_size", std::to_string(xpu_l3_workspace_size_)});
  }
  os.InsetDivider();

  if (use_lite_) {
    os.InsertRow({"use_lite", use_lite_ ? "true" : "false"});
  }

  // ir info
  os.InsertRow({"ir_optim", enable_ir_optim_ ? "true" : "false"});
  os.InsertRow({"ir_debug", ir_debug_ ? "true" : "false"});
  os.InsertRow({"memory_optim", enable_memory_optim_ ? "true" : "false"});
  os.InsertRow({"enable_profile", with_profile_ ? "true" : "false"});
  os.InsertRow({"enable_log", with_glog_info_ ? "true" : "false"});
1099 1100
  os.InsertRow({"collect_shape_range_info",
                collect_shape_range_info_ ? shape_range_info_path_ : "false"});
1101 1102 1103 1104

  return os.PrintTable();
}

1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132
LiteNNAdapterConfig &LiteNNAdapterConfig::SetDeviceNames(
    const std::vector<std::string> &names) {
  nnadapter_device_names = names;
  return *this;
}

LiteNNAdapterConfig &LiteNNAdapterConfig::SetContextProperties(
    const std::string &properties) {
  nnadapter_context_properties = properties;
  return *this;
}

LiteNNAdapterConfig &LiteNNAdapterConfig::SetModelCacheDir(
    const std::string &dir) {
  nnadapter_model_cache_dir = dir;
  return *this;
}

LiteNNAdapterConfig &LiteNNAdapterConfig::SetModelCacheBuffers(
    const std::string &model_cache_token,
    const std::vector<char> &model_cache_buffer) {
  PADDLE_ENFORCE_EQ(model_cache_token.empty(), false,
                    platform::errors::InvalidArgument(
                        "model_cache_token should not be empty."));
  PADDLE_ENFORCE_EQ(model_cache_buffer.empty(), false,
                    platform::errors::InvalidArgument(
                        "model_cache_buffer should not be empty."));
  PADDLE_ENFORCE_EQ(nnadapter_model_cache_buffers.count(model_cache_token),
1133 1134 1135
                    false,
                    platform::errors::InvalidArgument(
                        "model_cache_token has already been set."));
1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160

  nnadapter_model_cache_buffers[model_cache_token] = model_cache_buffer;
  return *this;
}

LiteNNAdapterConfig &LiteNNAdapterConfig::SetSubgraphPartitionConfigPath(
    const std::string &path) {
  nnadapter_subgraph_partition_config_path = path;
  return *this;
}

LiteNNAdapterConfig &LiteNNAdapterConfig::SetSubgraphPartitionConfigBuffer(
    const std::string &buffer) {
  nnadapter_subgraph_partition_config_buffer = buffer;
  return *this;
}
LiteNNAdapterConfig &LiteNNAdapterConfig::Enable() {
  use_nnadapter = true;
  return *this;
}
LiteNNAdapterConfig &LiteNNAdapterConfig::Disable() {
  use_nnadapter = false;
  return *this;
}

1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196
void AnalysisConfig::CollectShapeRangeInfo(
    const std::string &shape_range_info_path) {
  LOG(INFO) << "In CollectShapeInfo mode, we will disable optimizations and "
               "collect the shape information of "
            << "all intermediate tensors in the compute graph and calculate "
               "the min_shape, max_shape and opt_shape.";
  collect_shape_range_info_ = true;
  PADDLE_ENFORCE_EQ(shape_range_info_path.empty(), false,
                    platform::errors::InvalidArgument(
                        "The shape_range_info_path should not be empty, please "
                        "re-check the argument."));
  shape_range_info_path_ = shape_range_info_path;
}

const std::string &AnalysisConfig::shape_range_info_path() {
  return shape_range_info_path_;
}

bool AnalysisConfig::shape_range_info_collected() {
  return collect_shape_range_info_;
}

void AnalysisConfig::EnableTunedTensorRtDynamicShape(
    const std::string &shape_range_info_path, bool allow_build_at_runtime) {
  shape_range_info_path_ = shape_range_info_path;
  trt_allow_build_at_runtime_ = allow_build_at_runtime;
  trt_tuned_dynamic_shape_ = true;
}

bool AnalysisConfig::tuned_tensorrt_dynamic_shape() {
  return trt_tuned_dynamic_shape_;
}

bool AnalysisConfig::trt_allow_build_at_runtime() {
  return trt_allow_build_at_runtime_;
}
1197
}  // namespace paddle