Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
4d78390e
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
4d78390e
编写于
8月 25, 2022
作者:
C
chenjian
提交者:
GitHub
8月 25, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix record operator input shapes segment fault in new dygraph (#45360)
* fix segment fault * fix
上级
0d14e74a
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
32 addition
and
13 deletion
+32
-13
paddle/phi/api/yaml/generator/api_base.py
paddle/phi/api/yaml/generator/api_base.py
+32
-13
未找到文件。
paddle/phi/api/yaml/generator/api_base.py
浏览文件 @
4d78390e
...
...
@@ -691,24 +691,44 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d
input_tensor_code
=
input_tensor_code
+
f
"""
{
code_indent
}
std::vector<std::pair<const char*, std::vector<phi::DDim>>> input_shapes;"""
else
:
for
input_name
in
single_tensor_names
:
if
input_name
in
self
.
optional_vars
:
input_tensors
=
input_name_tensor_map
[
input_name
]
input_tensor_code
=
input_tensor_code
+
f
"""
{
code_indent
}
std::vector<phi::DDim>
{
input_name
}
_record_shapes;"""
for
input_tensor
,
_
in
input_tensors
:
input_tensor_code
=
input_tensor_code
+
f
"""
{
code_indent
}
if(
{
input_tensor
}
){{
{
code_indent
}
{
input_name
}
_record_shapes.push_back((*
{
input_tensor
}
).dims());
{
code_indent
}
}}"""
input_tensor_code
=
input_tensor_code
+
f
"""
{
code_indent
}
std::vector<std::pair<const char*, std::vector<phi::DDim>>> input_shapes{{"""
for
input_name
in
single_tensor_names
[:
-
1
]:
i
nput_tensors
=
input_name_tensor_map
[
input_name
]
input_tensor_code
=
input_tensor_code
+
f
"""
{
code_indent
}
{{"
{
input_name
}
", {
{
"""
for
input_tensor
,
_
in
input_tensors
[:
-
1
]
:
i
f
input_name
in
self
.
optional_vars
:
input_tensor_code
=
input_tensor_code
+
f
"""
{
code_indent
}
{{"
{
input_name
}
",
{
input_name
}
_record_shapes}},
"""
else
:
input_tensor_code
=
input_tensor_code
+
f
"""
{
code_indent
}
{{"
{
input_name
}
", {{"""
input_tensors
=
input_name_tensor_map
[
input_name
]
for
input_tensor
,
_
in
input_tensors
[:
-
1
]:
input_tensor_code
=
input_tensor_code
+
f
"""
{
code_indent
}
(*
{
input_tensor
}
).dims(),"""
input_tensor_code
=
input_tensor_code
+
f
"""
input_tensor_code
=
input_tensor_code
+
f
"""
{
code_indent
}
(*
{
input_tensors
[
-
1
][
0
]
}
).dims()}}}},"""
input_tensors
=
input_name_tensor_map
[
single_tensor_names
[
-
1
]]
input_tensor_code
=
input_tensor_code
+
f
"""
{
code_indent
}
{{"
{
single_tensor_names
[
-
1
]
}
", {{"""
for
input_tensor
,
_
in
input_tensors
[:
-
1
]:
if
single_tensor_names
[
-
1
]
in
self
.
optional_vars
:
input_tensor_code
=
input_tensor_code
+
f
"""
{
code_indent
}
{{"
{
single_tensor_names
[
-
1
]
}
",
{
code_indent
}
{
single_tensor_names
[
-
1
]
}
_record_shapes}}}};"""
else
:
input_tensor_code
=
input_tensor_code
+
f
"""
{
code_indent
}
{{"
{
single_tensor_names
[
-
1
]
}
", {{"""
input_tensors
=
input_name_tensor_map
[
single_tensor_names
[
-
1
]]
for
input_tensor
,
_
in
input_tensors
[:
-
1
]:
input_tensor_code
=
input_tensor_code
+
f
"""
{
code_indent
}
(*
{
input_tensor
}
).dims(),"""
input_tensor_code
=
input_tensor_code
+
f
"""
input_tensor_code
=
input_tensor_code
+
f
"""
{
code_indent
}
(*
{
input_tensors
[
-
1
][
0
]
}
).dims()}}}}}};"""
if
list_tensor_names
:
input_tensor_code
=
input_tensor_code
+
f
"""
...
...
@@ -743,8 +763,8 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d
input_tensor_code
=
input_tensor_code
+
f
"""
{
code_indent
}
input_shapes.emplace_back("
{
input_name
}
", ddims_vec);"""
input_tensor_code
=
input_tensor_code
+
f
"""
{
code_indent
}
platform::RecordOpInfoSupplement("
{
self
.
api
}
", input_shapes);
input_tensor_code
=
input_tensor_code
+
f
"""
{
code_indent
}
platform::RecordOpInfoSupplement("
{
self
.
api
}
", input_shapes);
{
code_indent
}
}}"""
kernel_args
=
[
"*dev_ctx"
]
for
param
in
kernel_param
:
...
...
@@ -857,7 +877,6 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d
{
code_indent
}
if (kernel_result.has_fallback_cpu) {{
{
fallback_kernel_output_trans
}
{
code_indent
}
}}
{
code_indent
}
{
self
.
gene_return_code
()
}
"""
def
get_condition_code
(
self
,
kernel_name
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录