未验证 提交 3016a4ac 编写于 作者: X xujiaqi01 提交者: GitHub

add mock barrier all (#24786)

* add mock barrier all
test=develop

* fix
test=develop

* fix
test=develop

* fix
test=develop
上级 6e100227
...@@ -30,6 +30,41 @@ class Role: ...@@ -30,6 +30,41 @@ class Role:
SERVER = 2 SERVER = 2
class MockBarrier(object):
"""
MockBarrier is a empty impletation for barrier
mock as a real barrier for never-barrier in a specific scenario
"""
def barrier(self):
"""
dummy barrier, do nothing
"""
pass
def barrier_all(self):
"""
dummy all barrier, do nothing
"""
pass
def all_reduce(self, obj):
"""
dummy all reduce, do nothing
Args:
obj(any): obj to do all reduce
"""
return obj
def all_gather(self, obj):
"""
dummy all gather, do nothing
Args:
obj(any): obj to do all gather
"""
return [obj]
class RoleMakerBase(object): class RoleMakerBase(object):
""" """
RoleMakerBase is a base class for assigning a role to current process RoleMakerBase is a base class for assigning a role to current process
...@@ -587,7 +622,10 @@ class GeneralRoleMaker(RoleMakerBase): ...@@ -587,7 +622,10 @@ class GeneralRoleMaker(RoleMakerBase):
trainers_num = len(worker_endpoints) trainers_num = len(worker_endpoints)
if training_role not in ["TRAINER", "PSERVER"]: if training_role not in ["TRAINER", "PSERVER"]:
raise ValueError("TRAINING_ROLE must be PSERVER or TRAINER") raise ValueError("TRAINING_ROLE must be PSERVER or TRAINER")
self._is_barrier_all = 1
if "PADDLE_IS_BARRIER_ALL_ROLE" in os.environ:
self._is_barrier_all = int(os.environ[
"PADDLE_IS_BARRIER_ALL_ROLE"])
if training_role == "TRAINER": if training_role == "TRAINER":
role = Role.WORKER role = Role.WORKER
current_id = int(os.environ["PADDLE_TRAINER_ID"]) current_id = int(os.environ["PADDLE_TRAINER_ID"])
...@@ -608,6 +646,7 @@ class GeneralRoleMaker(RoleMakerBase): ...@@ -608,6 +646,7 @@ class GeneralRoleMaker(RoleMakerBase):
self._http_server.start() self._http_server.start()
self._node_type = 1 self._node_type = 1
self._cur_endpoint = worker_endpoints[current_id] self._cur_endpoint = worker_endpoints[current_id]
if self._is_barrier_all:
gloo = fluid.core.Gloo() gloo = fluid.core.Gloo()
gloo.set_rank(current_id) gloo.set_rank(current_id)
gloo.set_size(len(worker_endpoints)) gloo.set_size(len(worker_endpoints))
...@@ -617,12 +656,15 @@ class GeneralRoleMaker(RoleMakerBase): ...@@ -617,12 +656,15 @@ class GeneralRoleMaker(RoleMakerBase):
self._run_timeout_seconds) self._run_timeout_seconds)
if len(self._http_ip_port) != 0: if len(self._http_ip_port) != 0:
gloo.set_http_store(self._http_ip_port[0], gloo.set_http_store(self._http_ip_port[0],
int(self._http_ip_port[1]), "trainer") int(self._http_ip_port[1]),
"trainer")
else: else:
gloo.set_hdfs_store(self._hdfs_path + "/trainer", gloo.set_hdfs_store(self._hdfs_path + "/trainer",
self._hdfs_name, self._hdfs_ugi) self._hdfs_name, self._hdfs_ugi)
gloo.init() gloo.init()
self._node_type_comm = gloo self._node_type_comm = gloo
else:
self._all_comm = MockBarrier()
elif training_role == "PSERVER": elif training_role == "PSERVER":
role = Role.SERVER role = Role.SERVER
if os.environ.get("PADDLE_PSERVER_ID") is not None: if os.environ.get("PADDLE_PSERVER_ID") is not None:
......
...@@ -79,6 +79,21 @@ class TestCloudRoleMaker(unittest.TestCase): ...@@ -79,6 +79,21 @@ class TestCloudRoleMaker(unittest.TestCase):
print("do not support pslib test, skip") print("do not support pslib test, skip")
return return
from paddle.fluid.incubate.fleet.base.role_maker import MockBarrier
mb = MockBarrier()
mb.barrier()
mb.barrier_all()
mb.all_reduce(1)
mb.all_gather(1)
os.environ["POD_IP"] = "127.0.0.1"
os.environ["PADDLE_PORT"] = "36005"
os.environ["TRAINING_ROLE"] = "TRAINER"
os.environ["PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36005"
os.environ["PADDLE_PSERVERS_IP_PORT_LIST"] = "127.0.0.1:36006"
os.environ["PADDLE_IS_BARRIER_ALL_ROLE"] = "0"
role_maker = GeneralRoleMaker(path="test_mock1")
role_maker.generate_role()
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册