提交 68aeb4e7 编写于 作者: Q Qiao Longfei

add fake init test in test_dist_transpiler

上级 a13c788a
...@@ -68,9 +68,10 @@ class FakeInitOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -68,9 +68,10 @@ class FakeInitOpMaker : public framework::OpProtoAndCheckerMaker {
"(Tensor) Tensor of specified shape will be filled " "(Tensor) Tensor of specified shape will be filled "
"with the specified value"); "with the specified value");
AddComment(R"DOC( AddComment(R"DOC(
FakeInitBatchSizeLike Operator. FakeInit Operator.
Init an op but not alloc tensor for it, it is used for distributed lookup table. Init an variable but not alloc memory for it, it is used for init the
table parameter at trainer side in distributed lookup table.
)DOC"); )DOC");
} }
......
...@@ -497,7 +497,7 @@ class TestDistLookupTable(TestDistLookupTableBase): ...@@ -497,7 +497,7 @@ class TestDistLookupTable(TestDistLookupTableBase):
# 5 save table # 5 save table
self.assertEqual([op.type for op in pserver1.blocks[5].ops], ["save"]) self.assertEqual([op.type for op in pserver1.blocks[5].ops], ["save"])
trainer, _ = self.get_trainer() trainer, trainer_startup = self.get_trainer()
self.assertEqual(len(trainer.blocks), 1) self.assertEqual(len(trainer.blocks), 1)
ops = [ ops = [
'split_ids', 'prefetch', 'merge_ids', 'sequence_pool', 'split_ids', 'split_ids', 'prefetch', 'merge_ids', 'sequence_pool', 'split_ids',
...@@ -511,6 +511,16 @@ class TestDistLookupTable(TestDistLookupTableBase): ...@@ -511,6 +511,16 @@ class TestDistLookupTable(TestDistLookupTableBase):
] ]
self.assertEqual([op.type for op in trainer.blocks[0].ops], ops) self.assertEqual([op.type for op in trainer.blocks[0].ops], ops)
startup_ops = [
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'uniform_random', 'recv', 'recv',
'fetch_barrier', 'fake_init'
]
self.assertEqual([op.type for op in trainer_startup.blocks[0].ops],
startup_ops)
class TestAsyncLocalLookupTable(TestDistLookupTableBase): class TestAsyncLocalLookupTable(TestDistLookupTableBase):
def net_conf(self): def net_conf(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册