From 542eb736ab66ca5f7f974fde8d6a91bbfa781f4b Mon Sep 17 00:00:00 2001 From: dangqingqing Date: Wed, 1 Mar 2017 15:47:07 +0800 Subject: [PATCH] update --- demo/semantic_role_labeling/api_train_v2.py | 37 +++++++++++---------- 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/demo/semantic_role_labeling/api_train_v2.py b/demo/semantic_role_labeling/api_train_v2.py index 0317c818db..cfbd2a0224 100644 --- a/demo/semantic_role_labeling/api_train_v2.py +++ b/demo/semantic_role_labeling/api_train_v2.py @@ -1,4 +1,4 @@ -import numpy +import numpy as np import paddle.v2 as paddle from model_v2 import db_lstm @@ -31,10 +31,6 @@ word_dict_len = len(word_dict) label_dict_len = len(label_dict) pred_len = len(predicate_dict) -print 'word_dict_len=%d' % word_dict_len -print 'label_dict_len=%d' % label_dict_len -print 'pred_len=%d' % pred_len - def train_reader(file_name="data/feature"): def reader(): @@ -65,25 +61,34 @@ def train_reader(file_name="data/feature"): return reader +def load_parameter(file_name, h, w): + with open(file_name, 'rb') as f: + f.read(16) # skip header for float type. + return np.fromfile(f, dtype=np.float32).reshape(h, w) + + def main(): paddle.init(use_gpu=False, trainer_count=1) # define network topology crf_cost, crf_dec = db_lstm(word_dict_len, label_dict_len, pred_len) - #parameters = paddle.parameters.create([crf_cost, crf_dec]) - parameters = paddle.parameters.create(crf_cost) + parameters = paddle.parameters.create([crf_cost, crf_dec]) optimizer = paddle.optimizer.Momentum(momentum=0.01, learning_rate=2e-2) def event_handler(event): if isinstance(event, paddle.event.EndIteration): - print "Pass %d, Batch %d, Cost %f" % (event.pass_id, event.batch_id, - event.cost) - + if event.batch_id % 100 == 0: + print "Pass %d, Batch %d, Cost %f, %s" % ( + event.pass_id, event.batch_id, event.cost, event.metrics) else: pass - trainer = paddle.trainer.SGD(update_equation=optimizer) + trainer = paddle.trainer.SGD(cost=crf_cost, + parameters=parameters, + update_equation=optimizer) + + parameters.set('emb', load_parameter("data/emb", 44068, 32)) reader_dict = { 'word_data': 0, @@ -96,18 +101,14 @@ def main(): 'mark_data': 7, 'target': 8, } - #trn_reader = paddle.reader.batched( - # paddle.reader.shuffle( - # train_reader(), buf_size=8192), batch_size=2) - trn_reader = paddle.reader.batched(train_reader(), batch_size=1) + trn_reader = paddle.reader.batched( + paddle.reader.shuffle( + train_reader(), buf_size=8192), batch_size=10) trainer.train( reader=trn_reader, - cost=crf_cost, - parameters=parameters, event_handler=event_handler, num_passes=10000, reader_dict=reader_dict) - #cost=[crf_cost, crf_dec], if __name__ == '__main__': -- GitLab