提交 e6e8bfb4 编写于 作者: D dangqingqing

update

上级 d60116db
......@@ -2,6 +2,8 @@ import numpy
import paddle.v2 as paddle
from model_v2 import db_lstm
UNK_IDX = 0
word_dict_file = './data/wordDict.txt'
label_dict_file = './data/targetDict.txt'
predicate_file = './data/verbDict.txt'
......@@ -29,6 +31,10 @@ word_dict_len = len(word_dict)
label_dict_len = len(label_dict)
pred_len = len(predicate_dict)
print 'word_dict_len=%d' % word_dict_len
print 'label_dict_len=%d' % label_dict_len
print 'pred_len=%d' % pred_len
def train_reader(file_name="data/feature"):
def reader():
......@@ -63,31 +69,16 @@ def main():
paddle.init(use_gpu=False, trainer_count=1)
# define network topology
output = db_lstm(word_dict_len, label_dict_len, pred_len)
target = paddle.layer.data(name='target', size=label_dict_len)
crf_cost = paddle.layer.crf_layer(
size=500,
input=output,
label=target,
param_attr=paddle.attr.Param(
name='crfw', initial_std=default_std, learning_rate=mix_hidden_lr))
crf_dec = paddle.layer.crf_decoding_layer(
name='crf_dec_l',
size=label_dict_len,
input=output,
label=target,
param_attr=paddle.attr.Param(name='crfw'))
topo = [crf_cost, crf_dec]
parameters = paddle.parameters.create(topo)
crf_cost, crf_dec = db_lstm(word_dict_len, label_dict_len, pred_len)
#parameters = paddle.parameters.create([crf_cost, crf_dec])
parameters = paddle.parameters.create(crf_cost)
optimizer = paddle.optimizer.Momentum(momentum=0.01, learning_rate=2e-2)
def event_handler(event):
if isinstance(event, paddle.event.EndIteration):
para = parameters.get('___fc_2__.w0')
print "Pass %d, Batch %d, Cost %f" % (event.pass_id, event.batch_id,
event.cost, para.mean())
event.cost)
else:
pass
......@@ -96,23 +87,27 @@ def main():
reader_dict = {
'word_data': 0,
'verb_data': 1,
'ctx_n2_data': 2,
'ctx_n1_data': 3,
'ctx_0_data': 4,
'ctx_p1_data': 5,
'ctx_p2_data': 6,
'ctx_n2_data': 1,
'ctx_n1_data': 2,
'ctx_0_data': 3,
'ctx_p1_data': 4,
'ctx_p2_data': 5,
'verb_data': 6,
'mark_data': 7,
'target': 8
'target': 8,
}
#trn_reader = paddle.reader.batched(
# paddle.reader.shuffle(
# train_reader(), buf_size=8192), batch_size=2)
trn_reader = paddle.reader.batched(train_reader(), batch_size=1)
trainer.train(
train_data_reader=train_reader,
batch_size=32,
topology=topo,
reader=trn_reader,
cost=crf_cost,
parameters=parameters,
event_handler=event_handler,
num_passes=10000,
reader_dict=reader_dict)
#cost=[crf_cost, crf_dec],
if __name__ == '__main__':
......
......@@ -23,23 +23,25 @@ def db_lstm(word_dict_len, label_dict_len, pred_len):
ctx_p2 = paddle.layer.data(name='ctx_p2_data', type=d_type(word_dict_len))
mark = paddle.layer.data(name='mark_data', type=d_type(mark_dict_len))
target = paddle.layer.data(name='target', type=d_type(label_dict_len))
default_std = 1 / math.sqrt(hidden_dim) / 3.0
emb_para = paddle.attr.Param(name='emb', initial_std=0., learning_rate=0.)
std_0 = paddle.attr.Param(initial_std=0.)
std_default = paddle.attr.Param(initial_std=default_std)
predicate_embedding = paddle.layer.embeding(
predicate_embedding = paddle.layer.embedding(
size=word_dim,
input=predicate,
param_attr=paddle.attr.Param(
name='vemb', initial_std=default_std))
mark_embedding = paddle.layer.embeding(
mark_embedding = paddle.layer.embedding(
size=mark_dim, input=mark, param_attr=std_0)
word_input = [word, ctx_n2, ctx_n1, ctx_0, ctx_p1, ctx_p2]
emb_layers = [
paddle.layer.embeding(
paddle.layer.embedding(
size=word_dim, input=x, param_attr=emb_para) for x in word_input
]
emb_layers.append(predicate_embedding)
......@@ -101,4 +103,19 @@ def db_lstm(word_dict_len, label_dict_len, pred_len):
input=input_tmp[1], param_attr=lstm_para_attr)
], )
return feature_out
crf_cost = paddle.layer.crf(size=label_dict_len,
input=feature_out,
label=target,
param_attr=paddle.attr.Param(
name='crfw',
initial_std=default_std,
learning_rate=mix_hidden_lr))
crf_dec = paddle.layer.crf_decoding(
name='crf_dec_l',
size=label_dict_len,
input=feature_out,
label=target,
param_attr=paddle.attr.Param(name='crfw'))
return crf_cost, crf_dec
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册