未验证 提交 2037fa68 编写于 作者: X xiongkun 提交者: GitHub

[optest]: fix transpose, support different parameter name between python_api...

[optest]: fix transpose, support  different parameter name between python_api and KernelSignature. (#40258)

* optest: fix transpose

* fix
上级 0b597754
...@@ -50,6 +50,7 @@ from paddle.fluid.tests.unittests.white_list import ( ...@@ -50,6 +50,7 @@ from paddle.fluid.tests.unittests.white_list import (
no_check_set_white_list, no_check_set_white_list,
op_threshold_white_list, op_threshold_white_list,
no_grad_set_white_list, ) no_grad_set_white_list, )
from paddle.fluid.dygraph.dygraph_to_static.utils import parse_arg_and_kwargs
def check_out_dtype(api_fn, in_specs, expect_dtypes, target_index=0, **configs): def check_out_dtype(api_fn, in_specs, expect_dtypes, target_index=0, **configs):
...@@ -698,19 +699,55 @@ class OpTest(unittest.TestCase): ...@@ -698,19 +699,55 @@ class OpTest(unittest.TestCase):
self.__class__.__name__) self.__class__.__name__)
def _calc_python_api_output(self, place): def _calc_python_api_output(self, place):
def prepare_python_api_arguments(op_proto_ins, op_proto_attrs, def prepare_python_api_arguments(api, op_proto_ins, op_proto_attrs,
kernel_sig): kernel_sig):
""" map from `op proto inputs and attrs` to `api input list and api attrs dict` """ map from `op proto inputs and attrs` to `api input list and api attrs dict`
""" """
class Empty:
pass
def is_empty(a):
return isinstance(a, Empty)
def get_default(idx, all_params_number, defaults):
related_idx = idx - all_params_number + len(defaults)
assert related_idx >= 0, "%d-th arguments don't have default value" % idx
return defaults[related_idx]
def remove_name(x):
if isinstance(x, list): return [i for i in x if i != 'name']
if isinstance(x, dict):
return {k: v for k, v in x.items() if k != 'name'}
assert False, "Only support list or dict."
def to_defaults_list(params, defaults):
return [defaults[p] for p in params if p in defaults]
# NOTE(xiongkun): why don't use input arguments dicts ? # NOTE(xiongkun): why don't use input arguments dicts ?
# Because we don't know the python api name of each arguments. # Because we don't know the python api name of each arguments.
# using parse_arg_and_kwargs, we can get the all api information we need.
api_params, api_defaults = [
remove_name(item) for item in parse_arg_and_kwargs(api)
]
api_defaults = to_defaults_list(api_params, api_defaults)
inputs_sig, attrs_sig, outputs_sig = kernel_sig inputs_sig, attrs_sig, outputs_sig = kernel_sig
input_arguments = [op_proto_ins[name] for name in inputs_sig] inputs_and_attrs = inputs_sig + attrs_sig
attr_arguments = { assert (
name: op_proto_attrs[name] len(api_params) == len(inputs_and_attrs)
for name in attrs_sig if name in op_proto_attrs ), "inputs and attrs length must equals to python api length. (May be output is in argument list?)"
} input_arguments = [op_proto_ins[name] for name in inputs_sig] + [
return input_arguments, attr_arguments op_proto_attrs[name] if name in op_proto_attrs else Empty()
for name in attrs_sig
]
results = []
for idx, arg in enumerate(input_arguments):
if is_empty(arg):
results.append(
get_default(idx, len(input_arguments), api_defaults))
else:
results.append(arg)
return results
def construct_output_dict_by_kernel_sig(ret_tuple, output_sig): def construct_output_dict_by_kernel_sig(ret_tuple, output_sig):
if not isinstance(ret_tuple, (tuple, list)): if not isinstance(ret_tuple, (tuple, list)):
...@@ -720,25 +757,27 @@ class OpTest(unittest.TestCase): ...@@ -720,25 +757,27 @@ class OpTest(unittest.TestCase):
len(output_sig), len(ret_tuple)) len(output_sig), len(ret_tuple))
return {a: b for a, b in zip(output_sig, ret_tuple)} return {a: b for a, b in zip(output_sig, ret_tuple)}
def assumption_assert_and_transform(args, argvs): def assumption_assert_and_transform(args, inp_num):
""" """
transform by the following rules: transform inputs by the following rules:
1. [Tensor] -> Tensor 1. [Tensor] -> Tensor
2. [Tensor, Tensor, ...] -> list of Tensors 2. [Tensor, Tensor, ...] -> list of Tensors
only support "X" is list of Tensor, currently don't support other structure like dict. only support "X" is list of Tensor, currently don't support other structure like dict.
""" """
for inp in args: for inp in args[:inp_num]:
assert isinstance( assert isinstance(
inp, list inp, list
), "currently only support `X` is [Tensor], don't support other structure." ), "currently only support `X` is [Tensor], don't support other structure."
args = [inp[0] if len(inp) == 1 else inp for inp in args] args = [
return args, argvs inp[0] if len(inp) == 1 else inp for inp in args[:inp_num]
] + args[inp_num:]
return args
def cal_python_api(python_api, args, argvs, kernel_sig): def cal_python_api(python_api, args, kernel_sig):
args, argvs = assumption_assert_and_transform(args, argvs)
inputs_sig, attrs_sig, outputs_sig = kernel_sig inputs_sig, attrs_sig, outputs_sig = kernel_sig
ret_tuple = python_api(*args, **argvs) args = assumption_assert_and_transform(args, len(inputs_sig))
ret_tuple = python_api(*args)
return construct_output_dict_by_kernel_sig(ret_tuple, outputs_sig) return construct_output_dict_by_kernel_sig(ret_tuple, outputs_sig)
with fluid.dygraph.base.guard(place=place): with fluid.dygraph.base.guard(place=place):
...@@ -764,11 +803,11 @@ class OpTest(unittest.TestCase): ...@@ -764,11 +803,11 @@ class OpTest(unittest.TestCase):
assert hasattr( assert hasattr(
self, "python_api" self, "python_api"
), "Please set the `self.python_api` if you want to compare python api output." ), "Please set the `self.python_api` if you want to compare python api output."
arg, argv = prepare_python_api_arguments(inputs, attrs_outputs, args = prepare_python_api_arguments(self.python_api, inputs,
kernel_sig) attrs_outputs, kernel_sig)
""" we directly return the cal_python_api value because the value is already tensor. """ we directly return the cal_python_api value because the value is already tensor.
""" """
return cal_python_api(self.python_api, arg, argv, kernel_sig) return cal_python_api(self.python_api, args, kernel_sig)
def _calc_dygraph_output(self, place, parallel=False, no_check_set=None): def _calc_dygraph_output(self, place, parallel=False, no_check_set=None):
self.__class__.op_type = self.op_type # for ci check, please not delete it for now self.__class__.op_type = self.op_type # for ci check, please not delete it for now
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册