argument.h 21.0 KB
Newer Older
Y
Yan Chunwei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

/*
 * This file defines the class Argument, which is the input and output of the
 * analysis module. All the fields that needed either by Passes or PassManagers
 * are contained in Argument.
 *
 * TODO(Superjomn) Find some way better to contain the fields when it grow too
 * big.
 */

G
gongweibao 已提交
24 25
#pragma once

26
#include <map>
N
nhzlx 已提交
27
#include <memory>
G
gongweibao 已提交
28
#include <string>
N
nhzlx 已提交
29 30
#include <unordered_map>
#include <unordered_set>
31
#include <utility>
32
#include <vector>
N
nhzlx 已提交
33

34
#include "paddle/fluid/framework/ir/graph.h"
Y
Yan Chunwei 已提交
35
#include "paddle/fluid/framework/program_desc.h"
36
#include "paddle/fluid/framework/scope.h"
37

38
#include "paddle/phi/common/data_type.h"
Y
Yan Chunwei 已提交
39 40 41 42

namespace paddle {
namespace inference {
namespace analysis {
43

44
#ifdef PADDLE_WITH_MKLDNN
45
using VarQuantScale =
46
    std::unordered_map<std::string, std::pair<bool, phi::DenseTensor>>;
47
#endif
Y
Yan Chunwei 已提交
48 49 50 51 52 53 54

/*
 * The argument definition of both Pass and PassManagers.
 *
 * All the fields should be registered here for clearness.
 */
struct Argument {
Y
Yan Chunwei 已提交
55
  Argument() = default;
56 57 58 59
  explicit Argument(const std::string& model_dir) { SetModelDir(model_dir); }

  using unique_ptr_t = std::unique_ptr<void, std::function<void(void*)>>;
  using fusion_statis_t = std::unordered_map<std::string, int>;
60
  using input_shape_t = std::map<std::string, std::vector<int>>;
61 62

  bool Has(const std::string& key) const { return valid_fields_.count(key); }
63 64 65
  // If we set the model using config.SetModelBuffer,
  // the model and parameter will occupy additional CPU resources.
  // Use this interface to release these resources.
66 67 68 69 70 71 72 73 74 75
  void PartiallyRelease() {
    if (Has("model_program_path")) {
      if (Has("model_from_memory") && model_from_memory()) {
        model_program_path().clear();
        model_program_path().shrink_to_fit();
        model_params_path().clear();
        model_params_path().shrink_to_fit();
      }
    }
  }
76

77 78 79 80
#define DECL_ARGUMENT_FIELD(field__, Field, type__)                      \
 public:                                                                 \
  type__& field__() {                                                    \
    PADDLE_ENFORCE_EQ(                                                   \
W
Wilber 已提交
81 82
        Has(#field__),                                                   \
        true,                                                            \
83 84 85 86 87 88 89 90 91 92 93
        platform::errors::PreconditionNotMet("There is no such field")); \
    return field__##_;                                                   \
  }                                                                      \
  void Set##Field(const type__& x) {                                     \
    field__##_ = x;                                                      \
    valid_fields_.insert(#field__);                                      \
  }                                                                      \
  DECL_ARGUMENT_FIELD_VALID(field__);                                    \
  type__* field__##_ptr() { return &field__##_; }                        \
                                                                         \
 private:                                                                \
94 95
  type__ field__##_;

Z
zhupengyang 已提交
96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
#define DECL_POINTER_ARGUMENT_FIELD(field__, Field, type__)              \
 public:                                                                 \
  type__& field__() {                                                    \
    PADDLE_ENFORCE_EQ(                                                   \
        Has(#field__),                                                   \
        true,                                                            \
        platform::errors::PreconditionNotMet("There is no such field")); \
    return field__##_;                                                   \
  }                                                                      \
  void Set##Field(type__ x) {                                            \
    field__##_ = x;                                                      \
    valid_fields_.insert(#field__);                                      \
  }                                                                      \
  DECL_ARGUMENT_FIELD_VALID(field__);                                    \
  type__* field__##_ptr() { return &field__##_; }                        \
                                                                         \
 private:                                                                \
  type__ field__##_;

115 116 117
#define DECL_ARGUMENT_FIELD_VALID(field__) \
  bool field__##_valid() { return Has(#field__); }

W
Wilber 已提交
118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
#define DECL_ARGUMENT_UNIQUE_FIELD(field__, Field, type__)                  \
 public:                                                                    \
  type__& field__() {                                                       \
    PADDLE_ENFORCE_NOT_NULL(                                                \
        field__##_,                                                         \
        platform::errors::PreconditionNotMet("filed should not be null.")); \
    PADDLE_ENFORCE_EQ(                                                      \
        Has(#field__),                                                      \
        true,                                                               \
        platform::errors::PreconditionNotMet("There is no such field"));    \
    return *static_cast<type__*>(field__##_.get());                         \
  }                                                                         \
  void Set##Field(type__* x) {                                              \
    field__##_ =                                                            \
        unique_ptr_t(x, [](void* x) { delete static_cast<type__*>(x); });   \
    valid_fields_.insert(#field__);                                         \
  }                                                                         \
  void Set##Field##NotOwned(type__* x) {                                    \
    valid_fields_.insert(#field__);                                         \
137
    field__##_ = unique_ptr_t(x, [](void* x UNUSED) {});                    \
W
Wilber 已提交
138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156
  }                                                                         \
  DECL_ARGUMENT_FIELD_VALID(field__);                                       \
  type__* field__##_ptr() {                                                 \
    PADDLE_ENFORCE_EQ(                                                      \
        Has(#field__),                                                      \
        true,                                                               \
        platform::errors::PreconditionNotMet("There is no such field"));    \
    return static_cast<type__*>(field__##_.get());                          \
  }                                                                         \
  type__* Release##Field() {                                                \
    PADDLE_ENFORCE_EQ(                                                      \
        Has(#field__),                                                      \
        true,                                                               \
        platform::errors::PreconditionNotMet("There is no such field"));    \
    valid_fields_.erase(#field__);                                          \
    return static_cast<type__*>(field__##_.release());                      \
  }                                                                         \
                                                                            \
 private:                                                                   \
157 158
  unique_ptr_t field__##_;

159
  DECL_ARGUMENT_FIELD(predictor_id, PredictorID, int);
160
  DECL_ARGUMENT_FIELD(root_predictor_id, RootPredictorID, int);
161 162 163 164 165
  // Model path
  DECL_ARGUMENT_FIELD(model_dir, ModelDir, std::string);
  // Model specified with program and parameters files.
  DECL_ARGUMENT_FIELD(model_program_path, ModelProgramPath, std::string);
  DECL_ARGUMENT_FIELD(model_params_path, ModelParamsPath, std::string);
T
Tao Luo 已提交
166
  DECL_ARGUMENT_FIELD(model_from_memory, ModelFromMemory, bool);
167
  DECL_ARGUMENT_FIELD(save_optimized_model, SaveOptimizedModel, bool);
168
  DECL_ARGUMENT_FIELD(optim_cache_dir, OptimCacheDir, std::string);
169
  DECL_ARGUMENT_FIELD(enable_ir_optim, EnableIrOptim, bool);
170

171 172 173
  // For JITLayer
  DECL_ARGUMENT_FIELD(skip_load_params, SkipLoadParams, bool);

174 175 176 177 178
  // The overall graph to work on.
  DECL_ARGUMENT_UNIQUE_FIELD(main_graph, MainGraph, framework::ir::Graph);
  // The overall Scope to work on.
  DECL_ARGUMENT_UNIQUE_FIELD(scope, Scope, framework::Scope);

Y
Yan Chunwei 已提交
179
  // The default program, loaded from disk.
180 181 182
  DECL_ARGUMENT_UNIQUE_FIELD(main_program, MainProgram, framework::ProgramDesc);

  // The ir passes to perform in analysis phase.
W
Wilber 已提交
183 184
  DECL_ARGUMENT_FIELD(ir_analysis_passes,
                      IrAnalysisPasses,
185
                      std::vector<std::string>);
W
Wilber 已提交
186 187
  DECL_ARGUMENT_FIELD(analysis_passes,
                      AnalysisPasses,
Y
Yan Chunwei 已提交
188
                      std::vector<std::string>);
189

190 191 192
  // whether to mute all logs in inference.
  DECL_ARGUMENT_FIELD(disable_logs, DisableLogs, bool);

193
  // Pass a set of op types to enable its mkldnn kernel
W
Wilber 已提交
194 195
  DECL_ARGUMENT_FIELD(mkldnn_enabled_op_types,
                      MKLDNNEnabledOpTypes,
196
                      std::unordered_set<std::string>);
197 198
  // The cache capacity of different input shapes for mkldnn.
  DECL_ARGUMENT_FIELD(mkldnn_cache_capacity, MkldnnCacheCapacity, int);
199

200
#ifdef PADDLE_WITH_MKLDNN
201
  // A set of op types to enable their quantized kernels
W
Wilber 已提交
202 203
  DECL_ARGUMENT_FIELD(quantize_enabled_op_types,
                      QuantizeEnabledOpTypes,
204 205 206
                      std::unordered_set<std::string>);

  // A set of op IDs to exclude from enabling their quantized kernels
W
Wilber 已提交
207 208
  DECL_ARGUMENT_FIELD(quantize_excluded_op_ids,
                      QuantizeExcludedOpIds,
209 210
                      std::unordered_set<int>);

211 212
  // Scales for variables to be quantized
  DECL_ARGUMENT_FIELD(quant_var_scales, QuantVarScales, VarQuantScale);
213 214

  // A set of op types to enable their bfloat16 kernels
W
Wilber 已提交
215 216
  DECL_ARGUMENT_FIELD(bfloat16_enabled_op_types,
                      Bfloat16EnabledOpTypes,
217
                      std::unordered_set<std::string>);
B
baoachun 已提交
218 219

  DECL_ARGUMENT_FIELD(use_mkldnn_int8, UseMkldnnInt8, bool);
220
#endif
221

Y
Yan Chunwei 已提交
222
  // Passed from config.
223
  DECL_ARGUMENT_FIELD(use_gpu, UseGPU, bool);
224
  DECL_ARGUMENT_FIELD(use_cutlass, UseCutlass, bool);
225
  DECL_ARGUMENT_FIELD(use_fc_padding, UseFcPadding, bool);
S
superjomn 已提交
226
  DECL_ARGUMENT_FIELD(gpu_device_id, GPUDeviceId, int);
227

228 229 230 231
  // Usually use for trt dynamic shape.
  // TRT will select the best kernel according to opt shape
  // Setting the disable_trt_plugin_fp16 to true means that TRT plugin will not
  // run fp16.
232 233 234
  DECL_ARGUMENT_FIELD(min_input_shape, MinInputShape, input_shape_t);
  DECL_ARGUMENT_FIELD(max_input_shape, MaxInputShape, input_shape_t);
  DECL_ARGUMENT_FIELD(optim_input_shape, OptimInputShape, input_shape_t);
235
  DECL_ARGUMENT_FIELD(disable_trt_plugin_fp16, CloseTrtPluginFp16, bool);
236

237
  DECL_ARGUMENT_FIELD(use_tensorrt, UseTensorRT, bool);
238 239
  DECL_ARGUMENT_FIELD(tensorrt_use_dla, TensorRtUseDLA, bool);
  DECL_ARGUMENT_FIELD(tensorrt_dla_core, TensorRtDLACore, int);
240
  DECL_ARGUMENT_FIELD(tensorrt_max_batch_size, TensorRtMaxBatchSize, int);
241
  DECL_ARGUMENT_FIELD(tensorrt_workspace_size, TensorRtWorkspaceSize, int64_t);
242
  DECL_ARGUMENT_FIELD(tensorrt_min_subgraph_size, TensorRtMinSubgraphSize, int);
W
Wilber 已提交
243 244
  DECL_ARGUMENT_FIELD(tensorrt_disabled_ops,
                      TensorRtDisabledOPs,
245
                      std::vector<std::string>);
246
  DECL_ARGUMENT_FIELD(tensorrt_precision_mode, TensorRtPrecisionMode, int);
W
Wilber 已提交
247 248
  DECL_ARGUMENT_FIELD(tensorrt_use_static_engine,
                      TensorRtUseStaticEngine,
N
nhzlx 已提交
249
                      bool);
250
  DECL_ARGUMENT_FIELD(tensorrt_use_calib_mode, TensorRtUseCalibMode, bool);
W
Wilber 已提交
251
  DECL_ARGUMENT_FIELD(tensorrt_use_cuda_graph, TensorRtUseCudaGraph, bool);
252
  DECL_ARGUMENT_FIELD(tensorrt_use_varseqlen, TensorRtUseOSS, bool);
253
  DECL_ARGUMENT_FIELD(tensorrt_with_interleaved, TensorRtWithInterleaved, bool);
W
Wilber 已提交
254 255
  DECL_ARGUMENT_FIELD(tensorrt_transformer_posid,
                      TensorRtTransformerPosid,
256
                      std::string);
W
Wilber 已提交
257 258
  DECL_ARGUMENT_FIELD(tensorrt_transformer_maskid,
                      TensorRtTransformerMaskid,
259
                      std::string);
260
  DECL_ARGUMENT_FIELD(tensorrt_shape_range_info_path,
W
Wilber 已提交
261 262 263 264
                      TensorRtShapeRangeInfoPath,
                      std::string);
  DECL_ARGUMENT_FIELD(tensorrt_tuned_dynamic_shape,
                      TensorRtTunedDynamicShape,
265 266
                      bool);
  DECL_ARGUMENT_FIELD(tensorrt_allow_build_at_runtime,
W
Wilber 已提交
267 268
                      TensorRtAllowBuildAtRuntime,
                      bool);
269
  DECL_ARGUMENT_FIELD(tensorrt_use_inspector, TensorRtUseInspector, bool);
270

D
denglin-github 已提交
271 272 273
  DECL_ARGUMENT_FIELD(use_dlnne, UseDlnne, bool);
  DECL_ARGUMENT_FIELD(dlnne_min_subgraph_size, DlnneMinSubgraphSize, int);
  DECL_ARGUMENT_FIELD(dlnne_max_batch_size, DlnneMaxBatchSize, int);
D
denglin-github 已提交
274 275 276 277 278 279 280 281
  DECL_ARGUMENT_FIELD(dlnne_use_static_batch, DlnneUseStaticBatch, bool);
  DECL_ARGUMENT_FIELD(dlnne_weight_share_mode,
                      DlnneWeightShareMode,
                      std::string);
  DECL_ARGUMENT_FIELD(dlnne_disable_nodes_by_outputs,
                      DlnneDisableNodesByOutputs,
                      std::unordered_set<std::string>);
  DECL_ARGUMENT_FIELD(dlnne_use_calib_mode, DlnneUseCalibMode, bool);
282
  DECL_ARGUMENT_FIELD(dlnne_precision_mode, DlnnePrecisionMode, int);
D
denglin-github 已提交
283 284 285 286 287

  using dlnne_input_shape_type = std::map<std::string, std::vector<int64_t>>;
  DECL_ARGUMENT_FIELD(dlnne_input_shape_dict,
                      DlnneInputShapeDict,
                      dlnne_input_shape_type);
D
denglin-github 已提交
288 289
  DECL_ARGUMENT_FIELD(dlnne_workspace_size, DlnneWorkspaceSize, int);

W
Wilber 已提交
290 291
  DECL_ARGUMENT_FIELD(lite_passes_filter,
                      LitePassesFilter,
石晓伟 已提交
292 293
                      std::vector<std::string>);
  DECL_ARGUMENT_FIELD(lite_ops_filter, LiteOpsFilter, std::vector<std::string>);
294
  DECL_ARGUMENT_FIELD(lite_precision_mode, LitePrecisionMode, int);
295 296 297
  DECL_ARGUMENT_FIELD(lite_zero_copy, LiteZeroCopy, bool);

  DECL_ARGUMENT_FIELD(use_xpu, UseXpu, bool);
W
Wilber 已提交
298 299
  DECL_ARGUMENT_FIELD(xpu_locked, XpuLocked, bool);
  DECL_ARGUMENT_FIELD(xpu_precision, XpuPrecision, std::string);
300
  DECL_ARGUMENT_FIELD(xpu_enable_multi_stream, XpuEnableMultiStream, bool);
Z
zhupengyang 已提交
301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331
  // XpuConfig
  DECL_ARGUMENT_FIELD(xpu_device_id, XpuDeviceId, int);
  DECL_ARGUMENT_FIELD(xpu_l3_size, XpuL3Size, size_t);
  DECL_POINTER_ARGUMENT_FIELD(xpu_l3_ptr, XpuL3Ptr, void*);
  DECL_ARGUMENT_FIELD(xpu_l3_autotune_size, XpuL3AutotuneSize, size_t);
  DECL_POINTER_ARGUMENT_FIELD(xpu_stream, XpuStream, void*);
  DECL_ARGUMENT_FIELD(xpu_conv_autotune_level, XpuConvAutotuneLevel, int);
  DECL_ARGUMENT_FIELD(xpu_conv_autotune_file, XpuConvAutotuneFile, std::string);
  DECL_ARGUMENT_FIELD(xpu_conv_autotune_file_writeback,
                      XpuConvAutotuneFileWriteback,
                      bool);
  DECL_ARGUMENT_FIELD(xpu_fc_autotune_level, XpuFcAutotuneLevel, int);
  DECL_ARGUMENT_FIELD(xpu_fc_autotune_file, XpuFcAutotuneFile, std::string);
  DECL_ARGUMENT_FIELD(xpu_fc_autotune_file_writeback,
                      XpuFcAutotuneFileWriteback,
                      bool);
  DECL_ARGUMENT_FIELD(xpu_gemm_compute_precision, XpuGemmComputePrecision, int);
  DECL_ARGUMENT_FIELD(xpu_transformer_softmax_optimize_level,
                      XpuTransformerSoftmaxOptimizeLevel,
                      int);
  DECL_ARGUMENT_FIELD(xpu_transformer_encoder_adaptive_seqlen,
                      XpuTransformerEncoderAdaptiveSeqlen,
                      bool);
  DECL_ARGUMENT_FIELD(xpu_quant_post_static_gelu_out_threshold,
                      XpuQuantPostStaticGeluOutThreshold,
                      float);
  DECL_ARGUMENT_FIELD(xpu_quant_post_dynamic_activation_method,
                      XpuQuantPostDynamicActivationMethod,
                      int);
  DECL_ARGUMENT_FIELD(xpu_quant_post_dynamic_weight_precision,
                      XpuQuantPostDynamicWeightPrecision,
Z
zhupengyang 已提交
332 333
                      int);
  DECL_ARGUMENT_FIELD(xpu_quant_post_dynamic_op_types,
334
                      XpuQuantPostDynamicOpTypes,
Z
zhupengyang 已提交
335
                      std::vector<std::string>);
Z
zhupengyang 已提交
336 337 338 339
  DECL_ARGUMENT_FIELD(xpu_lite_l3_locked, XpuLiteL3Locked, bool);
  DECL_ARGUMENT_FIELD(xpu_lite_enable_multi_stream,
                      XpuLiteEnableMultiStream,
                      bool);
石晓伟 已提交
340

341 342
  DECL_ARGUMENT_FIELD(use_opencl, UseOpenCL, bool);

343
  DECL_ARGUMENT_FIELD(use_nnadapter, UseNNAdapter, bool);
W
Wilber 已提交
344 345
  DECL_ARGUMENT_FIELD(nnadapter_model_cache_dir,
                      NNAdapterModelCacheDir,
346
                      std::string);
W
Wilber 已提交
347 348
  DECL_ARGUMENT_FIELD(nnadapter_device_names,
                      NNAdapterDeviceNames,
349
                      std::vector<std::string>);
W
Wilber 已提交
350 351
  DECL_ARGUMENT_FIELD(nnadapter_context_properties,
                      NNAdapterContextProperties,
352 353
                      std::string);
  DECL_ARGUMENT_FIELD(nnadapter_subgraph_partition_config_buffer,
W
Wilber 已提交
354 355
                      NNAdapterSubgraphPartitionConfigBuffer,
                      std::string);
356
  DECL_ARGUMENT_FIELD(nnadapter_subgraph_partition_config_path,
W
Wilber 已提交
357 358 359 360
                      NNAdapterSubgraphPartitionConfigPath,
                      std::string);
  DECL_ARGUMENT_FIELD(nnadapter_model_cache_token,
                      NNAdapterModelCacheToken,
361
                      std::vector<std::string>);
W
Wilber 已提交
362 363
  DECL_ARGUMENT_FIELD(nnadapter_model_cache_buffer,
                      NNAdapterModelCacheBuffer,
364 365
                      std::vector<std::vector<char>>);

Y
Yan Chunwei 已提交
366 367
  // Memory optimized related.
  DECL_ARGUMENT_FIELD(enable_memory_optim, EnableMemoryOptim, bool);
368
  DECL_ARGUMENT_FIELD(trt_engine_memory_sharing, TrtEngineMemorySharing, bool);
369

Y
Yan Chunwei 已提交
370 371 372 373
  // Indicate which kind of sort algorithm is used for operators, the memory
  // optimization relays on the sort algorithm.
  DECL_ARGUMENT_FIELD(memory_optim_sort_kind, MemoryOptimSortKind, int);

374
  // The program transformed by IR analysis phase.
W
Wilber 已提交
375 376
  DECL_ARGUMENT_UNIQUE_FIELD(ir_analyzed_program,
                             IrAnalyzedProgram,
377 378 379
                             framework::proto::ProgramDesc);

  DECL_ARGUMENT_FIELD(fusion_statis, FusionStatis, fusion_statis_t);
380

W
Wilber 已提交
381
  // Only used in paddle-lite subgraph.
W
Wilber 已提交
382 383
  DECL_ARGUMENT_FIELD(cpu_math_library_num_threads,
                      CpuMathLibraryNumThreads,
W
Wilber 已提交
384 385
                      int);

J
jianghaicheng 已提交
386 387 388
  // ipu related
  DECL_ARGUMENT_FIELD(use_ipu, UseIpu, bool);
  DECL_ARGUMENT_FIELD(ipu_device_num, IpuDeviceNum, int);
389
  DECL_ARGUMENT_FIELD(ipu_micro_batch_size, IpuMicroBatchSize, int);
J
jianghaicheng 已提交
390 391
  DECL_ARGUMENT_FIELD(ipu_enable_pipelining, IpuEnablePipelining, bool);
  DECL_ARGUMENT_FIELD(ipu_batches_per_step, IpuBatchesPerStep, int);
392 393 394
  DECL_ARGUMENT_FIELD(ipu_enable_fp16, IpuEnableFp16, bool);
  DECL_ARGUMENT_FIELD(ipu_replica_num, IpuReplicaNum, int);
  DECL_ARGUMENT_FIELD(ipu_available_memory_proportion,
W
Wilber 已提交
395 396
                      IpuAvailableMemoryProportion,
                      float);
397
  DECL_ARGUMENT_FIELD(ipu_enable_half_partial, IpuEnableHalfPartial, bool);
398 399 400 401 402 403
  DECL_ARGUMENT_FIELD(ipu_custom_ops_info,
                      IpuCustomOpsInfo,
                      std::vector<std::vector<std::string>>);
  DECL_ARGUMENT_FIELD(ipu_custom_patterns,
                      IpuCustomPatterns,
                      std::vector<std::vector<std::string>>);
404 405 406
  DECL_ARGUMENT_FIELD(ipu_enable_model_runtime_executor,
                      IpuEnableModelRuntimeExecutor,
                      bool);
J
jianghaicheng 已提交
407

408 409
  // mixed precision related
  DECL_ARGUMENT_FIELD(model_precision, ModelPrecision, int);
410 411 412
  DECL_ARGUMENT_FIELD(mixed_black_list,
                      MixedBlackList,
                      std::unordered_set<std::string>);
413
  DECL_ARGUMENT_FIELD(enable_gpu_mixed, EnableGPUMixed, bool);
414
  DECL_ARGUMENT_FIELD(mixed_precision_mode, MixedPrecisionMode, int);
415
  DECL_ARGUMENT_FIELD(enable_low_precision_io, EnableLowPrecisionIO, bool);
416

417 418 419
  // cinn compiler related
  DECL_ARGUMENT_FIELD(use_cinn_compiler, UseCinnCompiler, bool);

420 421 422 423
  // custom device
  DECL_ARGUMENT_FIELD(use_custom_device, UseCustomDevice, bool);
  DECL_ARGUMENT_FIELD(custom_device_type, CustomDeviceType, std::string);
  DECL_ARGUMENT_FIELD(custom_device_id, CustomDeviceId, int);
424 425 426
  DECL_ARGUMENT_FIELD(enable_custom_device_mixed,
                      EnableCustomDeviceMixed,
                      bool);
427

428
 private:
429
  std::unordered_set<std::string> valid_fields_;
Y
Yan Chunwei 已提交
430 431
};

432
#define ARGUMENT_CHECK_FIELD(argument__, fieldname__) \
433
  PADDLE_ENFORCE_EQ(                                  \
W
Wilber 已提交
434 435
      argument__->Has(#fieldname__),                  \
      true,                                           \
436 437
      platform::errors::PreconditionNotMet(           \
          "the argument field [%s] should be set", #fieldname__));
Y
Yan Chunwei 已提交
438 439 440 441

}  // namespace analysis
}  // namespace inference
}  // namespace paddle