diff --git a/PaddleRec/gru4rec/net.py b/PaddleRec/gru4rec/net.py index 7e9f65bb4e589e02620248b17fd9851f216817a4..f049643402c38c7f5501c3e7965fca262310bb6f 100644 --- a/PaddleRec/gru4rec/net.py +++ b/PaddleRec/gru4rec/net.py @@ -69,6 +69,8 @@ def train_bpr_network(vocab_size, neg_size, hid_size, drop_out=0.2): name="emb", initializer=fluid.initializer.XavierInitializer(), learning_rate=emb_lr_x)) + emb_src = fluid.layers.squeeze(input=emb_src, axes=[1]) + emb_src_drop = fluid.layers.dropout(emb_src, dropout_prob=drop_out) @@ -134,6 +136,7 @@ def train_cross_entropy_network(vocab_size, neg_size, hid_size, drop_out=0.2): name="emb", initializer=fluid.initializer.XavierInitializer(), learning_rate=emb_lr_x)) + emb_src = fluid.layers.squeeze(input=emb_src, axes=[1]) emb_src_drop = fluid.layers.dropout(emb_src, dropout_prob=drop_out)