diff --git a/paddle/phi/api/yaml/generator/api_base.py b/paddle/phi/api/yaml/generator/api_base.py index 2933742c7a55d0b0990c60febdce11a53c588216..f76bc688ec25e7516ea4503a027f29296ce8ebec 100644 --- a/paddle/phi/api/yaml/generator/api_base.py +++ b/paddle/phi/api/yaml/generator/api_base.py @@ -691,24 +691,44 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d input_tensor_code = input_tensor_code + f""" {code_indent} std::vector>> input_shapes;""" else: + for input_name in single_tensor_names: + if input_name in self.optional_vars: + input_tensors = input_name_tensor_map[input_name] + input_tensor_code = input_tensor_code + f""" +{code_indent} std::vector {input_name}_record_shapes;""" + for input_tensor, _ in input_tensors: + input_tensor_code = input_tensor_code + f""" +{code_indent} if({input_tensor}){{ +{code_indent} {input_name}_record_shapes.push_back((*{input_tensor}).dims()); +{code_indent} }}""" + input_tensor_code = input_tensor_code + f""" {code_indent} std::vector>> input_shapes{{""" for input_name in single_tensor_names[:-1]: - input_tensors = input_name_tensor_map[input_name] - input_tensor_code = input_tensor_code + f""" -{code_indent} {{"{input_name}", {{""" - for input_tensor, _ in input_tensors[:-1]: + if input_name in self.optional_vars: + input_tensor_code = input_tensor_code + f""" +{code_indent} {{"{input_name}", {input_name}_record_shapes}},""" + else: input_tensor_code = input_tensor_code + f""" +{code_indent} {{"{input_name}", {{""" + input_tensors = input_name_tensor_map[input_name] + for input_tensor, _ in input_tensors[:-1]: + input_tensor_code = input_tensor_code + f""" {code_indent} (*{input_tensor}).dims(),""" - input_tensor_code = input_tensor_code + f""" + input_tensor_code = input_tensor_code + f""" {code_indent} (*{input_tensors[-1][0]}).dims()}}}},""" - input_tensors = input_name_tensor_map[single_tensor_names[-1]] - input_tensor_code = input_tensor_code + f""" -{code_indent} {{"{single_tensor_names[-1]}", {{""" - for input_tensor, _ in input_tensors[:-1]: + if single_tensor_names[-1] in self.optional_vars: input_tensor_code = input_tensor_code + f""" +{code_indent} {{"{single_tensor_names[-1]}", +{code_indent} {single_tensor_names[-1]}_record_shapes}}}};""" + else: + input_tensor_code = input_tensor_code + f""" +{code_indent} {{"{single_tensor_names[-1]}", {{""" + input_tensors = input_name_tensor_map[single_tensor_names[-1]] + for input_tensor, _ in input_tensors[:-1]: + input_tensor_code = input_tensor_code + f""" {code_indent} (*{input_tensor}).dims(),""" - input_tensor_code = input_tensor_code + f""" + input_tensor_code = input_tensor_code + f""" {code_indent} (*{input_tensors[-1][0]}).dims()}}}}}};""" if list_tensor_names: input_tensor_code = input_tensor_code + f""" @@ -743,8 +763,8 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d input_tensor_code = input_tensor_code + f""" {code_indent} input_shapes.emplace_back("{input_name}", ddims_vec);""" - input_tensor_code = input_tensor_code + f""" -{code_indent} platform::RecordOpInfoSupplement("{self.api}", input_shapes); + input_tensor_code = input_tensor_code + f""" +{code_indent} platform::RecordOpInfoSupplement("{self.api}", input_shapes); {code_indent} }}""" kernel_args = ["*dev_ctx"] for param in kernel_param: @@ -857,7 +877,6 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d {code_indent} if (kernel_result.has_fallback_cpu) {{ {fallback_kernel_output_trans} {code_indent} }} - {code_indent} {self.gene_return_code()}""" def get_condition_code(self, kernel_name):