diff --git a/python/paddle/fluid/incubate/fleet/base/role_maker.py b/python/paddle/fluid/incubate/fleet/base/role_maker.py index ae32fa039d1531c1af6f1174a8cc87ce720302c1..af25f195cacd2d0b38bd7e701f32ee6dea298641 100644 --- a/python/paddle/fluid/incubate/fleet/base/role_maker.py +++ b/python/paddle/fluid/incubate/fleet/base/role_maker.py @@ -16,7 +16,8 @@ from __future__ import print_function from enum import Enum __all__ = [ - 'Role', 'RoleMakerBase', 'MPISymetricRoleMaker', 'UserDefinedRoleMaker' + 'Role', 'RoleMakerBase', 'MPISymetricRoleMaker', 'UserDefinedRoleMaker', + 'UserDefinedCollectiveRoleMaker' ] @@ -346,3 +347,37 @@ class UserDefinedRoleMaker(RoleMakerBase): def worker_num(self): return self._worker_num + + +class UserDefinedCollectiveRoleMaker(RoleMakerBase): + def __init__(self, current_id=0, worker_endpoints=None): + """ + UserDefinedCollectiveRoleMaker is designed for worker assignment + under manual for collective mode. + """ + super(UserDefinedCollectiveRoleMaker, self).__init__() + + if not isinstance(current_id, int): + raise TypeError("current_id must be as int") + else: + if current_id < 0: + raise ValueError("current_id must be greater or equal 0") + self._current_id = current_id + + if not isinstance(worker_endpoints, list): + raise TypeError("worker_endpoints must be as string list") + else: + self._worker_endpoints = worker_endpoints + self._worker_num = len(self._worker_endpoints) + + def is_worker(self): + return True + + def is_first_worker(self): + return self._current_id == 0 + + def worker_index(self): + return self._current_id + + def worker_num(self): + return self._worker_num