未验证 提交 2ac6c6c3 编写于 作者: Z Zhou Wei 提交者: GitHub

fix bug of tensor copy of CUDAPinnedPlace (#27966)

上级 f58434ef
...@@ -84,6 +84,12 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place, ...@@ -84,6 +84,12 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place,
} }
#endif #endif
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
else if (platform::is_cuda_pinned_place(src_place) && // NOLINT
platform::is_cuda_pinned_place(dst_place)) {
memory::Copy(BOOST_GET_CONST(platform::CUDAPinnedPlace, dst_place), dst_ptr,
BOOST_GET_CONST(platform::CUDAPinnedPlace, src_place), src_ptr,
size);
}
else if (platform::is_cuda_pinned_place(src_place) && // NOLINT else if (platform::is_cuda_pinned_place(src_place) && // NOLINT
platform::is_cpu_place(dst_place)) { platform::is_cpu_place(dst_place)) {
memory::Copy(BOOST_GET_CONST(platform::CPUPlace, dst_place), dst_ptr, memory::Copy(BOOST_GET_CONST(platform::CPUPlace, dst_place), dst_ptr,
...@@ -285,6 +291,12 @@ void TensorCopySync(const Tensor& src, const platform::Place& dst_place, ...@@ -285,6 +291,12 @@ void TensorCopySync(const Tensor& src, const platform::Place& dst_place,
} }
#endif #endif
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
else if (platform::is_cuda_pinned_place(src_place) && // NOLINT
platform::is_cuda_pinned_place(dst_place)) {
memory::Copy(BOOST_GET_CONST(platform::CUDAPinnedPlace, dst_place), dst_ptr,
BOOST_GET_CONST(platform::CUDAPinnedPlace, src_place), src_ptr,
size);
}
else if (platform::is_cuda_pinned_place(src_place) && // NOLINT else if (platform::is_cuda_pinned_place(src_place) && // NOLINT
platform::is_cpu_place(dst_place)) { platform::is_cpu_place(dst_place)) {
memory::Copy(BOOST_GET_CONST(platform::CPUPlace, dst_place), dst_ptr, memory::Copy(BOOST_GET_CONST(platform::CPUPlace, dst_place), dst_ptr,
......
...@@ -141,7 +141,7 @@ class TestVarBase(unittest.TestCase): ...@@ -141,7 +141,7 @@ class TestVarBase(unittest.TestCase):
_test_place(core.CPUPlace()) _test_place(core.CPUPlace())
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
#_test_place(core.CUDAPinnedPlace()) _test_place(core.CUDAPinnedPlace())
_test_place(core.CUDAPlace(0)) _test_place(core.CUDAPlace(0))
def test_to_variable(self): def test_to_variable(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册