From 2ac6c6c3af8bf8767770444ff5bc40d72f6e9aa9 Mon Sep 17 00:00:00 2001 From: Zhou Wei <52485244+zhouwei25@users.noreply.github.com> Date: Thu, 15 Oct 2020 16:17:11 +0800 Subject: [PATCH] fix bug of tensor copy of CUDAPinnedPlace (#27966) --- paddle/fluid/framework/tensor_util.cc | 12 ++++++++++++ python/paddle/fluid/tests/unittests/test_var_base.py | 2 +- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/framework/tensor_util.cc b/paddle/fluid/framework/tensor_util.cc index 1a6fe7ded30..4730f6a4ec8 100644 --- a/paddle/fluid/framework/tensor_util.cc +++ b/paddle/fluid/framework/tensor_util.cc @@ -84,6 +84,12 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place, } #endif #ifdef PADDLE_WITH_CUDA + else if (platform::is_cuda_pinned_place(src_place) && // NOLINT + platform::is_cuda_pinned_place(dst_place)) { + memory::Copy(BOOST_GET_CONST(platform::CUDAPinnedPlace, dst_place), dst_ptr, + BOOST_GET_CONST(platform::CUDAPinnedPlace, src_place), src_ptr, + size); + } else if (platform::is_cuda_pinned_place(src_place) && // NOLINT platform::is_cpu_place(dst_place)) { memory::Copy(BOOST_GET_CONST(platform::CPUPlace, dst_place), dst_ptr, @@ -285,6 +291,12 @@ void TensorCopySync(const Tensor& src, const platform::Place& dst_place, } #endif #ifdef PADDLE_WITH_CUDA + else if (platform::is_cuda_pinned_place(src_place) && // NOLINT + platform::is_cuda_pinned_place(dst_place)) { + memory::Copy(BOOST_GET_CONST(platform::CUDAPinnedPlace, dst_place), dst_ptr, + BOOST_GET_CONST(platform::CUDAPinnedPlace, src_place), src_ptr, + size); + } else if (platform::is_cuda_pinned_place(src_place) && // NOLINT platform::is_cpu_place(dst_place)) { memory::Copy(BOOST_GET_CONST(platform::CPUPlace, dst_place), dst_ptr, diff --git a/python/paddle/fluid/tests/unittests/test_var_base.py b/python/paddle/fluid/tests/unittests/test_var_base.py index 4462bbf9316..6d4258a426d 100644 --- a/python/paddle/fluid/tests/unittests/test_var_base.py +++ b/python/paddle/fluid/tests/unittests/test_var_base.py @@ -141,7 +141,7 @@ class TestVarBase(unittest.TestCase): _test_place(core.CPUPlace()) if core.is_compiled_with_cuda(): - #_test_place(core.CUDAPinnedPlace()) + _test_place(core.CUDAPinnedPlace()) _test_place(core.CUDAPlace(0)) def test_to_variable(self): -- GitLab