提交 ee8efb58 编写于 作者: Q qingqing01 提交者: GitHub

Merge pull request #4160 from Xreki/core_fix_check_outputs

Check and only check the output varibles specified by self.outputs
...@@ -192,6 +192,9 @@ class OpTest(unittest.TestCase): ...@@ -192,6 +192,9 @@ class OpTest(unittest.TestCase):
self.op.run(self.scope, ctx) self.op.run(self.scope, ctx)
for out_name, out_dup in Operator.get_op_outputs(self.op.type()): for out_name, out_dup in Operator.get_op_outputs(self.op.type()):
if out_name not in self.outputs:
continue
if out_dup: if out_dup:
sub_out = self.outputs[out_name] sub_out = self.outputs[out_name]
if not isinstance(sub_out, list): if not isinstance(sub_out, list):
...@@ -206,14 +209,12 @@ class OpTest(unittest.TestCase): ...@@ -206,14 +209,12 @@ class OpTest(unittest.TestCase):
actual, expect, atol=1e-05), actual, expect, atol=1e-05),
"output name: " + out_name + " has diff") "output name: " + out_name + " has diff")
else: else:
var = self.scope.find_var(out_name) actual = np.array(self.scope.find_var(out_name).get_tensor())
if var is not None: expect = self.outputs[out_name]
actual = np.array(var.get_tensor()) self.assertTrue(
expect = self.outputs[out_name] np.allclose(
self.assertTrue( actual, expect, atol=1e-05),
np.allclose( "output name: " + out_name + " has diff")
actual, expect, atol=1e-05),
"output name: " + out_name + " has diff")
def check_output(self): def check_output(self):
places = [core.CPUPlace()] places = [core.CPUPlace()]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册