analysis_config.cc 28.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include <string>
16 17
#include "paddle/fluid/inference/api/paddle_analysis_config.h"
#include "paddle/fluid/inference/api/paddle_pass_builder.h"
18
#include "paddle/fluid/inference/utils/table_printer.h"
19
#include "paddle/fluid/platform/cpu_info.h"
20
#include "paddle/fluid/platform/enforce.h"
21
#include "paddle/fluid/platform/gpu_info.h"
22

23
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
24 25 26
DECLARE_uint64(initial_gpu_memory_in_mb);
#endif

27
namespace paddle {
W
wanghuancoder 已提交
28 29
struct MkldnnQuantizerConfig;

30
extern const std::vector<std::string> kTRTSubgraphPasses;
D
denglin-github 已提交
31
extern const std::vector<std::string> kDlnneSubgraphPasses;
石晓伟 已提交
32
extern const std::vector<std::string> kLiteSubgraphPasses;
33

34
PassStrategy *AnalysisConfig::pass_builder() const {
35 36 37 38
  if (!pass_builder_.get()) {
    if (use_gpu_) {
      LOG(INFO) << "Create GPU IR passes";
      pass_builder_.reset(new GpuPassStrategy);
39 40
    } else if (use_xpu_) {
      pass_builder_.reset(new XpuPassStrategy);
W
Wilber 已提交
41 42
    } else if (use_npu_) {
      pass_builder_.reset(new NpuPassStrategy);
43 44 45 46 47 48 49 50 51 52 53 54
    } else {
      LOG(INFO) << "Create CPU IR passes";
      pass_builder_.reset(new CpuPassStrategy);
    }
  } else if (pass_builder_->use_gpu() ^ use_gpu()) {
    LOG(WARNING) << "The use_gpu flag is not compatible between Config and "
                    "PassBuilder, the flags are "
                 << use_gpu() << " " << pass_builder_->use_gpu();
    LOG(WARNING) << "Please make them compatible, still use the existing "
                    "PassBuilder.";
  }

55 56 57
  return pass_builder_.get();
}

58
AnalysisConfig::AnalysisConfig(const std::string &model_dir) {
59
  model_dir_ = model_dir;
Y
Yan Chunwei 已提交
60 61

  Update();
62
}
63 64
AnalysisConfig::AnalysisConfig(const std::string &prog_file,
                               const std::string &params_file) {
65 66
  prog_file_ = prog_file;
  params_file_ = params_file;
Y
Yan Chunwei 已提交
67 68

  Update();
69
}
70 71
void AnalysisConfig::SetModel(const std::string &prog_file_path,
                              const std::string &params_file_path) {
72 73
  prog_file_ = prog_file_path;
  params_file_ = params_file_path;
Y
Yan Chunwei 已提交
74 75

  Update();
76
}
77 78
void AnalysisConfig::EnableUseGpu(uint64_t memory_pool_init_size_mb,
                                  int device_id) {
79
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
80 81
  use_gpu_ = true;
  memory_pool_init_size_mb_ = memory_pool_init_size_mb;
82
  FLAGS_initial_gpu_memory_in_mb = memory_pool_init_size_mb_;
83
  gpu_device_id_ = device_id;
84
#else
Y
Yan Chunwei 已提交
85
  LOG(ERROR) << "Please compile with gpu to EnableGpu()";
86 87
  use_gpu_ = false;
#endif
Y
Yan Chunwei 已提交
88 89 90

  Update();
}
91
void AnalysisConfig::DisableGpu() {
Y
Yan Chunwei 已提交
92 93 94
  use_gpu_ = false;

  Update();
95 96
}

97 98 99 100 101 102
void AnalysisConfig::DisableFCPadding() {
  use_fc_padding_ = false;

  Update();
}

W
Wilber 已提交
103 104 105 106
void AnalysisConfig::EnableXpu(int l3_workspace_size, bool locked,
                               bool autotune, const std::string &autotune_file,
                               const std::string &precision,
                               bool adaptive_seqlen) {
107 108
  use_xpu_ = true;
  xpu_l3_workspace_size_ = l3_workspace_size;
W
Wilber 已提交
109 110 111 112 113
  xpu_locked_ = locked;
  xpu_autotune_ = autotune;
  xpu_autotune_file_ = autotune_file;
  xpu_precision_ = precision;
  xpu_adaptive_seqlen_ = adaptive_seqlen;
114 115 116
  Update();
}

117 118 119 120 121 122 123 124
void AnalysisConfig::SetXpuDeviceId(int device_id) {
  PADDLE_ENFORCE_EQ(use_xpu_, true,
                    platform::errors::PreconditionNotMet(
                        "Should call EnableXpu before SetXpuDeviceId."));
  xpu_device_id_ = device_id;
  Update();
}

W
Wilber 已提交
125 126 127 128 129 130 131 132 133 134 135 136
void AnalysisConfig::EnableNpu(int device_id) {
#ifdef PADDLE_WITH_ASCEND_CL
  use_npu_ = true;
  npu_device_id_ = device_id;
#else
  LOG(ERROR) << "Please compile with npu to EnableNpu()";
  use_npu_ = false;
#endif

  Update();
}

137
AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
138 139 140 141 142 143
#define CP_MEMBER(member__) member__ = other.member__;

  // Model related.
  CP_MEMBER(model_dir_);
  CP_MEMBER(model_from_memory_);  // the memory model reuses prog_file_ and
                                  // params_file_ fields.
144

145
  CP_MEMBER(opt_cache_dir_);
W
Wilber 已提交
146 147
  CP_MEMBER(prog_file_);
  CP_MEMBER(params_file_);
148

149
  CP_MEMBER(use_fc_padding_);
150
  // GPU related.
151
  CP_MEMBER(use_gpu_);
152
  CP_MEMBER(use_cudnn_);
153
  CP_MEMBER(gpu_device_id_);
154
  CP_MEMBER(memory_pool_init_size_mb_);
Y
Yan Chunwei 已提交
155 156

  CP_MEMBER(enable_memory_optim_);
S
Sylwester Fraczek 已提交
157
  // TensorRT related.
158 159 160 161
  CP_MEMBER(use_tensorrt_);
  CP_MEMBER(tensorrt_workspace_size_);
  CP_MEMBER(tensorrt_max_batchsize_);
  CP_MEMBER(tensorrt_min_subgraph_size_);
N
nhzlx 已提交
162
  CP_MEMBER(tensorrt_precision_mode_);
163
  CP_MEMBER(trt_disabled_ops_);
164 165
  CP_MEMBER(trt_use_dla_);
  CP_MEMBER(trt_dla_core_);
N
nhzlx 已提交
166
  CP_MEMBER(trt_use_static_engine_);
167
  CP_MEMBER(trt_use_calib_mode_);
168
  CP_MEMBER(trt_use_oss_);
169 170 171 172
  CP_MEMBER(trt_tuned_dynamic_shape_);
  CP_MEMBER(trt_allow_build_at_runtime_);
  CP_MEMBER(collect_shape_range_info_);
  CP_MEMBER(shape_range_info_path_);
D
denglin-github 已提交
173 174 175
  // Dlnne related
  CP_MEMBER(use_dlnne_);
  CP_MEMBER(dlnne_min_subgraph_size_);
S
Sylwester Fraczek 已提交
176
  // MKLDNN related.
177 178
  CP_MEMBER(use_mkldnn_);
  CP_MEMBER(mkldnn_enabled_op_types_);
179
  CP_MEMBER(mkldnn_cache_capacity_);
180 181 182
  // Bfloat16 related.
  CP_MEMBER(use_mkldnn_bfloat16_);
  CP_MEMBER(bfloat16_enabled_op_types_);
183 184 185
  // Quantization related.
  CP_MEMBER(use_mkldnn_quantizer_);
  CP_MEMBER(mkldnn_quantizer_config_);
186 187 188
  CP_MEMBER(min_input_shape_);
  CP_MEMBER(max_input_shape_);
  CP_MEMBER(optim_input_shape_);
189
  CP_MEMBER(disable_trt_plugin_fp16_);
190

石晓伟 已提交
191 192 193 194
  CP_MEMBER(use_lite_);
  CP_MEMBER(lite_precision_mode_);
  CP_MEMBER(lite_passes_filter_);
  CP_MEMBER(lite_ops_filter_);
195 196
  CP_MEMBER(lite_zero_copy_);

W
Wilber 已提交
197
  // XPU related.
198
  CP_MEMBER(use_xpu_);
W
Wilber 已提交
199
  CP_MEMBER(xpu_device_id_);
200
  CP_MEMBER(xpu_l3_workspace_size_);
W
Wilber 已提交
201 202 203 204 205
  CP_MEMBER(xpu_locked_);
  CP_MEMBER(xpu_autotune_);
  CP_MEMBER(xpu_autotune_file_);
  CP_MEMBER(xpu_precision_);
  CP_MEMBER(xpu_adaptive_seqlen_);
石晓伟 已提交
206

W
Wilber 已提交
207 208 209
  // NPU related.
  CP_MEMBER(use_npu_);
  CP_MEMBER(npu_device_id_);
210
  CP_MEMBER(nnadapter_config_);
W
Wilber 已提交
211

212 213 214
  // profile related.
  CP_MEMBER(with_profile_);

215 216 217
  // glog related.
  CP_MEMBER(with_glog_info_);

218 219 220 221 222 223 224 225 226 227
  // Ir related.
  CP_MEMBER(enable_ir_optim_);
  CP_MEMBER(use_feed_fetch_ops_);
  CP_MEMBER(ir_debug_);
  CP_MEMBER(specify_input_name_);

  CP_MEMBER(cpu_math_library_num_threads_);

  CP_MEMBER(serialized_info_cache_);

228 229
  CP_MEMBER(thread_local_stream_);

230
  if (use_gpu_) {
231 232 233
    PADDLE_ENFORCE_EQ(use_xpu_, false,
                      platform::errors::InvalidArgument(
                          "Only one choice can be made between CPU and XPU."));
234 235
    pass_builder_.reset(new GpuPassStrategy(
        *static_cast<GpuPassStrategy *>(other.pass_builder())));
236 237 238
  } else if (use_xpu_) {
    pass_builder_.reset(new XpuPassStrategy(
        *static_cast<XpuPassStrategy *>(other.pass_builder())));
W
Wilber 已提交
239 240 241
  } else if (use_npu_) {
    pass_builder_.reset(new NpuPassStrategy(
        *static_cast<NpuPassStrategy *>(other.pass_builder())));
242 243 244 245 246
  } else {
    pass_builder_.reset(new CpuPassStrategy(
        *static_cast<CpuPassStrategy *>(other.pass_builder())));
  }

247
#undef CP_MEMBER
Y
Yan Chunwei 已提交
248

W
Wilber 已提交
249 250 251 252 253 254 255
  Update();
  if (use_tensorrt_) {
    // Update() will reset all the passes, when some tensorRT pass is deleted in
    // other.pass_builder(), it will set again, so we just remove the
    // deleted_pass.
    auto all_passes = kTRTSubgraphPasses;
    auto other_passes = other.pass_builder()->AllPasses();
W
Wilber 已提交
256 257 258 259
    // We should sort them, because the user may call the SwitchIrDebug
    // interface, which will change the pass.
    std::sort(all_passes.begin(), all_passes.end());
    std::sort(other_passes.begin(), other_passes.end());
W
Wilber 已提交
260 261 262 263 264 265 266
    std::vector<std::string> deleted_passes;
    std::set_difference(all_passes.begin(), all_passes.end(),
                        other_passes.begin(), other_passes.end(),
                        std::inserter(deleted_passes, deleted_passes.begin()));
    for (auto ps : deleted_passes) {
      pass_builder_->DeletePass(ps);
    }
267
  }
D
denglin-github 已提交
268 269 270 271 272 273 274 275 276 277 278 279 280 281 282
  if (use_dlnne_) {
    auto all_passes = kDlnneSubgraphPasses;
    auto other_passes = other.pass_builder()->AllPasses();
    // We should sort them, because the user may call the SwitchIrDebug
    // interface, which will change the pass.
    std::sort(all_passes.begin(), all_passes.end());
    std::sort(other_passes.begin(), other_passes.end());
    std::vector<std::string> deleted_passes;
    std::set_difference(all_passes.begin(), all_passes.end(),
                        other_passes.begin(), other_passes.end(),
                        std::inserter(deleted_passes, deleted_passes.begin()));
    for (auto ps : deleted_passes) {
      pass_builder_->DeletePass(ps);
    }
  }
283 284
}

285
void AnalysisConfig::EnableCUDNN() {
286
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
287 288 289 290 291 292 293 294 295
  use_cudnn_ = use_gpu_;
#else
  LOG(ERROR) << "Please compile with CUDA first to use cuDNN";
  use_cudnn_ = false;
#endif

  Update();
}

296
void AnalysisConfig::EnableMKLDNN() {
297 298 299 300 301 302
#ifdef PADDLE_WITH_MKLDNN
  use_mkldnn_ = true;
#else
  LOG(ERROR) << "Please compile with MKLDNN first to use MKLDNN";
  use_mkldnn_ = false;
#endif
Y
Yan Chunwei 已提交
303 304

  Update();
305 306
}

307 308 309 310 311 312 313 314 315
void AnalysisConfig::SetMkldnnCacheCapacity(int capacity) {
#ifdef PADDLE_WITH_MKLDNN
  mkldnn_cache_capacity_ = capacity;
#else
  LOG(ERROR) << "Please compile with MKLDNN first to set MKLDNN Thread Id";
  mkldnn_cache_capacity_ = 0;
#endif
}

316 317 318 319 320 321 322 323 324 325 326 327 328
void AnalysisConfig::EnableMkldnnQuantizer() {
#ifdef PADDLE_WITH_MKLDNN
  if (!mkldnn_quantizer_config_)
    mkldnn_quantizer_config_.reset(new MkldnnQuantizerConfig());
  use_mkldnn_quantizer_ = true;
#else
  LOG(ERROR) << "Please compile with MKLDNN first to use MkldnnQuantizer";
  use_mkldnn_quantizer_ = false;
#endif

  Update();
}

329 330
void AnalysisConfig::EnableMkldnnBfloat16() {
#ifdef PADDLE_WITH_MKLDNN
331 332
  if (platform::MayIUse(platform::cpu_isa_t::avx512_core)) {
    use_mkldnn_bfloat16_ = true;
333 334 335 336
    LOG(INFO) << "Hardware support for BFLOAT16"
              << (platform::MayIUse(platform::cpu_isa_t::avx512_bf16)
                      ? " is enabled"
                      : " is disabled. Simulation will be used");
337 338 339 340
  } else {
    LOG(INFO) << "CPU does not support BFLOAT16 calculations";
    use_mkldnn_bfloat16_ = false;
  }
341 342 343 344 345 346 347 348
#else
  LOG(ERROR) << "Please compile with MKLDNN first to use MkldnnBfloat16";
  use_mkldnn_bfloat16_ = false;
#endif

  Update();
}

349
MkldnnQuantizerConfig *AnalysisConfig::mkldnn_quantizer_config() const {
350
  PADDLE_ENFORCE_NOT_NULL(mkldnn_quantizer_config_,
351 352
                          platform::errors::PreconditionNotMet(
                              "MkldnnQuantizer was not enabled yet."));
353
  return mkldnn_quantizer_config_.get();
354 355
}

356
void AnalysisConfig::EnableTensorRtEngine(
N
nhzlx 已提交
357
    int workspace_size, int max_batch_size, int min_subgraph_size,
358
    AnalysisConfig::Precision precision_mode, bool use_static,
359
    bool use_calib_mode) {
360
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
Y
Yan Chunwei 已提交
361 362 363 364 365
  if (!use_gpu()) {
    LOG(ERROR) << "To use TensorRT engine, please call EnableGpu() first";
    return;
  }

366 367 368
  use_tensorrt_ = true;
  tensorrt_workspace_size_ = workspace_size;
  tensorrt_max_batchsize_ = max_batch_size;
N
nhzlx 已提交
369
  tensorrt_min_subgraph_size_ = min_subgraph_size;
N
nhzlx 已提交
370
  tensorrt_precision_mode_ = precision_mode;
N
nhzlx 已提交
371
  trt_use_static_engine_ = use_static;
372
  trt_use_calib_mode_ = use_calib_mode;
Y
Yan Chunwei 已提交
373

374
  Update();
Y
Yan Chunwei 已提交
375 376 377 378
#else
  LOG(ERROR)
      << "To use TensorRT engine, please compile inference lib with GPU first.";
#endif
379 380
}

D
denglin-github 已提交
381 382 383 384 385 386
void AnalysisConfig::EnableDlnne(int min_subgraph_size) {
  use_dlnne_ = true;
  dlnne_min_subgraph_size_ = min_subgraph_size;
  Update();
}

387 388 389 390 391 392 393 394 395 396 397
void AnalysisConfig::SetTRTDynamicShapeInfo(
    std::map<std::string, std::vector<int>> min_input_shape,
    std::map<std::string, std::vector<int>> max_input_shape,
    std::map<std::string, std::vector<int>> optim_input_shape,
    bool disable_trt_plugin_fp16) {
  min_input_shape_ = min_input_shape;
  max_input_shape_ = max_input_shape;
  optim_input_shape_ = optim_input_shape;
  disable_trt_plugin_fp16_ = disable_trt_plugin_fp16;
}

398 399 400 401 402
void AnalysisConfig::EnableTensorRtDLA(int dla_core) {
  trt_use_dla_ = true;
  trt_dla_core_ = dla_core;
}

403 404 405 406 407
void AnalysisConfig::Exp_DisableTensorRtOPs(
    const std::vector<std::string> &ops) {
  trt_disabled_ops_.insert(trt_disabled_ops_.end(), ops.begin(), ops.end());
}

408
void AnalysisConfig::EnableTensorRtOSS() { trt_use_oss_ = true; }
409

Y
Yan Chunwei 已提交
410
// TODO(Superjomn) refactor this, buggy.
411
void AnalysisConfig::Update() {
412 413 414
  auto info = SerializeInfoCache();
  if (info == serialized_info_cache_) return;

Y
Yan Chunwei 已提交
415
  // Transfer pass_builder and copy the existing compatible passes.
W
Wilber 已提交
416 417 418
  if (!pass_builder_ || ((use_gpu() ^ pass_builder_->use_gpu())) ||
      ((use_xpu() ^ pass_builder_->use_xpu())) ||
      ((use_npu() ^ pass_builder_->use_npu()))) {
Y
Yan Chunwei 已提交
419 420 421 422 423 424 425
    if (use_gpu()) {
      pass_builder_.reset(new GpuPassStrategy);

      if (use_tensorrt_) {
        // Append after the Affine_channel_conv_fuse pass.
        pass_builder()->InsertPass(3, "tensorrt_subgraph_pass");
      }
426 427 428 429 430 431
    } else if (use_xpu()) {
      PADDLE_ENFORCE_EQ(
          use_gpu(), false,
          platform::errors::InvalidArgument(
              "Only one choice can be made between CPU and XPU."));
      pass_builder_.reset(new XpuPassStrategy);
W
Wilber 已提交
432 433 434 435 436 437
    } else if (use_npu()) {
      PADDLE_ENFORCE_EQ(
          use_gpu(), false,
          platform::errors::InvalidArgument(
              "Only one choice can be made between GPU and NPU."));
      pass_builder_.reset(new NpuPassStrategy);
Y
Yan Chunwei 已提交
438 439 440
    } else {
      pass_builder_.reset(new CpuPassStrategy);
    }
441

442
  } else {
Y
Yan Chunwei 已提交
443 444 445
    if (use_gpu()) {
      pass_builder_.reset(new GpuPassStrategy(
          *static_cast<GpuPassStrategy *>(pass_builder_.get())));
446 447 448 449 450 451 452
    } else if (use_xpu()) {
      PADDLE_ENFORCE_EQ(
          use_gpu(), false,
          platform::errors::InvalidArgument(
              "Only one choice can be made between CPU and XPU."));
      pass_builder_.reset(new XpuPassStrategy(
          *static_cast<XpuPassStrategy *>(pass_builder_.get())));
W
Wilber 已提交
453 454 455 456 457 458 459
    } else if (use_npu()) {
      PADDLE_ENFORCE_EQ(
          use_gpu(), false,
          platform::errors::InvalidArgument(
              "Only one choice can be made between GPU and NPU."));
      pass_builder_.reset(new NpuPassStrategy(
          *static_cast<NpuPassStrategy *>(pass_builder_.get())));
Y
Yan Chunwei 已提交
460 461 462 463
    } else {
      pass_builder_.reset(new CpuPassStrategy(
          *static_cast<CpuPassStrategy *>(pass_builder_.get())));
    }
464 465 466
  }

  if (use_tensorrt_) {
467 468
    pass_builder()->ClearPasses();
    for (const auto &pass : kTRTSubgraphPasses) {
469
      if (tensorrt_precision_mode_ == AnalysisConfig::Precision::kInt8 &&
470
          (pass == "conv_bn_fuse_pass")) {
471 472
        continue;
      }
473
      pass_builder()->AppendPass(pass);
474 475
    }
  }
D
denglin-github 已提交
476 477 478 479 480 481 482
  if (use_dlnne_) {
    pass_builder()->ClearPasses();
    for (const auto &pass : kDlnneSubgraphPasses) {
      pass_builder()->AppendPass(pass);
    }
  }

483
  if (use_gpu() && use_cudnn_) {
484
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
485 486 487 488 489 490 491 492
    if (!enable_ir_optim_) {
      LOG(ERROR) << "EnableCUDNN() only works when IR optimization is enabled.";
    } else {
      pass_builder()->EnableCUDNN();
    }
#endif
  }

493
  if (use_mkldnn_) {
W
Wojciech Uss 已提交
494
#ifdef PADDLE_WITH_MKLDNN
495 496 497
    if (!enable_ir_optim_) {
      LOG(ERROR)
          << "EnableMKLDNN() only works when IR optimization is enabled.";
W
Wojciech Uss 已提交
498 499
    } else {
      pass_builder()->EnableMKLDNN();
500 501 502 503
    }
#endif
  }

504 505 506 507 508
  // Quantization passes must come after all other optimization passes
  if (use_mkldnn_quantizer_) {
    if (!enable_ir_optim_) {
      LOG(ERROR) << "EnableMkldnnQuantizer() only works when IR optimization "
                    "is enabled.";
509 510
    }
#ifdef PADDLE_WITH_MKLDNN
511
    pass_builder()->EnableMkldnnQuantizer();
512 513 514
#endif
  }

515 516 517 518 519 520
  if (use_mkldnn_bfloat16_) {
#ifdef PADDLE_WITH_MKLDNN
    pass_builder()->EnableMkldnnBfloat16();
#endif
  }

521
#ifdef PADDLE_WITH_MKLDNN
522 523
  // Do not optimize when mkldnn is on
  if (enable_memory_optim_ && !use_mkldnn_) {
524
#else
Y
Yan Chunwei 已提交
525
  if (enable_memory_optim_) {
526 527
#endif
    pass_builder()->AppendAnalysisPass("memory_optimize_pass");
Y
Yan Chunwei 已提交
528 529
  }

石晓伟 已提交
530 531 532 533 534 535 536 537 538 539 540 541 542 543
  if (use_lite_) {
#ifndef PADDLE_WITH_LITE
    LOG(WARNING) << "You tried to enable the lite subgraph "
                    "but did not have the option -DWITH_LITE compiled.";
#endif
    pass_builder()->ClearPasses();
    for (const auto &pass : kLiteSubgraphPasses) {
      if (std::find(lite_passes_filter_.begin(), lite_passes_filter_.end(),
                    pass) == lite_passes_filter_.end()) {
        pass_builder()->AppendPass(pass);
      }
    }
  }

544
  if (use_xpu_) {
545
#if (defined LITE_SUBGRAPH_WITH_XPU) || (defined PADDLE_WITH_XPU)
546 547 548 549
    PADDLE_ENFORCE_EQ(use_gpu_, false,
                      platform::errors::Unavailable(
                          "Currently, XPU and GPU cannot be enabled in the "
                          "same analysis configuration."));
550 551 552 553 554
#else
    PADDLE_THROW(platform::errors::Unavailable(
        "You tried to use an XPU device, but Paddle was not compiled "
        "with XPU-runtime."));
#endif
555 556
  }

W
Wilber 已提交
557
  if (use_npu_) {
558
#if defined(PADDLE_WITH_ASCEND_CL) || defined(LITE_SUBGRAPH_WITH_NPU)
W
Wilber 已提交
559 560 561 562 563 564 565 566 567 568 569
    PADDLE_ENFORCE_EQ(use_gpu_, false,
                      platform::errors::Unavailable(
                          "Currently, NPU and GPU cannot be enabled in the "
                          "same analysis configuration."));
#else
    PADDLE_THROW(platform::errors::Unavailable(
        "You tried to use an NPU device, but Paddle was not compiled "
        "with NPU-runtime."));
#endif
  }

570 571 572 573 574
  if (ir_debug_) {
    pass_builder()->TurnOnDebug();
  }
}

575
std::string AnalysisConfig::SerializeInfoCache() {
576
  std::stringstream ss;
Y
Yan Chunwei 已提交
577 578 579 580
  ss << model_dir_;
  ss << prog_file_;
  ss << params_file_;

581
  ss << use_gpu_;
582
  ss << use_fc_padding_;
583 584
  ss << gpu_device_id_;
  ss << xpu_device_id_;
585 586 587 588 589
  ss << memory_pool_init_size_mb_;

  ss << use_tensorrt_;
  ss << tensorrt_workspace_size_;
  ss << tensorrt_max_batchsize_;
Y
Yan Chunwei 已提交
590 591
  ss << tensorrt_min_subgraph_size_;

D
denglin-github 已提交
592 593 594
  ss << use_dlnne_;
  ss << dlnne_min_subgraph_size_;

595 596 597
  for (auto &op : trt_disabled_ops_) ss << op.c_str();
  ss << ";";

598 599 600
  ss << trt_use_dla_;
  ss << trt_dla_core_;

Y
Yan Chunwei 已提交
601
  ss << enable_memory_optim_;
602 603

  ss << use_mkldnn_;
604
  ss << mkldnn_cache_capacity_;
Y
Yan Chunwei 已提交
605 606 607
  for (auto &item : mkldnn_enabled_op_types_) ss << item;
  ss << ";";

608
  ss << use_mkldnn_quantizer_;
609
  ss << use_mkldnn_bfloat16_;
610 611
  for (auto &item : bfloat16_enabled_op_types_) ss << item;
  ss << ";";
Y
Yan Chunwei 已提交
612 613
  ss << model_from_memory_;

614 615
  ss << with_profile_;

616 617
  ss << with_glog_info_;

618 619 620 621
  ss << enable_ir_optim_;
  ss << use_feed_fetch_ops_;
  ss << ir_debug_;

Y
Yan Chunwei 已提交
622 623
  ss << specify_input_name_;
  ss << cpu_math_library_num_threads_;
石晓伟 已提交
624 625

  ss << use_lite_;
626 627
  ss << use_xpu_;
  ss << xpu_l3_workspace_size_;
W
Wilber 已提交
628 629 630 631 632
  ss << xpu_locked_;
  ss << xpu_autotune_;
  ss << xpu_autotune_file_;
  ss << xpu_precision_;
  ss << xpu_adaptive_seqlen_;
633

W
Wilber 已提交
634 635 636
  ss << use_npu_;
  ss << npu_device_id_;

637 638
  ss << thread_local_stream_;

639 640 641
  return ss.str();
}

642
void AnalysisConfig::SetCpuMathLibraryNumThreads(
643 644
    int cpu_math_library_num_threads) {
  cpu_math_library_num_threads_ = cpu_math_library_num_threads;
Y
Yan Chunwei 已提交
645 646

  Update();
647 648
}

649
float AnalysisConfig::fraction_of_gpu_memory_for_pool() const {
650
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
651 652
  // Get the GPU memory details and calculate the fraction of memory for the
  // GPU memory pool.
653
  size_t gpu_total, gpu_available;
654
  platform::SetDeviceId(gpu_device_id_);
655 656
  platform::GpuMemoryUsage(&gpu_available, &gpu_total);
  double total_gpu_memory = gpu_total / 1024. / 1024.;
657 658
  float fraction_of_gpu_memory =
      static_cast<double>(memory_pool_init_size_mb()) / total_gpu_memory;
659 660 661 662
  VLOG(3) << "total_gpu_memory is " << total_gpu_memory
          << "M, gpu_available is " << gpu_available / 1024. / 1024.
          << "M, memory_pool_init_size is " << memory_pool_init_size_mb()
          << "M.";
663 664 665 666
  return fraction_of_gpu_memory;
#else
  return 0.;
#endif
667 668
}

669 670
void AnalysisConfig::EnableMemoryOptim(bool x) {
  enable_memory_optim_ = x;
Y
Yan Chunwei 已提交
671 672 673
  Update();
}

674
bool AnalysisConfig::enable_memory_optim() const {
Y
Yan Chunwei 已提交
675 676 677
  return enable_memory_optim_;
}

678 679 680 681
void AnalysisConfig::SetModelBuffer(const char *prog_buffer,
                                    size_t prog_buffer_size,
                                    const char *param_buffer,
                                    size_t param_buffer_size) {
682 683
  prog_file_ = std::string(prog_buffer, prog_buffer + prog_buffer_size);
  params_file_ = std::string(param_buffer, param_buffer + param_buffer_size);
T
Tao Luo 已提交
684
  model_from_memory_ = true;
Y
Yan Chunwei 已提交
685 686

  Update();
T
Tao Luo 已提交
687 688
}

689
NativeConfig AnalysisConfig::ToNativeConfig() const {
Y
Yan Chunwei 已提交
690 691 692 693 694
  NativeConfig config;
  config.model_dir = model_dir_;
  config.prog_file = prog_file_;
  config.param_file = params_file_;
  config.use_gpu = use_gpu_;
695
  config.device = gpu_device_id_;
Y
Yan Chunwei 已提交
696 697 698 699 700
  config.fraction_of_gpu_memory = fraction_of_gpu_memory_for_pool();
  config.specify_input_name = specify_input_name_;
  return config;
}

Y
Yan Chunwei 已提交
701 702 703 704
void AnalysisConfig::SwitchIrDebug(int x) {
  ir_debug_ = x;
  Update();
}
705 706 707 708 709 710

void AnalysisConfig::EnableProfile() {
  with_profile_ = true;
  Update();
}

711 712 713 714 715
void AnalysisConfig::DisableGlogInfo() {
  with_glog_info_ = false;
  Update();
}

石晓伟 已提交
716
void AnalysisConfig::EnableLiteEngine(
717
    AnalysisConfig::Precision precision_mode, bool zero_copy,
石晓伟 已提交
718 719 720 721 722 723
    const std::vector<std::string> &passes_filter,
    const std::vector<std::string> &ops_filter) {
  use_lite_ = true;
  lite_precision_mode_ = precision_mode;
  lite_passes_filter_ = passes_filter;
  lite_ops_filter_ = ops_filter;
724
  lite_zero_copy_ = zero_copy;
石晓伟 已提交
725 726 727
  Update();
}

728 729 730 731 732 733 734
void AnalysisConfig::PartiallyRelease() {
  prog_file_.clear();
  prog_file_.shrink_to_fit();
  params_file_.clear();
  params_file_.shrink_to_fit();
}

735 736
void AnalysisConfig::EnableGpuMultiStream() { thread_local_stream_ = true; }

737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755
std::string AnalysisConfig::Summary() {
  const std::vector<std::string> header{"Option", "Value"};
  paddle::inference::TablePrinter os(header);

  if (!model_dir_.empty()) {
    os.InsertRow({"model_dir", model_dir_});
  }
  if (!(prog_file_.empty() && params_file_.empty())) {
    os.InsertRow({"model_file", prog_file_});
    os.InsertRow({"params_file", params_file_});
  }
  if (model_from_memory_) {
    os.InsertRow({"model_from_memory", params_file_});
  }
  os.InsetDivider();

  // cpu info
  os.InsertRow(
      {"cpu_math_thread", std::to_string(cpu_math_library_num_threads_)});
756
  os.InsertRow({"enable_mkldnn", use_mkldnn_ ? "true" : "false"});
757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798
  os.InsertRow(
      {"mkldnn_cache_capacity", std::to_string(mkldnn_cache_capacity_)});
  os.InsetDivider();

  auto Precision2String =
      [](paddle::AnalysisConfig::Precision prec) -> std::string {
    if (prec == Precision::kFloat32)
      return "fp32";
    else if (prec == Precision::kHalf)
      return "fp16";
    else if (prec == Precision::kInt8)
      return "int8";
    else
      return "None";
  };
  // gpu info
  os.InsertRow({"use_gpu", use_gpu_ ? "true" : "false"});
  if (use_gpu_) {
    os.InsertRow({"gpu_device_id", std::to_string(gpu_device_id_)});
    os.InsertRow({"memory_pool_init_size",
                  std::to_string(memory_pool_init_size_mb_) + "MB"});
    os.InsertRow(
        {"thread_local_stream", thread_local_stream_ ? "true" : "false"});

    os.InsertRow({"use_tensorrt", use_tensorrt_ ? "true" : "false"});
    if (use_tensorrt_) {
      os.InsertRow({"tensorrt_precision_mode",
                    Precision2String(tensorrt_precision_mode_)});
      os.InsertRow({"tensorrt_workspace_size",
                    std::to_string(tensorrt_workspace_size_)});
      os.InsertRow(
          {"tensorrt_max_batch_size", std::to_string(tensorrt_max_batchsize_)});
      os.InsertRow({"tensorrt_min_subgraph_size",
                    std::to_string(tensorrt_min_subgraph_size_)});
      os.InsertRow({"tensorrt_use_static_engine",
                    trt_use_static_engine_ ? "true" : "false"});
      os.InsertRow(
          {"tensorrt_use_calib_mode", trt_use_calib_mode_ ? "true" : "false"});

      // dynamic_shape
      os.InsertRow({"tensorrt_enable_dynamic_shape",
                    min_input_shape_.empty() ? "false" : "true"});
799 800 801
      os.InsertRow({"tensorrt_tuned_dynamic_shape", trt_tuned_dynamic_shape_
                                                        ? shape_range_info_path_
                                                        : "false"});
802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830

      os.InsertRow({"tensorrt_use_oss", trt_use_oss_ ? "true" : "false"});
      os.InsertRow({"tensorrt_use_dla", trt_use_dla_ ? "true" : "false"});
      if (trt_use_dla_) {
        os.InsertRow({"tensorrt_dla_core", std::to_string(trt_dla_core_)});
      }
    }
  }
  os.InsetDivider();

  // xpu info
  os.InsertRow({"use_xpu", use_xpu_ ? "true" : "false"});
  if (use_xpu_) {
    os.InsertRow({"xpu_device_id", std::to_string(xpu_device_id_)});
    os.InsertRow(
        {"xpu_l3_workspace_size", std::to_string(xpu_l3_workspace_size_)});
  }
  os.InsetDivider();

  if (use_lite_) {
    os.InsertRow({"use_lite", use_lite_ ? "true" : "false"});
  }

  // ir info
  os.InsertRow({"ir_optim", enable_ir_optim_ ? "true" : "false"});
  os.InsertRow({"ir_debug", ir_debug_ ? "true" : "false"});
  os.InsertRow({"memory_optim", enable_memory_optim_ ? "true" : "false"});
  os.InsertRow({"enable_profile", with_profile_ ? "true" : "false"});
  os.InsertRow({"enable_log", with_glog_info_ ? "true" : "false"});
831 832
  os.InsertRow({"collect_shape_range_info",
                collect_shape_range_info_ ? shape_range_info_path_ : "false"});
833 834 835 836

  return os.PrintTable();
}

837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891
LiteNNAdapterConfig &LiteNNAdapterConfig::SetDeviceNames(
    const std::vector<std::string> &names) {
  nnadapter_device_names = names;
  return *this;
}

LiteNNAdapterConfig &LiteNNAdapterConfig::SetContextProperties(
    const std::string &properties) {
  nnadapter_context_properties = properties;
  return *this;
}

LiteNNAdapterConfig &LiteNNAdapterConfig::SetModelCacheDir(
    const std::string &dir) {
  nnadapter_model_cache_dir = dir;
  return *this;
}

LiteNNAdapterConfig &LiteNNAdapterConfig::SetModelCacheBuffers(
    const std::string &model_cache_token,
    const std::vector<char> &model_cache_buffer) {
  PADDLE_ENFORCE_EQ(model_cache_token.empty(), false,
                    platform::errors::InvalidArgument(
                        "model_cache_token should not be empty."));
  PADDLE_ENFORCE_EQ(model_cache_buffer.empty(), false,
                    platform::errors::InvalidArgument(
                        "model_cache_buffer should not be empty."));
  PADDLE_ENFORCE_EQ(nnadapter_model_cache_buffers.count(model_cache_token),
                    false, platform::errors::InvalidArgument(
                               "model_cache_token has already been set."));

  nnadapter_model_cache_buffers[model_cache_token] = model_cache_buffer;
  return *this;
}

LiteNNAdapterConfig &LiteNNAdapterConfig::SetSubgraphPartitionConfigPath(
    const std::string &path) {
  nnadapter_subgraph_partition_config_path = path;
  return *this;
}

LiteNNAdapterConfig &LiteNNAdapterConfig::SetSubgraphPartitionConfigBuffer(
    const std::string &buffer) {
  nnadapter_subgraph_partition_config_buffer = buffer;
  return *this;
}
LiteNNAdapterConfig &LiteNNAdapterConfig::Enable() {
  use_nnadapter = true;
  return *this;
}
LiteNNAdapterConfig &LiteNNAdapterConfig::Disable() {
  use_nnadapter = false;
  return *this;
}

892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927
void AnalysisConfig::CollectShapeRangeInfo(
    const std::string &shape_range_info_path) {
  LOG(INFO) << "In CollectShapeInfo mode, we will disable optimizations and "
               "collect the shape information of "
            << "all intermediate tensors in the compute graph and calculate "
               "the min_shape, max_shape and opt_shape.";
  collect_shape_range_info_ = true;
  PADDLE_ENFORCE_EQ(shape_range_info_path.empty(), false,
                    platform::errors::InvalidArgument(
                        "The shape_range_info_path should not be empty, please "
                        "re-check the argument."));
  shape_range_info_path_ = shape_range_info_path;
}

const std::string &AnalysisConfig::shape_range_info_path() {
  return shape_range_info_path_;
}

bool AnalysisConfig::shape_range_info_collected() {
  return collect_shape_range_info_;
}

void AnalysisConfig::EnableTunedTensorRtDynamicShape(
    const std::string &shape_range_info_path, bool allow_build_at_runtime) {
  shape_range_info_path_ = shape_range_info_path;
  trt_allow_build_at_runtime_ = allow_build_at_runtime;
  trt_tuned_dynamic_shape_ = true;
}

bool AnalysisConfig::tuned_tensorrt_dynamic_shape() {
  return trt_tuned_dynamic_shape_;
}

bool AnalysisConfig::trt_allow_build_at_runtime() {
  return trt_allow_build_at_runtime_;
}
928
}  // namespace paddle