提交 7f5d532a 编写于 作者: R rensilin 提交者: 石晓伟

fix: fail to call ZeroCopyTensor::mutable_data() when device_id is no… (#21461)

* ZeroCopyTensor::mutable_data in the right device, test=develop

* add unittest for zerocopy, test=develop
上级 1d6f0b40
......@@ -52,7 +52,7 @@ T *ZeroCopyTensor::mutable_data(PaddlePlace place) {
return tensor->mutable_data<T>(platform::CPUPlace());
}
case static_cast<int>(PaddlePlace::kGPU): {
return tensor->mutable_data<T>(platform::CUDAPlace());
return tensor->mutable_data<T>(platform::CUDAPlace(device_));
}
default:
PADDLE_THROW("Unsupported place: %d", static_cast<int>(place));
......
......@@ -51,6 +51,7 @@ TEST(ZeroCopyTensor, uint8) {
input_t->Reshape({batch_size, length});
input_t->copy_from_cpu(input);
input_t->type();
input_t->mutable_data<uint8_t>(PaddlePlace::kGPU);
ASSERT_TRUE(predictor->ZeroCopyRun());
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册