diff --git a/ctr/reader.py b/ctr/reader.py index d511b9d7c3b3a843b1e2819481b377ba5a49ce1a..cafa2349ed0e51a8de65dbeeea8b345edcf0a879 100644 --- a/ctr/reader.py +++ b/ctr/reader.py @@ -4,50 +4,48 @@ feeding_index = {'dnn_input': 0, 'lr_input': 1, 'click': 2} class Dataset(object): - def __init__(self): - self.mode = TaskMode.create_train() - def train(self, path): ''' Load trainset. ''' logger.info("load trainset from %s" % path) - self.mode = TaskMode.create_train() - self.path = path - return self._parse + mode = TaskMode.create_train() + return self._parse_creator(path, mode) def test(self, path): ''' Load testset. ''' logger.info("load testset from %s" % path) - self.path = path - self.mode = TaskMode.create_test() - return self._parse + mode = TaskMode.create_test() + return self._parse_creator(path, mode) def infer(self, path): ''' Load infer set. ''' logger.info("load inferset from %s" % path) - self.path = path - self.mode = TaskMode.create_infer() - return self._parse + mode = TaskMode.create_infer() + return self._parse_creator(path, mode) - def _parse(self): + def _parse_creator(self, path, mode): ''' Parse dataset. ''' - with open(self.path) as f: - for line_id, line in enumerate(f): - fs = line.strip().split('\t') - dnn_input = load_dnn_input_record(fs[0]) - lr_input = load_lr_input_record(fs[1]) - if not self.mode.is_infer(): - click = [int(fs[2])] - yield dnn_input, lr_input, click - else: - yield dnn_input, lr_input + + def _parse(): + with open(path) as f: + for line_id, line in enumerate(f): + fs = line.strip().split('\t') + dnn_input = load_dnn_input_record(fs[0]) + lr_input = load_lr_input_record(fs[1]) + if not mode.is_infer(): + click = [int(fs[2])] + yield dnn_input, lr_input, click + else: + yield dnn_input, lr_input + + return _parse def load_data_meta(path):