未验证 提交 dd3c2422 编写于 作者: 石晓伟 提交者: GitHub

fix multi-thread exec of trt, test=develop (#19379)

上级 9048229b
......@@ -35,8 +35,15 @@ void TensorRTEngine::Build(const DescType &paddle_model) {
void TensorRTEngine::Execute(int batch_size, std::vector<void *> *buffers,
cudaStream_t stream) {
freshDeviceId();
const std::thread::id tid = std::this_thread::get_id();
batch_size_ = batch_size;
infer_context_->enqueue(batch_size, buffers->data(), stream, nullptr);
if (infer_context_.find(tid) == infer_context_.end()) {
PADDLE_ENFORCE_NOT_NULL(
infer_engine_,
"You should build engine first and then set the context.");
infer_context_[tid].reset(infer_engine_->createExecutionContext());
}
infer_context_[tid]->enqueue(batch_size, buffers->data(), stream, nullptr);
cudaStreamSynchronize(stream);
SetRuntimeBatch(batch_size);
}
......@@ -111,8 +118,6 @@ void TensorRTEngine::FreezeNetwork() {
infer_engine_.reset(infer_builder_->buildCudaEngine(*infer_network_));
PADDLE_ENFORCE(infer_engine_ != nullptr, "build cuda engine failed!");
infer_context_.reset(infer_engine_->createExecutionContext());
}
nvinfer1::ITensor *TensorRTEngine::DeclareInput(const std::string &name,
......
......@@ -128,7 +128,6 @@ class TensorRTEngine {
&inference::Singleton<plugin::PluginFactoryTensorRT>::Global()));
PADDLE_ENFORCE(infer_engine_ != nullptr,
"build cuda engine failed when deserialize engine info.!");
infer_context_.reset(infer_engine_->createExecutionContext());
}
void SetRuntimeBatch(size_t batch_size);
......@@ -200,7 +199,8 @@ class TensorRTEngine {
infer_ptr<nvinfer1::IBuilder> infer_builder_;
infer_ptr<nvinfer1::INetworkDefinition> infer_network_;
infer_ptr<nvinfer1::ICudaEngine> infer_engine_;
infer_ptr<nvinfer1::IExecutionContext> infer_context_;
std::unordered_map<std::thread::id, infer_ptr<nvinfer1::IExecutionContext>>
infer_context_;
infer_ptr<nvinfer1::IHostMemory> ihost_memory_;
std::unordered_map<nvinfer1::ITensor*, float> quant_dynamic_range_;
}; // class TensorRTEngine
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册