提交 da754d85 编写于 作者: D dangqingqing

srl api training

上级 c3fe50bc
import numpy import numpy
import paddle.v2 as paddle import paddle.v2 as paddle
from paddle.trainer_config_helpers.atts import ParamAttr from model_v2 import db_lstm
from mode_v2 import db_lstm
word_dict_file = './data/wordDict.txt' word_dict_file = './data/wordDict.txt'
label_dict_file = './data/targetDict.txt' label_dict_file = './data/targetDict.txt'
...@@ -64,9 +62,8 @@ def train_reader(file_name="data/feature"): ...@@ -64,9 +62,8 @@ def train_reader(file_name="data/feature"):
def main(): def main():
paddle.init(use_gpu=False, trainer_count=1) paddle.init(use_gpu=False, trainer_count=1)
label_dict_len = 500
# define network topology # define network topology
output = db_lstm() output = db_lstm(word_dict_len, label_dict_len, pred_len)
target = paddle.layer.data(name='target', size=label_dict_len) target = paddle.layer.data(name='target', size=label_dict_len)
crf_cost = paddle.layer.crf_layer( crf_cost = paddle.layer.crf_layer(
size=500, size=500,
...@@ -97,6 +94,17 @@ def main(): ...@@ -97,6 +94,17 @@ def main():
trainer = paddle.trainer.SGD(update_equation=optimizer) trainer = paddle.trainer.SGD(update_equation=optimizer)
reader_dict = {
'word_data': 0,
'verb_data': 1,
'ctx_n2_data': 2,
'ctx_n1_data': 3,
'ctx_0_data': 4,
'ctx_p1_data': 5,
'ctx_p2_data': 6,
'mark_data': 7,
'target': 8
}
trainer.train( trainer.train(
train_data_reader=train_reader, train_data_reader=train_reader,
batch_size=32, batch_size=32,
...@@ -104,8 +112,7 @@ def main(): ...@@ -104,8 +112,7 @@ def main():
parameters=parameters, parameters=parameters,
event_handler=event_handler, event_handler=event_handler,
num_passes=10000, num_passes=10000,
data_types=[], reader_dict=reader_dict)
reader_dict={})
if __name__ == '__main__': if __name__ == '__main__':
......
import math
import paddle.v2 as paddle import paddle.v2 as paddle
...@@ -9,15 +10,18 @@ def db_lstm(word_dict_len, label_dict_len, pred_len): ...@@ -9,15 +10,18 @@ def db_lstm(word_dict_len, label_dict_len, pred_len):
depth = 8 depth = 8
#8 features #8 features
word = paddle.layer.data(name='word_data', size=word_dict_len) def d_type(size):
predicate = paddle.layer.data(name='verb_data', size=pred_len) return paddle.data_type.integer_value_sequence(size)
ctx_n2 = paddle.layer.data(name='ctx_n2_data', size=word_dict_len) word = paddle.layer.data(name='word_data', type=d_type(word_dict_len))
ctx_n1 = paddle.layer.data(name='ctx_n1_data', size=word_dict_len) predicate = paddle.layer.data(name='verb_data', type=d_type(pred_len))
ctx_0 = paddle.layer.data(name='ctx_0_data', size=word_dict_len)
ctx_p1 = paddle.layer.data(name='ctx_p1_data', size=word_dict_len) ctx_n2 = paddle.layer.data(name='ctx_n2_data', type=d_type(word_dict_len))
ctx_p2 = paddle.layer.data(name='ctx_p2_data', size=word_dict_len) ctx_n1 = paddle.layer.data(name='ctx_n1_data', type=d_type(word_dict_len))
mark = paddle.layer.data(name='mark_data', size=mark_dict_len) ctx_0 = paddle.layer.data(name='ctx_0_data', type=d_type(word_dict_len))
ctx_p1 = paddle.layer.data(name='ctx_p1_data', type=d_type(word_dict_len))
ctx_p2 = paddle.layer.data(name='ctx_p2_data', type=d_type(word_dict_len))
mark = paddle.layer.data(name='mark_data', type=d_type(mark_dict_len))
default_std = 1 / math.sqrt(hidden_dim) / 3.0 default_std = 1 / math.sqrt(hidden_dim) / 3.0
...@@ -31,10 +35,7 @@ def db_lstm(word_dict_len, label_dict_len, pred_len): ...@@ -31,10 +35,7 @@ def db_lstm(word_dict_len, label_dict_len, pred_len):
param_attr=paddle.attr.Param( param_attr=paddle.attr.Param(
name='vemb', initial_std=default_std)) name='vemb', initial_std=default_std))
mark_embedding = paddle.layer.embeding( mark_embedding = paddle.layer.embeding(
name='word_ctx-in_embedding', size=mark_dim, input=mark, param_attr=std_0)
size=mark_dim,
input=mark,
param_attr=std_0)
word_input = [word, ctx_n2, ctx_n1, ctx_0, ctx_p1, ctx_p2] word_input = [word, ctx_n2, ctx_n1, ctx_0, ctx_p1, ctx_p2]
emb_layers = [ emb_layers = [
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册