Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
d83fe716
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
d83fe716
编写于
2月 12, 2023
作者:
C
cyber-pioneer
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add input map check
上级
3e85dbb6
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
11 addition
and
3 deletion
+11
-3
python/paddle/incubate/autograd/primx.py
python/paddle/incubate/autograd/primx.py
+1
-1
python/paddle/incubate/autograd/utils.py
python/paddle/incubate/autograd/utils.py
+10
-2
未找到文件。
python/paddle/incubate/autograd/primx.py
浏览文件 @
d83fe716
...
@@ -684,7 +684,7 @@ def _lower_composite(block, blacklist=[]):
...
@@ -684,7 +684,7 @@ def _lower_composite(block, blacklist=[]):
del
block
.
vars
[
var_name
]
del
block
.
vars
[
var_name
]
block
.
_sync_with_cpp
()
block
.
_sync_with_cpp
()
# composite ops may contain other ops, thus, call _lower_composite again.
# composite ops may contain other
composite
ops, thus, call _lower_composite again.
if
change
:
if
change
:
_lower_composite
(
block
,
blacklist
)
_lower_composite
(
block
,
blacklist
)
return
return
...
...
python/paddle/incubate/autograd/utils.py
浏览文件 @
d83fe716
...
@@ -169,6 +169,7 @@ def _get_args_values(op, phi_name):
...
@@ -169,6 +169,7 @@ def _get_args_values(op, phi_name):
arg_type
,
arg_name
=
_solve_arg
(
item
)
arg_type
,
arg_name
=
_solve_arg
(
item
)
op_content
=
op_map
[
op
.
type
]
op_content
=
op_map
[
op
.
type
]
if
arg_type
in
(
"Tensor"
,
"Tensor[]"
):
if
arg_type
in
(
"Tensor"
,
"Tensor[]"
):
# assume Tensor type must belong to inputs
if
(
if
(
"inputs"
in
op_content
.
keys
()
"inputs"
in
op_content
.
keys
()
and
arg_name
in
op_content
[
"inputs"
].
keys
()
and
arg_name
in
op_content
[
"inputs"
].
keys
()
...
@@ -182,7 +183,9 @@ def _get_args_values(op, phi_name):
...
@@ -182,7 +183,9 @@ def _get_args_values(op, phi_name):
"attrs"
in
op_content
.
keys
()
"attrs"
in
op_content
.
keys
()
and
arg_name
in
op_content
[
"attrs"
].
keys
()
and
arg_name
in
op_content
[
"attrs"
].
keys
()
):
):
attrs
.
append
(
op
.
attr
(
op_content
[
"attrs"
][
arg_name
]))
arg_name
=
op_content
[
"attrs"
][
arg_name
]
if
arg_name
not
in
op
.
attr_names
:
attrs
.
append
(
None
)
else
:
else
:
attrs
.
append
(
op
.
attr
(
arg_name
))
attrs
.
append
(
op
.
attr
(
arg_name
))
...
@@ -203,7 +206,12 @@ def prepare_python_api_arguments(op):
...
@@ -203,7 +206,12 @@ def prepare_python_api_arguments(op):
else
:
else
:
phi_name
=
op
.
type
phi_name
=
op
.
type
inputs
,
attrs
=
_get_args_values
(
op
,
phi_name
)
inputs
,
attrs
=
_get_args_values
(
op
,
phi_name
)
res
=
[
get_var_block
(
op
.
block
,
op
.
input
(
n
))
for
n
in
inputs
]
res
=
[]
for
item
in
inputs
:
if
item
in
op
.
input_names
:
res
.
append
(
get_var_block
(
op
.
block
,
op
.
input
(
item
)))
else
:
res
.
append
(
None
)
if
attrs
:
if
attrs
:
res
.
extend
(
attrs
)
res
.
extend
(
attrs
)
return
res
return
res
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录