From 1810bfb4157ecd62e24c44215b90caf1af9c93b0 Mon Sep 17 00:00:00 2001 From: lilong12 Date: Mon, 17 Jun 2019 14:31:03 +0800 Subject: [PATCH] UserDefinedCollectiveRoleMaker for collective mode (#17898) (#17987) * add UserDefinedCollectiveRoleMaker for collective mode (#17898) --- .../fluid/incubate/fleet/base/role_maker.py | 37 ++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/incubate/fleet/base/role_maker.py b/python/paddle/fluid/incubate/fleet/base/role_maker.py index ae32fa039d..af25f195ca 100644 --- a/python/paddle/fluid/incubate/fleet/base/role_maker.py +++ b/python/paddle/fluid/incubate/fleet/base/role_maker.py @@ -16,7 +16,8 @@ from __future__ import print_function from enum import Enum __all__ = [ - 'Role', 'RoleMakerBase', 'MPISymetricRoleMaker', 'UserDefinedRoleMaker' + 'Role', 'RoleMakerBase', 'MPISymetricRoleMaker', 'UserDefinedRoleMaker', + 'UserDefinedCollectiveRoleMaker' ] @@ -346,3 +347,37 @@ class UserDefinedRoleMaker(RoleMakerBase): def worker_num(self): return self._worker_num + + +class UserDefinedCollectiveRoleMaker(RoleMakerBase): + def __init__(self, current_id=0, worker_endpoints=None): + """ + UserDefinedCollectiveRoleMaker is designed for worker assignment + under manual for collective mode. + """ + super(UserDefinedCollectiveRoleMaker, self).__init__() + + if not isinstance(current_id, int): + raise TypeError("current_id must be as int") + else: + if current_id < 0: + raise ValueError("current_id must be greater or equal 0") + self._current_id = current_id + + if not isinstance(worker_endpoints, list): + raise TypeError("worker_endpoints must be as string list") + else: + self._worker_endpoints = worker_endpoints + self._worker_num = len(self._worker_endpoints) + + def is_worker(self): + return True + + def is_first_worker(self): + return self._current_id == 0 + + def worker_index(self): + return self._current_id + + def worker_num(self): + return self._worker_num -- GitLab