提交 790ed55d 编写于 作者: Q Qiao Longfei

fix some problem

上级 3228073f
......@@ -26,6 +26,7 @@ logger.setLevel(logging.INFO)
class CriteoDataset(data_generator.MultiSlotDataGenerator):
def __init__(self, sparse_feature_dim, trainer_id, is_train, trainer_num):
super(CriteoDataset, self).__init__()
self.cont_min_ = [0, -3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
self.cont_max_ = [20, 600, 100, 50, 64000, 500, 100, 50, 500, 10, 10, 10, 50]
self.cont_diff_ = [20, 603, 100, 50, 64000, 500, 100, 50, 500, 10, 10, 10, 50]
......@@ -41,7 +42,6 @@ class CriteoDataset(data_generator.MultiSlotDataGenerator):
def generate_sample(self, line):
def iter():
fs = line.strip().split('\t')
self.line_idx_ += 1
if self.is_train_ and self.line_idx_ > self.train_idx_:
return
......@@ -50,30 +50,29 @@ class CriteoDataset(data_generator.MultiSlotDataGenerator):
if self.line_idx_ % self.trainer_num_ != self.trainer_id_:
return
features = line.rstrip('\n').split('\t')
ret_result = []
dense_feature = []
sparse_feature = []
for idx in self.continuous_range_:
if features[idx] == '':
dense_feature.append(0.0)
else:
dense_feature.append((float(features[idx]) - self.cont_min_[idx - 1]) / self.cont_diff_[idx - 1])
ret_result.append(("dense_feature", dense_feature))
for idx in self.categorical_range_:
sparse_feature.append([hash(str(idx) + features[idx]) % self.hash_dim_])
ret_result.append((str(idx - 13), [hash(str(idx) + features[idx]) % self.hash_dim_]))
ret_result.append(("label", [int(features[0])]))
label = [int(features[0])]
yield [dense_feature] + sparse_feature + [label]
yield ("dnn_data", dnn_input), \
("lr_data", lr_input), \
("click", click)
yield tuple(ret_result)
return iter
if __name__ == "__main__":
sparse_feature_dim = sys.argv[1]
trainer_id = sys.argv[2]
sparse_feature_dim = int(sys.argv[1])
trainer_id = int(sys.argv[2])
is_train = bool(sys.argv[3])
trainer_num = sys.argv[4]
trainer_num = int(sys.argv[4])
pairwise_reader = CriteoDataset(sparse_feature_dim, trainer_id, is_train, trainer_num)
pairwise_reader.run_from_stdin()
......@@ -213,7 +213,7 @@ def train(args):
dataset.set_batch_size(128)
dataset.set_use_var(words)
pipe_command = 'python ctr_dataset_reader.py %d %d %d %d' \
% args.sparse_feature_dim, args.trainer_id, args.is_train, args.trainer_num
% args.sparse_feature_dim, args.trainer_id, 1, args.trainer_num
dataset.set_pipe_command(pipe_command)
dataset.set_filelist(filelist)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册