Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleHub
提交
db5fd353
P
PaddleHub
项目概览
PaddlePaddle
/
PaddleHub
大约 1 年 前同步成功
通知
282
Star
12117
Fork
2091
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
200
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
200
Issue
200
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
db5fd353
编写于
1月 04, 2019
作者:
Z
Zeyu Chen
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
remove useless files
上级
b43e0aa3
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
14 addition
and
80 deletion
+14
-80
paddle_hub/downloader.py
paddle_hub/downloader.py
+8
-65
paddle_hub/module.py
paddle_hub/module.py
+6
-15
未找到文件。
paddle_hub/downloader.py
浏览文件 @
db5fd353
...
...
@@ -64,9 +64,9 @@ def download_and_uncompress(url, save_name=None):
os
.
makedirs
(
dirname
)
#TODO(ZeyuChen) add download md5 file to verify file completeness
file_name
=
os
.
path
.
join
(
dirname
,
url
.
split
(
'/'
)[
-
1
]
if
save_name
is
None
else
save_name
)
file_name
=
os
.
path
.
join
(
dirname
,
url
.
split
(
'/'
)[
-
1
]
if
save_name
is
None
else
save_name
)
retry
=
0
retry_limit
=
3
...
...
@@ -76,8 +76,9 @@ def download_and_uncompress(url, save_name=None):
if
retry
<
retry_limit
:
retry
+=
1
else
:
raise
RuntimeError
(
"Cannot download {0} within retry limit {1}"
.
format
(
url
,
retry_limit
))
raise
RuntimeError
(
"Cannot download {0} within retry limit {1}"
.
format
(
url
,
retry_limit
))
print
(
"Cache file %s not found, downloading %s"
%
(
file_name
,
url
))
r
=
requests
.
get
(
url
,
stream
=
True
)
total_length
=
r
.
headers
.
get
(
'content-length'
)
...
...
@@ -94,8 +95,8 @@ def download_and_uncompress(url, save_name=None):
dl
+=
len
(
data
)
f
.
write
(
data
)
done
=
int
(
50
*
dl
/
total_length
)
sys
.
stdout
.
write
(
"
\r
[%s%s]"
%
(
'='
*
done
,
' '
*
(
50
-
done
)))
sys
.
stdout
.
write
(
"
\r
[%s%s]"
%
(
'='
*
done
,
' '
*
(
50
-
done
)))
sys
.
stdout
.
flush
()
print
(
"file download completed!"
,
file_name
)
...
...
@@ -111,64 +112,6 @@ def download_and_uncompress(url, save_name=None):
return
module_name
,
module_dir
class
TqdmProgress
(
tqdm
):
"""
tqdm prograss hook
"""
last_block
=
0
def
update_to
(
self
,
block_num
=
1
,
block_size
=
1
,
total_size
=
None
):
if
total_size
is
not
None
:
self
.
total
=
total_size
self
.
update
((
block_num
-
self
.
last_block
)
*
block_size
)
self
.
last_block
=
block_num
class
DownloadManager
(
object
):
def
__init__
(
self
):
self
.
dst_path
=
tempfile
.
mkstemp
()
def
download
(
self
,
link
,
dst_path
):
file_name
=
link
.
split
(
"/"
)[
-
1
]
if
dst_path
is
not
None
:
self
.
dst_path
=
dst_path
if
not
os
.
path
.
exists
(
self
.
dst_path
):
os
.
makedirs
(
self
.
dst_path
)
file_path
=
os
.
path
.
join
(
self
.
dst_path
,
file_name
)
print
(
"download filepath"
,
file_path
)
with
TqdmProgress
(
unit
=
'B'
,
unit_scale
=
True
,
unit_divisor
=
1024
,
miniters
=
1
,
desc
=
file_name
)
as
progress
:
path
,
header
=
urlretrieve
(
link
,
filename
=
file_path
,
reporthook
=
progress
.
update_to
,
data
=
None
)
return
path
def
_extract_file
(
self
,
tgz
,
tarinfo
,
dst_path
,
buffer_size
=
10
<<
20
):
"""Extracts 'tarinfo' from 'tgz' and writes to 'dst_path'."""
src
=
tgz
.
extractfile
(
tarinfo
)
dst
=
tf
.
gfile
.
GFile
(
dst_path
,
"wb"
)
while
1
:
buf
=
src
.
read
(
buffer_size
)
if
not
buf
:
break
dst
.
write
(
buf
)
self
.
_log_progress
(
len
(
buf
))
dst
.
close
()
src
.
close
()
def
download_and_uncompress
(
self
,
link
,
dst_path
):
file_name
=
self
.
download
(
link
,
dst_path
)
print
(
file_name
)
if
__name__
==
"__main__"
:
# TODO(ZeyuChen) add unit test
link
=
"http://paddlehub.bj.bcebos.com/word2vec/word2vec-dim16-simple-example-1.tar.gz"
...
...
paddle_hub/module.py
浏览文件 @
db5fd353
...
...
@@ -159,19 +159,6 @@ class Module(object):
word_dict
=
self
.
config
.
get_dict
()
return
list
(
map
(
lambda
x
:
word_dict
[
x
],
inputs
))
# # load assets folder
# def _load_assets(self, module_dir):
# assets_dir = os.path.join(module_dir, ASSETS_NAME)
# dict_path = os.path.join(assets_dir, DICT_NAME)
# word_id = 0
# with open(dict_path) as fi:
# words = fi.readlines()
# #TODO(ZeyuChen) check whether word id is duplicated and valid
# for line in fi:
# w, w_id = line.split()
# self.dict[w] = int(w_id)
def
add_module_feed_list
(
self
,
feed_list
):
self
.
feed_list
=
feed_list
...
...
@@ -240,9 +227,13 @@ class ModuleConfig(object):
w_id
=
self
.
dict
[
w
]
fo
.
write
(
"{}
\t
{}
\n
"
.
format
(
w
,
w_id
))
def
register_input_var
(
self
,
var
):
def
register_input_var
(
self
,
var
,
signature
=
"default"
):
var_name
=
var
.
name
()
self
.
desc
.
sign2input
[
signature
].
append
(
var_name
)
def
register_output_var
(
self
,
var
,
signature
=
"default"
):
var_name
=
var
.
name
()
self
.
feed_list
.
ad
d
(
var_name
)
self
.
desc
.
sign2output
[
signature
].
appen
d
(
var_name
)
def
save_dict
(
self
,
word_dict
,
dict_name
=
DICT_NAME
):
""" Save dictionary for NLP module
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录