未验证 提交 1ed8e9b8 编写于 作者: D duanyanhui 提交者: GitHub

make memcpy op to support custom_device (#45918)

* make memcpy op to support custom device

* fix bug
上级 a687b531
......@@ -113,7 +113,9 @@ class MemcpyOpProtoMaker : public framework::OpProtoAndCheckerMaker {
"1: dst is on CUDAPlace. "
"2: dst is on CUDAPinnedPlace. "
"3: dst is on XPUPlace. "
"4: dst is on NPUPlace. ");
"4: dst is on NPUPlace. "
"5: dst is on NPUPinnerPlace. "
"6: dst is on CustomDevicePlace");
AddComment(R"DOC(
Memcpy Operator.
By now, it ONLY supports the memcopy between CUDAPinnedPlace <-> CUDAPlace or
......
......@@ -41,6 +41,7 @@ class MemcpyFunctor {
XPU = 3,
NPU = 4,
NPU_PINNED = 5,
CUSTOM_DEVICE = 6,
};
public:
......@@ -67,6 +68,11 @@ class MemcpyFunctor {
} else if (dst_place_type_ == DeviceType::NPU_PINNED) { /* npu->npu_pin */
framework::TensorCopy(
lod_tensor, platform::NPUPinnedPlace(), dev_ctx_, &out_tensor);
#endif
#ifdef PADDLE_WTIH_CUSTOM_DEVICE
} else if (dst_place_type_ == DeviceType::CUSTOM_DEVICE) {
framework::TensorCopy(
lod_tensor, dev_ctx_.GetPlace(), dev_ctx_, &out_tensor);
#endif
} else {
PADDLE_THROW(platform::errors::Unimplemented(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册