提交 81e9797a 编写于 作者: D dongdaxiang

fix line idx bug, inherit Dataset

上级 ab9f01b1
class Dataset:
def _reader_creator(self, file_list, is_infer):
def reader():
for file in file_list:
with open(file, 'r') as f:
for line in f:
features = line.rstrip('\n').split('\t')
dense_feature = map(float, features[0].split(','))
sparse_feature = map(lambda x: [int(x)], features[1].split(','))
if not is_infer:
label = [float(features[2])]
yield [dense_feature
] + sparse_feature + [label]
else:
yield [dense_feature] + sparse_feature
return reader
def train(self, file_list):
return self._reader_creator(file_list, False)
def test(self, file_list):
return self._reader_creator(file_list, False)
def infer(self, file_list):
return self._reader_creator(file_list, True)
def __init__(self):
pass
class CriteoDataset:
class CriteoDataset(Dataset):
def __init__(self, sparse_feature_dim):
self.cont_min_ = [0, -3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
self.cont_max_ = [20, 600, 100, 50, 64000, 500, 100, 50, 500, 10, 10, 10, 50]
self.cont_diff_ = [20, 603, 100, 50, 64000, 500, 100, 50, 500, 10, 10, 10, 50]
self.hash_dim_ = sparse_feature_dim
# here, training data are lines with line_index < train_idx_
self.train_idx_ = 41256555
self.continuous_range_ = range(1, 14)
self.categorical_range_ = range(14, 40)
def _reader_creator(self, file_list, is_train, trainer_id):
def _reader_creator(self, file_list, is_train, trainer_num, trainer_id):
def reader():
for file in file_list:
with open(file, 'r') as f:
......@@ -46,7 +24,7 @@ class CriteoDataset:
continue
elif not is_train and line_idx <= self.train_idx_:
continue
if trainer_id > 0 and line_idx % trainer_id != 0:
if trainer_id > 0 and line_idx % trainer_num != trainer_id:
continue
features = line.rstrip('\n').split('\t')
dense_feature = []
......@@ -64,8 +42,8 @@ class CriteoDataset:
return reader
def train(self, file_list, trainer_id):
return self._reader_creator(file_list, True, trainer_id)
def train(self, file_list, trainer_num, trainer_id):
return self._reader_creator(file_list, True, trainer_num, trainer_id)
def test(self, file_list):
return self._reader_creator(file_list, False, -1)
......
......@@ -92,11 +92,12 @@ def parse_args():
return parser.parse_args()
def train_loop(args, train_program, data_list, loss, auc_var, batch_auc_var, trainer_id):
def train_loop(args, train_program, data_list, loss, auc_var, batch_auc_var,
trainer_num, trainer_id):
dataset = reader.CriteoDataset(args.sparse_feature_dim)
train_reader = paddle.batch(
paddle.reader.shuffle(
dataset.train([args.train_data_path], trainer_id),
dataset.train([args.train_data_path], trainer_num, trainer_id),
buf_size=args.batch_size * 100),
batch_size=args.batch_size)
place = fluid.CPUPlace()
......@@ -137,7 +138,7 @@ def train():
if args.is_local:
logger.info("run local training")
main_program = fluid.default_main_program()
train_loop(args, main_program, data_list, loss, auc_var, batch_auc_var, -1)
train_loop(args, main_program, data_list, loss, auc_var, batch_auc_var, 1, -1)
else:
logger.info("run dist training")
t = fluid.DistributeTranspiler()
......@@ -153,7 +154,7 @@ def train():
logger.info("run trainer")
train_prog = t.get_trainer_program()
train_loop(args, train_prog, data_list, loss, auc_var, batch_auc_var,
args.trainer_id + 1)
args.trainers, args.trainer_id + 1)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册