diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index 0e5202209e494136b3599e714e7cf80a51b1a04b..d3e4b632938c01e35cbdc169c3a32b98011fbe23 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -731,12 +731,14 @@ class OpTest(unittest.TestCase): if name in op_proto_attrs: return op_proto_attrs[name] elif name in op_inputs: - assert op_inputs[name].__len__( - ) == 1, "currently don't support multi-input in attribute." - # why don't use numpy().item() : if the Tensor is float64, we will change it to python.float32, where we loss accuracy: [allclose_op] - # why we reconstruct a tensor: because we want the tensor in cpu. - return paddle.to_tensor( - op_inputs[name][0].numpy(), place='cpu') + if len(op_inputs[name]) == 1: + # why don't use numpy().item() : if the Tensor is float64, we will change it to python.float32, where we loss accuracy: [allclose_op] + # why we reconstruct a tensor: because we want the tensor in cpu. + return paddle.to_tensor( + op_inputs[name][0].numpy(), place='cpu') + else: + # if this is a list (test_unsqueeze2_op): we just pass it into the python api. + return op_inputs[name] else: return Empty() @@ -786,6 +788,8 @@ class OpTest(unittest.TestCase): return results def construct_output_dict_by_kernel_sig(ret_tuple, output_sig): + if hasattr(self, "python_out_sig"): + output_sig = self.python_out_sig if not isinstance(ret_tuple, (tuple, list)): ret_tuple = [ret_tuple] if len(output_sig) == len(ret_tuple): @@ -795,7 +799,7 @@ class OpTest(unittest.TestCase): # [assumption]: return multi-Tensor in a single output. such as paddle.split() assert len( output_sig - ) == 1, "Don't support multi-output with multi-tensor output." + ) == 1, "Don't support multi-output with multi-tensor output. (May be you can use set `python_out_sig`, see `test_squeeze2_op` as a example.)" return {output_sig[0]: ret_tuple} def assumption_assert_and_transform(args, inp_num): @@ -825,6 +829,9 @@ class OpTest(unittest.TestCase): """ we think the kernel_sig is missing. """ kernel_sig = None + print( + "[Warning: op_test.py] Kernel Signature is not found for %s, fall back to intermediate state." + % self.op_type) return kernel_sig def cal_python_api(python_api, args, kernel_sig): @@ -1942,15 +1949,17 @@ class OpTest(unittest.TestCase): attrs_outputs[attrs_name] = self.attrs[attrs_name] if check_eager: - outputs = self._calc_python_api_output(place, inputs, outputs) - + eager_outputs = self._calc_python_api_output(place, inputs, + outputs) # if outputs is None, kernel sig is empty or other error is happens. - if not check_eager or outputs is None: + if not check_eager or eager_outputs is None: block.append_op( type=self.op_type, inputs=inputs, outputs=outputs, attrs=attrs_outputs if hasattr(self, "attrs") else None) + else: + outputs = eager_outputs if self.dtype == np.uint16: cast_inputs = self._find_var_in_dygraph(outputs,