reader.py 4.0 KB
Newer Older
S
Superjom 已提交
1 2
#!/usr/bin/env python
# -*- coding: utf-8 -*-
S
Superjom 已提交
3
from utils import UNK, ModelType, TaskType, load_dic, sent2ids, logger, ModelType
S
Superjom 已提交
4 5 6


class Dataset(object):
S
Superjom 已提交
7 8
    def __init__(self, train_path, test_path, source_dic_path, target_dic_path,
                 model_type):
S
Superjom 已提交
9 10 11 12
        self.train_path = train_path
        self.test_path = test_path
        self.source_dic_path = source_dic_path
        self.target_dic_path = target_dic_path
S
Superjom 已提交
13
        self.model_type = ModelType(model_type)
S
Superjom 已提交
14 15 16 17

        self.source_dic = load_dic(self.source_dic_path)
        self.target_dic = load_dic(self.target_dic_path)

S
Superjom 已提交
18 19 20 21 22 23 24 25
        _record_reader = {
            ModelType.CLASSIFICATION_MODE: self._read_classification_record,
            ModelType.REGRESSION_MODE: self._read_regression_record,
            ModelType.RANK_MODE: self._read_rank_record,
        }

        assert isinstance(model_type, ModelType)
        self.record_reader = _record_reader[model_type.mode]
S
Superjom 已提交
26
        self.is_infer = False
S
Superjom 已提交
27 28

    def train(self):
S
Superjom 已提交
29 30 31
        '''
        Load trainset.
        '''
S
Superjom 已提交
32 33 34 35 36 37
        logger.info("[reader] load trainset from %s" % self.train_path)
        with open(self.train_path) as f:
            for line_id, line in enumerate(f):
                yield self.record_reader(line)

    def test(self):
S
Superjom 已提交
38 39 40
        '''
        Load testset.
        '''
S
Superjom 已提交
41
        # logger.info("[reader] load testset from %s" % self.test_path)
S
Superjom 已提交
42 43 44 45
        with open(self.test_path) as f:
            for line_id, line in enumerate(f):
                yield self.record_reader(line)

S
Superjom 已提交
46 47 48 49 50 51
    def infer(self):
        self.is_infer = True
        with open(self.train_path) as f:
            for line in f:
                yield self.record_reader(line)

S
Superjom 已提交
52 53 54 55 56 57 58 59 60 61 62 63 64 65
    def _read_classification_record(self, line):
        '''
        data format:
            <source words> [TAB] <target words> [TAB] <label>

        @line: str
            a string line which represent a record.
        '''
        fs = line.strip().split('\t')
        assert len(fs) == 3, "wrong format for classification\n" + \
            "the format shoud be " +\
            "<source words> [TAB] <target words> [TAB] <label>'"
        source = sent2ids(fs[0], self.source_dic)
        target = sent2ids(fs[1], self.target_dic)
S
Superjom 已提交
66 67 68 69
        if not self.is_infer:
            label = int(fs[2])
            return (source, target, label, )
        return source, target
S
Superjom 已提交
70

S
Superjom 已提交
71 72 73 74 75 76 77 78 79 80 81 82 83 84
    def _read_regression_record(self, line):
        '''
        data format:
            <source words> [TAB] <target words> [TAB] <label>

        @line: str
            a string line which represent a record.
        '''
        fs = line.strip().split('\t')
        assert len(fs) == 3, "wrong format for regression\n" + \
            "the format shoud be " +\
            "<source words> [TAB] <target words> [TAB] <label>'"
        source = sent2ids(fs[0], self.source_dic)
        target = sent2ids(fs[1], self.target_dic)
S
Superjom 已提交
85 86 87 88
        if not self.is_infer:
            label = float(fs[2])
            return (source, target, [label], )
        return source, target
S
Superjom 已提交
89

S
Superjom 已提交
90 91 92 93 94 95 96 97 98 99 100 101 102
    def _read_rank_record(self, line):
        '''
        data format:
            <source words> [TAB] <left_target words> [TAB] <right_target words> [TAB] <label>
        '''
        fs = line.strip().split('\t')
        assert len(fs) == 4, "wrong format for rank\n" + \
            "the format should be " +\
            "<source words> [TAB] <left_target words> [TAB] <right_target words> [TAB] <label>"

        source = sent2ids(fs[0], self.source_dic)
        left_target = sent2ids(fs[1], self.target_dic)
        right_target = sent2ids(fs[2], self.target_dic)
S
Superjom 已提交
103 104 105 106
        if not self.is_infer:
            label = int(fs[3])
            return (source, left_target, right_target, label)
        return source, left_target, right_target
S
Superjom 已提交
107 108 109 110 111 112 113


if __name__ == '__main__':
    path = './data/classification/train.txt'
    test_path = './data/classification/test.txt'
    source_dic = './data/vocab.txt'
    dataset = Dataset(path, test_path, source_dic, source_dic,
S
Superjom 已提交
114
                      ModelType.CLASSIFICATION)
S
Superjom 已提交
115 116 117

    for rcd in dataset.train():
        print rcd