api_train_v2.py 3.5 KB
Newer Older
D
update  
dangqingqing 已提交
1
import numpy as np
D
dangqingqing 已提交
2
import paddle.v2 as paddle
D
dangqingqing 已提交
3
from model_v2 import db_lstm
D
dangqingqing 已提交
4

D
update  
dangqingqing 已提交
5 6
UNK_IDX = 0

D
dangqingqing 已提交
7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
word_dict_file = './data/wordDict.txt'
label_dict_file = './data/targetDict.txt'
predicate_file = './data/verbDict.txt'

word_dict = dict()
label_dict = dict()
predicate_dict = dict()

with open(word_dict_file, 'r') as f_word, \
     open(label_dict_file, 'r') as f_label, \
     open(predicate_file, 'r') as f_pre:
    for i, line in enumerate(f_word):
        w = line.strip()
        word_dict[w] = i

    for i, line in enumerate(f_label):
        w = line.strip()
        label_dict[w] = i

    for i, line in enumerate(f_pre):
        w = line.strip()
        predicate_dict[w] = i

word_dict_len = len(word_dict)
label_dict_len = len(label_dict)
pred_len = len(predicate_dict)


def train_reader(file_name="data/feature"):
    def reader():
        with open(file_name, 'r') as fdata:
            for line in fdata:
                sentence, predicate, ctx_n2, ctx_n1, ctx_0, ctx_p1, ctx_p2,  mark, label = \
                    line.strip().split('\t')

                words = sentence.split()
                sen_len = len(words)
                word_slot = [word_dict.get(w, UNK_IDX) for w in words]

                predicate_slot = [predicate_dict.get(predicate)] * sen_len
                ctx_n2_slot = [word_dict.get(ctx_n2, UNK_IDX)] * sen_len
                ctx_n1_slot = [word_dict.get(ctx_n1, UNK_IDX)] * sen_len
                ctx_0_slot = [word_dict.get(ctx_0, UNK_IDX)] * sen_len
                ctx_p1_slot = [word_dict.get(ctx_p1, UNK_IDX)] * sen_len
                ctx_p2_slot = [word_dict.get(ctx_p2, UNK_IDX)] * sen_len

                marks = mark.split()
                mark_slot = [int(w) for w in marks]

                label_list = label.split()
                label_slot = [label_dict.get(w) for w in label_list]
                yield word_slot, ctx_n2_slot, ctx_n1_slot, \
                  ctx_0_slot, ctx_p1_slot, ctx_p2_slot, predicate_slot, mark_slot, label_slot

    return reader


D
update  
dangqingqing 已提交
64 65 66 67 68 69
def load_parameter(file_name, h, w):
    with open(file_name, 'rb') as f:
        f.read(16)  # skip header for float type.
        return np.fromfile(f, dtype=np.float32).reshape(h, w)


D
dangqingqing 已提交
70 71 72 73
def main():
    paddle.init(use_gpu=False, trainer_count=1)

    # define network topology
D
update  
dangqingqing 已提交
74 75
    crf_cost, crf_dec = db_lstm(word_dict_len, label_dict_len, pred_len)

D
update  
dangqingqing 已提交
76
    parameters = paddle.parameters.create([crf_cost, crf_dec])
D
dangqingqing 已提交
77 78 79 80
    optimizer = paddle.optimizer.Momentum(momentum=0.01, learning_rate=2e-2)

    def event_handler(event):
        if isinstance(event, paddle.event.EndIteration):
D
update  
dangqingqing 已提交
81 82 83
            if event.batch_id % 100 == 0:
                print "Pass %d, Batch %d, Cost %f, %s" % (
                    event.pass_id, event.batch_id, event.cost, event.metrics)
D
dangqingqing 已提交
84 85 86
        else:
            pass

D
update  
dangqingqing 已提交
87 88 89 90 91
    trainer = paddle.trainer.SGD(cost=crf_cost,
                                 parameters=parameters,
                                 update_equation=optimizer)

    parameters.set('emb', load_parameter("data/emb", 44068, 32))
D
dangqingqing 已提交
92

D
dangqingqing 已提交
93 94
    reader_dict = {
        'word_data': 0,
D
update  
dangqingqing 已提交
95 96 97 98 99 100
        'ctx_n2_data': 1,
        'ctx_n1_data': 2,
        'ctx_0_data': 3,
        'ctx_p1_data': 4,
        'ctx_p2_data': 5,
        'verb_data': 6,
D
dangqingqing 已提交
101
        'mark_data': 7,
D
update  
dangqingqing 已提交
102
        'target': 8,
D
dangqingqing 已提交
103
    }
D
update  
dangqingqing 已提交
104 105 106
    trn_reader = paddle.reader.batched(
        paddle.reader.shuffle(
            train_reader(), buf_size=8192), batch_size=10)
D
dangqingqing 已提交
107
    trainer.train(
D
update  
dangqingqing 已提交
108
        reader=trn_reader,
D
dangqingqing 已提交
109 110
        event_handler=event_handler,
        num_passes=10000,
D
dangqingqing 已提交
111
        reader_dict=reader_dict)
D
dangqingqing 已提交
112 113 114 115


if __name__ == '__main__':
    main()