analysis_config.cc 38.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include <sstream>
16
#include <string>
17
#include <tuple>
18

19 20
#include "paddle/fluid/inference/api/paddle_analysis_config.h"
#include "paddle/fluid/inference/api/paddle_pass_builder.h"
21
#include "paddle/fluid/inference/utils/table_printer.h"
22
#include "paddle/fluid/platform/cpu_info.h"
23
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
24 25
#include "paddle/fluid/platform/enforce.h"

26 27 28 29
#ifdef PADDLE_WITH_TENSORRT
#include "paddle/fluid/inference/tensorrt/helper.h"
#endif

30
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
31 32 33
DECLARE_uint64(initial_gpu_memory_in_mb);
#endif

34
namespace paddle {
W
wanghuancoder 已提交
35 36
struct MkldnnQuantizerConfig;

37
extern const std::vector<std::string> kTRTSubgraphPasses;
D
denglin-github 已提交
38
extern const std::vector<std::string> kDlnneSubgraphPasses;
石晓伟 已提交
39
extern const std::vector<std::string> kLiteSubgraphPasses;
40

41
PassStrategy *AnalysisConfig::pass_builder() const {
42 43 44 45
  if (!pass_builder_.get()) {
    if (use_gpu_) {
      LOG(INFO) << "Create GPU IR passes";
      pass_builder_.reset(new GpuPassStrategy);
46 47
    } else if (use_xpu_) {
      pass_builder_.reset(new XpuPassStrategy);
W
Wilber 已提交
48 49
    } else if (use_npu_) {
      pass_builder_.reset(new NpuPassStrategy);
J
jianghaicheng 已提交
50 51 52
    } else if (use_ipu_) {
      LOG(INFO) << "Create IPU IR passes";
      pass_builder_.reset(new IpuPassStrategy);
53 54 55 56 57 58 59 60 61 62 63 64
    } else {
      LOG(INFO) << "Create CPU IR passes";
      pass_builder_.reset(new CpuPassStrategy);
    }
  } else if (pass_builder_->use_gpu() ^ use_gpu()) {
    LOG(WARNING) << "The use_gpu flag is not compatible between Config and "
                    "PassBuilder, the flags are "
                 << use_gpu() << " " << pass_builder_->use_gpu();
    LOG(WARNING) << "Please make them compatible, still use the existing "
                    "PassBuilder.";
  }

65 66 67
  return pass_builder_.get();
}

68
AnalysisConfig::AnalysisConfig(const std::string &model_dir) {
69
  model_dir_ = model_dir;
Y
Yan Chunwei 已提交
70 71

  Update();
72
}
73 74
AnalysisConfig::AnalysisConfig(const std::string &prog_file,
                               const std::string &params_file) {
75 76
  prog_file_ = prog_file;
  params_file_ = params_file;
Y
Yan Chunwei 已提交
77 78

  Update();
79
}
80 81
void AnalysisConfig::SetModel(const std::string &prog_file_path,
                              const std::string &params_file_path) {
82 83
  prog_file_ = prog_file_path;
  params_file_ = params_file_path;
Y
Yan Chunwei 已提交
84 85

  Update();
86
}
87 88
void AnalysisConfig::EnableUseGpu(uint64_t memory_pool_init_size_mb,
                                  int device_id) {
89
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
90 91
  use_gpu_ = true;
  memory_pool_init_size_mb_ = memory_pool_init_size_mb;
92
  FLAGS_initial_gpu_memory_in_mb = memory_pool_init_size_mb_;
93
  gpu_device_id_ = device_id;
94
#else
Y
Yan Chunwei 已提交
95
  LOG(ERROR) << "Please compile with gpu to EnableGpu()";
96 97
  use_gpu_ = false;
#endif
Y
Yan Chunwei 已提交
98 99 100

  Update();
}
101

102
void AnalysisConfig::SetExecStream(void *stream) {
W
Wilber 已提交
103 104 105
  PADDLE_ENFORCE_NOT_NULL(
      stream,
      platform::errors::InvalidArgument("`stream` should not be nullptr"));
106 107 108 109 110 111
  exec_stream_ = stream;
  use_external_stream_ = true;
  Update();
}

void *AnalysisConfig::GetExecStream() const {
W
Wilber 已提交
112 113 114
  PADDLE_ENFORCE_NOT_NULL(
      exec_stream_,
      platform::errors::InvalidArgument("`stream` should not be nullptr"));
115 116 117 118 119 120 121
  return exec_stream_;
}

bool AnalysisConfig::external_stream_enabled() const {
  return use_external_stream_;
}

122
void AnalysisConfig::DisableGpu() {
Y
Yan Chunwei 已提交
123 124 125
  use_gpu_ = false;

  Update();
126 127
}

128 129 130 131 132 133
void AnalysisConfig::DisableFCPadding() {
  use_fc_padding_ = false;

  Update();
}

W
Wilber 已提交
134 135 136 137
void AnalysisConfig::EnableXpu(int l3_workspace_size,
                               bool locked,
                               bool autotune,
                               const std::string &autotune_file,
W
Wilber 已提交
138 139
                               const std::string &precision,
                               bool adaptive_seqlen) {
140 141
  use_xpu_ = true;
  xpu_l3_workspace_size_ = l3_workspace_size;
W
Wilber 已提交
142 143 144 145 146
  xpu_locked_ = locked;
  xpu_autotune_ = autotune;
  xpu_autotune_file_ = autotune_file;
  xpu_precision_ = precision;
  xpu_adaptive_seqlen_ = adaptive_seqlen;
147 148 149
  Update();
}

150
void AnalysisConfig::SetXpuDeviceId(int device_id) {
W
Wilber 已提交
151 152
  PADDLE_ENFORCE_EQ(use_xpu_,
                    true,
153 154 155 156 157 158
                    platform::errors::PreconditionNotMet(
                        "Should call EnableXpu before SetXpuDeviceId."));
  xpu_device_id_ = device_id;
  Update();
}

W
Wilber 已提交
159 160 161 162 163 164 165 166 167 168 169
void AnalysisConfig::EnableNpu(int device_id) {
#ifdef PADDLE_WITH_ASCEND_CL
  use_npu_ = true;
  npu_device_id_ = device_id;
#else
  LOG(ERROR) << "Please compile with npu to EnableNpu()";
  use_npu_ = false;
#endif

  Update();
}
170

171 172 173 174 175 176 177 178 179 180 181 182 183
void AnalysisConfig::EnableCustomDevice(const std::string &device_type,
                                        int device_id) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
  use_custom_device_ = true;
  custom_device_id_ = device_id;
  custom_device_type_ = device_type;
#else
  LOG(ERROR) << "Please compile with CustomDevice to EnableCustomDevice()";
  use_custom_device_ = false;
#endif
  Update();
}

W
Wilber 已提交
184 185
void AnalysisConfig::EnableIpu(int ipu_device_num,
                               int ipu_micro_batch_size,
186 187
                               bool ipu_enable_pipelining,
                               int ipu_batches_per_step) {
J
jianghaicheng 已提交
188 189 190
  enable_ir_optim_ = true;

  use_ipu_ = true;
191 192
  ipu_device_num_ = ipu_device_num;
  ipu_micro_batch_size_ = ipu_micro_batch_size;
J
jianghaicheng 已提交
193 194
  ipu_enable_pipelining_ = ipu_enable_pipelining;
  ipu_batches_per_step_ = ipu_batches_per_step;
195 196 197 198

  Update();
}

W
Wilber 已提交
199 200
void AnalysisConfig::SetIpuConfig(bool ipu_enable_fp16,
                                  int ipu_replica_num,
201 202 203 204 205 206
                                  float ipu_available_memory_proportion,
                                  bool ipu_enable_half_partial) {
  ipu_enable_fp16_ = ipu_enable_fp16;
  ipu_replica_num_ = ipu_replica_num;
  ipu_available_memory_proportion_ = ipu_available_memory_proportion;
  ipu_enable_half_partial_ = ipu_enable_half_partial;
J
jianghaicheng 已提交
207 208 209

  Update();
}
W
Wilber 已提交
210

211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237
void AnalysisConfig::EnableONNXRuntime() {
#ifdef PADDLE_WITH_ONNXRUNTIME
  use_onnxruntime_ = true;
#else
  LOG(ERROR) << "Please compile with onnxruntime to EnableONNXRuntime()";
  use_onnxruntime_ = false;
#endif

  Update();
}

void AnalysisConfig::DisableONNXRuntime() {
  use_onnxruntime_ = false;
  Update();
}

void AnalysisConfig::EnableORTOptimization() {
#ifdef PADDLE_WITH_ONNXRUNTIME
  enable_ort_optimization_ = true;
#else
  LOG(ERROR) << "Please compile with onnxruntime to EnableORTOptimization()";
  enable_ort_optimization_ = false;
#endif

  Update();
}

238
AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
239 240 241 242 243 244
#define CP_MEMBER(member__) member__ = other.member__;

  // Model related.
  CP_MEMBER(model_dir_);
  CP_MEMBER(model_from_memory_);  // the memory model reuses prog_file_ and
                                  // params_file_ fields.
245

246
  CP_MEMBER(opt_cache_dir_);
W
Wilber 已提交
247 248
  CP_MEMBER(prog_file_);
  CP_MEMBER(params_file_);
249
  CP_MEMBER(calibration_file_path_);
250

251
  CP_MEMBER(use_fc_padding_);
252
  // GPU related.
253
  CP_MEMBER(use_gpu_);
254 255
  CP_MEMBER(use_external_stream_);
  CP_MEMBER(exec_stream_);
256
  CP_MEMBER(use_cudnn_);
257
  CP_MEMBER(gpu_device_id_);
258
  CP_MEMBER(memory_pool_init_size_mb_);
Y
Yan Chunwei 已提交
259

260 261 262
  // Mixed related.
  CP_MEMBER(mixed_black_list_);

Y
Yan Chunwei 已提交
263
  CP_MEMBER(enable_memory_optim_);
S
Sylwester Fraczek 已提交
264
  // TensorRT related.
265 266 267 268
  CP_MEMBER(use_tensorrt_);
  CP_MEMBER(tensorrt_workspace_size_);
  CP_MEMBER(tensorrt_max_batchsize_);
  CP_MEMBER(tensorrt_min_subgraph_size_);
N
nhzlx 已提交
269
  CP_MEMBER(tensorrt_precision_mode_);
270
  CP_MEMBER(trt_disabled_ops_);
271 272
  CP_MEMBER(trt_use_dla_);
  CP_MEMBER(trt_dla_core_);
N
nhzlx 已提交
273
  CP_MEMBER(trt_use_static_engine_);
274
  CP_MEMBER(trt_use_calib_mode_);
275
  CP_MEMBER(trt_use_varseqlen_);
276
  CP_MEMBER(trt_with_interleaved_);
277 278
  CP_MEMBER(tensorrt_transformer_posid_);
  CP_MEMBER(tensorrt_transformer_maskid_);
279 280 281 282
  CP_MEMBER(trt_tuned_dynamic_shape_);
  CP_MEMBER(trt_allow_build_at_runtime_);
  CP_MEMBER(collect_shape_range_info_);
  CP_MEMBER(shape_range_info_path_);
283
  CP_MEMBER(trt_use_inspector_);
284
  CP_MEMBER(trt_engine_memory_sharing_);
D
denglin-github 已提交
285 286 287
  // Dlnne related
  CP_MEMBER(use_dlnne_);
  CP_MEMBER(dlnne_min_subgraph_size_);
D
denglin-github 已提交
288 289 290 291 292 293 294
  CP_MEMBER(dlnne_max_batchsize_);
  CP_MEMBER(dlnne_use_static_batch_);
  CP_MEMBER(dlnne_weight_share_mode_);
  CP_MEMBER(dlnne_use_calib_mode_);
  CP_MEMBER(dlnne_precision_mode_);
  CP_MEMBER(dlnne_disable_nodes_by_outputs_);
  CP_MEMBER(dlnne_input_shape_dict_);
S
Sylwester Fraczek 已提交
295
  // MKLDNN related.
296 297
  CP_MEMBER(use_mkldnn_);
  CP_MEMBER(mkldnn_enabled_op_types_);
298
  CP_MEMBER(mkldnn_cache_capacity_);
299 300 301
  // Bfloat16 related.
  CP_MEMBER(use_mkldnn_bfloat16_);
  CP_MEMBER(bfloat16_enabled_op_types_);
302
  // Quantization related.
B
baoachun 已提交
303 304 305
  CP_MEMBER(use_mkldnn_int8_);
  CP_MEMBER(quantize_enabled_op_types_);
  CP_MEMBER(quantize_excluded_op_ids_);
306 307
  CP_MEMBER(use_mkldnn_quantizer_);
  CP_MEMBER(mkldnn_quantizer_config_);
308 309 310
  CP_MEMBER(min_input_shape_);
  CP_MEMBER(max_input_shape_);
  CP_MEMBER(optim_input_shape_);
311
  CP_MEMBER(disable_trt_plugin_fp16_);
312

石晓伟 已提交
313 314 315 316
  CP_MEMBER(use_lite_);
  CP_MEMBER(lite_precision_mode_);
  CP_MEMBER(lite_passes_filter_);
  CP_MEMBER(lite_ops_filter_);
317 318
  CP_MEMBER(lite_zero_copy_);

W
Wilber 已提交
319
  // XPU related.
320
  CP_MEMBER(use_xpu_);
W
Wilber 已提交
321
  CP_MEMBER(xpu_device_id_);
322
  CP_MEMBER(xpu_l3_workspace_size_);
W
Wilber 已提交
323 324 325 326 327
  CP_MEMBER(xpu_locked_);
  CP_MEMBER(xpu_autotune_);
  CP_MEMBER(xpu_autotune_file_);
  CP_MEMBER(xpu_precision_);
  CP_MEMBER(xpu_adaptive_seqlen_);
石晓伟 已提交
328

W
Wilber 已提交
329 330 331
  // NPU related.
  CP_MEMBER(use_npu_);
  CP_MEMBER(npu_device_id_);
332
  CP_MEMBER(nnadapter_config_);
W
Wilber 已提交
333

334 335 336
  // profile related.
  CP_MEMBER(with_profile_);

337 338 339
  // glog related.
  CP_MEMBER(with_glog_info_);

340 341 342 343 344 345 346 347 348 349
  // Ir related.
  CP_MEMBER(enable_ir_optim_);
  CP_MEMBER(use_feed_fetch_ops_);
  CP_MEMBER(ir_debug_);
  CP_MEMBER(specify_input_name_);

  CP_MEMBER(cpu_math_library_num_threads_);

  CP_MEMBER(serialized_info_cache_);

350 351
  CP_MEMBER(thread_local_stream_);

J
jianghaicheng 已提交
352 353 354
  // ipu related
  CP_MEMBER(use_ipu_);
  CP_MEMBER(ipu_device_num_);
355
  CP_MEMBER(ipu_micro_batch_size_);
J
jianghaicheng 已提交
356 357
  CP_MEMBER(ipu_enable_pipelining_);
  CP_MEMBER(ipu_batches_per_step_);
358 359 360 361
  CP_MEMBER(ipu_enable_fp16_);
  CP_MEMBER(ipu_replica_num_);
  CP_MEMBER(ipu_available_memory_proportion_);
  CP_MEMBER(ipu_enable_half_partial_);
J
jianghaicheng 已提交
362

363 364 365
  // fleet exe related
  CP_MEMBER(dist_config_);

366 367 368 369 370
  // custom device related.
  CP_MEMBER(use_custom_device_);
  CP_MEMBER(custom_device_type_);
  CP_MEMBER(custom_device_id_);

371
  if (use_gpu_) {
W
Wilber 已提交
372 373
    PADDLE_ENFORCE_EQ(use_xpu_,
                      false,
374 375
                      platform::errors::InvalidArgument(
                          "Only one choice can be made between CPU and XPU."));
376 377
    pass_builder_.reset(new GpuPassStrategy(
        *static_cast<GpuPassStrategy *>(other.pass_builder())));
J
jianghaicheng 已提交
378 379 380
  } else if (use_ipu_) {
    pass_builder_.reset(new IpuPassStrategy(
        *static_cast<IpuPassStrategy *>(other.pass_builder())));
381 382 383
  } else if (use_xpu_) {
    pass_builder_.reset(new XpuPassStrategy(
        *static_cast<XpuPassStrategy *>(other.pass_builder())));
W
Wilber 已提交
384 385 386
  } else if (use_npu_) {
    pass_builder_.reset(new NpuPassStrategy(
        *static_cast<NpuPassStrategy *>(other.pass_builder())));
387 388 389 390 391
  } else {
    pass_builder_.reset(new CpuPassStrategy(
        *static_cast<CpuPassStrategy *>(other.pass_builder())));
  }

392
#undef CP_MEMBER
Y
Yan Chunwei 已提交
393

W
Wilber 已提交
394 395 396 397 398
  Update();
  if (use_tensorrt_) {
    // Update() will reset all the passes, when some tensorRT pass is deleted in
    // other.pass_builder(), it will set again, so we just remove the
    // deleted_pass.
399
    pass_builder_->ClearPasses();
W
Wilber 已提交
400
    auto other_passes = other.pass_builder()->AllPasses();
401 402
    for (auto pass : other_passes) {
      pass_builder_->AppendPass(pass);
W
Wilber 已提交
403
    }
404
  }
D
denglin-github 已提交
405 406 407 408 409 410 411 412
  if (use_dlnne_) {
    auto all_passes = kDlnneSubgraphPasses;
    auto other_passes = other.pass_builder()->AllPasses();
    // We should sort them, because the user may call the SwitchIrDebug
    // interface, which will change the pass.
    std::sort(all_passes.begin(), all_passes.end());
    std::sort(other_passes.begin(), other_passes.end());
    std::vector<std::string> deleted_passes;
W
Wilber 已提交
413 414 415 416
    std::set_difference(all_passes.begin(),
                        all_passes.end(),
                        other_passes.begin(),
                        other_passes.end(),
D
denglin-github 已提交
417 418 419 420 421
                        std::inserter(deleted_passes, deleted_passes.begin()));
    for (auto ps : deleted_passes) {
      pass_builder_->DeletePass(ps);
    }
  }
W
Wilber 已提交
422 423 424 425

  for (auto &delete_pass : other.pass_builder()->GetAllDeletedPasses()) {
    pass_builder_->DeletePass(delete_pass);
  }
426 427
}

428
void AnalysisConfig::EnableCUDNN() {
429
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
430 431 432 433 434 435 436 437 438
  use_cudnn_ = use_gpu_;
#else
  LOG(ERROR) << "Please compile with CUDA first to use cuDNN";
  use_cudnn_ = false;
#endif

  Update();
}

439
void AnalysisConfig::EnableMKLDNN() {
440 441 442 443 444 445
#ifdef PADDLE_WITH_MKLDNN
  use_mkldnn_ = true;
#else
  LOG(ERROR) << "Please compile with MKLDNN first to use MKLDNN";
  use_mkldnn_ = false;
#endif
Y
Yan Chunwei 已提交
446 447

  Update();
448 449
}

450 451 452 453 454 455 456 457 458
void AnalysisConfig::SetMkldnnCacheCapacity(int capacity) {
#ifdef PADDLE_WITH_MKLDNN
  mkldnn_cache_capacity_ = capacity;
#else
  LOG(ERROR) << "Please compile with MKLDNN first to set MKLDNN Thread Id";
  mkldnn_cache_capacity_ = 0;
#endif
}

459 460 461 462 463 464 465 466 467 468 469 470 471
void AnalysisConfig::EnableMkldnnQuantizer() {
#ifdef PADDLE_WITH_MKLDNN
  if (!mkldnn_quantizer_config_)
    mkldnn_quantizer_config_.reset(new MkldnnQuantizerConfig());
  use_mkldnn_quantizer_ = true;
#else
  LOG(ERROR) << "Please compile with MKLDNN first to use MkldnnQuantizer";
  use_mkldnn_quantizer_ = false;
#endif

  Update();
}

472 473
void AnalysisConfig::EnableMkldnnBfloat16() {
#ifdef PADDLE_WITH_MKLDNN
474 475
  if (platform::MayIUse(platform::cpu_isa_t::avx512_core)) {
    use_mkldnn_bfloat16_ = true;
476 477 478 479
    LOG(INFO) << "Hardware support for BFLOAT16"
              << (platform::MayIUse(platform::cpu_isa_t::avx512_bf16)
                      ? " is enabled"
                      : " is disabled. Simulation will be used");
480 481 482 483
  } else {
    LOG(INFO) << "CPU does not support BFLOAT16 calculations";
    use_mkldnn_bfloat16_ = false;
  }
484 485 486 487 488 489 490 491
#else
  LOG(ERROR) << "Please compile with MKLDNN first to use MkldnnBfloat16";
  use_mkldnn_bfloat16_ = false;
#endif

  Update();
}

B
baoachun 已提交
492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520
void AnalysisConfig::EnableMkldnnInt8(
    const std::unordered_set<std::string> &op_list) {
#ifdef PADDLE_WITH_MKLDNN
  use_mkldnn_int8_ = true;
  use_fc_padding_ = false;
  if (!op_list.empty()) {
    for (auto &type : op_list) {
      if (!quantize_enabled_op_types_.count(type)) {
        LOG(ERROR) << "There are unsupported operators in the configured "
                      "quantization operator list. The unsupported operator "
                      "is: "
                   << type;
        use_mkldnn_int8_ = false;
        break;
      }
    }
    if (use_mkldnn_int8_) {
      quantize_enabled_op_types_.clear();
      quantize_enabled_op_types_.insert(op_list.begin(), op_list.end());
    }
  }
#else
  LOG(ERROR) << "Please compile with MKLDNN first to use MkldnnInt8";
  use_mkldnn_int8_ = false;
#endif

  Update();
}

521 522 523 524 525 526 527 528
void AnalysisConfig::SetCalibrationFilePath(
    const std::string &calibration_file_path) {
  calibration_file_path_ = calibration_file_path;
  VLOG(1) << "Set calibration file path of quantize model: " +
                 calibration_file_path_;
  Update();
}

529
MkldnnQuantizerConfig *AnalysisConfig::mkldnn_quantizer_config() const {
530
  PADDLE_ENFORCE_NOT_NULL(mkldnn_quantizer_config_,
531 532
                          platform::errors::PreconditionNotMet(
                              "MkldnnQuantizer was not enabled yet."));
533
  return mkldnn_quantizer_config_.get();
534 535
}

536
void AnalysisConfig::EnableTensorRtEngine(
537
    int64_t workspace_size,
W
Wilber 已提交
538 539 540 541
    int max_batch_size,
    int min_subgraph_size,
    AnalysisConfig::Precision precision_mode,
    bool use_static,
542
    bool use_calib_mode) {
543
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
Y
Yan Chunwei 已提交
544 545 546 547 548
  if (!use_gpu()) {
    LOG(ERROR) << "To use TensorRT engine, please call EnableGpu() first";
    return;
  }

549
  use_tensorrt_ = true;
550 551 552 553 554 555 556 557 558 559 560 561 562
#if PADDLE_WITH_TENSORRT
  // https://forums.developer.nvidia.com/t/nvinfer1-createexecutioncontextwithoutdevicememory-returns-nullptr/111878/2
  // when trt version less than 7.2,
  // createExecutionContextWithoutDeviceMemory() has bug.
  // so, we cannot enable engine context memory sharing.
#if IS_TRT_VERSION_GE(7200)
  trt_engine_memory_sharing_ = true;
#else
  LOG(WARNING)
      << "TensorRT engine context memory sharing needs version 7.2 and after.";
  trt_engine_memory_sharing_ = false;
#endif
#endif
563 564
  tensorrt_workspace_size_ = workspace_size;
  tensorrt_max_batchsize_ = max_batch_size;
N
nhzlx 已提交
565
  tensorrt_min_subgraph_size_ = min_subgraph_size;
N
nhzlx 已提交
566
  tensorrt_precision_mode_ = precision_mode;
N
nhzlx 已提交
567
  trt_use_static_engine_ = use_static;
568
  trt_use_calib_mode_ = use_calib_mode;
Y
Yan Chunwei 已提交
569

570
  Update();
Y
Yan Chunwei 已提交
571 572 573 574
#else
  LOG(ERROR)
      << "To use TensorRT engine, please compile inference lib with GPU first.";
#endif
575 576
}

D
denglin-github 已提交
577 578 579 580 581 582 583 584 585
void AnalysisConfig::EnableDlnne(
    int min_subgraph_size,
    int max_batch_size,
    bool use_static_batch,
    std::string weight_share_mode,
    std::unordered_set<std::string> disable_nodes_by_ouputs,
    std::map<std::string, std::vector<int64_t>> dlnne_input_shape_dict,
    bool use_calib_mode,
    AnalysisConfig::Precision precision_mode) {
D
denglin-github 已提交
586 587
  use_dlnne_ = true;
  dlnne_min_subgraph_size_ = min_subgraph_size;
D
denglin-github 已提交
588 589 590 591 592 593 594
  dlnne_max_batchsize_ = max_batch_size;
  dlnne_use_static_batch_ = use_static_batch;
  dlnne_weight_share_mode_ = weight_share_mode;
  dlnne_disable_nodes_by_outputs_ = disable_nodes_by_ouputs;
  dlnne_input_shape_dict_ = dlnne_input_shape_dict;
  dlnne_use_calib_mode_ = use_calib_mode;
  dlnne_precision_mode_ = precision_mode;
D
denglin-github 已提交
595 596 597
  Update();
}

598 599 600 601 602 603 604 605 606 607 608
void AnalysisConfig::SetTRTDynamicShapeInfo(
    std::map<std::string, std::vector<int>> min_input_shape,
    std::map<std::string, std::vector<int>> max_input_shape,
    std::map<std::string, std::vector<int>> optim_input_shape,
    bool disable_trt_plugin_fp16) {
  min_input_shape_ = min_input_shape;
  max_input_shape_ = max_input_shape;
  optim_input_shape_ = optim_input_shape;
  disable_trt_plugin_fp16_ = disable_trt_plugin_fp16;
}

609 610 611 612 613
void AnalysisConfig::EnableTensorRtDLA(int dla_core) {
  trt_use_dla_ = true;
  trt_dla_core_ = dla_core;
}

614 615
void AnalysisConfig::EnableTensorRtInspector() { trt_use_inspector_ = true; }

616 617 618 619 620
void AnalysisConfig::Exp_DisableTensorRtOPs(
    const std::vector<std::string> &ops) {
  trt_disabled_ops_.insert(trt_disabled_ops_.end(), ops.begin(), ops.end());
}

621
void AnalysisConfig::EnableVarseqlen() { trt_use_varseqlen_ = true; }
622

Y
Yan Chunwei 已提交
623
// TODO(Superjomn) refactor this, buggy.
624
void AnalysisConfig::Update() {
625
  auto &&info = SerializeInfoCache();
626 627
  if (info == serialized_info_cache_) return;

Y
Yan Chunwei 已提交
628
  // Transfer pass_builder and copy the existing compatible passes.
W
Wilber 已提交
629 630
  if (!pass_builder_ || ((use_gpu() ^ pass_builder_->use_gpu())) ||
      ((use_xpu() ^ pass_builder_->use_xpu())) ||
J
jianghaicheng 已提交
631
      ((use_npu() ^ pass_builder_->use_npu())) ||
632 633
      ((use_ipu() ^ pass_builder_->use_ipu())) ||
      ((use_custom_device() ^ pass_builder_->use_custom_device()))) {
Y
Yan Chunwei 已提交
634 635 636 637 638 639 640
    if (use_gpu()) {
      pass_builder_.reset(new GpuPassStrategy);

      if (use_tensorrt_) {
        // Append after the Affine_channel_conv_fuse pass.
        pass_builder()->InsertPass(3, "tensorrt_subgraph_pass");
      }
J
jianghaicheng 已提交
641 642 643
    } else if (use_ipu()) {
      VLOG(1) << "IpuPassStrategy has been used for new.";
      pass_builder_.reset(new IpuPassStrategy);
644 645
    } else if (use_xpu()) {
      PADDLE_ENFORCE_EQ(
W
Wilber 已提交
646 647
          use_gpu(),
          false,
648 649 650
          platform::errors::InvalidArgument(
              "Only one choice can be made between CPU and XPU."));
      pass_builder_.reset(new XpuPassStrategy);
W
Wilber 已提交
651 652
    } else if (use_npu()) {
      PADDLE_ENFORCE_EQ(
W
Wilber 已提交
653 654
          use_gpu(),
          false,
W
Wilber 已提交
655 656 657
          platform::errors::InvalidArgument(
              "Only one choice can be made between GPU and NPU."));
      pass_builder_.reset(new NpuPassStrategy);
658 659
    } else if (use_custom_device()) {
      PADDLE_ENFORCE_EQ(
W
Wilber 已提交
660 661
          use_gpu(),
          false,
662 663 664
          platform::errors::InvalidArgument(
              "Only one choice can be made between GPU and CustomDevice."));
      pass_builder_.reset(new CustomDevicePassStrategy);
Y
Yan Chunwei 已提交
665 666 667
    } else {
      pass_builder_.reset(new CpuPassStrategy);
    }
668

669
  } else {
Y
Yan Chunwei 已提交
670 671 672
    if (use_gpu()) {
      pass_builder_.reset(new GpuPassStrategy(
          *static_cast<GpuPassStrategy *>(pass_builder_.get())));
J
jianghaicheng 已提交
673 674 675 676
    } else if (use_ipu()) {
      VLOG(1) << "IpuPassStrategy has been used.";
      pass_builder_.reset(new IpuPassStrategy(
          *static_cast<IpuPassStrategy *>(pass_builder_.get())));
677 678
    } else if (use_xpu()) {
      PADDLE_ENFORCE_EQ(
W
Wilber 已提交
679 680
          use_gpu(),
          false,
681 682 683 684
          platform::errors::InvalidArgument(
              "Only one choice can be made between CPU and XPU."));
      pass_builder_.reset(new XpuPassStrategy(
          *static_cast<XpuPassStrategy *>(pass_builder_.get())));
W
Wilber 已提交
685 686
    } else if (use_npu()) {
      PADDLE_ENFORCE_EQ(
W
Wilber 已提交
687 688
          use_gpu(),
          false,
W
Wilber 已提交
689 690 691 692
          platform::errors::InvalidArgument(
              "Only one choice can be made between GPU and NPU."));
      pass_builder_.reset(new NpuPassStrategy(
          *static_cast<NpuPassStrategy *>(pass_builder_.get())));
693 694
    } else if (use_custom_device()) {
      PADDLE_ENFORCE_EQ(
W
Wilber 已提交
695 696
          use_gpu(),
          false,
697 698 699 700
          platform::errors::InvalidArgument(
              "Only one choice can be made between GPU and CustomDevice."));
      pass_builder_.reset(new CustomDevicePassStrategy(
          *static_cast<CustomDevicePassStrategy *>(pass_builder_.get())));
Y
Yan Chunwei 已提交
701 702 703 704
    } else {
      pass_builder_.reset(new CpuPassStrategy(
          *static_cast<CpuPassStrategy *>(pass_builder_.get())));
    }
705 706 707
  }

  if (use_tensorrt_) {
708 709
    pass_builder()->ClearPasses();
    for (const auto &pass : kTRTSubgraphPasses) {
710
      if (tensorrt_precision_mode_ == AnalysisConfig::Precision::kInt8 &&
711
          (pass == "conv_bn_fuse_pass")) {
712 713
        continue;
      }
714
      pass_builder()->AppendPass(pass);
715 716
    }
  }
717

D
denglin-github 已提交
718 719 720 721 722 723 724
  if (use_dlnne_) {
    pass_builder()->ClearPasses();
    for (const auto &pass : kDlnneSubgraphPasses) {
      pass_builder()->AppendPass(pass);
    }
  }

725
  if (use_gpu() && use_cudnn_) {
726
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
727 728 729 730 731 732 733 734
    if (!enable_ir_optim_) {
      LOG(ERROR) << "EnableCUDNN() only works when IR optimization is enabled.";
    } else {
      pass_builder()->EnableCUDNN();
    }
#endif
  }

735
  if (use_mkldnn_) {
W
Wojciech Uss 已提交
736
#ifdef PADDLE_WITH_MKLDNN
737 738 739
    if (!enable_ir_optim_) {
      LOG(ERROR)
          << "EnableMKLDNN() only works when IR optimization is enabled.";
W
Wojciech Uss 已提交
740 741
    } else {
      pass_builder()->EnableMKLDNN();
742 743 744 745
    }
#endif
  }

746 747 748 749 750
  // Quantization passes must come after all other optimization passes
  if (use_mkldnn_quantizer_) {
    if (!enable_ir_optim_) {
      LOG(ERROR) << "EnableMkldnnQuantizer() only works when IR optimization "
                    "is enabled.";
751 752
    }
#ifdef PADDLE_WITH_MKLDNN
753
    pass_builder()->EnableMkldnnQuantizer();
754 755 756
#endif
  }

757 758 759 760 761 762
  if (use_mkldnn_bfloat16_) {
#ifdef PADDLE_WITH_MKLDNN
    pass_builder()->EnableMkldnnBfloat16();
#endif
  }

B
baoachun 已提交
763 764 765 766 767 768 769 770 771 772 773 774 775 776
  if (use_mkldnn_int8_) {
#ifdef PADDLE_WITH_MKLDNN
    if (!enable_ir_optim_) {
      LOG(ERROR) << "EnableMkldnnInt8() only works when IR optimization "
                    "is enabled.";
    } else if (!use_mkldnn_) {
      LOG(ERROR) << "EnableMkldnnInt8() only works when MKLDNN "
                    "is enabled.";
    } else {
      pass_builder()->EnableMkldnnInt8();
    }
#endif
  }

777
#ifdef PADDLE_WITH_MKLDNN
778 779
  // Do not optimize when mkldnn is on
  if (enable_memory_optim_ && !use_mkldnn_) {
780
#else
Y
Yan Chunwei 已提交
781
  if (enable_memory_optim_) {
782 783
#endif
    pass_builder()->AppendAnalysisPass("memory_optimize_pass");
Y
Yan Chunwei 已提交
784 785
  }

石晓伟 已提交
786 787 788 789 790 791 792
  if (use_lite_) {
#ifndef PADDLE_WITH_LITE
    LOG(WARNING) << "You tried to enable the lite subgraph "
                    "but did not have the option -DWITH_LITE compiled.";
#endif
    pass_builder()->ClearPasses();
    for (const auto &pass : kLiteSubgraphPasses) {
W
Wilber 已提交
793 794
      if (std::find(lite_passes_filter_.begin(),
                    lite_passes_filter_.end(),
石晓伟 已提交
795 796 797 798 799 800
                    pass) == lite_passes_filter_.end()) {
        pass_builder()->AppendPass(pass);
      }
    }
  }

801
  if (use_xpu_) {
802
#if (defined LITE_SUBGRAPH_WITH_XPU) || (defined PADDLE_WITH_XPU)
W
Wilber 已提交
803 804
    PADDLE_ENFORCE_EQ(use_gpu_,
                      false,
805 806 807
                      platform::errors::Unavailable(
                          "Currently, XPU and GPU cannot be enabled in the "
                          "same analysis configuration."));
808 809 810 811 812
#else
    PADDLE_THROW(platform::errors::Unavailable(
        "You tried to use an XPU device, but Paddle was not compiled "
        "with XPU-runtime."));
#endif
813 814
  }

W
Wilber 已提交
815
  if (use_npu_) {
816
#if defined(PADDLE_WITH_ASCEND_CL) || defined(LITE_SUBGRAPH_WITH_NPU)
W
Wilber 已提交
817 818
    PADDLE_ENFORCE_EQ(use_gpu_,
                      false,
W
Wilber 已提交
819 820 821 822 823 824 825 826 827
                      platform::errors::Unavailable(
                          "Currently, NPU and GPU cannot be enabled in the "
                          "same analysis configuration."));
#else
    PADDLE_THROW(platform::errors::Unavailable(
        "You tried to use an NPU device, but Paddle was not compiled "
        "with NPU-runtime."));
#endif
  }
J
jianghaicheng 已提交
828 829 830 831 832 833 834
  if (use_ipu_) {
#ifndef PADDLE_WITH_IPU
    PADDLE_THROW(platform::errors::Unavailable(
        "You tried to enable the ipu "
        "but did not have the option -DWITH_IPU compiled."));
#endif
  }
835 836 837 838 839 840 841
  if (use_custom_device_) {
#ifndef PADDLE_WITH_CUSTOM_DEVICE
    PADDLE_THROW(platform::errors::Unavailable(
        "You tried to enable the custom device "
        "but did not have the option -DWITH_CUSTOM_DEVICE compiled."));
#endif
  }
842 843 844 845 846
  if (ir_debug_) {
    pass_builder()->TurnOnDebug();
  }
}

847
std::string AnalysisConfig::SerializeInfoCache() {
848
  std::stringstream ss;
Y
Yan Chunwei 已提交
849 850 851 852
  ss << model_dir_;
  ss << prog_file_;
  ss << params_file_;

853 854
  ss << calibration_file_path_;

855
  ss << use_gpu_;
856 857
  ss << use_external_stream_;
  ss << exec_stream_;
858
  ss << use_fc_padding_;
859 860
  ss << gpu_device_id_;
  ss << xpu_device_id_;
861 862 863 864 865
  ss << memory_pool_init_size_mb_;

  ss << use_tensorrt_;
  ss << tensorrt_workspace_size_;
  ss << tensorrt_max_batchsize_;
Y
Yan Chunwei 已提交
866 867
  ss << tensorrt_min_subgraph_size_;

D
denglin-github 已提交
868 869 870
  ss << use_dlnne_;
  ss << dlnne_min_subgraph_size_;

871 872 873
  for (auto &op : trt_disabled_ops_) ss << op.c_str();
  ss << ";";

874 875 876
  ss << trt_use_dla_;
  ss << trt_dla_core_;

Y
Yan Chunwei 已提交
877
  ss << enable_memory_optim_;
878
  ss << trt_engine_memory_sharing_;
879 880

  ss << use_mkldnn_;
881
  ss << mkldnn_cache_capacity_;
Y
Yan Chunwei 已提交
882 883 884
  for (auto &item : mkldnn_enabled_op_types_) ss << item;
  ss << ";";

885
  ss << use_mkldnn_quantizer_;
886
  ss << use_mkldnn_bfloat16_;
887
  for (auto &item : bfloat16_enabled_op_types_) ss << item;
B
baoachun 已提交
888 889 890
  ss << use_mkldnn_int8_;
  for (auto &item : quantize_enabled_op_types_) ss << item;
  for (auto &item : quantize_excluded_op_ids_) ss << item;
891
  ss << ";";
Y
Yan Chunwei 已提交
892 893
  ss << model_from_memory_;

894 895
  ss << with_profile_;

896 897
  ss << with_glog_info_;

898 899 900 901
  ss << enable_ir_optim_;
  ss << use_feed_fetch_ops_;
  ss << ir_debug_;

Y
Yan Chunwei 已提交
902 903
  ss << specify_input_name_;
  ss << cpu_math_library_num_threads_;
石晓伟 已提交
904 905

  ss << use_lite_;
906 907
  ss << use_xpu_;
  ss << xpu_l3_workspace_size_;
W
Wilber 已提交
908 909 910 911 912
  ss << xpu_locked_;
  ss << xpu_autotune_;
  ss << xpu_autotune_file_;
  ss << xpu_precision_;
  ss << xpu_adaptive_seqlen_;
913

W
Wilber 已提交
914 915 916
  ss << use_npu_;
  ss << npu_device_id_;

917 918
  ss << thread_local_stream_;

J
jianghaicheng 已提交
919 920
  ss << use_ipu_;
  ss << ipu_device_num_;
921
  ss << ipu_micro_batch_size_;
J
jianghaicheng 已提交
922 923
  ss << ipu_enable_pipelining_;
  ss << ipu_batches_per_step_;
924 925 926 927
  ss << ipu_enable_fp16_;
  ss << ipu_replica_num_;
  ss << ipu_available_memory_proportion_;
  ss << ipu_enable_half_partial_;
J
jianghaicheng 已提交
928

929
  for (auto &op : mixed_black_list_) ss << op.c_str();
930 931 932
  return ss.str();
}

933
void AnalysisConfig::SetCpuMathLibraryNumThreads(
934 935
    int cpu_math_library_num_threads) {
  cpu_math_library_num_threads_ = cpu_math_library_num_threads;
Y
Yan Chunwei 已提交
936 937

  Update();
938 939
}

940
float AnalysisConfig::fraction_of_gpu_memory_for_pool() const {
941
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
942 943
  // Get the GPU memory details and calculate the fraction of memory for the
  // GPU memory pool.
944
  size_t gpu_total, gpu_available;
945
  platform::SetDeviceId(gpu_device_id_);
946 947
  platform::GpuMemoryUsage(&gpu_available, &gpu_total);
  double total_gpu_memory = gpu_total / 1024. / 1024.;
948 949
  float fraction_of_gpu_memory =
      static_cast<double>(memory_pool_init_size_mb()) / total_gpu_memory;
950 951 952 953
  VLOG(3) << "total_gpu_memory is " << total_gpu_memory
          << "M, gpu_available is " << gpu_available / 1024. / 1024.
          << "M, memory_pool_init_size is " << memory_pool_init_size_mb()
          << "M.";
954 955 956 957
  return fraction_of_gpu_memory;
#else
  return 0.;
#endif
958 959
}

960 961
void AnalysisConfig::EnableMemoryOptim(bool x) {
  enable_memory_optim_ = x;
Y
Yan Chunwei 已提交
962 963 964
  Update();
}

965
bool AnalysisConfig::enable_memory_optim() const {
Y
Yan Chunwei 已提交
966 967 968
  return enable_memory_optim_;
}

969 970 971 972
bool AnalysisConfig::trt_engine_memory_sharing() const {
  return trt_engine_memory_sharing_;
}

973 974 975 976
void AnalysisConfig::SetModelBuffer(const char *prog_buffer,
                                    size_t prog_buffer_size,
                                    const char *param_buffer,
                                    size_t param_buffer_size) {
977 978
  prog_file_ = std::string(prog_buffer, prog_buffer + prog_buffer_size);
  params_file_ = std::string(param_buffer, param_buffer + param_buffer_size);
T
Tao Luo 已提交
979
  model_from_memory_ = true;
T
Tao Luo 已提交
980 981
}

982
NativeConfig AnalysisConfig::ToNativeConfig() const {
Y
Yan Chunwei 已提交
983 984 985 986 987
  NativeConfig config;
  config.model_dir = model_dir_;
  config.prog_file = prog_file_;
  config.param_file = params_file_;
  config.use_gpu = use_gpu_;
988
  config.device = gpu_device_id_;
Y
Yan Chunwei 已提交
989 990 991 992 993
  config.fraction_of_gpu_memory = fraction_of_gpu_memory_for_pool();
  config.specify_input_name = specify_input_name_;
  return config;
}

Y
Yan Chunwei 已提交
994 995 996 997
void AnalysisConfig::SwitchIrDebug(int x) {
  ir_debug_ = x;
  Update();
}
998 999 1000 1001 1002 1003

void AnalysisConfig::EnableProfile() {
  with_profile_ = true;
  Update();
}

1004 1005 1006 1007 1008
void AnalysisConfig::DisableGlogInfo() {
  with_glog_info_ = false;
  Update();
}

石晓伟 已提交
1009
void AnalysisConfig::EnableLiteEngine(
W
Wilber 已提交
1010 1011
    AnalysisConfig::Precision precision_mode,
    bool zero_copy,
石晓伟 已提交
1012 1013 1014 1015 1016 1017
    const std::vector<std::string> &passes_filter,
    const std::vector<std::string> &ops_filter) {
  use_lite_ = true;
  lite_precision_mode_ = precision_mode;
  lite_passes_filter_ = passes_filter;
  lite_ops_filter_ = ops_filter;
1018
  lite_zero_copy_ = zero_copy;
石晓伟 已提交
1019 1020 1021
  Update();
}

1022 1023 1024 1025 1026 1027 1028
void AnalysisConfig::PartiallyRelease() {
  prog_file_.clear();
  prog_file_.shrink_to_fit();
  params_file_.clear();
  params_file_.shrink_to_fit();
}

1029 1030
void AnalysisConfig::EnableGpuMultiStream() { thread_local_stream_ = true; }

1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041
std::string AnalysisConfig::Summary() {
  const std::vector<std::string> header{"Option", "Value"};
  paddle::inference::TablePrinter os(header);

  if (!model_dir_.empty()) {
    os.InsertRow({"model_dir", model_dir_});
  }
  if (!(prog_file_.empty() && params_file_.empty())) {
    os.InsertRow({"model_file", prog_file_});
    os.InsertRow({"params_file", params_file_});
  }
1042 1043 1044 1045
  if (!(calibration_file_path_.empty())) {
    os.InsertRow({"calibration_file_path", calibration_file_path_});
  }

1046 1047 1048 1049 1050 1051 1052 1053
  if (model_from_memory_) {
    os.InsertRow({"model_from_memory", params_file_});
  }
  os.InsetDivider();

  // cpu info
  os.InsertRow(
      {"cpu_math_thread", std::to_string(cpu_math_library_num_threads_)});
1054
  os.InsertRow({"enable_mkldnn", use_mkldnn_ ? "true" : "false"});
1055 1056 1057 1058 1059 1060 1061 1062 1063 1064
  os.InsertRow(
      {"mkldnn_cache_capacity", std::to_string(mkldnn_cache_capacity_)});
  os.InsetDivider();

  // gpu info
  os.InsertRow({"use_gpu", use_gpu_ ? "true" : "false"});
  if (use_gpu_) {
    os.InsertRow({"gpu_device_id", std::to_string(gpu_device_id_)});
    os.InsertRow({"memory_pool_init_size",
                  std::to_string(memory_pool_init_size_mb_) + "MB"});
1065 1066
    os.InsertRow(
        {"use_external_stream", use_external_stream_ ? "true" : "false"});
1067 1068 1069 1070 1071
    os.InsertRow(
        {"thread_local_stream", thread_local_stream_ ? "true" : "false"});

    os.InsertRow({"use_tensorrt", use_tensorrt_ ? "true" : "false"});
    if (use_tensorrt_) {
1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098
#ifdef PADDLE_WITH_TENSORRT
      auto Precision2String =
          [](paddle::AnalysisConfig::Precision prec) -> std::string {
        if (prec == Precision::kFloat32)
          return "fp32";
        else if (prec == Precision::kHalf)
          return "fp16";
        else if (prec == Precision::kInt8)
          return "int8";
        else
          return "None";
      };
      auto version2string =
          [](const std::tuple<int, int, int> &ver) -> std::string {
        std::ostringstream os;
        int major = std::get<0>(ver);
        int minor = std::get<1>(ver);
        int patch = std::get<2>(ver);
        os << major << "." << minor << "." << patch;
        return os.str();
      };
      os.InsertRow(
          {"trt_compile_version",
           version2string(inference::tensorrt::GetTrtCompileVersion())});
      os.InsertRow(
          {"trt_runtime_version",
           version2string(inference::tensorrt::GetTrtRuntimeVersion())});
1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114
      os.InsertRow({"tensorrt_precision_mode",
                    Precision2String(tensorrt_precision_mode_)});
      os.InsertRow({"tensorrt_workspace_size",
                    std::to_string(tensorrt_workspace_size_)});
      os.InsertRow(
          {"tensorrt_max_batch_size", std::to_string(tensorrt_max_batchsize_)});
      os.InsertRow({"tensorrt_min_subgraph_size",
                    std::to_string(tensorrt_min_subgraph_size_)});
      os.InsertRow({"tensorrt_use_static_engine",
                    trt_use_static_engine_ ? "true" : "false"});
      os.InsertRow(
          {"tensorrt_use_calib_mode", trt_use_calib_mode_ ? "true" : "false"});

      // dynamic_shape
      os.InsertRow({"tensorrt_enable_dynamic_shape",
                    min_input_shape_.empty() ? "false" : "true"});
W
Wilber 已提交
1115 1116 1117
      os.InsertRow(
          {"tensorrt_tuned_dynamic_shape",
           trt_tuned_dynamic_shape_ ? shape_range_info_path_ : "false"});
1118

1119 1120
      os.InsertRow(
          {"tensorrt_use_varseqlen", trt_use_varseqlen_ ? "true" : "false"});
1121 1122
      os.InsertRow({"tensorrt_with_interleaved",
                    trt_with_interleaved_ ? "true" : "false"});
1123 1124 1125
      os.InsertRow({"tensorrt_transformer_posid", tensorrt_transformer_posid_});
      os.InsertRow(
          {"tensorrt_transformer_maskid", tensorrt_transformer_maskid_});
1126 1127 1128 1129
      os.InsertRow({"tensorrt_use_dla", trt_use_dla_ ? "true" : "false"});
      if (trt_use_dla_) {
        os.InsertRow({"tensorrt_dla_core", std::to_string(trt_dla_core_)});
      }
1130 1131
      os.InsertRow({"trt_engine_memory_sharing",
                    trt_engine_memory_sharing_ ? "true" : "false"});
1132
#endif
1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155
    }
  }
  os.InsetDivider();

  // xpu info
  os.InsertRow({"use_xpu", use_xpu_ ? "true" : "false"});
  if (use_xpu_) {
    os.InsertRow({"xpu_device_id", std::to_string(xpu_device_id_)});
    os.InsertRow(
        {"xpu_l3_workspace_size", std::to_string(xpu_l3_workspace_size_)});
  }
  os.InsetDivider();

  if (use_lite_) {
    os.InsertRow({"use_lite", use_lite_ ? "true" : "false"});
  }

  // ir info
  os.InsertRow({"ir_optim", enable_ir_optim_ ? "true" : "false"});
  os.InsertRow({"ir_debug", ir_debug_ ? "true" : "false"});
  os.InsertRow({"memory_optim", enable_memory_optim_ ? "true" : "false"});
  os.InsertRow({"enable_profile", with_profile_ ? "true" : "false"});
  os.InsertRow({"enable_log", with_glog_info_ ? "true" : "false"});
1156 1157
  os.InsertRow({"collect_shape_range_info",
                collect_shape_range_info_ ? shape_range_info_path_ : "false"});
1158 1159 1160 1161

  return os.PrintTable();
}

1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182
LiteNNAdapterConfig &LiteNNAdapterConfig::SetDeviceNames(
    const std::vector<std::string> &names) {
  nnadapter_device_names = names;
  return *this;
}

LiteNNAdapterConfig &LiteNNAdapterConfig::SetContextProperties(
    const std::string &properties) {
  nnadapter_context_properties = properties;
  return *this;
}

LiteNNAdapterConfig &LiteNNAdapterConfig::SetModelCacheDir(
    const std::string &dir) {
  nnadapter_model_cache_dir = dir;
  return *this;
}

LiteNNAdapterConfig &LiteNNAdapterConfig::SetModelCacheBuffers(
    const std::string &model_cache_token,
    const std::vector<char> &model_cache_buffer) {
W
Wilber 已提交
1183 1184
  PADDLE_ENFORCE_EQ(model_cache_token.empty(),
                    false,
1185 1186
                    platform::errors::InvalidArgument(
                        "model_cache_token should not be empty."));
W
Wilber 已提交
1187 1188
  PADDLE_ENFORCE_EQ(model_cache_buffer.empty(),
                    false,
1189 1190 1191
                    platform::errors::InvalidArgument(
                        "model_cache_buffer should not be empty."));
  PADDLE_ENFORCE_EQ(nnadapter_model_cache_buffers.count(model_cache_token),
1192 1193 1194
                    false,
                    platform::errors::InvalidArgument(
                        "model_cache_token has already been set."));
1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219

  nnadapter_model_cache_buffers[model_cache_token] = model_cache_buffer;
  return *this;
}

LiteNNAdapterConfig &LiteNNAdapterConfig::SetSubgraphPartitionConfigPath(
    const std::string &path) {
  nnadapter_subgraph_partition_config_path = path;
  return *this;
}

LiteNNAdapterConfig &LiteNNAdapterConfig::SetSubgraphPartitionConfigBuffer(
    const std::string &buffer) {
  nnadapter_subgraph_partition_config_buffer = buffer;
  return *this;
}
LiteNNAdapterConfig &LiteNNAdapterConfig::Enable() {
  use_nnadapter = true;
  return *this;
}
LiteNNAdapterConfig &LiteNNAdapterConfig::Disable() {
  use_nnadapter = false;
  return *this;
}

1220 1221 1222 1223 1224 1225 1226
void AnalysisConfig::CollectShapeRangeInfo(
    const std::string &shape_range_info_path) {
  LOG(INFO) << "In CollectShapeInfo mode, we will disable optimizations and "
               "collect the shape information of "
            << "all intermediate tensors in the compute graph and calculate "
               "the min_shape, max_shape and opt_shape.";
  collect_shape_range_info_ = true;
W
Wilber 已提交
1227 1228
  PADDLE_ENFORCE_EQ(shape_range_info_path.empty(),
                    false,
1229 1230 1231 1232 1233 1234
                    platform::errors::InvalidArgument(
                        "The shape_range_info_path should not be empty, please "
                        "re-check the argument."));
  shape_range_info_path_ = shape_range_info_path;
}

1235
const std::string &AnalysisConfig::shape_range_info_path() const {
1236 1237 1238
  return shape_range_info_path_;
}

1239
bool AnalysisConfig::shape_range_info_collected() const {
1240 1241 1242 1243 1244 1245 1246 1247 1248 1249
  return collect_shape_range_info_;
}

void AnalysisConfig::EnableTunedTensorRtDynamicShape(
    const std::string &shape_range_info_path, bool allow_build_at_runtime) {
  shape_range_info_path_ = shape_range_info_path;
  trt_allow_build_at_runtime_ = allow_build_at_runtime;
  trt_tuned_dynamic_shape_ = true;
}

1250
bool AnalysisConfig::tuned_tensorrt_dynamic_shape() const {
1251 1252 1253
  return trt_tuned_dynamic_shape_;
}

1254
bool AnalysisConfig::trt_allow_build_at_runtime() const {
1255 1256
  return trt_allow_build_at_runtime_;
}
1257 1258 1259 1260 1261 1262

void AnalysisConfig::Exp_SetBlackListOpsForMixedModel(
    const std::unordered_set<std::string> &black_list) {
  mixed_black_list_ = black_list;
}

1263
}  // namespace paddle