reader.py 1.2 KB
Newer Older
C
caoying03 已提交
1 2 3 4 5
#!/usr/bin/env python
#coding=utf-8
import os
import random
import json
C
caoying03 已提交
6 7 8 9
import logging

logger = logging.getLogger("paddle")
logger.setLevel(logging.INFO)
C
caoying03 已提交
10 11


C
caoying03 已提交
12
def data_reader(data_list, is_train=True):
C
caoying03 已提交
13 14 15 16 17 18 19 20
    def reader():
        # every pass shuffle the data list again
        if is_train:
            random.shuffle(data_list)

        for train_sample in data_list:
            data = json.load(open(train_sample, "r"))

C
caoying03 已提交
21 22 23 24 25 26 27 28 29 30 31 32 33
            start_pos = 0
            doc = []
            same_as_question_word = []
            for l in data['sent_lengths']:
                doc.append(data['context'][start_pos:start_pos + l])
                same_as_question_word.append([
                    [[x]] for x in data['same_as_question_word']
                ][start_pos:start_pos + l])
                start_pos += l

            yield (data['question'], doc, same_as_question_word,
                   data['ans_sentence'], data['ans_start'],
                   data['ans_end'] - data['ans_start'])
C
caoying03 已提交
34 35 36 37 38 39 40 41

    return reader


if __name__ == "__main__":
    from train import choose_samples

    train_list, dev_list = choose_samples("data/featurized")
C
caoying03 已提交
42
    for i, item in enumerate(data_reader(train_list)()):
C
caoying03 已提交
43 44
        print(item)
        if i > 5: break