提交 2f55a940 编写于 作者: W wangguanzhong 提交者: GitHub

fix checkpoint in multi-process (#3468)

上级 30489edb
...@@ -47,6 +47,37 @@ def is_url(path): ...@@ -47,6 +47,37 @@ def is_url(path):
return path.startswith('http://') or path.startswith('https://') return path.startswith('http://') or path.startswith('https://')
def _get_weight_path(path):
env = os.environ
if 'PADDLE_TRAINERS_NUM' in env and 'PADDLE_TRAINER_ID' in env:
trainer_id = int(env['PADDLE_TRAINER_ID'])
num_trainers = int(env['PADDLE_TRAINERS_NUM'])
if num_trainers <= 1:
path = get_weights_path(path)
else:
from ppdet.utils.download import map_path, WEIGHTS_HOME
weight_path = map_path(path, WEIGHTS_HOME)
lock_path = weight_path + '.lock'
if not os.path.exists(weight_path):
try:
os.makedirs(os.path.dirname(weight_path))
except OSError as e:
if e.errno != errno.EEXIST:
raise
with open(lock_path, 'w'): # touch
os.utime(lock_path, None)
if trainer_id == 0:
get_weights_path(path)
os.remove(lock_path)
else:
while os.path.exists(lock_path):
time.sleep(1)
path = weight_path
else:
path = get_weights_path(path)
return path
def load_params(exe, prog, path, ignore_params=[]): def load_params(exe, prog, path, ignore_params=[]):
""" """
Load model from the given path. Load model from the given path.
...@@ -58,7 +89,7 @@ def load_params(exe, prog, path, ignore_params=[]): ...@@ -58,7 +89,7 @@ def load_params(exe, prog, path, ignore_params=[]):
""" """
if is_url(path): if is_url(path):
path = get_weights_path(path) path = _get_weights_path(path)
if not os.path.exists(path): if not os.path.exists(path):
raise ValueError("Model pretrain path {} does not " raise ValueError("Model pretrain path {} does not "
...@@ -94,7 +125,7 @@ def load_checkpoint(exe, prog, path): ...@@ -94,7 +125,7 @@ def load_checkpoint(exe, prog, path):
path (string): URL string or loca model path. path (string): URL string or loca model path.
""" """
if is_url(path): if is_url(path):
path = get_weights_path(path) path = _get_weights_path(path)
if not os.path.exists(path): if not os.path.exists(path):
raise ValueError("Model checkpoint path {} does not " raise ValueError("Model checkpoint path {} does not "
...@@ -147,7 +178,7 @@ def load_and_fusebn(exe, prog, path): ...@@ -147,7 +178,7 @@ def load_and_fusebn(exe, prog, path):
logger.info('Load model and fuse batch norm from {}...'.format(path)) logger.info('Load model and fuse batch norm from {}...'.format(path))
if is_url(path): if is_url(path):
path = get_weights_path(path) path = _get_weights_path(path)
if not os.path.exists(path): if not os.path.exists(path):
raise ValueError("Model path {} does not exists.".format(path)) raise ValueError("Model path {} does not exists.".format(path))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册