tensorrt_subgraph_pass.cc 16.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15 16
#include "paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.h"

17
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
18
#include "paddle/fluid/framework/ir/subgraph_detector.h"
19
#include "paddle/fluid/framework/op_version_registry.h"
20
#include "paddle/fluid/inference/analysis/helper.h"
N
nhzlx 已提交
21 22
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
23
#include "paddle/fluid/inference/tensorrt/op_teller.h"
24 25 26 27 28 29 30

namespace paddle {
namespace inference {
namespace analysis {

using framework::ir::Node;

31 32 33
void analysis::TensorRtSubgraphPass::ApplyImpl(
    framework::ir::Graph *graph) const {
  framework::ir::FusePassBase::Init("tensorrt_subgraph_pass", graph);
34 35 36
  auto enable_int8 = Get<bool>("enable_int8");
  auto use_calib_mode = Get<bool>("use_calib_mode");
  bool no_calib_int8 = enable_int8 && !(use_calib_mode);
37
  auto trt_disabled_ops = Get<std::vector<std::string>>("trt_disabled_ops");
38
  auto teller = [&](const framework::ir::Node *node) {
39
    if (!node->IsOp() || !node->Op()) return false;
40 41 42 43 44 45
    if (find(trt_disabled_ops.begin(), trt_disabled_ops.end(),
             node->Op()->Type()) != trt_disabled_ops.end()) {
      VLOG(3) << node->Op()->Type().c_str()
              << " is diabled by config in TensorRT";
      return false;
    }
46 47
    return tensorrt::OpTeller::Global().Tell(node->Op()->Type(), *node->Op(),
                                             no_calib_int8);
48
  };
49

50 51 52
  framework::ir::SubGraphFuser fuser(
      graph, teller, Get<int>("min_subgraph_size") /*min subgraph size*/,
      "tensorrt_engine");
53 54
  fuser();

55 56 57 58 59 60
  std::vector<std::string> graph_param_names =
      ExtractParameters(graph->Nodes());
  // those parameter already exist in trt, and should not have another copy in
  // fluid.
  std::vector<std::string> repetitive_params;

61
  for (auto *node : graph->Nodes()) {
62
    if (node->IsOp() && !framework::ir::Agent(node).subgraph()->empty()) {
63
      CreateTensorRTOp(node, graph, graph_param_names, &repetitive_params);
64 65

      std::unordered_set<const Node *> nodes2remove(
66 67
          framework::ir::Agent(node).subgraph()->begin(),
          framework::ir::Agent(node).subgraph()->end());
68
      framework::ir::GraphSafeRemoveNodes(graph, nodes2remove);
69 70 71 72 73
    }
  }

  std::unordered_set<const Node *> nodes2remove;
  for (auto *node : graph->Nodes()) {
74
    if (node->IsOp() && framework::ir::Agent(node).deleted()) {
75 76 77
      nodes2remove.insert(node);
    }
  }
78
  framework::ir::GraphSafeRemoveNodes(graph, nodes2remove);
79 80
  graph->Set(framework::ir::kRepetitiveParamAttr,
             new std::vector<std::string>(repetitive_params));
81 82
}

N
nhzlx 已提交
83
std::string GenerateEngineKey(const std::set<std::string> &engine_inputs,
N
nhzlx 已提交
84 85
                              const std::set<std::string> &engine_outputs,
                              const std::string &predictor_id) {
N
nhzlx 已提交
86 87 88 89 90 91 92
  std::string engine_hash_key = "";
  for (auto name : engine_inputs) {
    engine_hash_key += name;
  }
  for (auto name : engine_outputs) {
    engine_hash_key += name;
  }
N
nhzlx 已提交
93
  engine_hash_key += predictor_id;
N
nhzlx 已提交
94 95 96 97
  auto engine_key = std::to_string(std::hash<std::string>()(engine_hash_key));
  return engine_key;
}

98
void TensorRtSubgraphPass::CreateTensorRTOp(
99
    framework::ir::Node *node, framework::ir::Graph *graph,
100 101
    const std::vector<std::string> &graph_params,
    std::vector<std::string> *repetitive_params) const {
102
  auto *op_desc = node->Op();
103
  auto &subgraph = *framework::ir::Agent(node).subgraph();
104 105 106
  PADDLE_ENFORCE_EQ(subgraph.empty(), false,
                    platform::errors::PreconditionNotMet(
                        "The subgraph should not be empty."));
107

N
nhzlx 已提交
108 109 110 111 112 113 114 115
  framework::ProgramDesc *program_desc =
      Get<framework::ProgramDesc *>("program");
  // Add new block for TensorRTEngineOP
  const framework::BlockDesc &main_block =
      program_desc->Block(framework::kRootBlockIndex);
  // const framework::BlockDesc& main_block = program_desc->Block(0);
  framework::BlockDesc *new_block = program_desc->AppendBlock(main_block);

116
  // A fake block desc.
117 118 119 120
  framework::proto::BlockDesc block_proto;
  framework::BlockDesc block_desc(nullptr, &block_proto);
  block_desc.Proto()->set_parent_idx(-1);
  block_desc.Proto()->set_idx(0);
121
  LOG(INFO) << "---  detect a sub-graph with " << subgraph.size() << " nodes";
Y
Yan Chunwei 已提交
122

123
  for (auto *node : subgraph) {
N
nhzlx 已提交
124
    auto *new_block_op = new_block->AppendOp();
125
    auto *op = block_desc.AppendOp();
N
nhzlx 已提交
126
    *new_block_op->Proto() = *node->Op()->Proto();
127 128 129
    *op->Proto() = *node->Op()->Proto();
  }

N
nhzlx 已提交
130
  // Then, we will use the input_names_with_id and output_names_with_id to
131
  // generate the engine key.
N
nhzlx 已提交
132 133
  // So, We use set instead of unordered_set here to ensure that the engine key
  // is unique.
N
nhzlx 已提交
134 135
  std::set<std::string> input_names;
  std::set<std::string> input_names_with_id;
136
  std::vector<std::string> params;
137 138 139
  // if we delete fluid copy of params shared by more than 1 ops, there will be
  // problem, so we filter them out.
  std::vector<std::string> params_not_shared;
140

141
  // The node->inputs contains input tensors and parameters.
142 143 144
  for (auto *x : node->inputs) {
    input_names.insert(x->Name());
    input_names_with_id.insert(x->Name() + std::to_string(x->id()));
145 146 147
    if (std::count(graph_params.begin(), graph_params.end(), x->Name()) > 0) {
      params.push_back(x->Name());
    }
148 149 150 151
    if (std::count(graph_params.begin(), graph_params.end(), x->Name()) > 0 &&
        x->outputs.size() <= 1) {
      params_not_shared.push_back(x->Name());
    }
152
  }
153

N
nhzlx 已提交
154 155
  std::set<std::string> output_names;
  std::set<std::string> output_names_with_id;
156
  std::vector<int> origin_output_dims;
157 158 159
  for (auto *x : node->outputs) {
    output_names.insert(x->Name());
    output_names_with_id.insert(x->Name() + std::to_string(x->id()));
160
    origin_output_dims.push_back(x->Var()->GetShape().size());
161 162 163
  }

  std::unordered_map<std::string, std::string> output_name_map;
164 165 166 167 168 169 170
  std::unordered_map<std::string, framework::ir::Node *> graph_var_map;

  for (framework::ir::Node *node : graph->Nodes()) {
    if (node->IsVar() && node->Var()) {
      graph_var_map[node->Name()] = node;
    }
  }
Z
Zhaolong Xing 已提交
171 172 173
  auto precision_mode = Get<AnalysisConfig::Precision>("precision_mode");
  bool enable_fp16 = false;
  if (precision_mode == AnalysisConfig::Precision::kHalf) enable_fp16 = true;
174 175
  auto enable_int8 = Get<bool>("enable_int8");
  auto use_calib_mode = Get<bool>("use_calib_mode");
176
  auto &subgraph_nodes = *framework::ir::Agent(node).subgraph();
177 178 179 180 181 182
  auto min_input_shape =
      Get<std::map<std::string, std::vector<int>>>("min_input_shape");
  auto max_input_shape =
      Get<std::map<std::string, std::vector<int>>>("max_input_shape");
  auto opt_input_shape =
      Get<std::map<std::string, std::vector<int>>>("optim_input_shape");
183 184 185 186 187 188 189 190 191 192 193 194 195

  // The following procedure is used to rename all the intermediate
  // variables and the output variables of the subgraph.
  // Why we do this?
  // During the transition from fluid OP to tensorrt OP, we map
  // the input and output Tensor(fluid data structure) of fluid OP
  // to the corresponding ITensor (trt data structure) through the
  // Tensor name. When we set up ITensor for an variable, we must
  // ensure that it has not been set before.
  // If there is variable in the fluid graph, which is not only the
  // input of a OP, but also the output of a Op, there will be problems.
  // So we have to rename the variable in the subgraph to make sure
  // it is either an OP's input or an OP's output.
N
nhzlx 已提交
196
  RenameAndGetOutputs(subgraph_nodes, &block_desc, input_names_with_id,
197
                      &output_names_with_id, &output_names, &output_name_map,
198
                      graph_var_map, !enable_int8);
199 200 201 202 203 204

  // When tensorrt engine runs at the end of the operation,
  // output_mapping help us copy the data from the renamed ITensor
  // to Tensor.
  std::vector<std::string> output_mapping;
  for (auto name : output_names) {
205 206 207
    PADDLE_ENFORCE_NE(output_name_map.count(name), 0,
                      platform::errors::PreconditionNotMet(
                          "The output_name_map should have %s", name));
208 209
    output_mapping.push_back(output_name_map[name]);
  }
210 211 212 213 214 215
  PADDLE_ENFORCE_EQ(output_mapping.empty(), false,
                    platform::errors::PreconditionNotMet(
                        "The output_mapping should not be empty."));
  PADDLE_ENFORCE_EQ(
      !block_desc.Proto()->vars().empty(), true,
      platform::errors::PreconditionNotMet("the block has no var-desc"));
N
nhzlx 已提交
216

217 218
  // Set attrs
  op_desc->SetType("tensorrt_engine");
N
nhzlx 已提交
219 220 221 222 223 224
  op_desc->SetInput(
      "Xs", std::vector<std::string>(input_names.begin(), input_names.end()));

  op_desc->SetOutput(
      "Ys", std::vector<std::string>(output_names.begin(), output_names.end()));

N
nhzlx 已提交
225
  op_desc->SetBlockAttr("sub_block", new_block);
226 227 228 229 230
  op_desc->SetAttr("subgraph", block_desc.Proto()->SerializeAsString());
  op_desc->SetAttr("max_batch_size", Get<int>("max_batch_size"));
  op_desc->SetAttr("workspace_size", Get<int>("workspace_size"));
  op_desc->SetAttr("gpu_id", Get<int>("gpu_device_id"));
  op_desc->SetAttr("output_name_mapping", output_mapping);
231
  op_desc->SetAttr("origin_output_dims", origin_output_dims);
232
  op_desc->SetAttr("parameters", params);
N
nhzlx 已提交
233

234 235 236 237 238
  // we record all inputs' shapes in attr to check if they are consistent
  // with the real inputs' shapes retrieved from scope when trt runs.
  for (auto *x : node->inputs) {
    if (x->IsVar() && x->Var()) {
      framework::VarDesc *var = x->Var();
239
      op_desc->SetAttr(var->Name() + "_shape", var->GetShape());
240 241 242
    }
  }

243
  auto use_static_engine = Get<bool>("use_static_engine");
244 245
  // TODO(NHZlX)
  // There are models with the same structure but the different parameters,
T
tianshuo78520a 已提交
246
  // when running in the 'use_serialize' mode, there is a bug.
N
nhzlx 已提交
247
  auto engine_key = GenerateEngineKey(input_names_with_id, output_names_with_id,
N
nhzlx 已提交
248
                                      std::to_string(0));
249
  auto predictor_id = Get<int>("predictor_id");
N
nhzlx 已提交
250

N
nhzlx 已提交
251
  // Get "" when there is no cached calibration table data.
252 253
  bool load_from_memory = Get<bool>("model_from_memory");
  std::string calibration_data = "";
254
  if (enable_int8 && use_calib_mode) {
255 256 257
    calibration_data = GetTrtCalibTableData(
        Get<std::string>("model_opt_cache_dir"), engine_key, enable_int8);
  }
258 259 260 261 262 263
  op_desc->SetAttr("calibration_data", calibration_data);
  op_desc->SetAttr("enable_int8", enable_int8);
  op_desc->SetAttr("enable_fp16", enable_fp16);
  op_desc->SetAttr("use_calib_mode", use_calib_mode);
  op_desc->SetAttr("engine_key", engine_key);
  op_desc->SetAttr("predictor_id", predictor_id);
N
nhzlx 已提交
264

265
  std::string trt_engine_serialized_data = "";
266 267
  op_desc->SetAttr("engine_serialized_data", trt_engine_serialized_data);
  op_desc->Flush();
N
nhzlx 已提交
268 269 270 271

  std::unique_ptr<tensorrt::TRTInt8Calibrator> calibrator;
  if (enable_int8 && calibration_data.size() != 0) {
    calibrator.reset(new tensorrt::TRTInt8Calibrator(calibration_data));
272
    LOG(INFO) << "RUN Paddle TRT int8 calibration mode...";
N
nhzlx 已提交
273 274 275
  }
  // When in int8 mode and calibration_mode, the program just produce the
  // calibration table data.
276 277
  bool calibration_mode =
      (enable_int8 && calibration_data.size() == 0 && use_calib_mode);
N
nhzlx 已提交
278 279 280 281
  if (calibration_mode) {
    // calibraion mode means generate int8 calibration table data process.
    return;
  }
N
nhzlx 已提交
282

283
  std::copy(params_not_shared.begin(), params_not_shared.end(),
N
nhzlx 已提交
284 285
            std::back_inserter(*repetitive_params));

286 287 288 289 290 291 292 293 294 295 296 297 298 299
  // Check trt version for dynamic shape input.

  if (min_input_shape.size() > 0 && TRT_VERSION < 6000) {
    LOG_FIRST_N(WARNING, 1) << "You are using the dynamic size input mode of "
                               "Paddle-TRT, but we found that the version of "
                               "the TensorRT is less than 6.0, so we use the "
                               "static shape mode instead.";
    min_input_shape = {};
    max_input_shape = {};
    opt_input_shape = {};
  }

  if (min_input_shape.size() > 0 && TRT_VERSION > 6000) {
    LOG_FIRST_N(WARNING, 1)
300
        << "The Paddle lib links the " << TRT_VERSION << " version TensorRT, "
301 302 303 304
        << "make sure the runtime TensorRT you are using is no less than this "
           "version, otherwise, there might be Segfault!";
  }

305 306 307 308 309
  // Setting the disable_trt_plugin_fp16 to true means that TRT plugin will not
  // run fp16.
  // When running fp16, the output accuracy of the model will be affected,
  // closing the plugin fp16 may bring some improvement on accuracy.
  bool disable_trt_plugin_fp16 = Get<bool>("disable_trt_plugin_fp16");
310 311 312 313
  tensorrt::TensorRTEngine *trt_engine =
      inference::Singleton<inference::tensorrt::TRTEngineManager>::Global()
          .Create(engine_key + std::to_string(predictor_id),
                  Get<int>("max_batch_size"), Get<int>("workspace_size"),
314
                  precision_mode, calibrator.get(), Get<int>("gpu_device_id"),
315 316
                  min_input_shape, max_input_shape, opt_input_shape,
                  disable_trt_plugin_fp16);
317
  trt_engine->SetUseOSS(Get<bool>("use_oss"));
318 319
  trt_engine->SetUseDLA(Get<bool>("trt_use_dla"));
  trt_engine->SetDLACore(Get<int>("trt_dla_core"));
320

321
  trt_engine->SetWithErnie(
322 323
      graph->Has(framework::ir::kEmbEltwiseLayernormPass) &&
      graph->Has(framework::ir::kMultiheadMatmulPass));
324 325

  bool need_serialize = (use_static_engine && !load_from_memory);
N
nhzlx 已提交
326 327 328 329 330
  if (need_serialize) {
    trt_engine_serialized_data = GetTrtEngineSerializedData(
        Get<std::string>("model_opt_cache_dir"), engine_key);
    // we can load the engine info serialized before from the disk.
    if (!trt_engine_serialized_data.empty()) {
331
      trt_engine->Deserialize(trt_engine_serialized_data);
N
nhzlx 已提交
332
      LOG(INFO) << "Load TRT Optimized Info from "
N
nhzlx 已提交
333 334
                << GetTrtEngineSerializedPath(
                       Get<std::string>("model_opt_cache_dir"), engine_key);
N
nhzlx 已提交
335
      return;
336 337 338
    }
  }

N
nhzlx 已提交
339 340 341 342 343
  // the following code will NOT run in following situation:
  // 1. calibraion mode (generate trt int8 calibraiton table data)
  // 2. already load serialized trt engine info.
  LOG(INFO) << "Prepare TRT engine (Optimize model structure, Select OP "
               "kernel etc). This process may cost a lot of time.";
344

N
nhzlx 已提交
345 346 347 348 349 350 351
  auto *scope = param_scope();
  framework::BlockDesc block_desc_temp(nullptr, block_desc.Proto());
  std::unordered_set<std::string> param_set(params.begin(), params.end());
  inference::Singleton<inference::tensorrt::OpConverter>::Global()
      .ConvertBlockToTRTEngine(
          &block_desc_temp, *scope,
          std::vector<std::string>(input_names.begin(), input_names.end()),
352
          param_set, output_mapping, trt_engine);
N
nhzlx 已提交
353 354

  if (need_serialize) {
355 356 357 358
    nvinfer1::IHostMemory *serialized_engine_data = trt_engine->Serialize();
    trt_engine_serialized_data =
        std::string((const char *)serialized_engine_data->data(),
                    serialized_engine_data->size());
N
nhzlx 已提交
359 360 361 362
    SaveTrtEngineSerializedDataToFile(
        GetTrtEngineSerializedPath(Get<std::string>("model_opt_cache_dir"),
                                   engine_key),
        trt_engine_serialized_data);
N
nhzlx 已提交
363 364 365
  }
}

366 367 368 369 370 371 372
}  // namespace analysis
}  // namespace inference
}  // namespace paddle

REGISTER_PASS(tensorrt_subgraph_pass,
              paddle::inference::analysis::TensorRtSubgraphPass)
    .RequirePassAttr("max_batch_size")
373 374
    .RequirePassAttr("workspace_size")
    .RequirePassAttr("min_subgraph_size");
375 376 377 378

REGISTER_PASS_CAPABILITY(tensorrt_subgraph_pass)
    .AddCombination(
        paddle::framework::compatible::OpVersionComparatorCombination()
379
            .LE("conv2d", 1)
380 381 382 383 384
            .EQ("pool2d", 0)
            .EQ("relu", 0)
            .EQ("softmax", 0)
            .EQ("sigmoid", 0)
            .EQ("hard_swish", 0)
385
            .LE("depthwise_conv2d", 1)
386 387 388 389
            .EQ("batch_norm", 0)
            .EQ("concat", 0)
            .EQ("tanh", 0)
            .EQ("pad", 0)
390 391
            .LE("elementwise_add", 1)
            .LE("elementwise_mul", 1)
392
            .EQ("prelu", 0)
393
            .LE("conv2d_transpose", 2)
394 395 396 397 398
            .LE("leaky_relu", 1)
            .EQ("fc", 0)
            .EQ("shuffle_channel", 0)
            .EQ("swish", 0)
            .EQ("split", 0)
399
            .LE("instance_norm", 1)
400 401
            .EQ("gelu", 0)
            .EQ("layer_norm", 0)
402
            .EQ("scale", 0)
403
            .LE("matmul", 1));