未验证 提交 f8d09011 编写于 作者: R ronnywang 提交者: GitHub

[CustomDevice] add model parallel support for custom device (#52872)

上级 6b756e8c
......@@ -36,11 +36,15 @@ ProcessGroupIdMap& ProcessGroupIdMap::GetInstance() {
return instance;
}
void ProcessGroupIdMap::DestroyProcessGroup(int gid) {
int use_count = ProcessGroupIdMap::GetInstance()[gid].use_count();
void ProcessGroupIdMap::DestroyProcessGroup() {
auto& id_map = ProcessGroupIdMap::GetInstance();
for (auto iter = id_map.begin(); iter != id_map.end(); ++iter) {
auto use_count = iter->second.use_count();
for (int i = 0; i < use_count; ++i) {
ProcessGroupIdMap::GetInstance()[gid].reset();
iter->second.reset();
}
}
id_map.clear();
}
} // namespace distributed
......
......@@ -502,7 +502,7 @@ class ProcessGroupIdMap
: public std::unordered_map<int, std::shared_ptr<ProcessGroup>> {
public:
static ProcessGroupIdMap& GetInstance();
static void DestroyProcessGroup(int gid);
static void DestroyProcessGroup();
};
// TODO(dev): The following method will be removed soon.
......
......@@ -215,7 +215,7 @@ endif()
copy_if_different(${pybind_file} ${pybind_file_final})
if (WITH_CUSTOM_DEVICE)
cc_library(custom_device_common_op_registry SRCS custom_device_common_op_registry.cc DEPS operator)
cc_library(custom_device_common_op_registry SRCS custom_device_common_op_registry.cc DEPS operator phi_api)
endif()
if(NOT "${OP_LIST}" STREQUAL "")
......
......@@ -1357,7 +1357,6 @@ void BindDistributed(py::module *m) {
*m, "ProcessGroupIdMap")
.def_static("destroy",
distributed::ProcessGroupIdMap::DestroyProcessGroup,
py::arg("group_id") = 0,
py::call_guard<py::gil_scoped_release>());
}
......
......@@ -32,7 +32,6 @@ from paddle.distributed.fleet.base.topology import ParallelMode # noqa: F401
from .collective import split # noqa: F401
from .collective import new_group # noqa: F401
from .collective import is_available
from .collective import _destroy_process_group_id_map
from .communication import (
stream,
ReduceOp,
......@@ -122,5 +121,3 @@ __all__ = [ # noqa
"is_available",
"get_backend",
]
atexit.register(_destroy_process_group_id_map)
......@@ -172,16 +172,6 @@ def _set_custom_gid(gid):
_custom_gid = gid
def _destroy_process_group_id_map():
"""
Destroy the custom process group. Designed for CustomDevice.
"""
core.ProcessGroupIdMap.destroy()
def new_group(ranks=None, backend=None, timeout=_default_timeout):
"""
......
......@@ -223,3 +223,4 @@ atexit.register(core.clear_executor_cache)
# Keep clear_kernel_factory running before clear_device_manager
atexit.register(core.clear_device_manager)
atexit.register(core.clear_kernel_factory)
atexit.register(core.ProcessGroupIdMap.destroy)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册