未验证 提交 1ed8e9b8 编写于 作者: D duanyanhui 提交者: GitHub

make memcpy op to support custom_device (#45918)

* make memcpy op to support custom device

* fix bug
上级 a687b531
...@@ -113,7 +113,9 @@ class MemcpyOpProtoMaker : public framework::OpProtoAndCheckerMaker { ...@@ -113,7 +113,9 @@ class MemcpyOpProtoMaker : public framework::OpProtoAndCheckerMaker {
"1: dst is on CUDAPlace. " "1: dst is on CUDAPlace. "
"2: dst is on CUDAPinnedPlace. " "2: dst is on CUDAPinnedPlace. "
"3: dst is on XPUPlace. " "3: dst is on XPUPlace. "
"4: dst is on NPUPlace. "); "4: dst is on NPUPlace. "
"5: dst is on NPUPinnerPlace. "
"6: dst is on CustomDevicePlace");
AddComment(R"DOC( AddComment(R"DOC(
Memcpy Operator. Memcpy Operator.
By now, it ONLY supports the memcopy between CUDAPinnedPlace <-> CUDAPlace or By now, it ONLY supports the memcopy between CUDAPinnedPlace <-> CUDAPlace or
......
...@@ -41,6 +41,7 @@ class MemcpyFunctor { ...@@ -41,6 +41,7 @@ class MemcpyFunctor {
XPU = 3, XPU = 3,
NPU = 4, NPU = 4,
NPU_PINNED = 5, NPU_PINNED = 5,
CUSTOM_DEVICE = 6,
}; };
public: public:
...@@ -67,6 +68,11 @@ class MemcpyFunctor { ...@@ -67,6 +68,11 @@ class MemcpyFunctor {
} else if (dst_place_type_ == DeviceType::NPU_PINNED) { /* npu->npu_pin */ } else if (dst_place_type_ == DeviceType::NPU_PINNED) { /* npu->npu_pin */
framework::TensorCopy( framework::TensorCopy(
lod_tensor, platform::NPUPinnedPlace(), dev_ctx_, &out_tensor); lod_tensor, platform::NPUPinnedPlace(), dev_ctx_, &out_tensor);
#endif
#ifdef PADDLE_WTIH_CUSTOM_DEVICE
} else if (dst_place_type_ == DeviceType::CUSTOM_DEVICE) {
framework::TensorCopy(
lod_tensor, dev_ctx_.GetPlace(), dev_ctx_, &out_tensor);
#endif #endif
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册