You need to sign in or sign up before continuing.
engine.cc 8.5 KB
Newer Older
Y
Yan Chunwei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/inference/tensorrt/engine.h"

#include <NvInfer.h>
#include <cuda.h>
#include <glog/logging.h>
A
Abhinav Arora 已提交
20
#include <string>
Y
Yan Chunwei 已提交
21
#include "paddle/fluid/inference/analysis/helper.h"
Y
Yan Chunwei 已提交
22 23 24 25 26 27 28
#include "paddle/fluid/inference/tensorrt/helper.h"
#include "paddle/fluid/platform/enforce.h"

namespace paddle {
namespace inference {
namespace tensorrt {

29
void TensorRTEngine::Build(const DescType &paddle_model) {
Y
Yan Chunwei 已提交
30 31 32 33
  PADDLE_ENFORCE(false, "not implemented");
}

void TensorRTEngine::Execute(int batch_size) {
34 35 36
  batch_size_ = batch_size;
  std::vector<void *> buffers;
  for (auto &buf : buffers_) {
Y
Yan Chunwei 已提交
37 38 39 40 41
    PADDLE_ENFORCE_NOT_NULL(buf.buffer, "buffer should be allocated");
    PADDLE_ENFORCE_GT(buf.max_size, 0);
    PADDLE_ENFORCE(buf.device == DeviceType::GPU);
    buffers.push_back(buf.buffer);
  }
42
  PADDLE_ENFORCE_NOT_NULL(stream_);
Y
Yan Chunwei 已提交
43
  infer_context_->enqueue(batch_size, buffers.data(), *stream_, nullptr);
Y
Yan Chunwei 已提交
44 45 46 47
  cudaStreamSynchronize(*stream_);
}

TensorRTEngine::~TensorRTEngine() {
48
  cudaStreamSynchronize(*stream_);
Y
Yan Chunwei 已提交
49
  // clean buffer
50
  for (auto &buf : buffers_) {
51
    if (buf.device == DeviceType::GPU && buf.buffer != nullptr) {
Y
Yan Chunwei 已提交
52 53 54
      PADDLE_ENFORCE_EQ(0, cudaFree(buf.buffer));
      buf.buffer = nullptr;
      buf.max_size = 0;
Y
Yan Chunwei 已提交
55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
    }
  }
}

void TensorRTEngine::FreezeNetwork() {
  PADDLE_ENFORCE(infer_builder_ != nullptr,
                 "Call InitNetwork first to initialize network.");
  PADDLE_ENFORCE(infer_network_ != nullptr,
                 "Call InitNetwork first to initialize network.");
  // build engine.
  infer_builder_->setMaxBatchSize(max_batch_);
  infer_builder_->setMaxWorkspaceSize(max_workspace_);

  infer_engine_.reset(infer_builder_->buildCudaEngine(*infer_network_));
  PADDLE_ENFORCE(infer_engine_ != nullptr, "build cuda engine failed!");

  infer_context_.reset(infer_engine_->createExecutionContext());

  // allocate GPU buffers.
Y
Yan Chunwei 已提交
74
  buffers_.resize(buffer_sizes_.size());
75 76 77
  for (auto &item : buffer_sizes_) {
    // The output buffers are not set in the network building phrase, need to
    // infer from the TesorRT network.
Y
Yan Chunwei 已提交
78 79
    if (item.second == 0) {
      auto slot_offset = infer_engine_->getBindingIndex(item.first.c_str());
Y
Yan Chunwei 已提交
80
      auto dims = infer_engine_->getBindingDimensions(slot_offset);
Y
Yan Chunwei 已提交
81 82
      item.second = kDataTypeSize[static_cast<int>(
                        infer_engine_->getBindingDataType(slot_offset))] *
Y
Yan Chunwei 已提交
83
                    analysis::AccuDims(dims.d, dims.nbDims);
84
      PADDLE_ENFORCE_GT(item.second, 0);
Y
Yan Chunwei 已提交
85
    }
86 87 88

    auto &buf = buffer(item.first);
    buf.max_size = item.second * max_batch_;
Y
Yan Chunwei 已提交
89
    CHECK(buf.buffer == nullptr);  // buffer should be allocated only once.
90 91 92 93
    PADDLE_ENFORCE_EQ(0, cudaMalloc(&buf.buffer, buf.max_size));
    PADDLE_ENFORCE_LE(buf.max_size, 1 << 30);  // 10G
    // buf.size will changed in the runtime.
    buf.size = 0;
Y
Yan Chunwei 已提交
94
    buf.device = DeviceType::GPU;
Y
Yan Chunwei 已提交
95 96 97
  }
}

98
nvinfer1::ITensor *TensorRTEngine::DeclareInput(const std::string &name,
Y
Yan Chunwei 已提交
99
                                                nvinfer1::DataType dtype,
100
                                                const nvinfer1::Dims &dims) {
Y
Yan Chunwei 已提交
101 102 103 104
  PADDLE_ENFORCE_EQ(0, buffer_sizes_.count(name), "duplicate input name %s",
                    name);

  PADDLE_ENFORCE(infer_network_ != nullptr, "should initnetwork first");
105
  auto *input = infer_network_->addInput(name.c_str(), dtype, dims);
Y
Yan Chunwei 已提交
106
  PADDLE_ENFORCE(input, "infer network add input %s failed", name);
Y
Yan Chunwei 已提交
107 108
  buffer_sizes_[name] = kDataTypeSize[static_cast<int>(dtype)] *
                        analysis::AccuDims(dims.d, dims.nbDims);
109
  PADDLE_ENFORCE(input->isNetworkInput());
L
Luo Tao 已提交
110
  TensorRTEngine::SetITensor(name, input);
Y
Yan Chunwei 已提交
111 112 113
  return input;
}

114 115
void TensorRTEngine::DeclareOutput(const nvinfer1::ILayer *layer, int offset,
                                   const std::string &name) {
Y
Yan Chunwei 已提交
116 117 118
  PADDLE_ENFORCE_EQ(0, buffer_sizes_.count(name), "duplicate output name %s",
                    name);

119
  auto *output = layer->getOutput(offset);
120
  SetITensor(name, output);
Y
Yan Chunwei 已提交
121 122
  PADDLE_ENFORCE(output != nullptr);
  output->setName(name.c_str());
123
  PADDLE_ENFORCE(!output->isNetworkInput());
Y
Yan Chunwei 已提交
124
  infer_network_->markOutput(*output);
125
  PADDLE_ENFORCE(output->isNetworkOutput());
Y
Yan Chunwei 已提交
126 127 128 129 130
  // output buffers' size can only be decided latter, set zero here to mark this
  // and will reset latter.
  buffer_sizes_[name] = 0;
}

131
void TensorRTEngine::DeclareOutput(const std::string &name) {
L
Luo Tao 已提交
132 133 134
  PADDLE_ENFORCE_EQ(0, buffer_sizes_.count(name), "duplicate output name %s",
                    name);

135
  auto *output = TensorRTEngine::GetITensor(name);
L
Luo Tao 已提交
136 137
  PADDLE_ENFORCE(output != nullptr);
  output->setName(name.c_str());
138
  PADDLE_ENFORCE(!output->isNetworkInput());
L
Luo Tao 已提交
139 140 141 142 143 144
  infer_network_->markOutput(*output);
  // output buffers' size can only be decided latter, set zero here to mark this
  // and will reset latter.
  buffer_sizes_[name] = 0;
}

145
void *TensorRTEngine::GetOutputInGPU(const std::string &name) {
Y
Yan Chunwei 已提交
146
  return buffer(name).buffer;
Y
Yan Chunwei 已提交
147 148
}

149
void TensorRTEngine::GetOutputInGPU(const std::string &name, void *dst,
150 151 152 153 154 155
                                    size_t max_size) {
  // determine data size
  auto it = buffer_sizes_.find(name);
  PADDLE_ENFORCE(it != buffer_sizes_.end());
  PADDLE_ENFORCE_GT(it->second, 0);
  PADDLE_ENFORCE_GE(max_size, it->second);
156
  auto &buf = buffer(name);
157 158 159 160 161 162
  PADDLE_ENFORCE_NOT_NULL(buf.buffer, "buffer should be allocated before");
  PADDLE_ENFORCE_EQ(cudaMemcpyAsync(dst, buf.buffer, it->second,
                                    cudaMemcpyDeviceToDevice, *stream_),
                    0);
}

163
void TensorRTEngine::GetOutputInCPU(const std::string &name, void *dst,
Y
Yan Chunwei 已提交
164
                                    size_t max_size) {
165 166 167 168 169 170 171 172 173 174
  VLOG(4) << "get output in cpu";
  auto &buf = buffer(name);

  // Update needed buffer size.
  auto slot_offset = infer_engine_->getBindingIndex(name.c_str());
  auto dims = infer_engine_->getBindingDimensions(slot_offset);
  buf.size = kDataTypeSize[static_cast<int>(
                 infer_engine_->getBindingDataType(slot_offset))] *
             analysis::AccuDims(dims.d, dims.nbDims);
  PADDLE_ENFORCE_LE(buf.size, buf.max_size);
Y
Yan Chunwei 已提交
175
  // determine data size
Y
Yan Chunwei 已提交
176
  PADDLE_ENFORCE_NOT_NULL(buf.buffer, "buffer should be allocated before");
177 178 179 180
  // DEBUG
  memset(dst, 0, buf.size);
  PADDLE_ENFORCE_EQ(
      0, cudaMemcpy(dst, buf.buffer, buf.size, cudaMemcpyDeviceToHost));
Y
Yan Chunwei 已提交
181 182
}

183
Buffer &TensorRTEngine::buffer(const std::string &name) {
Y
Yan Chunwei 已提交
184 185 186 187 188 189 190
  PADDLE_ENFORCE(infer_engine_ != nullptr, "call FreezeNetwork first.");
  auto it = buffer_sizes_.find(name);
  PADDLE_ENFORCE(it != buffer_sizes_.end());
  auto slot_offset = infer_engine_->getBindingIndex(name.c_str());
  return buffers_[slot_offset];
}

191
void TensorRTEngine::SetInputFromCPU(const std::string &name, const void *data,
Y
Yan Chunwei 已提交
192
                                     size_t size) {
193
  auto &buf = buffer(name);
Y
Yan Chunwei 已提交
194
  PADDLE_ENFORCE_NOT_NULL(buf.buffer);
195 196
  PADDLE_ENFORCE_NOT_NULL(data);
  PADDLE_ENFORCE_NOT_NULL(stream_);
Y
Yan Chunwei 已提交
197 198
  PADDLE_ENFORCE_LE(size, buf.max_size, "buffer is too small");
  PADDLE_ENFORCE(buf.device == DeviceType::GPU);
199
  buf.size = size;
Y
Yan Chunwei 已提交
200 201
  PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(buf.buffer, data, size,
                                       cudaMemcpyHostToDevice, *stream_));
Y
Yan Chunwei 已提交
202 203
}

204
void TensorRTEngine::SetInputFromGPU(const std::string &name, const void *data,
205
                                     size_t size) {
206 207
  auto &buf = buffer(name);
  buf.size = size;
208 209 210 211 212 213 214
  PADDLE_ENFORCE_NOT_NULL(buf.buffer);
  PADDLE_ENFORCE_LE(size, buf.max_size, "buffer is too small");
  PADDLE_ENFORCE(buf.device == DeviceType::GPU);
  PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(buf.buffer, data, size,
                                       cudaMemcpyDeviceToDevice, *stream_));
}

215 216
void TensorRTEngine::SetITensor(const std::string &name,
                                nvinfer1::ITensor *tensor) {
L
Luo Tao 已提交
217
  PADDLE_ENFORCE(tensor != nullptr);
Y
Yan Chunwei 已提交
218
  PADDLE_ENFORCE_EQ(0, itensor_map_.count(name), "duplicate ITensor name %s",
L
Luo Tao 已提交
219 220 221 222
                    name);
  itensor_map_[name] = tensor;
}

223
nvinfer1::ITensor *TensorRTEngine::GetITensor(const std::string &name) {
Y
Yan Chunwei 已提交
224
  PADDLE_ENFORCE(itensor_map_.count(name), "no ITensor %s", name);
L
Luo Tao 已提交
225 226 227
  return itensor_map_[name];
}

Y
Yan Chunwei 已提交
228 229 230
}  // namespace tensorrt
}  // namespace inference
}  // namespace paddle