提交 da754d85 编写于 作者: D dangqingqing

srl api training

上级 c3fe50bc
import numpy
import paddle.v2 as paddle
from paddle.trainer_config_helpers.atts import ParamAttr
from mode_v2 import db_lstm
from model_v2 import db_lstm
word_dict_file = './data/wordDict.txt'
label_dict_file = './data/targetDict.txt'
......@@ -64,9 +62,8 @@ def train_reader(file_name="data/feature"):
def main():
paddle.init(use_gpu=False, trainer_count=1)
label_dict_len = 500
# define network topology
output = db_lstm()
output = db_lstm(word_dict_len, label_dict_len, pred_len)
target = paddle.layer.data(name='target', size=label_dict_len)
crf_cost = paddle.layer.crf_layer(
size=500,
......@@ -97,6 +94,17 @@ def main():
trainer = paddle.trainer.SGD(update_equation=optimizer)
reader_dict = {
'word_data': 0,
'verb_data': 1,
'ctx_n2_data': 2,
'ctx_n1_data': 3,
'ctx_0_data': 4,
'ctx_p1_data': 5,
'ctx_p2_data': 6,
'mark_data': 7,
'target': 8
}
trainer.train(
train_data_reader=train_reader,
batch_size=32,
......@@ -104,8 +112,7 @@ def main():
parameters=parameters,
event_handler=event_handler,
num_passes=10000,
data_types=[],
reader_dict={})
reader_dict=reader_dict)
if __name__ == '__main__':
......
import math
import paddle.v2 as paddle
......@@ -9,15 +10,18 @@ def db_lstm(word_dict_len, label_dict_len, pred_len):
depth = 8
#8 features
word = paddle.layer.data(name='word_data', size=word_dict_len)
predicate = paddle.layer.data(name='verb_data', size=pred_len)
def d_type(size):
return paddle.data_type.integer_value_sequence(size)
ctx_n2 = paddle.layer.data(name='ctx_n2_data', size=word_dict_len)
ctx_n1 = paddle.layer.data(name='ctx_n1_data', size=word_dict_len)
ctx_0 = paddle.layer.data(name='ctx_0_data', size=word_dict_len)
ctx_p1 = paddle.layer.data(name='ctx_p1_data', size=word_dict_len)
ctx_p2 = paddle.layer.data(name='ctx_p2_data', size=word_dict_len)
mark = paddle.layer.data(name='mark_data', size=mark_dict_len)
word = paddle.layer.data(name='word_data', type=d_type(word_dict_len))
predicate = paddle.layer.data(name='verb_data', type=d_type(pred_len))
ctx_n2 = paddle.layer.data(name='ctx_n2_data', type=d_type(word_dict_len))
ctx_n1 = paddle.layer.data(name='ctx_n1_data', type=d_type(word_dict_len))
ctx_0 = paddle.layer.data(name='ctx_0_data', type=d_type(word_dict_len))
ctx_p1 = paddle.layer.data(name='ctx_p1_data', type=d_type(word_dict_len))
ctx_p2 = paddle.layer.data(name='ctx_p2_data', type=d_type(word_dict_len))
mark = paddle.layer.data(name='mark_data', type=d_type(mark_dict_len))
default_std = 1 / math.sqrt(hidden_dim) / 3.0
......@@ -31,10 +35,7 @@ def db_lstm(word_dict_len, label_dict_len, pred_len):
param_attr=paddle.attr.Param(
name='vemb', initial_std=default_std))
mark_embedding = paddle.layer.embeding(
name='word_ctx-in_embedding',
size=mark_dim,
input=mark,
param_attr=std_0)
size=mark_dim, input=mark, param_attr=std_0)
word_input = [word, ctx_n2, ctx_n1, ctx_0, ctx_p1, ctx_p2]
emb_layers = [
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册