From bcd67bdd71b887cb55bacd0a76ffc136fb0416e4 Mon Sep 17 00:00:00 2001 From: nhzlx Date: Tue, 24 Jul 2018 08:19:37 +0000 Subject: [PATCH] add assert for GetOutput --- paddle/fluid/inference/tensorrt/convert/ut_helper.h | 4 ++-- paddle/fluid/inference/tensorrt/engine.cc | 12 ++++++++---- paddle/fluid/inference/tensorrt/engine.h | 4 ++-- paddle/fluid/inference/tensorrt/test_engine.cc | 6 +++--- paddle/fluid/operators/tensorrt_engine_op.h | 7 +++++-- 5 files changed, 20 insertions(+), 13 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/convert/ut_helper.h b/paddle/fluid/inference/tensorrt/convert/ut_helper.h index 2e6c895b2e..f14885b238 100644 --- a/paddle/fluid/inference/tensorrt/convert/ut_helper.h +++ b/paddle/fluid/inference/tensorrt/convert/ut_helper.h @@ -139,11 +139,11 @@ class TRTConvertValidation { cudaStreamSynchronize(*engine_->stream()); ASSERT_FALSE(op_desc_->OutputArgumentNames().empty()); - const size_t output_space_size = 200; + const size_t output_space_size = 2000; for (const auto& output : op_desc_->OutputArgumentNames()) { std::vector fluid_out; std::vector trt_out(output_space_size); - engine_->GetOutputInCPU(output, &trt_out[0]); + engine_->GetOutputInCPU(output, &trt_out[0], output_space_size); cudaStreamSynchronize(*engine_->stream()); auto* var = scope_.FindVar(output); diff --git a/paddle/fluid/inference/tensorrt/engine.cc b/paddle/fluid/inference/tensorrt/engine.cc index 5bef393b20..b821c3d0bf 100644 --- a/paddle/fluid/inference/tensorrt/engine.cc +++ b/paddle/fluid/inference/tensorrt/engine.cc @@ -1,7 +1,7 @@ /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. +Licensed under the Apache License, Version 2.0 (the "License"); you may not use +this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 @@ -149,7 +149,8 @@ void *TensorRTEngine::GetOutputInGPU(const std::string &name) { return buffer(name).buffer; } -void TensorRTEngine::GetOutputInGPU(const std::string &name, void *dst) { +void TensorRTEngine::GetOutputInGPU(const std::string &name, void *dst, + size_t max_size) { // determine data size auto *output = TensorRTEngine::GetITensor(name); nvinfer1::Dims dims = output->getDimensions(); @@ -161,6 +162,7 @@ void TensorRTEngine::GetOutputInGPU(const std::string &name, void *dst) { PADDLE_ENFORCE(it != buffer_sizes_.end()); PADDLE_ENFORCE_GT(it->second, 0); PADDLE_ENFORCE_LE(dst_size, it->second); + PADDLE_ENFORCE_GE(max_size, dst_size); auto &buf = buffer(name); PADDLE_ENFORCE_NOT_NULL(buf.buffer, "buffer should be allocated before"); PADDLE_ENFORCE_EQ(cudaMemcpyAsync(dst, buf.buffer, dst_size, @@ -168,7 +170,8 @@ void TensorRTEngine::GetOutputInGPU(const std::string &name, void *dst) { 0); } -void TensorRTEngine::GetOutputInCPU(const std::string &name, void *dst) { +void TensorRTEngine::GetOutputInCPU(const std::string &name, void *dst, + size_t max_size) { // determine data size auto *output = TensorRTEngine::GetITensor(name); @@ -180,6 +183,7 @@ void TensorRTEngine::GetOutputInCPU(const std::string &name, void *dst) { PADDLE_ENFORCE(it != buffer_sizes_.end()); PADDLE_ENFORCE_GT(it->second, 0); PADDLE_ENFORCE_LE(dst_size, it->second); + PADDLE_ENFORCE_GE(max_size, dst_size); auto &buf = buffer(name); PADDLE_ENFORCE_NOT_NULL(buf.buffer, "buffer should be allocated before"); PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(dst, buf.buffer, dst_size, diff --git a/paddle/fluid/inference/tensorrt/engine.h b/paddle/fluid/inference/tensorrt/engine.h index b2b714d0c9..694468c419 100644 --- a/paddle/fluid/inference/tensorrt/engine.h +++ b/paddle/fluid/inference/tensorrt/engine.h @@ -106,10 +106,10 @@ class TensorRTEngine : public EngineBase { // Return the output's GPU memory address without copy. void* GetOutputInGPU(const std::string& name); // Copy data into dst inside the GPU device. - void GetOutputInGPU(const std::string& name, void* dst); + void GetOutputInGPU(const std::string& name, void* dst, size_t max_size); // LOW EFFICENCY! Get output to CPU, this will trigger a memory copy from GPU // to CPU. - void GetOutputInCPU(const std::string& name, void* dst); + void GetOutputInCPU(const std::string& name, void* dst, size_t max_size); // Fill an ITensor into map itensor_map_. void SetITensor(const std::string& name, nvinfer1::ITensor* tensor); // Get an ITensor called name. diff --git a/paddle/fluid/inference/tensorrt/test_engine.cc b/paddle/fluid/inference/tensorrt/test_engine.cc index 0e2b00911f..f8732e51b6 100644 --- a/paddle/fluid/inference/tensorrt/test_engine.cc +++ b/paddle/fluid/inference/tensorrt/test_engine.cc @@ -71,7 +71,7 @@ TEST_F(TensorRTEngineTest, add_layer) { LOG(INFO) << "to get output"; float y_cpu; - engine_->GetOutputInCPU("y", &y_cpu); + engine_->GetOutputInCPU("y", &y_cpu, 1 * sizeof(float)); LOG(INFO) << "to checkout output"; ASSERT_EQ(y_cpu, x_v * 2 + 3); @@ -108,7 +108,7 @@ TEST_F(TensorRTEngineTest, add_layer_multi_dim) { ASSERT_EQ(dims.nbDims, 3); ASSERT_EQ(dims.d[0], 2); ASSERT_EQ(dims.d[1], 1); - engine_->GetOutputInCPU("y", &y_cpu[0]); + engine_->GetOutputInCPU("y", &y_cpu[0], 2 * sizeof(float)); ASSERT_EQ(y_cpu[0], 4.5); ASSERT_EQ(y_cpu[1], 14.5); } @@ -141,7 +141,7 @@ TEST_F(TensorRTEngineTest, test_conv2d_temp) { LOG(INFO) << "to get output"; float* y_cpu = new float[18]; - engine_->GetOutputInCPU("y", &y_cpu[0]); + engine_->GetOutputInCPU("y", &y_cpu[0], 18 * sizeof(float)); ASSERT_EQ(y_cpu[0], 4.0); ASSERT_EQ(y_cpu[1], 6.0); } diff --git a/paddle/fluid/operators/tensorrt_engine_op.h b/paddle/fluid/operators/tensorrt_engine_op.h index 816665cf17..32d10fd8a5 100644 --- a/paddle/fluid/operators/tensorrt_engine_op.h +++ b/paddle/fluid/operators/tensorrt_engine_op.h @@ -100,8 +100,11 @@ class TensorRTEngineKernel : public framework::OpKernel { // tensor. // if (platform::is_cpu_place(fluid_t->place())) { // TODO(Superjomn) change this float to dtype size. - engine->GetOutputInCPU( - y, fluid_t->mutable_data(platform::CPUPlace())); + auto size = inference::analysis::AccuDims(dims.d, dims.nbDims) * + FLAGS_tensorrt_engine_batch_size; + engine->GetOutputInCPU(y, + fluid_t->mutable_data(platform::CPUPlace()), + size * sizeof(float)); //} else { // engine->GetOutputInGPU( // y, fluid_t->mutable_data(platform::CUDAPlace()), -- GitLab