diff --git a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py index 650a745cdc415edf0c7b733c95a72454b1305cfd..8abd7d9e0cf35e0997c82446657cdba273ae2b98 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py +++ b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py @@ -870,9 +870,21 @@ class TestRemoteNce(TestDistLookupTableBase): def transpiler_test_impl(self): trainer, _ = self.get_trainer() + + out_vars = ["nce_w.block0", "nce_w.block1"] + in_vars = ["nce_b.block0", "nce_b.block1"] + + recv_var_names = [] + for op in trainer.blocks[0].ops: if op.type == "recv": - pass + for var in op.output("Out"): + recv_var_names.append(var) + + for out_var in out_vars: + self.assertFalse(out_var in recv_var_names) + for in_var in in_vars: + self.assertTrue(in_var in recv_var_names) if __name__ == "__main__":