From 466d48fd23401bdf13fce4b658754ae4da4d459f Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Mon, 18 Sep 2017 10:48:23 +0000 Subject: [PATCH] Check and only check the output varibles specified by self.outputs. --- python/paddle/v2/framework/tests/op_test.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/python/paddle/v2/framework/tests/op_test.py b/python/paddle/v2/framework/tests/op_test.py index 6bbea22c5f..0a5673868c 100644 --- a/python/paddle/v2/framework/tests/op_test.py +++ b/python/paddle/v2/framework/tests/op_test.py @@ -192,6 +192,9 @@ class OpTest(unittest.TestCase): self.op.run(self.scope, ctx) for out_name, out_dup in Operator.get_op_outputs(self.op.type()): + if out_name not in self.outputs: + continue + if out_dup: sub_out = self.outputs[out_name] if not isinstance(sub_out, list): @@ -206,14 +209,12 @@ class OpTest(unittest.TestCase): actual, expect, atol=1e-05), "output name: " + out_name + " has diff") else: - var = self.scope.find_var(out_name) - if var is not None: - actual = np.array(var.get_tensor()) - expect = self.outputs[out_name] - self.assertTrue( - np.allclose( - actual, expect, atol=1e-05), - "output name: " + out_name + " has diff") + actual = np.array(self.scope.find_var(out_name).get_tensor()) + expect = self.outputs[out_name] + self.assertTrue( + np.allclose( + actual, expect, atol=1e-05), + "output name: " + out_name + " has diff") def check_output(self): places = [core.CPUPlace()] -- GitLab