未验证 提交 dd3c2422 编写于 作者: 石晓伟 提交者: GitHub

fix multi-thread exec of trt, test=develop (#19379)

上级 9048229b
...@@ -35,8 +35,15 @@ void TensorRTEngine::Build(const DescType &paddle_model) { ...@@ -35,8 +35,15 @@ void TensorRTEngine::Build(const DescType &paddle_model) {
void TensorRTEngine::Execute(int batch_size, std::vector<void *> *buffers, void TensorRTEngine::Execute(int batch_size, std::vector<void *> *buffers,
cudaStream_t stream) { cudaStream_t stream) {
freshDeviceId(); freshDeviceId();
const std::thread::id tid = std::this_thread::get_id();
batch_size_ = batch_size; batch_size_ = batch_size;
infer_context_->enqueue(batch_size, buffers->data(), stream, nullptr); if (infer_context_.find(tid) == infer_context_.end()) {
PADDLE_ENFORCE_NOT_NULL(
infer_engine_,
"You should build engine first and then set the context.");
infer_context_[tid].reset(infer_engine_->createExecutionContext());
}
infer_context_[tid]->enqueue(batch_size, buffers->data(), stream, nullptr);
cudaStreamSynchronize(stream); cudaStreamSynchronize(stream);
SetRuntimeBatch(batch_size); SetRuntimeBatch(batch_size);
} }
...@@ -111,8 +118,6 @@ void TensorRTEngine::FreezeNetwork() { ...@@ -111,8 +118,6 @@ void TensorRTEngine::FreezeNetwork() {
infer_engine_.reset(infer_builder_->buildCudaEngine(*infer_network_)); infer_engine_.reset(infer_builder_->buildCudaEngine(*infer_network_));
PADDLE_ENFORCE(infer_engine_ != nullptr, "build cuda engine failed!"); PADDLE_ENFORCE(infer_engine_ != nullptr, "build cuda engine failed!");
infer_context_.reset(infer_engine_->createExecutionContext());
} }
nvinfer1::ITensor *TensorRTEngine::DeclareInput(const std::string &name, nvinfer1::ITensor *TensorRTEngine::DeclareInput(const std::string &name,
......
...@@ -128,7 +128,6 @@ class TensorRTEngine { ...@@ -128,7 +128,6 @@ class TensorRTEngine {
&inference::Singleton<plugin::PluginFactoryTensorRT>::Global())); &inference::Singleton<plugin::PluginFactoryTensorRT>::Global()));
PADDLE_ENFORCE(infer_engine_ != nullptr, PADDLE_ENFORCE(infer_engine_ != nullptr,
"build cuda engine failed when deserialize engine info.!"); "build cuda engine failed when deserialize engine info.!");
infer_context_.reset(infer_engine_->createExecutionContext());
} }
void SetRuntimeBatch(size_t batch_size); void SetRuntimeBatch(size_t batch_size);
...@@ -200,7 +199,8 @@ class TensorRTEngine { ...@@ -200,7 +199,8 @@ class TensorRTEngine {
infer_ptr<nvinfer1::IBuilder> infer_builder_; infer_ptr<nvinfer1::IBuilder> infer_builder_;
infer_ptr<nvinfer1::INetworkDefinition> infer_network_; infer_ptr<nvinfer1::INetworkDefinition> infer_network_;
infer_ptr<nvinfer1::ICudaEngine> infer_engine_; infer_ptr<nvinfer1::ICudaEngine> infer_engine_;
infer_ptr<nvinfer1::IExecutionContext> infer_context_; std::unordered_map<std::thread::id, infer_ptr<nvinfer1::IExecutionContext>>
infer_context_;
infer_ptr<nvinfer1::IHostMemory> ihost_memory_; infer_ptr<nvinfer1::IHostMemory> ihost_memory_;
std::unordered_map<nvinfer1::ITensor*, float> quant_dynamic_range_; std::unordered_map<nvinfer1::ITensor*, float> quant_dynamic_range_;
}; // class TensorRTEngine }; // class TensorRTEngine
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册