Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
542eb736
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
542eb736
编写于
3月 01, 2017
作者:
D
dangqingqing
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update
上级
8925ccf6
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
19 addition
and
18 deletion
+19
-18
demo/semantic_role_labeling/api_train_v2.py
demo/semantic_role_labeling/api_train_v2.py
+19
-18
未找到文件。
demo/semantic_role_labeling/api_train_v2.py
浏览文件 @
542eb736
import
numpy
import
numpy
as
np
import
paddle.v2
as
paddle
from
model_v2
import
db_lstm
...
...
@@ -31,10 +31,6 @@ word_dict_len = len(word_dict)
label_dict_len
=
len
(
label_dict
)
pred_len
=
len
(
predicate_dict
)
print
'word_dict_len=%d'
%
word_dict_len
print
'label_dict_len=%d'
%
label_dict_len
print
'pred_len=%d'
%
pred_len
def
train_reader
(
file_name
=
"data/feature"
):
def
reader
():
...
...
@@ -65,25 +61,34 @@ def train_reader(file_name="data/feature"):
return
reader
def
load_parameter
(
file_name
,
h
,
w
):
with
open
(
file_name
,
'rb'
)
as
f
:
f
.
read
(
16
)
# skip header for float type.
return
np
.
fromfile
(
f
,
dtype
=
np
.
float32
).
reshape
(
h
,
w
)
def
main
():
paddle
.
init
(
use_gpu
=
False
,
trainer_count
=
1
)
# define network topology
crf_cost
,
crf_dec
=
db_lstm
(
word_dict_len
,
label_dict_len
,
pred_len
)
#parameters = paddle.parameters.create([crf_cost, crf_dec])
parameters
=
paddle
.
parameters
.
create
(
crf_cost
)
parameters
=
paddle
.
parameters
.
create
([
crf_cost
,
crf_dec
])
optimizer
=
paddle
.
optimizer
.
Momentum
(
momentum
=
0.01
,
learning_rate
=
2e-2
)
def
event_handler
(
event
):
if
isinstance
(
event
,
paddle
.
event
.
EndIteration
):
print
"Pass %d, Batch %d, Cost %f"
%
(
event
.
pass_id
,
event
.
batch_id
,
event
.
cost
)
if
event
.
batch_id
%
100
==
0
:
print
"Pass %d, Batch %d, Cost %f, %s"
%
(
event
.
pass_id
,
event
.
batch_id
,
event
.
cost
,
event
.
metrics
)
else
:
pass
trainer
=
paddle
.
trainer
.
SGD
(
update_equation
=
optimizer
)
trainer
=
paddle
.
trainer
.
SGD
(
cost
=
crf_cost
,
parameters
=
parameters
,
update_equation
=
optimizer
)
parameters
.
set
(
'emb'
,
load_parameter
(
"data/emb"
,
44068
,
32
))
reader_dict
=
{
'word_data'
:
0
,
...
...
@@ -96,18 +101,14 @@ def main():
'mark_data'
:
7
,
'target'
:
8
,
}
#trn_reader = paddle.reader.batched(
# paddle.reader.shuffle(
# train_reader(), buf_size=8192), batch_size=2)
trn_reader
=
paddle
.
reader
.
batched
(
train_reader
(),
batch_size
=
1
)
trn_reader
=
paddle
.
reader
.
batched
(
paddle
.
reader
.
shuffle
(
train_reader
(),
buf_size
=
8192
),
batch_size
=
10
)
trainer
.
train
(
reader
=
trn_reader
,
cost
=
crf_cost
,
parameters
=
parameters
,
event_handler
=
event_handler
,
num_passes
=
10000
,
reader_dict
=
reader_dict
)
#cost=[crf_cost, crf_dec],
if
__name__
==
'__main__'
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录