engine.cc 8.7 KB
Newer Older
Y
Yan Chunwei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/inference/tensorrt/engine.h"

#include <NvInfer.h>
#include <cuda.h>
#include <glog/logging.h>
A
Abhinav Arora 已提交
20
#include <string>
Y
Yan Chunwei 已提交
21
#include "paddle/fluid/inference/analysis/helper.h"
Y
Yan Chunwei 已提交
22 23 24 25 26 27 28
#include "paddle/fluid/inference/tensorrt/helper.h"
#include "paddle/fluid/platform/enforce.h"

namespace paddle {
namespace inference {
namespace tensorrt {

29 30
int TensorRTEngine::runtime_batch_ = 1;

Y
Yan Chunwei 已提交
31 32 33 34 35
void TensorRTEngine::Build(const DescType& paddle_model) {
  PADDLE_ENFORCE(false, "not implemented");
}

void TensorRTEngine::Execute(int batch_size) {
Y
Yan Chunwei 已提交
36 37 38 39 40 41 42 43
  std::vector<void*> buffers;
  for (auto& buf : buffers_) {
    PADDLE_ENFORCE_NOT_NULL(buf.buffer, "buffer should be allocated");
    PADDLE_ENFORCE_GT(buf.max_size, 0);
    PADDLE_ENFORCE(buf.device == DeviceType::GPU);
    buffers.push_back(buf.buffer);
  }
  infer_context_->enqueue(batch_size, buffers.data(), *stream_, nullptr);
Y
Yan Chunwei 已提交
44
  cudaStreamSynchronize(*stream_);
45
  SetRuntimeBatch(batch_size);
Y
Yan Chunwei 已提交
46 47 48
}

TensorRTEngine::~TensorRTEngine() {
49
  cudaStreamSynchronize(*stream_);
Y
Yan Chunwei 已提交
50
  // clean buffer
Y
Yan Chunwei 已提交
51
  for (auto& buf : buffers_) {
52
    if (buf.device == DeviceType::GPU && buf.buffer != nullptr) {
Y
Yan Chunwei 已提交
53 54 55
      PADDLE_ENFORCE_EQ(0, cudaFree(buf.buffer));
      buf.buffer = nullptr;
      buf.max_size = 0;
Y
Yan Chunwei 已提交
56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74
    }
  }
}

void TensorRTEngine::FreezeNetwork() {
  PADDLE_ENFORCE(infer_builder_ != nullptr,
                 "Call InitNetwork first to initialize network.");
  PADDLE_ENFORCE(infer_network_ != nullptr,
                 "Call InitNetwork first to initialize network.");
  // build engine.
  infer_builder_->setMaxBatchSize(max_batch_);
  infer_builder_->setMaxWorkspaceSize(max_workspace_);

  infer_engine_.reset(infer_builder_->buildCudaEngine(*infer_network_));
  PADDLE_ENFORCE(infer_engine_ != nullptr, "build cuda engine failed!");

  infer_context_.reset(infer_engine_->createExecutionContext());

  // allocate GPU buffers.
Y
Yan Chunwei 已提交
75
  buffers_.resize(buffer_sizes_.size());
Y
Yan Chunwei 已提交
76 77 78
  for (auto& item : buffer_sizes_) {
    if (item.second == 0) {
      auto slot_offset = infer_engine_->getBindingIndex(item.first.c_str());
Y
Yan Chunwei 已提交
79
      auto dims = infer_engine_->getBindingDimensions(slot_offset);
Y
Yan Chunwei 已提交
80 81
      item.second = kDataTypeSize[static_cast<int>(
                        infer_engine_->getBindingDataType(slot_offset))] *
82
                    analysis::AccuDims(dims.d, dims.nbDims) * max_batch_;
Y
Yan Chunwei 已提交
83
    }
Y
Yan Chunwei 已提交
84 85
    auto& buf = buffer(item.first);
    CHECK(buf.buffer == nullptr);  // buffer should be allocated only once.
86
    PADDLE_ENFORCE_EQ(0, cudaMalloc(&buf.buffer, item.second * max_batch_));
87 88
    VLOG(4) << "buffer malloc " << item.first << " " << item.second << " "
            << buf.buffer;
89 90
    buf.size = item.second;
    buf.max_size = item.second * max_batch_;
Y
Yan Chunwei 已提交
91
    buf.device = DeviceType::GPU;
Y
Yan Chunwei 已提交
92 93 94 95 96
  }
}

nvinfer1::ITensor* TensorRTEngine::DeclareInput(const std::string& name,
                                                nvinfer1::DataType dtype,
Y
Yan Chunwei 已提交
97
                                                const nvinfer1::Dims& dims) {
Y
Yan Chunwei 已提交
98 99 100 101
  PADDLE_ENFORCE_EQ(0, buffer_sizes_.count(name), "duplicate input name %s",
                    name);

  PADDLE_ENFORCE(infer_network_ != nullptr, "should initnetwork first");
Y
Yan Chunwei 已提交
102
  auto* input = infer_network_->addInput(name.c_str(), dtype, dims);
Y
Yan Chunwei 已提交
103
  PADDLE_ENFORCE(input, "infer network add input %s failed", name);
Y
Yan Chunwei 已提交
104
  buffer_sizes_[name] = kDataTypeSize[static_cast<int>(dtype)] *
105
                        analysis::AccuDims(dims.d, dims.nbDims) * max_batch_;
106
  PADDLE_ENFORCE(input->isNetworkInput());
L
Luo Tao 已提交
107
  TensorRTEngine::SetITensor(name, input);
Y
Yan Chunwei 已提交
108 109 110 111 112 113 114 115 116
  return input;
}

void TensorRTEngine::DeclareOutput(const nvinfer1::ILayer* layer, int offset,
                                   const std::string& name) {
  PADDLE_ENFORCE_EQ(0, buffer_sizes_.count(name), "duplicate output name %s",
                    name);

  auto* output = layer->getOutput(offset);
117
  SetITensor(name, output);
Y
Yan Chunwei 已提交
118 119
  PADDLE_ENFORCE(output != nullptr);
  output->setName(name.c_str());
120
  PADDLE_ENFORCE(!output->isNetworkInput());
Y
Yan Chunwei 已提交
121
  infer_network_->markOutput(*output);
122
  PADDLE_ENFORCE(output->isNetworkOutput());
Y
Yan Chunwei 已提交
123 124 125 126 127
  // output buffers' size can only be decided latter, set zero here to mark this
  // and will reset latter.
  buffer_sizes_[name] = 0;
}

L
Luo Tao 已提交
128 129 130 131 132 133 134
void TensorRTEngine::DeclareOutput(const std::string& name) {
  PADDLE_ENFORCE_EQ(0, buffer_sizes_.count(name), "duplicate output name %s",
                    name);

  auto* output = TensorRTEngine::GetITensor(name);
  PADDLE_ENFORCE(output != nullptr);
  output->setName(name.c_str());
135
  PADDLE_ENFORCE(!output->isNetworkInput());
L
Luo Tao 已提交
136 137 138 139 140 141
  infer_network_->markOutput(*output);
  // output buffers' size can only be decided latter, set zero here to mark this
  // and will reset latter.
  buffer_sizes_[name] = 0;
}

Y
Yan Chunwei 已提交
142
void* TensorRTEngine::GetOutputInGPU(const std::string& name) {
Y
Yan Chunwei 已提交
143
  return buffer(name).buffer;
Y
Yan Chunwei 已提交
144 145
}

146
void TensorRTEngine::GetOutputInGPU(const std::string& name, void* dst) {
147
  // determine data size
148 149 150 151 152 153
  auto* output = TensorRTEngine::GetITensor(name);
  nvinfer1::Dims dims = output->getDimensions();
  auto dim_size = analysis::AccuDims(dims.d, dims.nbDims);
  size_t dst_size = dim_size * runtime_batch_ *
                    kDataTypeSize[static_cast<int>(output->getType())];

154 155 156
  auto it = buffer_sizes_.find(name);
  PADDLE_ENFORCE(it != buffer_sizes_.end());
  PADDLE_ENFORCE_GT(it->second, 0);
157
  PADDLE_ENFORCE_LE(dst_size, it->second);
158 159
  auto& buf = buffer(name);
  PADDLE_ENFORCE_NOT_NULL(buf.buffer, "buffer should be allocated before");
160
  PADDLE_ENFORCE_EQ(cudaMemcpyAsync(dst, buf.buffer, dst_size,
161 162 163 164
                                    cudaMemcpyDeviceToDevice, *stream_),
                    0);
}

165
void TensorRTEngine::GetOutputInCPU(const std::string& name, void* dst) {
Y
Yan Chunwei 已提交
166
  // determine data size
167 168 169 170 171 172

  auto* output = TensorRTEngine::GetITensor(name);
  nvinfer1::Dims dims = output->getDimensions();
  auto dim_size = analysis::AccuDims(dims.d, dims.nbDims);
  size_t dst_size = dim_size * runtime_batch_ *
                    kDataTypeSize[static_cast<int>(output->getType())];
Y
Yan Chunwei 已提交
173 174 175
  auto it = buffer_sizes_.find(name);
  PADDLE_ENFORCE(it != buffer_sizes_.end());
  PADDLE_ENFORCE_GT(it->second, 0);
176
  PADDLE_ENFORCE_LE(dst_size, it->second);
Y
Yan Chunwei 已提交
177 178
  auto& buf = buffer(name);
  PADDLE_ENFORCE_NOT_NULL(buf.buffer, "buffer should be allocated before");
179
  PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(dst, buf.buffer, dst_size,
Y
Yan Chunwei 已提交
180 181 182
                                       cudaMemcpyDeviceToHost, *stream_));
}

Y
Yan Chunwei 已提交
183
Buffer& TensorRTEngine::buffer(const std::string& name) {
Y
Yan Chunwei 已提交
184 185 186 187 188 189 190
  PADDLE_ENFORCE(infer_engine_ != nullptr, "call FreezeNetwork first.");
  auto it = buffer_sizes_.find(name);
  PADDLE_ENFORCE(it != buffer_sizes_.end());
  auto slot_offset = infer_engine_->getBindingIndex(name.c_str());
  return buffers_[slot_offset];
}

191
void TensorRTEngine::SetInputFromCPU(const std::string& name, const void* data,
Y
Yan Chunwei 已提交
192
                                     size_t size) {
Y
Yan Chunwei 已提交
193 194 195 196 197 198
  auto& buf = buffer(name);
  PADDLE_ENFORCE_NOT_NULL(buf.buffer);
  PADDLE_ENFORCE_LE(size, buf.max_size, "buffer is too small");
  PADDLE_ENFORCE(buf.device == DeviceType::GPU);
  PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(buf.buffer, data, size,
                                       cudaMemcpyHostToDevice, *stream_));
Y
Yan Chunwei 已提交
199 200
}

201 202 203 204 205 206 207 208 209 210
void TensorRTEngine::SetInputFromGPU(const std::string& name, const void* data,
                                     size_t size) {
  auto& buf = buffer(name);
  PADDLE_ENFORCE_NOT_NULL(buf.buffer);
  PADDLE_ENFORCE_LE(size, buf.max_size, "buffer is too small");
  PADDLE_ENFORCE(buf.device == DeviceType::GPU);
  PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(buf.buffer, data, size,
                                       cudaMemcpyDeviceToDevice, *stream_));
}

L
Luo Tao 已提交
211 212 213
void TensorRTEngine::SetITensor(const std::string& name,
                                nvinfer1::ITensor* tensor) {
  PADDLE_ENFORCE(tensor != nullptr);
Y
Yan Chunwei 已提交
214
  PADDLE_ENFORCE_EQ(0, itensor_map_.count(name), "duplicate ITensor name %s",
L
Luo Tao 已提交
215 216 217 218 219
                    name);
  itensor_map_[name] = tensor;
}

nvinfer1::ITensor* TensorRTEngine::GetITensor(const std::string& name) {
Y
Yan Chunwei 已提交
220
  PADDLE_ENFORCE(itensor_map_.count(name), "no ITensor %s", name);
L
Luo Tao 已提交
221 222 223
  return itensor_map_[name];
}

224 225 226 227 228 229
void TensorRTEngine::SetRuntimeBatch(size_t batch_size) {
  runtime_batch_ = batch_size;
}

int TensorRTEngine::GetRuntimeBatch() { return runtime_batch_; }

Y
Yan Chunwei 已提交
230 231 232
}  // namespace tensorrt
}  // namespace inference
}  // namespace paddle