engine.cc 10.0 KB
Newer Older
Y
Yan Chunwei 已提交
1 2
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

N
nhzlx 已提交
3 4
Licensed under the Apache License, Version 2.0 (the "License"); you may not use
this file except in compliance with the License.
Y
Yan Chunwei 已提交
5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/inference/tensorrt/engine.h"

#include <NvInfer.h>
#include <cuda.h>
#include <glog/logging.h>
A
Abhinav Arora 已提交
20
#include <string>
Y
Yan Chunwei 已提交
21
#include "paddle/fluid/inference/analysis/helper.h"
Y
Yan Chunwei 已提交
22 23 24 25 26 27 28
#include "paddle/fluid/inference/tensorrt/helper.h"
#include "paddle/fluid/platform/enforce.h"

namespace paddle {
namespace inference {
namespace tensorrt {

29 30
int TensorRTEngine::runtime_batch_ = 1;

31
void TensorRTEngine::Build(const DescType &paddle_model) {
Y
Yan Chunwei 已提交
32 33 34 35
  PADDLE_ENFORCE(false, "not implemented");
}

void TensorRTEngine::Execute(int batch_size) {
N
nhzlx 已提交
36
  freshDeviceId();
37 38 39
  batch_size_ = batch_size;
  std::vector<void *> buffers;
  for (auto &buf : buffers_) {
Y
Yan Chunwei 已提交
40 41 42 43 44
    PADDLE_ENFORCE_NOT_NULL(buf.buffer, "buffer should be allocated");
    PADDLE_ENFORCE_GT(buf.max_size, 0);
    PADDLE_ENFORCE(buf.device == DeviceType::GPU);
    buffers.push_back(buf.buffer);
  }
45
  PADDLE_ENFORCE_NOT_NULL(stream_);
Y
Yan Chunwei 已提交
46
  infer_context_->enqueue(batch_size, buffers.data(), *stream_, nullptr);
Y
Yan Chunwei 已提交
47
  cudaStreamSynchronize(*stream_);
48
  SetRuntimeBatch(batch_size);
Y
Yan Chunwei 已提交
49 50 51
}

TensorRTEngine::~TensorRTEngine() {
52
  cudaStreamSynchronize(*stream_);
Y
Yan Chunwei 已提交
53
  // clean buffer
54
  for (auto &buf : buffers_) {
55
    if (buf.device == DeviceType::GPU && buf.buffer != nullptr) {
Y
Yan Chunwei 已提交
56 57 58
      PADDLE_ENFORCE_EQ(0, cudaFree(buf.buffer));
      buf.buffer = nullptr;
      buf.max_size = 0;
Y
Yan Chunwei 已提交
59 60 61 62 63
    }
  }
}

void TensorRTEngine::FreezeNetwork() {
64
  VLOG(3) << "TRT to freeze network";
N
nhzlx 已提交
65
  freshDeviceId();
Y
Yan Chunwei 已提交
66 67 68 69 70 71 72
  PADDLE_ENFORCE(infer_builder_ != nullptr,
                 "Call InitNetwork first to initialize network.");
  PADDLE_ENFORCE(infer_network_ != nullptr,
                 "Call InitNetwork first to initialize network.");
  // build engine.
  infer_builder_->setMaxBatchSize(max_batch_);
  infer_builder_->setMaxWorkspaceSize(max_workspace_);
N
nhzlx 已提交
73
  if (enable_int8_) {
N
nhzlx 已提交
74 75 76 77 78 79
    infer_builder_->setInt8Mode(true);
    PADDLE_ENFORCE(
        calibrator_ != nullptr,
        "The precision mode is 'INT8', the calibrator should not be nullptr");
    infer_builder_->setInt8Calibrator(calibrator_);
  }
Y
Yan Chunwei 已提交
80 81 82 83 84 85 86

  infer_engine_.reset(infer_builder_->buildCudaEngine(*infer_network_));
  PADDLE_ENFORCE(infer_engine_ != nullptr, "build cuda engine failed!");

  infer_context_.reset(infer_engine_->createExecutionContext());

  // allocate GPU buffers.
Y
Yan Chunwei 已提交
87
  buffers_.resize(buffer_sizes_.size());
88 89 90
  for (auto &item : buffer_sizes_) {
    // The output buffers are not set in the network building phrase, need to
    // infer from the TesorRT network.
Y
Yan Chunwei 已提交
91 92
    if (item.second == 0) {
      auto slot_offset = infer_engine_->getBindingIndex(item.first.c_str());
Y
Yan Chunwei 已提交
93
      auto dims = infer_engine_->getBindingDimensions(slot_offset);
Y
Yan Chunwei 已提交
94 95
      item.second = kDataTypeSize[static_cast<int>(
                        infer_engine_->getBindingDataType(slot_offset))] *
96
                    analysis::AccuDims(dims.d, dims.nbDims) * max_batch_;
97
      PADDLE_ENFORCE_GT(item.second, 0);
Y
Yan Chunwei 已提交
98
    }
99 100 101

    auto &buf = buffer(item.first);
    buf.max_size = item.second * max_batch_;
Y
Yan Chunwei 已提交
102
    CHECK(buf.buffer == nullptr);  // buffer should be allocated only once.
N
nhzlx 已提交
103

104
    PADDLE_ENFORCE_EQ(0, cudaMalloc(&buf.buffer, item.second * max_batch_));
105
    buf.size = 0;
N
nhzlx 已提交
106
    PADDLE_ENFORCE_LE(buf.max_size, 1 << 30);  // 10G
Y
Yan Chunwei 已提交
107
    buf.device = DeviceType::GPU;
Y
Yan Chunwei 已提交
108 109 110
  }
}

111
nvinfer1::ITensor *TensorRTEngine::DeclareInput(const std::string &name,
Y
Yan Chunwei 已提交
112
                                                nvinfer1::DataType dtype,
113
                                                const nvinfer1::Dims &dims) {
Y
Yan Chunwei 已提交
114 115 116 117
  PADDLE_ENFORCE_EQ(0, buffer_sizes_.count(name), "duplicate input name %s",
                    name);

  PADDLE_ENFORCE(infer_network_ != nullptr, "should initnetwork first");
118
  auto *input = infer_network_->addInput(name.c_str(), dtype, dims);
Y
Yan Chunwei 已提交
119
  PADDLE_ENFORCE(input, "infer network add input %s failed", name);
Y
Yan Chunwei 已提交
120
  buffer_sizes_[name] = kDataTypeSize[static_cast<int>(dtype)] *
121
                        analysis::AccuDims(dims.d, dims.nbDims) * max_batch_;
122
  PADDLE_ENFORCE(input->isNetworkInput());
L
Luo Tao 已提交
123
  TensorRTEngine::SetITensor(name, input);
Y
Yan Chunwei 已提交
124 125 126
  return input;
}

127 128
void TensorRTEngine::DeclareOutput(const nvinfer1::ILayer *layer, int offset,
                                   const std::string &name) {
Y
Yan Chunwei 已提交
129 130 131
  PADDLE_ENFORCE_EQ(0, buffer_sizes_.count(name), "duplicate output name %s",
                    name);

132
  auto *output = layer->getOutput(offset);
133
  SetITensor(name, output);
Y
Yan Chunwei 已提交
134 135
  PADDLE_ENFORCE(output != nullptr);
  output->setName(name.c_str());
136
  PADDLE_ENFORCE(!output->isNetworkInput());
Y
Yan Chunwei 已提交
137
  infer_network_->markOutput(*output);
138
  PADDLE_ENFORCE(output->isNetworkOutput());
Y
Yan Chunwei 已提交
139 140 141 142 143
  // output buffers' size can only be decided latter, set zero here to mark this
  // and will reset latter.
  buffer_sizes_[name] = 0;
}

N
nhzlx 已提交
144 145 146 147
bool TensorRTEngine::HasDeclared(const std::string &name) {
  return buffer_sizes_.count(name) > 0;
}

148
void TensorRTEngine::DeclareOutput(const std::string &name) {
L
Luo Tao 已提交
149 150 151
  PADDLE_ENFORCE_EQ(0, buffer_sizes_.count(name), "duplicate output name %s",
                    name);

152
  auto *output = TensorRTEngine::GetITensor(name);
L
Luo Tao 已提交
153 154
  PADDLE_ENFORCE(output != nullptr);
  output->setName(name.c_str());
155
  PADDLE_ENFORCE(!output->isNetworkInput());
L
Luo Tao 已提交
156 157 158 159 160 161
  infer_network_->markOutput(*output);
  // output buffers' size can only be decided latter, set zero here to mark this
  // and will reset latter.
  buffer_sizes_[name] = 0;
}

162
void *TensorRTEngine::GetOutputInGPU(const std::string &name) {
Y
Yan Chunwei 已提交
163
  return buffer(name).buffer;
Y
Yan Chunwei 已提交
164 165
}

N
nhzlx 已提交
166 167
void TensorRTEngine::GetOutputInGPU(const std::string &name, void *dst,
                                    size_t max_size) {
168
  // determine data size
N
nhzlx 已提交
169
  auto *output = TensorRTEngine::GetITensor(name);
170 171 172 173 174
  nvinfer1::Dims dims = output->getDimensions();
  auto dim_size = analysis::AccuDims(dims.d, dims.nbDims);
  size_t dst_size = dim_size * runtime_batch_ *
                    kDataTypeSize[static_cast<int>(output->getType())];

175 176 177
  auto it = buffer_sizes_.find(name);
  PADDLE_ENFORCE(it != buffer_sizes_.end());
  PADDLE_ENFORCE_GT(it->second, 0);
178
  PADDLE_ENFORCE_LE(dst_size, it->second);
N
nhzlx 已提交
179
  PADDLE_ENFORCE_GE(max_size, dst_size);
180
  auto &buf = buffer(name);
181
  PADDLE_ENFORCE_NOT_NULL(buf.buffer, "buffer should be allocated before");
182
  PADDLE_ENFORCE_EQ(cudaMemcpyAsync(dst, buf.buffer, dst_size,
183 184 185 186
                                    cudaMemcpyDeviceToDevice, *stream_),
                    0);
}

N
nhzlx 已提交
187 188
void TensorRTEngine::GetOutputInCPU(const std::string &name, void *dst,
                                    size_t max_size) {
Y
Yan Chunwei 已提交
189
  // determine data size
190

N
nhzlx 已提交
191
  auto *output = TensorRTEngine::GetITensor(name);
192 193 194 195
  nvinfer1::Dims dims = output->getDimensions();
  auto dim_size = analysis::AccuDims(dims.d, dims.nbDims);
  size_t dst_size = dim_size * runtime_batch_ *
                    kDataTypeSize[static_cast<int>(output->getType())];
Y
Yan Chunwei 已提交
196 197 198
  auto it = buffer_sizes_.find(name);
  PADDLE_ENFORCE(it != buffer_sizes_.end());
  PADDLE_ENFORCE_GT(it->second, 0);
199
  PADDLE_ENFORCE_LE(dst_size, it->second);
N
nhzlx 已提交
200
  PADDLE_ENFORCE_GE(max_size, dst_size);
N
nhzlx 已提交
201
  auto &buf = buffer(name);
Y
Yan Chunwei 已提交
202
  PADDLE_ENFORCE_NOT_NULL(buf.buffer, "buffer should be allocated before");
203
  PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(dst, buf.buffer, dst_size,
Y
Yan Chunwei 已提交
204 205 206
                                       cudaMemcpyDeviceToHost, *stream_));
}

207
Buffer &TensorRTEngine::buffer(const std::string &name) {
Y
Yan Chunwei 已提交
208 209
  PADDLE_ENFORCE(infer_engine_ != nullptr, "call FreezeNetwork first.");
  auto it = buffer_sizes_.find(name);
210 211
  PADDLE_ENFORCE(it != buffer_sizes_.end(), "tried to access buffer named %s",
                 name);
Y
Yan Chunwei 已提交
212 213 214 215
  auto slot_offset = infer_engine_->getBindingIndex(name.c_str());
  return buffers_[slot_offset];
}

216
void TensorRTEngine::SetInputFromCPU(const std::string &name, const void *data,
Y
Yan Chunwei 已提交
217
                                     size_t size) {
218
  auto &buf = buffer(name);
Y
Yan Chunwei 已提交
219
  PADDLE_ENFORCE_NOT_NULL(buf.buffer);
220 221
  PADDLE_ENFORCE_NOT_NULL(data);
  PADDLE_ENFORCE_NOT_NULL(stream_);
Y
Yan Chunwei 已提交
222 223
  PADDLE_ENFORCE_LE(size, buf.max_size, "buffer is too small");
  PADDLE_ENFORCE(buf.device == DeviceType::GPU);
224
  buf.size = size;
Y
Yan Chunwei 已提交
225 226
  PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(buf.buffer, data, size,
                                       cudaMemcpyHostToDevice, *stream_));
Y
Yan Chunwei 已提交
227 228
}

229
void TensorRTEngine::SetInputFromGPU(const std::string &name, const void *data,
230
                                     size_t size) {
231 232
  auto &buf = buffer(name);
  buf.size = size;
233 234 235 236 237 238 239
  PADDLE_ENFORCE_NOT_NULL(buf.buffer);
  PADDLE_ENFORCE_LE(size, buf.max_size, "buffer is too small");
  PADDLE_ENFORCE(buf.device == DeviceType::GPU);
  PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(buf.buffer, data, size,
                                       cudaMemcpyDeviceToDevice, *stream_));
}

240 241
void TensorRTEngine::SetITensor(const std::string &name,
                                nvinfer1::ITensor *tensor) {
L
Luo Tao 已提交
242
  PADDLE_ENFORCE(tensor != nullptr);
Y
Yan Chunwei 已提交
243
  PADDLE_ENFORCE_EQ(0, itensor_map_.count(name), "duplicate ITensor name %s",
L
Luo Tao 已提交
244 245 246 247
                    name);
  itensor_map_[name] = tensor;
}

248
nvinfer1::ITensor *TensorRTEngine::GetITensor(const std::string &name) {
Y
Yan Chunwei 已提交
249
  PADDLE_ENFORCE(itensor_map_.count(name), "no ITensor %s", name);
L
Luo Tao 已提交
250 251 252
  return itensor_map_[name];
}

253 254 255 256 257 258
void TensorRTEngine::SetRuntimeBatch(size_t batch_size) {
  runtime_batch_ = batch_size;
}

int TensorRTEngine::GetRuntimeBatch() { return runtime_batch_; }

N
nhzlx 已提交
259 260 261 262 263 264 265
void TensorRTEngine::freshDeviceId() {
  int count;
  cudaGetDeviceCount(&count);
  PADDLE_ENFORCE_LT(device_, count);
  cudaSetDevice(device_);
}

N
nhzlx 已提交
266
nvinfer1::IPluginLayer *TensorRTEngine::AddPlugin(
267 268
    nvinfer1::ITensor *const *inputs, int num_inputs,
    plugin::PluginTensorRT *plugin) {
269
  owned_plugin_.emplace_back(plugin);
270
  return infer_network_.get()->addPluginExt(inputs, num_inputs, *plugin);
271 272
}

Y
Yan Chunwei 已提交
273 274 275
}  // namespace tensorrt
}  // namespace inference
}  // namespace paddle