提交 ee8efb58 编写于 作者: Q qingqing01 提交者: GitHub

Merge pull request #4160 from Xreki/core_fix_check_outputs

Check and only check the output varibles specified by self.outputs
......@@ -192,6 +192,9 @@ class OpTest(unittest.TestCase):
self.op.run(self.scope, ctx)
for out_name, out_dup in Operator.get_op_outputs(self.op.type()):
if out_name not in self.outputs:
continue
if out_dup:
sub_out = self.outputs[out_name]
if not isinstance(sub_out, list):
......@@ -206,14 +209,12 @@ class OpTest(unittest.TestCase):
actual, expect, atol=1e-05),
"output name: " + out_name + " has diff")
else:
var = self.scope.find_var(out_name)
if var is not None:
actual = np.array(var.get_tensor())
expect = self.outputs[out_name]
self.assertTrue(
np.allclose(
actual, expect, atol=1e-05),
"output name: " + out_name + " has diff")
actual = np.array(self.scope.find_var(out_name).get_tensor())
expect = self.outputs[out_name]
self.assertTrue(
np.allclose(
actual, expect, atol=1e-05),
"output name: " + out_name + " has diff")
def check_output(self):
places = [core.CPUPlace()]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册