未验证 提交 a8df3901 编写于 作者: X xiongkun 提交者: GitHub

Polish optest: refine the optest parameter logic. support name, dtype, out,...

Polish optest: refine the optest parameter logic. support name, dtype, out, output in arbitrary position (#40824)

* 1. add the python api grad 2. add final and intermediate state vlog 3. change the python_api error logic

* add python api or close the check_eager=True

* fix the compatibility

* matmul

* disable unittests: test_elementwise_add_op test_scatter_nd_op test_gather_nd_op test_scatter_op test_index_sample_op test_elementwise_add_mkldnn_op

* refine the logic of prepara_parameter logic

* fix Tensor(gpu) 2 Scalar segment fault.
上级 4ccd5cb8
......@@ -713,44 +713,76 @@ class OpTest(unittest.TestCase):
def is_empty(a):
return isinstance(a, Empty)
def get_default(idx, all_params_number, defaults):
related_idx = idx - all_params_number + len(defaults)
assert related_idx >= 0, "%d-th arguments don't have default value" % idx
return defaults[related_idx]
def filter_by_name(x):
names = set(['name', 'out', 'output'])
if isinstance(x, list): return [i for i in x if i not in names]
if isinstance(x, dict):
return {k: v for k, v in x.items() if k not in names}
assert False, "Only support list or dict."
def get_default(idx, defaults):
assert not isinstance(
defaults[idx], Empty
), "%d-th params of python api don't have default value." % idx
return defaults[idx]
def to_defaults_list(params, defaults):
return [defaults[p] for p in params if p in defaults]
# NOTE(xiongkun): why don't use input arguments dicts ?
# Because we don't know the python api name of each arguments.
# using parse_arg_and_kwargs, we can get the all api information we need.
api_params, api_defaults = [
filter_by_name(item) for item in parse_arg_and_kwargs(api)
]
def parse_attri_value(name, op_inputs, op_attrs):
""" parse true value from inputs and attrs, if there is no name passed by OpTest, return Empty
1. if the name in op_attrs, use the op_attrs[name]
2. if the name in op_inputs, convert the op_inputs to [type of default value]
3. if the name not in op_attrs ans op_inputs, return Empty. (this will use the default value from python api)
"""
if name in op_proto_attrs:
return op_proto_attrs[name]
elif name in op_inputs:
assert op_inputs[name].__len__(
) == 1, "currently don't support multi-input in attribute."
# why don't use numpy().item() : if the Tensor is float64, we will change it to python.float32, where we loss accuracy: [allclose_op]
# why we reconstruct a tensor: because we want the tensor in cpu.
return paddle.to_tensor(
op_inputs[name][0].numpy(), place='cpu')
else:
return Empty()
# NOTE(xiongkun): the logic of constructing parameters:
# for example:
# python api: cumprod(x, dim, dtype=None, name=None)
# kernel sig: [["x"], ["dim"], ["out"]]"
#
# we will construct a lot of list with the same length : len == len(api_params), here is 4
# api_params = ["x", "dim", "dtype", "name"]
# api_defaults = [Empty, Empty, None, None]; empty means no defaults.
# inputs_and_attrs = ["x", "dim"] , the length may shorter or longer than api_params
# input_arguments = [RealValue in self.inputs and self.attrs]
# then ,we will loop for the api_params, construct a result list:
# if the name in ['name', 'dtype', 'out', 'output'], we will use the default value
# else, we will consume a input_arguments. (because the name is not corresponding, so we only use the order)
api_params, api_defaults = parse_arg_and_kwargs(api)
api_defaults = to_defaults_list(api_params, api_defaults)
api_defaults = [
Empty() for i in range(len(api_params) - len(api_defaults))
] + api_defaults
assert len(api_defaults) == len(
api_params), "Error happens. contack xiongkun03 to solve."
inputs_sig, attrs_sig, outputs_sig = kernel_sig
inputs_and_attrs = inputs_sig + attrs_sig
assert (
len(api_params) == len(inputs_and_attrs)
), "inputs and attrs length must equals to python api length. (May be output is in argument list?)"
input_arguments = [op_proto_ins[name] for name in inputs_sig] + [
op_proto_attrs[name] if name in op_proto_attrs else Empty()
parse_attri_value(name, op_proto_ins, op_proto_attrs)
for name in attrs_sig
]
results = []
for idx, arg in enumerate(input_arguments):
if is_empty(arg):
results.append(
get_default(idx, len(input_arguments), api_defaults))
api_ignore_param_list = set(['name', 'dtype', 'out', 'output'])
idx_of_op_proto_arguments = 0
for idx, arg_name in enumerate(api_params):
if arg_name in api_ignore_param_list:
results.append(get_default(idx, api_defaults))
else:
results.append(arg)
assert idx_of_op_proto_arguments < len(
input_arguments), "Assert False."
tmp = input_arguments[idx_of_op_proto_arguments]
idx_of_op_proto_arguments += 1
if isinstance(tmp, Empty):
results.append(get_default(idx, api_defaults))
else:
results.append(tmp)
assert len(results) == len(api_params)
return results
def construct_output_dict_by_kernel_sig(ret_tuple, output_sig):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册