提交 1a098de6 编写于 作者: F frankwhzhang

update ssr model

上级 c6aad62a
......@@ -86,16 +86,16 @@ class SequenceSemanticRetrieval(object):
return correct
def train(self):
user_data = io.data(name="user", shape=[1], dtype="int64", lod_level=1)
pos_item_data = io.data(
name="p_item", shape=[1], dtype="int64", lod_level=1)
neg_item_data = io.data(
name="n_item", shape=[1], dtype="int64", lod_level=1)
user_emb = nn.embedding(
user_data = fluid.data(name="user", shape=[None, 1], dtype="int64", lod_level=1)
pos_item_data = fluid.data(
name="p_item", shape=[None, 1], dtype="int64", lod_level=1)
neg_item_data = fluid.data(
name="n_item", shape=[None, 1], dtype="int64", lod_level=1)
user_emb = fluid.embedding(
input=user_data, size=self.emb_shape, param_attr="emb.item")
pos_item_emb = nn.embedding(
pos_item_emb = fluid.embedding(
input=pos_item_data, size=self.emb_shape, param_attr="emb.item")
neg_item_emb = nn.embedding(
neg_item_emb = fluid.embedding(
input=neg_item_data, size=self.emb_shape, param_attr="emb.item")
user_enc = self.user_encoder.forward(user_emb)
pos_item_enc = self.item_encoder.forward(pos_item_emb)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册