未验证 提交 b83138d0 编写于 作者: L levi131 提交者: GitHub

add blacklist in prim2orig interface (#44383)

上级 02e9453f
......@@ -88,6 +88,12 @@ class TestAutoGradTransformForAdd(unittest.TestCase):
'mul_p',
'mul_p'
]
self.prim2orig_ops_with_blacklist = [
'tanh', 'tanh', 'add_p', 'fill_constant', 'fill_constant',
'fill_constant', 'elementwise_mul', 'sub_p', 'fill_constant',
'elementwise_mul', 'sub_p', 'fill_constant', 'elementwise_mul',
'elementwise_mul'
]
self.prim2orig_ops = [
'tanh', 'tanh', 'elementwise_add', 'fill_constant', 'fill_constant',
'fill_constant', 'elementwise_mul', 'elementwise_sub',
......@@ -132,6 +138,13 @@ class TestAutoGradTransformForAdd(unittest.TestCase):
for k, v in self.ys_shape_map.items():
self.assertEqual(flatten_ys_bar[k].shape, v)
# Test prim2orig with blacklist
prim2orig(block=self.main_program.block(0),
blacklist=['add_p', 'sub_p'])
prim2orig_ops = [op.type for op in self.main_program.block(0).ops]
self.assertEqual(sorted(prim2orig_ops),
sorted(self.prim2orig_ops_with_blacklist))
# Test prim2orig
prim2orig(block=self.main_program.block(0))
prim2orig_ops = [op.type for op in self.main_program.block(0).ops]
......@@ -198,6 +211,26 @@ class TestAutoGradTransformForMatmul(TestAutoGradTransformForAdd):
'reshape_p',
]
self.prim2orig_ops_with_blacklist = [
'reshape2',
'fill_constant',
'fill_constant',
'fill_constant',
'elementwise_mul',
'add_p',
'matmul_v2',
'fill_constant',
'fill_constant',
'fill_constant',
'elementwise_mul',
'transpose2',
'matmul_v2',
'transpose2',
'matmul_v2',
# 'elementwise_mul',
'reshape2',
]
self.prim2orig_ops = [
'reshape2',
'fill_constant',
......@@ -312,6 +345,17 @@ class TestAutoGradTransformForIndexSelect(TestAutoGradTransformForAdd):
'add_p',
]
self.prim2orig_ops_with_blacklist = [
'expand_v2', 'add_p', 'reshape2', 'elementwise_mul', 'reduce_sum',
'sqrt', 'expand_v2', 'sub_p', 'concat', 'gather', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'elementwise_mul', 'reduce_sum', 'reshape2',
'reshape2', 'elementwise_mul', 'elementwise_mul', 'reshape2',
'expand_v2', 'elementwise_div', 'reduce_sum', 'reshape2',
'fill_constant', 'sub_p', 'split', 'fill_constant', 'fill_any_like',
'add_p', 'scatter', 'elementwise_add', 'add_p'
]
self.prim2orig_ops = [
'expand_v2', 'elementwise_add', 'reshape2', 'elementwise_mul',
'reduce_sum', 'sqrt', 'expand_v2', 'elementwise_sub', 'concat',
......
......@@ -408,7 +408,7 @@ class Transform(object):
# TODO(lml): supporting control flow, nested blocks, and block other than current block of main program.
def _lower(block, reverse):
def _lower(block, reverse, blacklist):
# Some functions which are only used in _lower.
def bind(args, to_bind, value_table):
for i in range(len(args)):
......@@ -452,7 +452,7 @@ def _lower(block, reverse):
for op_idx in range(len(block.ops)):
op = block.ops[op_idx]
ops_to_remove.append(op_idx)
if lookup_fn(op.type) is not None:
if lookup_fn(op.type) is not None and op.type not in blacklist:
input_args = get_input_var_list(op)
bind(input_args, to_bind, value_table)
......@@ -535,11 +535,11 @@ def orig2prim(block=None):
block = default_main_program().current_block() if block is None else block
assert block == default_main_program().current_block(
), f'block is neither None nor current block of main program'
_lower(block, reverse=False)
_lower(block, reverse=False, blacklist=[])
@framework.static_only
def prim2orig(block=None):
def prim2orig(block=None, blacklist=None):
"""
.. note::
**ONLY available in the static mode.**
......@@ -554,7 +554,11 @@ def prim2orig(block=None):
block(paddle.static.Block|None, optional): The
target block to process on. Default None, and will
process on the current block of main program.
blacklist(list[string]|None, optional): The names of automatic
differential basic operator that will not be transformed
into original operators. Default None, and the blacklist
is treated as empty list.
Examples:
.. code-block:: python
......@@ -576,4 +580,5 @@ def prim2orig(block=None):
block = default_main_program().current_block() if block is None else block
assert block == default_main_program().current_block(
), f'block is neither None nor current block of main program'
_lower(block, reverse=True)
blacklist = [] if blacklist is None else blacklist
_lower(block, reverse=True, blacklist=blacklist)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册