未验证 提交 91992dac 编写于 作者: C Charles-hit 提交者: GitHub

fix prim_op_test when python api outs is different with kernel sig (#50788)

上级 82f170b6
...@@ -920,7 +920,7 @@ class OpTest(unittest.TestCase): ...@@ -920,7 +920,7 @@ class OpTest(unittest.TestCase):
enable_inplace=None, enable_inplace=None,
for_inplace_test=None, for_inplace_test=None,
): ):
with paddle.fluid.framework._dygraph_guard(None): with paddle.static.program_guard(paddle.static.Program()):
program = Program() program = Program()
block = program.global_block() block = program.global_block()
op = self._append_ops(block) op = self._append_ops(block)
...@@ -1249,7 +1249,7 @@ class OpTest(unittest.TestCase): ...@@ -1249,7 +1249,7 @@ class OpTest(unittest.TestCase):
Returns: Returns:
res (tuple(outs, fetch_list, feed_map, program, op_desc)): The results of given grad_op_desc. res (tuple(outs, fetch_list, feed_map, program, op_desc)): The results of given grad_op_desc.
""" """
with paddle.fluid.framework._dygraph_guard(None): with paddle.static.program_guard(paddle.static.Program()):
( (
fwd_outs, fwd_outs,
fwd_fetch_list, fwd_fetch_list,
...@@ -2360,7 +2360,7 @@ class OpTest(unittest.TestCase): ...@@ -2360,7 +2360,7 @@ class OpTest(unittest.TestCase):
user_defined_grad_outputs=None, user_defined_grad_outputs=None,
parallel=False, parallel=False,
): ):
with paddle.fluid.framework._dygraph_guard(None): with paddle.static.program_guard(paddle.static.Program()):
prog = Program() prog = Program()
scope = core.Scope() scope = core.Scope()
block = prog.global_block() block = prog.global_block()
......
...@@ -930,7 +930,7 @@ class OpTest(unittest.TestCase): ...@@ -930,7 +930,7 @@ class OpTest(unittest.TestCase):
enable_inplace=None, enable_inplace=None,
for_inplace_test=None, for_inplace_test=None,
): ):
with paddle.fluid.framework._dygraph_guard(None): with paddle.static.program_guard(paddle.static.Program()):
program = Program() program = Program()
block = program.global_block() block = program.global_block()
op = self._append_ops(block) op = self._append_ops(block)
...@@ -1259,7 +1259,7 @@ class OpTest(unittest.TestCase): ...@@ -1259,7 +1259,7 @@ class OpTest(unittest.TestCase):
Returns: Returns:
res (tuple(outs, fetch_list, feed_map, program, op_desc)): The results of given grad_op_desc. res (tuple(outs, fetch_list, feed_map, program, op_desc)): The results of given grad_op_desc.
""" """
with paddle.fluid.framework._dygraph_guard(None): with paddle.static.program_guard(paddle.static.Program()):
( (
fwd_outs, fwd_outs,
fwd_fetch_list, fwd_fetch_list,
...@@ -2472,7 +2472,7 @@ class OpTest(unittest.TestCase): ...@@ -2472,7 +2472,7 @@ class OpTest(unittest.TestCase):
user_defined_grad_outputs=None, user_defined_grad_outputs=None,
parallel=False, parallel=False,
): ):
with paddle.fluid.framework._dygraph_guard(None): with paddle.static.program_guard(paddle.static.Program()):
prog = Program() prog = Program()
scope = core.Scope() scope = core.Scope()
block = prog.global_block() block = prog.global_block()
......
...@@ -244,7 +244,7 @@ class PrimForwardChecker: ...@@ -244,7 +244,7 @@ class PrimForwardChecker:
def init_checker(self): def init_checker(self):
assert hasattr( assert hasattr(
self.op_test, 'prim_op_type' self.op_test, 'prim_op_type'
), "if you want to test comp op, please set prim_op_type in setUp function." ), "if you want to test comp op, please set prim_op_type with \'prim\' or \'comp\' in setUp function."
assert self.op_test.prim_op_type in [ assert self.op_test.prim_op_type in [
"comp", "comp",
"prim", "prim",
...@@ -786,11 +786,12 @@ class PrimGradChecker(PrimForwardChecker): ...@@ -786,11 +786,12 @@ class PrimGradChecker(PrimForwardChecker):
self.recover_eager_or_static_status() self.recover_eager_or_static_status()
def get_output_dict(self, np_outputs, api_outputs, outputs_sig): def get_output_dict(self, np_outputs, api_outputs, outputs_sig):
assert len(api_outputs) == len(outputs_sig), ( assert len(api_outputs) <= len(outputs_sig), (
"forward api outputs length must be the same as KernelSignature outputs,but recive %s and %s" "forward api outputs length must be the less than or equal to KernelSignature outputs,but recive %s and %s"
) % (len(api_outputs), len(outputs_sig)) ) % (len(api_outputs), len(outputs_sig))
output_dict = {} output_dict = {}
for i, output_name in enumerate(outputs_sig): for i in range(len(api_outputs)):
output_name = outputs_sig[i]
if isinstance(np_outputs[output_name], list): if isinstance(np_outputs[output_name], list):
for j, tup in enumerate(np_outputs[output_name]): for j, tup in enumerate(np_outputs[output_name]):
output_dict.update({tup[0]: api_outputs[i][j]}) output_dict.update({tup[0]: api_outputs[i][j]})
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册