argument.h 21.6 KB
Newer Older
Y
Yan Chunwei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

/*
 * This file defines the class Argument, which is the input and output of the
 * analysis module. All the fields that needed either by Passes or PassManagers
 * are contained in Argument.
 *
 * TODO(Superjomn) Find some way better to contain the fields when it grow too
 * big.
 */

G
gongweibao 已提交
24 25
#pragma once

26
#include <map>
N
nhzlx 已提交
27
#include <memory>
G
gongweibao 已提交
28
#include <string>
N
nhzlx 已提交
29 30
#include <unordered_map>
#include <unordered_set>
31
#include <utility>
32
#include <vector>
N
nhzlx 已提交
33

34
#include "paddle/fluid/framework/ir/graph.h"
Y
Yan Chunwei 已提交
35
#include "paddle/fluid/framework/program_desc.h"
36
#include "paddle/fluid/framework/scope.h"
37

38
#include "paddle/phi/common/data_type.h"
Y
Yan Chunwei 已提交
39 40 41 42

namespace paddle {
namespace inference {
namespace analysis {
43

44
#ifdef PADDLE_WITH_DNNL
45
using VarQuantScale =
46
    std::unordered_map<std::string, std::pair<bool, phi::DenseTensor>>;
47
#endif
Y
Yan Chunwei 已提交
48 49 50 51 52 53 54

/*
 * The argument definition of both Pass and PassManagers.
 *
 * All the fields should be registered here for clearness.
 */
struct Argument {
Y
Yan Chunwei 已提交
55
  Argument() = default;
56 57 58 59
  explicit Argument(const std::string& model_dir) { SetModelDir(model_dir); }

  using unique_ptr_t = std::unique_ptr<void, std::function<void(void*)>>;
  using fusion_statis_t = std::unordered_map<std::string, int>;
60
  using input_shape_t = std::map<std::string, std::vector<int>>;
61 62

  bool Has(const std::string& key) const { return valid_fields_.count(key); }
63 64 65
  // If we set the model using config.SetModelBuffer,
  // the model and parameter will occupy additional CPU resources.
  // Use this interface to release these resources.
66 67 68 69 70 71 72 73 74 75
  void PartiallyRelease() {
    if (Has("model_program_path")) {
      if (Has("model_from_memory") && model_from_memory()) {
        model_program_path().clear();
        model_program_path().shrink_to_fit();
        model_params_path().clear();
        model_params_path().shrink_to_fit();
      }
    }
  }
76

77 78 79 80
#define DECL_ARGUMENT_FIELD(field__, Field, type__)                      \
 public:                                                                 \
  type__& field__() {                                                    \
    PADDLE_ENFORCE_EQ(                                                   \
W
Wilber 已提交
81 82
        Has(#field__),                                                   \
        true,                                                            \
83 84 85 86 87 88 89 90 91 92 93
        platform::errors::PreconditionNotMet("There is no such field")); \
    return field__##_;                                                   \
  }                                                                      \
  void Set##Field(const type__& x) {                                     \
    field__##_ = x;                                                      \
    valid_fields_.insert(#field__);                                      \
  }                                                                      \
  DECL_ARGUMENT_FIELD_VALID(field__);                                    \
  type__* field__##_ptr() { return &field__##_; }                        \
                                                                         \
 private:                                                                \
94 95
  type__ field__##_;

Z
zhupengyang 已提交
96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
#define DECL_POINTER_ARGUMENT_FIELD(field__, Field, type__)              \
 public:                                                                 \
  type__& field__() {                                                    \
    PADDLE_ENFORCE_EQ(                                                   \
        Has(#field__),                                                   \
        true,                                                            \
        platform::errors::PreconditionNotMet("There is no such field")); \
    return field__##_;                                                   \
  }                                                                      \
  void Set##Field(type__ x) {                                            \
    field__##_ = x;                                                      \
    valid_fields_.insert(#field__);                                      \
  }                                                                      \
  DECL_ARGUMENT_FIELD_VALID(field__);                                    \
  type__* field__##_ptr() { return &field__##_; }                        \
                                                                         \
 private:                                                                \
  type__ field__##_;

115 116 117
#define DECL_ARGUMENT_FIELD_VALID(field__) \
  bool field__##_valid() { return Has(#field__); }

W
Wilber 已提交
118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
#define DECL_ARGUMENT_UNIQUE_FIELD(field__, Field, type__)                  \
 public:                                                                    \
  type__& field__() {                                                       \
    PADDLE_ENFORCE_NOT_NULL(                                                \
        field__##_,                                                         \
        platform::errors::PreconditionNotMet("filed should not be null.")); \
    PADDLE_ENFORCE_EQ(                                                      \
        Has(#field__),                                                      \
        true,                                                               \
        platform::errors::PreconditionNotMet("There is no such field"));    \
    return *static_cast<type__*>(field__##_.get());                         \
  }                                                                         \
  void Set##Field(type__* x) {                                              \
    field__##_ =                                                            \
        unique_ptr_t(x, [](void* x) { delete static_cast<type__*>(x); });   \
    valid_fields_.insert(#field__);                                         \
  }                                                                         \
  void Set##Field##NotOwned(type__* x) {                                    \
    valid_fields_.insert(#field__);                                         \
137
    field__##_ = unique_ptr_t(x, [](void* x UNUSED) {});                    \
W
Wilber 已提交
138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156
  }                                                                         \
  DECL_ARGUMENT_FIELD_VALID(field__);                                       \
  type__* field__##_ptr() {                                                 \
    PADDLE_ENFORCE_EQ(                                                      \
        Has(#field__),                                                      \
        true,                                                               \
        platform::errors::PreconditionNotMet("There is no such field"));    \
    return static_cast<type__*>(field__##_.get());                          \
  }                                                                         \
  type__* Release##Field() {                                                \
    PADDLE_ENFORCE_EQ(                                                      \
        Has(#field__),                                                      \
        true,                                                               \
        platform::errors::PreconditionNotMet("There is no such field"));    \
    valid_fields_.erase(#field__);                                          \
    return static_cast<type__*>(field__##_.release());                      \
  }                                                                         \
                                                                            \
 private:                                                                   \
157 158
  unique_ptr_t field__##_;

159
  DECL_ARGUMENT_FIELD(predictor_id, PredictorID, int);
160
  DECL_ARGUMENT_FIELD(root_predictor_id, RootPredictorID, int);
161 162 163 164 165
  // Model path
  DECL_ARGUMENT_FIELD(model_dir, ModelDir, std::string);
  // Model specified with program and parameters files.
  DECL_ARGUMENT_FIELD(model_program_path, ModelProgramPath, std::string);
  DECL_ARGUMENT_FIELD(model_params_path, ModelParamsPath, std::string);
T
Tao Luo 已提交
166
  DECL_ARGUMENT_FIELD(model_from_memory, ModelFromMemory, bool);
167
  DECL_ARGUMENT_FIELD(save_optimized_model, SaveOptimizedModel, bool);
168
  DECL_ARGUMENT_FIELD(optim_cache_dir, OptimCacheDir, std::string);
169
  DECL_ARGUMENT_FIELD(enable_ir_optim, EnableIrOptim, bool);
170

171 172 173
  // For JITLayer
  DECL_ARGUMENT_FIELD(skip_load_params, SkipLoadParams, bool);

174 175 176 177 178
  // The overall graph to work on.
  DECL_ARGUMENT_UNIQUE_FIELD(main_graph, MainGraph, framework::ir::Graph);
  // The overall Scope to work on.
  DECL_ARGUMENT_UNIQUE_FIELD(scope, Scope, framework::Scope);

Y
Yan Chunwei 已提交
179
  // The default program, loaded from disk.
180 181 182
  DECL_ARGUMENT_UNIQUE_FIELD(main_program, MainProgram, framework::ProgramDesc);

  // The ir passes to perform in analysis phase.
W
Wilber 已提交
183 184
  DECL_ARGUMENT_FIELD(ir_analysis_passes,
                      IrAnalysisPasses,
185
                      std::vector<std::string>);
W
Wilber 已提交
186 187
  DECL_ARGUMENT_FIELD(analysis_passes,
                      AnalysisPasses,
Y
Yan Chunwei 已提交
188
                      std::vector<std::string>);
189

190 191 192
  // whether to mute all logs in inference.
  DECL_ARGUMENT_FIELD(disable_logs, DisableLogs, bool);

193
  // Pass a set of op types to enable its mkldnn kernel
W
Wilber 已提交
194 195
  DECL_ARGUMENT_FIELD(mkldnn_enabled_op_types,
                      MKLDNNEnabledOpTypes,
196
                      std::unordered_set<std::string>);
197 198
  // The cache capacity of different input shapes for mkldnn.
  DECL_ARGUMENT_FIELD(mkldnn_cache_capacity, MkldnnCacheCapacity, int);
199

200
#ifdef PADDLE_WITH_DNNL
201
  // A set of op types to enable their quantized kernels
W
Wilber 已提交
202 203
  DECL_ARGUMENT_FIELD(quantize_enabled_op_types,
                      QuantizeEnabledOpTypes,
204 205 206
                      std::unordered_set<std::string>);

  // A set of op IDs to exclude from enabling their quantized kernels
W
Wilber 已提交
207 208
  DECL_ARGUMENT_FIELD(quantize_excluded_op_ids,
                      QuantizeExcludedOpIds,
209 210
                      std::unordered_set<int>);

211 212
  // Scales for variables to be quantized
  DECL_ARGUMENT_FIELD(quant_var_scales, QuantVarScales, VarQuantScale);
213 214

  // A set of op types to enable their bfloat16 kernels
W
Wilber 已提交
215 216
  DECL_ARGUMENT_FIELD(bfloat16_enabled_op_types,
                      Bfloat16EnabledOpTypes,
217
                      std::unordered_set<std::string>);
B
baoachun 已提交
218 219

  DECL_ARGUMENT_FIELD(use_mkldnn_int8, UseMkldnnInt8, bool);
220
#endif
221

Y
Yan Chunwei 已提交
222
  // Passed from config.
223
  DECL_ARGUMENT_FIELD(use_gpu, UseGPU, bool);
224
  DECL_ARGUMENT_FIELD(use_cutlass, UseCutlass, bool);
225
  DECL_ARGUMENT_FIELD(use_fc_padding, UseFcPadding, bool);
S
superjomn 已提交
226
  DECL_ARGUMENT_FIELD(gpu_device_id, GPUDeviceId, int);
227

228 229 230 231
  // Usually use for trt dynamic shape.
  // TRT will select the best kernel according to opt shape
  // Setting the disable_trt_plugin_fp16 to true means that TRT plugin will not
  // run fp16.
232 233 234
  DECL_ARGUMENT_FIELD(min_input_shape, MinInputShape, input_shape_t);
  DECL_ARGUMENT_FIELD(max_input_shape, MaxInputShape, input_shape_t);
  DECL_ARGUMENT_FIELD(optim_input_shape, OptimInputShape, input_shape_t);
235
  DECL_ARGUMENT_FIELD(disable_trt_plugin_fp16, CloseTrtPluginFp16, bool);
236

237
  DECL_ARGUMENT_FIELD(use_tensorrt, UseTensorRT, bool);
238 239
  DECL_ARGUMENT_FIELD(tensorrt_use_dla, TensorRtUseDLA, bool);
  DECL_ARGUMENT_FIELD(tensorrt_dla_core, TensorRtDLACore, int);
240
  DECL_ARGUMENT_FIELD(tensorrt_max_batch_size, TensorRtMaxBatchSize, int);
241
  DECL_ARGUMENT_FIELD(tensorrt_workspace_size, TensorRtWorkspaceSize, int64_t);
242
  DECL_ARGUMENT_FIELD(tensorrt_min_subgraph_size, TensorRtMinSubgraphSize, int);
M
ming1753 已提交
243
  DECL_ARGUMENT_FIELD(trt_mark_output, TRTMarkOutput, bool);
244
  DECL_ARGUMENT_FIELD(trt_mark_output_with_id, TRTMarkOutputWithId, bool);
M
ming1753 已提交
245 246 247
  DECL_ARGUMENT_FIELD(trt_output_tensor_names,
                      TRTOutputTensorNames,
                      std::vector<std::string>);
W
Wilber 已提交
248 249
  DECL_ARGUMENT_FIELD(tensorrt_disabled_ops,
                      TensorRtDisabledOPs,
250
                      std::vector<std::string>);
251
  DECL_ARGUMENT_FIELD(tensorrt_precision_mode, TensorRtPrecisionMode, int);
W
Wilber 已提交
252 253
  DECL_ARGUMENT_FIELD(tensorrt_use_static_engine,
                      TensorRtUseStaticEngine,
N
nhzlx 已提交
254
                      bool);
255
  DECL_ARGUMENT_FIELD(tensorrt_use_calib_mode, TensorRtUseCalibMode, bool);
W
Wilber 已提交
256
  DECL_ARGUMENT_FIELD(tensorrt_use_cuda_graph, TensorRtUseCudaGraph, bool);
257
  DECL_ARGUMENT_FIELD(tensorrt_use_varseqlen, TensorRtUseOSS, bool);
258
  DECL_ARGUMENT_FIELD(tensorrt_with_interleaved, TensorRtWithInterleaved, bool);
W
Wilber 已提交
259 260
  DECL_ARGUMENT_FIELD(tensorrt_transformer_posid,
                      TensorRtTransformerPosid,
261
                      std::string);
W
Wilber 已提交
262 263
  DECL_ARGUMENT_FIELD(tensorrt_transformer_maskid,
                      TensorRtTransformerMaskid,
264
                      std::string);
265
  DECL_ARGUMENT_FIELD(tensorrt_shape_range_info_path,
W
Wilber 已提交
266 267 268 269
                      TensorRtShapeRangeInfoPath,
                      std::string);
  DECL_ARGUMENT_FIELD(tensorrt_tuned_dynamic_shape,
                      TensorRtTunedDynamicShape,
270 271
                      bool);
  DECL_ARGUMENT_FIELD(tensorrt_allow_build_at_runtime,
W
Wilber 已提交
272 273
                      TensorRtAllowBuildAtRuntime,
                      bool);
274
  DECL_ARGUMENT_FIELD(tensorrt_use_inspector, TensorRtUseInspector, bool);
L
Leo Chen 已提交
275 276 277
  DECL_ARGUMENT_FIELD(tensorrt_use_explicit_quantization,
                      TensorRtUseExplicitQuantization,
                      bool);
278

D
denglin-github 已提交
279 280 281
  DECL_ARGUMENT_FIELD(use_dlnne, UseDlnne, bool);
  DECL_ARGUMENT_FIELD(dlnne_min_subgraph_size, DlnneMinSubgraphSize, int);
  DECL_ARGUMENT_FIELD(dlnne_max_batch_size, DlnneMaxBatchSize, int);
D
denglin-github 已提交
282 283 284 285 286 287 288 289
  DECL_ARGUMENT_FIELD(dlnne_use_static_batch, DlnneUseStaticBatch, bool);
  DECL_ARGUMENT_FIELD(dlnne_weight_share_mode,
                      DlnneWeightShareMode,
                      std::string);
  DECL_ARGUMENT_FIELD(dlnne_disable_nodes_by_outputs,
                      DlnneDisableNodesByOutputs,
                      std::unordered_set<std::string>);
  DECL_ARGUMENT_FIELD(dlnne_use_calib_mode, DlnneUseCalibMode, bool);
290
  DECL_ARGUMENT_FIELD(dlnne_precision_mode, DlnnePrecisionMode, int);
D
denglin-github 已提交
291 292 293 294 295

  using dlnne_input_shape_type = std::map<std::string, std::vector<int64_t>>;
  DECL_ARGUMENT_FIELD(dlnne_input_shape_dict,
                      DlnneInputShapeDict,
                      dlnne_input_shape_type);
D
denglin-github 已提交
296 297
  DECL_ARGUMENT_FIELD(dlnne_workspace_size, DlnneWorkspaceSize, int);

W
Wilber 已提交
298 299
  DECL_ARGUMENT_FIELD(lite_passes_filter,
                      LitePassesFilter,
石晓伟 已提交
300 301
                      std::vector<std::string>);
  DECL_ARGUMENT_FIELD(lite_ops_filter, LiteOpsFilter, std::vector<std::string>);
302
  DECL_ARGUMENT_FIELD(lite_precision_mode, LitePrecisionMode, int);
303 304 305
  DECL_ARGUMENT_FIELD(lite_zero_copy, LiteZeroCopy, bool);

  DECL_ARGUMENT_FIELD(use_xpu, UseXpu, bool);
W
Wilber 已提交
306 307
  DECL_ARGUMENT_FIELD(xpu_locked, XpuLocked, bool);
  DECL_ARGUMENT_FIELD(xpu_precision, XpuPrecision, std::string);
308
  DECL_ARGUMENT_FIELD(xpu_enable_multi_stream, XpuEnableMultiStream, bool);
Z
zhupengyang 已提交
309 310 311 312 313
  // XpuConfig
  DECL_ARGUMENT_FIELD(xpu_device_id, XpuDeviceId, int);
  DECL_ARGUMENT_FIELD(xpu_l3_size, XpuL3Size, size_t);
  DECL_POINTER_ARGUMENT_FIELD(xpu_l3_ptr, XpuL3Ptr, void*);
  DECL_ARGUMENT_FIELD(xpu_l3_autotune_size, XpuL3AutotuneSize, size_t);
314
  DECL_ARGUMENT_FIELD(xpu_context_gm_size, XpuContextGmSize, int);
315
  DECL_POINTER_ARGUMENT_FIELD(xpu_context, XpuContext, void*);
Z
zhupengyang 已提交
316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341
  DECL_POINTER_ARGUMENT_FIELD(xpu_stream, XpuStream, void*);
  DECL_ARGUMENT_FIELD(xpu_conv_autotune_level, XpuConvAutotuneLevel, int);
  DECL_ARGUMENT_FIELD(xpu_conv_autotune_file, XpuConvAutotuneFile, std::string);
  DECL_ARGUMENT_FIELD(xpu_conv_autotune_file_writeback,
                      XpuConvAutotuneFileWriteback,
                      bool);
  DECL_ARGUMENT_FIELD(xpu_fc_autotune_level, XpuFcAutotuneLevel, int);
  DECL_ARGUMENT_FIELD(xpu_fc_autotune_file, XpuFcAutotuneFile, std::string);
  DECL_ARGUMENT_FIELD(xpu_fc_autotune_file_writeback,
                      XpuFcAutotuneFileWriteback,
                      bool);
  DECL_ARGUMENT_FIELD(xpu_gemm_compute_precision, XpuGemmComputePrecision, int);
  DECL_ARGUMENT_FIELD(xpu_transformer_softmax_optimize_level,
                      XpuTransformerSoftmaxOptimizeLevel,
                      int);
  DECL_ARGUMENT_FIELD(xpu_transformer_encoder_adaptive_seqlen,
                      XpuTransformerEncoderAdaptiveSeqlen,
                      bool);
  DECL_ARGUMENT_FIELD(xpu_quant_post_static_gelu_out_threshold,
                      XpuQuantPostStaticGeluOutThreshold,
                      float);
  DECL_ARGUMENT_FIELD(xpu_quant_post_dynamic_activation_method,
                      XpuQuantPostDynamicActivationMethod,
                      int);
  DECL_ARGUMENT_FIELD(xpu_quant_post_dynamic_weight_precision,
                      XpuQuantPostDynamicWeightPrecision,
Z
zhupengyang 已提交
342 343
                      int);
  DECL_ARGUMENT_FIELD(xpu_quant_post_dynamic_op_types,
344
                      XpuQuantPostDynamicOpTypes,
Z
zhupengyang 已提交
345
                      std::vector<std::string>);
Z
zhupengyang 已提交
346 347 348 349
  DECL_ARGUMENT_FIELD(xpu_lite_l3_locked, XpuLiteL3Locked, bool);
  DECL_ARGUMENT_FIELD(xpu_lite_enable_multi_stream,
                      XpuLiteEnableMultiStream,
                      bool);
石晓伟 已提交
350

351 352
  DECL_ARGUMENT_FIELD(use_opencl, UseOpenCL, bool);

353
  DECL_ARGUMENT_FIELD(use_nnadapter, UseNNAdapter, bool);
W
Wilber 已提交
354 355
  DECL_ARGUMENT_FIELD(nnadapter_model_cache_dir,
                      NNAdapterModelCacheDir,
356
                      std::string);
W
Wilber 已提交
357 358
  DECL_ARGUMENT_FIELD(nnadapter_device_names,
                      NNAdapterDeviceNames,
359
                      std::vector<std::string>);
W
Wilber 已提交
360 361
  DECL_ARGUMENT_FIELD(nnadapter_context_properties,
                      NNAdapterContextProperties,
362 363
                      std::string);
  DECL_ARGUMENT_FIELD(nnadapter_subgraph_partition_config_buffer,
W
Wilber 已提交
364 365
                      NNAdapterSubgraphPartitionConfigBuffer,
                      std::string);
366
  DECL_ARGUMENT_FIELD(nnadapter_subgraph_partition_config_path,
W
Wilber 已提交
367 368 369 370
                      NNAdapterSubgraphPartitionConfigPath,
                      std::string);
  DECL_ARGUMENT_FIELD(nnadapter_model_cache_token,
                      NNAdapterModelCacheToken,
371
                      std::vector<std::string>);
W
Wilber 已提交
372 373
  DECL_ARGUMENT_FIELD(nnadapter_model_cache_buffer,
                      NNAdapterModelCacheBuffer,
374 375
                      std::vector<std::vector<char>>);

Y
Yan Chunwei 已提交
376 377
  // Memory optimized related.
  DECL_ARGUMENT_FIELD(enable_memory_optim, EnableMemoryOptim, bool);
378
  DECL_ARGUMENT_FIELD(trt_engine_memory_sharing, TrtEngineMemorySharing, bool);
379

Y
Yan Chunwei 已提交
380 381 382 383
  // Indicate which kind of sort algorithm is used for operators, the memory
  // optimization relays on the sort algorithm.
  DECL_ARGUMENT_FIELD(memory_optim_sort_kind, MemoryOptimSortKind, int);

384
  // The program transformed by IR analysis phase.
W
Wilber 已提交
385 386
  DECL_ARGUMENT_UNIQUE_FIELD(ir_analyzed_program,
                             IrAnalyzedProgram,
387 388 389
                             framework::proto::ProgramDesc);

  DECL_ARGUMENT_FIELD(fusion_statis, FusionStatis, fusion_statis_t);
390

W
Wilber 已提交
391
  // Only used in paddle-lite subgraph.
W
Wilber 已提交
392 393
  DECL_ARGUMENT_FIELD(cpu_math_library_num_threads,
                      CpuMathLibraryNumThreads,
W
Wilber 已提交
394 395
                      int);

J
jianghaicheng 已提交
396 397 398
  // ipu related
  DECL_ARGUMENT_FIELD(use_ipu, UseIpu, bool);
  DECL_ARGUMENT_FIELD(ipu_device_num, IpuDeviceNum, int);
399
  DECL_ARGUMENT_FIELD(ipu_micro_batch_size, IpuMicroBatchSize, int);
J
jianghaicheng 已提交
400 401
  DECL_ARGUMENT_FIELD(ipu_enable_pipelining, IpuEnablePipelining, bool);
  DECL_ARGUMENT_FIELD(ipu_batches_per_step, IpuBatchesPerStep, int);
402 403 404
  DECL_ARGUMENT_FIELD(ipu_enable_fp16, IpuEnableFp16, bool);
  DECL_ARGUMENT_FIELD(ipu_replica_num, IpuReplicaNum, int);
  DECL_ARGUMENT_FIELD(ipu_available_memory_proportion,
W
Wilber 已提交
405 406
                      IpuAvailableMemoryProportion,
                      float);
407
  DECL_ARGUMENT_FIELD(ipu_enable_half_partial, IpuEnableHalfPartial, bool);
408 409 410 411 412 413
  DECL_ARGUMENT_FIELD(ipu_custom_ops_info,
                      IpuCustomOpsInfo,
                      std::vector<std::vector<std::string>>);
  DECL_ARGUMENT_FIELD(ipu_custom_patterns,
                      IpuCustomPatterns,
                      std::vector<std::vector<std::string>>);
414 415 416
  DECL_ARGUMENT_FIELD(ipu_enable_model_runtime_executor,
                      IpuEnableModelRuntimeExecutor,
                      bool);
J
jianghaicheng 已提交
417

418 419
  // mixed precision related
  DECL_ARGUMENT_FIELD(model_precision, ModelPrecision, int);
420 421 422
  DECL_ARGUMENT_FIELD(mixed_black_list,
                      MixedBlackList,
                      std::unordered_set<std::string>);
423 424 425
  DECL_ARGUMENT_FIELD(mixed_white_list,
                      MixedWhiteList,
                      std::unordered_set<std::string>);
426
  DECL_ARGUMENT_FIELD(enable_gpu_mixed, EnableGPUMixed, bool);
427
  DECL_ARGUMENT_FIELD(mixed_precision_mode, MixedPrecisionMode, int);
428
  DECL_ARGUMENT_FIELD(enable_low_precision_io, EnableLowPrecisionIO, bool);
429

430 431 432
  // cinn compiler related
  DECL_ARGUMENT_FIELD(use_cinn_compiler, UseCinnCompiler, bool);

433 434 435 436
  // custom device
  DECL_ARGUMENT_FIELD(use_custom_device, UseCustomDevice, bool);
  DECL_ARGUMENT_FIELD(custom_device_type, CustomDeviceType, std::string);
  DECL_ARGUMENT_FIELD(custom_device_id, CustomDeviceId, int);
437 438 439
  DECL_ARGUMENT_FIELD(enable_custom_device_mixed,
                      EnableCustomDeviceMixed,
                      bool);
440

441
 private:
442
  std::unordered_set<std::string> valid_fields_;
Y
Yan Chunwei 已提交
443 444
};

445
#define ARGUMENT_CHECK_FIELD(argument__, fieldname__) \
446
  PADDLE_ENFORCE_EQ(                                  \
W
Wilber 已提交
447 448
      argument__->Has(#fieldname__),                  \
      true,                                           \
449 450
      platform::errors::PreconditionNotMet(           \
          "the argument field [%s] should be set", #fieldname__));
Y
Yan Chunwei 已提交
451 452 453 454

}  // namespace analysis
}  // namespace inference
}  // namespace paddle