Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
f89a7b55
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
f89a7b55
编写于
6月 10, 2021
作者:
W
Wenyu
提交者:
GitHub
6月 10, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add wget option in download (#33379)
* add wget option in download
上级
945e0847
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
106 addition
and
34 deletion
+106
-34
python/paddle/hapi/hub.py
python/paddle/hapi/hub.py
+5
-1
python/paddle/tests/test_download.py
python/paddle/tests/test_download.py
+25
-0
python/paddle/utils/download.py
python/paddle/utils/download.py
+76
-33
未找到文件。
python/paddle/hapi/hub.py
浏览文件 @
f89a7b55
...
...
@@ -110,7 +110,11 @@ def _get_cache_or_reload(repo, force_reload, verbose=True, source='github'):
url
=
_git_archive_link
(
repo_owner
,
repo_name
,
branch
,
source
=
source
)
fpath
=
get_path_from_url
(
url
,
hub_dir
,
check_exist
=
not
force_reload
,
decompress
=
False
)
url
,
hub_dir
,
check_exist
=
not
force_reload
,
decompress
=
False
,
method
=
(
'wget'
if
source
==
'gitee'
else
'get'
))
shutil
.
move
(
fpath
,
cached_file
)
with
zipfile
.
ZipFile
(
cached_file
)
as
cached_zipfile
:
...
...
python/paddle/tests/test_download.py
浏览文件 @
f89a7b55
...
...
@@ -77,6 +77,31 @@ class TestDownload(unittest.TestCase):
'www.baidu.com'
,
'./test'
,
)
def
test_wget_download_error
(
self
,
):
with
self
.
assertRaises
(
RuntimeError
):
from
paddle.utils.download
import
_download
_download
(
'www.baidu'
,
'./test'
,
method
=
'wget'
)
def
test_download_methods
(
self
,
):
urls
=
[
"https://paddle-hapi.bj.bcebos.com/unittest/files.tar"
,
"https://paddle-hapi.bj.bcebos.com/unittest/files.zip"
,
]
import
sys
from
paddle.utils.download
import
_download
if
sys
.
platform
==
'linux'
:
methods
=
[
'wget'
,
'get'
]
else
:
methods
=
[
'get'
]
for
url
in
urls
:
for
method
in
methods
:
_download
(
url
,
path
=
'./test'
,
method
=
method
,
)
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/utils/download.py
浏览文件 @
f89a7b55
...
...
@@ -21,6 +21,7 @@ import sys
import
os.path
as
osp
import
shutil
import
requests
import
subprocess
import
hashlib
import
tarfile
import
zipfile
...
...
@@ -121,7 +122,8 @@ def get_path_from_url(url,
root_dir
,
md5sum
=
None
,
check_exist
=
True
,
decompress
=
True
):
decompress
=
True
,
method
=
'get'
):
""" Download from given url to root_dir.
if file or directory specified by url is exists under
root_dir, return the path directly, otherwise download
...
...
@@ -132,7 +134,9 @@ def get_path_from_url(url,
root_dir (str): root dir for downloading, it should be
WEIGHTS_HOME or DATASET_HOME
md5sum (str): md5 sum of download package
decompress (bool): decompress zip or tar file. Default is `True`
method (str): which download method to use. Support `wget` and `get`. Default is `get`.
Returns:
str: a local path to save downloaded models & weights & datasets.
"""
...
...
@@ -150,7 +154,7 @@ def get_path_from_url(url,
logger
.
info
(
"Found {}"
.
format
(
fullpath
))
else
:
if
ParallelEnv
().
current_endpoint
in
unique_endpoints
:
fullpath
=
_download
(
url
,
root_dir
,
md5sum
)
fullpath
=
_download
(
url
,
root_dir
,
md5sum
,
method
=
method
)
else
:
while
not
os
.
path
.
exists
(
fullpath
):
time
.
sleep
(
1
)
...
...
@@ -163,13 +167,79 @@ def get_path_from_url(url,
return
fullpath
def
_download
(
url
,
path
,
md5sum
=
None
):
def
_get_download
(
url
,
fullname
):
# using requests.get method
fname
=
osp
.
basename
(
fullname
)
try
:
req
=
requests
.
get
(
url
,
stream
=
True
)
except
Exception
as
e
:
# requests.exceptions.ConnectionError
logger
.
info
(
"Downloading {} from {} failed with exception {}"
.
format
(
fname
,
url
,
str
(
e
)))
return
False
if
req
.
status_code
!=
200
:
raise
RuntimeError
(
"Downloading from {} failed with code "
"{}!"
.
format
(
url
,
req
.
status_code
))
# For protecting download interupted, download to
# tmp_fullname firstly, move tmp_fullname to fullname
# after download finished
tmp_fullname
=
fullname
+
"_tmp"
total_size
=
req
.
headers
.
get
(
'content-length'
)
with
open
(
tmp_fullname
,
'wb'
)
as
f
:
if
total_size
:
with
tqdm
(
total
=
(
int
(
total_size
)
+
1023
)
//
1024
)
as
pbar
:
for
chunk
in
req
.
iter_content
(
chunk_size
=
1024
):
f
.
write
(
chunk
)
pbar
.
update
(
1
)
else
:
for
chunk
in
req
.
iter_content
(
chunk_size
=
1024
):
if
chunk
:
f
.
write
(
chunk
)
shutil
.
move
(
tmp_fullname
,
fullname
)
return
fullname
def
_wget_download
(
url
,
fullname
):
# using wget to download url
tmp_fullname
=
fullname
+
"_tmp"
# –user-agent
command
=
'wget -O {} -t {} {}'
.
format
(
tmp_fullname
,
DOWNLOAD_RETRY_LIMIT
,
url
)
subprc
=
subprocess
.
Popen
(
command
,
shell
=
True
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
)
_
=
subprc
.
communicate
()
if
subprc
.
returncode
!=
0
:
raise
RuntimeError
(
'{} failed. Please make sure `wget` is installed or {} exists'
.
format
(
command
,
url
))
shutil
.
move
(
tmp_fullname
,
fullname
)
return
fullname
_download_methods
=
{
'get'
:
_get_download
,
'wget'
:
_wget_download
,
}
def
_download
(
url
,
path
,
md5sum
=
None
,
method
=
'get'
):
"""
Download from url, save to path.
url (str): download url
path (str): download to given path
md5sum (str): md5 sum of download package
method (str): which download method to use. Support `wget` and `get`. Default is `get`.
"""
assert
method
in
_download_methods
,
'make sure `{}` implemented'
.
format
(
method
)
if
not
osp
.
exists
(
path
):
os
.
makedirs
(
path
)
...
...
@@ -177,6 +247,7 @@ def _download(url, path, md5sum=None):
fullname
=
osp
.
join
(
path
,
fname
)
retry_cnt
=
0
logger
.
info
(
"Downloading {} from {}"
.
format
(
fname
,
url
))
while
not
(
osp
.
exists
(
fullname
)
and
_md5check
(
fullname
,
md5sum
)):
if
retry_cnt
<
DOWNLOAD_RETRY_LIMIT
:
retry_cnt
+=
1
...
...
@@ -184,38 +255,10 @@ def _download(url, path, md5sum=None):
raise
RuntimeError
(
"Download from {} failed. "
"Retry limit reached"
.
format
(
url
))
logger
.
info
(
"Downloading {} from {}"
.
format
(
fname
,
url
))
try
:
req
=
requests
.
get
(
url
,
stream
=
True
)
except
Exception
as
e
:
# requests.exceptions.ConnectionError
logger
.
info
(
"Downloading {} from {} failed {} times with exception {}"
.
format
(
fname
,
url
,
retry_cnt
+
1
,
str
(
e
)))
if
not
_download_methods
[
method
](
url
,
fullname
):
time
.
sleep
(
1
)
continue
if
req
.
status_code
!=
200
:
raise
RuntimeError
(
"Downloading from {} failed with code "
"{}!"
.
format
(
url
,
req
.
status_code
))
# For protecting download interupted, download to
# tmp_fullname firstly, move tmp_fullname to fullname
# after download finished
tmp_fullname
=
fullname
+
"_tmp"
total_size
=
req
.
headers
.
get
(
'content-length'
)
with
open
(
tmp_fullname
,
'wb'
)
as
f
:
if
total_size
:
with
tqdm
(
total
=
(
int
(
total_size
)
+
1023
)
//
1024
)
as
pbar
:
for
chunk
in
req
.
iter_content
(
chunk_size
=
1024
):
f
.
write
(
chunk
)
pbar
.
update
(
1
)
else
:
for
chunk
in
req
.
iter_content
(
chunk_size
=
1024
):
if
chunk
:
f
.
write
(
chunk
)
shutil
.
move
(
tmp_fullname
,
fullname
)
return
fullname
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录