未验证 提交 b83138d0 编写于 作者: L levi131 提交者: GitHub

add blacklist in prim2orig interface (#44383)

上级 02e9453f
...@@ -88,6 +88,12 @@ class TestAutoGradTransformForAdd(unittest.TestCase): ...@@ -88,6 +88,12 @@ class TestAutoGradTransformForAdd(unittest.TestCase):
'mul_p', 'mul_p',
'mul_p' 'mul_p'
] ]
self.prim2orig_ops_with_blacklist = [
'tanh', 'tanh', 'add_p', 'fill_constant', 'fill_constant',
'fill_constant', 'elementwise_mul', 'sub_p', 'fill_constant',
'elementwise_mul', 'sub_p', 'fill_constant', 'elementwise_mul',
'elementwise_mul'
]
self.prim2orig_ops = [ self.prim2orig_ops = [
'tanh', 'tanh', 'elementwise_add', 'fill_constant', 'fill_constant', 'tanh', 'tanh', 'elementwise_add', 'fill_constant', 'fill_constant',
'fill_constant', 'elementwise_mul', 'elementwise_sub', 'fill_constant', 'elementwise_mul', 'elementwise_sub',
...@@ -132,6 +138,13 @@ class TestAutoGradTransformForAdd(unittest.TestCase): ...@@ -132,6 +138,13 @@ class TestAutoGradTransformForAdd(unittest.TestCase):
for k, v in self.ys_shape_map.items(): for k, v in self.ys_shape_map.items():
self.assertEqual(flatten_ys_bar[k].shape, v) self.assertEqual(flatten_ys_bar[k].shape, v)
# Test prim2orig with blacklist
prim2orig(block=self.main_program.block(0),
blacklist=['add_p', 'sub_p'])
prim2orig_ops = [op.type for op in self.main_program.block(0).ops]
self.assertEqual(sorted(prim2orig_ops),
sorted(self.prim2orig_ops_with_blacklist))
# Test prim2orig # Test prim2orig
prim2orig(block=self.main_program.block(0)) prim2orig(block=self.main_program.block(0))
prim2orig_ops = [op.type for op in self.main_program.block(0).ops] prim2orig_ops = [op.type for op in self.main_program.block(0).ops]
...@@ -198,6 +211,26 @@ class TestAutoGradTransformForMatmul(TestAutoGradTransformForAdd): ...@@ -198,6 +211,26 @@ class TestAutoGradTransformForMatmul(TestAutoGradTransformForAdd):
'reshape_p', 'reshape_p',
] ]
self.prim2orig_ops_with_blacklist = [
'reshape2',
'fill_constant',
'fill_constant',
'fill_constant',
'elementwise_mul',
'add_p',
'matmul_v2',
'fill_constant',
'fill_constant',
'fill_constant',
'elementwise_mul',
'transpose2',
'matmul_v2',
'transpose2',
'matmul_v2',
# 'elementwise_mul',
'reshape2',
]
self.prim2orig_ops = [ self.prim2orig_ops = [
'reshape2', 'reshape2',
'fill_constant', 'fill_constant',
...@@ -312,6 +345,17 @@ class TestAutoGradTransformForIndexSelect(TestAutoGradTransformForAdd): ...@@ -312,6 +345,17 @@ class TestAutoGradTransformForIndexSelect(TestAutoGradTransformForAdd):
'add_p', 'add_p',
] ]
self.prim2orig_ops_with_blacklist = [
'expand_v2', 'add_p', 'reshape2', 'elementwise_mul', 'reduce_sum',
'sqrt', 'expand_v2', 'sub_p', 'concat', 'gather', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'elementwise_mul', 'reduce_sum', 'reshape2',
'reshape2', 'elementwise_mul', 'elementwise_mul', 'reshape2',
'expand_v2', 'elementwise_div', 'reduce_sum', 'reshape2',
'fill_constant', 'sub_p', 'split', 'fill_constant', 'fill_any_like',
'add_p', 'scatter', 'elementwise_add', 'add_p'
]
self.prim2orig_ops = [ self.prim2orig_ops = [
'expand_v2', 'elementwise_add', 'reshape2', 'elementwise_mul', 'expand_v2', 'elementwise_add', 'reshape2', 'elementwise_mul',
'reduce_sum', 'sqrt', 'expand_v2', 'elementwise_sub', 'concat', 'reduce_sum', 'sqrt', 'expand_v2', 'elementwise_sub', 'concat',
......
...@@ -408,7 +408,7 @@ class Transform(object): ...@@ -408,7 +408,7 @@ class Transform(object):
# TODO(lml): supporting control flow, nested blocks, and block other than current block of main program. # TODO(lml): supporting control flow, nested blocks, and block other than current block of main program.
def _lower(block, reverse): def _lower(block, reverse, blacklist):
# Some functions which are only used in _lower. # Some functions which are only used in _lower.
def bind(args, to_bind, value_table): def bind(args, to_bind, value_table):
for i in range(len(args)): for i in range(len(args)):
...@@ -452,7 +452,7 @@ def _lower(block, reverse): ...@@ -452,7 +452,7 @@ def _lower(block, reverse):
for op_idx in range(len(block.ops)): for op_idx in range(len(block.ops)):
op = block.ops[op_idx] op = block.ops[op_idx]
ops_to_remove.append(op_idx) ops_to_remove.append(op_idx)
if lookup_fn(op.type) is not None: if lookup_fn(op.type) is not None and op.type not in blacklist:
input_args = get_input_var_list(op) input_args = get_input_var_list(op)
bind(input_args, to_bind, value_table) bind(input_args, to_bind, value_table)
...@@ -535,11 +535,11 @@ def orig2prim(block=None): ...@@ -535,11 +535,11 @@ def orig2prim(block=None):
block = default_main_program().current_block() if block is None else block block = default_main_program().current_block() if block is None else block
assert block == default_main_program().current_block( assert block == default_main_program().current_block(
), f'block is neither None nor current block of main program' ), f'block is neither None nor current block of main program'
_lower(block, reverse=False) _lower(block, reverse=False, blacklist=[])
@framework.static_only @framework.static_only
def prim2orig(block=None): def prim2orig(block=None, blacklist=None):
""" """
.. note:: .. note::
**ONLY available in the static mode.** **ONLY available in the static mode.**
...@@ -554,7 +554,11 @@ def prim2orig(block=None): ...@@ -554,7 +554,11 @@ def prim2orig(block=None):
block(paddle.static.Block|None, optional): The block(paddle.static.Block|None, optional): The
target block to process on. Default None, and will target block to process on. Default None, and will
process on the current block of main program. process on the current block of main program.
blacklist(list[string]|None, optional): The names of automatic
differential basic operator that will not be transformed
into original operators. Default None, and the blacklist
is treated as empty list.
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -576,4 +580,5 @@ def prim2orig(block=None): ...@@ -576,4 +580,5 @@ def prim2orig(block=None):
block = default_main_program().current_block() if block is None else block block = default_main_program().current_block() if block is None else block
assert block == default_main_program().current_block( assert block == default_main_program().current_block(
), f'block is neither None nor current block of main program' ), f'block is neither None nor current block of main program'
_lower(block, reverse=True) blacklist = [] if blacklist is None else blacklist
_lower(block, reverse=True, blacklist=blacklist)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册