未验证 提交 ceda0b9b 编写于 作者: Z Zhaolong Xing 提交者: GitHub

[Fix BUG]: Core when multi thread + clone + paddle-trt (#22442)

* add mutex for trt engine
test=develop

* add the test for copy_to_cpu
test=develop
上级 30320b33
...@@ -138,7 +138,8 @@ void ZeroCopyTensor::copy_to_cpu(T *data) { ...@@ -138,7 +138,8 @@ void ZeroCopyTensor::copy_to_cpu(T *data) {
static_cast<const platform::CUDADeviceContext *>(pool.Get(gpu_place)); static_cast<const platform::CUDADeviceContext *>(pool.Get(gpu_place));
memory::Copy(platform::CPUPlace(), static_cast<void *>(data), gpu_place, memory::Copy(platform::CPUPlace(), static_cast<void *>(data), gpu_place,
t_data, ele_num * sizeof(T), dev_ctx->stream()); t_data, ele_num * sizeof(T), dev_ctx->stream());
cudaDeviceSynchronize();
cudaStreamSynchronize(dev_ctx->stream());
#else #else
PADDLE_THROW("Not compile with CUDA, should not reach here."); PADDLE_THROW("Not compile with CUDA, should not reach here.");
#endif #endif
......
...@@ -38,13 +38,13 @@ void TensorRTEngine::Execute(int batch_size, std::vector<void *> *buffers, ...@@ -38,13 +38,13 @@ void TensorRTEngine::Execute(int batch_size, std::vector<void *> *buffers,
const std::thread::id tid = std::this_thread::get_id(); const std::thread::id tid = std::this_thread::get_id();
batch_size_ = batch_size; batch_size_ = batch_size;
if (infer_context_.find(tid) == infer_context_.end()) { if (infer_context_.find(tid) == infer_context_.end()) {
std::unique_lock<std::mutex> lock(mutex_);
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
infer_engine_, infer_engine_,
"You should build engine first and then set the context."); "You should build engine first and then set the context.");
infer_context_[tid].reset(infer_engine_->createExecutionContext()); infer_context_[tid].reset(infer_engine_->createExecutionContext());
} }
infer_context_[tid]->enqueue(batch_size, buffers->data(), stream, nullptr); infer_context_[tid]->enqueue(batch_size, buffers->data(), stream, nullptr);
cudaStreamSynchronize(stream);
SetRuntimeBatch(batch_size); SetRuntimeBatch(batch_size);
} }
......
...@@ -82,7 +82,7 @@ class TensorRTEngine { ...@@ -82,7 +82,7 @@ class TensorRTEngine {
void Build(const DescType& paddle_model); void Build(const DescType& paddle_model);
void Execute(int batch_size, std::vector<void*>* buffers, void Execute(int batch_size, std::vector<void*>* buffers,
cudaStream_t stream); cudaStream_t stream = nullptr);
// Initialize the inference network, so that TensorRT layers can add to this // Initialize the inference network, so that TensorRT layers can add to this
// network. // network.
...@@ -216,6 +216,7 @@ class TensorRTEngine { ...@@ -216,6 +216,7 @@ class TensorRTEngine {
infer_context_; infer_context_;
infer_ptr<nvinfer1::IHostMemory> ihost_memory_; infer_ptr<nvinfer1::IHostMemory> ihost_memory_;
std::unordered_map<nvinfer1::ITensor*, float> quant_dynamic_range_; std::unordered_map<nvinfer1::ITensor*, float> quant_dynamic_range_;
std::mutex mutex_;
}; // class TensorRTEngine }; // class TensorRTEngine
#define IS_TRT_VERSION_GE(version) \ #define IS_TRT_VERSION_GE(version) \
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#include <gflags/gflags.h> #include <gflags/gflags.h>
#include <glog/logging.h> #include <glog/logging.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <numeric>
#include "paddle/fluid/inference/tests/api/trt_test_helper.h" #include "paddle/fluid/inference/tests/api/trt_test_helper.h"
...@@ -44,6 +45,15 @@ TEST(quant_int8, resnet50) { ...@@ -44,6 +45,15 @@ TEST(quant_int8, resnet50) {
input_t->copy_from_cpu(input); input_t->copy_from_cpu(input);
ASSERT_TRUE(predictor->ZeroCopyRun()); ASSERT_TRUE(predictor->ZeroCopyRun());
std::vector<float> out_data;
auto output_names = predictor->GetOutputNames();
auto output_t = predictor->GetOutputTensor(output_names[0]);
std::vector<int> output_shape = output_t->shape();
int out_num = std::accumulate(output_shape.begin(), output_shape.end(), 1,
std::multiplies<int>());
out_data.resize(out_num);
output_t->copy_to_cpu(out_data.data());
} }
} // namespace inference } // namespace inference
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册