未验证 提交 70d00f6f 编写于 作者: S shangliang Xu 提交者: GitHub

[dev] fix bug in _download_dist (#6419)

上级 6c59641e
......@@ -393,7 +393,12 @@ def _download(url, path, md5sum=None):
def _download_dist(url, path, md5sum=None):
env = os.environ
if 'PADDLE_TRAINERS_NUM' in env and 'PADDLE_TRAINER_ID' in env:
trainer_id = int(env['PADDLE_TRAINER_ID'])
# Mainly used to solve the problem of downloading data from
# different machines in the case of multiple machines.
# Different nodes will download data, and the same node
# will only download data once.
# Reference https://github.com/PaddlePaddle/PaddleClas/blob/develop/ppcls/utils/download.py#L108
rank_id_curr_node = int(os.environ.get("PADDLE_RANK_IN_NODE", 0))
num_trainers = int(env['PADDLE_TRAINERS_NUM'])
if num_trainers <= 1:
return _download(url, path, md5sum)
......@@ -406,12 +411,9 @@ def _download_dist(url, path, md5sum=None):
os.makedirs(path)
if not osp.exists(fullname):
from paddle.distributed import ParallelEnv
unique_endpoints = _get_unique_endpoints(ParallelEnv()
.trainer_endpoints[:])
with open(lock_path, 'w'): # touch
os.utime(lock_path, None)
if ParallelEnv().current_endpoint in unique_endpoints:
if rank_id_curr_node == 0:
_download(url, path, md5sum)
os.remove(lock_path)
else:
......
......@@ -262,7 +262,7 @@ else
continue
fi
if [ ${autocast} = "amp" ]; then
if [ ${autocast} = "amp" ] || [ ${autocast} = "fp16" ]; then
set_autocast="--amp"
else
set_autocast=" "
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册