未验证 提交 4d78390e 编写于 作者: C chenjian 提交者: GitHub

Fix record operator input shapes segment fault in new dygraph (#45360)

* fix segment fault

* fix
上级 0d14e74a
......@@ -691,20 +691,40 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d
input_tensor_code = input_tensor_code + f"""
{code_indent} std::vector<std::pair<const char*, std::vector<phi::DDim>>> input_shapes;"""
else:
for input_name in single_tensor_names:
if input_name in self.optional_vars:
input_tensors = input_name_tensor_map[input_name]
input_tensor_code = input_tensor_code + f"""
{code_indent} std::vector<phi::DDim> {input_name}_record_shapes;"""
for input_tensor, _ in input_tensors:
input_tensor_code = input_tensor_code + f"""
{code_indent} if({input_tensor}){{
{code_indent} {input_name}_record_shapes.push_back((*{input_tensor}).dims());
{code_indent} }}"""
input_tensor_code = input_tensor_code + f"""
{code_indent} std::vector<std::pair<const char*, std::vector<phi::DDim>>> input_shapes{{"""
for input_name in single_tensor_names[:-1]:
input_tensors = input_name_tensor_map[input_name]
if input_name in self.optional_vars:
input_tensor_code = input_tensor_code + f"""
{code_indent} {{"{input_name}", {input_name}_record_shapes}},"""
else:
input_tensor_code = input_tensor_code + f"""
{code_indent} {{"{input_name}", {{"""
input_tensors = input_name_tensor_map[input_name]
for input_tensor, _ in input_tensors[:-1]:
input_tensor_code = input_tensor_code + f"""
{code_indent} (*{input_tensor}).dims(),"""
input_tensor_code = input_tensor_code + f"""
{code_indent} (*{input_tensors[-1][0]}).dims()}}}},"""
input_tensors = input_name_tensor_map[single_tensor_names[-1]]
if single_tensor_names[-1] in self.optional_vars:
input_tensor_code = input_tensor_code + f"""
{code_indent} {{"{single_tensor_names[-1]}",
{code_indent} {single_tensor_names[-1]}_record_shapes}}}};"""
else:
input_tensor_code = input_tensor_code + f"""
{code_indent} {{"{single_tensor_names[-1]}", {{"""
input_tensors = input_name_tensor_map[single_tensor_names[-1]]
for input_tensor, _ in input_tensors[:-1]:
input_tensor_code = input_tensor_code + f"""
{code_indent} (*{input_tensor}).dims(),"""
......@@ -857,7 +877,6 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d
{code_indent} if (kernel_result.has_fallback_cpu) {{
{fallback_kernel_output_trans}
{code_indent} }}
{code_indent} {self.gene_return_code()}"""
def get_condition_code(self, kernel_name):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册