未验证 提交 4df02fdf 编写于 作者: A Aganlengzi 提交者: GitHub

[CustomDevice] op_test supports custom device (#42227)

* [DO NOT MERGE] test op_test

* update with more related modifications

* split op_test.py to use test=allcases for testing

* split op_test.py to use test=allcases for testing
上级 2cebcf4a
...@@ -835,6 +835,16 @@ class AllocatorFacadePrivate { ...@@ -835,6 +835,16 @@ class AllocatorFacadePrivate {
platform::MLUPlace p(i); platform::MLUPlace p(i);
system_allocators_[p] = std::make_shared<NaiveBestFitAllocator>(p); system_allocators_[p] = std::make_shared<NaiveBestFitAllocator>(p);
} }
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
auto device_types = phi::DeviceManager::GetAllCustomDeviceTypes();
for (const auto& dev_type : device_types) {
for (size_t dev_id = 0;
dev_id < phi::DeviceManager::GetDeviceCount(dev_type); dev_id++) {
platform::CustomPlace p(dev_type, dev_id);
system_allocators_[p] = std::make_shared<NaiveBestFitAllocator>(p);
}
}
#endif #endif
} }
......
...@@ -2206,6 +2206,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -2206,6 +2206,7 @@ All parameter, weight, gradient are variables in Paddle.
std::exit(-1); std::exit(-1);
#endif #endif
}) })
.def("_type", &PlaceIndex<platform::CustomPlace>)
.def("get_device_id", .def("get_device_id",
[](const platform::CustomPlace &self) { return self.GetDeviceId(); }) [](const platform::CustomPlace &self) { return self.GetDeviceId(); })
.def("get_device_type", .def("get_device_type",
......
...@@ -1386,7 +1386,8 @@ class Executor(object): ...@@ -1386,7 +1386,8 @@ class Executor(object):
def _can_use_interpreter_core(program, place): def _can_use_interpreter_core(program, place):
if core.is_compiled_with_npu() or core.is_compiled_with_xpu( if core.is_compiled_with_npu() or core.is_compiled_with_xpu(
) or core.is_compiled_with_mlu() or core.is_compiled_with_ipu(): ) or core.is_compiled_with_mlu() or core.is_compiled_with_ipu(
) or isinstance(place, core.CustomPlace):
return False return False
compiled = isinstance(program, compiler.CompiledProgram) compiled = isinstance(program, compiler.CompiledProgram)
......
...@@ -341,6 +341,10 @@ class OpTest(unittest.TestCase): ...@@ -341,6 +341,10 @@ class OpTest(unittest.TestCase):
def is_mlu_op_test(): def is_mlu_op_test():
return hasattr(cls, "use_mlu") and cls.use_mlu == True return hasattr(cls, "use_mlu") and cls.use_mlu == True
def is_custom_device_op_test():
return hasattr(
cls, "use_custom_device") and cls.use_custom_device == True
if not hasattr(cls, "op_type"): if not hasattr(cls, "op_type"):
raise AssertionError( raise AssertionError(
"This test do not have op_type in class attrs, " "This test do not have op_type in class attrs, "
...@@ -364,7 +368,8 @@ class OpTest(unittest.TestCase): ...@@ -364,7 +368,8 @@ class OpTest(unittest.TestCase):
and not is_mkldnn_op_test() \ and not is_mkldnn_op_test() \
and not is_rocm_op_test() \ and not is_rocm_op_test() \
and not is_npu_op_test() \ and not is_npu_op_test() \
and not is_mlu_op_test(): and not is_mlu_op_test() \
and not is_custom_device_op_test():
raise AssertionError( raise AssertionError(
"This test of %s op needs check_grad with fp64 precision." % "This test of %s op needs check_grad with fp64 precision." %
cls.op_type) cls.op_type)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册