argument.h 10.2 KB
Newer Older
Y
Yan Chunwei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

/*
 * This file defines the class Argument, which is the input and output of the
 * analysis module. All the fields that needed either by Passes or PassManagers
 * are contained in Argument.
 *
 * TODO(Superjomn) Find some way better to contain the fields when it grow too
 * big.
 */

G
gongweibao 已提交
24 25
#pragma once

26
#include <map>
N
nhzlx 已提交
27
#include <memory>
G
gongweibao 已提交
28
#include <string>
N
nhzlx 已提交
29 30
#include <unordered_map>
#include <unordered_set>
31
#include <utility>
32
#include <vector>
N
nhzlx 已提交
33

34
#include "paddle/fluid/framework/ir/graph.h"
Y
Yan Chunwei 已提交
35
#include "paddle/fluid/framework/program_desc.h"
36
#include "paddle/fluid/framework/scope.h"
N
nhzlx 已提交
37
#include "paddle/fluid/inference/api/paddle_analysis_config.h"
38
#include "paddle/fluid/platform/variant.h"
Y
Yan Chunwei 已提交
39 40 41 42

namespace paddle {
namespace inference {
namespace analysis {
43

44
using framework::ir::Graph;
45 46

#ifdef PADDLE_WITH_MKLDNN
47 48
using VarQuantScale =
    std::unordered_map<std::string, std::pair<bool, framework::LoDTensor>>;
49
#endif
Y
Yan Chunwei 已提交
50 51 52 53 54 55 56

/*
 * The argument definition of both Pass and PassManagers.
 *
 * All the fields should be registered here for clearness.
 */
struct Argument {
Y
Yan Chunwei 已提交
57
  Argument() = default;
58 59 60 61
  explicit Argument(const std::string& model_dir) { SetModelDir(model_dir); }

  using unique_ptr_t = std::unique_ptr<void, std::function<void(void*)>>;
  using fusion_statis_t = std::unordered_map<std::string, int>;
62
  using input_shape_t = std::map<std::string, std::vector<int>>;
63 64

  bool Has(const std::string& key) const { return valid_fields_.count(key); }
65 66 67
  // If we set the model using config.SetModelBuffer,
  // the model and parameter will occupy additional CPU resources.
  // Use this interface to release these resources.
68 69 70 71 72 73 74 75 76 77
  void PartiallyRelease() {
    if (Has("model_program_path")) {
      if (Has("model_from_memory") && model_from_memory()) {
        model_program_path().clear();
        model_program_path().shrink_to_fit();
        model_params_path().clear();
        model_params_path().shrink_to_fit();
      }
    }
  }
78

79 80 81 82 83 84 85 86 87 88 89 90 91 92
#define DECL_ARGUMENT_FIELD(field__, Field, type__)          \
 public:                                                     \
  type__& field__() {                                        \
    PADDLE_ENFORCE(Has(#field__), "There is no such field"); \
    return field__##_;                                       \
  }                                                          \
  void Set##Field(const type__& x) {                         \
    field__##_ = x;                                          \
    valid_fields_.insert(#field__);                          \
  }                                                          \
  DECL_ARGUMENT_FIELD_VALID(field__);                        \
  type__* field__##_ptr() { return &field__##_; }            \
                                                             \
 private:                                                    \
93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
  type__ field__##_;

#define DECL_ARGUMENT_FIELD_VALID(field__) \
  bool field__##_valid() { return Has(#field__); }

#define DECL_ARGUMENT_UNIQUE_FIELD(field__, Field, type__)                \
 public:                                                                  \
  type__& field__() {                                                     \
    PADDLE_ENFORCE_NOT_NULL(field__##_);                                  \
    PADDLE_ENFORCE(Has(#field__));                                        \
    return *static_cast<type__*>(field__##_.get());                       \
  }                                                                       \
  void Set##Field(type__* x) {                                            \
    field__##_ =                                                          \
        unique_ptr_t(x, [](void* x) { delete static_cast<type__*>(x); }); \
    valid_fields_.insert(#field__);                                       \
  }                                                                       \
  void Set##Field##NotOwned(type__* x) {                                  \
    valid_fields_.insert(#field__);                                       \
    field__##_ = unique_ptr_t(x, [](void* x) {});                         \
  }                                                                       \
  DECL_ARGUMENT_FIELD_VALID(field__);                                     \
  type__* field__##_ptr() {                                               \
    PADDLE_ENFORCE(Has(#field__));                                        \
    return static_cast<type__*>(field__##_.get());                        \
  }                                                                       \
  type__* Release##Field() {                                              \
    PADDLE_ENFORCE(Has(#field__));                                        \
    valid_fields_.erase(#field__);                                        \
    return static_cast<type__*>(field__##_.release());                    \
  }                                                                       \
                                                                          \
 private:                                                                 \
  unique_ptr_t field__##_;

128
  DECL_ARGUMENT_FIELD(predictor_id, PredictorID, int);
129 130 131 132 133
  // Model path
  DECL_ARGUMENT_FIELD(model_dir, ModelDir, std::string);
  // Model specified with program and parameters files.
  DECL_ARGUMENT_FIELD(model_program_path, ModelProgramPath, std::string);
  DECL_ARGUMENT_FIELD(model_params_path, ModelParamsPath, std::string);
T
Tao Luo 已提交
134
  DECL_ARGUMENT_FIELD(model_from_memory, ModelFromMemory, bool);
135
  DECL_ARGUMENT_FIELD(optim_cache_dir, OptimCacheDir, std::string);
136
  DECL_ARGUMENT_FIELD(enable_analysis_optim, EnableAnalysisOptim, bool);
137 138 139 140 141 142

  // The overall graph to work on.
  DECL_ARGUMENT_UNIQUE_FIELD(main_graph, MainGraph, framework::ir::Graph);
  // The overall Scope to work on.
  DECL_ARGUMENT_UNIQUE_FIELD(scope, Scope, framework::Scope);

Y
Yan Chunwei 已提交
143
  // The default program, loaded from disk.
144 145 146 147 148
  DECL_ARGUMENT_UNIQUE_FIELD(main_program, MainProgram, framework::ProgramDesc);

  // The ir passes to perform in analysis phase.
  DECL_ARGUMENT_FIELD(ir_analysis_passes, IrAnalysisPasses,
                      std::vector<std::string>);
Y
Yan Chunwei 已提交
149 150
  DECL_ARGUMENT_FIELD(analysis_passes, AnalysisPasses,
                      std::vector<std::string>);
151

152 153 154
  // whether to mute all logs in inference.
  DECL_ARGUMENT_FIELD(disable_logs, DisableLogs, bool);

155 156 157
  // Pass a set of op types to enable its mkldnn kernel
  DECL_ARGUMENT_FIELD(mkldnn_enabled_op_types, MKLDNNEnabledOpTypes,
                      std::unordered_set<std::string>);
158 159
  // The cache capacity of different input shapes for mkldnn.
  DECL_ARGUMENT_FIELD(mkldnn_cache_capacity, MkldnnCacheCapacity, int);
160

161
#ifdef PADDLE_WITH_MKLDNN
162 163 164 165 166 167 168 169
  // A set of op types to enable their quantized kernels
  DECL_ARGUMENT_FIELD(quantize_enabled_op_types, QuantizeEnabledOpTypes,
                      std::unordered_set<std::string>);

  // A set of op IDs to exclude from enabling their quantized kernels
  DECL_ARGUMENT_FIELD(quantize_excluded_op_ids, QuantizeExcludedOpIds,
                      std::unordered_set<int>);

170 171
  // Scales for variables to be quantized
  DECL_ARGUMENT_FIELD(quant_var_scales, QuantVarScales, VarQuantScale);
172
#endif
173

Y
Yan Chunwei 已提交
174
  // Passed from config.
175
  DECL_ARGUMENT_FIELD(use_gpu, UseGPU, bool);
176
  DECL_ARGUMENT_FIELD(use_fc_padding, UseFcPadding, bool);
S
superjomn 已提交
177
  DECL_ARGUMENT_FIELD(gpu_device_id, GPUDeviceId, int);
178

179 180 181 182
  // Usually use for trt dynamic shape.
  // TRT will select the best kernel according to opt shape
  // Setting the disable_trt_plugin_fp16 to true means that TRT plugin will not
  // run fp16.
183 184 185
  DECL_ARGUMENT_FIELD(min_input_shape, MinInputShape, input_shape_t);
  DECL_ARGUMENT_FIELD(max_input_shape, MaxInputShape, input_shape_t);
  DECL_ARGUMENT_FIELD(optim_input_shape, OptimInputShape, input_shape_t);
186
  DECL_ARGUMENT_FIELD(disable_trt_plugin_fp16, CloseTrtPluginFp16, bool);
187

188 189 190
  DECL_ARGUMENT_FIELD(use_tensorrt, UseTensorRT, bool);
  DECL_ARGUMENT_FIELD(tensorrt_max_batch_size, TensorRtMaxBatchSize, int);
  DECL_ARGUMENT_FIELD(tensorrt_workspace_size, TensorRtWorkspaceSize, int);
191
  DECL_ARGUMENT_FIELD(tensorrt_min_subgraph_size, TensorRtMinSubgraphSize, int);
N
nhzlx 已提交
192
  DECL_ARGUMENT_FIELD(tensorrt_precision_mode, TensorRtPrecisionMode,
193
                      AnalysisConfig::Precision);
N
nhzlx 已提交
194 195
  DECL_ARGUMENT_FIELD(tensorrt_use_static_engine, TensorRtUseStaticEngine,
                      bool);
196
  DECL_ARGUMENT_FIELD(tensorrt_use_calib_mode, TensorRtUseCalibMode, bool);
S
Shang Zhizhou 已提交
197
  DECL_ARGUMENT_FIELD(tensorrt_use_oss, TensorRtUseOSS, bool);
198

石晓伟 已提交
199 200 201 202 203
  DECL_ARGUMENT_FIELD(lite_passes_filter, LitePassesFilter,
                      std::vector<std::string>);
  DECL_ARGUMENT_FIELD(lite_ops_filter, LiteOpsFilter, std::vector<std::string>);
  DECL_ARGUMENT_FIELD(lite_precision_mode, LitePrecisionMode,
                      AnalysisConfig::Precision);
204 205 206 207
  DECL_ARGUMENT_FIELD(lite_zero_copy, LiteZeroCopy, bool);

  DECL_ARGUMENT_FIELD(use_xpu, UseXpu, bool);
  DECL_ARGUMENT_FIELD(xpu_l3_workspace_size, XpuL3WorkspaceSize, int);
石晓伟 已提交
208

Y
Yan Chunwei 已提交
209 210
  // Memory optimized related.
  DECL_ARGUMENT_FIELD(enable_memory_optim, EnableMemoryOptim, bool);
211

Y
Yan Chunwei 已提交
212 213 214 215
  // Indicate which kind of sort algorithm is used for operators, the memory
  // optimization relays on the sort algorithm.
  DECL_ARGUMENT_FIELD(memory_optim_sort_kind, MemoryOptimSortKind, int);

216 217 218 219 220
  // The program transformed by IR analysis phase.
  DECL_ARGUMENT_UNIQUE_FIELD(ir_analyzed_program, IrAnalyzedProgram,
                             framework::proto::ProgramDesc);

  DECL_ARGUMENT_FIELD(fusion_statis, FusionStatis, fusion_statis_t);
221 222

 private:
223
  std::unordered_set<std::string> valid_fields_;
Y
Yan Chunwei 已提交
224 225
};

226 227 228
#define ARGUMENT_CHECK_FIELD(argument__, fieldname__) \
  PADDLE_ENFORCE(argument__->Has(#fieldname__),       \
                 "the argument field [%s] should be set", #fieldname__);
Y
Yan Chunwei 已提交
229 230 231 232

}  // namespace analysis
}  // namespace inference
}  // namespace paddle