未验证 提交 f8d09011 编写于 作者: R ronnywang 提交者: GitHub

[CustomDevice] add model parallel support for custom device (#52872)

上级 6b756e8c
...@@ -36,11 +36,15 @@ ProcessGroupIdMap& ProcessGroupIdMap::GetInstance() { ...@@ -36,11 +36,15 @@ ProcessGroupIdMap& ProcessGroupIdMap::GetInstance() {
return instance; return instance;
} }
void ProcessGroupIdMap::DestroyProcessGroup(int gid) { void ProcessGroupIdMap::DestroyProcessGroup() {
int use_count = ProcessGroupIdMap::GetInstance()[gid].use_count(); auto& id_map = ProcessGroupIdMap::GetInstance();
for (int i = 0; i < use_count; ++i) { for (auto iter = id_map.begin(); iter != id_map.end(); ++iter) {
ProcessGroupIdMap::GetInstance()[gid].reset(); auto use_count = iter->second.use_count();
for (int i = 0; i < use_count; ++i) {
iter->second.reset();
}
} }
id_map.clear();
} }
} // namespace distributed } // namespace distributed
......
...@@ -502,7 +502,7 @@ class ProcessGroupIdMap ...@@ -502,7 +502,7 @@ class ProcessGroupIdMap
: public std::unordered_map<int, std::shared_ptr<ProcessGroup>> { : public std::unordered_map<int, std::shared_ptr<ProcessGroup>> {
public: public:
static ProcessGroupIdMap& GetInstance(); static ProcessGroupIdMap& GetInstance();
static void DestroyProcessGroup(int gid); static void DestroyProcessGroup();
}; };
// TODO(dev): The following method will be removed soon. // TODO(dev): The following method will be removed soon.
......
...@@ -215,7 +215,7 @@ endif() ...@@ -215,7 +215,7 @@ endif()
copy_if_different(${pybind_file} ${pybind_file_final}) copy_if_different(${pybind_file} ${pybind_file_final})
if (WITH_CUSTOM_DEVICE) if (WITH_CUSTOM_DEVICE)
cc_library(custom_device_common_op_registry SRCS custom_device_common_op_registry.cc DEPS operator) cc_library(custom_device_common_op_registry SRCS custom_device_common_op_registry.cc DEPS operator phi_api)
endif() endif()
if(NOT "${OP_LIST}" STREQUAL "") if(NOT "${OP_LIST}" STREQUAL "")
......
...@@ -1357,7 +1357,6 @@ void BindDistributed(py::module *m) { ...@@ -1357,7 +1357,6 @@ void BindDistributed(py::module *m) {
*m, "ProcessGroupIdMap") *m, "ProcessGroupIdMap")
.def_static("destroy", .def_static("destroy",
distributed::ProcessGroupIdMap::DestroyProcessGroup, distributed::ProcessGroupIdMap::DestroyProcessGroup,
py::arg("group_id") = 0,
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
} }
......
...@@ -32,7 +32,6 @@ from paddle.distributed.fleet.base.topology import ParallelMode # noqa: F401 ...@@ -32,7 +32,6 @@ from paddle.distributed.fleet.base.topology import ParallelMode # noqa: F401
from .collective import split # noqa: F401 from .collective import split # noqa: F401
from .collective import new_group # noqa: F401 from .collective import new_group # noqa: F401
from .collective import is_available from .collective import is_available
from .collective import _destroy_process_group_id_map
from .communication import ( from .communication import (
stream, stream,
ReduceOp, ReduceOp,
...@@ -122,5 +121,3 @@ __all__ = [ # noqa ...@@ -122,5 +121,3 @@ __all__ = [ # noqa
"is_available", "is_available",
"get_backend", "get_backend",
] ]
atexit.register(_destroy_process_group_id_map)
...@@ -172,16 +172,6 @@ def _set_custom_gid(gid): ...@@ -172,16 +172,6 @@ def _set_custom_gid(gid):
_custom_gid = gid _custom_gid = gid
def _destroy_process_group_id_map():
"""
Destroy the custom process group. Designed for CustomDevice.
"""
core.ProcessGroupIdMap.destroy()
def new_group(ranks=None, backend=None, timeout=_default_timeout): def new_group(ranks=None, backend=None, timeout=_default_timeout):
""" """
......
...@@ -223,3 +223,4 @@ atexit.register(core.clear_executor_cache) ...@@ -223,3 +223,4 @@ atexit.register(core.clear_executor_cache)
# Keep clear_kernel_factory running before clear_device_manager # Keep clear_kernel_factory running before clear_device_manager
atexit.register(core.clear_device_manager) atexit.register(core.clear_device_manager)
atexit.register(core.clear_kernel_factory) atexit.register(core.clear_kernel_factory)
atexit.register(core.ProcessGroupIdMap.destroy)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册