From 723f68727db273902674e6046ead5f0ebdb78bf4 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Fri, 14 Dec 2018 17:00:48 +0800 Subject: [PATCH] add ut about nce in transpiler --- .../fluid/tests/unittests/test_dist_transpiler.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py index 650a745cdc4..8abd7d9e0cf 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py +++ b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py @@ -870,9 +870,21 @@ class TestRemoteNce(TestDistLookupTableBase): def transpiler_test_impl(self): trainer, _ = self.get_trainer() + + out_vars = ["nce_w.block0", "nce_w.block1"] + in_vars = ["nce_b.block0", "nce_b.block1"] + + recv_var_names = [] + for op in trainer.blocks[0].ops: if op.type == "recv": - pass + for var in op.output("Out"): + recv_var_names.append(var) + + for out_var in out_vars: + self.assertFalse(out_var in recv_var_names) + for in_var in in_vars: + self.assertTrue(in_var in recv_var_names) if __name__ == "__main__": -- GitLab