未验证 提交 3b0aa75e 编写于 作者: A Aganlengzi 提交者: GitHub

[CustomDevice] register Copy for custom device (#44200)

* [CustomDevice] register Copy for custom device

* [CustomDevice] register Copy for custom device

* [CustomDevice] register Copy for custom device

* merge and add uts

* merge and add uts

* fix for blocking and unittests coverage
上级 db864f0b
......@@ -200,10 +200,9 @@ void Copy(const Context& dev_ctx,
paddle::memory::Copy(
dst_cuda_pinned_place, dst_ptr, src_gpu_place, src_ptr, size, stream);
#endif
}
#ifdef PADDLE_WITH_XPU
else if (paddle::platform::is_xpu_place(src_place) && // NOLINT
paddle::platform::is_cpu_place(dst_place)) {
} else if (paddle::platform::is_xpu_place(src_place) && // NOLINT
paddle::platform::is_cpu_place(dst_place)) {
paddle::memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size);
} else if (paddle::platform::is_cpu_place(src_place) &&
paddle::platform::is_xpu_place(dst_place)) {
......@@ -216,11 +215,40 @@ void Copy(const Context& dev_ctx,
return;
}
paddle::memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size);
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
} else if (paddle::platform::is_custom_place(src_place) && // NOLINT
paddle::platform::is_cpu_place(dst_place)) {
auto stream =
blocking
? nullptr
: reinterpret_cast<const paddle::platform::CustomDeviceContext&>(
dev_ctx)
.stream();
paddle::memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size, stream);
} else if (paddle::platform::is_cpu_place(src_place) && // NOLINT
paddle::platform::is_custom_place(dst_place)) {
auto stream =
blocking
? nullptr
: reinterpret_cast<const paddle::platform::CustomDeviceContext&>(
dev_ctx)
.stream();
paddle::memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size, stream);
} else if (paddle::platform::is_custom_place(src_place) && // NOLINT
paddle::platform::is_custom_place(dst_place)) {
auto stream =
blocking
? nullptr
: reinterpret_cast<const paddle::platform::CustomDeviceContext&>(
dev_ctx)
.stream();
paddle::memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size, stream);
#endif
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"Copy from %s to %s is not supported.", src_place, dst_place));
}
#endif
}
template <typename Context>
......@@ -363,4 +391,11 @@ template void Copy(const XPUContext& dev_ctx,
DenseTensor* dst);
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
template void Copy(const CustomContext& dev_ctx,
const DenseTensor& src,
Place dst_place,
bool blocking,
DenseTensor* dst);
#endif
} // namespace phi
......@@ -39,6 +39,7 @@ class TestCustomCPUPlugin(unittest.TestCase):
self._test_custom_device_dataloader()
self._test_custom_device_mnist()
self._test_eager_backward_api()
self._test_eager_copy_to()
self._test_custom_device_dataloader()
self._test_custom_device_mnist()
......@@ -133,6 +134,32 @@ class TestCustomCPUPlugin(unittest.TestCase):
self.assertTrue(x_tensor.grad.place.is_custom_place())
def _test_eager_copy_to(self):
import paddle
x = np.random.random([2, 2]).astype("float32")
# cpu -> custom
cpu_tensor = paddle.to_tensor(x,
dtype='float32',
place=paddle.CPUPlace())
custom_cpu_tensor = cpu_tensor._copy_to(
paddle.CustomPlace('custom_cpu', 0), True)
self.assertTrue(np.array_equal(custom_cpu_tensor, x))
self.assertTrue(custom_cpu_tensor.place.is_custom_place())
# custom -> custom
another_custom_cpu_tensor = custom_cpu_tensor._copy_to(
paddle.CustomPlace('custom_cpu', 0), True)
self.assertTrue(np.array_equal(another_custom_cpu_tensor, x))
self.assertTrue(another_custom_cpu_tensor.place.is_custom_place())
# custom -> cpu
another_cpu_tensor = custom_cpu_tensor._copy_to(paddle.CPUPlace(), True)
self.assertTrue(np.array_equal(another_cpu_tensor, x))
self.assertTrue(another_cpu_tensor.place.is_cpu_place())
# custom -> custom self
another_custom_cpu_tensor = another_custom_cpu_tensor._copy_to(
paddle.CustomPlace('custom_cpu', 0), True)
self.assertTrue(np.array_equal(another_custom_cpu_tensor, x))
self.assertTrue(another_custom_cpu_tensor.place.is_custom_place())
def tearDown(self):
del os.environ['CUSTOM_DEVICE_ROOT']
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册