未验证 提交 7f22ef54 编写于 作者: A Aganlengzi 提交者: GitHub

[CustomDevice]add custom place supports (#43813)

* [CustomDevice]add custom place supports

* sync format
上级 b848bd37
...@@ -52,7 +52,8 @@ class ArrayOp : public framework::OperatorBase { ...@@ -52,7 +52,8 @@ class ArrayOp : public framework::OperatorBase {
size_t offset; size_t offset;
if (platform::is_gpu_place(i_tensor.place()) || if (platform::is_gpu_place(i_tensor.place()) ||
platform::is_xpu_place(i_tensor.place()) || platform::is_xpu_place(i_tensor.place()) ||
platform::is_npu_place(i_tensor.place())) { platform::is_npu_place(i_tensor.place()) ||
platform::is_custom_place(i_tensor.place())) {
// FIXME: Avoid copy from GPU to CPU // FIXME: Avoid copy from GPU to CPU
framework::Tensor t; framework::Tensor t;
framework::TensorCopy(i_tensor, platform::CPUPlace(), dev_ctx, &t); framework::TensorCopy(i_tensor, platform::CPUPlace(), dev_ctx, &t);
......
...@@ -566,7 +566,8 @@ class ReduceOp : public framework::OperatorWithKernel { ...@@ -566,7 +566,8 @@ class ReduceOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
platform::is_gpu_place(ctx.GetPlace()) || platform::is_gpu_place(ctx.GetPlace()) ||
platform::is_npu_place(ctx.GetPlace()) || platform::is_npu_place(ctx.GetPlace()) ||
platform::is_mlu_place(ctx.GetPlace()), platform::is_mlu_place(ctx.GetPlace()) ||
platform::is_custom_place(ctx.GetPlace()),
true, true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"float16 can only be used on GPU or NPU or MLU place")); "float16 can only be used on GPU or NPU or MLU place"));
......
...@@ -2801,6 +2801,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -2801,6 +2801,7 @@ All parameter, weight, gradient are variables in Paddle.
.def("_equals", &IsSamePlace<platform::Place, platform::IPUPlace>) .def("_equals", &IsSamePlace<platform::Place, platform::IPUPlace>)
.def("_equals", &IsSamePlace<platform::Place, platform::CUDAPinnedPlace>) .def("_equals", &IsSamePlace<platform::Place, platform::CUDAPinnedPlace>)
.def("_equals", &IsSamePlace<platform::Place, platform::MLUPlace>) .def("_equals", &IsSamePlace<platform::Place, platform::MLUPlace>)
.def("_equals", &IsSamePlace<platform::Place, platform::CustomPlace>)
.def("is_gpu_place", .def("is_gpu_place",
[](platform::Place &self) { return platform::is_gpu_place(self); }) [](platform::Place &self) { return platform::is_gpu_place(self); })
.def("is_cpu_place", .def("is_cpu_place",
......
...@@ -349,6 +349,10 @@ def get_device(): ...@@ -349,6 +349,10 @@ def get_device():
elif isinstance(place, core.MLUPlace): elif isinstance(place, core.MLUPlace):
device_id = place.get_device_id() device_id = place.get_device_id()
device = 'mlu:' + str(device_id) device = 'mlu:' + str(device_id)
elif isinstance(place, core.CustomPlace):
device_id = place.get_device_id()
device_type = place.get_device_type()
device = device_type + ':' + str(device_id)
else: else:
raise ValueError("The device specification {} is invalid".format(place)) raise ValueError("The device specification {} is invalid".format(place))
......
...@@ -1685,7 +1685,8 @@ class OpTest(unittest.TestCase): ...@@ -1685,7 +1685,8 @@ class OpTest(unittest.TestCase):
# Currently not support ParallelExecutor on XPUPlace. # Currently not support ParallelExecutor on XPUPlace.
if not paddle.is_compiled_with_xpu( if not paddle.is_compiled_with_xpu(
) and not paddle.is_compiled_with_npu( ) and not paddle.is_compiled_with_npu(
) and not paddle.is_compiled_with_mlu(): ) and not paddle.is_compiled_with_mlu() and not isinstance(
place, core.CustomPlace):
self.check_inplace_output_with_place(place, self.check_inplace_output_with_place(place,
no_check_set=no_check_set, no_check_set=no_check_set,
inplace_atol=inplace_atol) inplace_atol=inplace_atol)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册