未验证 提交 e228e707 编写于 作者: 石晓伟 提交者: GitHub

fix ZeroCopyTensor::mutable_data(), test=release/1.6 (#21581)

上级 0a4002f5
...@@ -52,7 +52,7 @@ T *ZeroCopyTensor::mutable_data(PaddlePlace place) { ...@@ -52,7 +52,7 @@ T *ZeroCopyTensor::mutable_data(PaddlePlace place) {
return tensor->mutable_data<T>(platform::CPUPlace()); return tensor->mutable_data<T>(platform::CPUPlace());
} }
case static_cast<int>(PaddlePlace::kGPU): { case static_cast<int>(PaddlePlace::kGPU): {
return tensor->mutable_data<T>(platform::CUDAPlace()); return tensor->mutable_data<T>(platform::CUDAPlace(device_));
} }
default: default:
PADDLE_THROW("Unsupported place: %d", static_cast<int>(place)); PADDLE_THROW("Unsupported place: %d", static_cast<int>(place));
......
...@@ -51,6 +51,7 @@ TEST(ZeroCopyTensor, uint8) { ...@@ -51,6 +51,7 @@ TEST(ZeroCopyTensor, uint8) {
input_t->Reshape({batch_size, length}); input_t->Reshape({batch_size, length});
input_t->copy_from_cpu(input); input_t->copy_from_cpu(input);
input_t->type(); input_t->type();
input_t->mutable_data<uint8_t>(PaddlePlace::kGPU);
ASSERT_TRUE(predictor->ZeroCopyRun()); ASSERT_TRUE(predictor->ZeroCopyRun());
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册