未验证 提交 b5c35ae3 编写于 作者: L lilong12 提交者: GitHub

add UserDefinedCollectiveRoleMaker for collective mode (#17898)

* add 'UserDefinedRoleMakerNCCL' for collective mode.

* code style

* add the name UserDefinedRoleMakerNCCL to __all__

* rename to UserDefinedRoleMakerCollective

* rename to UserDefinedCollectiveRoleMaker
上级 84bb45c0
...@@ -16,7 +16,8 @@ from __future__ import print_function ...@@ -16,7 +16,8 @@ from __future__ import print_function
from enum import Enum from enum import Enum
__all__ = [ __all__ = [
'Role', 'RoleMakerBase', 'MPISymetricRoleMaker', 'UserDefinedRoleMaker' 'Role', 'RoleMakerBase', 'MPISymetricRoleMaker', 'UserDefinedRoleMaker',
'UserDefinedCollectiveRoleMaker'
] ]
...@@ -346,3 +347,37 @@ class UserDefinedRoleMaker(RoleMakerBase): ...@@ -346,3 +347,37 @@ class UserDefinedRoleMaker(RoleMakerBase):
def worker_num(self): def worker_num(self):
return self._worker_num return self._worker_num
class UserDefinedCollectiveRoleMaker(RoleMakerBase):
def __init__(self, current_id=0, worker_endpoints=None):
"""
UserDefinedCollectiveRoleMaker is designed for worker assignment
under manual for collective mode.
"""
super(UserDefinedCollectiveRoleMaker, self).__init__()
if not isinstance(current_id, int):
raise TypeError("current_id must be as int")
else:
if current_id < 0:
raise ValueError("current_id must be greater or equal 0")
self._current_id = current_id
if not isinstance(worker_endpoints, list):
raise TypeError("worker_endpoints must be as string list")
else:
self._worker_endpoints = worker_endpoints
self._worker_num = len(self._worker_endpoints)
def is_worker(self):
return True
def is_first_worker(self):
return self._current_id == 0
def worker_index(self):
return self._current_id
def worker_num(self):
return self._worker_num
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册