From 2f55a940f6d4d0aa63b98ad29ec5419297224a5e Mon Sep 17 00:00:00 2001 From: wangguanzhong Date: Thu, 3 Oct 2019 20:19:28 +0800 Subject: [PATCH] fix checkpoint in multi-process (#3468) --- ppdet/utils/checkpoint.py | 37 ++++++++++++++++++++++++++++++++++--- 1 file changed, 34 insertions(+), 3 deletions(-) diff --git a/ppdet/utils/checkpoint.py b/ppdet/utils/checkpoint.py index b1dfa864e..37eddac1c 100644 --- a/ppdet/utils/checkpoint.py +++ b/ppdet/utils/checkpoint.py @@ -47,6 +47,37 @@ def is_url(path): return path.startswith('http://') or path.startswith('https://') +def _get_weight_path(path): + env = os.environ + if 'PADDLE_TRAINERS_NUM' in env and 'PADDLE_TRAINER_ID' in env: + trainer_id = int(env['PADDLE_TRAINER_ID']) + num_trainers = int(env['PADDLE_TRAINERS_NUM']) + if num_trainers <= 1: + path = get_weights_path(path) + else: + from ppdet.utils.download import map_path, WEIGHTS_HOME + weight_path = map_path(path, WEIGHTS_HOME) + lock_path = weight_path + '.lock' + if not os.path.exists(weight_path): + try: + os.makedirs(os.path.dirname(weight_path)) + except OSError as e: + if e.errno != errno.EEXIST: + raise + with open(lock_path, 'w'): # touch + os.utime(lock_path, None) + if trainer_id == 0: + get_weights_path(path) + os.remove(lock_path) + else: + while os.path.exists(lock_path): + time.sleep(1) + path = weight_path + else: + path = get_weights_path(path) + return path + + def load_params(exe, prog, path, ignore_params=[]): """ Load model from the given path. @@ -58,7 +89,7 @@ def load_params(exe, prog, path, ignore_params=[]): """ if is_url(path): - path = get_weights_path(path) + path = _get_weights_path(path) if not os.path.exists(path): raise ValueError("Model pretrain path {} does not " @@ -94,7 +125,7 @@ def load_checkpoint(exe, prog, path): path (string): URL string or loca model path. """ if is_url(path): - path = get_weights_path(path) + path = _get_weights_path(path) if not os.path.exists(path): raise ValueError("Model checkpoint path {} does not " @@ -147,7 +178,7 @@ def load_and_fusebn(exe, prog, path): logger.info('Load model and fuse batch norm from {}...'.format(path)) if is_url(path): - path = get_weights_path(path) + path = _get_weights_path(path) if not os.path.exists(path): raise ValueError("Model path {} does not exists.".format(path)) -- GitLab