From a898bb29ae7b161d4b5f7638d336740972e3457c Mon Sep 17 00:00:00 2001 From: chenxuyi Date: Fri, 21 May 2021 21:04:36 +0800 Subject: [PATCH] fix muiti process download --- ernie/file_utils.py | 53 ++++++++++++++++++++------------------------- requirements.txt | 1 + 2 files changed, 25 insertions(+), 29 deletions(-) diff --git a/ernie/file_utils.py b/ernie/file_utils.py index 89ed138..40b5f0b 100644 --- a/ernie/file_utils.py +++ b/ernie/file_utils.py @@ -21,7 +21,6 @@ import logging from tqdm import tqdm from pathlib import Path import six -import paddle as P import time if six.PY2: from pathlib2 import Path @@ -35,8 +34,6 @@ def _fetch_from_remote(url, force_download=False, cached_dir='~/.paddle-ernie-cache'): import hashlib, tempfile, requests, tarfile - env = P.distributed.ParallelEnv() - sig = hashlib.md5(url.encode('utf8')).hexdigest() cached_dir = Path(cached_dir).expanduser() try: @@ -44,33 +41,31 @@ def _fetch_from_remote(url, except OSError: pass cached_dir_model = cached_dir / sig - done_file = cached_dir_model / 'fetch_done' - if force_download or not done_file.exists(): - if env.dev_id == 0: - cached_dir_model.mkdir() - tmpfile = cached_dir_model / 'tmp' - with tmpfile.open('wb') as f: - r = requests.get(url, stream=True) - total_len = int(r.headers.get('content-length')) - for chunk in tqdm( - r.iter_content(chunk_size=1024), - total=total_len // 1024, - desc='downloading %s' % url, - unit='KB'): - if chunk: - f.write(chunk) - f.flush() - log.debug('extacting... to %s' % tmpfile) - with tarfile.open(tmpfile.as_posix()) as tf: - tf.extractall(path=cached_dir_model.as_posix()) - os.remove(tmpfile.as_posix()) - f = done_file.open('wb') - f.close() - else: - while not done_file.exists(): - time.sleep(1) + from filelock import FileLock + with FileLock(str(cached_dir_model) + '.lock'): + donefile = cached_dir_model / 'done' + if (not force_download) and donefile.exists(): + log.debug('%s cached in %s' % (url, cached_dir_model)) + return cached_dir_model + cached_dir_model.mkdir(exist_ok=True) + tmpfile = cached_dir_model / 'tmp' + with tmpfile.open('wb') as f: + r = requests.get(url, stream=True) + total_len = int(r.headers.get('content-length')) + for chunk in tqdm( + r.iter_content(chunk_size=1024), + total=total_len // 1024, + desc='downloading %s' % url, + unit='KB'): + if chunk: + f.write(chunk) + f.flush() + log.debug('extacting... to %s' % tmpfile) + with tarfile.open(tmpfile.as_posix()) as tf: + tf.extractall(path=str(cached_dir_model)) + donefile.touch() + os.remove(tmpfile.as_posix()) - log.debug('%s cached in %s' % (url, cached_dir)) return cached_dir_model diff --git a/requirements.txt b/requirements.txt index ab5fb9d..aed1dc9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,4 +6,5 @@ sentencepiece==0.1.8 jieba==0.39 visualdl>=2.0.0b7 pathlib2>=2.3.2 +filelock>=3.0.0 tqdm>=4.32.2 -- GitLab