argument.h 17.8 KB
Newer Older
Y
Yan Chunwei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

/*
 * This file defines the class Argument, which is the input and output of the
 * analysis module. All the fields that needed either by Passes or PassManagers
 * are contained in Argument.
 *
 * TODO(Superjomn) Find some way better to contain the fields when it grow too
 * big.
 */

G
gongweibao 已提交
24 25
#pragma once

26
#include <map>
N
nhzlx 已提交
27
#include <memory>
G
gongweibao 已提交
28
#include <string>
N
nhzlx 已提交
29 30
#include <unordered_map>
#include <unordered_set>
31
#include <utility>
32
#include <vector>
N
nhzlx 已提交
33

34
#include "paddle/fluid/framework/ir/graph.h"
Y
Yan Chunwei 已提交
35
#include "paddle/fluid/framework/program_desc.h"
36
#include "paddle/fluid/framework/scope.h"
N
nhzlx 已提交
37
#include "paddle/fluid/inference/api/paddle_analysis_config.h"
38

39
#include "paddle/phi/common/data_type.h"
Y
Yan Chunwei 已提交
40 41 42 43

namespace paddle {
namespace inference {
namespace analysis {
44

45
#ifdef PADDLE_WITH_MKLDNN
46
using VarQuantScale =
47
    std::unordered_map<std::string, std::pair<bool, phi::DenseTensor>>;
48
#endif
Y
Yan Chunwei 已提交
49 50 51 52 53 54 55

/*
 * The argument definition of both Pass and PassManagers.
 *
 * All the fields should be registered here for clearness.
 */
struct Argument {
Y
Yan Chunwei 已提交
56
  Argument() = default;
57 58 59 60
  explicit Argument(const std::string& model_dir) { SetModelDir(model_dir); }

  using unique_ptr_t = std::unique_ptr<void, std::function<void(void*)>>;
  using fusion_statis_t = std::unordered_map<std::string, int>;
61
  using input_shape_t = std::map<std::string, std::vector<int>>;
62 63

  bool Has(const std::string& key) const { return valid_fields_.count(key); }
64 65 66
  // If we set the model using config.SetModelBuffer,
  // the model and parameter will occupy additional CPU resources.
  // Use this interface to release these resources.
67 68 69 70 71 72 73 74 75 76
  void PartiallyRelease() {
    if (Has("model_program_path")) {
      if (Has("model_from_memory") && model_from_memory()) {
        model_program_path().clear();
        model_program_path().shrink_to_fit();
        model_params_path().clear();
        model_params_path().shrink_to_fit();
      }
    }
  }
77

78 79 80 81
#define DECL_ARGUMENT_FIELD(field__, Field, type__)                      \
 public:                                                                 \
  type__& field__() {                                                    \
    PADDLE_ENFORCE_EQ(                                                   \
W
Wilber 已提交
82 83
        Has(#field__),                                                   \
        true,                                                            \
84 85 86 87 88 89 90 91 92 93 94
        platform::errors::PreconditionNotMet("There is no such field")); \
    return field__##_;                                                   \
  }                                                                      \
  void Set##Field(const type__& x) {                                     \
    field__##_ = x;                                                      \
    valid_fields_.insert(#field__);                                      \
  }                                                                      \
  DECL_ARGUMENT_FIELD_VALID(field__);                                    \
  type__* field__##_ptr() { return &field__##_; }                        \
                                                                         \
 private:                                                                \
95 96 97 98 99
  type__ field__##_;

#define DECL_ARGUMENT_FIELD_VALID(field__) \
  bool field__##_valid() { return Has(#field__); }

W
Wilber 已提交
100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138
#define DECL_ARGUMENT_UNIQUE_FIELD(field__, Field, type__)                  \
 public:                                                                    \
  type__& field__() {                                                       \
    PADDLE_ENFORCE_NOT_NULL(                                                \
        field__##_,                                                         \
        platform::errors::PreconditionNotMet("filed should not be null.")); \
    PADDLE_ENFORCE_EQ(                                                      \
        Has(#field__),                                                      \
        true,                                                               \
        platform::errors::PreconditionNotMet("There is no such field"));    \
    return *static_cast<type__*>(field__##_.get());                         \
  }                                                                         \
  void Set##Field(type__* x) {                                              \
    field__##_ =                                                            \
        unique_ptr_t(x, [](void* x) { delete static_cast<type__*>(x); });   \
    valid_fields_.insert(#field__);                                         \
  }                                                                         \
  void Set##Field##NotOwned(type__* x) {                                    \
    valid_fields_.insert(#field__);                                         \
    field__##_ = unique_ptr_t(x, [](void* x) {});                           \
  }                                                                         \
  DECL_ARGUMENT_FIELD_VALID(field__);                                       \
  type__* field__##_ptr() {                                                 \
    PADDLE_ENFORCE_EQ(                                                      \
        Has(#field__),                                                      \
        true,                                                               \
        platform::errors::PreconditionNotMet("There is no such field"));    \
    return static_cast<type__*>(field__##_.get());                          \
  }                                                                         \
  type__* Release##Field() {                                                \
    PADDLE_ENFORCE_EQ(                                                      \
        Has(#field__),                                                      \
        true,                                                               \
        platform::errors::PreconditionNotMet("There is no such field"));    \
    valid_fields_.erase(#field__);                                          \
    return static_cast<type__*>(field__##_.release());                      \
  }                                                                         \
                                                                            \
 private:                                                                   \
139 140
  unique_ptr_t field__##_;

141
  DECL_ARGUMENT_FIELD(predictor_id, PredictorID, int);
142
  DECL_ARGUMENT_FIELD(root_predictor_id, RootPredictorID, int);
143 144 145 146 147
  // Model path
  DECL_ARGUMENT_FIELD(model_dir, ModelDir, std::string);
  // Model specified with program and parameters files.
  DECL_ARGUMENT_FIELD(model_program_path, ModelProgramPath, std::string);
  DECL_ARGUMENT_FIELD(model_params_path, ModelParamsPath, std::string);
T
Tao Luo 已提交
148
  DECL_ARGUMENT_FIELD(model_from_memory, ModelFromMemory, bool);
149
  DECL_ARGUMENT_FIELD(optim_cache_dir, OptimCacheDir, std::string);
150
  DECL_ARGUMENT_FIELD(enable_ir_optim, EnableIrOptim, bool);
151

152 153 154
  // For JITLayer
  DECL_ARGUMENT_FIELD(skip_load_params, SkipLoadParams, bool);

155 156 157 158 159
  // The overall graph to work on.
  DECL_ARGUMENT_UNIQUE_FIELD(main_graph, MainGraph, framework::ir::Graph);
  // The overall Scope to work on.
  DECL_ARGUMENT_UNIQUE_FIELD(scope, Scope, framework::Scope);

Y
Yan Chunwei 已提交
160
  // The default program, loaded from disk.
161 162 163
  DECL_ARGUMENT_UNIQUE_FIELD(main_program, MainProgram, framework::ProgramDesc);

  // The ir passes to perform in analysis phase.
W
Wilber 已提交
164 165
  DECL_ARGUMENT_FIELD(ir_analysis_passes,
                      IrAnalysisPasses,
166
                      std::vector<std::string>);
W
Wilber 已提交
167 168
  DECL_ARGUMENT_FIELD(analysis_passes,
                      AnalysisPasses,
Y
Yan Chunwei 已提交
169
                      std::vector<std::string>);
170

171 172 173
  // whether to mute all logs in inference.
  DECL_ARGUMENT_FIELD(disable_logs, DisableLogs, bool);

174
  // Pass a set of op types to enable its mkldnn kernel
W
Wilber 已提交
175 176
  DECL_ARGUMENT_FIELD(mkldnn_enabled_op_types,
                      MKLDNNEnabledOpTypes,
177
                      std::unordered_set<std::string>);
178 179
  // The cache capacity of different input shapes for mkldnn.
  DECL_ARGUMENT_FIELD(mkldnn_cache_capacity, MkldnnCacheCapacity, int);
180

181
#ifdef PADDLE_WITH_MKLDNN
182
  // A set of op types to enable their quantized kernels
W
Wilber 已提交
183 184
  DECL_ARGUMENT_FIELD(quantize_enabled_op_types,
                      QuantizeEnabledOpTypes,
185 186 187
                      std::unordered_set<std::string>);

  // A set of op IDs to exclude from enabling their quantized kernels
W
Wilber 已提交
188 189
  DECL_ARGUMENT_FIELD(quantize_excluded_op_ids,
                      QuantizeExcludedOpIds,
190 191
                      std::unordered_set<int>);

192 193
  // Scales for variables to be quantized
  DECL_ARGUMENT_FIELD(quant_var_scales, QuantVarScales, VarQuantScale);
194 195

  // A set of op types to enable their bfloat16 kernels
W
Wilber 已提交
196 197
  DECL_ARGUMENT_FIELD(bfloat16_enabled_op_types,
                      Bfloat16EnabledOpTypes,
198
                      std::unordered_set<std::string>);
B
baoachun 已提交
199 200

  DECL_ARGUMENT_FIELD(use_mkldnn_int8, UseMkldnnInt8, bool);
201
#endif
202

Y
Yan Chunwei 已提交
203
  // Passed from config.
204
  DECL_ARGUMENT_FIELD(use_gpu, UseGPU, bool);
205
  DECL_ARGUMENT_FIELD(use_fc_padding, UseFcPadding, bool);
S
superjomn 已提交
206
  DECL_ARGUMENT_FIELD(gpu_device_id, GPUDeviceId, int);
207

208 209 210 211
  // Usually use for trt dynamic shape.
  // TRT will select the best kernel according to opt shape
  // Setting the disable_trt_plugin_fp16 to true means that TRT plugin will not
  // run fp16.
212 213 214
  DECL_ARGUMENT_FIELD(min_input_shape, MinInputShape, input_shape_t);
  DECL_ARGUMENT_FIELD(max_input_shape, MaxInputShape, input_shape_t);
  DECL_ARGUMENT_FIELD(optim_input_shape, OptimInputShape, input_shape_t);
215
  DECL_ARGUMENT_FIELD(disable_trt_plugin_fp16, CloseTrtPluginFp16, bool);
216

217
  DECL_ARGUMENT_FIELD(use_tensorrt, UseTensorRT, bool);
218 219
  DECL_ARGUMENT_FIELD(tensorrt_use_dla, TensorRtUseDLA, bool);
  DECL_ARGUMENT_FIELD(tensorrt_dla_core, TensorRtDLACore, int);
220
  DECL_ARGUMENT_FIELD(tensorrt_max_batch_size, TensorRtMaxBatchSize, int);
221
  DECL_ARGUMENT_FIELD(tensorrt_workspace_size, TensorRtWorkspaceSize, int64_t);
222
  DECL_ARGUMENT_FIELD(tensorrt_min_subgraph_size, TensorRtMinSubgraphSize, int);
W
Wilber 已提交
223 224
  DECL_ARGUMENT_FIELD(tensorrt_disabled_ops,
                      TensorRtDisabledOPs,
225
                      std::vector<std::string>);
W
Wilber 已提交
226 227
  DECL_ARGUMENT_FIELD(tensorrt_precision_mode,
                      TensorRtPrecisionMode,
228
                      AnalysisConfig::Precision);
W
Wilber 已提交
229 230
  DECL_ARGUMENT_FIELD(tensorrt_use_static_engine,
                      TensorRtUseStaticEngine,
N
nhzlx 已提交
231
                      bool);
232
  DECL_ARGUMENT_FIELD(tensorrt_use_calib_mode, TensorRtUseCalibMode, bool);
233
  DECL_ARGUMENT_FIELD(tensorrt_use_varseqlen, TensorRtUseOSS, bool);
234
  DECL_ARGUMENT_FIELD(tensorrt_with_interleaved, TensorRtWithInterleaved, bool);
W
Wilber 已提交
235 236
  DECL_ARGUMENT_FIELD(tensorrt_transformer_posid,
                      TensorRtTransformerPosid,
237
                      std::string);
W
Wilber 已提交
238 239
  DECL_ARGUMENT_FIELD(tensorrt_transformer_maskid,
                      TensorRtTransformerMaskid,
240
                      std::string);
241
  DECL_ARGUMENT_FIELD(tensorrt_shape_range_info_path,
W
Wilber 已提交
242 243 244 245
                      TensorRtShapeRangeInfoPath,
                      std::string);
  DECL_ARGUMENT_FIELD(tensorrt_tuned_dynamic_shape,
                      TensorRtTunedDynamicShape,
246 247
                      bool);
  DECL_ARGUMENT_FIELD(tensorrt_allow_build_at_runtime,
W
Wilber 已提交
248 249
                      TensorRtAllowBuildAtRuntime,
                      bool);
250
  DECL_ARGUMENT_FIELD(tensorrt_use_inspector, TensorRtUseInspector, bool);
251

D
denglin-github 已提交
252 253 254
  DECL_ARGUMENT_FIELD(use_dlnne, UseDlnne, bool);
  DECL_ARGUMENT_FIELD(dlnne_min_subgraph_size, DlnneMinSubgraphSize, int);
  DECL_ARGUMENT_FIELD(dlnne_max_batch_size, DlnneMaxBatchSize, int);
D
denglin-github 已提交
255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270
  DECL_ARGUMENT_FIELD(dlnne_use_static_batch, DlnneUseStaticBatch, bool);
  DECL_ARGUMENT_FIELD(dlnne_weight_share_mode,
                      DlnneWeightShareMode,
                      std::string);
  DECL_ARGUMENT_FIELD(dlnne_disable_nodes_by_outputs,
                      DlnneDisableNodesByOutputs,
                      std::unordered_set<std::string>);
  DECL_ARGUMENT_FIELD(dlnne_use_calib_mode, DlnneUseCalibMode, bool);
  DECL_ARGUMENT_FIELD(dlnne_precision_mode,
                      DlnnePrecisionMode,
                      AnalysisConfig::Precision);

  using dlnne_input_shape_type = std::map<std::string, std::vector<int64_t>>;
  DECL_ARGUMENT_FIELD(dlnne_input_shape_dict,
                      DlnneInputShapeDict,
                      dlnne_input_shape_type);
D
denglin-github 已提交
271 272
  DECL_ARGUMENT_FIELD(dlnne_workspace_size, DlnneWorkspaceSize, int);

W
Wilber 已提交
273 274
  DECL_ARGUMENT_FIELD(lite_passes_filter,
                      LitePassesFilter,
石晓伟 已提交
275 276
                      std::vector<std::string>);
  DECL_ARGUMENT_FIELD(lite_ops_filter, LiteOpsFilter, std::vector<std::string>);
W
Wilber 已提交
277 278
  DECL_ARGUMENT_FIELD(lite_precision_mode,
                      LitePrecisionMode,
石晓伟 已提交
279
                      AnalysisConfig::Precision);
280 281 282 283
  DECL_ARGUMENT_FIELD(lite_zero_copy, LiteZeroCopy, bool);

  DECL_ARGUMENT_FIELD(use_xpu, UseXpu, bool);
  DECL_ARGUMENT_FIELD(xpu_l3_workspace_size, XpuL3WorkspaceSize, int);
W
Wilber 已提交
284 285 286 287 288
  DECL_ARGUMENT_FIELD(xpu_locked, XpuLocked, bool);
  DECL_ARGUMENT_FIELD(xpu_autotune, XpuAutotune, bool);
  DECL_ARGUMENT_FIELD(xpu_autotune_file, XpuAutotuneFile, std::string);
  DECL_ARGUMENT_FIELD(xpu_precision, XpuPrecision, std::string);
  DECL_ARGUMENT_FIELD(xpu_adaptive_seqlen, XpuAdaptiveSeqlen, bool);
289
  DECL_ARGUMENT_FIELD(xpu_device_id, XpuDeviceId, int);
290
  DECL_ARGUMENT_FIELD(xpu_enable_multi_stream, XpuEnableMultiStream, bool);
石晓伟 已提交
291

292 293
  DECL_ARGUMENT_FIELD(use_opencl, UseOpenCL, bool);

294
  DECL_ARGUMENT_FIELD(use_nnadapter, UseNNAdapter, bool);
W
Wilber 已提交
295 296
  DECL_ARGUMENT_FIELD(nnadapter_model_cache_dir,
                      NNAdapterModelCacheDir,
297
                      std::string);
W
Wilber 已提交
298 299
  DECL_ARGUMENT_FIELD(nnadapter_device_names,
                      NNAdapterDeviceNames,
300
                      std::vector<std::string>);
W
Wilber 已提交
301 302
  DECL_ARGUMENT_FIELD(nnadapter_context_properties,
                      NNAdapterContextProperties,
303 304
                      std::string);
  DECL_ARGUMENT_FIELD(nnadapter_subgraph_partition_config_buffer,
W
Wilber 已提交
305 306
                      NNAdapterSubgraphPartitionConfigBuffer,
                      std::string);
307
  DECL_ARGUMENT_FIELD(nnadapter_subgraph_partition_config_path,
W
Wilber 已提交
308 309 310 311
                      NNAdapterSubgraphPartitionConfigPath,
                      std::string);
  DECL_ARGUMENT_FIELD(nnadapter_model_cache_token,
                      NNAdapterModelCacheToken,
312
                      std::vector<std::string>);
W
Wilber 已提交
313 314
  DECL_ARGUMENT_FIELD(nnadapter_model_cache_buffer,
                      NNAdapterModelCacheBuffer,
315 316
                      std::vector<std::vector<char>>);

Y
Yan Chunwei 已提交
317 318
  // Memory optimized related.
  DECL_ARGUMENT_FIELD(enable_memory_optim, EnableMemoryOptim, bool);
319
  DECL_ARGUMENT_FIELD(trt_engine_memory_sharing, TrtEngineMemorySharing, bool);
320

Y
Yan Chunwei 已提交
321 322 323 324
  // Indicate which kind of sort algorithm is used for operators, the memory
  // optimization relays on the sort algorithm.
  DECL_ARGUMENT_FIELD(memory_optim_sort_kind, MemoryOptimSortKind, int);

325
  // The program transformed by IR analysis phase.
W
Wilber 已提交
326 327
  DECL_ARGUMENT_UNIQUE_FIELD(ir_analyzed_program,
                             IrAnalyzedProgram,
328 329 330
                             framework::proto::ProgramDesc);

  DECL_ARGUMENT_FIELD(fusion_statis, FusionStatis, fusion_statis_t);
331

W
Wilber 已提交
332
  // Only used in paddle-lite subgraph.
W
Wilber 已提交
333 334
  DECL_ARGUMENT_FIELD(cpu_math_library_num_threads,
                      CpuMathLibraryNumThreads,
W
Wilber 已提交
335 336
                      int);

J
jianghaicheng 已提交
337 338 339
  // ipu related
  DECL_ARGUMENT_FIELD(use_ipu, UseIpu, bool);
  DECL_ARGUMENT_FIELD(ipu_device_num, IpuDeviceNum, int);
340
  DECL_ARGUMENT_FIELD(ipu_micro_batch_size, IpuMicroBatchSize, int);
J
jianghaicheng 已提交
341 342
  DECL_ARGUMENT_FIELD(ipu_enable_pipelining, IpuEnablePipelining, bool);
  DECL_ARGUMENT_FIELD(ipu_batches_per_step, IpuBatchesPerStep, int);
343 344 345
  DECL_ARGUMENT_FIELD(ipu_enable_fp16, IpuEnableFp16, bool);
  DECL_ARGUMENT_FIELD(ipu_replica_num, IpuReplicaNum, int);
  DECL_ARGUMENT_FIELD(ipu_available_memory_proportion,
W
Wilber 已提交
346 347
                      IpuAvailableMemoryProportion,
                      float);
348
  DECL_ARGUMENT_FIELD(ipu_enable_half_partial, IpuEnableHalfPartial, bool);
349 350 351 352 353 354
  DECL_ARGUMENT_FIELD(ipu_custom_ops_info,
                      IpuCustomOpsInfo,
                      std::vector<std::vector<std::string>>);
  DECL_ARGUMENT_FIELD(ipu_custom_patterns,
                      IpuCustomPatterns,
                      std::vector<std::vector<std::string>>);
355 356 357
  DECL_ARGUMENT_FIELD(ipu_enable_model_runtime_executor,
                      IpuEnableModelRuntimeExecutor,
                      bool);
J
jianghaicheng 已提交
358

359 360 361 362
  // npu related
  DECL_ARGUMENT_FIELD(use_npu, UseNpu, bool);
  DECL_ARGUMENT_FIELD(npu_device_id, NPUDeviceId, int);

363 364
  // mixed precision related
  DECL_ARGUMENT_FIELD(model_precision, ModelPrecision, int);
365 366 367
  DECL_ARGUMENT_FIELD(mixed_black_list,
                      MixedBlackList,
                      std::unordered_set<std::string>);
368 369
  DECL_ARGUMENT_FIELD(enable_gpu_half, EnableGPUHalf, bool);
  DECL_ARGUMENT_FIELD(mixed_precision_mode, MixedPrecisionMode, int);
370

371 372 373
  // cinn compiler related
  DECL_ARGUMENT_FIELD(use_cinn_compiler, UseCinnCompiler, bool);

374
 private:
375
  std::unordered_set<std::string> valid_fields_;
Y
Yan Chunwei 已提交
376 377
};

378
#define ARGUMENT_CHECK_FIELD(argument__, fieldname__) \
379
  PADDLE_ENFORCE_EQ(                                  \
W
Wilber 已提交
380 381
      argument__->Has(#fieldname__),                  \
      true,                                           \
382 383
      platform::errors::PreconditionNotMet(           \
          "the argument field [%s] should be set", #fieldname__));
Y
Yan Chunwei 已提交
384 385 386 387

}  // namespace analysis
}  // namespace inference
}  // namespace paddle