engine.cc 7.6 KB
Newer Older
Y
Yan Chunwei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/inference/tensorrt/engine.h"

#include <NvInfer.h>
#include <cuda.h>
#include <glog/logging.h>
A
Abhinav Arora 已提交
20
#include <string>
Y
Yan Chunwei 已提交
21
#include "paddle/fluid/inference/analysis/helper.h"
Y
Yan Chunwei 已提交
22 23 24 25 26 27 28 29 30 31 32 33
#include "paddle/fluid/inference/tensorrt/helper.h"
#include "paddle/fluid/platform/enforce.h"

namespace paddle {
namespace inference {
namespace tensorrt {

void TensorRTEngine::Build(const DescType& paddle_model) {
  PADDLE_ENFORCE(false, "not implemented");
}

void TensorRTEngine::Execute(int batch_size) {
Y
Yan Chunwei 已提交
34 35 36 37 38 39 40 41
  std::vector<void*> buffers;
  for (auto& buf : buffers_) {
    PADDLE_ENFORCE_NOT_NULL(buf.buffer, "buffer should be allocated");
    PADDLE_ENFORCE_GT(buf.max_size, 0);
    PADDLE_ENFORCE(buf.device == DeviceType::GPU);
    buffers.push_back(buf.buffer);
  }
  infer_context_->enqueue(batch_size, buffers.data(), *stream_, nullptr);
Y
Yan Chunwei 已提交
42 43 44 45 46
  cudaStreamSynchronize(*stream_);
}

TensorRTEngine::~TensorRTEngine() {
  // clean buffer
Y
Yan Chunwei 已提交
47 48 49 50 51
  for (auto& buf : buffers_) {
    if (buf.buffer != nullptr) {
      PADDLE_ENFORCE_EQ(0, cudaFree(buf.buffer));
      buf.buffer = nullptr;
      buf.max_size = 0;
Y
Yan Chunwei 已提交
52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
    }
  }
}

void TensorRTEngine::FreezeNetwork() {
  PADDLE_ENFORCE(infer_builder_ != nullptr,
                 "Call InitNetwork first to initialize network.");
  PADDLE_ENFORCE(infer_network_ != nullptr,
                 "Call InitNetwork first to initialize network.");
  // build engine.
  infer_builder_->setMaxBatchSize(max_batch_);
  infer_builder_->setMaxWorkspaceSize(max_workspace_);

  infer_engine_.reset(infer_builder_->buildCudaEngine(*infer_network_));
  PADDLE_ENFORCE(infer_engine_ != nullptr, "build cuda engine failed!");

  infer_context_.reset(infer_engine_->createExecutionContext());

  // allocate GPU buffers.
Y
Yan Chunwei 已提交
71
  buffers_.resize(buffer_sizes_.size());
Y
Yan Chunwei 已提交
72 73 74
  for (auto& item : buffer_sizes_) {
    if (item.second == 0) {
      auto slot_offset = infer_engine_->getBindingIndex(item.first.c_str());
Y
Yan Chunwei 已提交
75
      auto dims = infer_engine_->getBindingDimensions(slot_offset);
Y
Yan Chunwei 已提交
76 77
      item.second = kDataTypeSize[static_cast<int>(
                        infer_engine_->getBindingDataType(slot_offset))] *
Y
Yan Chunwei 已提交
78
                    analysis::AccuDims(dims.d, dims.nbDims);
Y
Yan Chunwei 已提交
79
    }
Y
Yan Chunwei 已提交
80 81 82 83 84
    auto& buf = buffer(item.first);
    CHECK(buf.buffer == nullptr);  // buffer should be allocated only once.
    PADDLE_ENFORCE_EQ(0, cudaMalloc(&buf.buffer, item.second));
    buf.size = buf.max_size = item.second;
    buf.device = DeviceType::GPU;
Y
Yan Chunwei 已提交
85 86 87 88 89
  }
}

nvinfer1::ITensor* TensorRTEngine::DeclareInput(const std::string& name,
                                                nvinfer1::DataType dtype,
Y
Yan Chunwei 已提交
90
                                                const nvinfer1::Dims& dims) {
Y
Yan Chunwei 已提交
91 92 93 94
  PADDLE_ENFORCE_EQ(0, buffer_sizes_.count(name), "duplicate input name %s",
                    name);

  PADDLE_ENFORCE(infer_network_ != nullptr, "should initnetwork first");
Y
Yan Chunwei 已提交
95
  auto* input = infer_network_->addInput(name.c_str(), dtype, dims);
Y
Yan Chunwei 已提交
96
  PADDLE_ENFORCE(input, "infer network add input %s failed", name);
Y
Yan Chunwei 已提交
97 98
  buffer_sizes_[name] = kDataTypeSize[static_cast<int>(dtype)] *
                        analysis::AccuDims(dims.d, dims.nbDims);
L
Luo Tao 已提交
99
  TensorRTEngine::SetITensor(name, input);
Y
Yan Chunwei 已提交
100 101 102 103 104 105 106 107 108
  return input;
}

void TensorRTEngine::DeclareOutput(const nvinfer1::ILayer* layer, int offset,
                                   const std::string& name) {
  PADDLE_ENFORCE_EQ(0, buffer_sizes_.count(name), "duplicate output name %s",
                    name);

  auto* output = layer->getOutput(offset);
109
  SetITensor(name, output);
Y
Yan Chunwei 已提交
110 111 112 113 114 115 116 117
  PADDLE_ENFORCE(output != nullptr);
  output->setName(name.c_str());
  infer_network_->markOutput(*output);
  // output buffers' size can only be decided latter, set zero here to mark this
  // and will reset latter.
  buffer_sizes_[name] = 0;
}

L
Luo Tao 已提交
118 119 120 121 122 123 124 125 126 127 128 129 130
void TensorRTEngine::DeclareOutput(const std::string& name) {
  PADDLE_ENFORCE_EQ(0, buffer_sizes_.count(name), "duplicate output name %s",
                    name);

  auto* output = TensorRTEngine::GetITensor(name);
  PADDLE_ENFORCE(output != nullptr);
  output->setName(name.c_str());
  infer_network_->markOutput(*output);
  // output buffers' size can only be decided latter, set zero here to mark this
  // and will reset latter.
  buffer_sizes_[name] = 0;
}

Y
Yan Chunwei 已提交
131
void* TensorRTEngine::GetOutputInGPU(const std::string& name) {
Y
Yan Chunwei 已提交
132
  return buffer(name).buffer;
Y
Yan Chunwei 已提交
133 134
}

135 136 137 138 139 140 141 142 143 144 145 146 147 148
void TensorRTEngine::GetOutputInGPU(const std::string& name, void* dst,
                                    size_t max_size) {
  // determine data size
  auto it = buffer_sizes_.find(name);
  PADDLE_ENFORCE(it != buffer_sizes_.end());
  PADDLE_ENFORCE_GT(it->second, 0);
  PADDLE_ENFORCE_GE(max_size, it->second);
  auto& buf = buffer(name);
  PADDLE_ENFORCE_NOT_NULL(buf.buffer, "buffer should be allocated before");
  PADDLE_ENFORCE_EQ(cudaMemcpyAsync(dst, buf.buffer, it->second,
                                    cudaMemcpyDeviceToDevice, *stream_),
                    0);
}

Y
Yan Chunwei 已提交
149 150 151 152 153 154 155
void TensorRTEngine::GetOutputInCPU(const std::string& name, void* dst,
                                    size_t max_size) {
  // determine data size
  auto it = buffer_sizes_.find(name);
  PADDLE_ENFORCE(it != buffer_sizes_.end());
  PADDLE_ENFORCE_GT(it->second, 0);
  PADDLE_ENFORCE_GE(max_size, it->second);
Y
Yan Chunwei 已提交
156 157 158
  auto& buf = buffer(name);
  PADDLE_ENFORCE_NOT_NULL(buf.buffer, "buffer should be allocated before");
  PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(dst, buf.buffer, it->second,
Y
Yan Chunwei 已提交
159 160 161
                                       cudaMemcpyDeviceToHost, *stream_));
}

Y
Yan Chunwei 已提交
162
Buffer& TensorRTEngine::buffer(const std::string& name) {
Y
Yan Chunwei 已提交
163 164 165 166 167 168 169
  PADDLE_ENFORCE(infer_engine_ != nullptr, "call FreezeNetwork first.");
  auto it = buffer_sizes_.find(name);
  PADDLE_ENFORCE(it != buffer_sizes_.end());
  auto slot_offset = infer_engine_->getBindingIndex(name.c_str());
  return buffers_[slot_offset];
}

170
void TensorRTEngine::SetInputFromCPU(const std::string& name, const void* data,
Y
Yan Chunwei 已提交
171
                                     size_t size) {
Y
Yan Chunwei 已提交
172 173 174 175 176 177
  auto& buf = buffer(name);
  PADDLE_ENFORCE_NOT_NULL(buf.buffer);
  PADDLE_ENFORCE_LE(size, buf.max_size, "buffer is too small");
  PADDLE_ENFORCE(buf.device == DeviceType::GPU);
  PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(buf.buffer, data, size,
                                       cudaMemcpyHostToDevice, *stream_));
Y
Yan Chunwei 已提交
178 179
}

180 181 182 183 184 185 186 187 188 189
void TensorRTEngine::SetInputFromGPU(const std::string& name, const void* data,
                                     size_t size) {
  auto& buf = buffer(name);
  PADDLE_ENFORCE_NOT_NULL(buf.buffer);
  PADDLE_ENFORCE_LE(size, buf.max_size, "buffer is too small");
  PADDLE_ENFORCE(buf.device == DeviceType::GPU);
  PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(buf.buffer, data, size,
                                       cudaMemcpyDeviceToDevice, *stream_));
}

L
Luo Tao 已提交
190 191 192
void TensorRTEngine::SetITensor(const std::string& name,
                                nvinfer1::ITensor* tensor) {
  PADDLE_ENFORCE(tensor != nullptr);
Y
Yan Chunwei 已提交
193
  PADDLE_ENFORCE_EQ(0, itensor_map_.count(name), "duplicate ITensor name %s",
L
Luo Tao 已提交
194 195 196 197 198
                    name);
  itensor_map_[name] = tensor;
}

nvinfer1::ITensor* TensorRTEngine::GetITensor(const std::string& name) {
Y
Yan Chunwei 已提交
199
  PADDLE_ENFORCE(itensor_map_.count(name), "no ITensor %s", name);
L
Luo Tao 已提交
200 201 202
  return itensor_map_[name];
}

Y
Yan Chunwei 已提交
203 204 205
}  // namespace tensorrt
}  // namespace inference
}  // namespace paddle