diff --git a/python/paddle/fluid/tests/unittests/autograd/test_transform.py b/python/paddle/fluid/tests/unittests/autograd/test_transform.py index 08626593e290488f5750e47c43fe4e801005479e..f976ef729cc7a00f403e1c0f5360e9b92dd974d8 100644 --- a/python/paddle/fluid/tests/unittests/autograd/test_transform.py +++ b/python/paddle/fluid/tests/unittests/autograd/test_transform.py @@ -88,6 +88,12 @@ class TestAutoGradTransformForAdd(unittest.TestCase): 'mul_p', 'mul_p' ] + self.prim2orig_ops_with_blacklist = [ + 'tanh', 'tanh', 'add_p', 'fill_constant', 'fill_constant', + 'fill_constant', 'elementwise_mul', 'sub_p', 'fill_constant', + 'elementwise_mul', 'sub_p', 'fill_constant', 'elementwise_mul', + 'elementwise_mul' + ] self.prim2orig_ops = [ 'tanh', 'tanh', 'elementwise_add', 'fill_constant', 'fill_constant', 'fill_constant', 'elementwise_mul', 'elementwise_sub', @@ -132,6 +138,13 @@ class TestAutoGradTransformForAdd(unittest.TestCase): for k, v in self.ys_shape_map.items(): self.assertEqual(flatten_ys_bar[k].shape, v) + # Test prim2orig with blacklist + prim2orig(block=self.main_program.block(0), + blacklist=['add_p', 'sub_p']) + prim2orig_ops = [op.type for op in self.main_program.block(0).ops] + self.assertEqual(sorted(prim2orig_ops), + sorted(self.prim2orig_ops_with_blacklist)) + # Test prim2orig prim2orig(block=self.main_program.block(0)) prim2orig_ops = [op.type for op in self.main_program.block(0).ops] @@ -198,6 +211,26 @@ class TestAutoGradTransformForMatmul(TestAutoGradTransformForAdd): 'reshape_p', ] + self.prim2orig_ops_with_blacklist = [ + 'reshape2', + 'fill_constant', + 'fill_constant', + 'fill_constant', + 'elementwise_mul', + 'add_p', + 'matmul_v2', + 'fill_constant', + 'fill_constant', + 'fill_constant', + 'elementwise_mul', + 'transpose2', + 'matmul_v2', + 'transpose2', + 'matmul_v2', + # 'elementwise_mul', + 'reshape2', + ] + self.prim2orig_ops = [ 'reshape2', 'fill_constant', @@ -312,6 +345,17 @@ class TestAutoGradTransformForIndexSelect(TestAutoGradTransformForAdd): 'add_p', ] + self.prim2orig_ops_with_blacklist = [ + 'expand_v2', 'add_p', 'reshape2', 'elementwise_mul', 'reduce_sum', + 'sqrt', 'expand_v2', 'sub_p', 'concat', 'gather', 'fill_constant', + 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', + 'fill_constant', 'elementwise_mul', 'reduce_sum', 'reshape2', + 'reshape2', 'elementwise_mul', 'elementwise_mul', 'reshape2', + 'expand_v2', 'elementwise_div', 'reduce_sum', 'reshape2', + 'fill_constant', 'sub_p', 'split', 'fill_constant', 'fill_any_like', + 'add_p', 'scatter', 'elementwise_add', 'add_p' + ] + self.prim2orig_ops = [ 'expand_v2', 'elementwise_add', 'reshape2', 'elementwise_mul', 'reduce_sum', 'sqrt', 'expand_v2', 'elementwise_sub', 'concat', diff --git a/python/paddle/incubate/autograd/primx.py b/python/paddle/incubate/autograd/primx.py index 260a97cdc16a43a1b2f230ad00a8b198e80da9db..19f87dd9292154ca17dcca8364e76544107b2d5d 100644 --- a/python/paddle/incubate/autograd/primx.py +++ b/python/paddle/incubate/autograd/primx.py @@ -408,7 +408,7 @@ class Transform(object): # TODO(lml): supporting control flow, nested blocks, and block other than current block of main program. -def _lower(block, reverse): +def _lower(block, reverse, blacklist): # Some functions which are only used in _lower. def bind(args, to_bind, value_table): for i in range(len(args)): @@ -452,7 +452,7 @@ def _lower(block, reverse): for op_idx in range(len(block.ops)): op = block.ops[op_idx] ops_to_remove.append(op_idx) - if lookup_fn(op.type) is not None: + if lookup_fn(op.type) is not None and op.type not in blacklist: input_args = get_input_var_list(op) bind(input_args, to_bind, value_table) @@ -535,11 +535,11 @@ def orig2prim(block=None): block = default_main_program().current_block() if block is None else block assert block == default_main_program().current_block( ), f'block is neither None nor current block of main program' - _lower(block, reverse=False) + _lower(block, reverse=False, blacklist=[]) @framework.static_only -def prim2orig(block=None): +def prim2orig(block=None, blacklist=None): """ .. note:: **ONLY available in the static mode.** @@ -554,7 +554,11 @@ def prim2orig(block=None): block(paddle.static.Block|None, optional): The target block to process on. Default None, and will process on the current block of main program. - + blacklist(list[string]|None, optional): The names of automatic + differential basic operator that will not be transformed + into original operators. Default None, and the blacklist + is treated as empty list. + Examples: .. code-block:: python @@ -576,4 +580,5 @@ def prim2orig(block=None): block = default_main_program().current_block() if block is None else block assert block == default_main_program().current_block( ), f'block is neither None nor current block of main program' - _lower(block, reverse=True) + blacklist = [] if blacklist is None else blacklist + _lower(block, reverse=True, blacklist=blacklist)