未验证 提交 7f22ef54 编写于 作者: A Aganlengzi 提交者: GitHub

[CustomDevice]add custom place supports (#43813)

* [CustomDevice]add custom place supports

* sync format
上级 b848bd37
......@@ -52,7 +52,8 @@ class ArrayOp : public framework::OperatorBase {
size_t offset;
if (platform::is_gpu_place(i_tensor.place()) ||
platform::is_xpu_place(i_tensor.place()) ||
platform::is_npu_place(i_tensor.place())) {
platform::is_npu_place(i_tensor.place()) ||
platform::is_custom_place(i_tensor.place())) {
// FIXME: Avoid copy from GPU to CPU
framework::Tensor t;
framework::TensorCopy(i_tensor, platform::CPUPlace(), dev_ctx, &t);
......
......@@ -566,7 +566,8 @@ class ReduceOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(
platform::is_gpu_place(ctx.GetPlace()) ||
platform::is_npu_place(ctx.GetPlace()) ||
platform::is_mlu_place(ctx.GetPlace()),
platform::is_mlu_place(ctx.GetPlace()) ||
platform::is_custom_place(ctx.GetPlace()),
true,
platform::errors::InvalidArgument(
"float16 can only be used on GPU or NPU or MLU place"));
......
......@@ -2801,6 +2801,7 @@ All parameter, weight, gradient are variables in Paddle.
.def("_equals", &IsSamePlace<platform::Place, platform::IPUPlace>)
.def("_equals", &IsSamePlace<platform::Place, platform::CUDAPinnedPlace>)
.def("_equals", &IsSamePlace<platform::Place, platform::MLUPlace>)
.def("_equals", &IsSamePlace<platform::Place, platform::CustomPlace>)
.def("is_gpu_place",
[](platform::Place &self) { return platform::is_gpu_place(self); })
.def("is_cpu_place",
......
......@@ -349,6 +349,10 @@ def get_device():
elif isinstance(place, core.MLUPlace):
device_id = place.get_device_id()
device = 'mlu:' + str(device_id)
elif isinstance(place, core.CustomPlace):
device_id = place.get_device_id()
device_type = place.get_device_type()
device = device_type + ':' + str(device_id)
else:
raise ValueError("The device specification {} is invalid".format(place))
......
......@@ -1685,7 +1685,8 @@ class OpTest(unittest.TestCase):
# Currently not support ParallelExecutor on XPUPlace.
if not paddle.is_compiled_with_xpu(
) and not paddle.is_compiled_with_npu(
) and not paddle.is_compiled_with_mlu():
) and not paddle.is_compiled_with_mlu() and not isinstance(
place, core.CustomPlace):
self.check_inplace_output_with_place(place,
no_check_set=no_check_set,
inplace_atol=inplace_atol)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册