提交 723f6872 编写于 作者: T tangwei12

add ut about nce in transpiler

上级 50fce879
...@@ -870,9 +870,21 @@ class TestRemoteNce(TestDistLookupTableBase): ...@@ -870,9 +870,21 @@ class TestRemoteNce(TestDistLookupTableBase):
def transpiler_test_impl(self): def transpiler_test_impl(self):
trainer, _ = self.get_trainer() trainer, _ = self.get_trainer()
out_vars = ["nce_w.block0", "nce_w.block1"]
in_vars = ["nce_b.block0", "nce_b.block1"]
recv_var_names = []
for op in trainer.blocks[0].ops: for op in trainer.blocks[0].ops:
if op.type == "recv": if op.type == "recv":
pass for var in op.output("Out"):
recv_var_names.append(var)
for out_var in out_vars:
self.assertFalse(out_var in recv_var_names)
for in_var in in_vars:
self.assertTrue(in_var in recv_var_names)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册