提交 81903811 编写于 作者: Y Yan Chunwei 提交者: Tao Luo

Feature/engine refactor (#10497)

* init refactor

* init

* update some comment

* fix build

* fix errorrr

* fix bug

* fix comment

* update
上级 c7c62e07
......@@ -19,6 +19,9 @@ limitations under the License. */
namespace paddle {
namespace inference {
struct Buffer;
enum class DeviceType { UNK = -1, CPU, GPU };
/*
* EngineBase is the base class of all inference engines. An inference engine
* takes a paddle program as input, and outputs the result in fluid Tensor
......@@ -45,8 +48,20 @@ class EngineBase {
// Execute the engine, that will run the inference network.
virtual void Execute(int batch_size) = 0;
// Return the IO buffer that allocated in engine. One can read/write directly
// on the buffer. If the buffer's buffer is nullptr, one can also allocate
// memory and maintain it outside the engine.
virtual Buffer& buffer(const std::string& name) = 0;
virtual ~EngineBase() {}
}; // class EngineBase
struct Buffer {
void* buffer{nullptr}; // buffer should be allocated only once.
int max_size; // buffer allocated space.
int size; // data size.
DeviceType device{DeviceType::UNK}; // tells which device this buffer is on.
};
} // namespace inference
} // namespace paddle
nv_library(tensorrt_engine SRCS engine.cc)
nv_test(test_tensorrt SRCS test_tensorrt.cc DEPS dynload_cuda device_context dynamic_loader)
nv_test(test_tensorrt_engine SRCS test_engine.cc engine.cc DEPS dynload_cuda)
set(ENGINE_FILE ${CMAKE_CURRENT_SOURCE_DIR}/engine.cc)
nv_test(test_tensorrt_engine SRCS test_engine.cc DEPS dynload_cuda tensorrt_engine)
add_subdirectory(convert)
nv_test(test_op_converter SRCS test_op_converter.cc mul_op.cc conv2d_op.cc DEPS ${FLUID_CORE_MODULES})
nv_test(test_trt_activation_op SRCS test_activation_op.cc ${ENGINE_FILE} activation_op.cc
DEPS ${FLUID_CORE_MODULES} activation_op)
nv_test(test_trt_activation_op SRCS test_activation_op.cc activation_op.cc
DEPS ${FLUID_CORE_MODULES} activation_op tensorrt_engine)
nv_test(test_io_converter SRCS test_io_converter.cc io_converter.cc DEPS dynload_cuda dynamic_loader lod_tensor)
......@@ -30,16 +30,24 @@ void TensorRTEngine::Build(const DescType& paddle_model) {
}
void TensorRTEngine::Execute(int batch_size) {
infer_context_->enqueue(batch_size, buffers_.data(), *stream_, nullptr);
std::vector<void*> buffers;
for (auto& buf : buffers_) {
PADDLE_ENFORCE_NOT_NULL(buf.buffer, "buffer should be allocated");
PADDLE_ENFORCE_GT(buf.max_size, 0);
PADDLE_ENFORCE(buf.device == DeviceType::GPU);
buffers.push_back(buf.buffer);
}
infer_context_->enqueue(batch_size, buffers.data(), *stream_, nullptr);
cudaStreamSynchronize(*stream_);
}
TensorRTEngine::~TensorRTEngine() {
// clean buffer
for (auto& buffer : buffers_) {
if (buffer != nullptr) {
PADDLE_ENFORCE_EQ(0, cudaFree(buffer));
buffer = nullptr;
for (auto& buf : buffers_) {
if (buf.buffer != nullptr) {
PADDLE_ENFORCE_EQ(0, cudaFree(buf.buffer));
buf.buffer = nullptr;
buf.max_size = 0;
}
}
}
......@@ -59,7 +67,7 @@ void TensorRTEngine::FreezeNetwork() {
infer_context_.reset(infer_engine_->createExecutionContext());
// allocate GPU buffers.
buffers_.resize(buffer_sizes_.size(), nullptr);
buffers_.resize(buffer_sizes_.size());
for (auto& item : buffer_sizes_) {
if (item.second == 0) {
auto slot_offset = infer_engine_->getBindingIndex(item.first.c_str());
......@@ -67,7 +75,11 @@ void TensorRTEngine::FreezeNetwork() {
infer_engine_->getBindingDataType(slot_offset))] *
AccumDims(infer_engine_->getBindingDimensions(slot_offset));
}
PADDLE_ENFORCE_EQ(0, cudaMalloc(&buffer(item.first), item.second));
auto& buf = buffer(item.first);
CHECK(buf.buffer == nullptr); // buffer should be allocated only once.
PADDLE_ENFORCE_EQ(0, cudaMalloc(&buf.buffer, item.second));
buf.size = buf.max_size = item.second;
buf.device = DeviceType::GPU;
}
}
......@@ -113,7 +125,7 @@ void TensorRTEngine::DeclareOutput(const std::string& name) {
}
void* TensorRTEngine::GetOutputInGPU(const std::string& name) {
return buffer(name);
return buffer(name).buffer;
}
void TensorRTEngine::GetOutputInCPU(const std::string& name, void* dst,
......@@ -123,11 +135,13 @@ void TensorRTEngine::GetOutputInCPU(const std::string& name, void* dst,
PADDLE_ENFORCE(it != buffer_sizes_.end());
PADDLE_ENFORCE_GT(it->second, 0);
PADDLE_ENFORCE_GE(max_size, it->second);
PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(dst, buffer(name), it->second,
auto& buf = buffer(name);
PADDLE_ENFORCE_NOT_NULL(buf.buffer, "buffer should be allocated before");
PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(dst, buf.buffer, it->second,
cudaMemcpyDeviceToHost, *stream_));
}
void*& TensorRTEngine::buffer(const std::string& name) {
Buffer& TensorRTEngine::buffer(const std::string& name) {
PADDLE_ENFORCE(infer_engine_ != nullptr, "call FreezeNetwork first.");
auto it = buffer_sizes_.find(name);
PADDLE_ENFORCE(it != buffer_sizes_.end());
......@@ -137,10 +151,12 @@ void*& TensorRTEngine::buffer(const std::string& name) {
void TensorRTEngine::SetInputFromCPU(const std::string& name, void* data,
size_t size) {
void* buf = buffer(name);
cudaMemcpyAsync(buf, data, size, cudaMemcpyHostToDevice, *stream_);
PADDLE_ENFORCE_EQ(
0, cudaMemcpyAsync(buf, data, size, cudaMemcpyHostToDevice, *stream_));
auto& buf = buffer(name);
PADDLE_ENFORCE_NOT_NULL(buf.buffer);
PADDLE_ENFORCE_LE(size, buf.max_size, "buffer is too small");
PADDLE_ENFORCE(buf.device == DeviceType::GPU);
PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(buf.buffer, data, size,
cudaMemcpyHostToDevice, *stream_));
}
void TensorRTEngine::SetITensor(const std::string& name,
......
......@@ -87,7 +87,9 @@ class TensorRTEngine : public EngineBase {
// these memory directly for acceleration, for example, output the converted
// data directly to the buffer to save data copy overhead.
// NOTE this should be used after calling `FreezeNetwork`.
void*& buffer(const std::string& name);
Buffer& buffer(const std::string& name) override;
cudaStream_t* stream() { return stream_; }
// Fill an input from CPU memory with name and size.
void SetInputFromCPU(const std::string& name, void* data, size_t size);
......@@ -116,7 +118,7 @@ class TensorRTEngine : public EngineBase {
cudaStream_t* stream_;
nvinfer1::ILogger& logger_;
std::vector<void*> buffers_;
std::vector<Buffer> buffers_;
// max data size for the buffers.
std::unordered_map<std::string /*name*/, size_t /*max size*/> buffer_sizes_;
std::unordered_map<std::string /*name*/, nvinfer1::ITensor* /*ITensor*/>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册