paddle_analysis_config.h 31.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
14 15 16 17 18 19 20 21 22 23 24

///
/// \file paddle_analysis_config.h
///
/// \brief Paddle Analysis Config API信息
///
/// \author paddle-infer@baidu.com
/// \date 2020-03-20
/// \since 1.7
///

25 26 27
#pragma once

#include <cassert>
28
#include <map>
29 30
#include <memory>
#include <string>
31
#include <unordered_set>
32
#include <utility>
33
#include <vector>
34

35
#include "paddle_infer_declare.h"  // NOLINT
36

37
/*! \file */
38 39 40 41
// Here we include some header files with relative paths, for that in deploy,
// the abstract path of this header file will be changed.
#include "paddle_api.h"           // NOLINT
#include "paddle_pass_builder.h"  // NOLINT
42 43 44
#ifdef PADDLE_WITH_MKLDNN
#include "paddle_mkldnn_quantizer_config.h"  // NOLINT
#endif
45 46 47 48

namespace paddle {

class AnalysisPredictor;
49
struct MkldnnQuantizerConfig;
50

51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
struct LiteNNAdapterConfig {
  bool use_nnadapter{false};
  std::string nnadapter_model_cache_dir;
  std::map<std::string, std::vector<char>> nnadapter_model_cache_buffers;
  std::vector<std::string> nnadapter_device_names;
  std::string nnadapter_context_properties;
  std::string nnadapter_subgraph_partition_config_path;
  std::string nnadapter_subgraph_partition_config_buffer;

  LiteNNAdapterConfig& SetDeviceNames(const std::vector<std::string>& names);

  LiteNNAdapterConfig& SetContextProperties(const std::string& properties);

  LiteNNAdapterConfig& SetModelCacheDir(const std::string& dir);

  LiteNNAdapterConfig& SetModelCacheBuffers(
      const std::string& model_cache_token,
      const std::vector<char>& model_cache_buffer);

  LiteNNAdapterConfig& SetSubgraphPartitionConfigPath(const std::string& path);

  LiteNNAdapterConfig& SetSubgraphPartitionConfigBuffer(
      const std::string& buffer);

  LiteNNAdapterConfig& Enable();
  LiteNNAdapterConfig& Disable();
};

79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
struct DistConfig {
  bool use_dist_model() const { return use_dist_model_; }
  void EnableDistModel(bool use_dist_model) {
    use_dist_model_ = use_dist_model;
  }

  std::vector<std::string> trainer_endpoints() const {
    return trainer_endpoints_;
  }

  std::string current_endpoint() const { return current_endpoint_; }

  void SetEndpoints(const std::vector<std::string>& trainer_endpoints,
                    const std::string& current_endpoint) {
    trainer_endpoints_ = trainer_endpoints;
    current_endpoint_ = current_endpoint;
  }

  int64_t nranks() const { return nranks_; }

  int64_t rank() const { return rank_; }

  void SetRanks(int64_t nranks, int64_t rank) {
    nranks_ = nranks;
    rank_ = rank;
  }

  std::string comm_init_config() const { return comm_init_config_; }

  void SetCommInitConfig(const std::string& comm_init_config) {
    comm_init_config_ = comm_init_config;
  }

  void SetCarrierId(const std::string& carrier_id) { carrier_id_ = carrier_id; }

  std::string carrier_id() const { return carrier_id_; }

 protected:
  // DistModel Inference related
  bool use_dist_model_{false};  // whether use DistModel or not
  std::vector<std::string> trainer_endpoints_{};  // all trainers' endpoints
  std::string current_endpoint_{};                // current trainer's endpoint
  int64_t nranks_{1};               // total ranks (number of trainers)
  int64_t rank_{0};                 // rank
  std::string comm_init_config_{};  // converter config path
  std::string carrier_id_{"inference"};
};

127
///
128
/// \brief configuration manager for AnalysisPredictor.
129 130
/// \since 1.7.0
///
131
/// AnalysisConfig manages configurations of AnalysisPredictor.
132 133 134 135 136
/// During inference procedure, there are many parameters(model/params path,
/// place of inference, etc.)
/// to be specified, and various optimizations(subgraph fusion, memory
/// optimazation, TensorRT engine, etc.)
/// to be done. Users can manage these settings by creating and modifying an
137 138
/// AnalysisConfig,
/// and loading it into AnalysisPredictor.
139
///
140
struct PD_INFER_DECL AnalysisConfig {
141
  AnalysisConfig() = default;
142
  ///
143 144
  /// \brief Construct a new AnalysisConfig from another
  /// AnalysisConfig.
145
  ///
146
  /// \param[in] other another AnalysisConfig
147
  ///
148
  explicit AnalysisConfig(const AnalysisConfig& other);
149
  ///
150
  /// \brief Construct a new AnalysisConfig from a no-combined model.
151 152 153
  ///
  /// \param[in] model_dir model directory of the no-combined model.
  ///
154
  explicit AnalysisConfig(const std::string& model_dir);
155
  ///
156
  /// \brief Construct a new AnalysisConfig from a combined model.
157 158 159 160
  ///
  /// \param[in] prog_file model file path of the combined model.
  /// \param[in] params_file params file path of the combined model.
  ///
161 162
  explicit AnalysisConfig(const std::string& prog_file,
                          const std::string& params_file);
163 164 165
  ///
  /// \brief Precision of inference in TensorRT.
  ///
N
nhzlx 已提交
166
  enum class Precision {
167 168 169
    kFloat32 = 0,  ///< fp32
    kInt8,         ///< int8
    kHalf,         ///< fp16
N
nhzlx 已提交
170
  };
171

172 173 174 175 176
  ///
  /// \brief Set the no-combined model dir path.
  ///
  /// \param model_dir model dir path.
  ///
177
  void SetModel(const std::string& model_dir) { model_dir_ = model_dir; }
178 179 180 181 182 183 184 185

  ///
  /// \brief Set the combined model with two specific pathes for program and
  /// parameters.
  ///
  /// \param prog_file_path model file path of the combined model.
  /// \param params_file_path params file path of the combined model.
  ///
186 187
  void SetModel(const std::string& prog_file_path,
                const std::string& params_file_path);
188 189 190 191 192
  ///
  /// \brief Set the model file path of a combined model.
  ///
  /// \param x model file path.
  ///
193
  void SetProgFile(const std::string& x) { prog_file_ = x; }
194 195 196 197 198
  ///
  /// \brief Set the params file path of a combined model.
  ///
  /// \param x params file path.
  ///
199
  void SetParamsFile(const std::string& x) { params_file_ = x; }
200 201 202 203 204 205

  ///
  /// \brief Set the path of optimization cache directory.
  ///
  /// \param opt_cache_dir the path of optimization cache directory.
  ///
206 207 208
  void SetOptimCacheDir(const std::string& opt_cache_dir) {
    opt_cache_dir_ = opt_cache_dir;
  }
209 210 211 212 213
  ///
  /// \brief Get the model directory path.
  ///
  /// \return const std::string& The model directory path.
  ///
214
  const std::string& model_dir() const { return model_dir_; }
215 216 217 218 219
  ///
  /// \brief Get the program file path.
  ///
  /// \return const std::string& The program file path.
  ///
220
  const std::string& prog_file() const { return prog_file_; }
221 222 223 224 225
  ///
  /// \brief Get the combined parameters file.
  ///
  /// \return const std::string& The combined parameters file.
  ///
226 227
  const std::string& params_file() const { return params_file_; }

228
  // Padding related.
229 230 231 232 233

  ///
  /// \brief Turn off FC Padding.
  ///
  ///
234
  void DisableFCPadding();
235 236 237 238 239
  ///
  /// \brief A boolean state telling whether fc padding is used.
  ///
  /// \return bool Whether fc padding is used.
  ///
240 241
  bool use_fc_padding() const { return use_fc_padding_; }

242
  // GPU related.
243

244 245 246 247 248 249
  ///
  /// \brief Turn on GPU.
  ///
  /// \param memory_pool_init_size_mb initial size of the GPU memory pool in MB.
  /// \param device_id device_id the GPU card to use (default is 0).
  ///
250
  void EnableUseGpu(uint64_t memory_pool_init_size_mb, int device_id = 0);
251 252 253 254
  ///
  /// \brief Turn off GPU.
  ///
  ///
255
  void DisableGpu();
256

257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276
  ///
  /// \brief Turn on XPU.
  ///
  /// \param l3_workspace_size The size of the video memory allocated by the l3
  ///         cache, the maximum is 16M.
  /// \param locked Whether the allocated L3 cache can be locked. If false,
  ///       it means that the L3 cache is not locked, and the allocated L3
  ///       cache can be shared by multiple models, and multiple models
  ///       sharing the L3 cache will be executed sequentially on the card.
  /// \param autotune Whether to autotune the conv operator in the model. If
  ///       true, when the conv operator of a certain dimension is executed
  ///       for the first time, it will automatically search for a better
  ///       algorithm to improve the performance of subsequent conv operators
  ///       of the same dimension.
  /// \param autotune_file Specify the path of the autotune file. If
  ///       autotune_file is specified, the algorithm specified in the
  ///       file will be used and autotune will not be performed again.
  /// \param precision Calculation accuracy of multi_encoder
  /// \param adaptive_seqlen Is the input of multi_encoder variable length
  ///
W
Wilber 已提交
277 278 279 280
  void EnableXpu(int l3_workspace_size = 0xfffc00, bool locked = false,
                 bool autotune = true, const std::string& autotune_file = "",
                 const std::string& precision = "int16",
                 bool adaptive_seqlen = false);
J
jianghaicheng 已提交
281 282 283 284

  ///
  /// \brief Turn on IPU.
  ///
285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308
  /// \param ipu_device_num the number of IPUs.
  /// \param ipu_micro_batch_size the batch size in the graph, only work with
  /// mutable input shapes.
  /// \param ipu_enable_pipelining enable pipelining.
  /// \param ipu_batches_per_step the number of batches per run in pipelining.
  ///
  void EnableIpu(int ipu_device_num = 1, int ipu_micro_batch_size = 1,
                 bool ipu_enable_pipelining = false,
                 int ipu_batches_per_step = 1);

  ///
  /// \brief Set IPU config.
  ///
  /// \param ipu_enable_fp16 enable fp16.
  /// \param ipu_replica_num the number of graph replication.
  /// \param ipu_available_memory_proportion the available memory proportion for
  /// matmul/conv.
  /// \param ipu_enable_half_partial enable fp16 partial for matmul, only work
  /// with fp16.
  ///
  void SetIpuConfig(bool ipu_enable_fp16 = false, int ipu_replica_num = 1,
                    float ipu_available_memory_proportion = 1.0,
                    bool ipu_enable_half_partial = false);

309
  ///
310 311 312 313 314 315
  /// \brief Set XPU device id.
  ///
  /// \param device_id the XPU card to use (default is 0).
  ///
  void SetXpuDeviceId(int device_id = 0);
  ///
W
Wilber 已提交
316 317 318 319 320 321
  /// \brief Turn on NPU.
  ///
  /// \param device_id device_id the NPU card to use (default is 0).
  ///
  void EnableNpu(int device_id = 0);
  ///
322 323 324 325 326 327 328 329 330 331 332 333
  /// \brief Turn on ONNXRuntime.
  ///
  void EnableONNXRuntime();
  ///
  /// \brief Turn off ONNXRuntime.
  ///
  void DisableONNXRuntime();
  ///
  /// \brief Turn on ONNXRuntime Optimization.
  ///
  void EnableORTOptimization();
  ///
334 335 336 337
  /// \brief A boolean state telling whether the GPU is turned on.
  ///
  /// \return bool Whether the GPU is turned on.
  ///
338
  bool use_gpu() const { return use_gpu_; }
339
  ///
340 341 342 343 344 345
  /// \brief A boolean state telling whether the XPU is turned on.
  ///
  /// \return bool Whether the XPU is turned on.
  ///
  bool use_xpu() const { return use_xpu_; }
  ///
W
Wilber 已提交
346 347 348 349 350
  /// \brief A boolean state telling whether the NPU is turned on.
  ///
  /// \return bool Whether the NPU is turned on.
  ///
  bool use_npu() const { return use_npu_; }
J
jianghaicheng 已提交
351 352 353 354 355
  /// \brief A boolean state telling whether the IPU is turned on.
  ///
  /// \return bool Whether the IPU is turned on.
  ///
  bool use_ipu() const { return use_ipu_; }
W
Wilber 已提交
356
  ///
357 358 359 360 361 362 363 364 365 366 367 368 369
  /// \brief A boolean state telling whether the ONNXRuntime is turned on.
  ///
  /// \return bool Whether the ONNXRuntime is turned on.
  ///
  bool use_onnxruntime() const { return use_onnxruntime_; }
  ///
  /// \brief A boolean state telling whether the ONNXRuntime Optimization is
  /// turned on.
  ///
  /// \return bool Whether the ONNXRuntime Optimization is turned on.
  ///
  bool ort_optimization_enabled() const { return enable_ort_optimization_; }
  ///
370 371 372 373 374 375
  /// \brief Get the GPU device id.
  ///
  /// \return int The GPU device id.
  ///
  int gpu_device_id() const { return gpu_device_id_; }
  ///
376
  /// \brief Get the XPU device id.
377
  ///
378
  /// \return int The XPU device id.
379
  ///
380
  int xpu_device_id() const { return xpu_device_id_; }
381
  ///
W
Wilber 已提交
382 383 384 385 386
  /// \brief Get the NPU device id.
  ///
  /// \return int The NPU device id.
  ///
  int npu_device_id() const { return npu_device_id_; }
J
jianghaicheng 已提交
387 388 389 390 391
  /// \brief Get the the number of IPU device .
  ///
  /// \return int The number of IPU device.
  ///
  int ipu_device_num() const { return ipu_device_num_; }
W
Wilber 已提交
392
  ///
393 394 395 396
  /// \brief Get the initial size in MB of the GPU memory pool.
  ///
  /// \return int The initial size in MB of the GPU memory pool.
  ///
397
  int memory_pool_init_size_mb() const { return memory_pool_init_size_mb_; }
398 399 400 401 402 403
  ///
  /// \brief Get the proportion of the initial memory pool size compared to the
  /// device.
  ///
  /// \return float The proportion of the initial memory pool size.
  ///
404
  float fraction_of_gpu_memory_for_pool() const;
405

406 407 408 409 410
  // CUDNN related.
  ///
  /// \brief Turn on CUDNN.
  ///
  ///
411
  void EnableCUDNN();
412 413 414 415 416
  ///
  /// \brief A boolean state telling whether to use CUDNN.
  ///
  /// \return bool Whether to use CUDNN.
  ///
417 418
  bool cudnn_enabled() const { return use_cudnn_; }

419 420 421 422 423 424
  ///
  /// \brief Control whether to perform IR graph optimization.
  /// If turned off, the AnalysisConfig will act just like a NativeConfig.
  ///
  /// \param x Whether the ir graph optimization is actived.
  ///
425
  void SwitchIrOptim(int x = true) { enable_ir_optim_ = x; }
426 427 428 429 430 431
  ///
  /// \brief A boolean state telling whether the ir graph optimization is
  /// actived.
  ///
  /// \return bool Whether to use ir graph optimization.
  ///
432
  bool ir_optim() const { return enable_ir_optim_; }
433

434 435 436 437 438 439 440
  ///
  /// \brief INTERNAL Determine whether to use the feed and fetch operators.
  /// Just for internal development, not stable yet.
  /// When ZeroCopyTensor is used, this should be turned off.
  ///
  /// \param x Whether to use the feed and fetch operators.
  ///
441
  void SwitchUseFeedFetchOps(int x = true) { use_feed_fetch_ops_ = x; }
442 443 444 445 446 447
  ///
  /// \brief A boolean state telling whether to use the feed and fetch
  /// operators.
  ///
  /// \return bool Whether to use the feed and fetch operators.
  ///
448
  bool use_feed_fetch_ops_enabled() const { return use_feed_fetch_ops_; }
449

450 451 452 453 454 455 456 457 458 459 460
  ///
  /// \brief Control whether to specify the inputs' names.
  /// The ZeroCopyTensor type has a name member, assign it with the
  /// corresponding
  /// variable name. This is used only when the input ZeroCopyTensors passed to
  /// the
  /// AnalysisPredictor.ZeroCopyRun() cannot follow the order in the training
  /// phase.
  ///
  /// \param x Whether to specify the inputs' names.
  ///
461
  void SwitchSpecifyInputNames(bool x = true) { specify_input_name_ = x; }
462 463 464 465 466 467 468
  ///
  /// \brief A boolean state tell whether the input ZeroCopyTensor names
  /// specified should
  /// be used to reorder the inputs in AnalysisPredictor.ZeroCopyRun().
  ///
  /// \return bool Whether to specify the inputs' names.
  ///
469
  bool specify_input_name() const { return specify_input_name_; }
470

471 472 473 474 475 476 477 478 479 480
  ///
  /// \brief Turn on the TensorRT engine.
  /// The TensorRT engine will accelerate some subgraphes in the original Fluid
  /// computation graph. In some models such as resnet50, GoogleNet and so on,
  /// it gains significant performance acceleration.
  ///
  /// \param workspace_size The memory size(in byte) used for TensorRT
  /// workspace.
  /// \param max_batch_size The maximum batch size of this prediction task,
  /// better set as small as possible for less performance loss.
481
  /// \param min_subgraph_size The minimum TensorRT subgraph size needed, if a
482 483 484 485 486 487 488 489
  /// subgraph is smaller than this, it will not be transferred to TensorRT
  /// engine.
  /// \param precision The precision used in TensorRT.
  /// \param use_static Serialize optimization information to disk for reusing.
  /// \param use_calib_mode Use TRT int8 calibration(post training
  /// quantization).
  ///
  ///
490 491 492 493 494
  void EnableTensorRtEngine(int workspace_size = 1 << 20,
                            int max_batch_size = 1, int min_subgraph_size = 3,
                            Precision precision = Precision::kFloat32,
                            bool use_static = false,
                            bool use_calib_mode = true);
495 496 497 498 499
  ///
  /// \brief A boolean state telling whether the TensorRT engine is used.
  ///
  /// \return bool Whether the TensorRT engine is used.
  ///
500
  bool tensorrt_engine_enabled() const { return use_tensorrt_; }
501
  ///
502 503 504 505 506 507
  /// \brief  Get the TensorRT engine precision.
  ///
  /// \return Precision Get the TensorRT engine precision.
  ///
  Precision tensorrt_precision_mode() const { return tensorrt_precision_mode_; }
  ///
508 509 510 511 512 513 514
  /// \brief Set min, max, opt shape for TensorRT Dynamic shape mode.
  /// \param min_input_shape The min input shape of the subgraph input.
  /// \param max_input_shape The max input shape of the subgraph input.
  /// \param opt_input_shape The opt input shape of the subgraph input.
  /// \param disable_trt_plugin_fp16 Setting this parameter to true means that
  /// TRT plugin will not run fp16.
  ///
515 516 517 518 519
  void SetTRTDynamicShapeInfo(
      std::map<std::string, std::vector<int>> min_input_shape,
      std::map<std::string, std::vector<int>> max_input_shape,
      std::map<std::string, std::vector<int>> optim_input_shape,
      bool disable_trt_plugin_fp16 = false);
520 521 522 523 524 525
  ///
  /// \brief A boolean state telling whether the trt dynamic_shape is used.
  ///
  /// \return bool Whether the trt dynamic_shape is used.
  ///
  bool tensorrt_dynamic_shape_enabled() const {
W
Wilber 已提交
526
    return !min_input_shape_.empty();
527
  }
528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571
  ///
  /// \brief Enable tuned tensorrt dynamic shape.
  ///
  /// \param shape_range_info_path the path to shape_info file got in
  /// CollectShapeInfo
  /// mode.
  /// \param allow_build_at_runtime allow build trt engine at runtime.
  ///
  void EnableTunedTensorRtDynamicShape(const std::string& shape_range_info_path,
                                       bool allow_build_at_runtime = true);

  ///
  /// \brief A boolean state telling whether to use tuned tensorrt dynamic
  /// shape.
  ///
  bool tuned_tensorrt_dynamic_shape();

  ///
  /// \brief A boolean state telling whether to allow building trt engine at
  /// runtime.
  ///
  bool trt_allow_build_at_runtime();

  ///
  /// \brief Collect shape info of all tensors in compute graph.
  ///
  /// \param shape_range_info_path the path to save shape info.
  ///
  void CollectShapeRangeInfo(const std::string& shape_range_info_path);

  ///
  /// \brief the shape info path in CollectShapeInfo mode.
  ///
  /// \return the shape info path.
  ///
  const std::string& shape_range_info_path();

  ///
  /// \brief A boolean state telling whether to collect shape info.
  ///
  /// \return bool Whether to collect shape info.
  ///
  bool shape_range_info_collected();

572 573 574 575 576 577
  ///
  /// \brief Prevent ops running in Paddle-TRT
  /// NOTE: just experimental, not an official stable API, easy to be broken.
  ///
  void Exp_DisableTensorRtOPs(const std::vector<std::string>& ops);

578 579
  ///
  /// \brief Replace some TensorRT plugins to TensorRT OSS(
580 581 582
  /// https://github.com/NVIDIA/TensorRT), with which some models's inference
  /// may be more high-performance. Libnvinfer_plugin.so greater than
  /// V7.2.1 is needed.
583 584
  ///
  void EnableTensorRtOSS();
585

586 587 588 589 590 591 592
  ///
  /// \brief A boolean state telling whether to use the TensorRT OSS.
  ///
  /// \return bool Whether to use the TensorRT OSS.
  ///
  bool tensorrt_oss_enabled() { return trt_use_oss_; }

593 594 595 596 597 598 599 600 601 602 603 604 605 606
  ///
  /// \brief Enable TensorRT DLA
  /// \param dla_core ID of DLACore, which should be 0, 1,
  ///        ..., IBuilder.getNbDLACores() - 1
  ///
  void EnableTensorRtDLA(int dla_core = 0);

  ///
  /// \brief A boolean state telling whether to use the TensorRT DLA.
  ///
  /// \return bool Whether to use the TensorRT DLA.
  ///
  bool tensorrt_dla_enabled() { return trt_use_dla_; }

607 608 609
  void EnableTensorRtInspector();
  bool tensorrt_inspector_enabled() { return trt_use_inspector_; }

D
denglin-github 已提交
610 611 612
  void EnableDlnne(int min_subgraph_size = 3);
  bool dlnne_enabled() const { return use_dlnne_; }

613 614 615 616 617 618 619
  ///
  /// \brief Turn on the usage of Lite sub-graph engine.
  ///
  /// \param precision_mode Precion used in Lite sub-graph engine.
  /// \param passes_filter Set the passes used in Lite sub-graph engine.
  /// \param ops_filter Operators not supported by Lite.
  ///
石晓伟 已提交
620 621
  void EnableLiteEngine(
      AnalysisConfig::Precision precision_mode = Precision::kFloat32,
622
      bool zero_copy = false,
石晓伟 已提交
623 624 625
      const std::vector<std::string>& passes_filter = {},
      const std::vector<std::string>& ops_filter = {});

626 627 628 629 630 631
  ///
  /// \brief A boolean state indicating whether the Lite sub-graph engine is
  /// used.
  ///
  /// \return bool whether the Lite sub-graph engine is used.
  ///
石晓伟 已提交
632 633
  bool lite_engine_enabled() const { return use_lite_; }

634 635 636 637 638 639 640
  ///
  /// \brief Control whether to debug IR graph analysis phase.
  /// This will generate DOT files for visualizing the computation graph after
  /// each analysis pass applied.
  ///
  /// \param x whether to debug IR graph analysis phase.
  ///
Y
Yan Chunwei 已提交
641
  void SwitchIrDebug(int x = true);
642

643 644 645 646
  ///
  /// \brief Turn on MKLDNN.
  ///
  ///
L
luotao1 已提交
647
  void EnableMKLDNN();
648 649 650
  ///
  /// \brief Set the cache capacity of different input shapes for MKLDNN.
  /// Default value 0 means not caching any shape.
651 652
  /// Please see MKL-DNN Data Caching Design Document:
  /// https://github.com/PaddlePaddle/FluidDoc/blob/develop/doc/fluid/design/mkldnn/caching/caching.md
653 654 655
  ///
  /// \param capacity The cache capacity.
  ///
656
  void SetMkldnnCacheCapacity(int capacity);
657 658 659 660 661
  ///
  /// \brief A boolean state telling whether to use the MKLDNN.
  ///
  /// \return bool Whether to use the MKLDNN.
  ///
662 663
  bool mkldnn_enabled() const { return use_mkldnn_; }

664 665 666 667 668 669
  ///
  /// \brief Set the number of cpu math library threads.
  ///
  /// \param cpu_math_library_num_threads The number of cpu math library
  /// threads.
  ///
670
  void SetCpuMathLibraryNumThreads(int cpu_math_library_num_threads);
671 672 673 674 675 676
  ///
  /// \brief An int state telling how many threads are used in the CPU math
  /// library.
  ///
  /// \return int The number of threads used in the CPU math library.
  ///
677 678 679 680
  int cpu_math_library_num_threads() const {
    return cpu_math_library_num_threads_;
  }

681 682 683 684 685
  ///
  /// \brief Transform the AnalysisConfig to NativeConfig.
  ///
  /// \return NativeConfig The NativeConfig transformed.
  ///
Y
Yan Chunwei 已提交
686
  NativeConfig ToNativeConfig() const;
687 688 689 690 691
  ///
  /// \brief Specify the operator type list to use MKLDNN acceleration.
  ///
  /// \param op_list The operator type list.
  ///
692 693 694
  void SetMKLDNNOp(std::unordered_set<std::string> op_list) {
    mkldnn_enabled_op_types_ = op_list;
  }
695

696 697 698 699
  ///
  /// \brief Turn on MKLDNN quantization.
  ///
  ///
700 701
  void EnableMkldnnQuantizer();

702 703 704 705 706 707 708 709 710 711 712 713 714
  ///
  /// \brief Turn on MKLDNN bfloat16.
  ///
  ///
  void EnableMkldnnBfloat16();

  ///
  /// \brief A boolean state telling whether to use the MKLDNN Bfloat16.
  ///
  /// \return bool Whether to use the MKLDNN Bfloat16.
  ///
  bool mkldnn_bfloat16_enabled() const { return use_mkldnn_bfloat16_; }

715 716 717 718 719 720 721 722
  /// \brief Specify the operator type list to use Bfloat16 acceleration.
  ///
  /// \param op_list The operator type list.
  ///
  void SetBfloat16Op(std::unordered_set<std::string> op_list) {
    bfloat16_enabled_op_types_ = op_list;
  }

723 724 725 726 727 728 729 730
  ///
  /// \brief A boolean state telling whether the thread local CUDA stream is
  /// enabled.
  ///
  /// \return bool Whether the thread local CUDA stream is enabled.
  ///
  bool thread_local_stream_enabled() const { return thread_local_stream_; }

731 732 733 734 735
  ///
  /// \brief A boolean state telling whether the MKLDNN quantization is enabled.
  ///
  /// \return bool Whether the MKLDNN quantization is enabled.
  ///
736 737
  bool mkldnn_quantizer_enabled() const { return use_mkldnn_quantizer_; }

738 739 740 741 742
  ///
  /// \brief Get MKLDNN quantizer config.
  ///
  /// \return MkldnnQuantizerConfig* MKLDNN quantizer config.
  ///
743
  MkldnnQuantizerConfig* mkldnn_quantizer_config() const;
744

745 746 747 748 749 750 751 752 753
  ///
  /// \brief Specify the memory buffer of program and parameter.
  /// Used when model and params are loaded directly from memory.
  ///
  /// \param prog_buffer The memory buffer of program.
  /// \param prog_buffer_size The size of the model data.
  /// \param params_buffer The memory buffer of the combined parameters file.
  /// \param params_buffer_size The size of the combined parameters data.
  ///
T
Tao Luo 已提交
754
  void SetModelBuffer(const char* prog_buffer, size_t prog_buffer_size,
755
                      const char* params_buffer, size_t params_buffer_size);
756 757 758 759 760 761
  ///
  /// \brief A boolean state telling whether the model is set from the CPU
  /// memory.
  ///
  /// \return bool Whether model and params are loaded directly from memory.
  ///
T
Tao Luo 已提交
762
  bool model_from_memory() const { return model_from_memory_; }
T
Tao Luo 已提交
763

764 765 766 767
  ///
  /// \brief Turn on memory optimize
  /// NOTE still in development.
  ///
768 769 770
  /// \param x Whether to enable memory optimize.
  ///
  void EnableMemoryOptim(bool x = true);
771 772 773 774 775 776
  ///
  /// \brief A boolean state telling whether the memory optimization is
  /// activated.
  ///
  /// \return bool Whether the memory optimization is activated.
  ///
Y
Yan Chunwei 已提交
777
  bool enable_memory_optim() const;
778

779 780 781 782
  ///
  /// \brief Turn on profiling report.
  /// If not turned on, no profiling report will be generated.
  ///
783
  void EnableProfile();
784 785 786 787 788
  ///
  /// \brief A boolean state telling whether the profiler is activated.
  ///
  /// \return bool Whether the profiler is activated.
  ///
789 790
  bool profile_enabled() const { return with_profile_; }

791 792 793
  ///
  /// \brief Mute all logs in Paddle inference.
  ///
794
  void DisableGlogInfo();
795 796 797 798 799
  ///
  /// \brief A boolean state telling whether logs in Paddle inference are muted.
  ///
  /// \return bool Whether logs in Paddle inference are muted.
  ///
800 801
  bool glog_info_disabled() const { return !with_glog_info_; }

802 803 804 805 806
  ///
  /// \brief Set the AnalysisConfig to be invalid.
  /// This is to ensure that an AnalysisConfig can only be used in one
  /// AnalysisPredictor.
  ///
807
  void SetInValid() const { is_valid_ = false; }
808 809 810 811 812
  ///
  /// \brief A boolean state telling whether the AnalysisConfig is valid.
  ///
  /// \return bool Whether the AnalysisConfig is valid.
  ///
813
  bool is_valid() const { return is_valid_; }
Y
Yan Chunwei 已提交
814

815 816
  friend class ::paddle::AnalysisPredictor;

817 818 819 820 821
  ///
  /// \brief Get a pass builder for customize the passes in IR analysis phase.
  /// NOTE: Just for developer, not an official API, easy to be broken.
  ///
  ///
822
  PassStrategy* pass_builder() const;
823 824 825 826 827 828 829

  ///
  /// \brief Enable the GPU multi-computing stream feature.
  /// NOTE: The current behavior of this interface is to bind the computation
  /// stream to the thread, and this behavior may be changed in the future.
  ///
  void EnableGpuMultiStream();
830
  void PartiallyRelease();
831

832 833 834 835 836
  ///
  /// \brief Print the summary of config.
  ///
  std::string Summary();

837 838
  LiteNNAdapterConfig& NNAdapter() { return nnadapter_config_; }

839 840 841 842 843 844
  void SetDistConfig(const DistConfig& dist_config) {
    dist_config_ = dist_config;
  }

  const DistConfig& dist_config() const { return dist_config_; }

845 846 847 848 849 850
 protected:
  // Update the config.
  void Update();

  std::string SerializeInfoCache();

851
 protected:
852 853
  // Model pathes.
  std::string model_dir_;
854 855
  mutable std::string prog_file_;
  mutable std::string params_file_;
856

S
Sylwester Fraczek 已提交
857
  // GPU related.
858
  bool use_gpu_{false};
859
  int gpu_device_id_{0};
860
  uint64_t memory_pool_init_size_mb_{100};  // initial size is 100MB.
W
Wilber 已提交
861
  bool thread_local_stream_{false};
862

863 864
  bool use_cudnn_{false};

W
Wilber 已提交
865 866 867 868
  // NPU related
  bool use_npu_{false};
  int npu_device_id_{0};

869 870 871 872
  // ONNXRuntime related
  bool use_onnxruntime_{false};
  bool enable_ort_optimization_{false};

873 874 875
  // Padding related
  bool use_fc_padding_{true};

S
Sylwester Fraczek 已提交
876
  // TensorRT related.
877
  bool use_tensorrt_{false};
878 879
  // For workspace_size, refer it from here:
  // https://docs.nvidia.com/deeplearning/sdk/tensorrt-developer-guide/index.html#troubleshooting
880
  int tensorrt_workspace_size_{1 << 30};
881 882 883 884
  // While TensorRT allows an engine optimized for a given max batch size
  // to run at any smaller size, the performance for those smaller
  // sizes may not be as well-optimized. Therefore, Max batch is best
  // equivalent to the runtime batch size.
885
  int tensorrt_max_batchsize_{1};
886 887 888 889 890
  //  We transform the Ops that can be converted into TRT layer in the model,
  //  and aggregate these Ops into subgraphs for TRT execution.
  //  We set this variable to control the minimum number of nodes in the
  //  subgraph, 3 as default value.
  int tensorrt_min_subgraph_size_{3};
891 892 893
  Precision tensorrt_precision_mode_{Precision::kFloat32};
  bool trt_use_static_engine_{false};
  bool trt_use_calib_mode_{true};
894
  bool trt_use_oss_{false};
895
  bool trt_with_interleaved_{false};
896 897
  bool trt_use_dla_{false};
  int trt_dla_core_{0};
898 899 900
  std::map<std::string, std::vector<int>> min_input_shape_{};
  std::map<std::string, std::vector<int>> max_input_shape_{};
  std::map<std::string, std::vector<int>> optim_input_shape_{};
901
  std::vector<std::string> trt_disabled_ops_{};
902
  bool disable_trt_plugin_fp16_{false};
903 904 905
  bool trt_allow_build_at_runtime_{false};
  // tune to get dynamic_shape info.
  bool trt_tuned_dynamic_shape_{false};
906
  bool trt_use_inspector_{false};
907 908 909 910 911 912

  // In CollectShapeInfo mode, we will collect the shape information of
  // all intermediate tensors in the compute graph and calculate the
  // min_shape, max_shape and opt_shape and save in shape_range_info_path_;
  bool collect_shape_range_info_{false};
  std::string shape_range_info_path_;
913

D
denglin-github 已提交
914 915 916 917
  // dlnne related.
  bool use_dlnne_{false};
  int dlnne_min_subgraph_size_{3};

Y
Yan Chunwei 已提交
918 919 920
  // memory reuse related.
  bool enable_memory_optim_{false};

921 922 923
  bool use_mkldnn_{false};
  std::unordered_set<std::string> mkldnn_enabled_op_types_;

T
Tao Luo 已提交
924
  bool model_from_memory_{false};
925

926 927 928 929 930 931 932 933
  bool enable_ir_optim_{true};
  bool use_feed_fetch_ops_{true};
  bool ir_debug_{false};

  bool specify_input_name_{false};

  int cpu_math_library_num_threads_{1};

934 935
  bool with_profile_{false};

936 937
  bool with_glog_info_{true};

938 939 940 941
  // A runtime cache, shouldn't be transferred to others.
  std::string serialized_info_cache_;

  mutable std::unique_ptr<PassStrategy> pass_builder_;
942

石晓伟 已提交
943 944 945 946
  bool use_lite_{false};
  std::vector<std::string> lite_passes_filter_;
  std::vector<std::string> lite_ops_filter_;
  Precision lite_precision_mode_;
947
  bool lite_zero_copy_;
石晓伟 已提交
948

W
Wilber 已提交
949
  // XPU related.
950
  bool use_xpu_{false};
W
Wilber 已提交
951
  int xpu_device_id_{0};
952
  int xpu_l3_workspace_size_{0};
W
Wilber 已提交
953 954 955 956 957
  bool xpu_locked_;
  bool xpu_autotune_;
  std::string xpu_autotune_file_;
  std::string xpu_precision_;
  bool xpu_adaptive_seqlen_;
958

959 960 961
  // NNAdapter related
  LiteNNAdapterConfig nnadapter_config_;

962
  // mkldnn related.
W
Wilber 已提交
963
  int mkldnn_cache_capacity_{10};
964 965
  bool use_mkldnn_quantizer_{false};
  std::shared_ptr<MkldnnQuantizerConfig> mkldnn_quantizer_config_;
966
  bool use_mkldnn_bfloat16_{false};
967
  std::unordered_set<std::string> bfloat16_enabled_op_types_;
968

J
jianghaicheng 已提交
969 970 971
  // ipu related.
  bool use_ipu_{false};
  int ipu_device_num_{1};
972
  int ipu_micro_batch_size_{1};
J
jianghaicheng 已提交
973 974
  bool ipu_enable_pipelining_{false};
  int ipu_batches_per_step_{1};
975 976 977 978 979

  bool ipu_enable_fp16_{false};
  int ipu_replica_num_{1};
  float ipu_available_memory_proportion_{1.0};
  bool ipu_enable_half_partial_{false};
J
jianghaicheng 已提交
980

981 982 983 984
  // If the config is already used on a predictor, it becomes invalid.
  // Any config can only be used with one predictor.
  // Variables held by config can take up a lot of memory in some cases.
  // So we release the memory when the predictor is set up.
985 986
  mutable bool is_valid_{true};
  std::string opt_cache_dir_;
987
  friend class paddle_infer::experimental::InternalUtils;
988 989 990

  // fleet exe related
  DistConfig dist_config_{};
991 992 993
};

}  // namespace paddle