Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
b83138d0
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
b83138d0
编写于
7月 18, 2022
作者:
L
levi131
提交者:
GitHub
7月 18, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add blacklist in prim2orig interface (#44383)
上级
02e9453f
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
55 addition
and
6 deletion
+55
-6
python/paddle/fluid/tests/unittests/autograd/test_transform.py
...n/paddle/fluid/tests/unittests/autograd/test_transform.py
+44
-0
python/paddle/incubate/autograd/primx.py
python/paddle/incubate/autograd/primx.py
+11
-6
未找到文件。
python/paddle/fluid/tests/unittests/autograd/test_transform.py
浏览文件 @
b83138d0
...
...
@@ -88,6 +88,12 @@ class TestAutoGradTransformForAdd(unittest.TestCase):
'mul_p'
,
'mul_p'
]
self
.
prim2orig_ops_with_blacklist
=
[
'tanh'
,
'tanh'
,
'add_p'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'elementwise_mul'
,
'sub_p'
,
'fill_constant'
,
'elementwise_mul'
,
'sub_p'
,
'fill_constant'
,
'elementwise_mul'
,
'elementwise_mul'
]
self
.
prim2orig_ops
=
[
'tanh'
,
'tanh'
,
'elementwise_add'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'elementwise_mul'
,
'elementwise_sub'
,
...
...
@@ -132,6 +138,13 @@ class TestAutoGradTransformForAdd(unittest.TestCase):
for
k
,
v
in
self
.
ys_shape_map
.
items
():
self
.
assertEqual
(
flatten_ys_bar
[
k
].
shape
,
v
)
# Test prim2orig with blacklist
prim2orig
(
block
=
self
.
main_program
.
block
(
0
),
blacklist
=
[
'add_p'
,
'sub_p'
])
prim2orig_ops
=
[
op
.
type
for
op
in
self
.
main_program
.
block
(
0
).
ops
]
self
.
assertEqual
(
sorted
(
prim2orig_ops
),
sorted
(
self
.
prim2orig_ops_with_blacklist
))
# Test prim2orig
prim2orig
(
block
=
self
.
main_program
.
block
(
0
))
prim2orig_ops
=
[
op
.
type
for
op
in
self
.
main_program
.
block
(
0
).
ops
]
...
...
@@ -198,6 +211,26 @@ class TestAutoGradTransformForMatmul(TestAutoGradTransformForAdd):
'reshape_p'
,
]
self
.
prim2orig_ops_with_blacklist
=
[
'reshape2'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'elementwise_mul'
,
'add_p'
,
'matmul_v2'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'elementwise_mul'
,
'transpose2'
,
'matmul_v2'
,
'transpose2'
,
'matmul_v2'
,
# 'elementwise_mul',
'reshape2'
,
]
self
.
prim2orig_ops
=
[
'reshape2'
,
'fill_constant'
,
...
...
@@ -312,6 +345,17 @@ class TestAutoGradTransformForIndexSelect(TestAutoGradTransformForAdd):
'add_p'
,
]
self
.
prim2orig_ops_with_blacklist
=
[
'expand_v2'
,
'add_p'
,
'reshape2'
,
'elementwise_mul'
,
'reduce_sum'
,
'sqrt'
,
'expand_v2'
,
'sub_p'
,
'concat'
,
'gather'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'elementwise_mul'
,
'reduce_sum'
,
'reshape2'
,
'reshape2'
,
'elementwise_mul'
,
'elementwise_mul'
,
'reshape2'
,
'expand_v2'
,
'elementwise_div'
,
'reduce_sum'
,
'reshape2'
,
'fill_constant'
,
'sub_p'
,
'split'
,
'fill_constant'
,
'fill_any_like'
,
'add_p'
,
'scatter'
,
'elementwise_add'
,
'add_p'
]
self
.
prim2orig_ops
=
[
'expand_v2'
,
'elementwise_add'
,
'reshape2'
,
'elementwise_mul'
,
'reduce_sum'
,
'sqrt'
,
'expand_v2'
,
'elementwise_sub'
,
'concat'
,
...
...
python/paddle/incubate/autograd/primx.py
浏览文件 @
b83138d0
...
...
@@ -408,7 +408,7 @@ class Transform(object):
# TODO(lml): supporting control flow, nested blocks, and block other than current block of main program.
def
_lower
(
block
,
reverse
):
def
_lower
(
block
,
reverse
,
blacklist
):
# Some functions which are only used in _lower.
def
bind
(
args
,
to_bind
,
value_table
):
for
i
in
range
(
len
(
args
)):
...
...
@@ -452,7 +452,7 @@ def _lower(block, reverse):
for
op_idx
in
range
(
len
(
block
.
ops
)):
op
=
block
.
ops
[
op_idx
]
ops_to_remove
.
append
(
op_idx
)
if
lookup_fn
(
op
.
type
)
is
not
None
:
if
lookup_fn
(
op
.
type
)
is
not
None
and
op
.
type
not
in
blacklist
:
input_args
=
get_input_var_list
(
op
)
bind
(
input_args
,
to_bind
,
value_table
)
...
...
@@ -535,11 +535,11 @@ def orig2prim(block=None):
block
=
default_main_program
().
current_block
()
if
block
is
None
else
block
assert
block
==
default_main_program
().
current_block
(
),
f
'block is neither None nor current block of main program'
_lower
(
block
,
reverse
=
False
)
_lower
(
block
,
reverse
=
False
,
blacklist
=
[]
)
@
framework
.
static_only
def
prim2orig
(
block
=
None
):
def
prim2orig
(
block
=
None
,
blacklist
=
None
):
"""
.. note::
**ONLY available in the static mode.**
...
...
@@ -554,6 +554,10 @@ def prim2orig(block=None):
block(paddle.static.Block|None, optional): The
target block to process on. Default None, and will
process on the current block of main program.
blacklist(list[string]|None, optional): The names of automatic
differential basic operator that will not be transformed
into original operators. Default None, and the blacklist
is treated as empty list.
Examples:
...
...
@@ -576,4 +580,5 @@ def prim2orig(block=None):
block
=
default_main_program
().
current_block
()
if
block
is
None
else
block
assert
block
==
default_main_program
().
current_block
(
),
f
'block is neither None nor current block of main program'
_lower
(
block
,
reverse
=
True
)
blacklist
=
[]
if
blacklist
is
None
else
blacklist
_lower
(
block
,
reverse
=
True
,
blacklist
=
blacklist
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录