analysis_config.cc 35.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include <sstream>
16
#include <string>
17
#include <tuple>
18 19
#include "paddle/fluid/inference/api/paddle_analysis_config.h"
#include "paddle/fluid/inference/api/paddle_pass_builder.h"
20
#include "paddle/fluid/inference/utils/table_printer.h"
21
#include "paddle/fluid/platform/cpu_info.h"
22
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
23 24
#include "paddle/fluid/platform/enforce.h"

25 26 27 28
#ifdef PADDLE_WITH_TENSORRT
#include "paddle/fluid/inference/tensorrt/helper.h"
#endif

29
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
30 31 32
DECLARE_uint64(initial_gpu_memory_in_mb);
#endif

33
namespace paddle {
W
wanghuancoder 已提交
34 35
struct MkldnnQuantizerConfig;

36
extern const std::vector<std::string> kTRTSubgraphPasses;
D
denglin-github 已提交
37
extern const std::vector<std::string> kDlnneSubgraphPasses;
石晓伟 已提交
38
extern const std::vector<std::string> kLiteSubgraphPasses;
39

40
PassStrategy *AnalysisConfig::pass_builder() const {
41 42 43 44
  if (!pass_builder_.get()) {
    if (use_gpu_) {
      LOG(INFO) << "Create GPU IR passes";
      pass_builder_.reset(new GpuPassStrategy);
45 46
    } else if (use_xpu_) {
      pass_builder_.reset(new XpuPassStrategy);
W
Wilber 已提交
47 48
    } else if (use_npu_) {
      pass_builder_.reset(new NpuPassStrategy);
J
jianghaicheng 已提交
49 50 51
    } else if (use_ipu_) {
      LOG(INFO) << "Create IPU IR passes";
      pass_builder_.reset(new IpuPassStrategy);
52 53 54 55 56 57 58 59 60 61 62 63
    } else {
      LOG(INFO) << "Create CPU IR passes";
      pass_builder_.reset(new CpuPassStrategy);
    }
  } else if (pass_builder_->use_gpu() ^ use_gpu()) {
    LOG(WARNING) << "The use_gpu flag is not compatible between Config and "
                    "PassBuilder, the flags are "
                 << use_gpu() << " " << pass_builder_->use_gpu();
    LOG(WARNING) << "Please make them compatible, still use the existing "
                    "PassBuilder.";
  }

64 65 66
  return pass_builder_.get();
}

67
AnalysisConfig::AnalysisConfig(const std::string &model_dir) {
68
  model_dir_ = model_dir;
Y
Yan Chunwei 已提交
69 70

  Update();
71
}
72 73
AnalysisConfig::AnalysisConfig(const std::string &prog_file,
                               const std::string &params_file) {
74 75
  prog_file_ = prog_file;
  params_file_ = params_file;
Y
Yan Chunwei 已提交
76 77

  Update();
78
}
79 80
void AnalysisConfig::SetModel(const std::string &prog_file_path,
                              const std::string &params_file_path) {
81 82
  prog_file_ = prog_file_path;
  params_file_ = params_file_path;
Y
Yan Chunwei 已提交
83 84

  Update();
85
}
86

87 88
void AnalysisConfig::EnableUseGpu(uint64_t memory_pool_init_size_mb,
                                  int device_id) {
89
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
90 91
  use_gpu_ = true;
  memory_pool_init_size_mb_ = memory_pool_init_size_mb;
92
  FLAGS_initial_gpu_memory_in_mb = memory_pool_init_size_mb_;
93
  gpu_device_id_ = device_id;
94
#else
Y
Yan Chunwei 已提交
95
  LOG(ERROR) << "Please compile with gpu to EnableGpu()";
96 97
  use_gpu_ = false;
#endif
Y
Yan Chunwei 已提交
98 99 100

  Update();
}
101

102
void AnalysisConfig::DisableGpu() {
Y
Yan Chunwei 已提交
103 104 105
  use_gpu_ = false;

  Update();
106 107
}

108 109 110 111 112 113 114 115 116 117 118 119 120
void AnalysisConfig::Exp_EnableUseGpuFp16(
    std::unordered_set<std::string> op_list) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
  use_gpu_fp16_ = true;
  gpu_fp16_disabled_op_types_.insert(op_list.begin(), op_list.end());
#else
  LOG(ERROR) << "Please compile with gpu to Exp_EnableUseGpuFp16()";
  use_gpu_fp16_ = false;
#endif

  Update();
}

121 122 123 124 125 126
void AnalysisConfig::DisableFCPadding() {
  use_fc_padding_ = false;

  Update();
}

W
Wilber 已提交
127 128 129 130
void AnalysisConfig::EnableXpu(int l3_workspace_size, bool locked,
                               bool autotune, const std::string &autotune_file,
                               const std::string &precision,
                               bool adaptive_seqlen) {
131 132
  use_xpu_ = true;
  xpu_l3_workspace_size_ = l3_workspace_size;
W
Wilber 已提交
133 134 135 136 137
  xpu_locked_ = locked;
  xpu_autotune_ = autotune;
  xpu_autotune_file_ = autotune_file;
  xpu_precision_ = precision;
  xpu_adaptive_seqlen_ = adaptive_seqlen;
138 139 140
  Update();
}

141 142 143 144 145 146 147 148
void AnalysisConfig::SetXpuDeviceId(int device_id) {
  PADDLE_ENFORCE_EQ(use_xpu_, true,
                    platform::errors::PreconditionNotMet(
                        "Should call EnableXpu before SetXpuDeviceId."));
  xpu_device_id_ = device_id;
  Update();
}

W
Wilber 已提交
149 150 151 152 153 154 155 156 157 158 159
void AnalysisConfig::EnableNpu(int device_id) {
#ifdef PADDLE_WITH_ASCEND_CL
  use_npu_ = true;
  npu_device_id_ = device_id;
#else
  LOG(ERROR) << "Please compile with npu to EnableNpu()";
  use_npu_ = false;
#endif

  Update();
}
160

161 162 163 164 165 166 167 168 169 170 171 172 173
void AnalysisConfig::EnableCustomDevice(const std::string &device_type,
                                        int device_id) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
  use_custom_device_ = true;
  custom_device_id_ = device_id;
  custom_device_type_ = device_type;
#else
  LOG(ERROR) << "Please compile with CustomDevice to EnableCustomDevice()";
  use_custom_device_ = false;
#endif
  Update();
}

174 175 176
void AnalysisConfig::EnableIpu(int ipu_device_num, int ipu_micro_batch_size,
                               bool ipu_enable_pipelining,
                               int ipu_batches_per_step) {
J
jianghaicheng 已提交
177 178 179
  enable_ir_optim_ = true;

  use_ipu_ = true;
180 181
  ipu_device_num_ = ipu_device_num;
  ipu_micro_batch_size_ = ipu_micro_batch_size;
J
jianghaicheng 已提交
182 183
  ipu_enable_pipelining_ = ipu_enable_pipelining;
  ipu_batches_per_step_ = ipu_batches_per_step;
184 185 186 187 188 189 190 191 192 193 194

  Update();
}

void AnalysisConfig::SetIpuConfig(bool ipu_enable_fp16, int ipu_replica_num,
                                  float ipu_available_memory_proportion,
                                  bool ipu_enable_half_partial) {
  ipu_enable_fp16_ = ipu_enable_fp16;
  ipu_replica_num_ = ipu_replica_num;
  ipu_available_memory_proportion_ = ipu_available_memory_proportion;
  ipu_enable_half_partial_ = ipu_enable_half_partial;
J
jianghaicheng 已提交
195 196 197

  Update();
}
W
Wilber 已提交
198

199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225
void AnalysisConfig::EnableONNXRuntime() {
#ifdef PADDLE_WITH_ONNXRUNTIME
  use_onnxruntime_ = true;
#else
  LOG(ERROR) << "Please compile with onnxruntime to EnableONNXRuntime()";
  use_onnxruntime_ = false;
#endif

  Update();
}

void AnalysisConfig::DisableONNXRuntime() {
  use_onnxruntime_ = false;
  Update();
}

void AnalysisConfig::EnableORTOptimization() {
#ifdef PADDLE_WITH_ONNXRUNTIME
  enable_ort_optimization_ = true;
#else
  LOG(ERROR) << "Please compile with onnxruntime to EnableORTOptimization()";
  enable_ort_optimization_ = false;
#endif

  Update();
}

226
AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
227 228 229 230 231 232
#define CP_MEMBER(member__) member__ = other.member__;

  // Model related.
  CP_MEMBER(model_dir_);
  CP_MEMBER(model_from_memory_);  // the memory model reuses prog_file_ and
                                  // params_file_ fields.
233

234
  CP_MEMBER(opt_cache_dir_);
W
Wilber 已提交
235 236
  CP_MEMBER(prog_file_);
  CP_MEMBER(params_file_);
237

238
  CP_MEMBER(use_fc_padding_);
239
  // GPU related.
240
  CP_MEMBER(use_gpu_);
241
  CP_MEMBER(use_cudnn_);
242
  CP_MEMBER(gpu_device_id_);
243
  CP_MEMBER(memory_pool_init_size_mb_);
244 245
  CP_MEMBER(use_gpu_fp16_);
  CP_MEMBER(gpu_fp16_disabled_op_types_);
Y
Yan Chunwei 已提交
246 247

  CP_MEMBER(enable_memory_optim_);
S
Sylwester Fraczek 已提交
248
  // TensorRT related.
249 250 251 252
  CP_MEMBER(use_tensorrt_);
  CP_MEMBER(tensorrt_workspace_size_);
  CP_MEMBER(tensorrt_max_batchsize_);
  CP_MEMBER(tensorrt_min_subgraph_size_);
N
nhzlx 已提交
253
  CP_MEMBER(tensorrt_precision_mode_);
254
  CP_MEMBER(trt_disabled_ops_);
255 256
  CP_MEMBER(trt_use_dla_);
  CP_MEMBER(trt_dla_core_);
N
nhzlx 已提交
257
  CP_MEMBER(trt_use_static_engine_);
258
  CP_MEMBER(trt_use_calib_mode_);
259
  CP_MEMBER(trt_use_oss_);
260
  CP_MEMBER(trt_with_interleaved_);
261 262 263 264
  CP_MEMBER(trt_tuned_dynamic_shape_);
  CP_MEMBER(trt_allow_build_at_runtime_);
  CP_MEMBER(collect_shape_range_info_);
  CP_MEMBER(shape_range_info_path_);
265
  CP_MEMBER(trt_use_inspector_);
D
denglin-github 已提交
266 267 268
  // Dlnne related
  CP_MEMBER(use_dlnne_);
  CP_MEMBER(dlnne_min_subgraph_size_);
S
Sylwester Fraczek 已提交
269
  // MKLDNN related.
270 271
  CP_MEMBER(use_mkldnn_);
  CP_MEMBER(mkldnn_enabled_op_types_);
272
  CP_MEMBER(mkldnn_cache_capacity_);
273 274 275
  // Bfloat16 related.
  CP_MEMBER(use_mkldnn_bfloat16_);
  CP_MEMBER(bfloat16_enabled_op_types_);
276
  // Quantization related.
B
baoachun 已提交
277 278 279
  CP_MEMBER(use_mkldnn_int8_);
  CP_MEMBER(quantize_enabled_op_types_);
  CP_MEMBER(quantize_excluded_op_ids_);
280 281
  CP_MEMBER(use_mkldnn_quantizer_);
  CP_MEMBER(mkldnn_quantizer_config_);
282 283 284
  CP_MEMBER(min_input_shape_);
  CP_MEMBER(max_input_shape_);
  CP_MEMBER(optim_input_shape_);
285
  CP_MEMBER(disable_trt_plugin_fp16_);
286

石晓伟 已提交
287 288 289 290
  CP_MEMBER(use_lite_);
  CP_MEMBER(lite_precision_mode_);
  CP_MEMBER(lite_passes_filter_);
  CP_MEMBER(lite_ops_filter_);
291 292
  CP_MEMBER(lite_zero_copy_);

W
Wilber 已提交
293
  // XPU related.
294
  CP_MEMBER(use_xpu_);
W
Wilber 已提交
295
  CP_MEMBER(xpu_device_id_);
296
  CP_MEMBER(xpu_l3_workspace_size_);
W
Wilber 已提交
297 298 299 300 301
  CP_MEMBER(xpu_locked_);
  CP_MEMBER(xpu_autotune_);
  CP_MEMBER(xpu_autotune_file_);
  CP_MEMBER(xpu_precision_);
  CP_MEMBER(xpu_adaptive_seqlen_);
石晓伟 已提交
302

W
Wilber 已提交
303 304 305
  // NPU related.
  CP_MEMBER(use_npu_);
  CP_MEMBER(npu_device_id_);
306
  CP_MEMBER(nnadapter_config_);
W
Wilber 已提交
307

308 309 310
  // profile related.
  CP_MEMBER(with_profile_);

311 312 313
  // glog related.
  CP_MEMBER(with_glog_info_);

314 315 316 317 318 319 320 321 322 323
  // Ir related.
  CP_MEMBER(enable_ir_optim_);
  CP_MEMBER(use_feed_fetch_ops_);
  CP_MEMBER(ir_debug_);
  CP_MEMBER(specify_input_name_);

  CP_MEMBER(cpu_math_library_num_threads_);

  CP_MEMBER(serialized_info_cache_);

324 325
  CP_MEMBER(thread_local_stream_);

J
jianghaicheng 已提交
326 327 328
  // ipu related
  CP_MEMBER(use_ipu_);
  CP_MEMBER(ipu_device_num_);
329
  CP_MEMBER(ipu_micro_batch_size_);
J
jianghaicheng 已提交
330 331
  CP_MEMBER(ipu_enable_pipelining_);
  CP_MEMBER(ipu_batches_per_step_);
332 333 334 335
  CP_MEMBER(ipu_enable_fp16_);
  CP_MEMBER(ipu_replica_num_);
  CP_MEMBER(ipu_available_memory_proportion_);
  CP_MEMBER(ipu_enable_half_partial_);
J
jianghaicheng 已提交
336

337 338 339
  // fleet exe related
  CP_MEMBER(dist_config_);

340 341 342 343 344
  // custom device related.
  CP_MEMBER(use_custom_device_);
  CP_MEMBER(custom_device_type_);
  CP_MEMBER(custom_device_id_);

345
  if (use_gpu_) {
346 347 348
    PADDLE_ENFORCE_EQ(use_xpu_, false,
                      platform::errors::InvalidArgument(
                          "Only one choice can be made between CPU and XPU."));
349 350
    pass_builder_.reset(new GpuPassStrategy(
        *static_cast<GpuPassStrategy *>(other.pass_builder())));
J
jianghaicheng 已提交
351 352 353
  } else if (use_ipu_) {
    pass_builder_.reset(new IpuPassStrategy(
        *static_cast<IpuPassStrategy *>(other.pass_builder())));
354 355 356
  } else if (use_xpu_) {
    pass_builder_.reset(new XpuPassStrategy(
        *static_cast<XpuPassStrategy *>(other.pass_builder())));
W
Wilber 已提交
357 358 359
  } else if (use_npu_) {
    pass_builder_.reset(new NpuPassStrategy(
        *static_cast<NpuPassStrategy *>(other.pass_builder())));
360 361 362 363 364
  } else {
    pass_builder_.reset(new CpuPassStrategy(
        *static_cast<CpuPassStrategy *>(other.pass_builder())));
  }

365
#undef CP_MEMBER
Y
Yan Chunwei 已提交
366

W
Wilber 已提交
367 368 369 370 371
  Update();
  if (use_tensorrt_) {
    // Update() will reset all the passes, when some tensorRT pass is deleted in
    // other.pass_builder(), it will set again, so we just remove the
    // deleted_pass.
372
    pass_builder_->ClearPasses();
W
Wilber 已提交
373
    auto other_passes = other.pass_builder()->AllPasses();
374 375
    for (auto pass : other_passes) {
      pass_builder_->AppendPass(pass);
W
Wilber 已提交
376
    }
377
  }
D
denglin-github 已提交
378 379 380 381 382 383 384 385 386 387 388 389 390 391 392
  if (use_dlnne_) {
    auto all_passes = kDlnneSubgraphPasses;
    auto other_passes = other.pass_builder()->AllPasses();
    // We should sort them, because the user may call the SwitchIrDebug
    // interface, which will change the pass.
    std::sort(all_passes.begin(), all_passes.end());
    std::sort(other_passes.begin(), other_passes.end());
    std::vector<std::string> deleted_passes;
    std::set_difference(all_passes.begin(), all_passes.end(),
                        other_passes.begin(), other_passes.end(),
                        std::inserter(deleted_passes, deleted_passes.begin()));
    for (auto ps : deleted_passes) {
      pass_builder_->DeletePass(ps);
    }
  }
393 394
}

395
void AnalysisConfig::EnableCUDNN() {
396
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
397 398 399 400 401 402 403 404 405
  use_cudnn_ = use_gpu_;
#else
  LOG(ERROR) << "Please compile with CUDA first to use cuDNN";
  use_cudnn_ = false;
#endif

  Update();
}

406
void AnalysisConfig::EnableMKLDNN() {
407 408 409 410 411 412
#ifdef PADDLE_WITH_MKLDNN
  use_mkldnn_ = true;
#else
  LOG(ERROR) << "Please compile with MKLDNN first to use MKLDNN";
  use_mkldnn_ = false;
#endif
Y
Yan Chunwei 已提交
413 414

  Update();
415 416
}

417 418 419 420 421 422 423 424 425
void AnalysisConfig::SetMkldnnCacheCapacity(int capacity) {
#ifdef PADDLE_WITH_MKLDNN
  mkldnn_cache_capacity_ = capacity;
#else
  LOG(ERROR) << "Please compile with MKLDNN first to set MKLDNN Thread Id";
  mkldnn_cache_capacity_ = 0;
#endif
}

426 427 428 429 430 431 432 433 434 435 436 437 438
void AnalysisConfig::EnableMkldnnQuantizer() {
#ifdef PADDLE_WITH_MKLDNN
  if (!mkldnn_quantizer_config_)
    mkldnn_quantizer_config_.reset(new MkldnnQuantizerConfig());
  use_mkldnn_quantizer_ = true;
#else
  LOG(ERROR) << "Please compile with MKLDNN first to use MkldnnQuantizer";
  use_mkldnn_quantizer_ = false;
#endif

  Update();
}

439 440
void AnalysisConfig::EnableMkldnnBfloat16() {
#ifdef PADDLE_WITH_MKLDNN
441 442
  if (platform::MayIUse(platform::cpu_isa_t::avx512_core)) {
    use_mkldnn_bfloat16_ = true;
443 444 445 446
    LOG(INFO) << "Hardware support for BFLOAT16"
              << (platform::MayIUse(platform::cpu_isa_t::avx512_bf16)
                      ? " is enabled"
                      : " is disabled. Simulation will be used");
447 448 449 450
  } else {
    LOG(INFO) << "CPU does not support BFLOAT16 calculations";
    use_mkldnn_bfloat16_ = false;
  }
451 452 453 454 455 456 457 458
#else
  LOG(ERROR) << "Please compile with MKLDNN first to use MkldnnBfloat16";
  use_mkldnn_bfloat16_ = false;
#endif

  Update();
}

B
baoachun 已提交
459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487
void AnalysisConfig::EnableMkldnnInt8(
    const std::unordered_set<std::string> &op_list) {
#ifdef PADDLE_WITH_MKLDNN
  use_mkldnn_int8_ = true;
  use_fc_padding_ = false;
  if (!op_list.empty()) {
    for (auto &type : op_list) {
      if (!quantize_enabled_op_types_.count(type)) {
        LOG(ERROR) << "There are unsupported operators in the configured "
                      "quantization operator list. The unsupported operator "
                      "is: "
                   << type;
        use_mkldnn_int8_ = false;
        break;
      }
    }
    if (use_mkldnn_int8_) {
      quantize_enabled_op_types_.clear();
      quantize_enabled_op_types_.insert(op_list.begin(), op_list.end());
    }
  }
#else
  LOG(ERROR) << "Please compile with MKLDNN first to use MkldnnInt8";
  use_mkldnn_int8_ = false;
#endif

  Update();
}

488
MkldnnQuantizerConfig *AnalysisConfig::mkldnn_quantizer_config() const {
489
  PADDLE_ENFORCE_NOT_NULL(mkldnn_quantizer_config_,
490 491
                          platform::errors::PreconditionNotMet(
                              "MkldnnQuantizer was not enabled yet."));
492
  return mkldnn_quantizer_config_.get();
493 494
}

495
void AnalysisConfig::EnableTensorRtEngine(
N
nhzlx 已提交
496
    int workspace_size, int max_batch_size, int min_subgraph_size,
497
    AnalysisConfig::Precision precision_mode, bool use_static,
498
    bool use_calib_mode) {
499
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
Y
Yan Chunwei 已提交
500 501 502 503 504
  if (!use_gpu()) {
    LOG(ERROR) << "To use TensorRT engine, please call EnableGpu() first";
    return;
  }

505 506 507
  use_tensorrt_ = true;
  tensorrt_workspace_size_ = workspace_size;
  tensorrt_max_batchsize_ = max_batch_size;
N
nhzlx 已提交
508
  tensorrt_min_subgraph_size_ = min_subgraph_size;
N
nhzlx 已提交
509
  tensorrt_precision_mode_ = precision_mode;
N
nhzlx 已提交
510
  trt_use_static_engine_ = use_static;
511
  trt_use_calib_mode_ = use_calib_mode;
Y
Yan Chunwei 已提交
512

513
  Update();
Y
Yan Chunwei 已提交
514 515 516 517
#else
  LOG(ERROR)
      << "To use TensorRT engine, please compile inference lib with GPU first.";
#endif
518 519
}

D
denglin-github 已提交
520 521 522 523 524 525
void AnalysisConfig::EnableDlnne(int min_subgraph_size) {
  use_dlnne_ = true;
  dlnne_min_subgraph_size_ = min_subgraph_size;
  Update();
}

526 527 528 529 530 531 532 533 534 535 536
void AnalysisConfig::SetTRTDynamicShapeInfo(
    std::map<std::string, std::vector<int>> min_input_shape,
    std::map<std::string, std::vector<int>> max_input_shape,
    std::map<std::string, std::vector<int>> optim_input_shape,
    bool disable_trt_plugin_fp16) {
  min_input_shape_ = min_input_shape;
  max_input_shape_ = max_input_shape;
  optim_input_shape_ = optim_input_shape;
  disable_trt_plugin_fp16_ = disable_trt_plugin_fp16;
}

537 538 539 540 541
void AnalysisConfig::EnableTensorRtDLA(int dla_core) {
  trt_use_dla_ = true;
  trt_dla_core_ = dla_core;
}

542 543
void AnalysisConfig::EnableTensorRtInspector() { trt_use_inspector_ = true; }

544 545 546 547 548
void AnalysisConfig::Exp_DisableTensorRtOPs(
    const std::vector<std::string> &ops) {
  trt_disabled_ops_.insert(trt_disabled_ops_.end(), ops.begin(), ops.end());
}

549
void AnalysisConfig::EnableTensorRtOSS() { trt_use_oss_ = true; }
550

Y
Yan Chunwei 已提交
551
// TODO(Superjomn) refactor this, buggy.
552
void AnalysisConfig::Update() {
553 554 555
  auto info = SerializeInfoCache();
  if (info == serialized_info_cache_) return;

Y
Yan Chunwei 已提交
556
  // Transfer pass_builder and copy the existing compatible passes.
W
Wilber 已提交
557 558
  if (!pass_builder_ || ((use_gpu() ^ pass_builder_->use_gpu())) ||
      ((use_xpu() ^ pass_builder_->use_xpu())) ||
J
jianghaicheng 已提交
559
      ((use_npu() ^ pass_builder_->use_npu())) ||
560 561
      ((use_ipu() ^ pass_builder_->use_ipu())) ||
      ((use_custom_device() ^ pass_builder_->use_custom_device()))) {
Y
Yan Chunwei 已提交
562 563 564 565 566 567 568
    if (use_gpu()) {
      pass_builder_.reset(new GpuPassStrategy);

      if (use_tensorrt_) {
        // Append after the Affine_channel_conv_fuse pass.
        pass_builder()->InsertPass(3, "tensorrt_subgraph_pass");
      }
J
jianghaicheng 已提交
569 570 571
    } else if (use_ipu()) {
      VLOG(1) << "IpuPassStrategy has been used for new.";
      pass_builder_.reset(new IpuPassStrategy);
572 573 574 575 576 577
    } else if (use_xpu()) {
      PADDLE_ENFORCE_EQ(
          use_gpu(), false,
          platform::errors::InvalidArgument(
              "Only one choice can be made between CPU and XPU."));
      pass_builder_.reset(new XpuPassStrategy);
W
Wilber 已提交
578 579 580 581 582 583
    } else if (use_npu()) {
      PADDLE_ENFORCE_EQ(
          use_gpu(), false,
          platform::errors::InvalidArgument(
              "Only one choice can be made between GPU and NPU."));
      pass_builder_.reset(new NpuPassStrategy);
584 585 586 587 588 589
    } else if (use_custom_device()) {
      PADDLE_ENFORCE_EQ(
          use_gpu(), false,
          platform::errors::InvalidArgument(
              "Only one choice can be made between GPU and CustomDevice."));
      pass_builder_.reset(new CustomDevicePassStrategy);
Y
Yan Chunwei 已提交
590 591 592
    } else {
      pass_builder_.reset(new CpuPassStrategy);
    }
593

594
  } else {
Y
Yan Chunwei 已提交
595 596 597
    if (use_gpu()) {
      pass_builder_.reset(new GpuPassStrategy(
          *static_cast<GpuPassStrategy *>(pass_builder_.get())));
J
jianghaicheng 已提交
598 599 600 601
    } else if (use_ipu()) {
      VLOG(1) << "IpuPassStrategy has been used.";
      pass_builder_.reset(new IpuPassStrategy(
          *static_cast<IpuPassStrategy *>(pass_builder_.get())));
602 603 604 605 606 607 608
    } else if (use_xpu()) {
      PADDLE_ENFORCE_EQ(
          use_gpu(), false,
          platform::errors::InvalidArgument(
              "Only one choice can be made between CPU and XPU."));
      pass_builder_.reset(new XpuPassStrategy(
          *static_cast<XpuPassStrategy *>(pass_builder_.get())));
W
Wilber 已提交
609 610 611 612 613 614 615
    } else if (use_npu()) {
      PADDLE_ENFORCE_EQ(
          use_gpu(), false,
          platform::errors::InvalidArgument(
              "Only one choice can be made between GPU and NPU."));
      pass_builder_.reset(new NpuPassStrategy(
          *static_cast<NpuPassStrategy *>(pass_builder_.get())));
616 617 618 619 620 621 622
    } else if (use_custom_device()) {
      PADDLE_ENFORCE_EQ(
          use_gpu(), false,
          platform::errors::InvalidArgument(
              "Only one choice can be made between GPU and CustomDevice."));
      pass_builder_.reset(new CustomDevicePassStrategy(
          *static_cast<CustomDevicePassStrategy *>(pass_builder_.get())));
Y
Yan Chunwei 已提交
623 624 625 626
    } else {
      pass_builder_.reset(new CpuPassStrategy(
          *static_cast<CpuPassStrategy *>(pass_builder_.get())));
    }
627 628 629
  }

  if (use_tensorrt_) {
630 631
    pass_builder()->ClearPasses();
    for (const auto &pass : kTRTSubgraphPasses) {
632
      if (tensorrt_precision_mode_ == AnalysisConfig::Precision::kInt8 &&
633
          (pass == "conv_bn_fuse_pass")) {
634 635
        continue;
      }
636
      pass_builder()->AppendPass(pass);
637 638
    }
  }
639

D
denglin-github 已提交
640 641 642 643 644 645 646
  if (use_dlnne_) {
    pass_builder()->ClearPasses();
    for (const auto &pass : kDlnneSubgraphPasses) {
      pass_builder()->AppendPass(pass);
    }
  }

647
  if (use_gpu() && use_cudnn_) {
648
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
649 650 651 652 653 654 655 656
    if (!enable_ir_optim_) {
      LOG(ERROR) << "EnableCUDNN() only works when IR optimization is enabled.";
    } else {
      pass_builder()->EnableCUDNN();
    }
#endif
  }

657 658 659 660 661 662 663 664 665 666 667 668 669 670
  if (use_gpu_fp16_) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
    if (!enable_ir_optim_) {
      LOG(ERROR) << "Exp_EnableUseGpuFp16() only works when IR optimization is "
                    "enabled.";
    } else if (!use_gpu()) {
      LOG(ERROR)
          << "Exp_EnableUseGpuFp16() only works when use_gpu is enabled.";
    } else {
      pass_builder()->Exp_EnableUseGpuFp16();
    }
#endif
  }

671
  if (use_mkldnn_) {
W
Wojciech Uss 已提交
672
#ifdef PADDLE_WITH_MKLDNN
673 674 675
    if (!enable_ir_optim_) {
      LOG(ERROR)
          << "EnableMKLDNN() only works when IR optimization is enabled.";
W
Wojciech Uss 已提交
676 677
    } else {
      pass_builder()->EnableMKLDNN();
678 679 680 681
    }
#endif
  }

682 683 684 685 686
  // Quantization passes must come after all other optimization passes
  if (use_mkldnn_quantizer_) {
    if (!enable_ir_optim_) {
      LOG(ERROR) << "EnableMkldnnQuantizer() only works when IR optimization "
                    "is enabled.";
687 688
    }
#ifdef PADDLE_WITH_MKLDNN
689
    pass_builder()->EnableMkldnnQuantizer();
690 691 692
#endif
  }

693 694 695 696 697 698
  if (use_mkldnn_bfloat16_) {
#ifdef PADDLE_WITH_MKLDNN
    pass_builder()->EnableMkldnnBfloat16();
#endif
  }

B
baoachun 已提交
699 700 701 702 703 704 705 706 707 708 709 710 711 712
  if (use_mkldnn_int8_) {
#ifdef PADDLE_WITH_MKLDNN
    if (!enable_ir_optim_) {
      LOG(ERROR) << "EnableMkldnnInt8() only works when IR optimization "
                    "is enabled.";
    } else if (!use_mkldnn_) {
      LOG(ERROR) << "EnableMkldnnInt8() only works when MKLDNN "
                    "is enabled.";
    } else {
      pass_builder()->EnableMkldnnInt8();
    }
#endif
  }

713
#ifdef PADDLE_WITH_MKLDNN
714 715
  // Do not optimize when mkldnn is on
  if (enable_memory_optim_ && !use_mkldnn_) {
716
#else
Y
Yan Chunwei 已提交
717
  if (enable_memory_optim_) {
718 719
#endif
    pass_builder()->AppendAnalysisPass("memory_optimize_pass");
Y
Yan Chunwei 已提交
720 721
  }

石晓伟 已提交
722 723 724 725 726 727 728 729 730 731 732 733 734 735
  if (use_lite_) {
#ifndef PADDLE_WITH_LITE
    LOG(WARNING) << "You tried to enable the lite subgraph "
                    "but did not have the option -DWITH_LITE compiled.";
#endif
    pass_builder()->ClearPasses();
    for (const auto &pass : kLiteSubgraphPasses) {
      if (std::find(lite_passes_filter_.begin(), lite_passes_filter_.end(),
                    pass) == lite_passes_filter_.end()) {
        pass_builder()->AppendPass(pass);
      }
    }
  }

736
  if (use_xpu_) {
737
#if (defined LITE_SUBGRAPH_WITH_XPU) || (defined PADDLE_WITH_XPU)
738 739 740 741
    PADDLE_ENFORCE_EQ(use_gpu_, false,
                      platform::errors::Unavailable(
                          "Currently, XPU and GPU cannot be enabled in the "
                          "same analysis configuration."));
742 743 744 745 746
#else
    PADDLE_THROW(platform::errors::Unavailable(
        "You tried to use an XPU device, but Paddle was not compiled "
        "with XPU-runtime."));
#endif
747 748
  }

W
Wilber 已提交
749
  if (use_npu_) {
750
#if defined(PADDLE_WITH_ASCEND_CL) || defined(LITE_SUBGRAPH_WITH_NPU)
W
Wilber 已提交
751 752 753 754 755 756 757 758 759 760
    PADDLE_ENFORCE_EQ(use_gpu_, false,
                      platform::errors::Unavailable(
                          "Currently, NPU and GPU cannot be enabled in the "
                          "same analysis configuration."));
#else
    PADDLE_THROW(platform::errors::Unavailable(
        "You tried to use an NPU device, but Paddle was not compiled "
        "with NPU-runtime."));
#endif
  }
J
jianghaicheng 已提交
761 762 763 764 765 766 767
  if (use_ipu_) {
#ifndef PADDLE_WITH_IPU
    PADDLE_THROW(platform::errors::Unavailable(
        "You tried to enable the ipu "
        "but did not have the option -DWITH_IPU compiled."));
#endif
  }
768 769 770 771 772 773 774
  if (use_custom_device_) {
#ifndef PADDLE_WITH_CUSTOM_DEVICE
    PADDLE_THROW(platform::errors::Unavailable(
        "You tried to enable the custom device "
        "but did not have the option -DWITH_CUSTOM_DEVICE compiled."));
#endif
  }
775 776 777 778 779
  if (ir_debug_) {
    pass_builder()->TurnOnDebug();
  }
}

780
std::string AnalysisConfig::SerializeInfoCache() {
781
  std::stringstream ss;
Y
Yan Chunwei 已提交
782 783 784 785
  ss << model_dir_;
  ss << prog_file_;
  ss << params_file_;

786
  ss << use_gpu_;
787 788
  ss << use_gpu_fp16_;
  for (auto &item : gpu_fp16_disabled_op_types_) ss << item;
789
  ss << use_fc_padding_;
790 791
  ss << gpu_device_id_;
  ss << xpu_device_id_;
792 793 794 795 796
  ss << memory_pool_init_size_mb_;

  ss << use_tensorrt_;
  ss << tensorrt_workspace_size_;
  ss << tensorrt_max_batchsize_;
Y
Yan Chunwei 已提交
797 798
  ss << tensorrt_min_subgraph_size_;

D
denglin-github 已提交
799 800 801
  ss << use_dlnne_;
  ss << dlnne_min_subgraph_size_;

802 803 804
  for (auto &op : trt_disabled_ops_) ss << op.c_str();
  ss << ";";

805 806 807
  ss << trt_use_dla_;
  ss << trt_dla_core_;

Y
Yan Chunwei 已提交
808
  ss << enable_memory_optim_;
809 810

  ss << use_mkldnn_;
811
  ss << mkldnn_cache_capacity_;
Y
Yan Chunwei 已提交
812 813 814
  for (auto &item : mkldnn_enabled_op_types_) ss << item;
  ss << ";";

815
  ss << use_mkldnn_quantizer_;
816
  ss << use_mkldnn_bfloat16_;
817
  for (auto &item : bfloat16_enabled_op_types_) ss << item;
B
baoachun 已提交
818 819 820
  ss << use_mkldnn_int8_;
  for (auto &item : quantize_enabled_op_types_) ss << item;
  for (auto &item : quantize_excluded_op_ids_) ss << item;
821
  ss << ";";
Y
Yan Chunwei 已提交
822 823
  ss << model_from_memory_;

824 825
  ss << with_profile_;

826 827
  ss << with_glog_info_;

828 829 830 831
  ss << enable_ir_optim_;
  ss << use_feed_fetch_ops_;
  ss << ir_debug_;

Y
Yan Chunwei 已提交
832 833
  ss << specify_input_name_;
  ss << cpu_math_library_num_threads_;
石晓伟 已提交
834 835

  ss << use_lite_;
836 837
  ss << use_xpu_;
  ss << xpu_l3_workspace_size_;
W
Wilber 已提交
838 839 840 841 842
  ss << xpu_locked_;
  ss << xpu_autotune_;
  ss << xpu_autotune_file_;
  ss << xpu_precision_;
  ss << xpu_adaptive_seqlen_;
843

W
Wilber 已提交
844 845 846
  ss << use_npu_;
  ss << npu_device_id_;

847 848
  ss << thread_local_stream_;

J
jianghaicheng 已提交
849 850
  ss << use_ipu_;
  ss << ipu_device_num_;
851
  ss << ipu_micro_batch_size_;
J
jianghaicheng 已提交
852 853
  ss << ipu_enable_pipelining_;
  ss << ipu_batches_per_step_;
854 855 856 857
  ss << ipu_enable_fp16_;
  ss << ipu_replica_num_;
  ss << ipu_available_memory_proportion_;
  ss << ipu_enable_half_partial_;
J
jianghaicheng 已提交
858

859 860 861
  return ss.str();
}

862
void AnalysisConfig::SetCpuMathLibraryNumThreads(
863 864
    int cpu_math_library_num_threads) {
  cpu_math_library_num_threads_ = cpu_math_library_num_threads;
Y
Yan Chunwei 已提交
865 866

  Update();
867 868
}

869
float AnalysisConfig::fraction_of_gpu_memory_for_pool() const {
870
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
871 872
  // Get the GPU memory details and calculate the fraction of memory for the
  // GPU memory pool.
873
  size_t gpu_total, gpu_available;
874
  platform::SetDeviceId(gpu_device_id_);
875 876
  platform::GpuMemoryUsage(&gpu_available, &gpu_total);
  double total_gpu_memory = gpu_total / 1024. / 1024.;
877 878
  float fraction_of_gpu_memory =
      static_cast<double>(memory_pool_init_size_mb()) / total_gpu_memory;
879 880 881 882
  VLOG(3) << "total_gpu_memory is " << total_gpu_memory
          << "M, gpu_available is " << gpu_available / 1024. / 1024.
          << "M, memory_pool_init_size is " << memory_pool_init_size_mb()
          << "M.";
883 884 885 886
  return fraction_of_gpu_memory;
#else
  return 0.;
#endif
887 888
}

889 890
void AnalysisConfig::EnableMemoryOptim(bool x) {
  enable_memory_optim_ = x;
Y
Yan Chunwei 已提交
891 892 893
  Update();
}

894
bool AnalysisConfig::enable_memory_optim() const {
Y
Yan Chunwei 已提交
895 896 897
  return enable_memory_optim_;
}

898 899 900 901
void AnalysisConfig::SetModelBuffer(const char *prog_buffer,
                                    size_t prog_buffer_size,
                                    const char *param_buffer,
                                    size_t param_buffer_size) {
902 903
  prog_file_ = std::string(prog_buffer, prog_buffer + prog_buffer_size);
  params_file_ = std::string(param_buffer, param_buffer + param_buffer_size);
T
Tao Luo 已提交
904
  model_from_memory_ = true;
T
Tao Luo 已提交
905 906
}

907
NativeConfig AnalysisConfig::ToNativeConfig() const {
Y
Yan Chunwei 已提交
908 909 910 911 912
  NativeConfig config;
  config.model_dir = model_dir_;
  config.prog_file = prog_file_;
  config.param_file = params_file_;
  config.use_gpu = use_gpu_;
913
  config.device = gpu_device_id_;
Y
Yan Chunwei 已提交
914 915 916 917 918
  config.fraction_of_gpu_memory = fraction_of_gpu_memory_for_pool();
  config.specify_input_name = specify_input_name_;
  return config;
}

Y
Yan Chunwei 已提交
919 920 921 922
void AnalysisConfig::SwitchIrDebug(int x) {
  ir_debug_ = x;
  Update();
}
923 924 925 926 927 928

void AnalysisConfig::EnableProfile() {
  with_profile_ = true;
  Update();
}

929 930 931 932 933
void AnalysisConfig::DisableGlogInfo() {
  with_glog_info_ = false;
  Update();
}

石晓伟 已提交
934
void AnalysisConfig::EnableLiteEngine(
935
    AnalysisConfig::Precision precision_mode, bool zero_copy,
石晓伟 已提交
936 937 938 939 940 941
    const std::vector<std::string> &passes_filter,
    const std::vector<std::string> &ops_filter) {
  use_lite_ = true;
  lite_precision_mode_ = precision_mode;
  lite_passes_filter_ = passes_filter;
  lite_ops_filter_ = ops_filter;
942
  lite_zero_copy_ = zero_copy;
石晓伟 已提交
943 944 945
  Update();
}

946 947 948 949 950 951 952
void AnalysisConfig::PartiallyRelease() {
  prog_file_.clear();
  prog_file_.shrink_to_fit();
  params_file_.clear();
  params_file_.shrink_to_fit();
}

953 954
void AnalysisConfig::EnableGpuMultiStream() { thread_local_stream_ = true; }

955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973
std::string AnalysisConfig::Summary() {
  const std::vector<std::string> header{"Option", "Value"};
  paddle::inference::TablePrinter os(header);

  if (!model_dir_.empty()) {
    os.InsertRow({"model_dir", model_dir_});
  }
  if (!(prog_file_.empty() && params_file_.empty())) {
    os.InsertRow({"model_file", prog_file_});
    os.InsertRow({"params_file", params_file_});
  }
  if (model_from_memory_) {
    os.InsertRow({"model_from_memory", params_file_});
  }
  os.InsetDivider();

  // cpu info
  os.InsertRow(
      {"cpu_math_thread", std::to_string(cpu_math_library_num_threads_)});
974
  os.InsertRow({"enable_mkldnn", use_mkldnn_ ? "true" : "false"});
975 976 977 978 979 980 981 982 983 984 985 986 987 988 989
  os.InsertRow(
      {"mkldnn_cache_capacity", std::to_string(mkldnn_cache_capacity_)});
  os.InsetDivider();

  // gpu info
  os.InsertRow({"use_gpu", use_gpu_ ? "true" : "false"});
  if (use_gpu_) {
    os.InsertRow({"gpu_device_id", std::to_string(gpu_device_id_)});
    os.InsertRow({"memory_pool_init_size",
                  std::to_string(memory_pool_init_size_mb_) + "MB"});
    os.InsertRow(
        {"thread_local_stream", thread_local_stream_ ? "true" : "false"});

    os.InsertRow({"use_tensorrt", use_tensorrt_ ? "true" : "false"});
    if (use_tensorrt_) {
990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016
#ifdef PADDLE_WITH_TENSORRT
      auto Precision2String =
          [](paddle::AnalysisConfig::Precision prec) -> std::string {
        if (prec == Precision::kFloat32)
          return "fp32";
        else if (prec == Precision::kHalf)
          return "fp16";
        else if (prec == Precision::kInt8)
          return "int8";
        else
          return "None";
      };
      auto version2string =
          [](const std::tuple<int, int, int> &ver) -> std::string {
        std::ostringstream os;
        int major = std::get<0>(ver);
        int minor = std::get<1>(ver);
        int patch = std::get<2>(ver);
        os << major << "." << minor << "." << patch;
        return os.str();
      };
      os.InsertRow(
          {"trt_compile_version",
           version2string(inference::tensorrt::GetTrtCompileVersion())});
      os.InsertRow(
          {"trt_runtime_version",
           version2string(inference::tensorrt::GetTrtRuntimeVersion())});
1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032
      os.InsertRow({"tensorrt_precision_mode",
                    Precision2String(tensorrt_precision_mode_)});
      os.InsertRow({"tensorrt_workspace_size",
                    std::to_string(tensorrt_workspace_size_)});
      os.InsertRow(
          {"tensorrt_max_batch_size", std::to_string(tensorrt_max_batchsize_)});
      os.InsertRow({"tensorrt_min_subgraph_size",
                    std::to_string(tensorrt_min_subgraph_size_)});
      os.InsertRow({"tensorrt_use_static_engine",
                    trt_use_static_engine_ ? "true" : "false"});
      os.InsertRow(
          {"tensorrt_use_calib_mode", trt_use_calib_mode_ ? "true" : "false"});

      // dynamic_shape
      os.InsertRow({"tensorrt_enable_dynamic_shape",
                    min_input_shape_.empty() ? "false" : "true"});
1033 1034 1035
      os.InsertRow({"tensorrt_tuned_dynamic_shape", trt_tuned_dynamic_shape_
                                                        ? shape_range_info_path_
                                                        : "false"});
1036 1037

      os.InsertRow({"tensorrt_use_oss", trt_use_oss_ ? "true" : "false"});
1038 1039
      os.InsertRow({"tensorrt_with_interleaved",
                    trt_with_interleaved_ ? "true" : "false"});
1040 1041 1042 1043
      os.InsertRow({"tensorrt_use_dla", trt_use_dla_ ? "true" : "false"});
      if (trt_use_dla_) {
        os.InsertRow({"tensorrt_dla_core", std::to_string(trt_dla_core_)});
      }
1044
#endif
1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067
    }
  }
  os.InsetDivider();

  // xpu info
  os.InsertRow({"use_xpu", use_xpu_ ? "true" : "false"});
  if (use_xpu_) {
    os.InsertRow({"xpu_device_id", std::to_string(xpu_device_id_)});
    os.InsertRow(
        {"xpu_l3_workspace_size", std::to_string(xpu_l3_workspace_size_)});
  }
  os.InsetDivider();

  if (use_lite_) {
    os.InsertRow({"use_lite", use_lite_ ? "true" : "false"});
  }

  // ir info
  os.InsertRow({"ir_optim", enable_ir_optim_ ? "true" : "false"});
  os.InsertRow({"ir_debug", ir_debug_ ? "true" : "false"});
  os.InsertRow({"memory_optim", enable_memory_optim_ ? "true" : "false"});
  os.InsertRow({"enable_profile", with_profile_ ? "true" : "false"});
  os.InsertRow({"enable_log", with_glog_info_ ? "true" : "false"});
1068 1069
  os.InsertRow({"collect_shape_range_info",
                collect_shape_range_info_ ? shape_range_info_path_ : "false"});
1070 1071 1072 1073

  return os.PrintTable();
}

1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128
LiteNNAdapterConfig &LiteNNAdapterConfig::SetDeviceNames(
    const std::vector<std::string> &names) {
  nnadapter_device_names = names;
  return *this;
}

LiteNNAdapterConfig &LiteNNAdapterConfig::SetContextProperties(
    const std::string &properties) {
  nnadapter_context_properties = properties;
  return *this;
}

LiteNNAdapterConfig &LiteNNAdapterConfig::SetModelCacheDir(
    const std::string &dir) {
  nnadapter_model_cache_dir = dir;
  return *this;
}

LiteNNAdapterConfig &LiteNNAdapterConfig::SetModelCacheBuffers(
    const std::string &model_cache_token,
    const std::vector<char> &model_cache_buffer) {
  PADDLE_ENFORCE_EQ(model_cache_token.empty(), false,
                    platform::errors::InvalidArgument(
                        "model_cache_token should not be empty."));
  PADDLE_ENFORCE_EQ(model_cache_buffer.empty(), false,
                    platform::errors::InvalidArgument(
                        "model_cache_buffer should not be empty."));
  PADDLE_ENFORCE_EQ(nnadapter_model_cache_buffers.count(model_cache_token),
                    false, platform::errors::InvalidArgument(
                               "model_cache_token has already been set."));

  nnadapter_model_cache_buffers[model_cache_token] = model_cache_buffer;
  return *this;
}

LiteNNAdapterConfig &LiteNNAdapterConfig::SetSubgraphPartitionConfigPath(
    const std::string &path) {
  nnadapter_subgraph_partition_config_path = path;
  return *this;
}

LiteNNAdapterConfig &LiteNNAdapterConfig::SetSubgraphPartitionConfigBuffer(
    const std::string &buffer) {
  nnadapter_subgraph_partition_config_buffer = buffer;
  return *this;
}
LiteNNAdapterConfig &LiteNNAdapterConfig::Enable() {
  use_nnadapter = true;
  return *this;
}
LiteNNAdapterConfig &LiteNNAdapterConfig::Disable() {
  use_nnadapter = false;
  return *this;
}

1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164
void AnalysisConfig::CollectShapeRangeInfo(
    const std::string &shape_range_info_path) {
  LOG(INFO) << "In CollectShapeInfo mode, we will disable optimizations and "
               "collect the shape information of "
            << "all intermediate tensors in the compute graph and calculate "
               "the min_shape, max_shape and opt_shape.";
  collect_shape_range_info_ = true;
  PADDLE_ENFORCE_EQ(shape_range_info_path.empty(), false,
                    platform::errors::InvalidArgument(
                        "The shape_range_info_path should not be empty, please "
                        "re-check the argument."));
  shape_range_info_path_ = shape_range_info_path;
}

const std::string &AnalysisConfig::shape_range_info_path() {
  return shape_range_info_path_;
}

bool AnalysisConfig::shape_range_info_collected() {
  return collect_shape_range_info_;
}

void AnalysisConfig::EnableTunedTensorRtDynamicShape(
    const std::string &shape_range_info_path, bool allow_build_at_runtime) {
  shape_range_info_path_ = shape_range_info_path;
  trt_allow_build_at_runtime_ = allow_build_at_runtime;
  trt_tuned_dynamic_shape_ = true;
}

bool AnalysisConfig::tuned_tensorrt_dynamic_shape() {
  return trt_tuned_dynamic_shape_;
}

bool AnalysisConfig::trt_allow_build_at_runtime() {
  return trt_allow_build_at_runtime_;
}
1165
}  // namespace paddle