diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index 2d678db4dfcb48ddefb3170ad4285112b1ba8391..0e5202209e494136b3599e714e7cf80a51b1a04b 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -713,44 +713,76 @@ class OpTest(unittest.TestCase): def is_empty(a): return isinstance(a, Empty) - def get_default(idx, all_params_number, defaults): - related_idx = idx - all_params_number + len(defaults) - assert related_idx >= 0, "%d-th arguments don't have default value" % idx - return defaults[related_idx] - - def filter_by_name(x): - names = set(['name', 'out', 'output']) - if isinstance(x, list): return [i for i in x if i not in names] - if isinstance(x, dict): - return {k: v for k, v in x.items() if k not in names} - assert False, "Only support list or dict." + def get_default(idx, defaults): + assert not isinstance( + defaults[idx], Empty + ), "%d-th params of python api don't have default value." % idx + return defaults[idx] def to_defaults_list(params, defaults): return [defaults[p] for p in params if p in defaults] - # NOTE(xiongkun): why don't use input arguments dicts ? - # Because we don't know the python api name of each arguments. - # using parse_arg_and_kwargs, we can get the all api information we need. - api_params, api_defaults = [ - filter_by_name(item) for item in parse_arg_and_kwargs(api) - ] + def parse_attri_value(name, op_inputs, op_attrs): + """ parse true value from inputs and attrs, if there is no name passed by OpTest, return Empty + 1. if the name in op_attrs, use the op_attrs[name] + 2. if the name in op_inputs, convert the op_inputs to [type of default value] + 3. if the name not in op_attrs ans op_inputs, return Empty. (this will use the default value from python api) + """ + if name in op_proto_attrs: + return op_proto_attrs[name] + elif name in op_inputs: + assert op_inputs[name].__len__( + ) == 1, "currently don't support multi-input in attribute." + # why don't use numpy().item() : if the Tensor is float64, we will change it to python.float32, where we loss accuracy: [allclose_op] + # why we reconstruct a tensor: because we want the tensor in cpu. + return paddle.to_tensor( + op_inputs[name][0].numpy(), place='cpu') + else: + return Empty() + + # NOTE(xiongkun): the logic of constructing parameters: + # for example: + # python api: cumprod(x, dim, dtype=None, name=None) + # kernel sig: [["x"], ["dim"], ["out"]]" + # + # we will construct a lot of list with the same length : len == len(api_params), here is 4 + # api_params = ["x", "dim", "dtype", "name"] + # api_defaults = [Empty, Empty, None, None]; empty means no defaults. + # inputs_and_attrs = ["x", "dim"] , the length may shorter or longer than api_params + # input_arguments = [RealValue in self.inputs and self.attrs] + # then ,we will loop for the api_params, construct a result list: + # if the name in ['name', 'dtype', 'out', 'output'], we will use the default value + # else, we will consume a input_arguments. (because the name is not corresponding, so we only use the order) + + api_params, api_defaults = parse_arg_and_kwargs(api) api_defaults = to_defaults_list(api_params, api_defaults) + api_defaults = [ + Empty() for i in range(len(api_params) - len(api_defaults)) + ] + api_defaults + assert len(api_defaults) == len( + api_params), "Error happens. contack xiongkun03 to solve." inputs_sig, attrs_sig, outputs_sig = kernel_sig inputs_and_attrs = inputs_sig + attrs_sig - assert ( - len(api_params) == len(inputs_and_attrs) - ), "inputs and attrs length must equals to python api length. (May be output is in argument list?)" input_arguments = [op_proto_ins[name] for name in inputs_sig] + [ - op_proto_attrs[name] if name in op_proto_attrs else Empty() + parse_attri_value(name, op_proto_ins, op_proto_attrs) for name in attrs_sig ] results = [] - for idx, arg in enumerate(input_arguments): - if is_empty(arg): - results.append( - get_default(idx, len(input_arguments), api_defaults)) + api_ignore_param_list = set(['name', 'dtype', 'out', 'output']) + idx_of_op_proto_arguments = 0 + for idx, arg_name in enumerate(api_params): + if arg_name in api_ignore_param_list: + results.append(get_default(idx, api_defaults)) else: - results.append(arg) + assert idx_of_op_proto_arguments < len( + input_arguments), "Assert False." + tmp = input_arguments[idx_of_op_proto_arguments] + idx_of_op_proto_arguments += 1 + if isinstance(tmp, Empty): + results.append(get_default(idx, api_defaults)) + else: + results.append(tmp) + assert len(results) == len(api_params) return results def construct_output_dict_by_kernel_sig(ret_tuple, output_sig):