diff --git a/python/paddle/fluid/tests/unittests/eager_op_test.py b/python/paddle/fluid/tests/unittests/eager_op_test.py index abcb671d47ac082598700bfa49e63ead46a55a1b..82b97fe5a0ec9ec84c5404ab8212305c95390bf3 100644 --- a/python/paddle/fluid/tests/unittests/eager_op_test.py +++ b/python/paddle/fluid/tests/unittests/eager_op_test.py @@ -338,6 +338,9 @@ class OpTest(unittest.TestCase): _set_use_system_allocator(cls._use_system_allocator) + if hasattr(cls, 'check_prim') and os.getenv('FLAGS_prim_test_log'): + print("check prim end!") + def is_empty_grad_op(op_type): all_op_kernels = core._get_all_register_op_kernels() grad_op = op_type + '_grad' diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index d9c4cfce2a86e308fd1acb1d96ea960b55f5662a..6518480f5e982d819cf9dcd2caef6fde890481e2 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -352,6 +352,9 @@ class OpTest(unittest.TestCase): _set_use_system_allocator(cls._use_system_allocator) + if hasattr(cls, 'check_prim') and os.getenv('FLAGS_prim_test_log'): + print("check prim end!") + def is_empty_grad_op(op_type): all_op_kernels = core._get_all_register_op_kernels() grad_op = op_type + '_grad' diff --git a/python/paddle/fluid/tests/unittests/prim_op_test.py b/python/paddle/fluid/tests/unittests/prim_op_test.py index 00665c23c36808cae1df726e088556617b3b0164..81d3b7d0ba6af0ba531d6e588a0b043143a0b74b 100644 --- a/python/paddle/fluid/tests/unittests/prim_op_test.py +++ b/python/paddle/fluid/tests/unittests/prim_op_test.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import struct from collections import defaultdict @@ -277,6 +278,8 @@ class PrimForwardChecker: if hasattr(self.op_test, 'enable_cinn') else True ) + if os.getenv('FLAGS_enable_cinn'): + self.enable_cinn = True self.enable_check_eager_comp = ( self.op_test.enable_check_eager_comp if hasattr(self.op_test, 'enable_check_eager_comp') @@ -398,8 +401,8 @@ class PrimForwardChecker: eager_tensor_inputs, attrs_outputs, _, - ) = self.get_eager_input_attr_and_inputdict() - eager_tensor_outputs = self.get_eager_empty_output() + ) = self.get_eager_input_attr_and_inputdict(stop_gradient=True) + eager_tensor_outputs = self.get_eager_empty_output(stop_gradient=True) kernel_sig = OpTestUtils._get_kernel_signature( self.op_type, eager_tensor_inputs, @@ -418,7 +421,7 @@ class PrimForwardChecker: eager_tensor_inputs, attrs_outputs, _, - ) = self.get_eager_input_attr_and_inputdict() + ) = self.get_eager_input_attr_and_inputdict(stop_gradient=True) args = OpTestUtils.prepare_python_api_arguments( self.python_api, eager_tensor_inputs, attrs_outputs, self.kernel_sig ) @@ -432,7 +435,7 @@ class PrimForwardChecker: ret = map_structure(lambda x: convert_uint16_to_float(x), ret) return ret - def get_eager_input_attr_and_inputdict(self): + def get_eager_input_attr_and_inputdict(self, stop_gradient): attrs_outputs = {} for attrs_name in self.attrs: if self.attrs[attrs_name] is not None: @@ -450,7 +453,7 @@ class PrimForwardChecker: x = paddle.to_tensor( data=tup[1], place=self.place, - stop_gradient=False, + stop_gradient=stop_gradient, dtype=dtype, ) eager_inputs[name].append(x) @@ -464,14 +467,14 @@ class PrimForwardChecker: x = paddle.to_tensor( data=item, place=self.place, - stop_gradient=False, + stop_gradient=stop_gradient, dtype=dtype, ) eager_inputs[name].append(x) input_dict.update({name: x}) return eager_inputs, attrs_outputs, input_dict - def get_eager_empty_output(self): + def get_eager_empty_output(self, stop_gradient): eager_outputs = defaultdict(list) for name, item in self.outputs.items(): if isinstance(item, list): @@ -484,7 +487,7 @@ class PrimForwardChecker: x = paddle.to_tensor( data=[], place=self.place, - stop_gradient=False, + stop_gradient=stop_gradient, dtype=dtype, ) eager_outputs[name].append(x) @@ -495,12 +498,15 @@ class PrimForwardChecker: else item.dtype ) x = paddle.to_tensor( - data=[], place=self.place, stop_gradient=False, dtype=dtype + data=[], + place=self.place, + stop_gradient=stop_gradient, + dtype=dtype, ) eager_outputs[name].append(x) return eager_outputs - def get_static_input_attr_inputdict_and_feed(self): + def get_static_input_attr_inputdict_and_feed(self, stop_gradient): attrs_outputs = {} for attrs_name in self.attrs: if self.attrs[attrs_name] is not None: @@ -519,7 +525,7 @@ class PrimForwardChecker: x = paddle.static.data( name=str(tup[0]), shape=tup[1].shape, dtype=dtype ) - x.stop_gradient = False + x.stop_gradient = stop_gradient static_inputs[name].append(x) feed.update({str(tup[0]): tup[1]}) input_dict.update({str(tup[0]): x}) @@ -530,7 +536,7 @@ class PrimForwardChecker: else item.dtype ) x = paddle.static.data(name=name, shape=item.shape, dtype=dtype) - x.stop_gradient = False + x.stop_gradient = stop_gradient static_inputs[name].append(x) feed.update({name: item}) input_dict.update({name: x}) @@ -555,7 +561,9 @@ class PrimForwardChecker: attrs, input_dict, feed, - ) = self.get_static_input_attr_inputdict_and_feed() + ) = self.get_static_input_attr_inputdict_and_feed( + stop_gradient=True + ) args = OpTestUtils.prepare_python_api_arguments( self.python_api, static_inputs, attrs, self.kernel_sig ) @@ -621,7 +629,7 @@ class PrimForwardChecker: eager_tensor_inputs, attrs_outputs, _, - ) = self.get_eager_input_attr_and_inputdict() + ) = self.get_eager_input_attr_and_inputdict(stop_gradient=True) args = OpTestUtils.prepare_python_api_arguments( self.python_api, eager_tensor_inputs, attrs_outputs, self.kernel_sig ) @@ -698,7 +706,7 @@ class PrimForwardChecker: eager_tensor_inputs, attrs_outputs, _, - ) = self.get_eager_input_attr_and_inputdict() + ) = self.get_eager_input_attr_and_inputdict(stop_gradient=True) args = OpTestUtils.prepare_python_api_arguments( self.python_api, eager_tensor_inputs, attrs_outputs, self.kernel_sig ) @@ -794,7 +802,9 @@ class PrimGradChecker(PrimForwardChecker): output_dict = {} for i in range(len(api_outputs)): output_name = outputs_sig[i] - if isinstance(np_outputs[output_name], list): + if output_name in np_outputs and isinstance( + np_outputs[output_name], list + ): for j, tup in enumerate(np_outputs[output_name]): output_dict.update({tup[0]: api_outputs[i][j]}) else: @@ -854,11 +864,13 @@ class PrimGradChecker(PrimForwardChecker): eager_tensor_inputs, attrs_outputs, inputs_dict, - ) = self.get_eager_input_attr_and_inputdict() + ) = self.get_eager_input_attr_and_inputdict(stop_gradient=False) args = OpTestUtils.prepare_python_api_arguments( self.python_api, eager_tensor_inputs, attrs_outputs, self.kernel_sig ) inputs_sig, _, outputs_sig = self.kernel_sig + if hasattr(self.op_test, "python_out_sig"): + outputs_sig = self.op_test.python_out_sig args = OpTestUtils.assumption_assert_and_transform( args, len(inputs_sig) ) @@ -954,11 +966,15 @@ class PrimGradChecker(PrimForwardChecker): attrs, inputs_dict, feed, - ) = self.get_static_input_attr_inputdict_and_feed() + ) = self.get_static_input_attr_inputdict_and_feed( + stop_gradient=False + ) args = OpTestUtils.prepare_python_api_arguments( self.python_api, static_inputs, attrs, self.kernel_sig ) inputs_sig, _, outputs_sig = self.kernel_sig + if hasattr(self.op_test, "python_out_sig"): + outputs_sig = self.op_test.python_out_sig args = OpTestUtils.assumption_assert_and_transform( args, len(inputs_sig) ) @@ -1055,7 +1071,7 @@ class PrimGradChecker(PrimForwardChecker): eager_tensor_inputs, attrs_outputs, inputs_dict, - ) = self.get_eager_input_attr_and_inputdict() + ) = self.get_eager_input_attr_and_inputdict(stop_gradient=False) args = OpTestUtils.prepare_python_api_arguments( self.python_api, eager_tensor_inputs, attrs_outputs, self.kernel_sig ) @@ -1066,6 +1082,8 @@ class PrimGradChecker(PrimForwardChecker): net = PrimNet(self.python_api) net = apply_to_static(net, False) out = _as_list(net(args)) + if hasattr(self.op_test, "python_out_sig"): + outputs_sig = self.op_test.python_out_sig outputs_dict = self.get_output_dict(self.outputs, out, outputs_sig) ys = [] if isinstance(self.output_names, list): @@ -1163,7 +1181,7 @@ class PrimGradChecker(PrimForwardChecker): eager_tensor_inputs, attrs_outputs, inputs_dict, - ) = self.get_eager_input_attr_and_inputdict() + ) = self.get_eager_input_attr_and_inputdict(stop_gradient=False) args = OpTestUtils.prepare_python_api_arguments( self.python_api, eager_tensor_inputs, attrs_outputs, self.kernel_sig ) @@ -1176,6 +1194,8 @@ class PrimGradChecker(PrimForwardChecker): net, core.is_compiled_with_cinn() and self.enable_cinn ) out = _as_list(net(args)) + if hasattr(self.op_test, "python_out_sig"): + outputs_sig = self.op_test.python_out_sig outputs_dict = self.get_output_dict(self.outputs, out, outputs_sig) ys = [] if isinstance(self.output_names, list):