Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
f89a7b55
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
f89a7b55
编写于
6月 10, 2021
作者:
W
Wenyu
提交者:
GitHub
6月 10, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add wget option in download (#33379)
* add wget option in download
上级
945e0847
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
106 addition
and
34 deletion
+106
-34
python/paddle/hapi/hub.py
python/paddle/hapi/hub.py
+5
-1
python/paddle/tests/test_download.py
python/paddle/tests/test_download.py
+25
-0
python/paddle/utils/download.py
python/paddle/utils/download.py
+76
-33
未找到文件。
python/paddle/hapi/hub.py
浏览文件 @
f89a7b55
...
...
@@ -110,7 +110,11 @@ def _get_cache_or_reload(repo, force_reload, verbose=True, source='github'):
url
=
_git_archive_link
(
repo_owner
,
repo_name
,
branch
,
source
=
source
)
fpath
=
get_path_from_url
(
url
,
hub_dir
,
check_exist
=
not
force_reload
,
decompress
=
False
)
url
,
hub_dir
,
check_exist
=
not
force_reload
,
decompress
=
False
,
method
=
(
'wget'
if
source
==
'gitee'
else
'get'
))
shutil
.
move
(
fpath
,
cached_file
)
with
zipfile
.
ZipFile
(
cached_file
)
as
cached_zipfile
:
...
...
python/paddle/tests/test_download.py
浏览文件 @
f89a7b55
...
...
@@ -77,6 +77,31 @@ class TestDownload(unittest.TestCase):
'www.baidu.com'
,
'./test'
,
)
def
test_wget_download_error
(
self
,
):
with
self
.
assertRaises
(
RuntimeError
):
from
paddle.utils.download
import
_download
_download
(
'www.baidu'
,
'./test'
,
method
=
'wget'
)
def
test_download_methods
(
self
,
):
urls
=
[
"https://paddle-hapi.bj.bcebos.com/unittest/files.tar"
,
"https://paddle-hapi.bj.bcebos.com/unittest/files.zip"
,
]
import
sys
from
paddle.utils.download
import
_download
if
sys
.
platform
==
'linux'
:
methods
=
[
'wget'
,
'get'
]
else
:
methods
=
[
'get'
]
for
url
in
urls
:
for
method
in
methods
:
_download
(
url
,
path
=
'./test'
,
method
=
method
,
)
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/utils/download.py
浏览文件 @
f89a7b55
...
...
@@ -21,6 +21,7 @@ import sys
import
os.path
as
osp
import
shutil
import
requests
import
subprocess
import
hashlib
import
tarfile
import
zipfile
...
...
@@ -121,7 +122,8 @@ def get_path_from_url(url,
root_dir
,
md5sum
=
None
,
check_exist
=
True
,
decompress
=
True
):
decompress
=
True
,
method
=
'get'
):
""" Download from given url to root_dir.
if file or directory specified by url is exists under
root_dir, return the path directly, otherwise download
...
...
@@ -132,7 +134,9 @@ def get_path_from_url(url,
root_dir (str): root dir for downloading, it should be
WEIGHTS_HOME or DATASET_HOME
md5sum (str): md5 sum of download package
decompress (bool): decompress zip or tar file. Default is `True`
method (str): which download method to use. Support `wget` and `get`. Default is `get`.
Returns:
str: a local path to save downloaded models & weights & datasets.
"""
...
...
@@ -150,7 +154,7 @@ def get_path_from_url(url,
logger
.
info
(
"Found {}"
.
format
(
fullpath
))
else
:
if
ParallelEnv
().
current_endpoint
in
unique_endpoints
:
fullpath
=
_download
(
url
,
root_dir
,
md5sum
)
fullpath
=
_download
(
url
,
root_dir
,
md5sum
,
method
=
method
)
else
:
while
not
os
.
path
.
exists
(
fullpath
):
time
.
sleep
(
1
)
...
...
@@ -163,13 +167,79 @@ def get_path_from_url(url,
return
fullpath
def
_download
(
url
,
path
,
md5sum
=
None
):
def
_get_download
(
url
,
fullname
):
# using requests.get method
fname
=
osp
.
basename
(
fullname
)
try
:
req
=
requests
.
get
(
url
,
stream
=
True
)
except
Exception
as
e
:
# requests.exceptions.ConnectionError
logger
.
info
(
"Downloading {} from {} failed with exception {}"
.
format
(
fname
,
url
,
str
(
e
)))
return
False
if
req
.
status_code
!=
200
:
raise
RuntimeError
(
"Downloading from {} failed with code "
"{}!"
.
format
(
url
,
req
.
status_code
))
# For protecting download interupted, download to
# tmp_fullname firstly, move tmp_fullname to fullname
# after download finished
tmp_fullname
=
fullname
+
"_tmp"
total_size
=
req
.
headers
.
get
(
'content-length'
)
with
open
(
tmp_fullname
,
'wb'
)
as
f
:
if
total_size
:
with
tqdm
(
total
=
(
int
(
total_size
)
+
1023
)
//
1024
)
as
pbar
:
for
chunk
in
req
.
iter_content
(
chunk_size
=
1024
):
f
.
write
(
chunk
)
pbar
.
update
(
1
)
else
:
for
chunk
in
req
.
iter_content
(
chunk_size
=
1024
):
if
chunk
:
f
.
write
(
chunk
)
shutil
.
move
(
tmp_fullname
,
fullname
)
return
fullname
def
_wget_download
(
url
,
fullname
):
# using wget to download url
tmp_fullname
=
fullname
+
"_tmp"
# –user-agent
command
=
'wget -O {} -t {} {}'
.
format
(
tmp_fullname
,
DOWNLOAD_RETRY_LIMIT
,
url
)
subprc
=
subprocess
.
Popen
(
command
,
shell
=
True
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
)
_
=
subprc
.
communicate
()
if
subprc
.
returncode
!=
0
:
raise
RuntimeError
(
'{} failed. Please make sure `wget` is installed or {} exists'
.
format
(
command
,
url
))
shutil
.
move
(
tmp_fullname
,
fullname
)
return
fullname
_download_methods
=
{
'get'
:
_get_download
,
'wget'
:
_wget_download
,
}
def
_download
(
url
,
path
,
md5sum
=
None
,
method
=
'get'
):
"""
Download from url, save to path.
url (str): download url
path (str): download to given path
md5sum (str): md5 sum of download package
method (str): which download method to use. Support `wget` and `get`. Default is `get`.
"""
assert
method
in
_download_methods
,
'make sure `{}` implemented'
.
format
(
method
)
if
not
osp
.
exists
(
path
):
os
.
makedirs
(
path
)
...
...
@@ -177,6 +247,7 @@ def _download(url, path, md5sum=None):
fullname
=
osp
.
join
(
path
,
fname
)
retry_cnt
=
0
logger
.
info
(
"Downloading {} from {}"
.
format
(
fname
,
url
))
while
not
(
osp
.
exists
(
fullname
)
and
_md5check
(
fullname
,
md5sum
)):
if
retry_cnt
<
DOWNLOAD_RETRY_LIMIT
:
retry_cnt
+=
1
...
...
@@ -184,38 +255,10 @@ def _download(url, path, md5sum=None):
raise
RuntimeError
(
"Download from {} failed. "
"Retry limit reached"
.
format
(
url
))
logger
.
info
(
"Downloading {} from {}"
.
format
(
fname
,
url
))
try
:
req
=
requests
.
get
(
url
,
stream
=
True
)
except
Exception
as
e
:
# requests.exceptions.ConnectionError
logger
.
info
(
"Downloading {} from {} failed {} times with exception {}"
.
format
(
fname
,
url
,
retry_cnt
+
1
,
str
(
e
)))
if
not
_download_methods
[
method
](
url
,
fullname
):
time
.
sleep
(
1
)
continue
if
req
.
status_code
!=
200
:
raise
RuntimeError
(
"Downloading from {} failed with code "
"{}!"
.
format
(
url
,
req
.
status_code
))
# For protecting download interupted, download to
# tmp_fullname firstly, move tmp_fullname to fullname
# after download finished
tmp_fullname
=
fullname
+
"_tmp"
total_size
=
req
.
headers
.
get
(
'content-length'
)
with
open
(
tmp_fullname
,
'wb'
)
as
f
:
if
total_size
:
with
tqdm
(
total
=
(
int
(
total_size
)
+
1023
)
//
1024
)
as
pbar
:
for
chunk
in
req
.
iter_content
(
chunk_size
=
1024
):
f
.
write
(
chunk
)
pbar
.
update
(
1
)
else
:
for
chunk
in
req
.
iter_content
(
chunk_size
=
1024
):
if
chunk
:
f
.
write
(
chunk
)
shutil
.
move
(
tmp_fullname
,
fullname
)
return
fullname
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录