提交 839193fd 编写于 作者: Q Qiao Longfei

fix unit test test=develop

上级 9450048a
...@@ -61,8 +61,7 @@ class TestDistCTR2x2(TestDistRunnerBase): ...@@ -61,8 +61,7 @@ class TestDistCTR2x2(TestDistRunnerBase):
param_attr=fluid.ParamAttr( param_attr=fluid.ParamAttr(
name="deep_embedding", name="deep_embedding",
initializer=fluid.initializer.Constant(value=0.01)), initializer=fluid.initializer.Constant(value=0.01)),
is_sparse=IS_SPARSE, is_sparse=IS_SPARSE)
remote_prefetch=True)
dnn_pool = fluid.layers.sequence_pool( dnn_pool = fluid.layers.sequence_pool(
input=dnn_embedding, pool_type="sum") input=dnn_embedding, pool_type="sum")
dnn_out = dnn_pool dnn_out = dnn_pool
...@@ -84,8 +83,7 @@ class TestDistCTR2x2(TestDistRunnerBase): ...@@ -84,8 +83,7 @@ class TestDistCTR2x2(TestDistRunnerBase):
param_attr=fluid.ParamAttr( param_attr=fluid.ParamAttr(
name="wide_embedding", name="wide_embedding",
initializer=fluid.initializer.Constant(value=0.01)), initializer=fluid.initializer.Constant(value=0.01)),
is_sparse=IS_SPARSE, is_sparse=IS_SPARSE)
remote_prefetch=True)
lr_pool = fluid.layers.sequence_pool(input=lr_embbding, pool_type="sum") lr_pool = fluid.layers.sequence_pool(input=lr_embbding, pool_type="sum")
merge_layer = fluid.layers.concat(input=[dnn_out, lr_pool], axis=1) merge_layer = fluid.layers.concat(input=[dnn_out, lr_pool], axis=1)
......
...@@ -447,23 +447,19 @@ class TestEmptyPserverOptimizeBlocks(TranspilerTest): ...@@ -447,23 +447,19 @@ class TestEmptyPserverOptimizeBlocks(TranspilerTest):
class TestDistLookupTableBase(TranspilerTest): class TestDistLookupTableBase(TranspilerTest):
def network_with_table(self, def network_with_table(self, is_sparse, is_distributed):
is_sparse,
is_distributed,
remote_prefetch=False):
self.table_size = 1000 self.table_size = 1000
self.emb_size = 64 self.emb_size = 64
self.lookup_table_name = 'shared_w' self.lookup_table_name = 'shared_w'
def emb_pool(ids, table_name, is_distributed, remote_prefetch): def emb_pool(ids, table_name, is_distributed):
emb = fluid.layers.embedding( emb = fluid.layers.embedding(
input=ids, input=ids,
size=[self.table_size, self.emb_size], size=[self.table_size, self.emb_size],
dtype='float32', dtype='float32',
param_attr=table_name, param_attr=table_name,
is_sparse=is_sparse, is_sparse=is_sparse,
is_distributed=is_distributed, is_distributed=is_distributed)
remote_prefetch=remote_prefetch)
pool = fluid.layers.sequence_pool(input=emb, pool_type='average') pool = fluid.layers.sequence_pool(input=emb, pool_type='average')
return pool return pool
...@@ -473,12 +469,9 @@ class TestDistLookupTableBase(TranspilerTest): ...@@ -473,12 +469,9 @@ class TestDistLookupTableBase(TranspilerTest):
name='brand_ids', shape=[1], dtype='int64', lod_level=1) name='brand_ids', shape=[1], dtype='int64', lod_level=1)
profile_ids = fluid.layers.data( profile_ids = fluid.layers.data(
name='brand_ids', shape=[1], dtype='int64', lod_level=1) name='brand_ids', shape=[1], dtype='int64', lod_level=1)
title_emb = emb_pool(title_ids, self.lookup_table_name, is_distributed, title_emb = emb_pool(title_ids, self.lookup_table_name, is_distributed)
False) brand_emb = emb_pool(brand_ids, self.lookup_table_name, is_distributed)
brand_emb = emb_pool(brand_ids, self.lookup_table_name, is_distributed, profile_emb = emb_pool(profile_ids, "profile_emb", False)
False)
profile_emb = emb_pool(profile_ids, "profile_emb", False,
remote_prefetch)
fc0 = fluid.layers.concat( fc0 = fluid.layers.concat(
input=[title_emb, brand_emb, profile_emb], axis=1) input=[title_emb, brand_emb, profile_emb], axis=1)
predict = fluid.layers.fc(input=fc0, predict = fluid.layers.fc(input=fc0,
...@@ -794,8 +787,7 @@ class TestRemoteLookupTable(TestDistLookupTableBase): ...@@ -794,8 +787,7 @@ class TestRemoteLookupTable(TestDistLookupTableBase):
def net_conf(self): def net_conf(self):
import os import os
os.environ['PADDLE_ENABLE_REMOTE_PREFETCH'] = "1" os.environ['PADDLE_ENABLE_REMOTE_PREFETCH'] = "1"
self.network_with_table( self.network_with_table(is_sparse=True, is_distributed=False)
is_sparse=True, is_distributed=False, remote_prefetch=True)
def transpiler_test_impl(self): def transpiler_test_impl(self):
pserver1, startup1 = self.get_pserver(self.pserver1_ep) pserver1, startup1 = self.get_pserver(self.pserver1_ep)
...@@ -826,7 +818,7 @@ class TestRemoteLookupTable(TestDistLookupTableBase): ...@@ -826,7 +818,7 @@ class TestRemoteLookupTable(TestDistLookupTableBase):
'split_selected_rows', 'send', 'sequence_pool_grad', 'split_selected_rows', 'send', 'sequence_pool_grad',
'lookup_table_grad', 'sequence_pool_grad', 'lookup_table_grad', 'lookup_table_grad', 'sequence_pool_grad', 'lookup_table_grad',
'sum', 'split_selected_rows', 'send', 'send_barrier', 'recv', 'sum', 'split_selected_rows', 'send', 'send_barrier', 'recv',
'recv', 'recv', 'fetch_barrier', 'concat' 'recv', 'fetch_barrier'
] ]
self.assertEqual([op.type for op in trainer.blocks[0].ops], ops) self.assertEqual([op.type for op in trainer.blocks[0].ops], ops)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册