未验证 提交 50609f0f 编写于 作者: T Thunderbrook 提交者: GitHub

fix gloo error; Cherry pick mock (#26185)

* add mock barrier all (#24786)

* add mock barrier all
test=develop

* fix
test=develop

* fix
test=develop

* fix
test=develop

* fix gloo error
test=develop
Co-authored-by: Nxujiaqi01 <173596896@qq.com>
上级 824c72f5
......@@ -54,7 +54,7 @@ void HdfsStore::set(const std::string& key, const std::vector<char>& data) {
paddle::framework::fs_remove(tmp);
if (i == retry_times_) {
VLOG(0) << "fs_open_write failed, retry times reaches limit";
PADDLE_THROW(platform::errors::PreconditionNotMet(
PADDLE_THROW(paddle::platform::errors::PreconditionNotMet(
"fs_open_write failed, retry times reaches"
" limit ",
retry_times_));
......@@ -143,7 +143,7 @@ void HdfsStore::wait(const std::vector<std::string>& keys,
break;
}
}
PADDLE_THROW(platform::errors::ExecutionTimeout(
PADDLE_THROW(paddle::platform::errors::ExecutionTimeout(
"TIMEOUT self_rank = %d pair_rank = %d", self_rank_,
last_check_rank));
}
......
......@@ -30,6 +30,41 @@ class Role:
SERVER = 2
class MockBarrier(object):
"""
MockBarrier is a empty impletation for barrier
mock as a real barrier for never-barrier in a specific scenario
"""
def barrier(self):
"""
dummy barrier, do nothing
"""
pass
def barrier_all(self):
"""
dummy all barrier, do nothing
"""
pass
def all_reduce(self, obj):
"""
dummy all reduce, do nothing
Args:
obj(any): obj to do all reduce
"""
return obj
def all_gather(self, obj):
"""
dummy all gather, do nothing
Args:
obj(any): obj to do all gather
"""
return [obj]
class RoleMakerBase(object):
"""
RoleMakerBase is a base class for assigning a role to current process
......@@ -587,7 +622,10 @@ class GeneralRoleMaker(RoleMakerBase):
trainers_num = len(worker_endpoints)
if training_role not in ["TRAINER", "PSERVER"]:
raise ValueError("TRAINING_ROLE must be PSERVER or TRAINER")
self._is_barrier_all = 1
if "PADDLE_IS_BARRIER_ALL_ROLE" in os.environ:
self._is_barrier_all = int(os.environ[
"PADDLE_IS_BARRIER_ALL_ROLE"])
if training_role == "TRAINER":
role = Role.WORKER
current_id = int(os.environ["PADDLE_TRAINER_ID"])
......@@ -608,6 +646,7 @@ class GeneralRoleMaker(RoleMakerBase):
self._http_server.start()
self._node_type = 1
self._cur_endpoint = worker_endpoints[current_id]
if self._is_barrier_all:
gloo = fluid.core.Gloo()
gloo.set_rank(current_id)
gloo.set_size(len(worker_endpoints))
......@@ -617,12 +656,15 @@ class GeneralRoleMaker(RoleMakerBase):
self._run_timeout_seconds)
if len(self._http_ip_port) != 0:
gloo.set_http_store(self._http_ip_port[0],
int(self._http_ip_port[1]), "trainer")
int(self._http_ip_port[1]),
"trainer")
else:
gloo.set_hdfs_store(self._hdfs_path + "/trainer",
self._hdfs_name, self._hdfs_ugi)
gloo.init()
self._node_type_comm = gloo
else:
self._all_comm = MockBarrier()
elif training_role == "PSERVER":
role = Role.SERVER
if os.environ.get("PADDLE_PSERVER_ID") is not None:
......
......@@ -79,6 +79,21 @@ class TestCloudRoleMaker(unittest.TestCase):
print("do not support pslib test, skip")
return
from paddle.fluid.incubate.fleet.base.role_maker import MockBarrier
mb = MockBarrier()
mb.barrier()
mb.barrier_all()
mb.all_reduce(1)
mb.all_gather(1)
os.environ["POD_IP"] = "127.0.0.1"
os.environ["PADDLE_PORT"] = "36005"
os.environ["TRAINING_ROLE"] = "TRAINER"
os.environ["PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36005"
os.environ["PADDLE_PSERVERS_IP_PORT_LIST"] = "127.0.0.1:36006"
os.environ["PADDLE_IS_BARRIER_ALL_ROLE"] = "0"
role_maker = GeneralRoleMaker(path="test_mock1")
role_maker.generate_role()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册