From 1ed8e9b8249e2380fb12197de33367450a2fe6b9 Mon Sep 17 00:00:00 2001 From: duanyanhui <45005871+YanhuiDua@users.noreply.github.com> Date: Fri, 9 Sep 2022 20:17:13 +0800 Subject: [PATCH] make memcpy op to support custom_device (#45918) * make memcpy op to support custom device * fix bug --- paddle/fluid/operators/memcpy_op.cc | 4 +++- paddle/fluid/operators/memcpy_op.h | 6 ++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/memcpy_op.cc b/paddle/fluid/operators/memcpy_op.cc index ef430f8bfa..9fb06c5968 100644 --- a/paddle/fluid/operators/memcpy_op.cc +++ b/paddle/fluid/operators/memcpy_op.cc @@ -113,7 +113,9 @@ class MemcpyOpProtoMaker : public framework::OpProtoAndCheckerMaker { "1: dst is on CUDAPlace. " "2: dst is on CUDAPinnedPlace. " "3: dst is on XPUPlace. " - "4: dst is on NPUPlace. "); + "4: dst is on NPUPlace. " + "5: dst is on NPUPinnerPlace. " + "6: dst is on CustomDevicePlace"); AddComment(R"DOC( Memcpy Operator. By now, it ONLY supports the memcopy between CUDAPinnedPlace <-> CUDAPlace or diff --git a/paddle/fluid/operators/memcpy_op.h b/paddle/fluid/operators/memcpy_op.h index a35fefa53b..092d0cf368 100644 --- a/paddle/fluid/operators/memcpy_op.h +++ b/paddle/fluid/operators/memcpy_op.h @@ -41,6 +41,7 @@ class MemcpyFunctor { XPU = 3, NPU = 4, NPU_PINNED = 5, + CUSTOM_DEVICE = 6, }; public: @@ -67,6 +68,11 @@ class MemcpyFunctor { } else if (dst_place_type_ == DeviceType::NPU_PINNED) { /* npu->npu_pin */ framework::TensorCopy( lod_tensor, platform::NPUPinnedPlace(), dev_ctx_, &out_tensor); +#endif +#ifdef PADDLE_WTIH_CUSTOM_DEVICE + } else if (dst_place_type_ == DeviceType::CUSTOM_DEVICE) { + framework::TensorCopy( + lod_tensor, dev_ctx_.GetPlace(), dev_ctx_, &out_tensor); #endif } else { PADDLE_THROW(platform::errors::Unimplemented( -- GitLab