提交 09ff2060 编写于 作者: C Cao Ying 提交者: GitHub

Merge pull request #313 from ranqiu92/ctr

fix bugs of data reader in CTR demo.
......@@ -4,50 +4,48 @@ feeding_index = {'dnn_input': 0, 'lr_input': 1, 'click': 2}
class Dataset(object):
def __init__(self):
self.mode = TaskMode.create_train()
def train(self, path):
'''
Load trainset.
'''
logger.info("load trainset from %s" % path)
self.mode = TaskMode.create_train()
self.path = path
return self._parse
mode = TaskMode.create_train()
return self._parse_creator(path, mode)
def test(self, path):
'''
Load testset.
'''
logger.info("load testset from %s" % path)
self.path = path
self.mode = TaskMode.create_test()
return self._parse
mode = TaskMode.create_test()
return self._parse_creator(path, mode)
def infer(self, path):
'''
Load infer set.
'''
logger.info("load inferset from %s" % path)
self.path = path
self.mode = TaskMode.create_infer()
return self._parse
mode = TaskMode.create_infer()
return self._parse_creator(path, mode)
def _parse(self):
def _parse_creator(self, path, mode):
'''
Parse dataset.
'''
with open(self.path) as f:
for line_id, line in enumerate(f):
fs = line.strip().split('\t')
dnn_input = load_dnn_input_record(fs[0])
lr_input = load_lr_input_record(fs[1])
if not self.mode.is_infer():
click = [int(fs[2])]
yield dnn_input, lr_input, click
else:
yield dnn_input, lr_input
def _parse():
with open(path) as f:
for line_id, line in enumerate(f):
fs = line.strip().split('\t')
dnn_input = load_dnn_input_record(fs[0])
lr_input = load_lr_input_record(fs[1])
if not mode.is_infer():
click = [int(fs[2])]
yield dnn_input, lr_input, click
else:
yield dnn_input, lr_input
return _parse
def load_data_meta(path):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册