Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
83c24ff2
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
83c24ff2
编写于
5月 17, 2017
作者:
C
Cao Ying
提交者:
GitHub
5月 17, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #2178 from lcy-seso/update_srl_demo
update the SRL demo.
上级
39b91123
4c73240d
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
137 addition
and
50 deletion
+137
-50
demo/semantic_role_labeling/api_train_v2.py
demo/semantic_role_labeling/api_train_v2.py
+137
-50
未找到文件。
demo/semantic_role_labeling/api_train_v2.py
浏览文件 @
83c24ff2
import
sys
import
math
import
numpy
as
np
import
paddle.v2
as
paddle
import
gzip
import
logging
import
paddle.v2.dataset.conll05
as
conll05
import
paddle.v2.evaluator
as
evaluator
import
paddle.v2
as
paddle
logger
=
logging
.
getLogger
(
'paddle'
)
def
db_lstm
():
word_dict
,
verb_dict
,
label_dict
=
conll05
.
get_dict
()
word_dict_len
=
len
(
word_dict
)
label_dict_len
=
len
(
label_dict
)
pred_len
=
len
(
verb_dict
)
word_dict
,
verb_dict
,
label_dict
=
conll05
.
get_dict
()
word_dict_len
=
len
(
word_dict
)
label_dict_len
=
len
(
label_dict
)
pred_len
=
len
(
verb_dict
)
mark_dict_len
=
2
word_dim
=
32
mark_dim
=
5
hidden_dim
=
512
depth
=
8
mark_dict_len
=
2
word_dim
=
32
mark_dim
=
5
hidden_dim
=
512
depth
=
8
default_std
=
1
/
math
.
sqrt
(
hidden_dim
)
/
3.0
mix_hidden_lr
=
1e-3
#8 features
def
d_type
(
size
):
return
paddle
.
data_type
.
integer_value_sequence
(
size
)
def
d_type
(
size
):
return
paddle
.
data_type
.
integer_value_sequence
(
size
)
def
db_lstm
():
#8 features
word
=
paddle
.
layer
.
data
(
name
=
'word_data'
,
type
=
d_type
(
word_dict_len
))
predicate
=
paddle
.
layer
.
data
(
name
=
'verb_data'
,
type
=
d_type
(
pred_len
))
...
...
@@ -31,11 +38,7 @@ def db_lstm():
ctx_p2
=
paddle
.
layer
.
data
(
name
=
'ctx_p2_data'
,
type
=
d_type
(
word_dict_len
))
mark
=
paddle
.
layer
.
data
(
name
=
'mark_data'
,
type
=
d_type
(
mark_dict_len
))
target
=
paddle
.
layer
.
data
(
name
=
'target'
,
type
=
d_type
(
label_dict_len
))
default_std
=
1
/
math
.
sqrt
(
hidden_dim
)
/
3.0
emb_para
=
paddle
.
attr
.
Param
(
name
=
'emb'
,
initial_std
=
0.
,
learning_rate
=
0.
)
emb_para
=
paddle
.
attr
.
Param
(
name
=
'emb'
,
initial_std
=
0.
,
is_static
=
True
)
std_0
=
paddle
.
attr
.
Param
(
initial_std
=
0.
)
std_default
=
paddle
.
attr
.
Param
(
initial_std
=
default_std
)
...
...
@@ -63,7 +66,6 @@ def db_lstm():
input
=
emb
,
param_attr
=
std_default
)
for
emb
in
emb_layers
])
mix_hidden_lr
=
1e-3
lstm_para_attr
=
paddle
.
attr
.
Param
(
initial_std
=
0.0
,
learning_rate
=
1.0
)
hidden_para_attr
=
paddle
.
attr
.
Param
(
initial_std
=
default_std
,
learning_rate
=
mix_hidden_lr
)
...
...
@@ -111,6 +113,21 @@ def db_lstm():
input
=
input_tmp
[
1
],
param_attr
=
lstm_para_attr
)
],
)
return
feature_out
def
load_parameter
(
file_name
,
h
,
w
):
with
open
(
file_name
,
'rb'
)
as
f
:
f
.
read
(
16
)
# skip header.
return
np
.
fromfile
(
f
,
dtype
=
np
.
float32
).
reshape
(
h
,
w
)
def
train
():
paddle
.
init
(
use_gpu
=
False
,
trainer_count
=
1
)
# define network topology
feature_out
=
db_lstm
()
target
=
paddle
.
layer
.
data
(
name
=
'target'
,
type
=
d_type
(
label_dict_len
))
crf_cost
=
paddle
.
layer
.
crf
(
size
=
label_dict_len
,
input
=
feature_out
,
label
=
target
,
...
...
@@ -120,29 +137,15 @@ def db_lstm():
learning_rate
=
mix_hidden_lr
))
crf_dec
=
paddle
.
layer
.
crf_decoding
(
name
=
'crf_dec_l'
,
size
=
label_dict_len
,
input
=
feature_out
,
label
=
target
,
param_attr
=
paddle
.
attr
.
Param
(
name
=
'crfw'
))
return
crf_cost
,
crf_dec
def
load_parameter
(
file_name
,
h
,
w
):
with
open
(
file_name
,
'rb'
)
as
f
:
f
.
read
(
16
)
# skip header.
return
np
.
fromfile
(
f
,
dtype
=
np
.
float32
).
reshape
(
h
,
w
)
def
main
():
paddle
.
init
(
use_gpu
=
False
,
trainer_count
=
1
)
# define network topology
crf_cost
,
crf_dec
=
db_lstm
()
evaluator
.
sum
(
input
=
crf_dec
)
# create parameters
parameters
=
paddle
.
parameters
.
create
([
crf_cost
,
crf_dec
])
parameters
=
paddle
.
parameters
.
create
(
crf_cost
)
parameters
.
set
(
'emb'
,
load_parameter
(
conll05
.
get_embedding
(),
44068
,
32
))
# create optimizer
optimizer
=
paddle
.
optimizer
.
Momentum
(
...
...
@@ -152,18 +155,12 @@ def main():
model_average
=
paddle
.
optimizer
.
ModelAverage
(
average_window
=
0.5
,
max_average_window
=
10000
),
)
def
event_handler
(
event
):
if
isinstance
(
event
,
paddle
.
event
.
EndIteration
):
if
event
.
batch_id
%
100
==
0
:
print
"Pass %d, Batch %d, Cost %f, %s"
%
(
event
.
pass_id
,
event
.
batch_id
,
event
.
cost
,
event
.
metrics
)
trainer
=
paddle
.
trainer
.
SGD
(
cost
=
crf_cost
,
parameters
=
parameters
,
update_equation
=
optimizer
)
parameters
.
set
(
'emb'
,
load_parameter
(
conll05
.
get_embedding
(),
44068
,
32
)
)
update_equation
=
optimizer
,
extra_layers
=
crf_dec
)
trn_
reader
=
paddle
.
batch
(
reader
=
paddle
.
batch
(
paddle
.
reader
.
shuffle
(
conll05
.
test
(),
buf_size
=
8192
),
batch_size
=
10
)
...
...
@@ -179,12 +176,102 @@ def main():
'target'
:
8
}
def
event_handler
(
event
):
if
isinstance
(
event
,
paddle
.
event
.
EndIteration
):
if
event
.
batch_id
%
100
==
0
:
logger
.
info
(
"Pass %d, Batch %d, Cost %f, %s"
%
(
event
.
pass_id
,
event
.
batch_id
,
event
.
cost
,
event
.
metrics
))
if
event
.
batch_id
and
event
.
batch_id
%
1000
==
0
:
result
=
trainer
.
test
(
reader
=
reader
,
feeding
=
feeding
)
logger
.
info
(
"
\n
Test with Pass %d, Batch %d, %s"
%
(
event
.
pass_id
,
event
.
batch_id
,
result
.
metrics
))
if
isinstance
(
event
,
paddle
.
event
.
EndPass
):
# save parameters
with
gzip
.
open
(
'params_pass_%d.tar.gz'
%
event
.
pass_id
,
'w'
)
as
f
:
parameters
.
to_tar
(
f
)
result
=
trainer
.
test
(
reader
=
reader
,
feeding
=
feeding
)
logger
.
info
(
"
\n
Test with Pass %d, %s"
%
(
event
.
pass_id
,
result
.
metrics
))
trainer
.
train
(
reader
=
trn_
reader
,
reader
=
reader
,
event_handler
=
event_handler
,
num_passes
=
10
000
,
num_passes
=
10
,
feeding
=
feeding
)
def
infer_a_batch
(
inferer
,
test_data
,
word_dict
,
pred_dict
,
label_dict
):
probs
=
inferer
.
infer
(
input
=
test_data
,
field
=
'id'
)
assert
len
(
probs
)
==
sum
(
len
(
x
[
0
])
for
x
in
test_data
)
for
idx
,
test_sample
in
enumerate
(
test_data
):
start_id
=
0
pred_str
=
"%s
\t
"
%
(
pred_dict
[
test_sample
[
6
][
0
]])
for
w
,
tag
in
zip
(
test_sample
[
0
],
probs
[
start_id
:
start_id
+
len
(
test_sample
[
0
])]):
pred_str
+=
"%s[%s] "
%
(
word_dict
[
w
],
label_dict
[
tag
])
print
(
pred_str
.
strip
())
start_id
+=
len
(
test_sample
[
0
])
def
infer
():
label_dict_reverse
=
dict
((
value
,
key
)
for
key
,
value
in
label_dict
.
iteritems
())
word_dict_reverse
=
dict
((
value
,
key
)
for
key
,
value
in
word_dict
.
iteritems
())
pred_dict_reverse
=
dict
((
value
,
key
)
for
key
,
value
in
verb_dict
.
iteritems
())
test_creator
=
paddle
.
dataset
.
conll05
.
test
()
paddle
.
init
(
use_gpu
=
False
,
trainer_count
=
1
)
# define network topology
feature_out
=
db_lstm
()
predict
=
paddle
.
layer
.
crf_decoding
(
size
=
label_dict_len
,
input
=
feature_out
,
param_attr
=
paddle
.
attr
.
Param
(
name
=
'crfw'
))
test_pass
=
0
with
gzip
.
open
(
'params_pass_%d.tar.gz'
%
(
test_pass
))
as
f
:
parameters
=
paddle
.
parameters
.
Parameters
.
from_tar
(
f
)
inferer
=
paddle
.
inference
.
Inference
(
output_layer
=
predict
,
parameters
=
parameters
)
# prepare test data
test_data
=
[]
test_batch_size
=
50
for
idx
,
item
in
enumerate
(
test_creator
()):
test_data
.
append
(
item
[
0
:
8
])
if
idx
and
(
not
idx
%
test_batch_size
):
infer_a_batch
(
inferer
,
test_data
,
word_dict_reverse
,
pred_dict_reverse
,
label_dict_reverse
,
)
test_data
=
[]
infer_a_batch
(
inferer
,
test_data
,
word_dict_reverse
,
pred_dict_reverse
,
label_dict_reverse
,
)
test_data
=
[]
def
main
(
is_inferring
=
False
):
if
is_inferring
:
infer
()
else
:
train
()
if
__name__
==
'__main__'
:
main
()
main
(
is_inferring
=
False
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录