未验证 提交 e228e707 编写于 作者: 石晓伟 提交者: GitHub

fix ZeroCopyTensor::mutable_data(), test=release/1.6 (#21581)

上级 0a4002f5
......@@ -52,7 +52,7 @@ T *ZeroCopyTensor::mutable_data(PaddlePlace place) {
return tensor->mutable_data<T>(platform::CPUPlace());
}
case static_cast<int>(PaddlePlace::kGPU): {
return tensor->mutable_data<T>(platform::CUDAPlace());
return tensor->mutable_data<T>(platform::CUDAPlace(device_));
}
default:
PADDLE_THROW("Unsupported place: %d", static_cast<int>(place));
......
......@@ -51,6 +51,7 @@ TEST(ZeroCopyTensor, uint8) {
input_t->Reshape({batch_size, length});
input_t->copy_from_cpu(input);
input_t->type();
input_t->mutable_data<uint8_t>(PaddlePlace::kGPU);
ASSERT_TRUE(predictor->ZeroCopyRun());
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册