Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
4c73240d
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
4c73240d
编写于
5月 17, 2017
作者:
C
caoying03
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
follow comments.
上级
32f7176d
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
83 addition
and
51 deletion
+83
-51
demo/semantic_role_labeling/api_train_v2.py
demo/semantic_role_labeling/api_train_v2.py
+83
-51
未找到文件。
demo/semantic_role_labeling/api_train_v2.py
浏览文件 @
4c73240d
...
...
@@ -6,6 +6,8 @@ import paddle.v2.dataset.conll05 as conll05
import
paddle.v2.evaluator
as
evaluator
import
paddle.v2
as
paddle
logger
=
logging
.
getLogger
(
'paddle'
)
word_dict
,
verb_dict
,
label_dict
=
conll05
.
get_dict
()
word_dict_len
=
len
(
word_dict
)
label_dict_len
=
len
(
label_dict
)
...
...
@@ -120,19 +122,7 @@ def load_parameter(file_name, h, w):
return
np
.
fromfile
(
f
,
dtype
=
np
.
float32
).
reshape
(
h
,
w
)
def
test_a_batch
(
inferer
,
test_data
,
tag_dict
):
probs
=
inferer
.
infer
(
input
=
test_data
,
field
=
'id'
)
assert
len
(
probs
)
==
sum
(
len
(
x
[
0
])
for
x
in
test_data
)
for
test_sample
in
test_data
:
start_id
=
0
pre_lab
=
[
tag_dict
[
probs
[
start_id
+
i
]]
for
i
in
xrange
(
len
(
test_sample
[
0
]))
]
print
pre_lab
start_id
+=
len
(
test_sample
[
0
])
def
main
(
is_predict
=
False
):
def
train
():
paddle
.
init
(
use_gpu
=
False
,
trainer_count
=
1
)
# define network topology
...
...
@@ -189,12 +179,12 @@ def main(is_predict=False):
def
event_handler
(
event
):
if
isinstance
(
event
,
paddle
.
event
.
EndIteration
):
if
event
.
batch_id
%
100
==
0
:
print
"Pass %d, Batch %d, Cost %f, %s"
%
(
event
.
pass_id
,
event
.
batch_id
,
event
.
cost
,
event
.
metrics
)
if
event
.
batch_id
%
1000
==
0
:
logger
.
info
(
"Pass %d, Batch %d, Cost %f, %s"
%
(
event
.
pass_id
,
event
.
batch_id
,
event
.
cost
,
event
.
metrics
)
)
if
event
.
batch_id
and
event
.
batch_id
%
1000
==
0
:
result
=
trainer
.
test
(
reader
=
reader
,
feeding
=
feeding
)
print
"
\n
Test with Pass %d, Batch %d, %s"
%
(
event
.
pass_id
,
event
.
batch_id
,
result
.
metrics
)
logger
.
info
(
"
\n
Test with Pass %d, Batch %d, %s"
%
(
event
.
pass_id
,
event
.
batch_id
,
result
.
metrics
)
)
if
isinstance
(
event
,
paddle
.
event
.
EndPass
):
# save parameters
...
...
@@ -202,44 +192,86 @@ def main(is_predict=False):
parameters
.
to_tar
(
f
)
result
=
trainer
.
test
(
reader
=
reader
,
feeding
=
feeding
)
print
"
\n
Test with Pass %d, %s"
%
(
event
.
pass_id
,
result
.
metrics
)
if
not
is_predict
:
trainer
.
train
(
reader
=
reader
,
event_handler
=
event_handler
,
num_passes
=
10
,
feeding
=
feeding
)
else
:
labels_reverse
=
{}
for
(
k
,
v
)
in
label_dict
.
items
():
labels_reverse
[
v
]
=
k
test_creator
=
paddle
.
dataset
.
conll05
.
test
()
logger
.
info
(
"
\n
Test with Pass %d, %s"
%
(
event
.
pass_id
,
result
.
metrics
))
trainer
.
train
(
reader
=
reader
,
event_handler
=
event_handler
,
num_passes
=
10
,
feeding
=
feeding
)
predict
=
paddle
.
layer
.
crf_decoding
(
size
=
label_dict_len
,
input
=
feature_out
,
param_attr
=
paddle
.
attr
.
Param
(
name
=
'crfw'
))
test_pass
=
0
with
gzip
.
open
(
'params_pass_%d.tar.gz'
%
(
test_pass
))
as
f
:
parameters
=
paddle
.
parameters
.
Parameters
.
from_tar
(
f
)
inferer
=
paddle
.
inference
.
Inference
(
output_layer
=
predict
,
parameters
=
parameters
)
def
infer_a_batch
(
inferer
,
test_data
,
word_dict
,
pred_dict
,
label_dict
):
probs
=
inferer
.
infer
(
input
=
test_data
,
field
=
'id'
)
assert
len
(
probs
)
==
sum
(
len
(
x
[
0
])
for
x
in
test_data
)
# prepare test data
test_data
=
[]
test_batch_size
=
50
for
idx
,
test_sample
in
enumerate
(
test_data
):
start_id
=
0
pred_str
=
"%s
\t
"
%
(
pred_dict
[
test_sample
[
6
][
0
]])
for
idx
,
item
in
enumerate
(
test_creator
()):
test_data
.
append
(
item
[
0
:
8
])
for
w
,
tag
in
zip
(
test_sample
[
0
],
probs
[
start_id
:
start_id
+
len
(
test_sample
[
0
])]):
pred_str
+=
"%s[%s] "
%
(
word_dict
[
w
],
label_dict
[
tag
])
print
(
pred_str
.
strip
())
start_id
+=
len
(
test_sample
[
0
])
if
idx
and
(
not
idx
%
test_batch_size
):
test_a_batch
(
inferer
,
test_data
,
labels_reverse
)
test_data
=
[]
test_a_batch
(
inferer
,
test_data
,
labels_reverse
)
test_data
=
[]
def
infer
():
label_dict_reverse
=
dict
((
value
,
key
)
for
key
,
value
in
label_dict
.
iteritems
())
word_dict_reverse
=
dict
((
value
,
key
)
for
key
,
value
in
word_dict
.
iteritems
())
pred_dict_reverse
=
dict
((
value
,
key
)
for
key
,
value
in
verb_dict
.
iteritems
())
test_creator
=
paddle
.
dataset
.
conll05
.
test
()
paddle
.
init
(
use_gpu
=
False
,
trainer_count
=
1
)
# define network topology
feature_out
=
db_lstm
()
predict
=
paddle
.
layer
.
crf_decoding
(
size
=
label_dict_len
,
input
=
feature_out
,
param_attr
=
paddle
.
attr
.
Param
(
name
=
'crfw'
))
test_pass
=
0
with
gzip
.
open
(
'params_pass_%d.tar.gz'
%
(
test_pass
))
as
f
:
parameters
=
paddle
.
parameters
.
Parameters
.
from_tar
(
f
)
inferer
=
paddle
.
inference
.
Inference
(
output_layer
=
predict
,
parameters
=
parameters
)
# prepare test data
test_data
=
[]
test_batch_size
=
50
for
idx
,
item
in
enumerate
(
test_creator
()):
test_data
.
append
(
item
[
0
:
8
])
if
idx
and
(
not
idx
%
test_batch_size
):
infer_a_batch
(
inferer
,
test_data
,
word_dict_reverse
,
pred_dict_reverse
,
label_dict_reverse
,
)
test_data
=
[]
infer_a_batch
(
inferer
,
test_data
,
word_dict_reverse
,
pred_dict_reverse
,
label_dict_reverse
,
)
test_data
=
[]
def
main
(
is_inferring
=
False
):
if
is_inferring
:
infer
()
else
:
train
()
if
__name__
==
'__main__'
:
main
(
is_
predict
=
False
)
main
(
is_
inferring
=
False
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录