argument.h 17.2 KB
Newer Older
Y
Yan Chunwei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

/*
 * This file defines the class Argument, which is the input and output of the
 * analysis module. All the fields that needed either by Passes or PassManagers
 * are contained in Argument.
 *
 * TODO(Superjomn) Find some way better to contain the fields when it grow too
 * big.
 */

G
gongweibao 已提交
24 25
#pragma once

26
#include <map>
N
nhzlx 已提交
27
#include <memory>
G
gongweibao 已提交
28
#include <string>
N
nhzlx 已提交
29 30
#include <unordered_map>
#include <unordered_set>
31
#include <utility>
32
#include <vector>
N
nhzlx 已提交
33

34
#include "paddle/fluid/framework/ir/graph.h"
Y
Yan Chunwei 已提交
35
#include "paddle/fluid/framework/program_desc.h"
36
#include "paddle/fluid/framework/scope.h"
N
nhzlx 已提交
37
#include "paddle/fluid/inference/api/paddle_analysis_config.h"
38

39
#include "paddle/phi/common/data_type.h"
Y
Yan Chunwei 已提交
40 41 42 43

namespace paddle {
namespace inference {
namespace analysis {
44

45
using framework::ir::Graph;
46 47

#ifdef PADDLE_WITH_MKLDNN
48 49
using VarQuantScale =
    std::unordered_map<std::string, std::pair<bool, framework::LoDTensor>>;
50
#endif
Y
Yan Chunwei 已提交
51 52 53 54 55 56 57

/*
 * The argument definition of both Pass and PassManagers.
 *
 * All the fields should be registered here for clearness.
 */
struct Argument {
Y
Yan Chunwei 已提交
58
  Argument() = default;
59 60 61 62
  explicit Argument(const std::string& model_dir) { SetModelDir(model_dir); }

  using unique_ptr_t = std::unique_ptr<void, std::function<void(void*)>>;
  using fusion_statis_t = std::unordered_map<std::string, int>;
63
  using input_shape_t = std::map<std::string, std::vector<int>>;
64 65

  bool Has(const std::string& key) const { return valid_fields_.count(key); }
66 67 68
  // If we set the model using config.SetModelBuffer,
  // the model and parameter will occupy additional CPU resources.
  // Use this interface to release these resources.
69 70 71 72 73 74 75 76 77 78
  void PartiallyRelease() {
    if (Has("model_program_path")) {
      if (Has("model_from_memory") && model_from_memory()) {
        model_program_path().clear();
        model_program_path().shrink_to_fit();
        model_params_path().clear();
        model_params_path().shrink_to_fit();
      }
    }
  }
79

80 81 82 83
#define DECL_ARGUMENT_FIELD(field__, Field, type__)                      \
 public:                                                                 \
  type__& field__() {                                                    \
    PADDLE_ENFORCE_EQ(                                                   \
W
Wilber 已提交
84 85
        Has(#field__),                                                   \
        true,                                                            \
86 87 88 89 90 91 92 93 94 95 96
        platform::errors::PreconditionNotMet("There is no such field")); \
    return field__##_;                                                   \
  }                                                                      \
  void Set##Field(const type__& x) {                                     \
    field__##_ = x;                                                      \
    valid_fields_.insert(#field__);                                      \
  }                                                                      \
  DECL_ARGUMENT_FIELD_VALID(field__);                                    \
  type__* field__##_ptr() { return &field__##_; }                        \
                                                                         \
 private:                                                                \
97 98 99 100 101
  type__ field__##_;

#define DECL_ARGUMENT_FIELD_VALID(field__) \
  bool field__##_valid() { return Has(#field__); }

W
Wilber 已提交
102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
#define DECL_ARGUMENT_UNIQUE_FIELD(field__, Field, type__)                  \
 public:                                                                    \
  type__& field__() {                                                       \
    PADDLE_ENFORCE_NOT_NULL(                                                \
        field__##_,                                                         \
        platform::errors::PreconditionNotMet("filed should not be null.")); \
    PADDLE_ENFORCE_EQ(                                                      \
        Has(#field__),                                                      \
        true,                                                               \
        platform::errors::PreconditionNotMet("There is no such field"));    \
    return *static_cast<type__*>(field__##_.get());                         \
  }                                                                         \
  void Set##Field(type__* x) {                                              \
    field__##_ =                                                            \
        unique_ptr_t(x, [](void* x) { delete static_cast<type__*>(x); });   \
    valid_fields_.insert(#field__);                                         \
  }                                                                         \
  void Set##Field##NotOwned(type__* x) {                                    \
    valid_fields_.insert(#field__);                                         \
    field__##_ = unique_ptr_t(x, [](void* x) {});                           \
  }                                                                         \
  DECL_ARGUMENT_FIELD_VALID(field__);                                       \
  type__* field__##_ptr() {                                                 \
    PADDLE_ENFORCE_EQ(                                                      \
        Has(#field__),                                                      \
        true,                                                               \
        platform::errors::PreconditionNotMet("There is no such field"));    \
    return static_cast<type__*>(field__##_.get());                          \
  }                                                                         \
  type__* Release##Field() {                                                \
    PADDLE_ENFORCE_EQ(                                                      \
        Has(#field__),                                                      \
        true,                                                               \
        platform::errors::PreconditionNotMet("There is no such field"));    \
    valid_fields_.erase(#field__);                                          \
    return static_cast<type__*>(field__##_.release());                      \
  }                                                                         \
                                                                            \
 private:                                                                   \
141 142
  unique_ptr_t field__##_;

143
  DECL_ARGUMENT_FIELD(predictor_id, PredictorID, int);
144 145 146 147 148
  // Model path
  DECL_ARGUMENT_FIELD(model_dir, ModelDir, std::string);
  // Model specified with program and parameters files.
  DECL_ARGUMENT_FIELD(model_program_path, ModelProgramPath, std::string);
  DECL_ARGUMENT_FIELD(model_params_path, ModelParamsPath, std::string);
T
Tao Luo 已提交
149
  DECL_ARGUMENT_FIELD(model_from_memory, ModelFromMemory, bool);
150
  DECL_ARGUMENT_FIELD(optim_cache_dir, OptimCacheDir, std::string);
151
  DECL_ARGUMENT_FIELD(enable_analysis_optim, EnableAnalysisOptim, bool);
152 153 154 155 156 157

  // The overall graph to work on.
  DECL_ARGUMENT_UNIQUE_FIELD(main_graph, MainGraph, framework::ir::Graph);
  // The overall Scope to work on.
  DECL_ARGUMENT_UNIQUE_FIELD(scope, Scope, framework::Scope);

Y
Yan Chunwei 已提交
158
  // The default program, loaded from disk.
159 160 161
  DECL_ARGUMENT_UNIQUE_FIELD(main_program, MainProgram, framework::ProgramDesc);

  // The ir passes to perform in analysis phase.
W
Wilber 已提交
162 163
  DECL_ARGUMENT_FIELD(ir_analysis_passes,
                      IrAnalysisPasses,
164
                      std::vector<std::string>);
W
Wilber 已提交
165 166
  DECL_ARGUMENT_FIELD(analysis_passes,
                      AnalysisPasses,
Y
Yan Chunwei 已提交
167
                      std::vector<std::string>);
168

169 170 171
  // whether to mute all logs in inference.
  DECL_ARGUMENT_FIELD(disable_logs, DisableLogs, bool);

172
  // Pass a set of op types to enable its mkldnn kernel
W
Wilber 已提交
173 174
  DECL_ARGUMENT_FIELD(mkldnn_enabled_op_types,
                      MKLDNNEnabledOpTypes,
175
                      std::unordered_set<std::string>);
176 177
  // The cache capacity of different input shapes for mkldnn.
  DECL_ARGUMENT_FIELD(mkldnn_cache_capacity, MkldnnCacheCapacity, int);
178

179
#ifdef PADDLE_WITH_MKLDNN
180
  // A set of op types to enable their quantized kernels
W
Wilber 已提交
181 182
  DECL_ARGUMENT_FIELD(quantize_enabled_op_types,
                      QuantizeEnabledOpTypes,
183 184 185
                      std::unordered_set<std::string>);

  // A set of op IDs to exclude from enabling their quantized kernels
W
Wilber 已提交
186 187
  DECL_ARGUMENT_FIELD(quantize_excluded_op_ids,
                      QuantizeExcludedOpIds,
188 189
                      std::unordered_set<int>);

190 191
  // Scales for variables to be quantized
  DECL_ARGUMENT_FIELD(quant_var_scales, QuantVarScales, VarQuantScale);
192 193

  // A set of op types to enable their bfloat16 kernels
W
Wilber 已提交
194 195
  DECL_ARGUMENT_FIELD(bfloat16_enabled_op_types,
                      Bfloat16EnabledOpTypes,
196
                      std::unordered_set<std::string>);
B
baoachun 已提交
197 198

  DECL_ARGUMENT_FIELD(use_mkldnn_int8, UseMkldnnInt8, bool);
199
#endif
200

Y
Yan Chunwei 已提交
201
  // Passed from config.
202
  DECL_ARGUMENT_FIELD(use_gpu, UseGPU, bool);
203
  DECL_ARGUMENT_FIELD(use_fc_padding, UseFcPadding, bool);
S
superjomn 已提交
204
  DECL_ARGUMENT_FIELD(gpu_device_id, GPUDeviceId, int);
205

206 207 208 209
  // Usually use for trt dynamic shape.
  // TRT will select the best kernel according to opt shape
  // Setting the disable_trt_plugin_fp16 to true means that TRT plugin will not
  // run fp16.
210 211 212
  DECL_ARGUMENT_FIELD(min_input_shape, MinInputShape, input_shape_t);
  DECL_ARGUMENT_FIELD(max_input_shape, MaxInputShape, input_shape_t);
  DECL_ARGUMENT_FIELD(optim_input_shape, OptimInputShape, input_shape_t);
213
  DECL_ARGUMENT_FIELD(disable_trt_plugin_fp16, CloseTrtPluginFp16, bool);
214

215
  DECL_ARGUMENT_FIELD(use_tensorrt, UseTensorRT, bool);
216 217
  DECL_ARGUMENT_FIELD(tensorrt_use_dla, TensorRtUseDLA, bool);
  DECL_ARGUMENT_FIELD(tensorrt_dla_core, TensorRtDLACore, int);
218
  DECL_ARGUMENT_FIELD(tensorrt_max_batch_size, TensorRtMaxBatchSize, int);
219
  DECL_ARGUMENT_FIELD(tensorrt_workspace_size, TensorRtWorkspaceSize, int64_t);
220
  DECL_ARGUMENT_FIELD(tensorrt_min_subgraph_size, TensorRtMinSubgraphSize, int);
W
Wilber 已提交
221 222
  DECL_ARGUMENT_FIELD(tensorrt_disabled_ops,
                      TensorRtDisabledOPs,
223
                      std::vector<std::string>);
W
Wilber 已提交
224 225
  DECL_ARGUMENT_FIELD(tensorrt_precision_mode,
                      TensorRtPrecisionMode,
226
                      AnalysisConfig::Precision);
W
Wilber 已提交
227 228
  DECL_ARGUMENT_FIELD(tensorrt_use_static_engine,
                      TensorRtUseStaticEngine,
N
nhzlx 已提交
229
                      bool);
230
  DECL_ARGUMENT_FIELD(tensorrt_use_calib_mode, TensorRtUseCalibMode, bool);
231
  DECL_ARGUMENT_FIELD(tensorrt_use_varseqlen, TensorRtUseOSS, bool);
232
  DECL_ARGUMENT_FIELD(tensorrt_with_interleaved, TensorRtWithInterleaved, bool);
W
Wilber 已提交
233 234
  DECL_ARGUMENT_FIELD(tensorrt_transformer_posid,
                      TensorRtTransformerPosid,
235
                      std::string);
W
Wilber 已提交
236 237
  DECL_ARGUMENT_FIELD(tensorrt_transformer_maskid,
                      TensorRtTransformerMaskid,
238
                      std::string);
239
  DECL_ARGUMENT_FIELD(tensorrt_shape_range_info_path,
W
Wilber 已提交
240 241 242 243
                      TensorRtShapeRangeInfoPath,
                      std::string);
  DECL_ARGUMENT_FIELD(tensorrt_tuned_dynamic_shape,
                      TensorRtTunedDynamicShape,
244 245
                      bool);
  DECL_ARGUMENT_FIELD(tensorrt_allow_build_at_runtime,
W
Wilber 已提交
246 247
                      TensorRtAllowBuildAtRuntime,
                      bool);
248
  DECL_ARGUMENT_FIELD(tensorrt_use_inspector, TensorRtUseInspector, bool);
249

D
denglin-github 已提交
250 251 252
  DECL_ARGUMENT_FIELD(use_dlnne, UseDlnne, bool);
  DECL_ARGUMENT_FIELD(dlnne_min_subgraph_size, DlnneMinSubgraphSize, int);
  DECL_ARGUMENT_FIELD(dlnne_max_batch_size, DlnneMaxBatchSize, int);
D
denglin-github 已提交
253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268
  DECL_ARGUMENT_FIELD(dlnne_use_static_batch, DlnneUseStaticBatch, bool);
  DECL_ARGUMENT_FIELD(dlnne_weight_share_mode,
                      DlnneWeightShareMode,
                      std::string);
  DECL_ARGUMENT_FIELD(dlnne_disable_nodes_by_outputs,
                      DlnneDisableNodesByOutputs,
                      std::unordered_set<std::string>);
  DECL_ARGUMENT_FIELD(dlnne_use_calib_mode, DlnneUseCalibMode, bool);
  DECL_ARGUMENT_FIELD(dlnne_precision_mode,
                      DlnnePrecisionMode,
                      AnalysisConfig::Precision);

  using dlnne_input_shape_type = std::map<std::string, std::vector<int64_t>>;
  DECL_ARGUMENT_FIELD(dlnne_input_shape_dict,
                      DlnneInputShapeDict,
                      dlnne_input_shape_type);
D
denglin-github 已提交
269 270
  DECL_ARGUMENT_FIELD(dlnne_workspace_size, DlnneWorkspaceSize, int);

W
Wilber 已提交
271 272
  DECL_ARGUMENT_FIELD(lite_passes_filter,
                      LitePassesFilter,
石晓伟 已提交
273 274
                      std::vector<std::string>);
  DECL_ARGUMENT_FIELD(lite_ops_filter, LiteOpsFilter, std::vector<std::string>);
W
Wilber 已提交
275 276
  DECL_ARGUMENT_FIELD(lite_precision_mode,
                      LitePrecisionMode,
石晓伟 已提交
277
                      AnalysisConfig::Precision);
278 279 280 281
  DECL_ARGUMENT_FIELD(lite_zero_copy, LiteZeroCopy, bool);

  DECL_ARGUMENT_FIELD(use_xpu, UseXpu, bool);
  DECL_ARGUMENT_FIELD(xpu_l3_workspace_size, XpuL3WorkspaceSize, int);
W
Wilber 已提交
282 283 284 285 286
  DECL_ARGUMENT_FIELD(xpu_locked, XpuLocked, bool);
  DECL_ARGUMENT_FIELD(xpu_autotune, XpuAutotune, bool);
  DECL_ARGUMENT_FIELD(xpu_autotune_file, XpuAutotuneFile, std::string);
  DECL_ARGUMENT_FIELD(xpu_precision, XpuPrecision, std::string);
  DECL_ARGUMENT_FIELD(xpu_adaptive_seqlen, XpuAdaptiveSeqlen, bool);
287
  DECL_ARGUMENT_FIELD(xpu_device_id, XpuDeviceId, int);
石晓伟 已提交
288

289
  DECL_ARGUMENT_FIELD(use_nnadapter, UseNNAdapter, bool);
W
Wilber 已提交
290 291
  DECL_ARGUMENT_FIELD(nnadapter_model_cache_dir,
                      NNAdapterModelCacheDir,
292
                      std::string);
W
Wilber 已提交
293 294
  DECL_ARGUMENT_FIELD(nnadapter_device_names,
                      NNAdapterDeviceNames,
295
                      std::vector<std::string>);
W
Wilber 已提交
296 297
  DECL_ARGUMENT_FIELD(nnadapter_context_properties,
                      NNAdapterContextProperties,
298 299
                      std::string);
  DECL_ARGUMENT_FIELD(nnadapter_subgraph_partition_config_buffer,
W
Wilber 已提交
300 301
                      NNAdapterSubgraphPartitionConfigBuffer,
                      std::string);
302
  DECL_ARGUMENT_FIELD(nnadapter_subgraph_partition_config_path,
W
Wilber 已提交
303 304 305 306
                      NNAdapterSubgraphPartitionConfigPath,
                      std::string);
  DECL_ARGUMENT_FIELD(nnadapter_model_cache_token,
                      NNAdapterModelCacheToken,
307
                      std::vector<std::string>);
W
Wilber 已提交
308 309
  DECL_ARGUMENT_FIELD(nnadapter_model_cache_buffer,
                      NNAdapterModelCacheBuffer,
310 311
                      std::vector<std::vector<char>>);

Y
Yan Chunwei 已提交
312 313
  // Memory optimized related.
  DECL_ARGUMENT_FIELD(enable_memory_optim, EnableMemoryOptim, bool);
314
  DECL_ARGUMENT_FIELD(trt_engine_memory_sharing, TrtEngineMemorySharing, bool);
315

Y
Yan Chunwei 已提交
316 317 318 319
  // Indicate which kind of sort algorithm is used for operators, the memory
  // optimization relays on the sort algorithm.
  DECL_ARGUMENT_FIELD(memory_optim_sort_kind, MemoryOptimSortKind, int);

320
  // The program transformed by IR analysis phase.
W
Wilber 已提交
321 322
  DECL_ARGUMENT_UNIQUE_FIELD(ir_analyzed_program,
                             IrAnalyzedProgram,
323 324 325
                             framework::proto::ProgramDesc);

  DECL_ARGUMENT_FIELD(fusion_statis, FusionStatis, fusion_statis_t);
326

W
Wilber 已提交
327
  // Only used in paddle-lite subgraph.
W
Wilber 已提交
328 329
  DECL_ARGUMENT_FIELD(cpu_math_library_num_threads,
                      CpuMathLibraryNumThreads,
W
Wilber 已提交
330 331
                      int);

J
jianghaicheng 已提交
332 333 334
  // ipu related
  DECL_ARGUMENT_FIELD(use_ipu, UseIpu, bool);
  DECL_ARGUMENT_FIELD(ipu_device_num, IpuDeviceNum, int);
335
  DECL_ARGUMENT_FIELD(ipu_micro_batch_size, IpuMicroBatchSize, int);
J
jianghaicheng 已提交
336 337
  DECL_ARGUMENT_FIELD(ipu_enable_pipelining, IpuEnablePipelining, bool);
  DECL_ARGUMENT_FIELD(ipu_batches_per_step, IpuBatchesPerStep, int);
338 339 340
  DECL_ARGUMENT_FIELD(ipu_enable_fp16, IpuEnableFp16, bool);
  DECL_ARGUMENT_FIELD(ipu_replica_num, IpuReplicaNum, int);
  DECL_ARGUMENT_FIELD(ipu_available_memory_proportion,
W
Wilber 已提交
341 342
                      IpuAvailableMemoryProportion,
                      float);
343
  DECL_ARGUMENT_FIELD(ipu_enable_half_partial, IpuEnableHalfPartial, bool);
344 345 346 347 348 349
  DECL_ARGUMENT_FIELD(ipu_custom_ops_info,
                      IpuCustomOpsInfo,
                      std::vector<std::vector<std::string>>);
  DECL_ARGUMENT_FIELD(ipu_custom_patterns,
                      IpuCustomPatterns,
                      std::vector<std::vector<std::string>>);
J
jianghaicheng 已提交
350

351 352 353 354
  // npu related
  DECL_ARGUMENT_FIELD(use_npu, UseNpu, bool);
  DECL_ARGUMENT_FIELD(npu_device_id, NPUDeviceId, int);

355 356
  // mixed precision related
  DECL_ARGUMENT_FIELD(model_precision, ModelPrecision, int);
357 358 359
  DECL_ARGUMENT_FIELD(mixed_black_list,
                      MixedBlackList,
                      std::unordered_set<std::string>);
360

361
 private:
362
  std::unordered_set<std::string> valid_fields_;
Y
Yan Chunwei 已提交
363 364
};

365
#define ARGUMENT_CHECK_FIELD(argument__, fieldname__) \
366
  PADDLE_ENFORCE_EQ(                                  \
W
Wilber 已提交
367 368
      argument__->Has(#fieldname__),                  \
      true,                                           \
369 370
      platform::errors::PreconditionNotMet(           \
          "the argument field [%s] should be set", #fieldname__));
Y
Yan Chunwei 已提交
371 372 373 374

}  // namespace analysis
}  // namespace inference
}  // namespace paddle