From 7f5d532a9c4d2b7355e848d4f7ce847782207eef Mon Sep 17 00:00:00 2001 From: rensilin <752318213@qq.com> Date: Tue, 10 Dec 2019 16:34:24 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20fail=20to=20call=20ZeroCopyTensor::mutab?= =?UTF-8?q?le=5Fdata()=20when=20device=5Fid=20is=20no=E2=80=A6=20(#21461)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * ZeroCopyTensor::mutable_data in the right device, test=develop * add unittest for zerocopy, test=develop --- paddle/fluid/inference/api/details/zero_copy_tensor.cc | 2 +- paddle/fluid/inference/tests/api/trt_fc_prelu_test.cc | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/inference/api/details/zero_copy_tensor.cc b/paddle/fluid/inference/api/details/zero_copy_tensor.cc index 59ad2c09c0..271b0fcbb7 100644 --- a/paddle/fluid/inference/api/details/zero_copy_tensor.cc +++ b/paddle/fluid/inference/api/details/zero_copy_tensor.cc @@ -52,7 +52,7 @@ T *ZeroCopyTensor::mutable_data(PaddlePlace place) { return tensor->mutable_data(platform::CPUPlace()); } case static_cast(PaddlePlace::kGPU): { - return tensor->mutable_data(platform::CUDAPlace()); + return tensor->mutable_data(platform::CUDAPlace(device_)); } default: PADDLE_THROW("Unsupported place: %d", static_cast(place)); diff --git a/paddle/fluid/inference/tests/api/trt_fc_prelu_test.cc b/paddle/fluid/inference/tests/api/trt_fc_prelu_test.cc index cb00c9c21c..37a443e0f6 100644 --- a/paddle/fluid/inference/tests/api/trt_fc_prelu_test.cc +++ b/paddle/fluid/inference/tests/api/trt_fc_prelu_test.cc @@ -51,6 +51,7 @@ TEST(ZeroCopyTensor, uint8) { input_t->Reshape({batch_size, length}); input_t->copy_from_cpu(input); input_t->type(); + input_t->mutable_data(PaddlePlace::kGPU); ASSERT_TRUE(predictor->ZeroCopyRun()); } -- GitLab