未验证 提交 04340a59 编写于 作者: W wangxinxin08 提交者: GitHub

fix hang bugs while training using multi machine (#2284)

上级 d82484d5
......@@ -41,6 +41,21 @@ def is_url(path):
or path.startswith('ppdet://')
def _get_unique_endpoints(trainer_endpoints):
# Sorting is to avoid different environmental variables for each card
trainer_endpoints.sort()
ips = set()
unique_endpoints = set()
for endpoint in trainer_endpoints:
ip = endpoint.split(":")[0]
if ip in ips:
continue
ips.add(ip)
unique_endpoints.add(endpoint)
logger.info("unique_endpoints {}".format(unique_endpoints))
return unique_endpoints
def get_weights_path_dist(path):
env = os.environ
if 'PADDLE_TRAINERS_NUM' in env and 'PADDLE_TRAINER_ID' in env:
......@@ -53,6 +68,9 @@ def get_weights_path_dist(path):
weight_path = map_path(path, WEIGHTS_HOME)
lock_path = weight_path + '.lock'
if not os.path.exists(weight_path):
from paddle.distributed import ParallelEnv
unique_endpoints = _get_unique_endpoints(ParallelEnv()
.trainer_endpoints[:])
try:
os.makedirs(os.path.dirname(weight_path))
except OSError as e:
......@@ -60,7 +78,7 @@ def get_weights_path_dist(path):
raise
with open(lock_path, 'w'): # touch
os.utime(lock_path, None)
if trainer_id == 0:
if ParallelEnv().current_endpoint in unique_endpoints:
get_weights_path(path)
os.remove(lock_path)
else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册