提交 ef2b1af4 编写于 作者: S seiriosPlus

fix UT

上级 4ce9e564
...@@ -67,15 +67,13 @@ class TestPSPassWithBow(unittest.TestCase): ...@@ -67,15 +67,13 @@ class TestPSPassWithBow(unittest.TestCase):
q = fluid.layers.data( q = fluid.layers.data(
name="query_ids", shape=[1], dtype="int64", lod_level=1) name="query_ids", shape=[1], dtype="int64", lod_level=1)
# embedding # embedding
q_emb = fluid.layers.embedding( q_emb = fluid.contrib.layers.sparse_embedding(
input=q, input=pt,
is_distributed=is_distributed,
size=[dict_dim, emb_dim], size=[dict_dim, emb_dim],
param_attr=fluid.ParamAttr( param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=0.01), initializer=fluid.initializer.Constant(value=0.01),
name="__emb__", name="__emb__",
learning_rate=emb_lr), learning_rate=emb_lr))
is_sparse=is_sparse)
q_emb = fluid.layers.reshape(q_emb, [-1, emb_dim]) q_emb = fluid.layers.reshape(q_emb, [-1, emb_dim])
# vsum # vsum
q_sum = fluid.layers.sequence_pool(input=q_emb, pool_type='sum') q_sum = fluid.layers.sequence_pool(input=q_emb, pool_type='sum')
...@@ -118,15 +116,13 @@ class TestPSPassWithBow(unittest.TestCase): ...@@ -118,15 +116,13 @@ class TestPSPassWithBow(unittest.TestCase):
nt = fluid.layers.data( nt = fluid.layers.data(
name="neg_title_ids", shape=[1], dtype="int64", lod_level=1) name="neg_title_ids", shape=[1], dtype="int64", lod_level=1)
# embedding # embedding
nt_emb = fluid.layers.embedding( nt_emb = fluid.contrib.layers.sparse_embedding(
input=nt, input=pt,
is_distributed=is_distributed,
size=[dict_dim, emb_dim], size=[dict_dim, emb_dim],
param_attr=fluid.ParamAttr( param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=0.01), initializer=fluid.initializer.Constant(value=0.01),
name="__emb__", name="__emb__",
learning_rate=emb_lr), learning_rate=emb_lr))
is_sparse=is_sparse)
nt_emb = fluid.layers.reshape(nt_emb, [-1, emb_dim]) nt_emb = fluid.layers.reshape(nt_emb, [-1, emb_dim])
# vsum # vsum
nt_sum = fluid.layers.sequence_pool(input=nt_emb, pool_type='sum') nt_sum = fluid.layers.sequence_pool(input=nt_emb, pool_type='sum')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册