engine.cc 9.9 KB
Newer Older
Y
Yan Chunwei 已提交
1 2
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

N
nhzlx 已提交
3 4
Licensed under the Apache License, Version 2.0 (the "License"); you may not use
this file except in compliance with the License.
Y
Yan Chunwei 已提交
5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/inference/tensorrt/engine.h"

#include <NvInfer.h>
#include <cuda.h>
#include <glog/logging.h>
A
Abhinav Arora 已提交
20
#include <string>
Y
Yan Chunwei 已提交
21
#include "paddle/fluid/inference/analysis/helper.h"
Y
Yan Chunwei 已提交
22 23 24 25 26 27 28
#include "paddle/fluid/inference/tensorrt/helper.h"
#include "paddle/fluid/platform/enforce.h"

namespace paddle {
namespace inference {
namespace tensorrt {

29 30
int TensorRTEngine::runtime_batch_ = 1;

31
void TensorRTEngine::Build(const DescType &paddle_model) {
Y
Yan Chunwei 已提交
32 33 34 35
  PADDLE_ENFORCE(false, "not implemented");
}

void TensorRTEngine::Execute(int batch_size) {
N
nhzlx 已提交
36
  freshDeviceId();
37 38 39
  batch_size_ = batch_size;
  std::vector<void *> buffers;
  for (auto &buf : buffers_) {
Y
Yan Chunwei 已提交
40 41 42 43 44
    PADDLE_ENFORCE_NOT_NULL(buf.buffer, "buffer should be allocated");
    PADDLE_ENFORCE_GT(buf.max_size, 0);
    PADDLE_ENFORCE(buf.device == DeviceType::GPU);
    buffers.push_back(buf.buffer);
  }
N
nhzlx 已提交
45 46
  infer_context_->enqueue(batch_size, buffers.data(), stream_, nullptr);
  cudaStreamSynchronize(stream_);
47
  SetRuntimeBatch(batch_size);
Y
Yan Chunwei 已提交
48 49 50
}

TensorRTEngine::~TensorRTEngine() {
N
nhzlx 已提交
51
  cudaStreamSynchronize(stream_);
Y
Yan Chunwei 已提交
52
  // clean buffer
53
  for (auto &buf : buffers_) {
54
    if (buf.device == DeviceType::GPU && buf.buffer != nullptr) {
Y
Yan Chunwei 已提交
55 56 57
      PADDLE_ENFORCE_EQ(0, cudaFree(buf.buffer));
      buf.buffer = nullptr;
      buf.max_size = 0;
Y
Yan Chunwei 已提交
58 59 60 61 62
    }
  }
}

void TensorRTEngine::FreezeNetwork() {
63
  VLOG(3) << "TRT to freeze network";
N
nhzlx 已提交
64
  freshDeviceId();
Y
Yan Chunwei 已提交
65 66 67 68 69 70 71
  PADDLE_ENFORCE(infer_builder_ != nullptr,
                 "Call InitNetwork first to initialize network.");
  PADDLE_ENFORCE(infer_network_ != nullptr,
                 "Call InitNetwork first to initialize network.");
  // build engine.
  infer_builder_->setMaxBatchSize(max_batch_);
  infer_builder_->setMaxWorkspaceSize(max_workspace_);
N
nhzlx 已提交
72
  if (enable_int8_) {
N
nhzlx 已提交
73 74 75 76 77 78
    infer_builder_->setInt8Mode(true);
    PADDLE_ENFORCE(
        calibrator_ != nullptr,
        "The precision mode is 'INT8', the calibrator should not be nullptr");
    infer_builder_->setInt8Calibrator(calibrator_);
  }
Y
Yan Chunwei 已提交
79 80 81 82 83 84 85

  infer_engine_.reset(infer_builder_->buildCudaEngine(*infer_network_));
  PADDLE_ENFORCE(infer_engine_ != nullptr, "build cuda engine failed!");

  infer_context_.reset(infer_engine_->createExecutionContext());

  // allocate GPU buffers.
Y
Yan Chunwei 已提交
86
  buffers_.resize(buffer_sizes_.size());
87 88 89
  for (auto &item : buffer_sizes_) {
    // The output buffers are not set in the network building phrase, need to
    // infer from the TesorRT network.
Y
Yan Chunwei 已提交
90 91
    if (item.second == 0) {
      auto slot_offset = infer_engine_->getBindingIndex(item.first.c_str());
Y
Yan Chunwei 已提交
92
      auto dims = infer_engine_->getBindingDimensions(slot_offset);
Y
Yan Chunwei 已提交
93 94
      item.second = kDataTypeSize[static_cast<int>(
                        infer_engine_->getBindingDataType(slot_offset))] *
95
                    analysis::AccuDims(dims.d, dims.nbDims) * max_batch_;
96
      PADDLE_ENFORCE_GT(item.second, 0);
Y
Yan Chunwei 已提交
97
    }
98 99 100

    auto &buf = buffer(item.first);
    buf.max_size = item.second * max_batch_;
Y
Yan Chunwei 已提交
101
    CHECK(buf.buffer == nullptr);  // buffer should be allocated only once.
N
nhzlx 已提交
102

103
    PADDLE_ENFORCE_EQ(0, cudaMalloc(&buf.buffer, item.second * max_batch_));
104
    buf.size = 0;
N
nhzlx 已提交
105
    PADDLE_ENFORCE_LE(buf.max_size, 1 << 30);  // 10G
Y
Yan Chunwei 已提交
106
    buf.device = DeviceType::GPU;
Y
Yan Chunwei 已提交
107 108 109
  }
}

110
nvinfer1::ITensor *TensorRTEngine::DeclareInput(const std::string &name,
Y
Yan Chunwei 已提交
111
                                                nvinfer1::DataType dtype,
112
                                                const nvinfer1::Dims &dims) {
Y
Yan Chunwei 已提交
113 114 115 116
  PADDLE_ENFORCE_EQ(0, buffer_sizes_.count(name), "duplicate input name %s",
                    name);

  PADDLE_ENFORCE(infer_network_ != nullptr, "should initnetwork first");
117
  auto *input = infer_network_->addInput(name.c_str(), dtype, dims);
Y
Yan Chunwei 已提交
118
  PADDLE_ENFORCE(input, "infer network add input %s failed", name);
Y
Yan Chunwei 已提交
119
  buffer_sizes_[name] = kDataTypeSize[static_cast<int>(dtype)] *
120
                        analysis::AccuDims(dims.d, dims.nbDims) * max_batch_;
121
  PADDLE_ENFORCE(input->isNetworkInput());
L
Luo Tao 已提交
122
  TensorRTEngine::SetITensor(name, input);
Y
Yan Chunwei 已提交
123 124 125
  return input;
}

126 127
void TensorRTEngine::DeclareOutput(const nvinfer1::ILayer *layer, int offset,
                                   const std::string &name) {
Y
Yan Chunwei 已提交
128 129 130
  PADDLE_ENFORCE_EQ(0, buffer_sizes_.count(name), "duplicate output name %s",
                    name);

131
  auto *output = layer->getOutput(offset);
132
  SetITensor(name, output);
Y
Yan Chunwei 已提交
133 134
  PADDLE_ENFORCE(output != nullptr);
  output->setName(name.c_str());
135
  PADDLE_ENFORCE(!output->isNetworkInput());
Y
Yan Chunwei 已提交
136
  infer_network_->markOutput(*output);
137
  PADDLE_ENFORCE(output->isNetworkOutput());
Y
Yan Chunwei 已提交
138 139 140 141 142
  // output buffers' size can only be decided latter, set zero here to mark this
  // and will reset latter.
  buffer_sizes_[name] = 0;
}

N
nhzlx 已提交
143 144 145 146
bool TensorRTEngine::HasDeclared(const std::string &name) {
  return buffer_sizes_.count(name) > 0;
}

147
void TensorRTEngine::DeclareOutput(const std::string &name) {
L
Luo Tao 已提交
148 149 150
  PADDLE_ENFORCE_EQ(0, buffer_sizes_.count(name), "duplicate output name %s",
                    name);

151
  auto *output = TensorRTEngine::GetITensor(name);
L
Luo Tao 已提交
152 153
  PADDLE_ENFORCE(output != nullptr);
  output->setName(name.c_str());
154
  PADDLE_ENFORCE(!output->isNetworkInput());
L
Luo Tao 已提交
155 156 157 158 159 160
  infer_network_->markOutput(*output);
  // output buffers' size can only be decided latter, set zero here to mark this
  // and will reset latter.
  buffer_sizes_[name] = 0;
}

161
void *TensorRTEngine::GetOutputInGPU(const std::string &name) {
Y
Yan Chunwei 已提交
162
  return buffer(name).buffer;
Y
Yan Chunwei 已提交
163 164
}

N
nhzlx 已提交
165 166
void TensorRTEngine::GetOutputInGPU(const std::string &name, void *dst,
                                    size_t max_size) {
167
  // determine data size
N
nhzlx 已提交
168
  auto *output = TensorRTEngine::GetITensor(name);
169 170 171 172 173
  nvinfer1::Dims dims = output->getDimensions();
  auto dim_size = analysis::AccuDims(dims.d, dims.nbDims);
  size_t dst_size = dim_size * runtime_batch_ *
                    kDataTypeSize[static_cast<int>(output->getType())];

174 175 176
  auto it = buffer_sizes_.find(name);
  PADDLE_ENFORCE(it != buffer_sizes_.end());
  PADDLE_ENFORCE_GT(it->second, 0);
177
  PADDLE_ENFORCE_LE(dst_size, it->second);
N
nhzlx 已提交
178
  PADDLE_ENFORCE_GE(max_size, dst_size);
179
  auto &buf = buffer(name);
180
  PADDLE_ENFORCE_NOT_NULL(buf.buffer, "buffer should be allocated before");
181
  PADDLE_ENFORCE_EQ(cudaMemcpyAsync(dst, buf.buffer, dst_size,
N
nhzlx 已提交
182
                                    cudaMemcpyDeviceToDevice, stream_),
183 184 185
                    0);
}

N
nhzlx 已提交
186 187
void TensorRTEngine::GetOutputInCPU(const std::string &name, void *dst,
                                    size_t max_size) {
Y
Yan Chunwei 已提交
188
  // determine data size
189

N
nhzlx 已提交
190
  auto *output = TensorRTEngine::GetITensor(name);
191 192 193 194
  nvinfer1::Dims dims = output->getDimensions();
  auto dim_size = analysis::AccuDims(dims.d, dims.nbDims);
  size_t dst_size = dim_size * runtime_batch_ *
                    kDataTypeSize[static_cast<int>(output->getType())];
Y
Yan Chunwei 已提交
195 196 197
  auto it = buffer_sizes_.find(name);
  PADDLE_ENFORCE(it != buffer_sizes_.end());
  PADDLE_ENFORCE_GT(it->second, 0);
198
  PADDLE_ENFORCE_LE(dst_size, it->second);
N
nhzlx 已提交
199
  PADDLE_ENFORCE_GE(max_size, dst_size);
N
nhzlx 已提交
200
  auto &buf = buffer(name);
Y
Yan Chunwei 已提交
201
  PADDLE_ENFORCE_NOT_NULL(buf.buffer, "buffer should be allocated before");
202
  PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(dst, buf.buffer, dst_size,
N
nhzlx 已提交
203
                                       cudaMemcpyDeviceToHost, stream_));
Y
Yan Chunwei 已提交
204 205
}

206
Buffer &TensorRTEngine::buffer(const std::string &name) {
Y
Yan Chunwei 已提交
207 208
  PADDLE_ENFORCE(infer_engine_ != nullptr, "call FreezeNetwork first.");
  auto it = buffer_sizes_.find(name);
209 210
  PADDLE_ENFORCE(it != buffer_sizes_.end(), "tried to access buffer named %s",
                 name);
Y
Yan Chunwei 已提交
211 212 213 214
  auto slot_offset = infer_engine_->getBindingIndex(name.c_str());
  return buffers_[slot_offset];
}

215
void TensorRTEngine::SetInputFromCPU(const std::string &name, const void *data,
Y
Yan Chunwei 已提交
216
                                     size_t size) {
217
  auto &buf = buffer(name);
Y
Yan Chunwei 已提交
218
  PADDLE_ENFORCE_NOT_NULL(buf.buffer);
219
  PADDLE_ENFORCE_NOT_NULL(data);
Y
Yan Chunwei 已提交
220 221
  PADDLE_ENFORCE_LE(size, buf.max_size, "buffer is too small");
  PADDLE_ENFORCE(buf.device == DeviceType::GPU);
222
  buf.size = size;
Y
Yan Chunwei 已提交
223
  PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(buf.buffer, data, size,
N
nhzlx 已提交
224
                                       cudaMemcpyHostToDevice, stream_));
Y
Yan Chunwei 已提交
225 226
}

227
void TensorRTEngine::SetInputFromGPU(const std::string &name, const void *data,
228
                                     size_t size) {
229 230
  auto &buf = buffer(name);
  buf.size = size;
231 232 233 234
  PADDLE_ENFORCE_NOT_NULL(buf.buffer);
  PADDLE_ENFORCE_LE(size, buf.max_size, "buffer is too small");
  PADDLE_ENFORCE(buf.device == DeviceType::GPU);
  PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(buf.buffer, data, size,
N
nhzlx 已提交
235
                                       cudaMemcpyDeviceToDevice, stream_));
236 237
}

238 239
void TensorRTEngine::SetITensor(const std::string &name,
                                nvinfer1::ITensor *tensor) {
L
Luo Tao 已提交
240
  PADDLE_ENFORCE(tensor != nullptr);
Y
Yan Chunwei 已提交
241
  PADDLE_ENFORCE_EQ(0, itensor_map_.count(name), "duplicate ITensor name %s",
L
Luo Tao 已提交
242 243 244 245
                    name);
  itensor_map_[name] = tensor;
}

246
nvinfer1::ITensor *TensorRTEngine::GetITensor(const std::string &name) {
Y
Yan Chunwei 已提交
247
  PADDLE_ENFORCE(itensor_map_.count(name), "no ITensor %s", name);
L
Luo Tao 已提交
248 249 250
  return itensor_map_[name];
}

251 252 253 254 255 256
void TensorRTEngine::SetRuntimeBatch(size_t batch_size) {
  runtime_batch_ = batch_size;
}

int TensorRTEngine::GetRuntimeBatch() { return runtime_batch_; }

N
nhzlx 已提交
257 258 259 260 261 262 263
void TensorRTEngine::freshDeviceId() {
  int count;
  cudaGetDeviceCount(&count);
  PADDLE_ENFORCE_LT(device_, count);
  cudaSetDevice(device_);
}

N
nhzlx 已提交
264
nvinfer1::IPluginLayer *TensorRTEngine::AddPlugin(
265 266
    nvinfer1::ITensor *const *inputs, int num_inputs,
    plugin::PluginTensorRT *plugin) {
267
  owned_plugin_.emplace_back(plugin);
268
  return infer_network_.get()->addPluginExt(inputs, num_inputs, *plugin);
269 270
}

Y
Yan Chunwei 已提交
271 272 273
}  // namespace tensorrt
}  // namespace inference
}  // namespace paddle