reader.py 2.9 KB
Newer Older
S
Superjom 已提交
1 2
#!/usr/bin/env python
# -*- coding: utf-8 -*-
S
Superjom 已提交
3
from utils import UNK, ModelType, TaskType, load_dic, sent2ids, logger, ModelType
S
Superjom 已提交
4 5 6 7 8 9 10 11


class Dataset(object):
    def __init__(self,
                 train_path,
                 test_path,
                 source_dic_path,
                 target_dic_path,
S
Superjom 已提交
12
                 model_type=ModelType.RANK):
S
Superjom 已提交
13 14 15 16
        self.train_path = train_path
        self.test_path = test_path
        self.source_dic_path = source_dic_path
        self.target_dic_path = target_dic_path
S
Superjom 已提交
17
        self.model_type = model_type
S
Superjom 已提交
18 19 20 21 22

        self.source_dic = load_dic(self.source_dic_path)
        self.target_dic = load_dic(self.target_dic_path)

        self.record_reader = self._read_classification_record \
S
Superjom 已提交
23
                                if self.model_type == ModelType.CLASSIFICATION \
S
Superjom 已提交
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
                                        else self._read_rank_record

    def train(self):
        logger.info("[reader] load trainset from %s" % self.train_path)
        with open(self.train_path) as f:
            for line_id, line in enumerate(f):
                yield self.record_reader(line)

    def test(self):
        logger.info("[reader] load testset from %s" % self.test_path)
        with open(self.test_path) as f:
            for line_id, line in enumerate(f):
                yield self.record_reader(line)

    def _read_classification_record(self, line):
        '''
        data format:
            <source words> [TAB] <target words> [TAB] <label>

        @line: str
            a string line which represent a record.
        '''
        fs = line.strip().split('\t')
        assert len(fs) == 3, "wrong format for classification\n" + \
            "the format shoud be " +\
            "<source words> [TAB] <target words> [TAB] <label>'"
        source = sent2ids(fs[0], self.source_dic)
        target = sent2ids(fs[1], self.target_dic)
        label = int(fs[2])
        return (source, target, label, )

    def _read_rank_record(self, line):
        '''
        data format:
            <source words> [TAB] <left_target words> [TAB] <right_target words> [TAB] <label>
        '''
        fs = line.strip().split('\t')
        assert len(fs) == 4, "wrong format for rank\n" + \
            "the format should be " +\
            "<source words> [TAB] <left_target words> [TAB] <right_target words> [TAB] <label>"

        source = sent2ids(fs[0], self.source_dic)
        left_target = sent2ids(fs[1], self.target_dic)
        right_target = sent2ids(fs[2], self.target_dic)
        label = int(fs[3])

        return (source, left_target, right_target, label)


if __name__ == '__main__':
    path = './data/classification/train.txt'
    test_path = './data/classification/test.txt'
    source_dic = './data/vocab.txt'
    dataset = Dataset(path, test_path, source_dic, source_dic,
S
Superjom 已提交
78
                      ModelType.CLASSIFICATION)
S
Superjom 已提交
79 80 81 82 83

    for rcd in dataset.train():
        print rcd
    # for i in range(10):
    #     print i, dataset.train().next()