From 1a098de69cff5e69b875d980892c5cea7bf928db Mon Sep 17 00:00:00 2001 From: frankwhzhang Date: Sat, 12 Oct 2019 13:57:32 +0800 Subject: [PATCH] update ssr model --- PaddleRec/ssr/nets.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/PaddleRec/ssr/nets.py b/PaddleRec/ssr/nets.py index 4df23573..0e853429 100644 --- a/PaddleRec/ssr/nets.py +++ b/PaddleRec/ssr/nets.py @@ -86,16 +86,16 @@ class SequenceSemanticRetrieval(object): return correct def train(self): - user_data = io.data(name="user", shape=[1], dtype="int64", lod_level=1) - pos_item_data = io.data( - name="p_item", shape=[1], dtype="int64", lod_level=1) - neg_item_data = io.data( - name="n_item", shape=[1], dtype="int64", lod_level=1) - user_emb = nn.embedding( + user_data = fluid.data(name="user", shape=[None, 1], dtype="int64", lod_level=1) + pos_item_data = fluid.data( + name="p_item", shape=[None, 1], dtype="int64", lod_level=1) + neg_item_data = fluid.data( + name="n_item", shape=[None, 1], dtype="int64", lod_level=1) + user_emb = fluid.embedding( input=user_data, size=self.emb_shape, param_attr="emb.item") - pos_item_emb = nn.embedding( + pos_item_emb = fluid.embedding( input=pos_item_data, size=self.emb_shape, param_attr="emb.item") - neg_item_emb = nn.embedding( + neg_item_emb = fluid.embedding( input=neg_item_data, size=self.emb_shape, param_attr="emb.item") user_enc = self.user_encoder.forward(user_emb) pos_item_enc = self.item_encoder.forward(pos_item_emb) -- GitLab