未验证 提交 2ac6c6c3 编写于 作者: Z Zhou Wei 提交者: GitHub

fix bug of tensor copy of CUDAPinnedPlace (#27966)

上级 f58434ef
......@@ -84,6 +84,12 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place,
}
#endif
#ifdef PADDLE_WITH_CUDA
else if (platform::is_cuda_pinned_place(src_place) && // NOLINT
platform::is_cuda_pinned_place(dst_place)) {
memory::Copy(BOOST_GET_CONST(platform::CUDAPinnedPlace, dst_place), dst_ptr,
BOOST_GET_CONST(platform::CUDAPinnedPlace, src_place), src_ptr,
size);
}
else if (platform::is_cuda_pinned_place(src_place) && // NOLINT
platform::is_cpu_place(dst_place)) {
memory::Copy(BOOST_GET_CONST(platform::CPUPlace, dst_place), dst_ptr,
......@@ -285,6 +291,12 @@ void TensorCopySync(const Tensor& src, const platform::Place& dst_place,
}
#endif
#ifdef PADDLE_WITH_CUDA
else if (platform::is_cuda_pinned_place(src_place) && // NOLINT
platform::is_cuda_pinned_place(dst_place)) {
memory::Copy(BOOST_GET_CONST(platform::CUDAPinnedPlace, dst_place), dst_ptr,
BOOST_GET_CONST(platform::CUDAPinnedPlace, src_place), src_ptr,
size);
}
else if (platform::is_cuda_pinned_place(src_place) && // NOLINT
platform::is_cpu_place(dst_place)) {
memory::Copy(BOOST_GET_CONST(platform::CPUPlace, dst_place), dst_ptr,
......
......@@ -141,7 +141,7 @@ class TestVarBase(unittest.TestCase):
_test_place(core.CPUPlace())
if core.is_compiled_with_cuda():
#_test_place(core.CUDAPinnedPlace())
_test_place(core.CUDAPinnedPlace())
_test_place(core.CUDAPlace(0))
def test_to_variable(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册