From dd3c24229c69a2ca54eab9e4a76c2de810ca11c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=9F=B3=E6=99=93=E4=BC=9F?= <39303645+Shixiaowei02@users.noreply.github.com> Date: Wed, 28 Aug 2019 14:17:32 +0800 Subject: [PATCH] fix multi-thread exec of trt, test=develop (#19379) --- paddle/fluid/inference/tensorrt/engine.cc | 11 ++++++++--- paddle/fluid/inference/tensorrt/engine.h | 4 ++-- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/engine.cc b/paddle/fluid/inference/tensorrt/engine.cc index 196cf2f89fb..fa8dc897905 100644 --- a/paddle/fluid/inference/tensorrt/engine.cc +++ b/paddle/fluid/inference/tensorrt/engine.cc @@ -35,8 +35,15 @@ void TensorRTEngine::Build(const DescType &paddle_model) { void TensorRTEngine::Execute(int batch_size, std::vector *buffers, cudaStream_t stream) { freshDeviceId(); + const std::thread::id tid = std::this_thread::get_id(); batch_size_ = batch_size; - infer_context_->enqueue(batch_size, buffers->data(), stream, nullptr); + if (infer_context_.find(tid) == infer_context_.end()) { + PADDLE_ENFORCE_NOT_NULL( + infer_engine_, + "You should build engine first and then set the context."); + infer_context_[tid].reset(infer_engine_->createExecutionContext()); + } + infer_context_[tid]->enqueue(batch_size, buffers->data(), stream, nullptr); cudaStreamSynchronize(stream); SetRuntimeBatch(batch_size); } @@ -111,8 +118,6 @@ void TensorRTEngine::FreezeNetwork() { infer_engine_.reset(infer_builder_->buildCudaEngine(*infer_network_)); PADDLE_ENFORCE(infer_engine_ != nullptr, "build cuda engine failed!"); - - infer_context_.reset(infer_engine_->createExecutionContext()); } nvinfer1::ITensor *TensorRTEngine::DeclareInput(const std::string &name, diff --git a/paddle/fluid/inference/tensorrt/engine.h b/paddle/fluid/inference/tensorrt/engine.h index 012c9fbb23e..19ec11017a4 100644 --- a/paddle/fluid/inference/tensorrt/engine.h +++ b/paddle/fluid/inference/tensorrt/engine.h @@ -128,7 +128,6 @@ class TensorRTEngine { &inference::Singleton::Global())); PADDLE_ENFORCE(infer_engine_ != nullptr, "build cuda engine failed when deserialize engine info.!"); - infer_context_.reset(infer_engine_->createExecutionContext()); } void SetRuntimeBatch(size_t batch_size); @@ -200,7 +199,8 @@ class TensorRTEngine { infer_ptr infer_builder_; infer_ptr infer_network_; infer_ptr infer_engine_; - infer_ptr infer_context_; + std::unordered_map> + infer_context_; infer_ptr ihost_memory_; std::unordered_map quant_dynamic_range_; }; // class TensorRTEngine -- GitLab