From 91992dac492e326a589cc1ab51e03cd63d080773 Mon Sep 17 00:00:00 2001 From: Charles-hit <56987902+Charles-hit@users.noreply.github.com> Date: Thu, 23 Feb 2023 13:15:52 +0800 Subject: [PATCH] fix prim_op_test when python api outs is different with kernel sig (#50788) --- python/paddle/fluid/tests/unittests/eager_op_test.py | 6 +++--- python/paddle/fluid/tests/unittests/op_test.py | 6 +++--- python/paddle/fluid/tests/unittests/prim_op_test.py | 9 +++++---- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/eager_op_test.py b/python/paddle/fluid/tests/unittests/eager_op_test.py index 189c59e5783..46c612cd075 100644 --- a/python/paddle/fluid/tests/unittests/eager_op_test.py +++ b/python/paddle/fluid/tests/unittests/eager_op_test.py @@ -920,7 +920,7 @@ class OpTest(unittest.TestCase): enable_inplace=None, for_inplace_test=None, ): - with paddle.fluid.framework._dygraph_guard(None): + with paddle.static.program_guard(paddle.static.Program()): program = Program() block = program.global_block() op = self._append_ops(block) @@ -1249,7 +1249,7 @@ class OpTest(unittest.TestCase): Returns: res (tuple(outs, fetch_list, feed_map, program, op_desc)): The results of given grad_op_desc. """ - with paddle.fluid.framework._dygraph_guard(None): + with paddle.static.program_guard(paddle.static.Program()): ( fwd_outs, fwd_fetch_list, @@ -2360,7 +2360,7 @@ class OpTest(unittest.TestCase): user_defined_grad_outputs=None, parallel=False, ): - with paddle.fluid.framework._dygraph_guard(None): + with paddle.static.program_guard(paddle.static.Program()): prog = Program() scope = core.Scope() block = prog.global_block() diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index 1e3170dfc97..81a2cb12f7f 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -930,7 +930,7 @@ class OpTest(unittest.TestCase): enable_inplace=None, for_inplace_test=None, ): - with paddle.fluid.framework._dygraph_guard(None): + with paddle.static.program_guard(paddle.static.Program()): program = Program() block = program.global_block() op = self._append_ops(block) @@ -1259,7 +1259,7 @@ class OpTest(unittest.TestCase): Returns: res (tuple(outs, fetch_list, feed_map, program, op_desc)): The results of given grad_op_desc. """ - with paddle.fluid.framework._dygraph_guard(None): + with paddle.static.program_guard(paddle.static.Program()): ( fwd_outs, fwd_fetch_list, @@ -2472,7 +2472,7 @@ class OpTest(unittest.TestCase): user_defined_grad_outputs=None, parallel=False, ): - with paddle.fluid.framework._dygraph_guard(None): + with paddle.static.program_guard(paddle.static.Program()): prog = Program() scope = core.Scope() block = prog.global_block() diff --git a/python/paddle/fluid/tests/unittests/prim_op_test.py b/python/paddle/fluid/tests/unittests/prim_op_test.py index fb5b8e5088b..7d9c706801b 100644 --- a/python/paddle/fluid/tests/unittests/prim_op_test.py +++ b/python/paddle/fluid/tests/unittests/prim_op_test.py @@ -244,7 +244,7 @@ class PrimForwardChecker: def init_checker(self): assert hasattr( self.op_test, 'prim_op_type' - ), "if you want to test comp op, please set prim_op_type in setUp function." + ), "if you want to test comp op, please set prim_op_type with \'prim\' or \'comp\' in setUp function." assert self.op_test.prim_op_type in [ "comp", "prim", @@ -786,11 +786,12 @@ class PrimGradChecker(PrimForwardChecker): self.recover_eager_or_static_status() def get_output_dict(self, np_outputs, api_outputs, outputs_sig): - assert len(api_outputs) == len(outputs_sig), ( - "forward api outputs length must be the same as KernelSignature outputs,but recive %s and %s" + assert len(api_outputs) <= len(outputs_sig), ( + "forward api outputs length must be the less than or equal to KernelSignature outputs,but recive %s and %s" ) % (len(api_outputs), len(outputs_sig)) output_dict = {} - for i, output_name in enumerate(outputs_sig): + for i in range(len(api_outputs)): + output_name = outputs_sig[i] if isinstance(np_outputs[output_name], list): for j, tup in enumerate(np_outputs[output_name]): output_dict.update({tup[0]: api_outputs[i][j]}) -- GitLab