argument.h 17.5 KB
Newer Older
Y
Yan Chunwei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

/*
 * This file defines the class Argument, which is the input and output of the
 * analysis module. All the fields that needed either by Passes or PassManagers
 * are contained in Argument.
 *
 * TODO(Superjomn) Find some way better to contain the fields when it grow too
 * big.
 */

G
gongweibao 已提交
24 25
#pragma once

26
#include <map>
N
nhzlx 已提交
27
#include <memory>
G
gongweibao 已提交
28
#include <string>
N
nhzlx 已提交
29 30
#include <unordered_map>
#include <unordered_set>
31
#include <utility>
32
#include <vector>
N
nhzlx 已提交
33

34
#include "paddle/fluid/framework/ir/graph.h"
Y
Yan Chunwei 已提交
35
#include "paddle/fluid/framework/program_desc.h"
36
#include "paddle/fluid/framework/scope.h"
N
nhzlx 已提交
37
#include "paddle/fluid/inference/api/paddle_analysis_config.h"
38

39
#include "paddle/phi/common/data_type.h"
Y
Yan Chunwei 已提交
40 41 42 43

namespace paddle {
namespace inference {
namespace analysis {
44

45
#ifdef PADDLE_WITH_MKLDNN
46
using VarQuantScale =
47
    std::unordered_map<std::string, std::pair<bool, phi::DenseTensor>>;
48
#endif
Y
Yan Chunwei 已提交
49 50 51 52 53 54 55

/*
 * The argument definition of both Pass and PassManagers.
 *
 * All the fields should be registered here for clearness.
 */
struct Argument {
Y
Yan Chunwei 已提交
56
  Argument() = default;
57 58 59 60
  explicit Argument(const std::string& model_dir) { SetModelDir(model_dir); }

  using unique_ptr_t = std::unique_ptr<void, std::function<void(void*)>>;
  using fusion_statis_t = std::unordered_map<std::string, int>;
61
  using input_shape_t = std::map<std::string, std::vector<int>>;
62 63

  bool Has(const std::string& key) const { return valid_fields_.count(key); }
64 65 66
  // If we set the model using config.SetModelBuffer,
  // the model and parameter will occupy additional CPU resources.
  // Use this interface to release these resources.
67 68 69 70 71 72 73 74 75 76
  void PartiallyRelease() {
    if (Has("model_program_path")) {
      if (Has("model_from_memory") && model_from_memory()) {
        model_program_path().clear();
        model_program_path().shrink_to_fit();
        model_params_path().clear();
        model_params_path().shrink_to_fit();
      }
    }
  }
77

78 79 80 81
#define DECL_ARGUMENT_FIELD(field__, Field, type__)                      \
 public:                                                                 \
  type__& field__() {                                                    \
    PADDLE_ENFORCE_EQ(                                                   \
W
Wilber 已提交
82 83
        Has(#field__),                                                   \
        true,                                                            \
84 85 86 87 88 89 90 91 92 93 94
        platform::errors::PreconditionNotMet("There is no such field")); \
    return field__##_;                                                   \
  }                                                                      \
  void Set##Field(const type__& x) {                                     \
    field__##_ = x;                                                      \
    valid_fields_.insert(#field__);                                      \
  }                                                                      \
  DECL_ARGUMENT_FIELD_VALID(field__);                                    \
  type__* field__##_ptr() { return &field__##_; }                        \
                                                                         \
 private:                                                                \
95 96 97 98 99
  type__ field__##_;

#define DECL_ARGUMENT_FIELD_VALID(field__) \
  bool field__##_valid() { return Has(#field__); }

W
Wilber 已提交
100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138
#define DECL_ARGUMENT_UNIQUE_FIELD(field__, Field, type__)                  \
 public:                                                                    \
  type__& field__() {                                                       \
    PADDLE_ENFORCE_NOT_NULL(                                                \
        field__##_,                                                         \
        platform::errors::PreconditionNotMet("filed should not be null.")); \
    PADDLE_ENFORCE_EQ(                                                      \
        Has(#field__),                                                      \
        true,                                                               \
        platform::errors::PreconditionNotMet("There is no such field"));    \
    return *static_cast<type__*>(field__##_.get());                         \
  }                                                                         \
  void Set##Field(type__* x) {                                              \
    field__##_ =                                                            \
        unique_ptr_t(x, [](void* x) { delete static_cast<type__*>(x); });   \
    valid_fields_.insert(#field__);                                         \
  }                                                                         \
  void Set##Field##NotOwned(type__* x) {                                    \
    valid_fields_.insert(#field__);                                         \
    field__##_ = unique_ptr_t(x, [](void* x) {});                           \
  }                                                                         \
  DECL_ARGUMENT_FIELD_VALID(field__);                                       \
  type__* field__##_ptr() {                                                 \
    PADDLE_ENFORCE_EQ(                                                      \
        Has(#field__),                                                      \
        true,                                                               \
        platform::errors::PreconditionNotMet("There is no such field"));    \
    return static_cast<type__*>(field__##_.get());                          \
  }                                                                         \
  type__* Release##Field() {                                                \
    PADDLE_ENFORCE_EQ(                                                      \
        Has(#field__),                                                      \
        true,                                                               \
        platform::errors::PreconditionNotMet("There is no such field"));    \
    valid_fields_.erase(#field__);                                          \
    return static_cast<type__*>(field__##_.release());                      \
  }                                                                         \
                                                                            \
 private:                                                                   \
139 140
  unique_ptr_t field__##_;

141
  DECL_ARGUMENT_FIELD(predictor_id, PredictorID, int);
142 143 144 145 146
  // Model path
  DECL_ARGUMENT_FIELD(model_dir, ModelDir, std::string);
  // Model specified with program and parameters files.
  DECL_ARGUMENT_FIELD(model_program_path, ModelProgramPath, std::string);
  DECL_ARGUMENT_FIELD(model_params_path, ModelParamsPath, std::string);
T
Tao Luo 已提交
147
  DECL_ARGUMENT_FIELD(model_from_memory, ModelFromMemory, bool);
148
  DECL_ARGUMENT_FIELD(optim_cache_dir, OptimCacheDir, std::string);
149
  DECL_ARGUMENT_FIELD(enable_ir_optim, EnableIrOptim, bool);
150

151 152 153
  // For JITLayer
  DECL_ARGUMENT_FIELD(skip_load_params, SkipLoadParams, bool);

154 155 156 157 158
  // The overall graph to work on.
  DECL_ARGUMENT_UNIQUE_FIELD(main_graph, MainGraph, framework::ir::Graph);
  // The overall Scope to work on.
  DECL_ARGUMENT_UNIQUE_FIELD(scope, Scope, framework::Scope);

Y
Yan Chunwei 已提交
159
  // The default program, loaded from disk.
160 161 162
  DECL_ARGUMENT_UNIQUE_FIELD(main_program, MainProgram, framework::ProgramDesc);

  // The ir passes to perform in analysis phase.
W
Wilber 已提交
163 164
  DECL_ARGUMENT_FIELD(ir_analysis_passes,
                      IrAnalysisPasses,
165
                      std::vector<std::string>);
W
Wilber 已提交
166 167
  DECL_ARGUMENT_FIELD(analysis_passes,
                      AnalysisPasses,
Y
Yan Chunwei 已提交
168
                      std::vector<std::string>);
169

170 171 172
  // whether to mute all logs in inference.
  DECL_ARGUMENT_FIELD(disable_logs, DisableLogs, bool);

173
  // Pass a set of op types to enable its mkldnn kernel
W
Wilber 已提交
174 175
  DECL_ARGUMENT_FIELD(mkldnn_enabled_op_types,
                      MKLDNNEnabledOpTypes,
176
                      std::unordered_set<std::string>);
177 178
  // The cache capacity of different input shapes for mkldnn.
  DECL_ARGUMENT_FIELD(mkldnn_cache_capacity, MkldnnCacheCapacity, int);
179

180
#ifdef PADDLE_WITH_MKLDNN
181
  // A set of op types to enable their quantized kernels
W
Wilber 已提交
182 183
  DECL_ARGUMENT_FIELD(quantize_enabled_op_types,
                      QuantizeEnabledOpTypes,
184 185 186
                      std::unordered_set<std::string>);

  // A set of op IDs to exclude from enabling their quantized kernels
W
Wilber 已提交
187 188
  DECL_ARGUMENT_FIELD(quantize_excluded_op_ids,
                      QuantizeExcludedOpIds,
189 190
                      std::unordered_set<int>);

191 192
  // Scales for variables to be quantized
  DECL_ARGUMENT_FIELD(quant_var_scales, QuantVarScales, VarQuantScale);
193 194

  // A set of op types to enable their bfloat16 kernels
W
Wilber 已提交
195 196
  DECL_ARGUMENT_FIELD(bfloat16_enabled_op_types,
                      Bfloat16EnabledOpTypes,
197
                      std::unordered_set<std::string>);
B
baoachun 已提交
198 199

  DECL_ARGUMENT_FIELD(use_mkldnn_int8, UseMkldnnInt8, bool);
200
#endif
201

Y
Yan Chunwei 已提交
202
  // Passed from config.
203
  DECL_ARGUMENT_FIELD(use_gpu, UseGPU, bool);
204
  DECL_ARGUMENT_FIELD(use_fc_padding, UseFcPadding, bool);
S
superjomn 已提交
205
  DECL_ARGUMENT_FIELD(gpu_device_id, GPUDeviceId, int);
206

207 208 209 210
  // Usually use for trt dynamic shape.
  // TRT will select the best kernel according to opt shape
  // Setting the disable_trt_plugin_fp16 to true means that TRT plugin will not
  // run fp16.
211 212 213
  DECL_ARGUMENT_FIELD(min_input_shape, MinInputShape, input_shape_t);
  DECL_ARGUMENT_FIELD(max_input_shape, MaxInputShape, input_shape_t);
  DECL_ARGUMENT_FIELD(optim_input_shape, OptimInputShape, input_shape_t);
214
  DECL_ARGUMENT_FIELD(disable_trt_plugin_fp16, CloseTrtPluginFp16, bool);
215

216
  DECL_ARGUMENT_FIELD(use_tensorrt, UseTensorRT, bool);
217 218
  DECL_ARGUMENT_FIELD(tensorrt_use_dla, TensorRtUseDLA, bool);
  DECL_ARGUMENT_FIELD(tensorrt_dla_core, TensorRtDLACore, int);
219
  DECL_ARGUMENT_FIELD(tensorrt_max_batch_size, TensorRtMaxBatchSize, int);
220
  DECL_ARGUMENT_FIELD(tensorrt_workspace_size, TensorRtWorkspaceSize, int64_t);
221
  DECL_ARGUMENT_FIELD(tensorrt_min_subgraph_size, TensorRtMinSubgraphSize, int);
W
Wilber 已提交
222 223
  DECL_ARGUMENT_FIELD(tensorrt_disabled_ops,
                      TensorRtDisabledOPs,
224
                      std::vector<std::string>);
W
Wilber 已提交
225 226
  DECL_ARGUMENT_FIELD(tensorrt_precision_mode,
                      TensorRtPrecisionMode,
227
                      AnalysisConfig::Precision);
W
Wilber 已提交
228 229
  DECL_ARGUMENT_FIELD(tensorrt_use_static_engine,
                      TensorRtUseStaticEngine,
N
nhzlx 已提交
230
                      bool);
231
  DECL_ARGUMENT_FIELD(tensorrt_use_calib_mode, TensorRtUseCalibMode, bool);
232
  DECL_ARGUMENT_FIELD(tensorrt_use_varseqlen, TensorRtUseOSS, bool);
233
  DECL_ARGUMENT_FIELD(tensorrt_with_interleaved, TensorRtWithInterleaved, bool);
W
Wilber 已提交
234 235
  DECL_ARGUMENT_FIELD(tensorrt_transformer_posid,
                      TensorRtTransformerPosid,
236
                      std::string);
W
Wilber 已提交
237 238
  DECL_ARGUMENT_FIELD(tensorrt_transformer_maskid,
                      TensorRtTransformerMaskid,
239
                      std::string);
240
  DECL_ARGUMENT_FIELD(tensorrt_shape_range_info_path,
W
Wilber 已提交
241 242 243 244
                      TensorRtShapeRangeInfoPath,
                      std::string);
  DECL_ARGUMENT_FIELD(tensorrt_tuned_dynamic_shape,
                      TensorRtTunedDynamicShape,
245 246
                      bool);
  DECL_ARGUMENT_FIELD(tensorrt_allow_build_at_runtime,
W
Wilber 已提交
247 248
                      TensorRtAllowBuildAtRuntime,
                      bool);
249
  DECL_ARGUMENT_FIELD(tensorrt_use_inspector, TensorRtUseInspector, bool);
250

D
denglin-github 已提交
251 252 253
  DECL_ARGUMENT_FIELD(use_dlnne, UseDlnne, bool);
  DECL_ARGUMENT_FIELD(dlnne_min_subgraph_size, DlnneMinSubgraphSize, int);
  DECL_ARGUMENT_FIELD(dlnne_max_batch_size, DlnneMaxBatchSize, int);
D
denglin-github 已提交
254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269
  DECL_ARGUMENT_FIELD(dlnne_use_static_batch, DlnneUseStaticBatch, bool);
  DECL_ARGUMENT_FIELD(dlnne_weight_share_mode,
                      DlnneWeightShareMode,
                      std::string);
  DECL_ARGUMENT_FIELD(dlnne_disable_nodes_by_outputs,
                      DlnneDisableNodesByOutputs,
                      std::unordered_set<std::string>);
  DECL_ARGUMENT_FIELD(dlnne_use_calib_mode, DlnneUseCalibMode, bool);
  DECL_ARGUMENT_FIELD(dlnne_precision_mode,
                      DlnnePrecisionMode,
                      AnalysisConfig::Precision);

  using dlnne_input_shape_type = std::map<std::string, std::vector<int64_t>>;
  DECL_ARGUMENT_FIELD(dlnne_input_shape_dict,
                      DlnneInputShapeDict,
                      dlnne_input_shape_type);
D
denglin-github 已提交
270 271
  DECL_ARGUMENT_FIELD(dlnne_workspace_size, DlnneWorkspaceSize, int);

W
Wilber 已提交
272 273
  DECL_ARGUMENT_FIELD(lite_passes_filter,
                      LitePassesFilter,
石晓伟 已提交
274 275
                      std::vector<std::string>);
  DECL_ARGUMENT_FIELD(lite_ops_filter, LiteOpsFilter, std::vector<std::string>);
W
Wilber 已提交
276 277
  DECL_ARGUMENT_FIELD(lite_precision_mode,
                      LitePrecisionMode,
石晓伟 已提交
278
                      AnalysisConfig::Precision);
279 280 281 282
  DECL_ARGUMENT_FIELD(lite_zero_copy, LiteZeroCopy, bool);

  DECL_ARGUMENT_FIELD(use_xpu, UseXpu, bool);
  DECL_ARGUMENT_FIELD(xpu_l3_workspace_size, XpuL3WorkspaceSize, int);
W
Wilber 已提交
283 284 285 286 287
  DECL_ARGUMENT_FIELD(xpu_locked, XpuLocked, bool);
  DECL_ARGUMENT_FIELD(xpu_autotune, XpuAutotune, bool);
  DECL_ARGUMENT_FIELD(xpu_autotune_file, XpuAutotuneFile, std::string);
  DECL_ARGUMENT_FIELD(xpu_precision, XpuPrecision, std::string);
  DECL_ARGUMENT_FIELD(xpu_adaptive_seqlen, XpuAdaptiveSeqlen, bool);
288
  DECL_ARGUMENT_FIELD(xpu_device_id, XpuDeviceId, int);
289
  DECL_ARGUMENT_FIELD(xpu_enable_multi_stream, XpuEnableMultiStream, bool);
石晓伟 已提交
290

291 292
  DECL_ARGUMENT_FIELD(use_opencl, UseOpenCL, bool);

293
  DECL_ARGUMENT_FIELD(use_nnadapter, UseNNAdapter, bool);
W
Wilber 已提交
294 295
  DECL_ARGUMENT_FIELD(nnadapter_model_cache_dir,
                      NNAdapterModelCacheDir,
296
                      std::string);
W
Wilber 已提交
297 298
  DECL_ARGUMENT_FIELD(nnadapter_device_names,
                      NNAdapterDeviceNames,
299
                      std::vector<std::string>);
W
Wilber 已提交
300 301
  DECL_ARGUMENT_FIELD(nnadapter_context_properties,
                      NNAdapterContextProperties,
302 303
                      std::string);
  DECL_ARGUMENT_FIELD(nnadapter_subgraph_partition_config_buffer,
W
Wilber 已提交
304 305
                      NNAdapterSubgraphPartitionConfigBuffer,
                      std::string);
306
  DECL_ARGUMENT_FIELD(nnadapter_subgraph_partition_config_path,
W
Wilber 已提交
307 308 309 310
                      NNAdapterSubgraphPartitionConfigPath,
                      std::string);
  DECL_ARGUMENT_FIELD(nnadapter_model_cache_token,
                      NNAdapterModelCacheToken,
311
                      std::vector<std::string>);
W
Wilber 已提交
312 313
  DECL_ARGUMENT_FIELD(nnadapter_model_cache_buffer,
                      NNAdapterModelCacheBuffer,
314 315
                      std::vector<std::vector<char>>);

Y
Yan Chunwei 已提交
316 317
  // Memory optimized related.
  DECL_ARGUMENT_FIELD(enable_memory_optim, EnableMemoryOptim, bool);
318
  DECL_ARGUMENT_FIELD(trt_engine_memory_sharing, TrtEngineMemorySharing, bool);
319

Y
Yan Chunwei 已提交
320 321 322 323
  // Indicate which kind of sort algorithm is used for operators, the memory
  // optimization relays on the sort algorithm.
  DECL_ARGUMENT_FIELD(memory_optim_sort_kind, MemoryOptimSortKind, int);

324
  // The program transformed by IR analysis phase.
W
Wilber 已提交
325 326
  DECL_ARGUMENT_UNIQUE_FIELD(ir_analyzed_program,
                             IrAnalyzedProgram,
327 328 329
                             framework::proto::ProgramDesc);

  DECL_ARGUMENT_FIELD(fusion_statis, FusionStatis, fusion_statis_t);
330

W
Wilber 已提交
331
  // Only used in paddle-lite subgraph.
W
Wilber 已提交
332 333
  DECL_ARGUMENT_FIELD(cpu_math_library_num_threads,
                      CpuMathLibraryNumThreads,
W
Wilber 已提交
334 335
                      int);

J
jianghaicheng 已提交
336 337 338
  // ipu related
  DECL_ARGUMENT_FIELD(use_ipu, UseIpu, bool);
  DECL_ARGUMENT_FIELD(ipu_device_num, IpuDeviceNum, int);
339
  DECL_ARGUMENT_FIELD(ipu_micro_batch_size, IpuMicroBatchSize, int);
J
jianghaicheng 已提交
340 341
  DECL_ARGUMENT_FIELD(ipu_enable_pipelining, IpuEnablePipelining, bool);
  DECL_ARGUMENT_FIELD(ipu_batches_per_step, IpuBatchesPerStep, int);
342 343 344
  DECL_ARGUMENT_FIELD(ipu_enable_fp16, IpuEnableFp16, bool);
  DECL_ARGUMENT_FIELD(ipu_replica_num, IpuReplicaNum, int);
  DECL_ARGUMENT_FIELD(ipu_available_memory_proportion,
W
Wilber 已提交
345 346
                      IpuAvailableMemoryProportion,
                      float);
347
  DECL_ARGUMENT_FIELD(ipu_enable_half_partial, IpuEnableHalfPartial, bool);
348 349 350 351 352 353
  DECL_ARGUMENT_FIELD(ipu_custom_ops_info,
                      IpuCustomOpsInfo,
                      std::vector<std::vector<std::string>>);
  DECL_ARGUMENT_FIELD(ipu_custom_patterns,
                      IpuCustomPatterns,
                      std::vector<std::vector<std::string>>);
354 355 356
  DECL_ARGUMENT_FIELD(ipu_enable_model_runtime_executor,
                      IpuEnableModelRuntimeExecutor,
                      bool);
J
jianghaicheng 已提交
357

358 359 360 361
  // npu related
  DECL_ARGUMENT_FIELD(use_npu, UseNpu, bool);
  DECL_ARGUMENT_FIELD(npu_device_id, NPUDeviceId, int);

362 363
  // mixed precision related
  DECL_ARGUMENT_FIELD(model_precision, ModelPrecision, int);
364 365 366
  DECL_ARGUMENT_FIELD(mixed_black_list,
                      MixedBlackList,
                      std::unordered_set<std::string>);
367

368
 private:
369
  std::unordered_set<std::string> valid_fields_;
Y
Yan Chunwei 已提交
370 371
};

372
#define ARGUMENT_CHECK_FIELD(argument__, fieldname__) \
373
  PADDLE_ENFORCE_EQ(                                  \
W
Wilber 已提交
374 375
      argument__->Has(#fieldname__),                  \
      true,                                           \
376 377
      platform::errors::PreconditionNotMet(           \
          "the argument field [%s] should be set", #fieldname__));
Y
Yan Chunwei 已提交
378 379 380 381

}  // namespace analysis
}  // namespace inference
}  // namespace paddle