diff --git a/paddle/fluid/distributed/collective/process_group.cc b/paddle/fluid/distributed/collective/process_group.cc index d670477f2d41e5f4d5c8bc31b90be692814d8119..2722edf8deac66772ff44f023359cd6375b454cb 100644 --- a/paddle/fluid/distributed/collective/process_group.cc +++ b/paddle/fluid/distributed/collective/process_group.cc @@ -36,5 +36,12 @@ ProcessGroupIdMap& ProcessGroupIdMap::GetInstance() { return instance; } +void ProcessGroupIdMap::DestroyProcessGroup(int gid) { + int use_count = ProcessGroupIdMap::GetInstance()[gid].use_count(); + for (int i = 0; i < use_count; ++i) { + ProcessGroupIdMap::GetInstance()[gid].reset(); + } +} + } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/collective/process_group.h b/paddle/fluid/distributed/collective/process_group.h index 4980dfc307565cc82d051be2b0de7fc1033d6c12..ad8ba19f8bae17152eb7082f86d308f20529dfb3 100644 --- a/paddle/fluid/distributed/collective/process_group.h +++ b/paddle/fluid/distributed/collective/process_group.h @@ -478,6 +478,7 @@ class ProcessGroupIdMap : public std::unordered_map> { public: static ProcessGroupIdMap& GetInstance(); + static void DestroyProcessGroup(int gid); }; // TODO(dev): The following method will be removed soon. diff --git a/paddle/fluid/pybind/distributed_py.cc b/paddle/fluid/pybind/distributed_py.cc index 6bf409d527cca935f2d3e24b17c9a0f7147d61c1..cc06861d8d052d04c4b744d618fb12c5359a11f2 100644 --- a/paddle/fluid/pybind/distributed_py.cc +++ b/paddle/fluid/pybind/distributed_py.cc @@ -1308,6 +1308,14 @@ void BindDistributed(py::module *m) { }, py::arg("tensors"), py::call_guard()); + + py::class_>( + *m, "ProcessGroupIdMap") + .def_static("destroy", + distributed::ProcessGroupIdMap::DestroyProcessGroup, + py::arg("group_id") = 0, + py::call_guard()); } } // end namespace pybind diff --git a/paddle/phi/backends/stream.cc b/paddle/phi/backends/stream.cc index bad57c5238ec834364a9e4a16f1398fc51cdf905..729f5717a5fd7367c4c608aadebc44a5713a5b18 100644 --- a/paddle/phi/backends/stream.cc +++ b/paddle/phi/backends/stream.cc @@ -82,10 +82,11 @@ void Stream::Wait() const { void Stream::WaitCallback() const { callback_manager_->Wait(); } void Stream::Destroy() { - if (own_data_) { + if (own_data_ && stream_ != nullptr) { phi::DeviceManager::SetDevice(place_); device_->DestroyStream(this); own_data_ = false; + stream_ = nullptr; } } diff --git a/python/paddle/distributed/__init__.py b/python/paddle/distributed/__init__.py index 22a7bed2fa587e799e93ce08797d9422bedaeec5..900cdacb1154662ecd61ce6e0754cd8f6f2d70a9 100644 --- a/python/paddle/distributed/__init__.py +++ b/python/paddle/distributed/__init__.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import atexit from . import io from .spawn import spawn # noqa: F401 from .launch.main import launch # noqa: F401 @@ -29,8 +30,8 @@ from paddle.distributed.fleet.base.topology import ParallelMode # noqa: F401 from .collective import split # noqa: F401 from .collective import new_group # noqa: F401 -from .collective import is_available # noqa: F401 - +from .collective import is_available +from .collective import _destroy_process_group_id_map from .communication import ( stream, ReduceOp, @@ -123,3 +124,5 @@ __all__ = [ # noqa "is_available", "get_backend", ] + +atexit.register(_destroy_process_group_id_map) diff --git a/python/paddle/distributed/collective.py b/python/paddle/distributed/collective.py index 9ffb03760f1ed3fc26d4ae727a3822a8c90e1713..586fe76c971f1c01df215c971d4dde0ae8f96b81 100644 --- a/python/paddle/distributed/collective.py +++ b/python/paddle/distributed/collective.py @@ -172,6 +172,16 @@ def _set_custom_gid(gid): _custom_gid = gid +def _destroy_process_group_id_map(): + """ + + Destroy the custom process group. Designed for CustomDevice. + + + """ + core.ProcessGroupIdMap.destroy() + + def new_group(ranks=None, backend=None, timeout=_default_timeout): """