未验证 提交 a923a757 编写于 作者: D duanyanhui 提交者: GitHub

[Custom Device] Clear ProcessGroup Manually (#49182)

* clear ProcessGroupCustom manually

* fix bug

* fix bug

* move destroy ProcessGroup to ProcessGroupIdMap

* enable destroy to all device

* remove unused comments

* change to internal api

* Update process_group.cc

* Update process_group.cc
上级 bd165b94
......@@ -36,5 +36,12 @@ ProcessGroupIdMap& ProcessGroupIdMap::GetInstance() {
return instance;
}
void ProcessGroupIdMap::DestroyProcessGroup(int gid) {
int use_count = ProcessGroupIdMap::GetInstance()[gid].use_count();
for (int i = 0; i < use_count; ++i) {
ProcessGroupIdMap::GetInstance()[gid].reset();
}
}
} // namespace distributed
} // namespace paddle
......@@ -478,6 +478,7 @@ class ProcessGroupIdMap
: public std::unordered_map<int, std::shared_ptr<ProcessGroup>> {
public:
static ProcessGroupIdMap& GetInstance();
static void DestroyProcessGroup(int gid);
};
// TODO(dev): The following method will be removed soon.
......
......@@ -1308,6 +1308,14 @@ void BindDistributed(py::module *m) {
},
py::arg("tensors"),
py::call_guard<py::gil_scoped_release>());
py::class_<distributed::ProcessGroupIdMap,
std::shared_ptr<distributed::ProcessGroupIdMap>>(
*m, "ProcessGroupIdMap")
.def_static("destroy",
distributed::ProcessGroupIdMap::DestroyProcessGroup,
py::arg("group_id") = 0,
py::call_guard<py::gil_scoped_release>());
}
} // end namespace pybind
......
......@@ -82,10 +82,11 @@ void Stream::Wait() const {
void Stream::WaitCallback() const { callback_manager_->Wait(); }
void Stream::Destroy() {
if (own_data_) {
if (own_data_ && stream_ != nullptr) {
phi::DeviceManager::SetDevice(place_);
device_->DestroyStream(this);
own_data_ = false;
stream_ = nullptr;
}
}
......
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import atexit
from . import io
from .spawn import spawn # noqa: F401
from .launch.main import launch # noqa: F401
......@@ -29,8 +30,8 @@ from paddle.distributed.fleet.base.topology import ParallelMode # noqa: F401
from .collective import split # noqa: F401
from .collective import new_group # noqa: F401
from .collective import is_available # noqa: F401
from .collective import is_available
from .collective import _destroy_process_group_id_map
from .communication import (
stream,
ReduceOp,
......@@ -123,3 +124,5 @@ __all__ = [ # noqa
"is_available",
"get_backend",
]
atexit.register(_destroy_process_group_id_map)
......@@ -172,6 +172,16 @@ def _set_custom_gid(gid):
_custom_gid = gid
def _destroy_process_group_id_map():
"""
Destroy the custom process group. Designed for CustomDevice.
"""
core.ProcessGroupIdMap.destroy()
def new_group(ranks=None, backend=None, timeout=_default_timeout):
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册