From e228e7079fe1b42857b025afdc50e8f4aad63941 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=9F=B3=E6=99=93=E4=BC=9F?= <39303645+Shixiaowei02@users.noreply.github.com> Date: Fri, 6 Dec 2019 12:14:14 +0800 Subject: [PATCH] fix ZeroCopyTensor::mutable_data(), test=release/1.6 (#21581) --- paddle/fluid/inference/api/details/zero_copy_tensor.cc | 2 +- paddle/fluid/inference/tests/api/trt_fc_prelu_test.cc | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/inference/api/details/zero_copy_tensor.cc b/paddle/fluid/inference/api/details/zero_copy_tensor.cc index 59ad2c09c0f..271b0fcbb72 100644 --- a/paddle/fluid/inference/api/details/zero_copy_tensor.cc +++ b/paddle/fluid/inference/api/details/zero_copy_tensor.cc @@ -52,7 +52,7 @@ T *ZeroCopyTensor::mutable_data(PaddlePlace place) { return tensor->mutable_data(platform::CPUPlace()); } case static_cast(PaddlePlace::kGPU): { - return tensor->mutable_data(platform::CUDAPlace()); + return tensor->mutable_data(platform::CUDAPlace(device_)); } default: PADDLE_THROW("Unsupported place: %d", static_cast(place)); diff --git a/paddle/fluid/inference/tests/api/trt_fc_prelu_test.cc b/paddle/fluid/inference/tests/api/trt_fc_prelu_test.cc index cb00c9c21c8..37a443e0f69 100644 --- a/paddle/fluid/inference/tests/api/trt_fc_prelu_test.cc +++ b/paddle/fluid/inference/tests/api/trt_fc_prelu_test.cc @@ -51,6 +51,7 @@ TEST(ZeroCopyTensor, uint8) { input_t->Reshape({batch_size, length}); input_t->copy_from_cpu(input); input_t->type(); + input_t->mutable_data(PaddlePlace::kGPU); ASSERT_TRUE(predictor->ZeroCopyRun()); } -- GitLab