paddle_analysis_config.h 28.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
14 15 16 17 18 19 20 21 22 23 24

///
/// \file paddle_analysis_config.h
///
/// \brief Paddle Analysis Config API信息
///
/// \author paddle-infer@baidu.com
/// \date 2020-03-20
/// \since 1.7
///

25 26 27
#pragma once

#include <cassert>
28
#include <map>
29 30
#include <memory>
#include <string>
31
#include <unordered_set>
32
#include <utility>
33
#include <vector>
34

35
#include "paddle_infer_declare.h"  // NOLINT
36

37
/*! \file */
38 39 40 41
// Here we include some header files with relative paths, for that in deploy,
// the abstract path of this header file will be changed.
#include "paddle_api.h"           // NOLINT
#include "paddle_pass_builder.h"  // NOLINT
42 43 44
#ifdef PADDLE_WITH_MKLDNN
#include "paddle_mkldnn_quantizer_config.h"  // NOLINT
#endif
45 46 47 48

namespace paddle {

class AnalysisPredictor;
49
struct MkldnnQuantizerConfig;
50

51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
struct LiteNNAdapterConfig {
  bool use_nnadapter{false};
  std::string nnadapter_model_cache_dir;
  std::map<std::string, std::vector<char>> nnadapter_model_cache_buffers;
  std::vector<std::string> nnadapter_device_names;
  std::string nnadapter_context_properties;
  std::string nnadapter_subgraph_partition_config_path;
  std::string nnadapter_subgraph_partition_config_buffer;

  LiteNNAdapterConfig& SetDeviceNames(const std::vector<std::string>& names);

  LiteNNAdapterConfig& SetContextProperties(const std::string& properties);

  LiteNNAdapterConfig& SetModelCacheDir(const std::string& dir);

  LiteNNAdapterConfig& SetModelCacheBuffers(
      const std::string& model_cache_token,
      const std::vector<char>& model_cache_buffer);

  LiteNNAdapterConfig& SetSubgraphPartitionConfigPath(const std::string& path);

  LiteNNAdapterConfig& SetSubgraphPartitionConfigBuffer(
      const std::string& buffer);

  LiteNNAdapterConfig& Enable();
  LiteNNAdapterConfig& Disable();
};

79
///
80
/// \brief configuration manager for AnalysisPredictor.
81 82
/// \since 1.7.0
///
83
/// AnalysisConfig manages configurations of AnalysisPredictor.
84 85 86 87 88
/// During inference procedure, there are many parameters(model/params path,
/// place of inference, etc.)
/// to be specified, and various optimizations(subgraph fusion, memory
/// optimazation, TensorRT engine, etc.)
/// to be done. Users can manage these settings by creating and modifying an
89 90
/// AnalysisConfig,
/// and loading it into AnalysisPredictor.
91
///
92
struct PD_INFER_DECL AnalysisConfig {
93
  AnalysisConfig() = default;
94
  ///
95 96
  /// \brief Construct a new AnalysisConfig from another
  /// AnalysisConfig.
97
  ///
98
  /// \param[in] other another AnalysisConfig
99
  ///
100
  explicit AnalysisConfig(const AnalysisConfig& other);
101
  ///
102
  /// \brief Construct a new AnalysisConfig from a no-combined model.
103 104 105
  ///
  /// \param[in] model_dir model directory of the no-combined model.
  ///
106
  explicit AnalysisConfig(const std::string& model_dir);
107
  ///
108
  /// \brief Construct a new AnalysisConfig from a combined model.
109 110 111 112
  ///
  /// \param[in] prog_file model file path of the combined model.
  /// \param[in] params_file params file path of the combined model.
  ///
113 114
  explicit AnalysisConfig(const std::string& prog_file,
                          const std::string& params_file);
115 116 117
  ///
  /// \brief Precision of inference in TensorRT.
  ///
N
nhzlx 已提交
118
  enum class Precision {
119 120 121
    kFloat32 = 0,  ///< fp32
    kInt8,         ///< int8
    kHalf,         ///< fp16
N
nhzlx 已提交
122
  };
123

124 125 126 127 128
  ///
  /// \brief Set the no-combined model dir path.
  ///
  /// \param model_dir model dir path.
  ///
129
  void SetModel(const std::string& model_dir) { model_dir_ = model_dir; }
130 131 132 133 134 135 136 137

  ///
  /// \brief Set the combined model with two specific pathes for program and
  /// parameters.
  ///
  /// \param prog_file_path model file path of the combined model.
  /// \param params_file_path params file path of the combined model.
  ///
138 139
  void SetModel(const std::string& prog_file_path,
                const std::string& params_file_path);
140 141 142 143 144
  ///
  /// \brief Set the model file path of a combined model.
  ///
  /// \param x model file path.
  ///
145
  void SetProgFile(const std::string& x) { prog_file_ = x; }
146 147 148 149 150
  ///
  /// \brief Set the params file path of a combined model.
  ///
  /// \param x params file path.
  ///
151
  void SetParamsFile(const std::string& x) { params_file_ = x; }
152 153 154 155 156 157

  ///
  /// \brief Set the path of optimization cache directory.
  ///
  /// \param opt_cache_dir the path of optimization cache directory.
  ///
158 159 160
  void SetOptimCacheDir(const std::string& opt_cache_dir) {
    opt_cache_dir_ = opt_cache_dir;
  }
161 162 163 164 165
  ///
  /// \brief Get the model directory path.
  ///
  /// \return const std::string& The model directory path.
  ///
166
  const std::string& model_dir() const { return model_dir_; }
167 168 169 170 171
  ///
  /// \brief Get the program file path.
  ///
  /// \return const std::string& The program file path.
  ///
172
  const std::string& prog_file() const { return prog_file_; }
173 174 175 176 177
  ///
  /// \brief Get the combined parameters file.
  ///
  /// \return const std::string& The combined parameters file.
  ///
178 179
  const std::string& params_file() const { return params_file_; }

180
  // Padding related.
181 182 183 184 185

  ///
  /// \brief Turn off FC Padding.
  ///
  ///
186
  void DisableFCPadding();
187 188 189 190 191
  ///
  /// \brief A boolean state telling whether fc padding is used.
  ///
  /// \return bool Whether fc padding is used.
  ///
192 193
  bool use_fc_padding() const { return use_fc_padding_; }

194
  // GPU related.
195

196 197 198 199 200 201
  ///
  /// \brief Turn on GPU.
  ///
  /// \param memory_pool_init_size_mb initial size of the GPU memory pool in MB.
  /// \param device_id device_id the GPU card to use (default is 0).
  ///
202
  void EnableUseGpu(uint64_t memory_pool_init_size_mb, int device_id = 0);
203 204 205 206
  ///
  /// \brief Turn off GPU.
  ///
  ///
207
  void DisableGpu();
208

209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228
  ///
  /// \brief Turn on XPU.
  ///
  /// \param l3_workspace_size The size of the video memory allocated by the l3
  ///         cache, the maximum is 16M.
  /// \param locked Whether the allocated L3 cache can be locked. If false,
  ///       it means that the L3 cache is not locked, and the allocated L3
  ///       cache can be shared by multiple models, and multiple models
  ///       sharing the L3 cache will be executed sequentially on the card.
  /// \param autotune Whether to autotune the conv operator in the model. If
  ///       true, when the conv operator of a certain dimension is executed
  ///       for the first time, it will automatically search for a better
  ///       algorithm to improve the performance of subsequent conv operators
  ///       of the same dimension.
  /// \param autotune_file Specify the path of the autotune file. If
  ///       autotune_file is specified, the algorithm specified in the
  ///       file will be used and autotune will not be performed again.
  /// \param precision Calculation accuracy of multi_encoder
  /// \param adaptive_seqlen Is the input of multi_encoder variable length
  ///
W
Wilber 已提交
229 230 231 232
  void EnableXpu(int l3_workspace_size = 0xfffc00, bool locked = false,
                 bool autotune = true, const std::string& autotune_file = "",
                 const std::string& precision = "int16",
                 bool adaptive_seqlen = false);
J
jianghaicheng 已提交
233 234 235 236

  ///
  /// \brief Turn on IPU.
  ///
237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260
  /// \param ipu_device_num the number of IPUs.
  /// \param ipu_micro_batch_size the batch size in the graph, only work with
  /// mutable input shapes.
  /// \param ipu_enable_pipelining enable pipelining.
  /// \param ipu_batches_per_step the number of batches per run in pipelining.
  ///
  void EnableIpu(int ipu_device_num = 1, int ipu_micro_batch_size = 1,
                 bool ipu_enable_pipelining = false,
                 int ipu_batches_per_step = 1);

  ///
  /// \brief Set IPU config.
  ///
  /// \param ipu_enable_fp16 enable fp16.
  /// \param ipu_replica_num the number of graph replication.
  /// \param ipu_available_memory_proportion the available memory proportion for
  /// matmul/conv.
  /// \param ipu_enable_half_partial enable fp16 partial for matmul, only work
  /// with fp16.
  ///
  void SetIpuConfig(bool ipu_enable_fp16 = false, int ipu_replica_num = 1,
                    float ipu_available_memory_proportion = 1.0,
                    bool ipu_enable_half_partial = false);

261
  ///
262 263 264 265 266 267
  /// \brief Set XPU device id.
  ///
  /// \param device_id the XPU card to use (default is 0).
  ///
  void SetXpuDeviceId(int device_id = 0);
  ///
W
Wilber 已提交
268 269 270 271 272 273
  /// \brief Turn on NPU.
  ///
  /// \param device_id device_id the NPU card to use (default is 0).
  ///
  void EnableNpu(int device_id = 0);
  ///
274 275 276 277
  /// \brief A boolean state telling whether the GPU is turned on.
  ///
  /// \return bool Whether the GPU is turned on.
  ///
278
  bool use_gpu() const { return use_gpu_; }
279
  ///
280 281 282 283 284 285
  /// \brief A boolean state telling whether the XPU is turned on.
  ///
  /// \return bool Whether the XPU is turned on.
  ///
  bool use_xpu() const { return use_xpu_; }
  ///
W
Wilber 已提交
286 287 288 289 290
  /// \brief A boolean state telling whether the NPU is turned on.
  ///
  /// \return bool Whether the NPU is turned on.
  ///
  bool use_npu() const { return use_npu_; }
J
jianghaicheng 已提交
291 292 293 294 295
  /// \brief A boolean state telling whether the IPU is turned on.
  ///
  /// \return bool Whether the IPU is turned on.
  ///
  bool use_ipu() const { return use_ipu_; }
W
Wilber 已提交
296
  ///
297 298 299 300 301 302
  /// \brief Get the GPU device id.
  ///
  /// \return int The GPU device id.
  ///
  int gpu_device_id() const { return gpu_device_id_; }
  ///
303
  /// \brief Get the XPU device id.
304
  ///
305
  /// \return int The XPU device id.
306
  ///
307
  int xpu_device_id() const { return xpu_device_id_; }
308
  ///
W
Wilber 已提交
309 310 311 312 313
  /// \brief Get the NPU device id.
  ///
  /// \return int The NPU device id.
  ///
  int npu_device_id() const { return npu_device_id_; }
J
jianghaicheng 已提交
314 315 316 317 318
  /// \brief Get the the number of IPU device .
  ///
  /// \return int The number of IPU device.
  ///
  int ipu_device_num() const { return ipu_device_num_; }
W
Wilber 已提交
319
  ///
320 321 322 323
  /// \brief Get the initial size in MB of the GPU memory pool.
  ///
  /// \return int The initial size in MB of the GPU memory pool.
  ///
324
  int memory_pool_init_size_mb() const { return memory_pool_init_size_mb_; }
325 326 327 328 329 330
  ///
  /// \brief Get the proportion of the initial memory pool size compared to the
  /// device.
  ///
  /// \return float The proportion of the initial memory pool size.
  ///
331
  float fraction_of_gpu_memory_for_pool() const;
332

333 334 335 336 337
  // CUDNN related.
  ///
  /// \brief Turn on CUDNN.
  ///
  ///
338
  void EnableCUDNN();
339 340 341 342 343
  ///
  /// \brief A boolean state telling whether to use CUDNN.
  ///
  /// \return bool Whether to use CUDNN.
  ///
344 345
  bool cudnn_enabled() const { return use_cudnn_; }

346 347 348 349 350 351
  ///
  /// \brief Control whether to perform IR graph optimization.
  /// If turned off, the AnalysisConfig will act just like a NativeConfig.
  ///
  /// \param x Whether the ir graph optimization is actived.
  ///
352
  void SwitchIrOptim(int x = true) { enable_ir_optim_ = x; }
353 354 355 356 357 358
  ///
  /// \brief A boolean state telling whether the ir graph optimization is
  /// actived.
  ///
  /// \return bool Whether to use ir graph optimization.
  ///
359
  bool ir_optim() const { return enable_ir_optim_; }
360

361 362 363 364 365 366 367
  ///
  /// \brief INTERNAL Determine whether to use the feed and fetch operators.
  /// Just for internal development, not stable yet.
  /// When ZeroCopyTensor is used, this should be turned off.
  ///
  /// \param x Whether to use the feed and fetch operators.
  ///
368
  void SwitchUseFeedFetchOps(int x = true) { use_feed_fetch_ops_ = x; }
369 370 371 372 373 374
  ///
  /// \brief A boolean state telling whether to use the feed and fetch
  /// operators.
  ///
  /// \return bool Whether to use the feed and fetch operators.
  ///
375
  bool use_feed_fetch_ops_enabled() const { return use_feed_fetch_ops_; }
376

377 378 379 380 381 382 383 384 385 386 387
  ///
  /// \brief Control whether to specify the inputs' names.
  /// The ZeroCopyTensor type has a name member, assign it with the
  /// corresponding
  /// variable name. This is used only when the input ZeroCopyTensors passed to
  /// the
  /// AnalysisPredictor.ZeroCopyRun() cannot follow the order in the training
  /// phase.
  ///
  /// \param x Whether to specify the inputs' names.
  ///
388
  void SwitchSpecifyInputNames(bool x = true) { specify_input_name_ = x; }
389 390 391 392 393 394 395
  ///
  /// \brief A boolean state tell whether the input ZeroCopyTensor names
  /// specified should
  /// be used to reorder the inputs in AnalysisPredictor.ZeroCopyRun().
  ///
  /// \return bool Whether to specify the inputs' names.
  ///
396
  bool specify_input_name() const { return specify_input_name_; }
397

398 399 400 401 402 403 404 405 406 407
  ///
  /// \brief Turn on the TensorRT engine.
  /// The TensorRT engine will accelerate some subgraphes in the original Fluid
  /// computation graph. In some models such as resnet50, GoogleNet and so on,
  /// it gains significant performance acceleration.
  ///
  /// \param workspace_size The memory size(in byte) used for TensorRT
  /// workspace.
  /// \param max_batch_size The maximum batch size of this prediction task,
  /// better set as small as possible for less performance loss.
408
  /// \param min_subgraph_size The minimum TensorRT subgraph size needed, if a
409 410 411 412 413 414 415 416
  /// subgraph is smaller than this, it will not be transferred to TensorRT
  /// engine.
  /// \param precision The precision used in TensorRT.
  /// \param use_static Serialize optimization information to disk for reusing.
  /// \param use_calib_mode Use TRT int8 calibration(post training
  /// quantization).
  ///
  ///
417 418 419 420 421
  void EnableTensorRtEngine(int workspace_size = 1 << 20,
                            int max_batch_size = 1, int min_subgraph_size = 3,
                            Precision precision = Precision::kFloat32,
                            bool use_static = false,
                            bool use_calib_mode = true);
422 423 424 425 426
  ///
  /// \brief A boolean state telling whether the TensorRT engine is used.
  ///
  /// \return bool Whether the TensorRT engine is used.
  ///
427
  bool tensorrt_engine_enabled() const { return use_tensorrt_; }
428
  ///
429 430 431 432 433 434
  /// \brief  Get the TensorRT engine precision.
  ///
  /// \return Precision Get the TensorRT engine precision.
  ///
  Precision tensorrt_precision_mode() const { return tensorrt_precision_mode_; }
  ///
435 436 437 438 439 440 441
  /// \brief Set min, max, opt shape for TensorRT Dynamic shape mode.
  /// \param min_input_shape The min input shape of the subgraph input.
  /// \param max_input_shape The max input shape of the subgraph input.
  /// \param opt_input_shape The opt input shape of the subgraph input.
  /// \param disable_trt_plugin_fp16 Setting this parameter to true means that
  /// TRT plugin will not run fp16.
  ///
442 443 444 445 446
  void SetTRTDynamicShapeInfo(
      std::map<std::string, std::vector<int>> min_input_shape,
      std::map<std::string, std::vector<int>> max_input_shape,
      std::map<std::string, std::vector<int>> optim_input_shape,
      bool disable_trt_plugin_fp16 = false);
447 448 449 450 451 452
  ///
  /// \brief A boolean state telling whether the trt dynamic_shape is used.
  ///
  /// \return bool Whether the trt dynamic_shape is used.
  ///
  bool tensorrt_dynamic_shape_enabled() const {
W
Wilber 已提交
453
    return !min_input_shape_.empty();
454
  }
455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498
  ///
  /// \brief Enable tuned tensorrt dynamic shape.
  ///
  /// \param shape_range_info_path the path to shape_info file got in
  /// CollectShapeInfo
  /// mode.
  /// \param allow_build_at_runtime allow build trt engine at runtime.
  ///
  void EnableTunedTensorRtDynamicShape(const std::string& shape_range_info_path,
                                       bool allow_build_at_runtime = true);

  ///
  /// \brief A boolean state telling whether to use tuned tensorrt dynamic
  /// shape.
  ///
  bool tuned_tensorrt_dynamic_shape();

  ///
  /// \brief A boolean state telling whether to allow building trt engine at
  /// runtime.
  ///
  bool trt_allow_build_at_runtime();

  ///
  /// \brief Collect shape info of all tensors in compute graph.
  ///
  /// \param shape_range_info_path the path to save shape info.
  ///
  void CollectShapeRangeInfo(const std::string& shape_range_info_path);

  ///
  /// \brief the shape info path in CollectShapeInfo mode.
  ///
  /// \return the shape info path.
  ///
  const std::string& shape_range_info_path();

  ///
  /// \brief A boolean state telling whether to collect shape info.
  ///
  /// \return bool Whether to collect shape info.
  ///
  bool shape_range_info_collected();

499 500 501 502 503 504
  ///
  /// \brief Prevent ops running in Paddle-TRT
  /// NOTE: just experimental, not an official stable API, easy to be broken.
  ///
  void Exp_DisableTensorRtOPs(const std::vector<std::string>& ops);

505 506
  ///
  /// \brief Replace some TensorRT plugins to TensorRT OSS(
507 508 509
  /// https://github.com/NVIDIA/TensorRT), with which some models's inference
  /// may be more high-performance. Libnvinfer_plugin.so greater than
  /// V7.2.1 is needed.
510 511
  ///
  void EnableTensorRtOSS();
512

513 514 515 516 517 518 519
  ///
  /// \brief A boolean state telling whether to use the TensorRT OSS.
  ///
  /// \return bool Whether to use the TensorRT OSS.
  ///
  bool tensorrt_oss_enabled() { return trt_use_oss_; }

520 521 522 523 524 525 526 527 528 529 530 531 532 533
  ///
  /// \brief Enable TensorRT DLA
  /// \param dla_core ID of DLACore, which should be 0, 1,
  ///        ..., IBuilder.getNbDLACores() - 1
  ///
  void EnableTensorRtDLA(int dla_core = 0);

  ///
  /// \brief A boolean state telling whether to use the TensorRT DLA.
  ///
  /// \return bool Whether to use the TensorRT DLA.
  ///
  bool tensorrt_dla_enabled() { return trt_use_dla_; }

534 535 536
  void EnableTensorRtInspector();
  bool tensorrt_inspector_enabled() { return trt_use_inspector_; }

D
denglin-github 已提交
537 538 539
  void EnableDlnne(int min_subgraph_size = 3);
  bool dlnne_enabled() const { return use_dlnne_; }

540 541 542 543 544 545 546
  ///
  /// \brief Turn on the usage of Lite sub-graph engine.
  ///
  /// \param precision_mode Precion used in Lite sub-graph engine.
  /// \param passes_filter Set the passes used in Lite sub-graph engine.
  /// \param ops_filter Operators not supported by Lite.
  ///
石晓伟 已提交
547 548
  void EnableLiteEngine(
      AnalysisConfig::Precision precision_mode = Precision::kFloat32,
549
      bool zero_copy = false,
石晓伟 已提交
550 551 552
      const std::vector<std::string>& passes_filter = {},
      const std::vector<std::string>& ops_filter = {});

553 554 555 556 557 558
  ///
  /// \brief A boolean state indicating whether the Lite sub-graph engine is
  /// used.
  ///
  /// \return bool whether the Lite sub-graph engine is used.
  ///
石晓伟 已提交
559 560
  bool lite_engine_enabled() const { return use_lite_; }

561 562 563 564 565 566 567
  ///
  /// \brief Control whether to debug IR graph analysis phase.
  /// This will generate DOT files for visualizing the computation graph after
  /// each analysis pass applied.
  ///
  /// \param x whether to debug IR graph analysis phase.
  ///
Y
Yan Chunwei 已提交
568
  void SwitchIrDebug(int x = true);
569

570 571 572 573
  ///
  /// \brief Turn on MKLDNN.
  ///
  ///
L
luotao1 已提交
574
  void EnableMKLDNN();
575 576 577
  ///
  /// \brief Set the cache capacity of different input shapes for MKLDNN.
  /// Default value 0 means not caching any shape.
578 579
  /// Please see MKL-DNN Data Caching Design Document:
  /// https://github.com/PaddlePaddle/FluidDoc/blob/develop/doc/fluid/design/mkldnn/caching/caching.md
580 581 582
  ///
  /// \param capacity The cache capacity.
  ///
583
  void SetMkldnnCacheCapacity(int capacity);
584 585 586 587 588
  ///
  /// \brief A boolean state telling whether to use the MKLDNN.
  ///
  /// \return bool Whether to use the MKLDNN.
  ///
589 590
  bool mkldnn_enabled() const { return use_mkldnn_; }

591 592 593 594 595 596
  ///
  /// \brief Set the number of cpu math library threads.
  ///
  /// \param cpu_math_library_num_threads The number of cpu math library
  /// threads.
  ///
597
  void SetCpuMathLibraryNumThreads(int cpu_math_library_num_threads);
598 599 600 601 602 603
  ///
  /// \brief An int state telling how many threads are used in the CPU math
  /// library.
  ///
  /// \return int The number of threads used in the CPU math library.
  ///
604 605 606 607
  int cpu_math_library_num_threads() const {
    return cpu_math_library_num_threads_;
  }

608 609 610 611 612
  ///
  /// \brief Transform the AnalysisConfig to NativeConfig.
  ///
  /// \return NativeConfig The NativeConfig transformed.
  ///
Y
Yan Chunwei 已提交
613
  NativeConfig ToNativeConfig() const;
614 615 616 617 618
  ///
  /// \brief Specify the operator type list to use MKLDNN acceleration.
  ///
  /// \param op_list The operator type list.
  ///
619 620 621
  void SetMKLDNNOp(std::unordered_set<std::string> op_list) {
    mkldnn_enabled_op_types_ = op_list;
  }
622

623 624 625 626
  ///
  /// \brief Turn on MKLDNN quantization.
  ///
  ///
627 628
  void EnableMkldnnQuantizer();

629 630 631 632 633 634 635 636 637 638 639 640 641
  ///
  /// \brief Turn on MKLDNN bfloat16.
  ///
  ///
  void EnableMkldnnBfloat16();

  ///
  /// \brief A boolean state telling whether to use the MKLDNN Bfloat16.
  ///
  /// \return bool Whether to use the MKLDNN Bfloat16.
  ///
  bool mkldnn_bfloat16_enabled() const { return use_mkldnn_bfloat16_; }

642 643 644 645 646 647 648 649
  /// \brief Specify the operator type list to use Bfloat16 acceleration.
  ///
  /// \param op_list The operator type list.
  ///
  void SetBfloat16Op(std::unordered_set<std::string> op_list) {
    bfloat16_enabled_op_types_ = op_list;
  }

650 651 652 653 654 655 656 657
  ///
  /// \brief A boolean state telling whether the thread local CUDA stream is
  /// enabled.
  ///
  /// \return bool Whether the thread local CUDA stream is enabled.
  ///
  bool thread_local_stream_enabled() const { return thread_local_stream_; }

658 659 660 661 662
  ///
  /// \brief A boolean state telling whether the MKLDNN quantization is enabled.
  ///
  /// \return bool Whether the MKLDNN quantization is enabled.
  ///
663 664
  bool mkldnn_quantizer_enabled() const { return use_mkldnn_quantizer_; }

665 666 667 668 669
  ///
  /// \brief Get MKLDNN quantizer config.
  ///
  /// \return MkldnnQuantizerConfig* MKLDNN quantizer config.
  ///
670
  MkldnnQuantizerConfig* mkldnn_quantizer_config() const;
671

672 673 674 675 676 677 678 679 680
  ///
  /// \brief Specify the memory buffer of program and parameter.
  /// Used when model and params are loaded directly from memory.
  ///
  /// \param prog_buffer The memory buffer of program.
  /// \param prog_buffer_size The size of the model data.
  /// \param params_buffer The memory buffer of the combined parameters file.
  /// \param params_buffer_size The size of the combined parameters data.
  ///
T
Tao Luo 已提交
681
  void SetModelBuffer(const char* prog_buffer, size_t prog_buffer_size,
682
                      const char* params_buffer, size_t params_buffer_size);
683 684 685 686 687 688
  ///
  /// \brief A boolean state telling whether the model is set from the CPU
  /// memory.
  ///
  /// \return bool Whether model and params are loaded directly from memory.
  ///
T
Tao Luo 已提交
689
  bool model_from_memory() const { return model_from_memory_; }
T
Tao Luo 已提交
690

691 692 693 694
  ///
  /// \brief Turn on memory optimize
  /// NOTE still in development.
  ///
695 696 697
  /// \param x Whether to enable memory optimize.
  ///
  void EnableMemoryOptim(bool x = true);
698 699 700 701 702 703
  ///
  /// \brief A boolean state telling whether the memory optimization is
  /// activated.
  ///
  /// \return bool Whether the memory optimization is activated.
  ///
Y
Yan Chunwei 已提交
704
  bool enable_memory_optim() const;
705

706 707 708 709
  ///
  /// \brief Turn on profiling report.
  /// If not turned on, no profiling report will be generated.
  ///
710
  void EnableProfile();
711 712 713 714 715
  ///
  /// \brief A boolean state telling whether the profiler is activated.
  ///
  /// \return bool Whether the profiler is activated.
  ///
716 717
  bool profile_enabled() const { return with_profile_; }

718 719 720
  ///
  /// \brief Mute all logs in Paddle inference.
  ///
721
  void DisableGlogInfo();
722 723 724 725 726
  ///
  /// \brief A boolean state telling whether logs in Paddle inference are muted.
  ///
  /// \return bool Whether logs in Paddle inference are muted.
  ///
727 728
  bool glog_info_disabled() const { return !with_glog_info_; }

729 730 731 732 733
  ///
  /// \brief Set the AnalysisConfig to be invalid.
  /// This is to ensure that an AnalysisConfig can only be used in one
  /// AnalysisPredictor.
  ///
734
  void SetInValid() const { is_valid_ = false; }
735 736 737 738 739
  ///
  /// \brief A boolean state telling whether the AnalysisConfig is valid.
  ///
  /// \return bool Whether the AnalysisConfig is valid.
  ///
740
  bool is_valid() const { return is_valid_; }
Y
Yan Chunwei 已提交
741

742 743
  friend class ::paddle::AnalysisPredictor;

744 745 746 747 748
  ///
  /// \brief Get a pass builder for customize the passes in IR analysis phase.
  /// NOTE: Just for developer, not an official API, easy to be broken.
  ///
  ///
749
  PassStrategy* pass_builder() const;
750 751 752 753 754 755 756

  ///
  /// \brief Enable the GPU multi-computing stream feature.
  /// NOTE: The current behavior of this interface is to bind the computation
  /// stream to the thread, and this behavior may be changed in the future.
  ///
  void EnableGpuMultiStream();
757
  void PartiallyRelease();
758

759 760 761 762 763
  ///
  /// \brief Print the summary of config.
  ///
  std::string Summary();

764 765
  LiteNNAdapterConfig& NNAdapter() { return nnadapter_config_; }

766 767 768 769 770 771
 protected:
  // Update the config.
  void Update();

  std::string SerializeInfoCache();

772
 protected:
773 774
  // Model pathes.
  std::string model_dir_;
775 776
  mutable std::string prog_file_;
  mutable std::string params_file_;
777

S
Sylwester Fraczek 已提交
778
  // GPU related.
779
  bool use_gpu_{false};
780
  int gpu_device_id_{0};
781
  uint64_t memory_pool_init_size_mb_{100};  // initial size is 100MB.
W
Wilber 已提交
782
  bool thread_local_stream_{false};
783

784 785
  bool use_cudnn_{false};

W
Wilber 已提交
786 787 788 789
  // NPU related
  bool use_npu_{false};
  int npu_device_id_{0};

790 791 792
  // Padding related
  bool use_fc_padding_{true};

S
Sylwester Fraczek 已提交
793
  // TensorRT related.
794
  bool use_tensorrt_{false};
795 796
  // For workspace_size, refer it from here:
  // https://docs.nvidia.com/deeplearning/sdk/tensorrt-developer-guide/index.html#troubleshooting
797
  int tensorrt_workspace_size_{1 << 30};
798 799 800 801
  // While TensorRT allows an engine optimized for a given max batch size
  // to run at any smaller size, the performance for those smaller
  // sizes may not be as well-optimized. Therefore, Max batch is best
  // equivalent to the runtime batch size.
802
  int tensorrt_max_batchsize_{1};
803 804 805 806 807
  //  We transform the Ops that can be converted into TRT layer in the model,
  //  and aggregate these Ops into subgraphs for TRT execution.
  //  We set this variable to control the minimum number of nodes in the
  //  subgraph, 3 as default value.
  int tensorrt_min_subgraph_size_{3};
808 809 810
  Precision tensorrt_precision_mode_{Precision::kFloat32};
  bool trt_use_static_engine_{false};
  bool trt_use_calib_mode_{true};
811
  bool trt_use_oss_{false};
812
  bool trt_with_interleaved_{false};
813 814
  bool trt_use_dla_{false};
  int trt_dla_core_{0};
815 816 817
  std::map<std::string, std::vector<int>> min_input_shape_{};
  std::map<std::string, std::vector<int>> max_input_shape_{};
  std::map<std::string, std::vector<int>> optim_input_shape_{};
818
  std::vector<std::string> trt_disabled_ops_{};
819
  bool disable_trt_plugin_fp16_{false};
820 821 822
  bool trt_allow_build_at_runtime_{false};
  // tune to get dynamic_shape info.
  bool trt_tuned_dynamic_shape_{false};
823
  bool trt_use_inspector_{false};
824 825 826 827 828 829

  // In CollectShapeInfo mode, we will collect the shape information of
  // all intermediate tensors in the compute graph and calculate the
  // min_shape, max_shape and opt_shape and save in shape_range_info_path_;
  bool collect_shape_range_info_{false};
  std::string shape_range_info_path_;
830

D
denglin-github 已提交
831 832 833 834
  // dlnne related.
  bool use_dlnne_{false};
  int dlnne_min_subgraph_size_{3};

Y
Yan Chunwei 已提交
835 836 837
  // memory reuse related.
  bool enable_memory_optim_{false};

838 839 840
  bool use_mkldnn_{false};
  std::unordered_set<std::string> mkldnn_enabled_op_types_;

T
Tao Luo 已提交
841
  bool model_from_memory_{false};
842

843 844 845 846 847 848 849 850
  bool enable_ir_optim_{true};
  bool use_feed_fetch_ops_{true};
  bool ir_debug_{false};

  bool specify_input_name_{false};

  int cpu_math_library_num_threads_{1};

851 852
  bool with_profile_{false};

853 854
  bool with_glog_info_{true};

855 856 857 858
  // A runtime cache, shouldn't be transferred to others.
  std::string serialized_info_cache_;

  mutable std::unique_ptr<PassStrategy> pass_builder_;
859

石晓伟 已提交
860 861 862 863
  bool use_lite_{false};
  std::vector<std::string> lite_passes_filter_;
  std::vector<std::string> lite_ops_filter_;
  Precision lite_precision_mode_;
864
  bool lite_zero_copy_;
石晓伟 已提交
865

W
Wilber 已提交
866
  // XPU related.
867
  bool use_xpu_{false};
W
Wilber 已提交
868
  int xpu_device_id_{0};
869
  int xpu_l3_workspace_size_{0};
W
Wilber 已提交
870 871 872 873 874
  bool xpu_locked_;
  bool xpu_autotune_;
  std::string xpu_autotune_file_;
  std::string xpu_precision_;
  bool xpu_adaptive_seqlen_;
875

876 877 878
  // NNAdapter related
  LiteNNAdapterConfig nnadapter_config_;

879
  // mkldnn related.
W
Wilber 已提交
880
  int mkldnn_cache_capacity_{10};
881 882
  bool use_mkldnn_quantizer_{false};
  std::shared_ptr<MkldnnQuantizerConfig> mkldnn_quantizer_config_;
883
  bool use_mkldnn_bfloat16_{false};
884
  std::unordered_set<std::string> bfloat16_enabled_op_types_;
885

J
jianghaicheng 已提交
886 887 888
  // ipu related.
  bool use_ipu_{false};
  int ipu_device_num_{1};
889
  int ipu_micro_batch_size_{1};
J
jianghaicheng 已提交
890 891
  bool ipu_enable_pipelining_{false};
  int ipu_batches_per_step_{1};
892 893 894 895 896

  bool ipu_enable_fp16_{false};
  int ipu_replica_num_{1};
  float ipu_available_memory_proportion_{1.0};
  bool ipu_enable_half_partial_{false};
J
jianghaicheng 已提交
897

898 899 900 901
  // If the config is already used on a predictor, it becomes invalid.
  // Any config can only be used with one predictor.
  // Variables held by config can take up a lot of memory in some cases.
  // So we release the memory when the predictor is set up.
902 903
  mutable bool is_valid_{true};
  std::string opt_cache_dir_;
904
  friend class paddle_infer::experimental::InternalUtils;
905 906 907
};

}  // namespace paddle