未验证 提交 3016a4ac 编写于 作者: X xujiaqi01 提交者: GitHub

add mock barrier all (#24786)

* add mock barrier all
test=develop

* fix
test=develop

* fix
test=develop

* fix
test=develop
上级 6e100227
......@@ -30,6 +30,41 @@ class Role:
SERVER = 2
class MockBarrier(object):
"""
MockBarrier is a empty impletation for barrier
mock as a real barrier for never-barrier in a specific scenario
"""
def barrier(self):
"""
dummy barrier, do nothing
"""
pass
def barrier_all(self):
"""
dummy all barrier, do nothing
"""
pass
def all_reduce(self, obj):
"""
dummy all reduce, do nothing
Args:
obj(any): obj to do all reduce
"""
return obj
def all_gather(self, obj):
"""
dummy all gather, do nothing
Args:
obj(any): obj to do all gather
"""
return [obj]
class RoleMakerBase(object):
"""
RoleMakerBase is a base class for assigning a role to current process
......@@ -587,7 +622,10 @@ class GeneralRoleMaker(RoleMakerBase):
trainers_num = len(worker_endpoints)
if training_role not in ["TRAINER", "PSERVER"]:
raise ValueError("TRAINING_ROLE must be PSERVER or TRAINER")
self._is_barrier_all = 1
if "PADDLE_IS_BARRIER_ALL_ROLE" in os.environ:
self._is_barrier_all = int(os.environ[
"PADDLE_IS_BARRIER_ALL_ROLE"])
if training_role == "TRAINER":
role = Role.WORKER
current_id = int(os.environ["PADDLE_TRAINER_ID"])
......@@ -608,6 +646,7 @@ class GeneralRoleMaker(RoleMakerBase):
self._http_server.start()
self._node_type = 1
self._cur_endpoint = worker_endpoints[current_id]
if self._is_barrier_all:
gloo = fluid.core.Gloo()
gloo.set_rank(current_id)
gloo.set_size(len(worker_endpoints))
......@@ -617,12 +656,15 @@ class GeneralRoleMaker(RoleMakerBase):
self._run_timeout_seconds)
if len(self._http_ip_port) != 0:
gloo.set_http_store(self._http_ip_port[0],
int(self._http_ip_port[1]), "trainer")
int(self._http_ip_port[1]),
"trainer")
else:
gloo.set_hdfs_store(self._hdfs_path + "/trainer",
self._hdfs_name, self._hdfs_ugi)
gloo.init()
self._node_type_comm = gloo
else:
self._all_comm = MockBarrier()
elif training_role == "PSERVER":
role = Role.SERVER
if os.environ.get("PADDLE_PSERVER_ID") is not None:
......
......@@ -79,6 +79,21 @@ class TestCloudRoleMaker(unittest.TestCase):
print("do not support pslib test, skip")
return
from paddle.fluid.incubate.fleet.base.role_maker import MockBarrier
mb = MockBarrier()
mb.barrier()
mb.barrier_all()
mb.all_reduce(1)
mb.all_gather(1)
os.environ["POD_IP"] = "127.0.0.1"
os.environ["PADDLE_PORT"] = "36005"
os.environ["TRAINING_ROLE"] = "TRAINER"
os.environ["PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36005"
os.environ["PADDLE_PSERVERS_IP_PORT_LIST"] = "127.0.0.1:36006"
os.environ["PADDLE_IS_BARRIER_ALL_ROLE"] = "0"
role_maker = GeneralRoleMaker(path="test_mock1")
role_maker.generate_role()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册