未验证 提交 a923a757 编写于 作者: D duanyanhui 提交者: GitHub

[Custom Device] Clear ProcessGroup Manually (#49182)

* clear ProcessGroupCustom manually

* fix bug

* fix bug

* move destroy ProcessGroup to ProcessGroupIdMap

* enable destroy to all device

* remove unused comments

* change to internal api

* Update process_group.cc

* Update process_group.cc
上级 bd165b94
...@@ -36,5 +36,12 @@ ProcessGroupIdMap& ProcessGroupIdMap::GetInstance() { ...@@ -36,5 +36,12 @@ ProcessGroupIdMap& ProcessGroupIdMap::GetInstance() {
return instance; return instance;
} }
void ProcessGroupIdMap::DestroyProcessGroup(int gid) {
int use_count = ProcessGroupIdMap::GetInstance()[gid].use_count();
for (int i = 0; i < use_count; ++i) {
ProcessGroupIdMap::GetInstance()[gid].reset();
}
}
} // namespace distributed } // namespace distributed
} // namespace paddle } // namespace paddle
...@@ -478,6 +478,7 @@ class ProcessGroupIdMap ...@@ -478,6 +478,7 @@ class ProcessGroupIdMap
: public std::unordered_map<int, std::shared_ptr<ProcessGroup>> { : public std::unordered_map<int, std::shared_ptr<ProcessGroup>> {
public: public:
static ProcessGroupIdMap& GetInstance(); static ProcessGroupIdMap& GetInstance();
static void DestroyProcessGroup(int gid);
}; };
// TODO(dev): The following method will be removed soon. // TODO(dev): The following method will be removed soon.
......
...@@ -1308,6 +1308,14 @@ void BindDistributed(py::module *m) { ...@@ -1308,6 +1308,14 @@ void BindDistributed(py::module *m) {
}, },
py::arg("tensors"), py::arg("tensors"),
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
py::class_<distributed::ProcessGroupIdMap,
std::shared_ptr<distributed::ProcessGroupIdMap>>(
*m, "ProcessGroupIdMap")
.def_static("destroy",
distributed::ProcessGroupIdMap::DestroyProcessGroup,
py::arg("group_id") = 0,
py::call_guard<py::gil_scoped_release>());
} }
} // end namespace pybind } // end namespace pybind
......
...@@ -82,10 +82,11 @@ void Stream::Wait() const { ...@@ -82,10 +82,11 @@ void Stream::Wait() const {
void Stream::WaitCallback() const { callback_manager_->Wait(); } void Stream::WaitCallback() const { callback_manager_->Wait(); }
void Stream::Destroy() { void Stream::Destroy() {
if (own_data_) { if (own_data_ && stream_ != nullptr) {
phi::DeviceManager::SetDevice(place_); phi::DeviceManager::SetDevice(place_);
device_->DestroyStream(this); device_->DestroyStream(this);
own_data_ = false; own_data_ = false;
stream_ = nullptr;
} }
} }
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import atexit
from . import io from . import io
from .spawn import spawn # noqa: F401 from .spawn import spawn # noqa: F401
from .launch.main import launch # noqa: F401 from .launch.main import launch # noqa: F401
...@@ -29,8 +30,8 @@ from paddle.distributed.fleet.base.topology import ParallelMode # noqa: F401 ...@@ -29,8 +30,8 @@ from paddle.distributed.fleet.base.topology import ParallelMode # noqa: F401
from .collective import split # noqa: F401 from .collective import split # noqa: F401
from .collective import new_group # noqa: F401 from .collective import new_group # noqa: F401
from .collective import is_available # noqa: F401 from .collective import is_available
from .collective import _destroy_process_group_id_map
from .communication import ( from .communication import (
stream, stream,
ReduceOp, ReduceOp,
...@@ -123,3 +124,5 @@ __all__ = [ # noqa ...@@ -123,3 +124,5 @@ __all__ = [ # noqa
"is_available", "is_available",
"get_backend", "get_backend",
] ]
atexit.register(_destroy_process_group_id_map)
...@@ -172,6 +172,16 @@ def _set_custom_gid(gid): ...@@ -172,6 +172,16 @@ def _set_custom_gid(gid):
_custom_gid = gid _custom_gid = gid
def _destroy_process_group_id_map():
"""
Destroy the custom process group. Designed for CustomDevice.
"""
core.ProcessGroupIdMap.destroy()
def new_group(ranks=None, backend=None, timeout=_default_timeout): def new_group(ranks=None, backend=None, timeout=_default_timeout):
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册