reader.py 1.9 KB
Newer Older
S
Superjom 已提交
1 2 3 4 5 6 7 8 9 10 11
from utils import logger, TaskMode, load_dnn_input_record, load_lr_input_record

feeding_index = {'dnn_input': 0, 'lr_input': 1, 'click': 2}


class Dataset(object):
    def train(self, path):
        '''
        Load trainset.
        '''
        logger.info("load trainset from %s" % path)
R
ranqiu 已提交
12 13
        mode = TaskMode.create_train()
        return self._parse_creator(path, mode)
S
Superjom 已提交
14 15 16 17 18 19

    def test(self, path):
        '''
        Load testset.
        '''
        logger.info("load testset from %s" % path)
R
ranqiu 已提交
20 21
        mode = TaskMode.create_test()
        return self._parse_creator(path, mode)
S
Superjom 已提交
22 23 24 25 26 27

    def infer(self, path):
        '''
        Load infer set.
        '''
        logger.info("load inferset from %s" % path)
R
ranqiu 已提交
28 29
        mode = TaskMode.create_infer()
        return self._parse_creator(path, mode)
S
Superjom 已提交
30

R
ranqiu 已提交
31
    def _parse_creator(self, path, mode):
S
Superjom 已提交
32 33 34
        '''
        Parse dataset.
        '''
R
ranqiu 已提交
35 36 37 38 39 40 41 42 43 44 45 46 47 48

        def _parse():
            with open(path) as f:
                for line_id, line in enumerate(f):
                    fs = line.strip().split('\t')
                    dnn_input = load_dnn_input_record(fs[0])
                    lr_input = load_lr_input_record(fs[1])
                    if not mode.is_infer():
                        click = [int(fs[2])]
                        yield dnn_input, lr_input, click
                    else:
                        yield dnn_input, lr_input

        return _parse
S
Superjom 已提交
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64


def load_data_meta(path):
    '''
    load data meta info from path, return (dnn_input_dim, lr_input_dim)
    '''
    with open(path) as f:
        lines = f.read().split('\n')
        err_info = "wrong meta format"
        assert len(lines) == 2, err_info
        assert 'dnn_input_dim:' in lines[0] and 'lr_input_dim:' in lines[
            1], err_info
        res = map(int, [_.split(':')[1] for _ in lines])
        logger.info('dnn input dim: %d' % res[0])
        logger.info('lr input dim: %d' % res[1])
        return res