engine.cc 9.0 KB
Newer Older
Y
Yan Chunwei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/inference/tensorrt/engine.h"

#include <NvInfer.h>
#include <cuda.h>
#include <glog/logging.h>
A
Abhinav Arora 已提交
20
#include <string>
Y
Yan Chunwei 已提交
21
#include "paddle/fluid/inference/analysis/helper.h"
Y
Yan Chunwei 已提交
22 23 24 25 26 27 28
#include "paddle/fluid/inference/tensorrt/helper.h"
#include "paddle/fluid/platform/enforce.h"

namespace paddle {
namespace inference {
namespace tensorrt {

29 30
int TensorRTEngine::runtime_batch_ = 1;

31
void TensorRTEngine::Build(const DescType &paddle_model) {
Y
Yan Chunwei 已提交
32 33 34 35
  PADDLE_ENFORCE(false, "not implemented");
}

void TensorRTEngine::Execute(int batch_size) {
36 37 38
  batch_size_ = batch_size;
  std::vector<void *> buffers;
  for (auto &buf : buffers_) {
Y
Yan Chunwei 已提交
39 40 41 42 43
    PADDLE_ENFORCE_NOT_NULL(buf.buffer, "buffer should be allocated");
    PADDLE_ENFORCE_GT(buf.max_size, 0);
    PADDLE_ENFORCE(buf.device == DeviceType::GPU);
    buffers.push_back(buf.buffer);
  }
44
  PADDLE_ENFORCE_NOT_NULL(stream_);
Y
Yan Chunwei 已提交
45
  infer_context_->enqueue(batch_size, buffers.data(), *stream_, nullptr);
Y
Yan Chunwei 已提交
46
  cudaStreamSynchronize(*stream_);
47
  SetRuntimeBatch(batch_size);
Y
Yan Chunwei 已提交
48 49 50
}

TensorRTEngine::~TensorRTEngine() {
51
  cudaStreamSynchronize(*stream_);
Y
Yan Chunwei 已提交
52
  // clean buffer
53
  for (auto &buf : buffers_) {
54
    if (buf.device == DeviceType::GPU && buf.buffer != nullptr) {
Y
Yan Chunwei 已提交
55 56 57
      PADDLE_ENFORCE_EQ(0, cudaFree(buf.buffer));
      buf.buffer = nullptr;
      buf.max_size = 0;
Y
Yan Chunwei 已提交
58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
    }
  }
}

void TensorRTEngine::FreezeNetwork() {
  PADDLE_ENFORCE(infer_builder_ != nullptr,
                 "Call InitNetwork first to initialize network.");
  PADDLE_ENFORCE(infer_network_ != nullptr,
                 "Call InitNetwork first to initialize network.");
  // build engine.
  infer_builder_->setMaxBatchSize(max_batch_);
  infer_builder_->setMaxWorkspaceSize(max_workspace_);

  infer_engine_.reset(infer_builder_->buildCudaEngine(*infer_network_));
  PADDLE_ENFORCE(infer_engine_ != nullptr, "build cuda engine failed!");

  infer_context_.reset(infer_engine_->createExecutionContext());

  // allocate GPU buffers.
Y
Yan Chunwei 已提交
77
  buffers_.resize(buffer_sizes_.size());
78 79 80
  for (auto &item : buffer_sizes_) {
    // The output buffers are not set in the network building phrase, need to
    // infer from the TesorRT network.
Y
Yan Chunwei 已提交
81 82
    if (item.second == 0) {
      auto slot_offset = infer_engine_->getBindingIndex(item.first.c_str());
Y
Yan Chunwei 已提交
83
      auto dims = infer_engine_->getBindingDimensions(slot_offset);
Y
Yan Chunwei 已提交
84 85
      item.second = kDataTypeSize[static_cast<int>(
                        infer_engine_->getBindingDataType(slot_offset))] *
86
                    analysis::AccuDims(dims.d, dims.nbDims) * max_batch_;
87
      PADDLE_ENFORCE_GT(item.second, 0);
Y
Yan Chunwei 已提交
88
    }
89 90 91

    auto &buf = buffer(item.first);
    buf.max_size = item.second * max_batch_;
Y
Yan Chunwei 已提交
92
    CHECK(buf.buffer == nullptr);  // buffer should be allocated only once.
N
nhzlx 已提交
93

94
    PADDLE_ENFORCE_EQ(0, cudaMalloc(&buf.buffer, item.second * max_batch_));
95
    buf.size = 0;
N
nhzlx 已提交
96
    PADDLE_ENFORCE_LE(buf.max_size, 1 << 30);  // 10G
Y
Yan Chunwei 已提交
97
    buf.device = DeviceType::GPU;
Y
Yan Chunwei 已提交
98 99 100
  }
}

101
nvinfer1::ITensor *TensorRTEngine::DeclareInput(const std::string &name,
Y
Yan Chunwei 已提交
102
                                                nvinfer1::DataType dtype,
103
                                                const nvinfer1::Dims &dims) {
Y
Yan Chunwei 已提交
104 105 106 107
  PADDLE_ENFORCE_EQ(0, buffer_sizes_.count(name), "duplicate input name %s",
                    name);

  PADDLE_ENFORCE(infer_network_ != nullptr, "should initnetwork first");
108
  auto *input = infer_network_->addInput(name.c_str(), dtype, dims);
Y
Yan Chunwei 已提交
109
  PADDLE_ENFORCE(input, "infer network add input %s failed", name);
Y
Yan Chunwei 已提交
110
  buffer_sizes_[name] = kDataTypeSize[static_cast<int>(dtype)] *
111
                        analysis::AccuDims(dims.d, dims.nbDims) * max_batch_;
112
  PADDLE_ENFORCE(input->isNetworkInput());
L
Luo Tao 已提交
113
  TensorRTEngine::SetITensor(name, input);
Y
Yan Chunwei 已提交
114 115 116
  return input;
}

117 118
void TensorRTEngine::DeclareOutput(const nvinfer1::ILayer *layer, int offset,
                                   const std::string &name) {
Y
Yan Chunwei 已提交
119 120 121
  PADDLE_ENFORCE_EQ(0, buffer_sizes_.count(name), "duplicate output name %s",
                    name);

122
  auto *output = layer->getOutput(offset);
123
  SetITensor(name, output);
Y
Yan Chunwei 已提交
124 125
  PADDLE_ENFORCE(output != nullptr);
  output->setName(name.c_str());
126
  PADDLE_ENFORCE(!output->isNetworkInput());
Y
Yan Chunwei 已提交
127
  infer_network_->markOutput(*output);
128
  PADDLE_ENFORCE(output->isNetworkOutput());
Y
Yan Chunwei 已提交
129 130 131 132 133
  // output buffers' size can only be decided latter, set zero here to mark this
  // and will reset latter.
  buffer_sizes_[name] = 0;
}

134
void TensorRTEngine::DeclareOutput(const std::string &name) {
L
Luo Tao 已提交
135 136 137
  PADDLE_ENFORCE_EQ(0, buffer_sizes_.count(name), "duplicate output name %s",
                    name);

138
  auto *output = TensorRTEngine::GetITensor(name);
L
Luo Tao 已提交
139 140
  PADDLE_ENFORCE(output != nullptr);
  output->setName(name.c_str());
141
  PADDLE_ENFORCE(!output->isNetworkInput());
L
Luo Tao 已提交
142 143 144 145 146 147
  infer_network_->markOutput(*output);
  // output buffers' size can only be decided latter, set zero here to mark this
  // and will reset latter.
  buffer_sizes_[name] = 0;
}

148
void *TensorRTEngine::GetOutputInGPU(const std::string &name) {
Y
Yan Chunwei 已提交
149
  return buffer(name).buffer;
Y
Yan Chunwei 已提交
150 151
}

N
nhzlx 已提交
152
void TensorRTEngine::GetOutputInGPU(const std::string &name, void *dst) {
153
  // determine data size
N
nhzlx 已提交
154
  auto *output = TensorRTEngine::GetITensor(name);
155 156 157 158 159
  nvinfer1::Dims dims = output->getDimensions();
  auto dim_size = analysis::AccuDims(dims.d, dims.nbDims);
  size_t dst_size = dim_size * runtime_batch_ *
                    kDataTypeSize[static_cast<int>(output->getType())];

160 161 162
  auto it = buffer_sizes_.find(name);
  PADDLE_ENFORCE(it != buffer_sizes_.end());
  PADDLE_ENFORCE_GT(it->second, 0);
163
  PADDLE_ENFORCE_LE(dst_size, it->second);
164
  auto &buf = buffer(name);
165
  PADDLE_ENFORCE_NOT_NULL(buf.buffer, "buffer should be allocated before");
166
  PADDLE_ENFORCE_EQ(cudaMemcpyAsync(dst, buf.buffer, dst_size,
167 168 169 170
                                    cudaMemcpyDeviceToDevice, *stream_),
                    0);
}

N
nhzlx 已提交
171
void TensorRTEngine::GetOutputInCPU(const std::string &name, void *dst) {
Y
Yan Chunwei 已提交
172
  // determine data size
173

N
nhzlx 已提交
174
  auto *output = TensorRTEngine::GetITensor(name);
175 176 177 178
  nvinfer1::Dims dims = output->getDimensions();
  auto dim_size = analysis::AccuDims(dims.d, dims.nbDims);
  size_t dst_size = dim_size * runtime_batch_ *
                    kDataTypeSize[static_cast<int>(output->getType())];
Y
Yan Chunwei 已提交
179 180 181
  auto it = buffer_sizes_.find(name);
  PADDLE_ENFORCE(it != buffer_sizes_.end());
  PADDLE_ENFORCE_GT(it->second, 0);
182
  PADDLE_ENFORCE_LE(dst_size, it->second);
N
nhzlx 已提交
183
  auto &buf = buffer(name);
Y
Yan Chunwei 已提交
184
  PADDLE_ENFORCE_NOT_NULL(buf.buffer, "buffer should be allocated before");
185
  PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(dst, buf.buffer, dst_size,
Y
Yan Chunwei 已提交
186 187 188
                                       cudaMemcpyDeviceToHost, *stream_));
}

189
Buffer &TensorRTEngine::buffer(const std::string &name) {
Y
Yan Chunwei 已提交
190 191 192 193 194 195 196
  PADDLE_ENFORCE(infer_engine_ != nullptr, "call FreezeNetwork first.");
  auto it = buffer_sizes_.find(name);
  PADDLE_ENFORCE(it != buffer_sizes_.end());
  auto slot_offset = infer_engine_->getBindingIndex(name.c_str());
  return buffers_[slot_offset];
}

197
void TensorRTEngine::SetInputFromCPU(const std::string &name, const void *data,
Y
Yan Chunwei 已提交
198
                                     size_t size) {
199
  auto &buf = buffer(name);
Y
Yan Chunwei 已提交
200
  PADDLE_ENFORCE_NOT_NULL(buf.buffer);
201 202
  PADDLE_ENFORCE_NOT_NULL(data);
  PADDLE_ENFORCE_NOT_NULL(stream_);
Y
Yan Chunwei 已提交
203 204
  PADDLE_ENFORCE_LE(size, buf.max_size, "buffer is too small");
  PADDLE_ENFORCE(buf.device == DeviceType::GPU);
205
  buf.size = size;
Y
Yan Chunwei 已提交
206 207
  PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(buf.buffer, data, size,
                                       cudaMemcpyHostToDevice, *stream_));
Y
Yan Chunwei 已提交
208 209
}

210
void TensorRTEngine::SetInputFromGPU(const std::string &name, const void *data,
211
                                     size_t size) {
212 213
  auto &buf = buffer(name);
  buf.size = size;
214 215 216 217 218 219 220
  PADDLE_ENFORCE_NOT_NULL(buf.buffer);
  PADDLE_ENFORCE_LE(size, buf.max_size, "buffer is too small");
  PADDLE_ENFORCE(buf.device == DeviceType::GPU);
  PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(buf.buffer, data, size,
                                       cudaMemcpyDeviceToDevice, *stream_));
}

221 222
void TensorRTEngine::SetITensor(const std::string &name,
                                nvinfer1::ITensor *tensor) {
L
Luo Tao 已提交
223
  PADDLE_ENFORCE(tensor != nullptr);
Y
Yan Chunwei 已提交
224
  PADDLE_ENFORCE_EQ(0, itensor_map_.count(name), "duplicate ITensor name %s",
L
Luo Tao 已提交
225 226 227 228
                    name);
  itensor_map_[name] = tensor;
}

229
nvinfer1::ITensor *TensorRTEngine::GetITensor(const std::string &name) {
Y
Yan Chunwei 已提交
230
  PADDLE_ENFORCE(itensor_map_.count(name), "no ITensor %s", name);
L
Luo Tao 已提交
231 232 233
  return itensor_map_[name];
}

234 235 236 237 238 239
void TensorRTEngine::SetRuntimeBatch(size_t batch_size) {
  runtime_batch_ = batch_size;
}

int TensorRTEngine::GetRuntimeBatch() { return runtime_batch_; }

Y
Yan Chunwei 已提交
240 241 242
}  // namespace tensorrt
}  // namespace inference
}  // namespace paddle