Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleHub
提交
7db1f61a
P
PaddleHub
项目概览
PaddlePaddle
/
PaddleHub
大约 1 年 前同步成功
通知
280
Star
12117
Fork
2091
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
200
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
200
Issue
200
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
7db1f61a
编写于
3月 05, 2019
作者:
W
wuzewu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update module func
上级
1cd11734
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
49 addition
and
10 deletion
+49
-10
paddle_hub/data_process/__init__.py
paddle_hub/data_process/__init__.py
+0
-0
paddle_hub/module/module.py
paddle_hub/module/module.py
+49
-10
未找到文件。
paddle_hub/data_process/__init__.py
已删除
100644 → 0
浏览文件 @
1cd11734
paddle_hub/module/module.py
浏览文件 @
7db1f61a
...
...
@@ -23,6 +23,7 @@ from paddle_hub.module import module_desc_pb2
from
paddle_hub.module.signature
import
Signature
,
create_signature
from
paddle_hub
import
version
import
os
import
functools
import
paddle
import
paddle.fluid
as
fluid
...
...
@@ -39,6 +40,8 @@ def create_module(sign_arr, module_dir, exe=None):
ASSETS_DIRNAME
=
"assets"
MODEL_DIRNAME
=
"model"
MODULE_DESC_PBNAME
=
"module_desc.pb"
PYTHON_DIR
=
"python"
PROCESSOR_NAME
=
"processor"
# paddle hub var prefix
HUB_VAR_PREFIX
=
"@HUB@"
...
...
@@ -48,8 +51,8 @@ class ModuleWrapper:
self
.
module
=
module
self
.
name
=
name
def
__call__
(
self
,
trainable
=
False
):
return
self
.
module
(
self
.
name
,
trainable
)
def
__call__
(
self
,
data
,
config
):
return
self
.
module
(
self
.
name
,
data
,
config
)
class
ModuleHelper
:
...
...
@@ -62,6 +65,15 @@ class ModuleHelper:
def
model_path
(
self
):
return
os
.
path
.
join
(
self
.
module_dir
,
MODEL_DIRNAME
)
def
processor_path
(
self
):
return
os
.
path
.
join
(
self
.
module_dir
,
PYTHON_DIR
)
def
processor_name
(
self
):
return
PROCESSOR_NAME
def
assets_path
(
self
):
return
os
.
path
.
join
(
self
.
module_dir
,
ASSETS_DIRNAME
)
class
Module
:
def
__init__
(
self
,
url
=
None
,
module_dir
=
None
,
signatures
=
None
,
name
=
None
):
...
...
@@ -73,6 +85,7 @@ class Module:
self
.
assets
=
[]
self
.
helper
=
None
self
.
signatures
=
{}
self
.
default_signature
=
None
if
url
:
self
.
_init_with_url
(
url
=
url
)
elif
module_dir
:
...
...
@@ -87,6 +100,13 @@ class Module:
module_dir
=
downloader
.
download_and_uncompress
(
module_url
)
self
.
_init_with_module_file
(
module_dir
)
def
_load_processor
(
self
):
import
sys
processor_path
=
self
.
helper
.
processor_path
()
sys
.
path
.
append
(
processor_path
)
processor_name
=
self
.
helper
.
processor_name
()
self
.
processor
=
__import__
(
processor_name
).
Processor
(
module
=
self
)
def
_init_with_module_file
(
self
,
module_dir
):
self
.
helper
=
ModuleHelper
(
module_dir
)
with
open
(
self
.
helper
.
module_desc_path
(),
"rb"
)
as
fi
:
...
...
@@ -97,6 +117,7 @@ class Module:
self
.
helper
.
model_path
(),
executor
=
exe
)
self
.
_recovery_parameter
(
self
.
program
)
self
.
_recover_variable_info
(
self
.
program
)
self
.
_load_processor
()
inputs
=
[]
outputs
=
[]
...
...
@@ -174,7 +195,8 @@ class Module:
def
_generate_sign_attr
(
self
):
self
.
_check_signatures
()
for
sign
in
self
.
signatures
:
self
.
__dict__
[
sign
]
=
ModuleWrapper
(
self
,
sign
)
self
.
__dict__
[
sign
]
=
functools
.
partial
(
self
.
__call__
,
sign_name
=
sign
)
def
_generate_desc
(
self
):
# save fluid Parameter
...
...
@@ -215,7 +237,26 @@ class Module:
fetch_var
.
var_name
=
HUB_VAR_PREFIX
+
output
.
name
fetch_var
.
alias
=
fetch_names
[
index
]
def
__call__
(
self
,
sign_name
,
trainable
=
False
):
def
__call__
(
self
,
sign_name
,
data
,
config
=
None
):
feed_dict
,
fetch_dict
,
program
=
self
.
context
(
sign_name
)
#TODO(wuzewu): more option
program
=
program
.
clone
(
for_test
=
True
)
reader
=
self
.
processor
.
reader
(
sign_name
=
sign_name
,
data_dict
=
data
)
feed_name_list
=
list
(
set
([
value
.
name
for
key
,
value
in
feed_dict
.
items
()]))
fetch_list
=
list
(
set
([
value
for
key
,
value
in
fetch_dict
.
items
()]))
with
fluid
.
program_guard
(
program
):
place
=
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
=
place
)
feeder
=
fluid
.
DataFeeder
(
feed_list
=
feed_name_list
,
place
=
place
)
for
batch
in
reader
():
data_out
=
exe
.
run
(
feed
=
feeder
.
feed
(
batch
),
fetch_list
=
fetch_list
,
return_numpy
=
False
)
self
.
processor
.
postprocess
(
sign_name
,
data_out
,
config
)
def
context
(
self
,
sign_name
,
trainable
=
False
):
assert
sign_name
in
self
.
signatures
,
"module did not have a signature with name %s"
%
sign_name
signature
=
self
.
signatures
[
sign_name
]
...
...
@@ -227,6 +268,7 @@ class Module:
self
.
_recovery_parameter
(
program
)
self
.
_recover_variable_info
(
program
)
#TODO(wuzewu): return feed_list and fetch_list directly
feed_dict
=
{}
fetch_dict
=
{}
for
index
,
var
in
enumerate
(
signature
.
inputs
):
...
...
@@ -243,18 +285,15 @@ class Module:
return
feed_dict
,
fetch_dict
,
program
def
preprocess
(
self
):
pass
def
postprocess
(
self
):
pass
def
parameters
(
self
):
pass
def
parameter_attrs
(
self
):
pass
def
default_signature
(
self
):
return
self
.
default_signature
def
_check_signatures
(
self
):
assert
self
.
signatures
,
"signature array should not be None"
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录