From af8199959c59925f86d44daff1d7c087dd2dbf96 Mon Sep 17 00:00:00 2001 From: zhanghan Date: Thu, 20 May 2021 16:02:00 +0800 Subject: [PATCH] fix multi process download model (#662) --- ernie/file_utils.py | 53 +++++++++++++++++++++++++++++---------------- 1 file changed, 34 insertions(+), 19 deletions(-) diff --git a/ernie/file_utils.py b/ernie/file_utils.py index e55a5bc..4715094 100644 --- a/ernie/file_utils.py +++ b/ernie/file_utils.py @@ -21,6 +21,8 @@ import logging from tqdm import tqdm from pathlib import Path import six +import paddle as P +import time if six.PY2: from pathlib2 import Path else: @@ -33,6 +35,8 @@ def _fetch_from_remote(url, force_download=False, cached_dir='~/.paddle-ernie-cache'): import hashlib, tempfile, requests, tarfile + env = P.distributed.ParallelEnv() + sig = hashlib.md5(url.encode('utf8')).hexdigest() cached_dir = Path(cached_dir).expanduser() try: @@ -40,25 +44,36 @@ def _fetch_from_remote(url, except OSError: pass cached_dir_model = cached_dir / sig - if force_download or not cached_dir_model.exists(): - cached_dir_model.mkdir() - tmpfile = cached_dir_model / 'tmp' - with tmpfile.open('wb') as f: - #url = 'https://ernie.bj.bcebos.com/ERNIE_stable.tgz' - r = requests.get(url, stream=True) - total_len = int(r.headers.get('content-length')) - for chunk in tqdm( - r.iter_content(chunk_size=1024), - total=total_len // 1024, - desc='downloading %s' % url, - unit='KB'): - if chunk: - f.write(chunk) - f.flush() - log.debug('extacting... to %s' % tmpfile) - with tarfile.open(tmpfile.as_posix()) as tf: - tf.extractall(path=cached_dir_model.as_posix()) - os.remove(tmpfile.as_posix()) + done_file = cached_dir_model / 'fetch_done' + if force_download or not done_file.exists(): + if env.dev_id == 0: + cached_dir_model.mkdir() + tmpfile = cached_dir_model / 'tmp' + with tmpfile.open('wb') as f: + #url = 'https://ernie.bj.bcebos.com/ERNIE_stable.tgz' + r = requests.get(url, stream=True) + total_len = int(r.headers.get('content-length')) + for chunk in tqdm( + r.iter_content(chunk_size=1024), + total=total_len // 1024, + desc='downloading %s' % url, + unit='KB'): + if chunk: + f.write(chunk) + f.flush() + log.debug('extacting... to %s' % tmpfile) + with tarfile.open(tmpfile.as_posix()) as tf: + tf.extractall(path=cached_dir_model.as_posix()) + os.remove(tmpfile.as_posix()) + f = done_file.open('wb') + f.close() + else: + while True: + if done_file.exists(): + break + else: + time.sleep(1) + log.debug('%s cached in %s' % (url, cached_dir)) return cached_dir_model -- GitLab