未验证 提交 4d78390e 编写于 作者: C chenjian 提交者: GitHub

Fix record operator input shapes segment fault in new dygraph (#45360)

* fix segment fault

* fix
上级 0d14e74a
...@@ -691,24 +691,44 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d ...@@ -691,24 +691,44 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d
input_tensor_code = input_tensor_code + f""" input_tensor_code = input_tensor_code + f"""
{code_indent} std::vector<std::pair<const char*, std::vector<phi::DDim>>> input_shapes;""" {code_indent} std::vector<std::pair<const char*, std::vector<phi::DDim>>> input_shapes;"""
else: else:
for input_name in single_tensor_names:
if input_name in self.optional_vars:
input_tensors = input_name_tensor_map[input_name]
input_tensor_code = input_tensor_code + f"""
{code_indent} std::vector<phi::DDim> {input_name}_record_shapes;"""
for input_tensor, _ in input_tensors:
input_tensor_code = input_tensor_code + f"""
{code_indent} if({input_tensor}){{
{code_indent} {input_name}_record_shapes.push_back((*{input_tensor}).dims());
{code_indent} }}"""
input_tensor_code = input_tensor_code + f""" input_tensor_code = input_tensor_code + f"""
{code_indent} std::vector<std::pair<const char*, std::vector<phi::DDim>>> input_shapes{{""" {code_indent} std::vector<std::pair<const char*, std::vector<phi::DDim>>> input_shapes{{"""
for input_name in single_tensor_names[:-1]: for input_name in single_tensor_names[:-1]:
input_tensors = input_name_tensor_map[input_name] if input_name in self.optional_vars:
input_tensor_code = input_tensor_code + f""" input_tensor_code = input_tensor_code + f"""
{code_indent} {{"{input_name}", {{""" {code_indent} {{"{input_name}", {input_name}_record_shapes}},"""
for input_tensor, _ in input_tensors[:-1]: else:
input_tensor_code = input_tensor_code + f""" input_tensor_code = input_tensor_code + f"""
{code_indent} {{"{input_name}", {{"""
input_tensors = input_name_tensor_map[input_name]
for input_tensor, _ in input_tensors[:-1]:
input_tensor_code = input_tensor_code + f"""
{code_indent} (*{input_tensor}).dims(),""" {code_indent} (*{input_tensor}).dims(),"""
input_tensor_code = input_tensor_code + f""" input_tensor_code = input_tensor_code + f"""
{code_indent} (*{input_tensors[-1][0]}).dims()}}}},""" {code_indent} (*{input_tensors[-1][0]}).dims()}}}},"""
input_tensors = input_name_tensor_map[single_tensor_names[-1]] if single_tensor_names[-1] in self.optional_vars:
input_tensor_code = input_tensor_code + f"""
{code_indent} {{"{single_tensor_names[-1]}", {{"""
for input_tensor, _ in input_tensors[:-1]:
input_tensor_code = input_tensor_code + f""" input_tensor_code = input_tensor_code + f"""
{code_indent} {{"{single_tensor_names[-1]}",
{code_indent} {single_tensor_names[-1]}_record_shapes}}}};"""
else:
input_tensor_code = input_tensor_code + f"""
{code_indent} {{"{single_tensor_names[-1]}", {{"""
input_tensors = input_name_tensor_map[single_tensor_names[-1]]
for input_tensor, _ in input_tensors[:-1]:
input_tensor_code = input_tensor_code + f"""
{code_indent} (*{input_tensor}).dims(),""" {code_indent} (*{input_tensor}).dims(),"""
input_tensor_code = input_tensor_code + f""" input_tensor_code = input_tensor_code + f"""
{code_indent} (*{input_tensors[-1][0]}).dims()}}}}}};""" {code_indent} (*{input_tensors[-1][0]}).dims()}}}}}};"""
if list_tensor_names: if list_tensor_names:
input_tensor_code = input_tensor_code + f""" input_tensor_code = input_tensor_code + f"""
...@@ -743,8 +763,8 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d ...@@ -743,8 +763,8 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d
input_tensor_code = input_tensor_code + f""" input_tensor_code = input_tensor_code + f"""
{code_indent} input_shapes.emplace_back("{input_name}", ddims_vec);""" {code_indent} input_shapes.emplace_back("{input_name}", ddims_vec);"""
input_tensor_code = input_tensor_code + f""" input_tensor_code = input_tensor_code + f"""
{code_indent} platform::RecordOpInfoSupplement("{self.api}", input_shapes); {code_indent} platform::RecordOpInfoSupplement("{self.api}", input_shapes);
{code_indent} }}""" {code_indent} }}"""
kernel_args = ["*dev_ctx"] kernel_args = ["*dev_ctx"]
for param in kernel_param: for param in kernel_param:
...@@ -857,7 +877,6 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d ...@@ -857,7 +877,6 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d
{code_indent} if (kernel_result.has_fallback_cpu) {{ {code_indent} if (kernel_result.has_fallback_cpu) {{
{fallback_kernel_output_trans} {fallback_kernel_output_trans}
{code_indent} }} {code_indent} }}
{code_indent} {self.gene_return_code()}""" {code_indent} {self.gene_return_code()}"""
def get_condition_code(self, kernel_name): def get_condition_code(self, kernel_name):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册