未验证 提交 2037fa68 编写于 作者: X xiongkun 提交者: GitHub

[optest]: fix transpose, support different parameter name between python_api...

[optest]: fix transpose, support  different parameter name between python_api and KernelSignature. (#40258)

* optest: fix transpose

* fix
上级 0b597754
......@@ -50,6 +50,7 @@ from paddle.fluid.tests.unittests.white_list import (
no_check_set_white_list,
op_threshold_white_list,
no_grad_set_white_list, )
from paddle.fluid.dygraph.dygraph_to_static.utils import parse_arg_and_kwargs
def check_out_dtype(api_fn, in_specs, expect_dtypes, target_index=0, **configs):
......@@ -698,19 +699,55 @@ class OpTest(unittest.TestCase):
self.__class__.__name__)
def _calc_python_api_output(self, place):
def prepare_python_api_arguments(op_proto_ins, op_proto_attrs,
def prepare_python_api_arguments(api, op_proto_ins, op_proto_attrs,
kernel_sig):
""" map from `op proto inputs and attrs` to `api input list and api attrs dict`
"""
class Empty:
pass
def is_empty(a):
return isinstance(a, Empty)
def get_default(idx, all_params_number, defaults):
related_idx = idx - all_params_number + len(defaults)
assert related_idx >= 0, "%d-th arguments don't have default value" % idx
return defaults[related_idx]
def remove_name(x):
if isinstance(x, list): return [i for i in x if i != 'name']
if isinstance(x, dict):
return {k: v for k, v in x.items() if k != 'name'}
assert False, "Only support list or dict."
def to_defaults_list(params, defaults):
return [defaults[p] for p in params if p in defaults]
# NOTE(xiongkun): why don't use input arguments dicts ?
# Because we don't know the python api name of each arguments.
# using parse_arg_and_kwargs, we can get the all api information we need.
api_params, api_defaults = [
remove_name(item) for item in parse_arg_and_kwargs(api)
]
api_defaults = to_defaults_list(api_params, api_defaults)
inputs_sig, attrs_sig, outputs_sig = kernel_sig
input_arguments = [op_proto_ins[name] for name in inputs_sig]
attr_arguments = {
name: op_proto_attrs[name]
for name in attrs_sig if name in op_proto_attrs
}
return input_arguments, attr_arguments
inputs_and_attrs = inputs_sig + attrs_sig
assert (
len(api_params) == len(inputs_and_attrs)
), "inputs and attrs length must equals to python api length. (May be output is in argument list?)"
input_arguments = [op_proto_ins[name] for name in inputs_sig] + [
op_proto_attrs[name] if name in op_proto_attrs else Empty()
for name in attrs_sig
]
results = []
for idx, arg in enumerate(input_arguments):
if is_empty(arg):
results.append(
get_default(idx, len(input_arguments), api_defaults))
else:
results.append(arg)
return results
def construct_output_dict_by_kernel_sig(ret_tuple, output_sig):
if not isinstance(ret_tuple, (tuple, list)):
......@@ -720,25 +757,27 @@ class OpTest(unittest.TestCase):
len(output_sig), len(ret_tuple))
return {a: b for a, b in zip(output_sig, ret_tuple)}
def assumption_assert_and_transform(args, argvs):
def assumption_assert_and_transform(args, inp_num):
"""
transform by the following rules:
transform inputs by the following rules:
1. [Tensor] -> Tensor
2. [Tensor, Tensor, ...] -> list of Tensors
only support "X" is list of Tensor, currently don't support other structure like dict.
"""
for inp in args:
for inp in args[:inp_num]:
assert isinstance(
inp, list
), "currently only support `X` is [Tensor], don't support other structure."
args = [inp[0] if len(inp) == 1 else inp for inp in args]
return args, argvs
args = [
inp[0] if len(inp) == 1 else inp for inp in args[:inp_num]
] + args[inp_num:]
return args
def cal_python_api(python_api, args, argvs, kernel_sig):
args, argvs = assumption_assert_and_transform(args, argvs)
def cal_python_api(python_api, args, kernel_sig):
inputs_sig, attrs_sig, outputs_sig = kernel_sig
ret_tuple = python_api(*args, **argvs)
args = assumption_assert_and_transform(args, len(inputs_sig))
ret_tuple = python_api(*args)
return construct_output_dict_by_kernel_sig(ret_tuple, outputs_sig)
with fluid.dygraph.base.guard(place=place):
......@@ -764,11 +803,11 @@ class OpTest(unittest.TestCase):
assert hasattr(
self, "python_api"
), "Please set the `self.python_api` if you want to compare python api output."
arg, argv = prepare_python_api_arguments(inputs, attrs_outputs,
kernel_sig)
args = prepare_python_api_arguments(self.python_api, inputs,
attrs_outputs, kernel_sig)
""" we directly return the cal_python_api value because the value is already tensor.
"""
return cal_python_api(self.python_api, arg, argv, kernel_sig)
return cal_python_api(self.python_api, args, kernel_sig)
def _calc_dygraph_output(self, place, parallel=False, no_check_set=None):
self.__class__.op_type = self.op_type # for ci check, please not delete it for now
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册