未验证 提交 91992dac 编写于 作者: C Charles-hit 提交者: GitHub

fix prim_op_test when python api outs is different with kernel sig (#50788)

上级 82f170b6
......@@ -920,7 +920,7 @@ class OpTest(unittest.TestCase):
enable_inplace=None,
for_inplace_test=None,
):
with paddle.fluid.framework._dygraph_guard(None):
with paddle.static.program_guard(paddle.static.Program()):
program = Program()
block = program.global_block()
op = self._append_ops(block)
......@@ -1249,7 +1249,7 @@ class OpTest(unittest.TestCase):
Returns:
res (tuple(outs, fetch_list, feed_map, program, op_desc)): The results of given grad_op_desc.
"""
with paddle.fluid.framework._dygraph_guard(None):
with paddle.static.program_guard(paddle.static.Program()):
(
fwd_outs,
fwd_fetch_list,
......@@ -2360,7 +2360,7 @@ class OpTest(unittest.TestCase):
user_defined_grad_outputs=None,
parallel=False,
):
with paddle.fluid.framework._dygraph_guard(None):
with paddle.static.program_guard(paddle.static.Program()):
prog = Program()
scope = core.Scope()
block = prog.global_block()
......
......@@ -930,7 +930,7 @@ class OpTest(unittest.TestCase):
enable_inplace=None,
for_inplace_test=None,
):
with paddle.fluid.framework._dygraph_guard(None):
with paddle.static.program_guard(paddle.static.Program()):
program = Program()
block = program.global_block()
op = self._append_ops(block)
......@@ -1259,7 +1259,7 @@ class OpTest(unittest.TestCase):
Returns:
res (tuple(outs, fetch_list, feed_map, program, op_desc)): The results of given grad_op_desc.
"""
with paddle.fluid.framework._dygraph_guard(None):
with paddle.static.program_guard(paddle.static.Program()):
(
fwd_outs,
fwd_fetch_list,
......@@ -2472,7 +2472,7 @@ class OpTest(unittest.TestCase):
user_defined_grad_outputs=None,
parallel=False,
):
with paddle.fluid.framework._dygraph_guard(None):
with paddle.static.program_guard(paddle.static.Program()):
prog = Program()
scope = core.Scope()
block = prog.global_block()
......
......@@ -244,7 +244,7 @@ class PrimForwardChecker:
def init_checker(self):
assert hasattr(
self.op_test, 'prim_op_type'
), "if you want to test comp op, please set prim_op_type in setUp function."
), "if you want to test comp op, please set prim_op_type with \'prim\' or \'comp\' in setUp function."
assert self.op_test.prim_op_type in [
"comp",
"prim",
......@@ -786,11 +786,12 @@ class PrimGradChecker(PrimForwardChecker):
self.recover_eager_or_static_status()
def get_output_dict(self, np_outputs, api_outputs, outputs_sig):
assert len(api_outputs) == len(outputs_sig), (
"forward api outputs length must be the same as KernelSignature outputs,but recive %s and %s"
assert len(api_outputs) <= len(outputs_sig), (
"forward api outputs length must be the less than or equal to KernelSignature outputs,but recive %s and %s"
) % (len(api_outputs), len(outputs_sig))
output_dict = {}
for i, output_name in enumerate(outputs_sig):
for i in range(len(api_outputs)):
output_name = outputs_sig[i]
if isinstance(np_outputs[output_name], list):
for j, tup in enumerate(np_outputs[output_name]):
output_dict.update({tup[0]: api_outputs[i][j]})
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册