From a923a757a50b5be56c35c688103f8f74aa40e1d8 Mon Sep 17 00:00:00 2001 From: duanyanhui <45005871+YanhuiDua@users.noreply.github.com> Date: Fri, 13 Jan 2023 16:56:17 +0800 Subject: [PATCH] [Custom Device] Clear ProcessGroup Manually (#49182) * clear ProcessGroupCustom manually * fix bug * fix bug * move destroy ProcessGroup to ProcessGroupIdMap * enable destroy to all device * remove unused comments * change to internal api * Update process_group.cc * Update process_group.cc --- paddle/fluid/distributed/collective/process_group.cc | 7 +++++++ paddle/fluid/distributed/collective/process_group.h | 1 + paddle/fluid/pybind/distributed_py.cc | 8 ++++++++ paddle/phi/backends/stream.cc | 3 ++- python/paddle/distributed/__init__.py | 7 +++++-- python/paddle/distributed/collective.py | 10 ++++++++++ 6 files changed, 33 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/distributed/collective/process_group.cc b/paddle/fluid/distributed/collective/process_group.cc index d670477f2d4..2722edf8dea 100644 --- a/paddle/fluid/distributed/collective/process_group.cc +++ b/paddle/fluid/distributed/collective/process_group.cc @@ -36,5 +36,12 @@ ProcessGroupIdMap& ProcessGroupIdMap::GetInstance() { return instance; } +void ProcessGroupIdMap::DestroyProcessGroup(int gid) { + int use_count = ProcessGroupIdMap::GetInstance()[gid].use_count(); + for (int i = 0; i < use_count; ++i) { + ProcessGroupIdMap::GetInstance()[gid].reset(); + } +} + } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/collective/process_group.h b/paddle/fluid/distributed/collective/process_group.h index 4980dfc3075..ad8ba19f8ba 100644 --- a/paddle/fluid/distributed/collective/process_group.h +++ b/paddle/fluid/distributed/collective/process_group.h @@ -478,6 +478,7 @@ class ProcessGroupIdMap : public std::unordered_map> { public: static ProcessGroupIdMap& GetInstance(); + static void DestroyProcessGroup(int gid); }; // TODO(dev): The following method will be removed soon. diff --git a/paddle/fluid/pybind/distributed_py.cc b/paddle/fluid/pybind/distributed_py.cc index 6bf409d527c..cc06861d8d0 100644 --- a/paddle/fluid/pybind/distributed_py.cc +++ b/paddle/fluid/pybind/distributed_py.cc @@ -1308,6 +1308,14 @@ void BindDistributed(py::module *m) { }, py::arg("tensors"), py::call_guard()); + + py::class_>( + *m, "ProcessGroupIdMap") + .def_static("destroy", + distributed::ProcessGroupIdMap::DestroyProcessGroup, + py::arg("group_id") = 0, + py::call_guard()); } } // end namespace pybind diff --git a/paddle/phi/backends/stream.cc b/paddle/phi/backends/stream.cc index bad57c5238e..729f5717a5f 100644 --- a/paddle/phi/backends/stream.cc +++ b/paddle/phi/backends/stream.cc @@ -82,10 +82,11 @@ void Stream::Wait() const { void Stream::WaitCallback() const { callback_manager_->Wait(); } void Stream::Destroy() { - if (own_data_) { + if (own_data_ && stream_ != nullptr) { phi::DeviceManager::SetDevice(place_); device_->DestroyStream(this); own_data_ = false; + stream_ = nullptr; } } diff --git a/python/paddle/distributed/__init__.py b/python/paddle/distributed/__init__.py index 22a7bed2fa5..900cdacb115 100644 --- a/python/paddle/distributed/__init__.py +++ b/python/paddle/distributed/__init__.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import atexit from . import io from .spawn import spawn # noqa: F401 from .launch.main import launch # noqa: F401 @@ -29,8 +30,8 @@ from paddle.distributed.fleet.base.topology import ParallelMode # noqa: F401 from .collective import split # noqa: F401 from .collective import new_group # noqa: F401 -from .collective import is_available # noqa: F401 - +from .collective import is_available +from .collective import _destroy_process_group_id_map from .communication import ( stream, ReduceOp, @@ -123,3 +124,5 @@ __all__ = [ # noqa "is_available", "get_backend", ] + +atexit.register(_destroy_process_group_id_map) diff --git a/python/paddle/distributed/collective.py b/python/paddle/distributed/collective.py index 9ffb03760f1..586fe76c971 100644 --- a/python/paddle/distributed/collective.py +++ b/python/paddle/distributed/collective.py @@ -172,6 +172,16 @@ def _set_custom_gid(gid): _custom_gid = gid +def _destroy_process_group_id_map(): + """ + + Destroy the custom process group. Designed for CustomDevice. + + + """ + core.ProcessGroupIdMap.destroy() + + def new_group(ranks=None, backend=None, timeout=_default_timeout): """ -- GitLab