Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
b83138d0
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
b83138d0
编写于
7月 18, 2022
作者:
L
levi131
提交者:
GitHub
7月 18, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add blacklist in prim2orig interface (#44383)
上级
02e9453f
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
55 addition
and
6 deletion
+55
-6
python/paddle/fluid/tests/unittests/autograd/test_transform.py
...n/paddle/fluid/tests/unittests/autograd/test_transform.py
+44
-0
python/paddle/incubate/autograd/primx.py
python/paddle/incubate/autograd/primx.py
+11
-6
未找到文件。
python/paddle/fluid/tests/unittests/autograd/test_transform.py
浏览文件 @
b83138d0
...
@@ -88,6 +88,12 @@ class TestAutoGradTransformForAdd(unittest.TestCase):
...
@@ -88,6 +88,12 @@ class TestAutoGradTransformForAdd(unittest.TestCase):
'mul_p'
,
'mul_p'
,
'mul_p'
'mul_p'
]
]
self
.
prim2orig_ops_with_blacklist
=
[
'tanh'
,
'tanh'
,
'add_p'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'elementwise_mul'
,
'sub_p'
,
'fill_constant'
,
'elementwise_mul'
,
'sub_p'
,
'fill_constant'
,
'elementwise_mul'
,
'elementwise_mul'
]
self
.
prim2orig_ops
=
[
self
.
prim2orig_ops
=
[
'tanh'
,
'tanh'
,
'elementwise_add'
,
'fill_constant'
,
'fill_constant'
,
'tanh'
,
'tanh'
,
'elementwise_add'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'elementwise_mul'
,
'elementwise_sub'
,
'fill_constant'
,
'elementwise_mul'
,
'elementwise_sub'
,
...
@@ -132,6 +138,13 @@ class TestAutoGradTransformForAdd(unittest.TestCase):
...
@@ -132,6 +138,13 @@ class TestAutoGradTransformForAdd(unittest.TestCase):
for
k
,
v
in
self
.
ys_shape_map
.
items
():
for
k
,
v
in
self
.
ys_shape_map
.
items
():
self
.
assertEqual
(
flatten_ys_bar
[
k
].
shape
,
v
)
self
.
assertEqual
(
flatten_ys_bar
[
k
].
shape
,
v
)
# Test prim2orig with blacklist
prim2orig
(
block
=
self
.
main_program
.
block
(
0
),
blacklist
=
[
'add_p'
,
'sub_p'
])
prim2orig_ops
=
[
op
.
type
for
op
in
self
.
main_program
.
block
(
0
).
ops
]
self
.
assertEqual
(
sorted
(
prim2orig_ops
),
sorted
(
self
.
prim2orig_ops_with_blacklist
))
# Test prim2orig
# Test prim2orig
prim2orig
(
block
=
self
.
main_program
.
block
(
0
))
prim2orig
(
block
=
self
.
main_program
.
block
(
0
))
prim2orig_ops
=
[
op
.
type
for
op
in
self
.
main_program
.
block
(
0
).
ops
]
prim2orig_ops
=
[
op
.
type
for
op
in
self
.
main_program
.
block
(
0
).
ops
]
...
@@ -198,6 +211,26 @@ class TestAutoGradTransformForMatmul(TestAutoGradTransformForAdd):
...
@@ -198,6 +211,26 @@ class TestAutoGradTransformForMatmul(TestAutoGradTransformForAdd):
'reshape_p'
,
'reshape_p'
,
]
]
self
.
prim2orig_ops_with_blacklist
=
[
'reshape2'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'elementwise_mul'
,
'add_p'
,
'matmul_v2'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'elementwise_mul'
,
'transpose2'
,
'matmul_v2'
,
'transpose2'
,
'matmul_v2'
,
# 'elementwise_mul',
'reshape2'
,
]
self
.
prim2orig_ops
=
[
self
.
prim2orig_ops
=
[
'reshape2'
,
'reshape2'
,
'fill_constant'
,
'fill_constant'
,
...
@@ -312,6 +345,17 @@ class TestAutoGradTransformForIndexSelect(TestAutoGradTransformForAdd):
...
@@ -312,6 +345,17 @@ class TestAutoGradTransformForIndexSelect(TestAutoGradTransformForAdd):
'add_p'
,
'add_p'
,
]
]
self
.
prim2orig_ops_with_blacklist
=
[
'expand_v2'
,
'add_p'
,
'reshape2'
,
'elementwise_mul'
,
'reduce_sum'
,
'sqrt'
,
'expand_v2'
,
'sub_p'
,
'concat'
,
'gather'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'elementwise_mul'
,
'reduce_sum'
,
'reshape2'
,
'reshape2'
,
'elementwise_mul'
,
'elementwise_mul'
,
'reshape2'
,
'expand_v2'
,
'elementwise_div'
,
'reduce_sum'
,
'reshape2'
,
'fill_constant'
,
'sub_p'
,
'split'
,
'fill_constant'
,
'fill_any_like'
,
'add_p'
,
'scatter'
,
'elementwise_add'
,
'add_p'
]
self
.
prim2orig_ops
=
[
self
.
prim2orig_ops
=
[
'expand_v2'
,
'elementwise_add'
,
'reshape2'
,
'elementwise_mul'
,
'expand_v2'
,
'elementwise_add'
,
'reshape2'
,
'elementwise_mul'
,
'reduce_sum'
,
'sqrt'
,
'expand_v2'
,
'elementwise_sub'
,
'concat'
,
'reduce_sum'
,
'sqrt'
,
'expand_v2'
,
'elementwise_sub'
,
'concat'
,
...
...
python/paddle/incubate/autograd/primx.py
浏览文件 @
b83138d0
...
@@ -408,7 +408,7 @@ class Transform(object):
...
@@ -408,7 +408,7 @@ class Transform(object):
# TODO(lml): supporting control flow, nested blocks, and block other than current block of main program.
# TODO(lml): supporting control flow, nested blocks, and block other than current block of main program.
def
_lower
(
block
,
reverse
):
def
_lower
(
block
,
reverse
,
blacklist
):
# Some functions which are only used in _lower.
# Some functions which are only used in _lower.
def
bind
(
args
,
to_bind
,
value_table
):
def
bind
(
args
,
to_bind
,
value_table
):
for
i
in
range
(
len
(
args
)):
for
i
in
range
(
len
(
args
)):
...
@@ -452,7 +452,7 @@ def _lower(block, reverse):
...
@@ -452,7 +452,7 @@ def _lower(block, reverse):
for
op_idx
in
range
(
len
(
block
.
ops
)):
for
op_idx
in
range
(
len
(
block
.
ops
)):
op
=
block
.
ops
[
op_idx
]
op
=
block
.
ops
[
op_idx
]
ops_to_remove
.
append
(
op_idx
)
ops_to_remove
.
append
(
op_idx
)
if
lookup_fn
(
op
.
type
)
is
not
None
:
if
lookup_fn
(
op
.
type
)
is
not
None
and
op
.
type
not
in
blacklist
:
input_args
=
get_input_var_list
(
op
)
input_args
=
get_input_var_list
(
op
)
bind
(
input_args
,
to_bind
,
value_table
)
bind
(
input_args
,
to_bind
,
value_table
)
...
@@ -535,11 +535,11 @@ def orig2prim(block=None):
...
@@ -535,11 +535,11 @@ def orig2prim(block=None):
block
=
default_main_program
().
current_block
()
if
block
is
None
else
block
block
=
default_main_program
().
current_block
()
if
block
is
None
else
block
assert
block
==
default_main_program
().
current_block
(
assert
block
==
default_main_program
().
current_block
(
),
f
'block is neither None nor current block of main program'
),
f
'block is neither None nor current block of main program'
_lower
(
block
,
reverse
=
False
)
_lower
(
block
,
reverse
=
False
,
blacklist
=
[]
)
@
framework
.
static_only
@
framework
.
static_only
def
prim2orig
(
block
=
None
):
def
prim2orig
(
block
=
None
,
blacklist
=
None
):
"""
"""
.. note::
.. note::
**ONLY available in the static mode.**
**ONLY available in the static mode.**
...
@@ -554,6 +554,10 @@ def prim2orig(block=None):
...
@@ -554,6 +554,10 @@ def prim2orig(block=None):
block(paddle.static.Block|None, optional): The
block(paddle.static.Block|None, optional): The
target block to process on. Default None, and will
target block to process on. Default None, and will
process on the current block of main program.
process on the current block of main program.
blacklist(list[string]|None, optional): The names of automatic
differential basic operator that will not be transformed
into original operators. Default None, and the blacklist
is treated as empty list.
Examples:
Examples:
...
@@ -576,4 +580,5 @@ def prim2orig(block=None):
...
@@ -576,4 +580,5 @@ def prim2orig(block=None):
block
=
default_main_program
().
current_block
()
if
block
is
None
else
block
block
=
default_main_program
().
current_block
()
if
block
is
None
else
block
assert
block
==
default_main_program
().
current_block
(
assert
block
==
default_main_program
().
current_block
(
),
f
'block is neither None nor current block of main program'
),
f
'block is neither None nor current block of main program'
_lower
(
block
,
reverse
=
True
)
blacklist
=
[]
if
blacklist
is
None
else
blacklist
_lower
(
block
,
reverse
=
True
,
blacklist
=
blacklist
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录