From cd644993512a2ea0f645c04988f879abab8d15f6 Mon Sep 17 00:00:00 2001 From: littletomatodonkey Date: Tue, 12 Jul 2022 11:18:00 +0800 Subject: [PATCH] fix download logic for multi-machine (#2140) --- ppcls/utils/download.py | 27 ++++++--------------------- 1 file changed, 6 insertions(+), 21 deletions(-) diff --git a/ppcls/utils/download.py b/ppcls/utils/download.py index 51d45438..3aeda0ce 100644 --- a/ppcls/utils/download.py +++ b/ppcls/utils/download.py @@ -77,21 +77,6 @@ def _map_path(url, root_dir): return osp.join(root_dir, fpath) -def _get_unique_endpoints(trainer_endpoints): - # Sorting is to avoid different environmental variables for each card - trainer_endpoints.sort() - ips = set() - unique_endpoints = set() - for endpoint in trainer_endpoints: - ip = endpoint.split(":")[0] - if ip in ips: - continue - ips.add(ip) - unique_endpoints.add(endpoint) - logger.info("unique_endpoints {}".format(unique_endpoints)) - return unique_endpoints - - def get_path_from_url(url, root_dir, md5sum=None, @@ -118,20 +103,20 @@ def get_path_from_url(url, # parse path after download to decompress under root_dir fullpath = _map_path(url, root_dir) # Mainly used to solve the problem of downloading data from different - # machines in the case of multiple machines. Different ips will download - # data, and the same ip will only download data once. - unique_endpoints = _get_unique_endpoints(ParallelEnv() - .trainer_endpoints[:]) + # machines in the case of multiple machines. Different nodes will download + # data, and the same node will only download data once. + rank_id_curr_node = int(os.environ.get("PADDLE_RANK_IN_NODE", 0)) + if osp.exists(fullpath) and check_exist and _md5check(fullpath, md5sum): logger.info("Found {}".format(fullpath)) else: - if ParallelEnv().current_endpoint in unique_endpoints: + if rank_id_curr_node == 0: fullpath = _download(url, root_dir, md5sum) else: while not os.path.exists(fullpath): time.sleep(1) - if ParallelEnv().current_endpoint in unique_endpoints: + if rank_id_curr_node == 0: if decompress and (tarfile.is_tarfile(fullpath) or zipfile.is_zipfile(fullpath)): fullpath = _decompress(fullpath) -- GitLab