engine.cc 14.6 KB
Newer Older
Y
Yan Chunwei 已提交
1 2
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

N
nhzlx 已提交
3 4
Licensed under the Apache License, Version 2.0 (the "License"); you may not use
this file except in compliance with the License.
Y
Yan Chunwei 已提交
5 6 7 8 9 10 11 12 13 14 15 16 17 18
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/inference/tensorrt/engine.h"

#include <NvInfer.h>
#include <glog/logging.h>
A
Abhinav Arora 已提交
19
#include <string>
W
wanghuancoder 已提交
20

21
#include "cuda_runtime_api.h"  // NOLINT
Y
Yan Chunwei 已提交
22 23
#include "paddle/fluid/inference/tensorrt/helper.h"
#include "paddle/fluid/platform/enforce.h"
24
#include "paddle/fluid/platform/gpu_info.h"
Y
Yan Chunwei 已提交
25 26 27 28 29

namespace paddle {
namespace inference {
namespace tensorrt {

30 31
int TensorRTEngine::runtime_batch_ = 1;

32 33 34 35 36
void TensorRTEngine::InitNetwork() {
  freshDeviceId();
  infer_builder_.reset(createInferBuilder(&logger_));

  if (with_dynamic_shape_) {
37
    infer_network_.reset(infer_builder_->createNetworkV2(
38 39 40
        1U << static_cast<int>(
            nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH)));
  } else {
41
    infer_network_.reset(infer_builder_->createNetworkV2(0U));
42
  }
43 44 45

  infer_builder_config_.reset(infer_builder_->createBuilderConfig());
  optim_profile_ = infer_builder_->createOptimizationProfile();
Y
Yan Chunwei 已提交
46 47
}

48 49
void TensorRTEngine::Execute(int batch_size, std::vector<void *> *buffers,
                             cudaStream_t stream) {
N
nhzlx 已提交
50
  freshDeviceId();
51 52 53 54 55 56 57
  auto infer_context = context();
  if (!with_dynamic_shape()) {
    infer_context->enqueue(batch_size, buffers->data(), stream, nullptr);
  } else {
#if IS_TRT_VERSION_GE(6000)
    infer_context->enqueueV2(buffers->data(), stream, nullptr);
#endif
58
  }
N
nhzlx 已提交
59 60 61
  SetRuntimeBatch(batch_size);
}

Y
Yan Chunwei 已提交
62
void TensorRTEngine::FreezeNetwork() {
N
nhzlx 已提交
63
  freshDeviceId();
64
  VLOG(3) << "TRT to freeze network";
65 66 67 68 69 70 71
  PADDLE_ENFORCE_NOT_NULL(infer_builder_,
                          platform::errors::InvalidArgument(
                              "Inference builder of TRT is null. Please make "
                              "sure you call InitNetwork first."));
  PADDLE_ENFORCE_NOT_NULL(network(),
                          platform::errors::InvalidArgument(
                              "Call InitNetwork first to initialize network."));
Y
Yan Chunwei 已提交
72 73
  // build engine.
  infer_builder_->setMaxBatchSize(max_batch_);
74 75
  infer_builder_config_->setMaxWorkspaceSize(max_workspace_);

Z
Zhaolong Xing 已提交
76 77 78
  bool enable_fp16 = (precision_ == AnalysisConfig::Precision::kHalf);
  if (enable_fp16) {
    bool support_fp16 = infer_builder_->platformHasFastFp16();
79
    infer_builder_config_->setFlag(nvinfer1::BuilderFlag::kFP16);
Z
Zhaolong Xing 已提交
80 81 82
    if (!support_fp16) {
      LOG(INFO) << "You specify FP16 mode, but the hardware do not support "
                   "FP16 speed up, use FP32 instead.";
83 84
    } else {
      LOG(INFO) << "Run Paddle-TRT FP16 mode";
Z
Zhaolong Xing 已提交
85 86 87
    }
  }

88
  bool enable_int8 = (precision_ == AnalysisConfig::Precision::kInt8);
Z
Zhaolong Xing 已提交
89
  if (enable_int8) {
90 91 92
    infer_builder_config_->setFlag(nvinfer1::BuilderFlag::kFP16);
    infer_builder_config_->setFlag(nvinfer1::BuilderFlag::kINT8);

93
    if (calibrator_) {
94
      infer_builder_config_->setInt8Calibrator(calibrator_);
95
    } else {
96
      infer_builder_config_->setInt8Calibrator(nullptr);
97 98 99 100 101 102 103 104 105

#if IS_TRT_VERSION_GE(5000)
      for (auto &quant_range : quant_dynamic_range_) {
        auto tensor = quant_range.first;
        float range = quant_range.second;
        tensor->setDynamicRange(-range, range);
      }

      std::unordered_set<nvinfer1::ITensor *> all_t;
106 107
      for (int i = 0; i < network()->getNbLayers(); i++) {
        auto layer = network()->getLayer(i);
108 109 110 111
        for (int j = 0; j < layer->getNbOutputs(); j++) {
          all_t.insert(layer->getOutput(j));
        }
      }
112

113 114
      for (int i = 0; i < network()->getNbInputs(); i++) {
        all_t.insert(network()->getInput(i));
115 116 117 118
      }

      for (auto &t : all_t) {
        if (!quant_dynamic_range_.count(t)) {
T
tianshuo78520a 已提交
119 120 121
          VLOG(3) << "We are in trt int8 mode(not calibration), scale not set"
                  << " for tensor " << t->getName()
                  << ", this might be ok when trt does not need this range";
122 123
        }
      }
124

125
#if IS_TRT_VERSION_GE(5122)
126 127 128 129 130 131 132 133 134 135
      auto is_layer_int8 = [&](nvinfer1::ILayer *layer) -> bool {
        for (int j = 0; j < layer->getNbInputs(); j++) {
          auto *temp_in = layer->getInput(j);
          if (!temp_in->dynamicRangeIsSet()) {
            VLOG(1) << "Layer(Name: " << layer->getName()
                    << ") is set to float32 because its input("
                    << temp_in->getName() << ") doesn't have dynamic range.";
            return false;
          }
        }
136 137
        for (int j = 0; j < layer->getNbOutputs(); j++) {
          auto *temp_out = layer->getOutput(j);
138 139 140 141 142 143 144 145 146 147 148
          if (temp_out->isNetworkOutput()) {
            VLOG(1) << "Layer(Name: " << layer->getName()
                    << ") is set to float32 because its output("
                    << temp_out->getName() << ") is the output of the network.";
            return false;
          }
          if (!temp_out->dynamicRangeIsSet()) {
            VLOG(1) << "Layer(Name: " << layer->getName()
                    << ") is set to float32 because its output("
                    << temp_out->getName() << ") doesn't have dynamic range.";
            return false;
149 150
          }
        }
151 152 153 154 155 156 157 158 159 160 161
        return true;
      };
      // If a layer's output is the network's output, or not all of its inputs
      // and outputs have scales,
      // this layer's precision and output type are set to float32.
      // This step has no effect if this layer is fused during TRT optimization.
      for (int i = 0; i < network()->getNbLayers(); i++) {
        auto layer = network()->getLayer(i);
        if (!is_layer_int8(layer)) {
          layer->setPrecision(nvinfer1::DataType::kFLOAT);
        }
162
      }
163 164 165 166 167
#else
      LOG(WARNING) << "If your TensorRT version is lower than 5.1.2.2, you "
                      "must provide quantization scales for all tensors using "
                      "TRT to run.";
#endif
168 169
#endif
    }
N
nhzlx 已提交
170
  }
Y
Yan Chunwei 已提交
171

172 173 174 175 176 177 178 179 180 181 182 183 184 185 186
  if (use_dla_) {
    if (!enable_int8 && !enable_fp16) {
      LOG(WARNING) << "TensorRT DLA must be used with int8 or fp16, but you "
                      "set float32, so DLA is not used.";
    } else if (infer_builder_->getNbDLACores() == 0) {
      LOG(WARNING)
          << "TensorRT DLA is set by config, but your device does not have "
             "DLA, so DLA is not used.";
    } else {
      if (dla_core_ < 0 || dla_core_ >= infer_builder_->getNbDLACores()) {
        dla_core_ = 0;
        LOG(WARNING) << "Invalid DLACore, must be 0 < DLACore < "
                     << infer_builder_->getNbDLACores() << ", but got "
                     << dla_core_ << ", so use use 0 as default.";
      }
187 188 189
      infer_builder_config_->setDefaultDeviceType(nvinfer1::DeviceType::kDLA);
      infer_builder_config_->setDLACore(dla_core_);
      infer_builder_config_->setFlag(nvinfer1::BuilderFlag::kGPU_FALLBACK);
190 191 192 193 194
      LOG(INFO) << "TensorRT DLA enabled in FreezeNetwork(), DLACore "
                << dla_core_;
    }
  }

195 196
  if (with_dynamic_shape_) {
#if IS_TRT_VERSION_GE(6000)
197
    LOG(INFO) << "Run Paddle-TRT Dynamic Shape mode.";
198
    for (auto &input : min_input_shape_) {
199 200 201 202
      VLOG(4) << "TRT dynamic_shape set " << input.first
              << " min: " << Vec2Str(input.second)
              << ", max: " << Vec2Str(max_input_shape_[input.first])
              << ", opt: " << Vec2Str(optim_input_shape_[input.first]);
203 204 205 206 207 208 209 210 211 212
      optim_profile_->setDimensions(
          input.first.c_str(), nvinfer1::OptProfileSelector::kMIN,
          Vec2TRT_Dims(input.second, input.first, true));
      optim_profile_->setDimensions(
          input.first.c_str(), nvinfer1::OptProfileSelector::kMAX,
          Vec2TRT_Dims(max_input_shape_[input.first], input.first, true));
      optim_profile_->setDimensions(
          input.first.c_str(), nvinfer1::OptProfileSelector::kOPT,
          Vec2TRT_Dims(optim_input_shape_[input.first], input.first, true));
    }
213
    infer_builder_config_->addOptimizationProfile(optim_profile_);
214 215 216 217 218 219
    if (WithFp16() && disable_trt_plugin_fp16()) {
      LOG(INFO) << "NOTE: In order to achieve higher accuracy, you have "
                   "disabled the fp16 mode of TRT Plugin,\n"
                << "you can reopen it with "
                   "'config.SetDynamicShapeInfo(min_shape, max_shape, "
                   "opt_shape, false /*disable_trt_plugin_fp16*/)'";
220
    }
221 222
#endif
  }
223 224

#if IS_TRT_VERSION_LT(8000)
225 226
  infer_engine_.reset(infer_builder_->buildEngineWithConfig(
      *network(), *infer_builder_config_));
227 228 229 230 231 232 233
#else
  infer_ptr<nvinfer1::IHostMemory> plan(infer_builder_->buildSerializedNetwork(
      *network(), *infer_builder_config_));
  infer_ptr<nvinfer1::IRuntime> runtime(createInferRuntime(&logger_));
  infer_engine_.reset(
      runtime->deserializeCudaEngine(plan->data(), plan->size()));
#endif
234

235 236 237 238
  PADDLE_ENFORCE_NOT_NULL(
      infer_engine_, platform::errors::Fatal(
                         "Build TensorRT cuda engine failed! Please recheck "
                         "you configurations related to paddle-TensorRT."));
Y
Yan Chunwei 已提交
239 240
}

241
nvinfer1::ITensor *TensorRTEngine::DeclareInput(const std::string &name,
Y
Yan Chunwei 已提交
242
                                                nvinfer1::DataType dtype,
243
                                                const nvinfer1::Dims &dims) {
244 245 246 247
  PADDLE_ENFORCE_EQ(network() != nullptr, true,
                    platform::errors::InvalidArgument(
                        "The TRT network should be initialized first."));
  auto *input = network()->addInput(name.c_str(), dtype, dims);
248 249 250 251 252 253 254 255 256 257
  PADDLE_ENFORCE_NOT_NULL(
      input, platform::errors::InvalidArgument("Adding input %s failed in "
                                               "TensorRT inference network. "
                                               "Please recheck your input.",
                                               name));
  PADDLE_ENFORCE_EQ(input->isNetworkInput(), true,
                    platform::errors::InvalidArgument(
                        "Input %s is not the input of TRT inference network. "
                        "Please recheck your input.",
                        name));
L
Luo Tao 已提交
258
  TensorRTEngine::SetITensor(name, input);
Y
Yan Chunwei 已提交
259 260 261
  return input;
}

262 263 264
void TensorRTEngine::DeclareOutput(const nvinfer1::ILayer *layer, int offset,
                                   const std::string &name) {
  auto *output = layer->getOutput(offset);
265
  SetITensor(name, output);
266 267 268
  PADDLE_ENFORCE_NOT_NULL(
      output, platform::errors::InvalidArgument(
                  "The output %s of TRT engine should not be null.", name));
Y
Yan Chunwei 已提交
269
  output->setName(name.c_str());
270 271 272 273 274
  PADDLE_ENFORCE_EQ(output->isNetworkInput(), false,
                    platform::errors::InvalidArgument(
                        "The output %s of TRT engine should not be the input "
                        "of the network at the same time.",
                        name));
275
  network()->markOutput(*output);
276 277 278 279 280
  PADDLE_ENFORCE_EQ(
      output->isNetworkOutput(), true,
      platform::errors::InvalidArgument(
          "The output %s of TRT engine should be the output of the network.",
          name));
N
nhzlx 已提交
281 282
}

283 284
void TensorRTEngine::DeclareOutput(const std::string &name) {
  auto *output = TensorRTEngine::GetITensor(name);
285 286 287
  PADDLE_ENFORCE_NOT_NULL(
      output, platform::errors::InvalidArgument(
                  "The output %s of TRT engine should not be null.", name));
L
Luo Tao 已提交
288
  output->setName(name.c_str());
289 290 291 292 293
  PADDLE_ENFORCE_EQ(output->isNetworkInput(), false,
                    platform::errors::InvalidArgument(
                        "The output %s of TRT engine should not be the input "
                        "of the network at the same time.",
                        name));
294
  network()->markOutput(*output);
L
Luo Tao 已提交
295 296
}

297 298
void TensorRTEngine::SetITensor(const std::string &name,
                                nvinfer1::ITensor *tensor) {
299 300 301 302 303 304 305
  PADDLE_ENFORCE_NOT_NULL(
      tensor, platform::errors::InvalidArgument(
                  "Tensor named %s of TRT engine should not be null.", name));
  PADDLE_ENFORCE_EQ(
      0, itensor_map_.count(name),
      platform::errors::InvalidArgument(
          "Tensor named %s of TRT engine should not be duplicated", name));
L
Luo Tao 已提交
306 307 308
  itensor_map_[name] = tensor;
}

309
nvinfer1::ITensor *TensorRTEngine::GetITensor(const std::string &name) {
310 311 312
  PADDLE_ENFORCE_EQ(itensor_map_.count(name), true,
                    platform::errors::NotFound(
                        "Tensor named %s is not found in TRT engine", name));
L
Luo Tao 已提交
313 314 315
  return itensor_map_[name];
}

316 317 318 319
void TensorRTEngine::SetRuntimeBatch(size_t batch_size) {
  runtime_batch_ = batch_size;
}

320 321 322 323
float *TensorRTEngine::GetWeightCPUData(const std::string &name,
                                        framework::Tensor *weight_tensor,
                                        bool enable_int8,
                                        const std::vector<float> &scale) {
324 325
  static int name_suffix_counter = 0;
  std::string name_suffix = std::to_string(name_suffix_counter);
P
Pei Yang 已提交
326 327
  std::string splitter = "__";
  std::string name_with_suffix = name + splitter + name_suffix;
328
  platform::CPUPlace cpu_place;
329 330 331 332 333
  PADDLE_ENFORCE_EQ(weight_map.count(name_with_suffix), 0,
                    platform::errors::AlreadyExists(
                        "The weight named %s is set into the weight map "
                        "twice in TRT OP converter.",
                        name_with_suffix));
334 335 336 337 338 339
  weight_map[name_with_suffix].reset(new framework::Tensor());
  weight_map[name_with_suffix]->Resize(weight_tensor->dims());
  TensorCopySync(*weight_tensor, cpu_place, weight_map[name_with_suffix].get());
  float *weight_data =
      weight_map[name_with_suffix]->mutable_data<float>(cpu_place);
  name_suffix_counter += 1;
340 341 342
  return weight_data;
}

343 344
int TensorRTEngine::GetRuntimeBatch() { return runtime_batch_; }

345
nvinfer1::IPluginV2Layer *TensorRTEngine::AddPlugin(
346 347
    nvinfer1::ITensor *const *inputs, int num_inputs,
    plugin::PluginTensorRT *plugin) {
348
  owned_plugin_.emplace_back(plugin);
349
  return network()->addPluginV2(inputs, num_inputs, *plugin);
350 351
}

352 353 354 355 356 357 358
nvinfer1::IPluginV2Layer *TensorRTEngine::AddPluginV2Ext(
    nvinfer1::ITensor *const *inputs, int num_inputs,
    plugin::PluginTensorRTV2Ext *plugin) {
  owned_plugin_v2ext_.emplace_back(plugin);
  return network()->addPluginV2(inputs, num_inputs, *plugin);
}

N
nhzlx 已提交
359 360 361
void TensorRTEngine::freshDeviceId() {
  int count;
  cudaGetDeviceCount(&count);
362 363 364 365
  PADDLE_ENFORCE_LT(device_id_, count,
                    platform::errors::OutOfRange(
                        "Device id %d exceeds the current device count: %d.",
                        device_id_, count));
L
Leo Chen 已提交
366
  platform::SetDeviceId(device_id_);
N
nhzlx 已提交
367 368
}

Y
Yan Chunwei 已提交
369 370 371
}  // namespace tensorrt
}  // namespace inference
}  // namespace paddle