提交 c6aad62a 编写于 作者: F frankwhzhang

update ms bug

上级 85febfd5
......@@ -125,34 +125,34 @@ class MultiviewSimnet(object):
def train_net(self):
# input fields for query, pos_title, neg_title
q_slots = [
io.data(
name="q%d" % i, shape=[1], lod_level=1, dtype='int64')
fluid.data(
name="q%d" % i, shape=[None, 1], lod_level=1, dtype='int64')
for i in range(len(self.query_encoders))
]
pt_slots = [
io.data(
name="pt%d" % i, shape=[1], lod_level=1, dtype='int64')
fluid.data(
name="pt%d" % i, shape=[None, 1], lod_level=1, dtype='int64')
for i in range(len(self.title_encoders))
]
nt_slots = [
io.data(
name="nt%d" % i, shape=[1], lod_level=1, dtype='int64')
fluid.data(
name="nt%d" % i, shape=[None, 1], lod_level=1, dtype='int64')
for i in range(len(self.title_encoders))
]
# lookup embedding for each slot
q_embs = [
nn.embedding(
fluid.embedding(
input=query, size=self.emb_shape, param_attr="emb")
for query in q_slots
]
pt_embs = [
nn.embedding(
fluid.embedding(
input=title, size=self.emb_shape, param_attr="emb")
for title in pt_slots
]
nt_embs = [
nn.embedding(
fluid.embedding(
input=title, size=self.emb_shape, param_attr="emb")
for title in nt_slots
]
......@@ -205,23 +205,23 @@ class MultiviewSimnet(object):
def pred_net(self, query_fields, pos_title_fields, neg_title_fields):
q_slots = [
io.data(
name="q%d" % i, shape=[1], lod_level=1, dtype='int64')
fluid.data(
name="q%d" % i, shape=[None, 1], lod_level=1, dtype='int64')
for i in range(len(self.query_encoders))
]
pt_slots = [
io.data(
name="pt%d" % i, shape=[1], lod_level=1, dtype='int64')
fluid.data(
name="pt%d" % i, shape=[None, 1], lod_level=1, dtype='int64')
for i in range(len(self.title_encoders))
]
# lookup embedding for each slot
q_embs = [
nn.embedding(
fluid.embedding(
input=query, size=self.emb_shape, param_attr="emb")
for query in q_slots
]
pt_embs = [
nn.embedding(
fluid.embedding(
input=title, size=self.emb_shape, param_attr="emb")
for title in pt_slots
]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册