From b5c35ae3e7df5d5d8bf37e445a6adc075cf7f4ec Mon Sep 17 00:00:00 2001 From: lilong12 Date: Tue, 11 Jun 2019 14:51:49 +0800 Subject: [PATCH] add UserDefinedCollectiveRoleMaker for collective mode (#17898) * add 'UserDefinedRoleMakerNCCL' for collective mode. * code style * add the name UserDefinedRoleMakerNCCL to __all__ * rename to UserDefinedRoleMakerCollective * rename to UserDefinedCollectiveRoleMaker --- .../fluid/incubate/fleet/base/role_maker.py | 37 ++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/incubate/fleet/base/role_maker.py b/python/paddle/fluid/incubate/fleet/base/role_maker.py index ae32fa039d1..af25f195cac 100644 --- a/python/paddle/fluid/incubate/fleet/base/role_maker.py +++ b/python/paddle/fluid/incubate/fleet/base/role_maker.py @@ -16,7 +16,8 @@ from __future__ import print_function from enum import Enum __all__ = [ - 'Role', 'RoleMakerBase', 'MPISymetricRoleMaker', 'UserDefinedRoleMaker' + 'Role', 'RoleMakerBase', 'MPISymetricRoleMaker', 'UserDefinedRoleMaker', + 'UserDefinedCollectiveRoleMaker' ] @@ -346,3 +347,37 @@ class UserDefinedRoleMaker(RoleMakerBase): def worker_num(self): return self._worker_num + + +class UserDefinedCollectiveRoleMaker(RoleMakerBase): + def __init__(self, current_id=0, worker_endpoints=None): + """ + UserDefinedCollectiveRoleMaker is designed for worker assignment + under manual for collective mode. + """ + super(UserDefinedCollectiveRoleMaker, self).__init__() + + if not isinstance(current_id, int): + raise TypeError("current_id must be as int") + else: + if current_id < 0: + raise ValueError("current_id must be greater or equal 0") + self._current_id = current_id + + if not isinstance(worker_endpoints, list): + raise TypeError("worker_endpoints must be as string list") + else: + self._worker_endpoints = worker_endpoints + self._worker_num = len(self._worker_endpoints) + + def is_worker(self): + return True + + def is_first_worker(self): + return self._current_id == 0 + + def worker_index(self): + return self._current_id + + def worker_num(self): + return self._worker_num -- GitLab