未验证 提交 6290b4dd 编写于 作者: G guru4elephant 提交者: GitHub

Merge pull request #1451 from jacquesqiao/fix-ctr

fix ctr reader
...@@ -15,8 +15,12 @@ def ctr_dnn_model(embedding_size, sparse_feature_dim): ...@@ -15,8 +15,12 @@ def ctr_dnn_model(embedding_size, sparse_feature_dim):
def embedding_layer(input): def embedding_layer(input):
return fluid.layers.embedding( return fluid.layers.embedding(
input=input, input=input,
is_sparse=True,
# you need to patch https://github.com/PaddlePaddle/Paddle/pull/14190
# if you want to set is_distributed to True
is_distributed=False,
size=[sparse_feature_dim, embedding_size], size=[sparse_feature_dim, embedding_size],
param_attr=fluid.ParamAttr(name="SparseFeatFactors", initializer=fluid.initializer.Normal(scale=1/math.sqrt(sparse_feature_dim)))) param_attr=fluid.ParamAttr(name="SparseFeatFactors", initializer=fluid.initializer.Uniform()))
sparse_embed_seq = map(embedding_layer, sparse_input_ids) sparse_embed_seq = map(embedding_layer, sparse_input_ids)
concated = fluid.layers.concat(sparse_embed_seq + [dense_input], axis=1) concated = fluid.layers.concat(sparse_embed_seq + [dense_input], axis=1)
......
...@@ -21,10 +21,10 @@ class CriteoDataset(Dataset): ...@@ -21,10 +21,10 @@ class CriteoDataset(Dataset):
for line in f: for line in f:
line_idx += 1 line_idx += 1
if is_train and line_idx > self.train_idx_: if is_train and line_idx > self.train_idx_:
continue break
elif not is_train and line_idx <= self.train_idx_: elif not is_train and line_idx <= self.train_idx_:
continue continue
if trainer_id > 0 and line_idx % trainer_num != trainer_id: if line_idx % trainer_num != trainer_id:
continue continue
features = line.rstrip('\n').split('\t') features = line.rstrip('\n').split('\t')
dense_feature = [] dense_feature = []
......
...@@ -138,7 +138,7 @@ def train(): ...@@ -138,7 +138,7 @@ def train():
if args.is_local: if args.is_local:
logger.info("run local training") logger.info("run local training")
main_program = fluid.default_main_program() main_program = fluid.default_main_program()
train_loop(args, main_program, data_list, loss, auc_var, batch_auc_var, 1, -1) train_loop(args, main_program, data_list, loss, auc_var, batch_auc_var, 1, 0)
else: else:
logger.info("run dist training") logger.info("run dist training")
t = fluid.DistributeTranspiler() t = fluid.DistributeTranspiler()
...@@ -154,7 +154,7 @@ def train(): ...@@ -154,7 +154,7 @@ def train():
logger.info("run trainer") logger.info("run trainer")
train_prog = t.get_trainer_program() train_prog = t.get_trainer_program()
train_loop(args, train_prog, data_list, loss, auc_var, batch_auc_var, train_loop(args, train_prog, data_list, loss, auc_var, batch_auc_var,
args.trainers, args.trainer_id + 1) args.trainers, args.trainer_id)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册