提交 e025fc97 编写于 作者: S seiriosPlus

fix unit test in cpu 1.1

上级 641369f9
...@@ -120,6 +120,8 @@ class OpDescCreationMethod(object): ...@@ -120,6 +120,8 @@ class OpDescCreationMethod(object):
new_attr.strings.extend(user_defined_attr) new_attr.strings.extend(user_defined_attr)
elif attr.type == framework_pb2.BOOLEANS: elif attr.type == framework_pb2.BOOLEANS:
new_attr.bools.extend(user_defined_attr) new_attr.bools.extend(user_defined_attr)
elif attr.type == framework_pb2.LONGS:
new_attr.longs.extend(user_defined_attr)
elif attr.type == framework_pb2.INT_PAIRS: elif attr.type == framework_pb2.INT_PAIRS:
for p in user_defined_attr: for p in user_defined_attr:
pair = new_attr.int_pairs.add() pair = new_attr.int_pairs.add()
......
...@@ -480,7 +480,7 @@ class TestDistLookupTable(TestDistLookupTableBase): ...@@ -480,7 +480,7 @@ class TestDistLookupTable(TestDistLookupTableBase):
def transpiler_test_impl(self): def transpiler_test_impl(self):
pserver1, startup1 = self.get_pserver(self.pserver1_ep) pserver1, startup1 = self.get_pserver(self.pserver1_ep)
self.assertEqual(len(pserver1.blocks), 6) self.assertEqual(len(pserver1.blocks), 5)
# 0 listen_and_serv # 0 listen_and_serv
# 1 optimize for fc_w or fc_b adam # 1 optimize for fc_w or fc_b adam
self.assertEqual([op.type for op in pserver1.blocks[1].ops], self.assertEqual([op.type for op in pserver1.blocks[1].ops],
...@@ -491,23 +491,19 @@ class TestDistLookupTable(TestDistLookupTableBase): ...@@ -491,23 +491,19 @@ class TestDistLookupTable(TestDistLookupTableBase):
# 3 prefetch -> lookup_sparse_table for data0 # 3 prefetch -> lookup_sparse_table for data0
self.assertEqual([op.type for op in pserver1.blocks[3].ops], self.assertEqual([op.type for op in pserver1.blocks[3].ops],
["lookup_sparse_table"]) ["lookup_sparse_table"])
# 4 prefetch -> lookup_sparse_table for data1 # 4 save table
self.assertEqual([op.type for op in pserver1.blocks[4].ops], self.assertEqual([op.type for op in pserver1.blocks[4].ops], ["save"])
["lookup_sparse_table"])
# 5 save table
self.assertEqual([op.type for op in pserver1.blocks[5].ops], ["save"])
trainer, trainer_startup = self.get_trainer() trainer, trainer_startup = self.get_trainer()
self.assertEqual(len(trainer.blocks), 1) self.assertEqual(len(trainer.blocks), 1)
ops = [ ops = [
'split_ids', 'prefetch', 'merge_ids', 'sequence_pool', 'split_ids', 'split_ids', 'prefetch', 'merge_ids', 'sequence_pool',
'prefetch', 'merge_ids', 'sequence_pool', 'concat', 'mul', 'sequence_pool', 'concat', 'mul', 'elementwise_add',
'elementwise_add', 'cross_entropy', 'mean', 'fill_constant', 'cross_entropy', 'mean', 'fill_constant', 'mean_grad',
'mean_grad', 'cross_entropy_grad', 'elementwise_add_grad', 'send', 'cross_entropy_grad', 'elementwise_add_grad', 'send', 'mul_grad',
'mul_grad', 'send', 'concat_grad', 'sequence_pool_grad', 'send', 'concat_grad', 'sequence_pool_grad', 'lookup_table_grad',
'lookup_table_grad', 'sequence_pool_grad', 'lookup_table_grad', 'sequence_pool_grad', 'lookup_table_grad', 'sum', 'split_ids',
'sum', 'split_ids', 'send', 'send_barrier', 'recv', 'recv', 'send', 'send_barrier', 'recv', 'recv', 'fetch_barrier'
'fetch_barrier'
] ]
self.assertEqual([op.type for op in trainer.blocks[0].ops], ops) self.assertEqual([op.type for op in trainer.blocks[0].ops], ops)
...@@ -563,7 +559,7 @@ class TestAsyncDistLookupTable(TestDistLookupTableBase): ...@@ -563,7 +559,7 @@ class TestAsyncDistLookupTable(TestDistLookupTableBase):
pserver1, startup1 = self.get_pserver(self.pserver1_ep, config, False) pserver1, startup1 = self.get_pserver(self.pserver1_ep, config, False)
self.assertEqual(len(pserver1.blocks), 6) self.assertEqual(len(pserver1.blocks), 5)
# 0 listen_and_serv # 0 listen_and_serv
# 1 optimize for fc_w or fc_b adam # 1 optimize for fc_w or fc_b adam
self.assertEqual([op.type for op in pserver1.blocks[1].ops], self.assertEqual([op.type for op in pserver1.blocks[1].ops],
...@@ -573,23 +569,21 @@ class TestAsyncDistLookupTable(TestDistLookupTableBase): ...@@ -573,23 +569,21 @@ class TestAsyncDistLookupTable(TestDistLookupTableBase):
# 3 prefetch -> lookup_sparse_table for data0 # 3 prefetch -> lookup_sparse_table for data0
self.assertEqual([op.type for op in pserver1.blocks[3].ops], self.assertEqual([op.type for op in pserver1.blocks[3].ops],
["lookup_sparse_table"]) ["lookup_sparse_table"])
# 4 prefetch -> lookup_sparse_table for data1 # 4 save table
self.assertEqual([op.type for op in pserver1.blocks[4].ops], self.assertEqual([op.type for op in pserver1.blocks[4].ops], ["save"])
["lookup_sparse_table"])
# 5 save table
self.assertEqual([op.type for op in pserver1.blocks[5].ops], ["save"])
trainer, _ = self.get_trainer(config) trainer, _ = self.get_trainer(config)
self.assertEqual(len(trainer.blocks), 1) self.assertEqual(len(trainer.blocks), 1)
ops = [ ops = [
'split_ids', 'prefetch', 'merge_ids', 'sequence_pool', 'split_ids', 'split_ids', 'prefetch', 'merge_ids', 'sequence_pool',
'prefetch', 'merge_ids', 'sequence_pool', 'concat', 'mul', 'sequence_pool', 'concat', 'mul', 'elementwise_add',
'elementwise_add', 'cross_entropy', 'mean', 'fill_constant', 'cross_entropy', 'mean', 'fill_constant', 'mean_grad',
'mean_grad', 'cross_entropy_grad', 'elementwise_add_grad', 'send', 'cross_entropy_grad', 'elementwise_add_grad', 'send', 'mul_grad',
'mul_grad', 'send', 'concat_grad', 'sequence_pool_grad', 'send', 'concat_grad', 'sequence_pool_grad', 'lookup_table_grad',
'lookup_table_grad', 'sequence_pool_grad', 'lookup_table_grad', 'sequence_pool_grad', 'lookup_table_grad', 'sum', 'split_ids',
'sum', 'split_ids', 'send', 'recv', 'recv' 'send', 'recv', 'recv'
] ]
self.assertEqual([op.type for op in trainer.blocks[0].ops], ops) self.assertEqual([op.type for op in trainer.blocks[0].ops], ops)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册