diff --git a/paddle/fluid/inference/tensorrt/engine.cc b/paddle/fluid/inference/tensorrt/engine.cc index 196cf2f89fbc4f5336fc6b6616a274a279a97883..fa8dc8979056d31996619c20b704914a1915e010 100644 --- a/paddle/fluid/inference/tensorrt/engine.cc +++ b/paddle/fluid/inference/tensorrt/engine.cc @@ -35,8 +35,15 @@ void TensorRTEngine::Build(const DescType &paddle_model) { void TensorRTEngine::Execute(int batch_size, std::vector *buffers, cudaStream_t stream) { freshDeviceId(); + const std::thread::id tid = std::this_thread::get_id(); batch_size_ = batch_size; - infer_context_->enqueue(batch_size, buffers->data(), stream, nullptr); + if (infer_context_.find(tid) == infer_context_.end()) { + PADDLE_ENFORCE_NOT_NULL( + infer_engine_, + "You should build engine first and then set the context."); + infer_context_[tid].reset(infer_engine_->createExecutionContext()); + } + infer_context_[tid]->enqueue(batch_size, buffers->data(), stream, nullptr); cudaStreamSynchronize(stream); SetRuntimeBatch(batch_size); } @@ -111,8 +118,6 @@ void TensorRTEngine::FreezeNetwork() { infer_engine_.reset(infer_builder_->buildCudaEngine(*infer_network_)); PADDLE_ENFORCE(infer_engine_ != nullptr, "build cuda engine failed!"); - - infer_context_.reset(infer_engine_->createExecutionContext()); } nvinfer1::ITensor *TensorRTEngine::DeclareInput(const std::string &name, diff --git a/paddle/fluid/inference/tensorrt/engine.h b/paddle/fluid/inference/tensorrt/engine.h index 012c9fbb23e5b899df3d5bb63d1bcbac1fe6eae1..19ec11017a42bfe0b83c4122f0d152934c3cd913 100644 --- a/paddle/fluid/inference/tensorrt/engine.h +++ b/paddle/fluid/inference/tensorrt/engine.h @@ -128,7 +128,6 @@ class TensorRTEngine { &inference::Singleton::Global())); PADDLE_ENFORCE(infer_engine_ != nullptr, "build cuda engine failed when deserialize engine info.!"); - infer_context_.reset(infer_engine_->createExecutionContext()); } void SetRuntimeBatch(size_t batch_size); @@ -200,7 +199,8 @@ class TensorRTEngine { infer_ptr infer_builder_; infer_ptr infer_network_; infer_ptr infer_engine_; - infer_ptr infer_context_; + std::unordered_map> + infer_context_; infer_ptr ihost_memory_; std::unordered_map quant_dynamic_range_; }; // class TensorRTEngine