Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleHub
提交
5737d422
P
PaddleHub
项目概览
PaddlePaddle
/
PaddleHub
大约 1 年 前同步成功
通知
280
Star
12117
Fork
2091
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
200
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
200
Issue
200
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
5737d422
编写于
4月 11, 2019
作者:
W
wuzewu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add extra info in module
上级
321b7083
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
49 addition
and
14 deletion
+49
-14
demo/image-classification/create_module.py
demo/image-classification/create_module.py
+8
-1
paddlehub/module/module.py
paddlehub/module/module.py
+41
-13
未找到文件。
demo/image-classification/create_module.py
浏览文件 @
5737d422
...
...
@@ -59,7 +59,14 @@ def create_module(args):
module_dir
=
args
.
model
+
".hub_module"
,
module_info
=
"resources/module_info.yml"
,
processor
=
processor
.
Processor
,
assets
=
assets
)
assets
=
assets
,
extra_info
=
{
'excepted_image_width'
:
224
,
'excepted_image_height'
:
224
,
'pretrained_images_mean'
:
[
0.485
,
0.456
,
0.406
],
'pretrained_images_std'
:
[
0.229
,
0.224
,
0.225
],
'image_channel_order'
:
'RGB'
})
def
main
():
...
...
paddlehub/module/module.py
浏览文件 @
5737d422
...
...
@@ -40,6 +40,15 @@ from paddlehub import version
__all__
=
[
'Module'
,
'create_module'
]
# paddle hub module dir name
ASSETS_DIRNAME
=
"assets"
MODEL_DIRNAME
=
"model"
MODULE_DESC_PBNAME
=
"module_desc.pb"
PYTHON_DIR
=
"python"
PROCESSOR_NAME
=
"processor"
# paddle hub var prefix
HUB_VAR_PREFIX
=
"@HUB_%s@"
def
set_max_seq_len
(
program
,
input_dict
):
""" Set """
...
...
@@ -51,26 +60,18 @@ def create_module(sign_arr,
processor
=
None
,
assets
=
None
,
module_info
=
None
,
exe
=
None
):
exe
=
None
,
extra_info
=
None
):
sign_arr
=
utils
.
to_list
(
sign_arr
)
module
=
Module
(
signatures
=
sign_arr
,
processor
=
processor
,
assets
=
assets
,
module_info
=
module_info
)
module_info
=
module_info
,
extra_info
=
extra_info
)
module
.
serialize_to_path
(
path
=
module_dir
,
exe
=
exe
)
# paddle hub module dir name
ASSETS_DIRNAME
=
"assets"
MODEL_DIRNAME
=
"model"
MODULE_DESC_PBNAME
=
"module_desc.pb"
PYTHON_DIR
=
"python"
PROCESSOR_NAME
=
"processor"
# paddle hub var prefix
HUB_VAR_PREFIX
=
"@HUB_%s@"
class
ModuleHelper
(
object
):
def
__init__
(
self
,
module_dir
):
self
.
module_dir
=
module_dir
...
...
@@ -99,7 +100,8 @@ class Module(object):
signatures
=
None
,
module_info
=
None
,
assets
=
None
,
processor
=
None
):
processor
=
None
,
extra_info
=
None
):
self
.
desc
=
module_desc_pb2
.
ModuleDesc
()
self
.
program
=
None
self
.
assets
=
[]
...
...
@@ -108,6 +110,10 @@ class Module(object):
self
.
default_signature
=
None
self
.
module_info
=
None
self
.
processor
=
None
self
.
extra_info
=
{}
if
extra_info
is
None
else
extra_info
if
not
isinstance
(
self
.
extra_info
,
dict
):
raise
TypeError
(
"The extra_info should be an instance of python dict"
)
# TODO(wuzewu): print more module loading info log
if
name
:
self
.
_init_with_name
(
name
=
name
)
...
...
@@ -204,6 +210,7 @@ class Module(object):
self
.
_load_assets
()
self
.
_recover_from_desc
()
self
.
_generate_sign_attr
()
self
.
_generate_extra_info
()
self
.
_restore_parameter
(
self
.
program
)
self
.
_recover_variable_info
(
self
.
program
)
...
...
@@ -213,6 +220,7 @@ class Module(object):
self
.
_check_signatures
()
self
.
_generate_desc
()
self
.
_generate_sign_attr
()
self
.
_generate_extra_info
()
def
_init_with_program
(
self
,
program
):
pass
...
...
@@ -261,6 +269,14 @@ class Module(object):
var
=
block
.
vars
[
var_name
]
var
.
stop_gradient
=
stop_gradient
def
get_extra_info
(
self
,
key
):
return
self
.
extra_info
.
get
(
key
,
None
)
def
_generate_extra_info
(
self
):
for
key
in
self
.
extra_info
:
self
.
__dict__
[
"get_%s"
%
key
]
=
functools
.
partial
(
self
.
get_extra_info
,
key
=
key
)
def
_generate_module_info
(
self
,
module_info
=
None
):
if
not
module_info
:
self
.
module_info
=
{}
...
...
@@ -332,6 +348,12 @@ class Module(object):
self
.
summary
=
utils
.
from_module_attr_to_pyobj
(
module_info
.
map
.
data
[
'summary'
])
# recover extra info
extra_info
=
self
.
desc
.
attr
.
map
.
data
[
'extra_info'
]
self
.
extra_info
=
{}
for
key
,
value
in
extra_info
.
map
.
data
.
items
():
self
.
extra_info
[
key
]
=
utils
.
from_module_attr_to_pyobj
(
value
)
# recover name prefix
self
.
name_prefix
=
utils
.
from_module_attr_to_pyobj
(
self
.
desc
.
attr
.
map
.
data
[
"name_prefix"
])
...
...
@@ -398,6 +420,12 @@ class Module(object):
utils
.
from_pyobj_to_module_attr
(
self
.
summary
,
module_info
.
map
.
data
[
'summary'
])
# save extra info
extra_info
=
attr
.
map
.
data
[
'extra_info'
]
extra_info
.
type
=
module_desc_pb2
.
MAP
for
key
,
value
in
self
.
extra_info
.
items
():
utils
.
from_pyobj_to_module_attr
(
value
,
extra_info
.
map
.
data
[
key
])
def
__call__
(
self
,
sign_name
,
data
,
**
kwargs
):
self
.
check_processor
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录