// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. /// /// \file paddle_analysis_config.h /// /// \brief Paddle Analysis Config API信息 /// /// \author paddle-infer@baidu.com /// \date 2020-03-20 /// \since 1.7 /// #pragma once #include #include #include #include #include #include #include #include "paddle_infer_declare.h" // NOLINT /*! \file */ // Here we include some header files with relative paths, for that in deploy, // the abstract path of this header file will be changed. #include "paddle_api.h" // NOLINT #include "paddle_pass_builder.h" // NOLINT #ifdef PADDLE_WITH_DNNL #include "paddle_mkldnn_quantizer_config.h" // NOLINT #endif namespace paddle { class AnalysisPredictor; struct MkldnnQuantizerConfig; struct LiteNNAdapterConfig { bool use_nnadapter{false}; std::string nnadapter_model_cache_dir; std::map> nnadapter_model_cache_buffers; std::vector nnadapter_device_names; std::string nnadapter_context_properties; std::string nnadapter_subgraph_partition_config_path; std::string nnadapter_subgraph_partition_config_buffer; LiteNNAdapterConfig& SetDeviceNames(const std::vector& names); LiteNNAdapterConfig& SetContextProperties(const std::string& properties); LiteNNAdapterConfig& SetModelCacheDir(const std::string& dir); LiteNNAdapterConfig& SetModelCacheBuffers( const std::string& model_cache_token, const std::vector& model_cache_buffer); LiteNNAdapterConfig& SetSubgraphPartitionConfigPath(const std::string& path); LiteNNAdapterConfig& SetSubgraphPartitionConfigBuffer( const std::string& buffer); LiteNNAdapterConfig& Enable(); LiteNNAdapterConfig& Disable(); }; struct PD_INFER_DECL XpuConfig { // Select which xpu device to run model. int device_id{0}; // Available l3 size (Byte) // For kunlun1, max l3_size is 16773120 Byte // For kunlun2, max l3_size is 67104768 Byte size_t l3_size{0}; // If l3_ptr is not nullptr, it is used as l3 buffer. // If l3_ptr is nullptr, new l3 buffer will be created. void* l3_ptr{nullptr}; // Available l3 size for autotune. // If l3_autotune_size is 0, autotune is closed. // Note: The remaining l3 size (l3_size - l3_autotune_size) is for // kernels (both paddle/xdnn kernels) size_t l3_autotune_size{0}; // Reserved xpu global memory size for xpu_context; // If not set(-1), default memory size for xpu_context is 128MB in XPU2 or // 64MB in XPU1. If set 1*1024*1024, memory size for xpu_conext will be 1MB; int context_gm_size{-1}; // xpu_context(from baidu::xpu::api::create_context) for execution. // If context is nullptr, new context will be created by default. void* context{nullptr}; // Stream for execution. // If stream is nullptr, default stream will be used. void* stream{nullptr}; // Conv autotune level. Default 0 means no autotune. // Note: Paddle-Lite only. int conv_autotune_level{0}; // Base conv autotune info is read from conv_autotune_file. // Note: Paddle-Lite only. std::string conv_autotune_file; // Whether write new conv autotune info to conv_autotune_file. // Note: Paddle-Lite only. bool conv_autotune_file_writeback{false}; // Fc autotune level. The Optional values are 0-9. Default 0 means no // autotune. Note: Paddle-Lite only. int fc_autotune_level{0}; // Base fc autotune info is read from fc_autotune_file. // Note: Paddle-Lite only. std::string fc_autotune_file; // Whether write new fc autotune info to fc_autotune_file. // Note: Paddle-Lite only. bool fc_autotune_file_writeback{false}; // Gemm compute precision. Optional values are 0(int8),1(int16),2(int31). // Note: "gemm_compute_precision" has no effect on quanted ops of quant model // Note: Paddle-Lite only. int gemm_compute_precision{1}; // Which method to optimize softmax in transformer structure. Optional values // are 0,1,2. Note: Paddle-Lite only. int transformer_softmax_optimize_level{0}; // Whether enable adaptive_seqlen optimize on transformer encoder. // Note: Paddle-Lite only. bool transformer_encoder_adaptive_seqlen{true}; // Gelu out max threshold is limited to quant_post_static_gelu_out_threshold // if use static post-quantization. // Note: Paddle-Lite only. float quant_post_static_gelu_out_threshold{10.f}; // Activation method if use dynamic post-quantization. // For kunlun1, optional values are 0(per_tensor),1(per_batch),2(per_head). // For kunlun2, optional values are 0(per_tensor) or non-zero(every_16). // Note: Paddle-Lite only. int quant_post_dynamic_activation_method{0}; // Preprocess weight to quant_post_dynamic_weight_precision if use dynamic // post-quantization. Optional values is 0,1,2. // * If 0, preprocess weight to int8. // * If 1, preprocess weight to int16. // * If 2, preprocess weight to float. // Note: PaddleInference only. int quant_post_dynamic_weight_precision{1}; std::vector quant_post_dynamic_op_types; }; struct DistConfig { bool use_dist_model() const { return use_dist_model_; } void EnableDistModel(bool use_dist_model) { use_dist_model_ = use_dist_model; } std::vector trainer_endpoints() const { return trainer_endpoints_; } std::string current_endpoint() const { return current_endpoint_; } void SetEndpoints(const std::vector& trainer_endpoints, const std::string& current_endpoint) { trainer_endpoints_ = trainer_endpoints; current_endpoint_ = current_endpoint; } int64_t nranks() const { return nranks_; } int64_t rank() const { return rank_; } void SetRanks(int64_t nranks, int64_t rank) { nranks_ = nranks; rank_ = rank; } std::string comm_init_config() const { return comm_init_config_; } void SetCommInitConfig(const std::string& comm_init_config) { comm_init_config_ = comm_init_config; } void SetCarrierId(const std::string& carrier_id) { carrier_id_ = carrier_id; } std::string carrier_id() const { return carrier_id_; } protected: // DistModel Inference related bool use_dist_model_{false}; // whether use DistModel or not std::vector trainer_endpoints_{}; // all trainers' endpoints std::string current_endpoint_{}; // current trainer's endpoint int64_t nranks_{1}; // total ranks (number of trainers) int64_t rank_{0}; // rank std::string comm_init_config_{}; // converter config path std::string carrier_id_{"inference"}; }; /// /// \brief configuration manager for AnalysisPredictor. /// \since 1.7.0 /// /// AnalysisConfig manages configurations of AnalysisPredictor. /// During inference procedure, there are many parameters(model/params path, /// place of inference, etc.) /// to be specified, and various optimizations(subgraph fusion, memory /// optimazation, TensorRT engine, etc.) /// to be done. Users can manage these settings by creating and modifying an /// AnalysisConfig, /// and loading it into AnalysisPredictor. /// struct PD_INFER_DECL AnalysisConfig { AnalysisConfig(); /// /// \brief Construct a new AnalysisConfig from another /// AnalysisConfig. /// /// \param[in] other another AnalysisConfig /// AnalysisConfig(const AnalysisConfig& other); /// /// \brief Construct a new AnalysisConfig from a no-combined model. /// /// \param[in] model_dir model directory of the no-combined model. /// explicit AnalysisConfig(const std::string& model_dir); /// /// \brief Construct a new AnalysisConfig from a combined model. /// /// \param[in] prog_file model file path of the combined model. /// \param[in] params_file params file path of the combined model. /// explicit AnalysisConfig(const std::string& prog_file, const std::string& params_file); /// /// \brief Precision of inference. /// enum class Precision { kFloat32 = 0, ///< fp32 kInt8, ///< int8 kHalf, ///< fp16 kBf16, ///< bf16 }; /// /// \brief Set the no-combined model dir path. /// /// \param model_dir model dir path. /// void SetModel(const std::string& model_dir) { model_dir_ = model_dir; } /// /// \brief Set the combined model with two specific pathes for program and /// parameters. /// /// \param prog_file_path model file path of the combined model. /// \param params_file_path params file path of the combined model. /// void SetModel(const std::string& prog_file_path, const std::string& params_file_path); /// /// \brief Set the model file path of a combined model. /// /// \param x model file path. /// void SetProgFile(const std::string& x) { prog_file_ = x; } /// /// \brief Set the params file path of a combined model. /// /// \param x params file path. /// void SetParamsFile(const std::string& x) { params_file_ = x; } /// /// \brief Save optimized model. /// /// \param save_optimized_model whether to enable save optimized model. /// void EnableSaveOptimModel(bool save_optimized_model) { save_optimized_model_ = save_optimized_model; } /// /// \brief Set the path of optimization cache directory. /// /// \param opt_cache_dir the path of optimization cache directory. /// void SetOptimCacheDir(const std::string& opt_cache_dir) { opt_cache_dir_ = opt_cache_dir; } /// /// \brief Get the model directory path. /// /// \return const std::string& The model directory path. /// const std::string& model_dir() const { return model_dir_; } /// /// \brief Get the program file path. /// /// \return const std::string& The program file path. /// const std::string& prog_file() const { return prog_file_; } /// /// \brief Get the combined parameters file. /// /// \return const std::string& The combined parameters file. /// const std::string& params_file() const { return params_file_; } // Padding related. /// /// \brief Turn off FC Padding. /// /// void DisableFCPadding(); /// /// \brief A boolean state telling whether fc padding is used. /// /// \return bool Whether fc padding is used. /// bool use_fc_padding() const { return use_fc_padding_; } // GPU related. /// /// \brief Turn on GPU. /// /// \param memory_pool_init_size_mb initial size of the GPU memory pool in MB. /// \param device_id device_id the GPU card to use (default is 0). /// \param precision the precision used in Paddle-GPU inference. /// void EnableUseGpu(uint64_t memory_pool_init_size_mb, int device_id = 0, Precision precision_mode = Precision::kFloat32); /// /// \brief Turn off GPU. /// /// void DisableGpu(); /// /// \brief Turn on XPU. /// /// \param l3_workspace_size The size of the video memory allocated by the l3 /// cache, the maximum is 16M. /// \param l3_locked Whether the allocated L3 cache can be locked. If false, /// it means that the L3 cache is not locked, and the allocated L3 /// cache can be shared by multiple models, and multiple models /// sharing the L3 cache will be executed sequentially on the card. /// \param conv_autotune Whether to autotune the conv operator in the model. /// If true, when the conv operator of a certain dimension is executed /// for the first time, it will automatically search for a better /// algorithm to improve the performance of subsequent conv operators /// of the same dimension. /// \param conv_autotune_file Specify the path of the autotune file. If /// autotune_file is specified, the algorithm specified in the /// file will be used and autotune will not be performed again. /// \param transformer_encoder_precision Calculation accuracy of multi_encoder /// \param transformer_encoder_adaptive_seqlen Is the input of multi_encoder /// variable length /// \param enable_multi_stream Whether to enable the multi /// stream of xpu. /// void EnableXpu(int l3_size = 0xfffc00, bool l3_locked = false, bool conv_autotune = true, const std::string& conv_autotune_file = "", const std::string& transformer_encoder_precision = "int16", bool transformer_encoder_adaptive_seqlen = false, bool enable_multi_stream = false); /// /// \brief configs of XPU /// /// \param config Configs for xpu. See XpuConfig for more details. /// void SetXpuConfig(const XpuConfig& config); /// /// \brief Get configs of xpu /// /// \return XpuConfig The configs of xpu. /// XpuConfig xpu_config() { return xpu_config_; } /// /// \brief configs of IPU /// enum class ipu_config_code { ipu_device_num, ipu_micro_batch_size, ipu_enable_pipelining, ipu_batches_per_step, ipu_enable_fp16, ipu_replica_num, ipu_available_memory_proportion, ipu_enable_half_partial, ipu_custom_ops_info, ipu_custom_patterns, ipu_enable_model_runtime_executor, }; /// /// \brief Turn on IPU. /// /// \param ipu_device_num the number of IPUs. /// \param ipu_micro_batch_size the batch size in the graph, only work with /// mutable input shapes. /// \param ipu_enable_pipelining enable pipelining. /// \param ipu_batches_per_step the number of batches per run in pipelining. /// void EnableIpu(int ipu_device_num = 1, int ipu_micro_batch_size = 1, bool ipu_enable_pipelining = false, int ipu_batches_per_step = 1); /// /// \brief Set IPU config. /// /// \param ipu_enable_fp16 enable fp16. /// \param ipu_replica_num the number of graph replication. /// \param ipu_available_memory_proportion the available memory proportion for /// matmul/conv. /// \param ipu_enable_half_partial enable fp16 partial for matmul, only work /// with fp16. /// \param ipu_enable_model_runtime_executor whether to use model_runtime /// executor. /// void SetIpuConfig(bool ipu_enable_fp16 = false, int ipu_replica_num = 1, float ipu_available_memory_proportion = 1.0, bool ipu_enable_half_partial = false, bool ipu_enable_model_runtime_executor = false); /// /// \brief Set IPU custom ops and patterns. /// /// \param custom_ops_info the mapper of paddle custom ops and popart ops. /// e.g. {{paddle_op_name, popart_op_name, op_domain, op_version}}. /// \param custom_patterns the names of popart patterns. e.g. {{pattern_name, /// enable_pattern}}} /// void SetIpuCustomInfo( const std::vector>& ipu_custom_ops_info = {}, const std::map& ipu_custom_patterns = {}); /// /// \brief Load IPU config from configuration file. /// /// \param config_path configure file path for ipu. /// void LoadIpuConfig(const std::string& config_path); /// /// \brief Set XPU device id. /// /// \param device_id the XPU card to use (default is 0). /// void SetXpuDeviceId(int device_id = 0); /// /// \brief Turn on CustomDevice. /// /// \param device_type device_type the custom device to use. /// /// \param device_id device_id the custom device to use (default is 0). /// void EnableCustomDevice(const std::string& device_type, int device_id = 0, Precision precision_mode = Precision::kFloat32); /// /// \brief Turn on ONNXRuntime. /// void EnableONNXRuntime(); /// /// \brief Turn off ONNXRuntime. /// void DisableONNXRuntime(); /// /// \brief Turn on ONNXRuntime Optimization. /// void EnableORTOptimization(); /// /// \brief A boolean state telling whether the GPU is turned on. /// /// \return bool Whether the GPU is turned on. /// bool use_gpu() const { return use_gpu_; } /// /// \brief When running the fp16 model on Nvidia GPU, you can also try running /// your model on cutlass. /// void Exp_EnableUseCutlass(); /// /// /// \brief A boolean state telling whether the XPU is turned on. /// /// \return bool Whether the XPU is turned on. /// bool use_xpu() const { return use_xpu_; } /// \brief A boolean state telling whether the IPU is turned on. /// /// \return bool Whether the IPU is turned on. /// bool use_ipu() const { return use_ipu_; } /// \brief A boolean state telling whether the CustomDevice is turned on. /// /// \return bool Whether the CustomDevice is turned on. /// bool use_custom_device() const { return use_custom_device_; } /// /// \brief A boolean state telling whether the ONNXRuntime is turned on. /// /// \return bool Whether the ONNXRuntime is turned on. /// bool use_onnxruntime() const { return use_onnxruntime_; } /// /// \brief A boolean state telling whether the Lite OpenCL is turned on. /// /// \return bool Whether the Lite OpenCL is turned on. /// bool use_opencl() const { return use_opencl_; } /// /// \brief A boolean state telling whether the ONNXRuntime Optimization is /// turned on. /// /// \return bool Whether the ONNXRuntime Optimization is turned on. /// bool ort_optimization_enabled() const { return enable_ort_optimization_; } /// /// \brief Get the GPU device id. /// /// \return int The GPU device id. /// int gpu_device_id() const { return gpu_device_id_; } /// /// \brief Get the XPU device id. /// /// \return int The XPU device id. /// int xpu_device_id() const { return xpu_config_.device_id; } /// \brief Get the number of IPU device . /// /// \return int The number of IPU device. /// int ipu_device_num() const { return ipu_device_num_; } /// /// \brief Get the custom device id. /// /// \return int The custom device id. /// int custom_device_id() const { return custom_device_id_; } /// \brief Get the custom device type. /// /// \return string The custom device type. /// std::string custom_device_type() const { return custom_device_type_; } /// \brief Get whether the custom device mixed preicsion is enabled. /// /// \return bool custom device mixed is enabled. /// bool enable_custom_device_mixed() const { return enable_custom_device_mixed_; } /// /// \brief Get the initial size in MB of the GPU memory pool. /// /// \return int The initial size in MB of the GPU memory pool. /// int memory_pool_init_size_mb() const { return memory_pool_init_size_mb_; } /// /// \brief Get the proportion of the initial memory pool size compared to the /// device. /// /// \return float The proportion of the initial memory pool size. /// float fraction_of_gpu_memory_for_pool() const; // CUDNN related. /// /// \brief Turn on CUDNN. /// /// void EnableCUDNN(); /// /// \brief A boolean state telling whether to use CUDNN. /// /// \return bool Whether to use CUDNN. /// bool cudnn_enabled() const { return use_cudnn_; } /// /// \brief Control whether to perform IR graph optimization. /// If turned off, the AnalysisConfig will act just like a NativeConfig. /// /// \param x Whether the ir graph optimization is actived. /// void SwitchIrOptim(int x = true) { enable_ir_optim_ = x; } /// /// \brief A boolean state telling whether the ir graph optimization is /// actived. /// /// \return bool Whether to use ir graph optimization. /// bool ir_optim() const { return enable_ir_optim_; } /// /// \brief INTERNAL Determine whether to use the feed and fetch operators. /// Just for internal development, not stable yet. /// When ZeroCopyTensor is used, this should be turned off. /// /// \param x Whether to use the feed and fetch operators. /// void SwitchUseFeedFetchOps(int x = true) { use_feed_fetch_ops_ = x; } /// /// \brief A boolean state telling whether to use the feed and fetch /// operators. /// /// \return bool Whether to use the feed and fetch operators. /// bool use_feed_fetch_ops_enabled() const { return use_feed_fetch_ops_; } /// /// \brief Turn on the feed and fetch data with low precision. /// /// \param x Whether to enable feed and fetch data with low precision. /// void EnableLowPrecisionIO(bool x = true); /// /// \brief Control whether to specify the inputs' names. /// The ZeroCopyTensor type has a name member, assign it with the /// corresponding /// variable name. This is used only when the input ZeroCopyTensors passed to /// the /// AnalysisPredictor.ZeroCopyRun() cannot follow the order in the training /// phase. /// /// \param x Whether to specify the inputs' names. /// void SwitchSpecifyInputNames(bool x = true) { specify_input_name_ = x; } /// /// \brief A boolean state tell whether the input ZeroCopyTensor names /// specified should /// be used to reorder the inputs in AnalysisPredictor.ZeroCopyRun(). /// /// \return bool Whether to specify the inputs' names. /// bool specify_input_name() const { return specify_input_name_; } /// /// \brief Turn on the TensorRT engine. /// The TensorRT engine will accelerate some subgraphes in the original Fluid /// computation graph. In some models such as resnet50, GoogleNet and so on, /// it gains significant performance acceleration. /// /// \param workspace_size The memory size(in byte) used for TensorRT /// workspace. /// \param max_batch_size The maximum batch size of this prediction task, /// better set as small as possible for less performance loss. /// \param min_subgraph_size The minimum TensorRT subgraph size needed, if a /// subgraph is smaller than this, it will not be transferred to TensorRT /// engine. /// \param precision The precision used in TensorRT. /// \param use_static Serialize optimization information to disk for reusing. /// \param use_calib_mode Use TRT int8 calibration(post training /// quantization). /// \param use_cuda_graph Use CudaGraph to reduce the time consumption of /// enqueue. Note that this option can only be enabled when your input is /// constant (including the batch dimension). /// /// void EnableTensorRtEngine(int64_t workspace_size = 1 << 30, int max_batch_size = 1, int min_subgraph_size = 3, Precision precision = Precision::kFloat32, bool use_static = false, bool use_calib_mode = true, bool use_cuda_graph = false); /// /// \brief A boolean state telling whether the TensorRT engine is used. /// /// \return bool Whether the TensorRT engine is used. /// bool tensorrt_engine_enabled() const { return use_tensorrt_; } /// /// \brief Whether to get the intermediate output of TensorRT Engine. /// /// \param output_tensor_names The name of the Tensor that needs to be marked /// void MarkTrtEngineOutputs( const std::vector& output_tensor_names = {}, const bool trt_mark_output_with_id = false); /// /// \brief Turn on the TensorRT memory optimization. /// /// \param engine_memory_sharing Whether to enable TensorRT memory /// optimization. /// \param sharing_identifier This parameter can be set if TensorRT memory /// optimization is enabled, and the value must be greater than 0. If you have /// multiple predictors that want to share memory, you can specify a /// same value for these predictors. NOTE: The predictors specified with the /// same value must be guaranteed to be executed serially, otherwise undefined /// behavior will occur. /// void EnableTensorRTMemoryOptim(bool engine_memory_sharing = true, int sharing_identifier = 0); /// /// \brief A boolean state telling whether the tensorrt engine memory sharing /// is activated. /// /// \return bool Whether the tensorrt engine memory sharing is activated. /// bool trt_engine_memory_sharing() const; /// /// \brief Get the TensorRT engine precision. /// /// \return Precision Get the TensorRT engine precision. /// Precision tensorrt_precision_mode() const { return tensorrt_precision_mode_; } /// /// \brief Set min, max, opt shape for TensorRT Dynamic shape mode. /// \param min_input_shape The min input shape of the subgraph input. /// \param max_input_shape The max input shape of the subgraph input. /// \param opt_input_shape The opt input shape of the subgraph input. /// \param disable_trt_plugin_fp16 Setting this parameter to true means that /// TRT plugin will not run fp16. /// void SetTRTDynamicShapeInfo( std::map> min_input_shape, std::map> max_input_shape, std::map> optim_input_shape, bool disable_trt_plugin_fp16 = false); /// /// \brief A boolean state telling whether the trt dynamic_shape is used. /// /// \return bool Whether the trt dynamic_shape is used. /// bool tensorrt_dynamic_shape_enabled() const { return !min_input_shape_.empty(); } /// /// \brief Enable tuned tensorrt dynamic shape. /// /// \param shape_range_info_path the path to shape_info file got in /// CollectShapeInfo /// mode. /// \param allow_build_at_runtime allow build trt engine at runtime. /// void EnableTunedTensorRtDynamicShape( const std::string& shape_range_info_path = "", bool allow_build_at_runtime = true); /// /// \brief A boolean state telling whether to use tuned tensorrt dynamic /// shape. /// bool tuned_tensorrt_dynamic_shape() const; /// /// \brief A boolean state telling whether to allow building trt engine at /// runtime. /// bool trt_allow_build_at_runtime() const; /// /// \brief Set execution stream. If not set a stream will be created /// internally. /// void SetExecStream(void* stream); /// /// \brief Get execution stream. The user needs to explicitly cast into a /// stream type such as cudaStream_t, hipStream_t, etc. /// void* GetExecStream() const; /// /// \brief Whether the external stream is used, if True, the predictor clone /// operation must use the external stream, otherwise the framework manages /// the stream internally. /// bool external_stream_enabled() const; /// /// \brief Collect shape info of all tensors in compute graph. /// /// \param shape_range_info_path the path to save shape info. /// void CollectShapeRangeInfo(const std::string& shape_range_info_path); /// /// \brief the shape info path in CollectShapeInfo mode. /// /// \return the shape info path. /// const std::string& shape_range_info_path() const; /// /// \brief A boolean state telling whether to collect shape info. /// /// \return bool Whether to collect shape info. /// bool shape_range_info_collected() const; /// /// \brief Prevent ops running in Paddle-TRT /// NOTE: just experimental, not an official stable API, easy to be broken. /// void Exp_DisableTensorRtOPs(const std::vector& ops); /// /// \brief Replace some TensorRT plugins to TensorRT OSS( /// https://github.com/NVIDIA/TensorRT), with which some models's inference /// may be more high-performance. Libnvinfer_plugin.so greater than /// V7.2.1 is needed. /// void EnableVarseqlen(); /// /// \brief A boolean state telling whether to use the TensorRT OSS. /// /// \return bool Whether to use the TensorRT OSS. /// bool tensorrt_varseqlen_enabled() { return trt_use_varseqlen_; } /// /// \brief Enable TensorRT DLA /// \param dla_core ID of DLACore, which should be 0, 1, /// ..., IBuilder.getNbDLACores() - 1 /// void EnableTensorRtDLA(int dla_core = 0); /// /// \brief A boolean state telling whether to use the TensorRT DLA. /// /// \return bool Whether to use the TensorRT DLA. /// bool tensorrt_dla_enabled() { return trt_use_dla_; } /// /// \brief A boolean state telling whether to show TensorRT inspector /// information. /// /// \return bool Whether to show TensorRT inspector information. /// void EnableTensorRtInspector(); bool tensorrt_inspector_enabled() { return trt_use_inspector_; } /// /// \brief A boolean state telling whether to use TensorRT explicit /// quantization. /// /// \return bool Whether to use TensorRT explicit quantization. /// void EnableTensorRtExplicitQuantization(); bool tensorrt_explicit_quantization_enabled() { return trt_use_explicit_quantization_; } void EnableDlnne( int min_subgraph_size = 3, int max_batch_size = 1, bool use_static_batch = false, std::string weight_share_mode = "0", std::unordered_set disable_nodes_by_outputs = {}, std::map> input_dict = {}, bool use_calib_mode = false, Precision precision_mode = Precision::kFloat32); bool dlnne_enabled() const { return use_dlnne_; } /// /// \brief Turn on the usage of Lite sub-graph engine. /// /// \param precision_mode Precion used in Lite sub-graph engine. /// \param passes_filter Set the passes used in Lite sub-graph engine. /// \param ops_filter Operators not supported by Lite. /// void EnableLiteEngine(Precision precision_mode = Precision::kFloat32, bool zero_copy = false, const std::vector& passes_filter = {}, const std::vector& ops_filter = {}); /// /// \brief Turn on the usage of Lite sub-graph engine with opencl. /// void EnableOpenCL(); /// /// \brief A boolean state indicating whether the Lite sub-graph engine is /// used. /// /// \return bool whether the Lite sub-graph engine is used. /// bool lite_engine_enabled() const { return use_lite_; } /// /// \brief Control whether to debug IR graph analysis phase. /// This will generate DOT files for visualizing the computation graph after /// each analysis pass applied. /// /// \param x whether to debug IR graph analysis phase. /// void SwitchIrDebug(int x = true); /// /// \brief Turn on MKLDNN. /// /// void EnableMKLDNN(); /// /// \brief Set the cache capacity of different input shapes for MKLDNN. /// Default value 0 means not caching any shape. /// Please see MKL-DNN Data Caching Design Document: /// https://github.com/PaddlePaddle/FluidDoc/blob/develop/doc/fluid/design/mkldnn/caching/caching.md /// /// \param capacity The cache capacity. /// void SetMkldnnCacheCapacity(int capacity); /// /// \brief A boolean state telling whether to use the MKLDNN. /// /// \return bool Whether to use the MKLDNN. /// bool mkldnn_enabled() const { return use_mkldnn_; } /// /// \brief Set the number of cpu math library threads. /// /// \param cpu_math_library_num_threads The number of cpu math library /// threads. /// void SetCpuMathLibraryNumThreads(int cpu_math_library_num_threads); /// /// \brief An int state telling how many threads are used in the CPU math /// library. /// /// \return int The number of threads used in the CPU math library. /// int cpu_math_library_num_threads() const { return cpu_math_library_num_threads_; } /// /// \brief Transform the AnalysisConfig to NativeConfig. /// /// \return NativeConfig The NativeConfig transformed. /// NativeConfig ToNativeConfig() const; /// /// \brief Specify the operator type list to use MKLDNN acceleration. /// /// \param op_list The operator type list. /// void SetMKLDNNOp(std::unordered_set op_list) { mkldnn_enabled_op_types_ = op_list; } /// /// \brief Turn on MKLDNN quantization. /// /// void EnableMkldnnQuantizer(); /// /// \brief Turn on MKLDNN int8. /// /// \param op_list The operator type list. /// void EnableMkldnnInt8(const std::unordered_set& op_list = {}); /// /// \brief A boolean state telling whether to use the MKLDNN Int8. /// /// \return bool Whether to use the MKLDNN Int8. /// bool mkldnn_int8_enabled() const { return use_mkldnn_int8_; } /// /// \brief Turn on MKLDNN bfloat16. /// /// void EnableMkldnnBfloat16(); /// /// \brief Turn off MKLDNN fc passes. /// void DisableMkldnnFcPasses(); /// /// \brief A boolean state telling whether to disable the MKLDNN Fc passes. /// /// \return bool Whether to disable the MKLDNN Fc passes. /// bool mkldnn_fc_passes_disabled() const { return disable_mkldnn_fc_passes_; } /// /// \brief A boolean state telling whether to use the MKLDNN Bfloat16. /// /// \return bool Whether to use the MKLDNN Bfloat16. /// bool mkldnn_bfloat16_enabled() const { return use_mkldnn_bfloat16_; } /// \brief Specify the operator type list to use Bfloat16 acceleration. /// /// \param op_list The operator type list. /// void SetBfloat16Op(std::unordered_set op_list) { bfloat16_enabled_op_types_ = op_list; } /// /// \brief A boolean state telling whether the thread local CUDA stream is /// enabled. /// /// \return bool Whether the thread local CUDA stream is enabled. /// bool thread_local_stream_enabled() const { return thread_local_stream_; } /// /// \brief A boolean state telling whether the MKLDNN quantization is enabled. /// /// \return bool Whether the MKLDNN quantization is enabled. /// bool mkldnn_quantizer_enabled() const { return use_mkldnn_quantizer_; } /// /// \brief Get MKLDNN quantizer config. /// /// \return MkldnnQuantizerConfig* MKLDNN quantizer config. /// MkldnnQuantizerConfig* mkldnn_quantizer_config() const; /// /// \brief Specify the memory buffer of program and parameter. /// Used when model and params are loaded directly from memory. /// /// \param prog_buffer The memory buffer of program. /// \param prog_buffer_size The size of the model data. /// \param params_buffer The memory buffer of the combined parameters file. /// \param params_buffer_size The size of the combined parameters data. /// void SetModelBuffer(const char* prog_buffer, size_t prog_buffer_size, const char* params_buffer, size_t params_buffer_size); /// /// \brief A boolean state telling whether the model is set from the CPU /// memory. /// /// \return bool Whether model and params are loaded directly from memory. /// bool model_from_memory() const { return model_from_memory_; } /// /// \brief Turn on memory optimize /// NOTE still in development. /// /// \param x Whether to enable memory optimize. /// void EnableMemoryOptim(bool x = true); /// /// \brief A boolean state telling whether the memory optimization is /// activated. /// /// \return bool Whether the memory optimization is activated. /// bool enable_memory_optim() const; /// /// \brief Turn on profiling report. /// If not turned on, no profiling report will be generated. /// void EnableProfile(); /// /// \brief A boolean state telling whether the profiler is activated. /// /// \return bool Whether the profiler is activated. /// bool profile_enabled() const { return with_profile_; } /// /// \brief Mute all logs in Paddle inference. /// void DisableGlogInfo(); /// /// \brief A boolean state telling whether logs in Paddle inference are muted. /// /// \return bool Whether logs in Paddle inference are muted. /// bool glog_info_disabled() const { return !with_glog_info_; } /// /// \brief Set the AnalysisConfig to be invalid. /// This is to ensure that an AnalysisConfig can only be used in one /// AnalysisPredictor. /// void SetInValid() const { is_valid_ = false; } /// /// \brief A boolean state telling whether the AnalysisConfig is valid. /// /// \return bool Whether the AnalysisConfig is valid. /// bool is_valid() const { return is_valid_; } friend class ::paddle::AnalysisPredictor; /// /// \brief Get a pass builder for customize the passes in IR analysis phase. /// NOTE: Just for developer, not an official API, easy to be broken. /// /// PassStrategy* pass_builder() const; /// /// \brief Enable the GPU multi-computing stream feature. /// NOTE: The current behavior of this interface is to bind the computation /// stream to the thread, and this behavior may be changed in the future. /// void EnableGpuMultiStream(); void PartiallyRelease(); /// /// \brief Print the summary of config. /// std::string Summary(); LiteNNAdapterConfig& NNAdapter() { return nnadapter_config_; } void SetDistConfig(const DistConfig& dist_config) { dist_config_ = dist_config; } const DistConfig& dist_config() const { return dist_config_; } /// /// \brief Set a list of operators that do not support mixed precision. This /// interface is in the experimental stage and may change in the future. Note /// that the blacklist must be the same as the model conversion blacklist. /// void Exp_DisableMixedPrecisionOps( const std::unordered_set& black_list); /// /// \brief Set a list of operators that do support mixed precision. This /// interface is in the experimental stage and may change in the future. Note /// that the whitelist must be the same as the model conversion whitelist. /// void Exp_EnableMixedPrecisionOps( const std::unordered_set& white_list); void SetApplyOptim(bool value) { apply_optim_ = value; } void SetSkipLoadParams(bool value) { skip_load_params_ = value; } /// /// \brief Enable use cinn compiler optimization. /// void Exp_EnableCINNCompiler(); /// /// \brief A boolean state telling whether the CINN compiler optimization is /// turned on. /// /// \return bool Whether the CINN compiler optimization is turned on. /// bool cinn_compiler_enabled() const; protected: // Update the config. void Update(); std::string SerializeInfoCache(); protected: // Model pathes. std::string model_dir_; mutable std::string prog_file_; mutable std::string params_file_; // Mixed precision related. Precision mixed_precision_mode_{Precision::kFloat32}; std::unordered_set mixed_black_list_; std::unordered_set mixed_white_list_; bool enable_low_precision_io_{false}; // GPU related. bool use_gpu_{false}; bool use_cutlass_{false}; int gpu_device_id_{0}; uint64_t memory_pool_init_size_mb_{100}; // initial size is 100MB. bool enable_gpu_mixed_{false}; bool thread_local_stream_{false}; bool use_cudnn_{false}; bool use_external_stream_{false}; void* exec_stream_{nullptr}; // CustomDevice related bool use_custom_device_{false}; int custom_device_id_{0}; std::string custom_device_type_; bool enable_custom_device_mixed_{false}; // ONNXRuntime related bool use_onnxruntime_{false}; bool enable_ort_optimization_{false}; // Padding related bool use_fc_padding_{true}; // TensorRT related. bool use_tensorrt_{false}; // For workspace_size, refer it from here: // https://docs.nvidia.com/deeplearning/sdk/tensorrt-developer-guide/index.html#troubleshooting int64_t tensorrt_workspace_size_{1 << 30}; // While TensorRT allows an engine optimized for a given max batch size // to run at any smaller size, the performance for those smaller // sizes may not be as well-optimized. Therefore, Max batch is best // equivalent to the runtime batch size. int tensorrt_max_batchsize_{1}; // We transform the Ops that can be converted into TRT layer in the model, // and aggregate these Ops into subgraphs for TRT execution. // We set this variable to control the minimum number of nodes in the // subgraph, 3 as default value. int tensorrt_min_subgraph_size_{3}; Precision tensorrt_precision_mode_{Precision::kFloat32}; bool trt_use_static_engine_{false}; bool trt_use_calib_mode_{true}; bool trt_use_cuda_graph_{false}; bool trt_use_varseqlen_{false}; bool trt_with_interleaved_{false}; bool trt_mark_output_{false}; bool trt_mark_output_with_id_{false}; std::vector trt_output_tensor_names_{}; std::string tensorrt_transformer_posid_{""}; std::string tensorrt_transformer_maskid_{""}; bool trt_use_dla_{false}; int trt_dla_core_{0}; std::map> min_input_shape_{}; std::map> max_input_shape_{}; std::map> optim_input_shape_{}; std::vector trt_disabled_ops_{}; bool disable_trt_plugin_fp16_{false}; bool trt_allow_build_at_runtime_{false}; // tune to get dynamic_shape info. bool trt_tuned_dynamic_shape_{false}; bool trt_use_inspector_{false}; bool trt_use_explicit_quantization_{false}; // In CollectShapeInfo mode, we will collect the shape information of // all intermediate tensors in the compute graph and calculate the // min_shape, max_shape and opt_shape and save in shape_range_info_path_; bool collect_shape_range_info_{false}; std::string shape_range_info_path_; // dlnne related. bool use_dlnne_{false}; int dlnne_min_subgraph_size_{3}; int dlnne_max_batchsize_{1}; std::unordered_set dlnne_disable_nodes_by_outputs_; bool dlnne_use_static_batch_{true}; std::string dlnne_weight_share_mode_; std::map> dlnne_input_shape_dict_{}; bool dlnne_use_calib_mode_{false}; Precision dlnne_precision_mode_{Precision::kFloat32}; // memory reuse related. bool enable_memory_optim_{false}; bool trt_engine_memory_sharing_{false}; int trt_engine_memory_sharing_identifier_{0}; bool use_mkldnn_{false}; std::unordered_set mkldnn_enabled_op_types_; bool model_from_memory_{false}; bool enable_ir_optim_{true}; bool use_feed_fetch_ops_{true}; bool ir_debug_{false}; bool specify_input_name_{false}; int cpu_math_library_num_threads_{1}; bool with_profile_{false}; bool with_glog_info_{true}; // A runtime cache, shouldn't be transferred to others. std::string serialized_info_cache_; mutable std::unique_ptr pass_builder_; bool use_lite_{false}; std::vector lite_passes_filter_; std::vector lite_ops_filter_; Precision lite_precision_mode_; bool lite_zero_copy_; // CINN compiler related. bool use_cinn_compiler_{false}; // XPU related. bool use_xpu_{false}; XpuConfig xpu_config_; bool xpu_lite_l3_locked_{false}; bool xpu_lite_enable_multi_stream_{false}; // LITE OPENCL SETTINGS bool use_opencl_{false}; // NNAdapter related LiteNNAdapterConfig nnadapter_config_; // mkldnn related. int mkldnn_cache_capacity_{10}; bool use_mkldnn_quantizer_{false}; std::shared_ptr mkldnn_quantizer_config_; bool use_mkldnn_bfloat16_{false}; std::unordered_set bfloat16_enabled_op_types_; bool use_mkldnn_int8_{false}; std::unordered_set quantize_excluded_op_ids_{}; std::unordered_set quantize_enabled_op_types_{}; bool disable_mkldnn_fc_passes_{false}; // ipu related. bool use_ipu_{false}; int ipu_device_num_{1}; int ipu_micro_batch_size_{1}; bool ipu_enable_pipelining_{false}; int ipu_batches_per_step_{1}; bool ipu_enable_fp16_{false}; int ipu_replica_num_{1}; float ipu_available_memory_proportion_{1.0}; bool ipu_enable_half_partial_{false}; bool ipu_enable_model_runtime_executor_{false}; std::vector> ipu_custom_ops_info_; std::vector> ipu_custom_patterns_; const std::unordered_map ipu_config_mapper_ = { {"ipu_device_num", ipu_config_code::ipu_device_num}, {"ipu_micro_batch_size", ipu_config_code::ipu_micro_batch_size}, {"ipu_enable_pipelining", ipu_config_code::ipu_enable_pipelining}, {"ipu_batches_per_step", ipu_config_code::ipu_batches_per_step}, {"ipu_enable_fp16", ipu_config_code::ipu_enable_fp16}, {"ipu_replica_num", ipu_config_code::ipu_replica_num}, {"ipu_available_memory_proportion", ipu_config_code::ipu_available_memory_proportion}, {"ipu_enable_half_partial", ipu_config_code::ipu_enable_half_partial}, {"ipu_enable_model_runtime_executor", ipu_config_code::ipu_enable_model_runtime_executor}, {"ipu_custom_ops_info", ipu_config_code::ipu_custom_ops_info}, {"ipu_custom_patterns", ipu_config_code::ipu_custom_patterns}}; // If the config is already used on a predictor, it becomes invalid. // Any config can only be used with one predictor. // Variables held by config can take up a lot of memory in some cases. // So we release the memory when the predictor is set up. mutable bool is_valid_{true}; bool save_optimized_model_{false}; std::string opt_cache_dir_; friend class paddle_infer::experimental::InternalUtils; // fleet exe related DistConfig dist_config_{}; // jit engine related // NOTE(Aureliue84): In case of Predictor in JITLayer, program is from outer // which means Predictor should apply optimization by calling // PrepareProgram(). So we add this flag to control the process. bool apply_optim_{false}; bool skip_load_params_{false}; }; } // namespace paddle