Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleHub
提交
2a93e018
P
PaddleHub
项目概览
PaddlePaddle
/
PaddleHub
1 年多 前同步成功
通知
284
Star
12117
Fork
2091
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
200
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
200
Issue
200
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
2a93e018
编写于
4月 10, 2019
作者:
W
wuzewu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
replace asserts
上级
5b09fad3
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
85 addition
and
43 deletion
+85
-43
paddlehub/commands/base_command.py
paddlehub/commands/base_command.py
+6
-5
paddlehub/commands/run.py
paddlehub/commands/run.py
+19
-5
paddlehub/common/paddle_helper.py
paddlehub/common/paddle_helper.py
+19
-13
paddlehub/module/module.py
paddlehub/module/module.py
+23
-10
paddlehub/module/signature.py
paddlehub/module/signature.py
+14
-8
paddlehub/reader/nlp_reader.py
paddlehub/reader/nlp_reader.py
+4
-2
未找到文件。
paddlehub/commands/base_command.py
浏览文件 @
2a93e018
...
...
@@ -30,8 +30,10 @@ class BaseCommand:
def
instance
(
cls
):
if
cls
.
name
in
BaseCommand
.
command_dict
:
command
=
BaseCommand
.
command_dict
[
cls
.
name
]
assert
command
.
__class__
.
__name__
==
cls
.
__name__
,
"already has a command %s with type %s"
%
(
cls
.
name
,
command
.
__class__
)
if
command
.
__class__
.
__name__
!=
cls
.
__name__
:
raise
KeyError
(
"Command dict already has a command %s with type %s"
%
(
cls
.
name
,
command
.
__class__
))
return
command
if
not
hasattr
(
cls
,
'_instance'
):
cls
.
_instance
=
cls
(
cls
.
name
)
...
...
@@ -39,9 +41,8 @@ class BaseCommand:
return
cls
.
_instance
def
__init__
(
self
,
name
):
assert
not
hasattr
(
self
.
__class__
,
'_instance'
),
'Please use `instance()` to get Command object!'
if
hasattr
(
self
.
__class__
,
'_instance'
):
raise
RuntimeError
(
"Please use `instance()` to get Command object!"
)
self
.
args
=
None
self
.
name
=
name
self
.
show_in_help
=
True
...
...
paddlehub/commands/run.py
浏览文件 @
2a93e018
...
...
@@ -122,7 +122,10 @@ class RunCommand(BaseCommand):
# data_format check
if
not
self
.
args
.
config
:
assert
len
(
expect_data_format
)
==
1
if
len
(
expect_data_format
)
!=
1
:
raise
RuntimeError
(
"Module requires %d inputs, please use config file to specify mappings for data and inputs."
%
len
(
expect_data_format
))
origin_data_key
=
list
(
origin_data
.
keys
())[
0
]
input_data_key
=
list
(
expect_data_format
.
keys
())[
0
]
input_data
=
{
input_data_key
:
origin_data
[
origin_data_key
]}
...
...
@@ -135,11 +138,22 @@ class RunCommand(BaseCommand):
input_data
=
{
input_data_key
:
origin_data
[
origin_data_key
]}
else
:
input_data_format
=
yaml_config
[
'input_data'
]
assert
len
(
input_data_format
)
==
len
(
expect_data_format
)
if
len
(
input_data_format
)
!=
len
(
expect_data_format
):
raise
ValueError
(
"Module requires %d inputs, but the input file gives %d."
%
(
len
(
expect_data_format
),
len
(
input_data_format
)))
for
key
,
value
in
expect_data_format
.
items
():
assert
key
in
input_data_format
assert
value
[
'type'
]
==
hub
.
DataType
.
type
(
input_data_format
[
key
][
'type'
])
if
key
not
in
input_data_format
:
raise
KeyError
(
"Input file gives an unexpected input %s"
%
key
)
if
value
[
'type'
]
!=
hub
.
DataType
.
type
(
input_data_format
[
key
][
'type'
]):
raise
TypeError
(
"Module expect Type %s for %s, but the input file gives %s"
%
(
value
[
'type'
],
key
,
hub
.
DataType
.
type
(
input_data_format
[
key
][
'type'
])))
input_data
=
{}
for
key
,
value
in
yaml_config
[
'input_data'
].
items
():
...
...
paddlehub/common/paddle_helper.py
浏览文件 @
2a93e018
...
...
@@ -26,9 +26,9 @@ from paddlehub.common.logger import logger
def
get_variable_info
(
var
):
assert
isinstance
(
var
,
fluid
.
framework
.
Variable
),
"var should be a fluid.framework.Variable"
if
not
isinstance
(
var
,
fluid
.
framework
.
Variable
):
raise
TypeError
(
"var shoule be an instance of fluid.framework.Variable"
)
var_info
=
{
'type'
:
var
.
type
,
'name'
:
var
.
name
,
...
...
@@ -148,20 +148,26 @@ def connect_program(pre_program, next_program, input_dict=None, inplace=True):
}
to_block
.
append_op
(
**
op_info
)
assert
isinstance
(
pre_program
,
fluid
.
Program
),
"pre_program should be fluid.Program"
assert
isinstance
(
next_program
,
fluid
.
Program
),
"next_program should be fluid.Program"
if
not
isinstance
(
pre_program
,
fluid
.
Program
):
raise
TypeError
(
"pre_program shoule be an instance of fluid.Program"
)
if
not
isinstance
(
next_program
,
fluid
.
Program
):
raise
TypeError
(
"next_program shoule be an instance of fluid.Program"
)
output_program
=
pre_program
if
inplace
else
pre_program
.
clone
(
for_test
=
False
)
if
input_dict
:
assert
isinstance
(
input_dict
,
dict
),
"the input_dict should be a dict with string-Variable pair"
if
not
isinstance
(
input_dict
,
dict
):
raise
TypeError
(
"input_dict shoule be a python dict like {str:fluid.framework.Variable}"
)
for
key
,
var
in
input_dict
.
items
():
assert
isinstance
(
var
,
fluid
.
framework
.
Variable
),
"the input_dict should be a dict with string-Variable pair"
if
not
isinstance
(
var
,
fluid
.
framework
.
Variable
):
raise
TypeError
(
"input_dict shoule be a python dict like {str:fluid.framework.Variable}"
)
var_info
=
copy
.
deepcopy
(
get_variable_info
(
var
))
input_var
=
output_program
.
global_block
().
create_var
(
**
var_info
)
output_var
=
next_program
.
global_block
().
var
(
key
)
...
...
paddlehub/module/module.py
浏览文件 @
2a93e018
...
...
@@ -117,9 +117,10 @@ class Module(object):
self
.
_init_with_module_file
(
module_dir
=
module_dir
)
elif
signatures
:
if
processor
:
assert
issubclass
(
processor
,
BaseProcessor
),
"processor should be sub class of hub.BaseProcessor"
if
not
issubclass
(
processor
,
BaseProcessor
):
raise
TypeError
(
"processor shoule be an instance of paddlehub.BaseProcessor"
)
if
assets
:
self
.
assets
=
utils
.
to_list
(
assets
)
# for asset in assets:
...
...
@@ -446,7 +447,8 @@ class Module(object):
return
result
def
check_processor
(
self
):
assert
self
.
processor
,
"this module couldn't be call"
if
not
self
.
processor
:
raise
ValueError
(
"This Module is not callable!"
)
def
context
(
self
,
sign_name
,
...
...
@@ -461,7 +463,9 @@ class Module(object):
available for BERT/ERNIE module
"""
assert
sign_name
in
self
.
signatures
,
"module did not have a signature with name %s"
%
sign_name
if
sign_name
not
in
self
.
signatures
:
raise
KeyError
(
"Module did not have a signature with name %s"
%
sign_name
)
signature
=
self
.
signatures
[
sign_name
]
program
=
self
.
program
.
clone
(
for_test
=
for_test
)
...
...
@@ -535,19 +539,28 @@ class Module(object):
return
self
.
get_name_prefix
()
+
var_name
def
_check_signatures
(
self
):
assert
self
.
signatures
,
"Signature array should not be None"
if
not
self
.
signatures
:
raise
ValueError
(
"Signatures should not be None"
)
for
key
,
sign
in
self
.
signatures
.
items
():
assert
isinstance
(
sign
,
Signature
),
"sign_arr should be list of Signature"
if
not
isinstance
(
sign
,
Signature
):
raise
TypeError
(
"Item in Signatures shoule be an instance of paddlehub.Signature"
)
for
input
in
sign
.
inputs
:
_tmp_program
=
input
.
block
.
program
assert
self
.
program
==
_tmp_program
,
"all the variable should come from the same program"
if
not
self
.
program
==
_tmp_program
:
raise
ValueError
(
"All input and outputs variables in signature should come from the same Program"
)
for
output
in
sign
.
outputs
:
_tmp_program
=
output
.
block
.
program
assert
self
.
program
==
_tmp_program
,
"all the variable should come from the same program"
if
not
self
.
program
==
_tmp_program
:
raise
ValueError
(
"All input and outputs variables in signature should come from the same Program"
)
def
serialize_to_path
(
self
,
path
=
None
,
exe
=
None
):
self
.
_check_signatures
()
...
...
paddlehub/module/signature.py
浏览文件 @
2a93e018
...
...
@@ -35,23 +35,29 @@ class Signature:
if
not
feed_names
:
feed_names
=
[
""
]
*
len
(
inputs
)
feed_names
=
to_list
(
feed_names
)
assert
len
(
inputs
)
==
len
(
feed_names
),
"the length of feed_names must be same with inputs"
if
len
(
inputs
)
!=
len
(
feed_names
):
raise
ValueError
(
"The length of feed_names must be same with inputs"
)
if
not
fetch_names
:
fetch_names
=
[
""
]
*
len
(
outputs
)
fetch_names
=
to_list
(
fetch_names
)
assert
len
(
outputs
)
==
len
(
fetch_names
),
"the length of fetch_names must be same with outputs"
if
len
(
outputs
)
!=
len
(
fetch_names
):
raise
ValueError
(
"the length of fetch_names must be same with outputs"
)
self
.
name
=
name
for
item
in
inputs
:
assert
isinstance
(
item
,
Variable
),
"the item of inputs list shoule be Variable"
if
not
isinstance
(
item
,
Variable
):
raise
TypeError
(
"Item in inputs list shoule be an instance of fluid.framework.Variable"
)
for
item
in
outputs
:
assert
isinstance
(
item
,
Variable
),
"the item of outputs list shoule be Variable"
if
not
isinstance
(
item
,
Variable
):
raise
TypeError
(
"Item in outputs list shoule be an instance of fluid.framework.Variable"
)
self
.
inputs
=
inputs
self
.
outputs
=
outputs
...
...
paddlehub/reader/nlp_reader.py
浏览文件 @
2a93e018
...
...
@@ -300,7 +300,8 @@ class SequenceLabelReader(BaseReader):
return
return_list
def
_reseg_token_label
(
self
,
tokens
,
labels
,
tokenizer
):
assert
len
(
tokens
)
==
len
(
labels
)
if
len
(
tokens
)
!=
len
(
labels
):
raise
ValueError
(
"The length of tokens must be same with labels"
)
ret_tokens
=
[]
ret_labels
=
[]
for
token
,
label
in
zip
(
tokens
,
labels
):
...
...
@@ -316,7 +317,8 @@ class SequenceLabelReader(BaseReader):
sub_label
=
"I-"
+
label
[
2
:]
ret_labels
.
extend
([
sub_label
]
*
(
len
(
sub_token
)
-
1
))
assert
len
(
ret_tokens
)
==
len
(
ret_labels
)
if
len
(
ret_tokens
)
!=
len
(
labels
):
raise
ValueError
(
"The length of ret_tokens can't match with labels"
)
return
ret_tokens
,
ret_labels
def
_convert_example_to_record
(
self
,
example
,
max_seq_length
,
tokenizer
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录