未验证 提交 fb6697b4 编写于 作者: S ShenLiang 提交者: GitHub

Fix the dowanload bug in the case of multiple machines (#29551)

* fix the dowanload bug
* add sort for ips
上级 1efef8ba
......@@ -140,6 +140,21 @@ def _map_path(url, root_dir):
return osp.join(root_dir, fpath)
def _get_unique_endpoints(trainer_endpoints):
# Sorting is to avoid different environmental variables for each card
trainer_endpoints.sort()
ips = set()
unique_endpoints = set()
for endpoint in trainer_endpoints:
ip = endpoint.split(":")[0]
if ip in ips:
continue
ips.add(ip)
unique_endpoints.add(endpoint)
logger.info("unique_endpoints {}".format(unique_endpoints))
return unique_endpoints
def get_path_from_url(url, root_dir, md5sum=None, check_exist=True):
""" Download from given url to root_dir.
if file or directory specified by url is exists under
......@@ -161,17 +176,20 @@ def get_path_from_url(url, root_dir, md5sum=None, check_exist=True):
assert is_url(url), "downloading from {} not a url".format(url)
# parse path after download to decompress under root_dir
fullpath = _map_path(url, root_dir)
# Mainly used to solve the problem of downloading data from different
# machines in the case of multiple machines. Different ips will download
# data, and the same ip will only download data once.
unique_endpoints = _get_unique_endpoints(ParallelEnv().trainer_endpoints[:])
if osp.exists(fullpath) and check_exist and _md5check(fullpath, md5sum):
logger.info("Found {}".format(fullpath))
else:
if ParallelEnv().local_rank == 0:
if ParallelEnv().current_endpoint in unique_endpoints:
fullpath = _download(url, root_dir, md5sum)
else:
while not os.path.exists(fullpath):
time.sleep(1)
if ParallelEnv().local_rank == 0:
if ParallelEnv().current_endpoint in unique_endpoints:
if tarfile.is_tarfile(fullpath) or zipfile.is_zipfile(fullpath):
fullpath = _decompress(fullpath)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册