From 2037fa68db8a79ff4869afcf0ce6864d7e05449f Mon Sep 17 00:00:00 2001 From: xiongkun Date: Wed, 9 Mar 2022 11:49:44 +0800 Subject: [PATCH] [optest]: fix transpose, support different parameter name between python_api and KernelSignature. (#40258) * optest: fix transpose * fix --- .../paddle/fluid/tests/unittests/op_test.py | 75 ++++++++++++++----- 1 file changed, 57 insertions(+), 18 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index 6455da9247..457f20ac5b 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -50,6 +50,7 @@ from paddle.fluid.tests.unittests.white_list import ( no_check_set_white_list, op_threshold_white_list, no_grad_set_white_list, ) +from paddle.fluid.dygraph.dygraph_to_static.utils import parse_arg_and_kwargs def check_out_dtype(api_fn, in_specs, expect_dtypes, target_index=0, **configs): @@ -698,19 +699,55 @@ class OpTest(unittest.TestCase): self.__class__.__name__) def _calc_python_api_output(self, place): - def prepare_python_api_arguments(op_proto_ins, op_proto_attrs, + def prepare_python_api_arguments(api, op_proto_ins, op_proto_attrs, kernel_sig): """ map from `op proto inputs and attrs` to `api input list and api attrs dict` """ + + class Empty: + pass + + def is_empty(a): + return isinstance(a, Empty) + + def get_default(idx, all_params_number, defaults): + related_idx = idx - all_params_number + len(defaults) + assert related_idx >= 0, "%d-th arguments don't have default value" % idx + return defaults[related_idx] + + def remove_name(x): + if isinstance(x, list): return [i for i in x if i != 'name'] + if isinstance(x, dict): + return {k: v for k, v in x.items() if k != 'name'} + assert False, "Only support list or dict." + + def to_defaults_list(params, defaults): + return [defaults[p] for p in params if p in defaults] + # NOTE(xiongkun): why don't use input arguments dicts ? # Because we don't know the python api name of each arguments. + # using parse_arg_and_kwargs, we can get the all api information we need. + api_params, api_defaults = [ + remove_name(item) for item in parse_arg_and_kwargs(api) + ] + api_defaults = to_defaults_list(api_params, api_defaults) inputs_sig, attrs_sig, outputs_sig = kernel_sig - input_arguments = [op_proto_ins[name] for name in inputs_sig] - attr_arguments = { - name: op_proto_attrs[name] - for name in attrs_sig if name in op_proto_attrs - } - return input_arguments, attr_arguments + inputs_and_attrs = inputs_sig + attrs_sig + assert ( + len(api_params) == len(inputs_and_attrs) + ), "inputs and attrs length must equals to python api length. (May be output is in argument list?)" + input_arguments = [op_proto_ins[name] for name in inputs_sig] + [ + op_proto_attrs[name] if name in op_proto_attrs else Empty() + for name in attrs_sig + ] + results = [] + for idx, arg in enumerate(input_arguments): + if is_empty(arg): + results.append( + get_default(idx, len(input_arguments), api_defaults)) + else: + results.append(arg) + return results def construct_output_dict_by_kernel_sig(ret_tuple, output_sig): if not isinstance(ret_tuple, (tuple, list)): @@ -720,25 +757,27 @@ class OpTest(unittest.TestCase): len(output_sig), len(ret_tuple)) return {a: b for a, b in zip(output_sig, ret_tuple)} - def assumption_assert_and_transform(args, argvs): + def assumption_assert_and_transform(args, inp_num): """ - transform by the following rules: + transform inputs by the following rules: 1. [Tensor] -> Tensor 2. [Tensor, Tensor, ...] -> list of Tensors only support "X" is list of Tensor, currently don't support other structure like dict. """ - for inp in args: + for inp in args[:inp_num]: assert isinstance( inp, list ), "currently only support `X` is [Tensor], don't support other structure." - args = [inp[0] if len(inp) == 1 else inp for inp in args] - return args, argvs + args = [ + inp[0] if len(inp) == 1 else inp for inp in args[:inp_num] + ] + args[inp_num:] + return args - def cal_python_api(python_api, args, argvs, kernel_sig): - args, argvs = assumption_assert_and_transform(args, argvs) + def cal_python_api(python_api, args, kernel_sig): inputs_sig, attrs_sig, outputs_sig = kernel_sig - ret_tuple = python_api(*args, **argvs) + args = assumption_assert_and_transform(args, len(inputs_sig)) + ret_tuple = python_api(*args) return construct_output_dict_by_kernel_sig(ret_tuple, outputs_sig) with fluid.dygraph.base.guard(place=place): @@ -764,11 +803,11 @@ class OpTest(unittest.TestCase): assert hasattr( self, "python_api" ), "Please set the `self.python_api` if you want to compare python api output." - arg, argv = prepare_python_api_arguments(inputs, attrs_outputs, - kernel_sig) + args = prepare_python_api_arguments(self.python_api, inputs, + attrs_outputs, kernel_sig) """ we directly return the cal_python_api value because the value is already tensor. """ - return cal_python_api(self.python_api, arg, argv, kernel_sig) + return cal_python_api(self.python_api, args, kernel_sig) def _calc_dygraph_output(self, place, parallel=False, no_check_set=None): self.__class__.op_type = self.op_type # for ci check, please not delete it for now -- GitLab