From c2d4523888b5ce46cf849d81ffb1ffd382de5088 Mon Sep 17 00:00:00 2001 From: zhang wenhui Date: Fri, 15 May 2020 12:09:39 +0800 Subject: [PATCH] 0515 fix ssr (#4630) * update api 1.8 * fix ssr infer --- PaddleRec/ssr/infer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/PaddleRec/ssr/infer.py b/PaddleRec/ssr/infer.py index 915cea92..42a26f14 100644 --- a/PaddleRec/ssr/infer.py +++ b/PaddleRec/ssr/infer.py @@ -40,13 +40,13 @@ def model(vocab_size, emb_size, hidden_size): user_data = fluid.data( name="user", shape=[None, 1], dtype="int64", lod_level=1) all_item_data = fluid.data( - name="all_item", shape=[None, vocab_size, 1], dtype="int64") + name="all_item", shape=[None, vocab_size], dtype="int64") user_emb = fluid.embedding( input=user_data, size=[vocab_size, emb_size], param_attr="emb.item") all_item_emb = fluid.embedding( input=all_item_data, size=[vocab_size, emb_size], param_attr="emb.item") - all_item_emb_re = all_item_emb + all_item_emb_re = fluid.layers.reshape(x=all_item_emb, shape=[-1, emb_size]) user_encoder = net.GrnnEncoder(hidden_size=hidden_size) user_enc = user_encoder.forward(user_emb) @@ -94,7 +94,7 @@ def infer(args, vocab_size, test_reader): user_data, pos_label = utils.infer_data(data, place) all_item_numpy = np.tile( np.arange(vocab_size), len(pos_label)).reshape( - len(pos_label), vocab_size, 1).astype("int64") + len(pos_label), vocab_size).astype("int64") para = exe.run(copy_program, feed={ "user": user_data, -- GitLab