From 300f36c0700ab4202df2046b0f3537d2e6ecfd69 Mon Sep 17 00:00:00 2001 From: cxxly Date: Sun, 5 Mar 2023 09:04:55 +0000 Subject: [PATCH] [Prim] enable whitelist and blacklist for custom_vjp --- python/paddle/fluid/core.py | 4 + .../tests/unittests/autograd/test_primapi.py | 47 ++++++++- .../dygraph_to_static/test_cinn_prim.py | 2 + .../dygraph_to_static/test_cinn_prim_gelu.py | 2 + .../test_cinn_prim_layer_norm.py | 2 + .../dygraph_to_static/test_cinn_prim_mean.py | 2 + .../test_partial_program_hook.py | 71 ++++++++++++++ .../test_composite_batch_norm.py | 3 +- .../test_composite_batch_norm_grad.py | 3 +- .../composite_ops/test_composite_dropout.py | 3 +- .../prim/composite_ops/test_composite_gelu.py | 3 +- .../composite_ops/test_composite_gelu_grad.py | 5 +- .../test_composite_layer_norm.py | 5 +- .../test_composite_layer_norm_grad.py | 13 +-- .../prim/composite_ops/test_composite_mean.py | 3 +- .../composite_ops/test_composite_mean_grad.py | 5 +- .../composite_ops/test_composite_softmax.py | 3 +- .../test_composite_softmax_grad.py | 5 +- .../prim/prim/flags/test_prim_flags.py | 11 ++- .../unittests/prim/process/test_copy_op.py | 3 +- .../unittests/prim/test_comp_custom_vjp.py | 2 +- .../fluid/tests/unittests/prim_op_test.py | 5 +- python/paddle/incubate/autograd/__init__.py | 3 +- python/paddle/incubate/autograd/primapi.py | 34 +++++-- python/paddle/incubate/autograd/primx.py | 13 ++- .../paddle/jit/dy2static/partial_program.py | 31 +++--- .../jit/dy2static/program_translator.py | 96 +++++++++---------- python/paddle/jit/dy2static/utils.py | 2 +- 28 files changed, 269 insertions(+), 112 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/dygraph_to_static/test_partial_program_hook.py diff --git a/python/paddle/fluid/core.py b/python/paddle/fluid/core.py index db3a7c29788..1793d06ce2a 100644 --- a/python/paddle/fluid/core.py +++ b/python/paddle/fluid/core.py @@ -446,6 +446,10 @@ def __sync_stat_with_flag(flag): ) +def _is_all_prim_enabled(): + return _is_fwd_prim_enabled() and _is_bwd_prim_enabled() + + # Alert!!! This method is only for test coveraget, user should never use it directly, this may cause serious system errors. def _test_use_sync(value): __sync_stat_with_flag(value) diff --git a/python/paddle/fluid/tests/unittests/autograd/test_primapi.py b/python/paddle/fluid/tests/unittests/autograd/test_primapi.py index 84bbe7bd1a3..0095ab0233d 100644 --- a/python/paddle/fluid/tests/unittests/autograd/test_primapi.py +++ b/python/paddle/fluid/tests/unittests/autograd/test_primapi.py @@ -1046,14 +1046,51 @@ class TestToPrim(unittest.TestCase): paddle.disable_static() @param.parameterized.expand((({'dropout'},),)) - def test_exclude(self, exclude): + def test_blacklist(self, blacklist): program = paddle.static.Program() with paddle.static.program_guard(program): - x = paddle.rand((1,)) - y = paddle.nn.functional.dropout(x) - primapi.to_prim(program.blocks, exclude) + paddle.nn.functional.softmax( + paddle.nn.functional.dropout(paddle.rand((1,))) + ) + primapi.to_prim(program.blocks, blacklist=blacklist) + ops = tuple(op.type for op in program.block(0).ops) + self.assertTrue(all(tuple(op in ops for op in blacklist))) + + @param.parameterized.expand((({'dropout'},),)) + def test_whitelist(self, whitelist): + program = paddle.static.Program() + with paddle.static.program_guard(program): + paddle.nn.functional.softmax( + paddle.nn.functional.dropout(paddle.rand((1,))) + ) + primapi.to_prim(program.blocks, whitelist=whitelist) ops = tuple(op.type for op in program.block(0).ops) - self.assertTrue(all(tuple(op in ops for op in exclude))) + self.assertTrue(all(tuple(op not in ops for op in whitelist))) + + @param.parameterized.expand((({'softmax'}, {'softmax', 'dropout'}),)) + def test_both_not_empty(self, blacklist, whitelist): + program = paddle.static.Program() + with paddle.static.program_guard(program): + paddle.nn.functional.softmax( + paddle.nn.functional.dropout(paddle.rand((1,))) + ) + primapi.to_prim( + program.blocks, blacklist=blacklist, whitelist=whitelist + ) + ops = tuple(op.type for op in program.block(0).ops) + self.assertTrue(all(tuple(op in ops for op in blacklist))) + + @param.parameterized.expand(((('dropout',), 'softmax'),)) + def test_type_error(self, blacklist, whitelist): + program = paddle.static.Program() + with paddle.static.program_guard(program): + paddle.nn.functional.softmax( + paddle.nn.functional.dropout(paddle.rand((1,))) + ) + with self.assertRaises(TypeError): + primapi.to_prim( + program.blocks, blacklist=blacklist, whitelist=whitelist + ) if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim.py index d25fe730308..a86cf18ade1 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim.py @@ -77,6 +77,8 @@ class TestPrimForward(unittest.TestCase): def check_prim(self, net, use_prim): if not use_prim: return + # Please use PartialProgramLayer(second output parameter of get_concrete_program) rather than + # main_program here, as main_program is original program before to_prim. fwd_ops = [ op.type for op in net.forward.get_concrete_program(self.x)[1] diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim_gelu.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim_gelu.py index ad68e1195a9..a4492f1bfdf 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim_gelu.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim_gelu.py @@ -93,6 +93,8 @@ class TestPrimForwardAndBackward(unittest.TestCase): def check_prim(self, net, use_prim): if not use_prim: return + # Please use PartialProgramLayer(second output parameter of get_concrete_program) rather than + # main_program here, as main_program is original program before to_prim. fwd_ops = [ op.type for op in net.forward.get_concrete_program(self.x)[1] diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim_layer_norm.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim_layer_norm.py index 28aac57b2f5..78fea41662e 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim_layer_norm.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim_layer_norm.py @@ -89,6 +89,8 @@ class TestPrimForward(unittest.TestCase): def check_prim(self, net, use_prim): if not use_prim: return + # Please use PartialProgramLayer(second output parameter of get_concrete_program) rather than + # main_program here, as main_program is original program before to_prim. fwd_ops = [ op.type for op in net.forward.get_concrete_program(self.x, self.w, self.b)[ diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim_mean.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim_mean.py index ff18964f7a3..ff433f439e0 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim_mean.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim_mean.py @@ -100,6 +100,8 @@ class TestPrimForward(unittest.TestCase): def check_prim(self, net, use_prim): if not use_prim: return + # Please use PartialProgramLayer(second output parameter of get_concrete_program) rather than + # main_program here, as main_program is original program before to_prim. fwd_ops = [ op.type for op in net.forward.get_concrete_program(self.x)[1] diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_partial_program_hook.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_partial_program_hook.py new file mode 100644 index 00000000000..896dde419bf --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_partial_program_hook.py @@ -0,0 +1,71 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import paddle +from paddle.fluid import core +from paddle.jit.dy2static import partial_program, program_translator + + +class TestPartiaProgramLayerHook(unittest.TestCase): + def setUp(self): + self._hook = partial_program.PartialProgramLayerHook() + + def test_before_append_backward(self): + self.assertIsNone(self._hook.before_append_backward(None)) + + def test_after_append_backward(self): + self.assertIsNone(self._hook.after_append_backward(None, 0)) + + def test_after_infer(self): + self.assertIsNone(self._hook.after_infer(None)) + + +class TestPrimHook(unittest.TestCase): + def setUp(self): + core._set_prim_all_enabled(False) + + def f(): + return paddle.nn.functional.dropout(paddle.rand((1,))) + + concrete_program, partial_program = paddle.jit.to_static( + f + ).get_concrete_program() + self._hook = program_translator.PrimHooker( + concrete_program.main_program + ) + self._forward = partial_program.forward_program + self._whole = partial_program._train_program + + core._set_prim_all_enabled(True) + + def tearDown(self): + core._set_prim_all_enabled(False) + + def test_before_append_backward(self): + self._hook.before_append_backward(self._forward) + self.assertNotIn( + 'dropout', tuple(op.type for op in self._forward.blocks[0].ops) + ) + + def test_after_append_backward(self): + self._hook.after_append_backward(self._whole, 0) + self.assertNotIn( + 'dropout_grad', tuple(op.type for op in self._whole.blocks[0].ops) + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_batch_norm.py b/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_batch_norm.py index af183e8793e..2c5bc6f72e2 100644 --- a/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_batch_norm.py +++ b/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_batch_norm.py @@ -21,6 +21,7 @@ import paddle import paddle.nn as nn import paddle.nn.functional as F from paddle.fluid import core, framework +from paddle.incubate.autograd import primapi from paddle.nn import BatchNorm from paddle.tensor import ones # noqa: F401 from paddle.tensor import zeros # noqa: F401 @@ -183,7 +184,7 @@ class TestCompositeBatchNorm(unittest.TestCase): attrs.use_global_stats, ) blocks = main_program.blocks - paddle.incubate.autograd.to_prim(blocks) + primapi.to_prim(blocks) exe = paddle.static.Executor() exe.run(startup_program) diff --git a/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_batch_norm_grad.py b/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_batch_norm_grad.py index 13e148e0a6a..ad92b9dc505 100644 --- a/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_batch_norm_grad.py +++ b/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_batch_norm_grad.py @@ -20,6 +20,7 @@ from utils import SUB_TOLERANCE import paddle import paddle.nn.functional as F from paddle.fluid import core +from paddle.incubate.autograd import primapi np.random.seed(2023) @@ -190,7 +191,7 @@ class TestCompositeBatchNorm(unittest.TestCase): attrs.use_global_stats, ) blocks = main_program.blocks - paddle.incubate.autograd.to_prim(blocks) + primapi.to_prim(blocks) z = paddle.static.gradients([y], [x1]) diff --git a/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_dropout.py b/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_dropout.py index c9be916edcc..d1dabef0d04 100644 --- a/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_dropout.py +++ b/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_dropout.py @@ -19,6 +19,7 @@ import parameterized as param import paddle from paddle.fluid import core +from paddle.incubate.autograd import primapi np.random.seed(2023) @@ -154,7 +155,7 @@ class TestCompositeDropout(unittest.TestCase): input_, p, training=(not is_test), mode=mode ) if core._is_fwd_prim_enabled(): - paddle.incubate.autograd.to_prim(mp.blocks) + primapi.to_prim(mp.blocks) grad = paddle.static.gradients(output, input_)[0] exe = paddle.static.Executor(self.place) exe.run(sp) diff --git a/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_gelu.py b/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_gelu.py index 3e5c10f803a..43a90318705 100644 --- a/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_gelu.py +++ b/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_gelu.py @@ -22,6 +22,7 @@ np.random.seed(2013) import paddle import paddle.nn.functional as F from paddle.fluid import core +from paddle.incubate.autograd import primapi def generate_data(shape, dtype="float32"): @@ -89,7 +90,7 @@ class TestCompositeGelu(unittest.TestCase): # Ensure that gelu in original block self.assertTrue('gelu' in fwd_ops) - paddle.incubate.autograd.to_prim(blocks) + primapi.to_prim(blocks) fwd_ops_new = [op.type for op in blocks[0].ops] # Ensure that gelu is splitted into small ops diff --git a/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_gelu_grad.py b/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_gelu_grad.py index dda900e2472..fbc2ad59155 100644 --- a/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_gelu_grad.py +++ b/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_gelu_grad.py @@ -22,6 +22,7 @@ np.random.seed(2013) import paddle import paddle.nn.functional as F from paddle.fluid import core +from paddle.incubate.autograd import primapi def generate_data(shape, dtype="float32"): @@ -97,7 +98,7 @@ class TestCompositeGelu(unittest.TestCase): # Ensure that gelu in original block self.assertTrue('gelu' in fwd_ops) - paddle.incubate.autograd.to_prim(blocks) + primapi.to_prim(blocks) fwd_ops_new = [op.type for op in blocks[0].ops] # Ensure that gelu is splitted into small ops @@ -164,7 +165,7 @@ class TestCompositeGeluPrimBackward(unittest.TestCase): x.stop_gradient = False y = fn(x) blocks = main_program.blocks - paddle.incubate.autograd.to_prim(blocks) + primapi.to_prim(blocks) z = paddle.static.gradients([y], x) exe = paddle.static.Executor() diff --git a/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_layer_norm.py b/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_layer_norm.py index d34003c5ae9..a9eb866e0cc 100644 --- a/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_layer_norm.py +++ b/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_layer_norm.py @@ -20,6 +20,7 @@ from utils import SUB_TOLERANCE import paddle import paddle.nn.functional as F from paddle.fluid import core +from paddle.incubate.autograd import primapi def generate_data(shape1, shape2, shape3, dtype="float32"): @@ -98,7 +99,7 @@ class TestCompositelayer_norm(unittest.TestCase): # Ensure that layer_norm in original block self.assertTrue('layer_norm' in fwd_ops) - paddle.incubate.autograd.to_prim(blocks) + primapi.to_prim(blocks) fwd_ops_new = [op.type for op in blocks[0].ops] # Ensure that layer_norm is splitted into small ops @@ -137,7 +138,7 @@ class TestCompositelayer_norm(unittest.TestCase): # Ensure that layer_norm in original block self.assertTrue('layer_norm' in fwd_ops) - paddle.incubate.autograd.to_prim(blocks) + primapi.to_prim(blocks) fwd_ops_new = [op.type for op in blocks[0].ops] # Ensure that layer_norm is splitted into small ops diff --git a/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_layer_norm_grad.py b/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_layer_norm_grad.py index a4551732033..1c85e6e46d0 100644 --- a/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_layer_norm_grad.py +++ b/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_layer_norm_grad.py @@ -22,6 +22,7 @@ from utils import SUB_TOLERANCE import paddle import paddle.nn.functional as F from paddle.fluid import core +from paddle.incubate.autograd import primapi TOLERANCE_NUMPY = { "float32": {"rtol": 2e-5, "atol": 2e-5}, @@ -196,7 +197,7 @@ class TestCompositelayer_norm(unittest.TestCase): # Ensure that layer_norm in original block self.assertTrue('layer_norm' in fwd_ops) - paddle.incubate.autograd.to_prim(blocks) + primapi.to_prim(blocks) fwd_ops_new = [op.type for op in blocks[0].ops] # Ensure that layer_norm is splitted into small ops @@ -242,7 +243,7 @@ class TestCompositelayer_norm(unittest.TestCase): # Ensure that layer_norm in original block self.assertTrue('layer_norm' in fwd_ops) - paddle.incubate.autograd.to_prim(blocks) + primapi.to_prim(blocks) fwd_ops_new = [op.type for op in blocks[0].ops] # Ensure that layer_norm is splitted into small ops @@ -341,7 +342,7 @@ class TestCompositelayer_normPrimBackward(unittest.TestCase): y = fn(x, norm_shape, w, b) blocks = main_program.blocks - paddle.incubate.autograd.to_prim(blocks) + primapi.to_prim(blocks) z = paddle.static.gradients([y], x) exe = paddle.static.Executor() @@ -374,7 +375,7 @@ class TestCompositelayer_normPrimBackward(unittest.TestCase): y = fn(x, norm_shape, weight, bias) blocks = main_program.blocks - paddle.incubate.autograd.to_prim(blocks) + primapi.to_prim(blocks) z = paddle.static.gradients([y], x) exe = paddle.static.Executor() @@ -480,7 +481,7 @@ class TestCompositeNumpylayer_norm(unittest.TestCase): # Ensure that layer_norm in original block self.assertTrue('layer_norm' in fwd_ops) - paddle.incubate.autograd.to_prim(blocks) + primapi.to_prim(blocks) fwd_ops_new = [op.type for op in blocks[0].ops] # Ensure that layer_norm is splitted into small ops @@ -532,7 +533,7 @@ class TestCompositeNumpylayer_norm(unittest.TestCase): ) blocks = main_program.blocks - paddle.incubate.autograd.to_prim(blocks) + primapi.to_prim(blocks) z = paddle.static.gradients([y], x) exe = paddle.static.Executor() diff --git a/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_mean.py b/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_mean.py index 7a43fed8e6b..05ef7ecb4d9 100644 --- a/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_mean.py +++ b/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_mean.py @@ -20,6 +20,7 @@ from utils import TOLERANCE import paddle import paddle.tensor as tensor from paddle.fluid import core +from paddle.incubate.autograd import primapi def generate_data(shape, dtype="float32"): @@ -93,7 +94,7 @@ class TestCompositeMean(unittest.TestCase): # Ensure that reduce_mean in original block self.assertTrue('reduce_mean' in fwd_ops) - paddle.incubate.autograd.to_prim(blocks) + primapi.to_prim(blocks) fwd_ops_new = [op.type for op in blocks[0].ops] # Ensure that reduce_mean is splitted into small ops diff --git a/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_mean_grad.py b/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_mean_grad.py index cd1e34ed147..6a067dddee9 100644 --- a/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_mean_grad.py +++ b/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_mean_grad.py @@ -20,6 +20,7 @@ from utils import TOLERANCE import paddle import paddle.tensor as tensor from paddle.fluid import core +from paddle.incubate.autograd import primapi def generate_data(shape, dtype="float32"): @@ -99,7 +100,7 @@ class TestCompositeMean(unittest.TestCase): # Ensure that reduce_mean in original block self.assertTrue('reduce_mean' in fwd_ops) - paddle.incubate.autograd.to_prim(blocks) + primapi.to_prim(blocks) fwd_ops_new = [op.type for op in blocks[0].ops] # Ensure that reduce_mean is splitted into small ops @@ -173,7 +174,7 @@ class TestCompositeMeanPrimBackward(unittest.TestCase): x.stop_gradient = False y = fn(x) blocks = main_program.blocks - paddle.incubate.autograd.to_prim(blocks) + primapi.to_prim(blocks) z = paddle.static.gradients([y], x) exe = paddle.static.Executor() diff --git a/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_softmax.py b/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_softmax.py index 6be130bbc57..9a7be77b196 100644 --- a/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_softmax.py +++ b/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_softmax.py @@ -20,6 +20,7 @@ from utils import TOLERANCE import paddle import paddle.nn.functional as F from paddle.fluid import core +from paddle.incubate.autograd import primapi def generate_data(shape, dtype="float32"): @@ -87,7 +88,7 @@ class TestCompositeSoftmax(unittest.TestCase): # Ensure that softmax in original block self.assertTrue('softmax' in fwd_ops) - paddle.incubate.autograd.to_prim(blocks) + primapi.to_prim(blocks) fwd_ops_new = [op.type for op in blocks[0].ops] # Ensure that softmax is splitted into small ops diff --git a/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_softmax_grad.py b/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_softmax_grad.py index 87a2fafb50f..da0028d3367 100644 --- a/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_softmax_grad.py +++ b/python/paddle/fluid/tests/unittests/prim/composite_ops/test_composite_softmax_grad.py @@ -20,6 +20,7 @@ from utils import TOLERANCE import paddle import paddle.nn.functional as F from paddle.fluid import core +from paddle.incubate.autograd import primapi def generate_data(shape, dtype="float32"): @@ -93,7 +94,7 @@ class TestCompositeSoftmax(unittest.TestCase): # Ensure that softmax in original block self.assertTrue('softmax' in fwd_ops) - paddle.incubate.autograd.to_prim(blocks) + primapi.to_prim(blocks) fwd_ops_new = [op.type for op in blocks[0].ops] # Ensure that softmax is splitted into small ops @@ -158,7 +159,7 @@ class TestCompositeSoftmaxPrimBackward(unittest.TestCase): x.stop_gradient = False y = fn(x) blocks = main_program.blocks - paddle.incubate.autograd.to_prim(blocks) + primapi.to_prim(blocks) z = paddle.static.gradients([y], x) exe = paddle.static.Executor() diff --git a/python/paddle/fluid/tests/unittests/prim/prim/flags/test_prim_flags.py b/python/paddle/fluid/tests/unittests/prim/prim/flags/test_prim_flags.py index 2c6d5133123..f88e2c75242 100644 --- a/python/paddle/fluid/tests/unittests/prim/prim/flags/test_prim_flags.py +++ b/python/paddle/fluid/tests/unittests/prim/prim/flags/test_prim_flags.py @@ -20,6 +20,7 @@ import numpy as np import paddle import paddle.nn.functional as F from paddle.fluid import core +from paddle.incubate.autograd import primapi class TestPrimFlags(unittest.TestCase): @@ -64,6 +65,12 @@ class TestPrimFlags(unittest.TestCase): with self.assertRaises(TypeError): core._test_use_sync("aaaa") + core._set_prim_all_enabled(True) + self.assertTrue(core._is_all_prim_enabled()) + + core._set_prim_all_enabled(False) + self.assertFalse(core._is_all_prim_enabled()) + class TestPrimBlacklistFlags(unittest.TestCase): def not_in_blacklist(self): @@ -83,7 +90,7 @@ class TestPrimBlacklistFlags(unittest.TestCase): # Ensure that softmax in original block self.assertTrue('softmax' in fwd_ops) - paddle.incubate.autograd.to_prim(blocks) + primapi.to_prim(blocks) fwd_ops_new = [op.type for op in blocks[0].ops] # Ensure that softmax is splitted into small ops @@ -113,7 +120,7 @@ class TestPrimBlacklistFlags(unittest.TestCase): # Ensure that softmax in original block self.assertTrue('softmax' in fwd_ops) - paddle.incubate.autograd.to_prim(blocks) + primapi.to_prim(blocks) fwd_ops_new = [op.type for op in blocks[0].ops] # Ensure that softmax is splitted into small ops diff --git a/python/paddle/fluid/tests/unittests/prim/process/test_copy_op.py b/python/paddle/fluid/tests/unittests/prim/process/test_copy_op.py index 15a6994ecbf..de208f11632 100644 --- a/python/paddle/fluid/tests/unittests/prim/process/test_copy_op.py +++ b/python/paddle/fluid/tests/unittests/prim/process/test_copy_op.py @@ -18,6 +18,7 @@ import numpy as np import paddle from paddle.fluid import core +from paddle.incubate.autograd import primapi paddle.framework.random._manual_program_seed(2023) @@ -49,7 +50,7 @@ class TestCompositeCopyOp(unittest.TestCase): # Ensure that dropout in original block self.assertTrue('dropout' in fwd_ops) - paddle.incubate.autograd.to_prim(blocks) + primapi.to_prim(blocks) fwd_ops_new = [op.type for op in blocks[0].ops] # Ensure that dropout is not splitted into small ops diff --git a/python/paddle/fluid/tests/unittests/prim/test_comp_custom_vjp.py b/python/paddle/fluid/tests/unittests/prim/test_comp_custom_vjp.py index 90651c0c401..94800b6f5fb 100644 --- a/python/paddle/fluid/tests/unittests/prim/test_comp_custom_vjp.py +++ b/python/paddle/fluid/tests/unittests/prim/test_comp_custom_vjp.py @@ -56,7 +56,7 @@ class TestCustomVJP(unittest.TestCase): 'elementwise_mul', 'scale', 'cast', - 'fill_constant', + 'fill_any_like', 'cast', 'elementwise_mul', 'fill_constant', diff --git a/python/paddle/fluid/tests/unittests/prim_op_test.py b/python/paddle/fluid/tests/unittests/prim_op_test.py index f3f780b05f9..980fdc5f7a5 100644 --- a/python/paddle/fluid/tests/unittests/prim_op_test.py +++ b/python/paddle/fluid/tests/unittests/prim_op_test.py @@ -22,6 +22,7 @@ import numpy as np import paddle import paddle.fluid.core as core from paddle.fluid.framework import _dygraph_tracer, in_dygraph_mode +from paddle.incubate.autograd import primapi from paddle.jit.dy2static.utils import parse_arg_and_kwargs @@ -588,7 +589,7 @@ class PrimForwardChecker: args, len(inputs_sig) ) ret = flatten(_as_list(self.python_api(*args))) - paddle.incubate.autograd.to_prim(main_program.blocks) + primapi.to_prim(main_program.blocks) exe = paddle.static.Executor(self.place) exe.run(startup_program) ret = exe.run(main_program, feed=feed, fetch_list=ret) @@ -1018,7 +1019,7 @@ class PrimGradChecker(PrimForwardChecker): outputs_dict = self.get_output_dict( self.outputs, fw_outs, outputs_sig ) - paddle.incubate.autograd.to_prim(main_program.blocks) + primapi.to_prim(main_program.blocks) ys = [] if isinstance(self.output_names, list): for output_name in self.output_names: diff --git a/python/paddle/incubate/autograd/__init__.py b/python/paddle/incubate/autograd/__init__.py index 3e73ff571e5..d9b9e417819 100644 --- a/python/paddle/incubate/autograd/__init__.py +++ b/python/paddle/incubate/autograd/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from .functional import Hessian, Jacobian, jvp, vjp -from .primapi import forward_grad, grad, to_prim +from .primapi import forward_grad, grad from .primx import prim2orig from .utils import disable_prim, enable_prim, prim_enabled @@ -25,5 +25,4 @@ __all__ = [ # noqa 'disable_prim', 'forward_grad', 'grad', - 'to_prim', ] diff --git a/python/paddle/incubate/autograd/primapi.py b/python/paddle/incubate/autograd/primapi.py index 5bfd05156c3..38dbd591baf 100644 --- a/python/paddle/incubate/autograd/primapi.py +++ b/python/paddle/incubate/autograd/primapi.py @@ -217,11 +217,18 @@ def grad(outputs, inputs, grad_outputs=None): @framework.static_only -def to_prim(blocks, exclude=frozenset()): +def to_prim(blocks, blacklist=frozenset(), whitelist=frozenset()): """Search nonbasic ops which have be registered composite rules and replace them with primitive ops. + The operators in blacklist will be excluded from program when lowering into primitives, and only the + operators in whitelist will be lowering. The priority of blacklist is higher than whitelist, it means + an operator both in blacklist and whitelist will not be lowering. + + The finally set that will be lowering is: + (blocks.ops & ops have decomposite rule & whitelist) - blacklist Args: - exclude(frozenset): The Operators that will be exclude in lowering. + blacklist(frozenset): The Operators that will be exclude when lowering into primitives. + whitelist(frozenset): Only the operators in whitelist will be lowering into primitives. """ if not core._is_fwd_prim_enabled(): return @@ -239,15 +246,28 @@ def to_prim(blocks, exclude=frozenset()): raise TypeError( f"Expect block or sequence of blocks, but got {type(blocks)}." ) - if not isinstance(exclude, (set, frozenset)): + if not isinstance(blacklist, (set, frozenset)): + raise TypeError( + f'Expected type of blacklisst is set|frozenset, but got {type(blacklist)}.' + ) + if not isinstance(whitelist, (set, frozenset)): raise TypeError( - f'Expected type of exclude is set|frozenset, but got {type(exclude)}.' + f'Expected type of whiltelist is set|frozenset, but got {type(whitelist)}.' ) + blacklist = prim_config["forward_blacklist"] | blacklist + with framework.program_guard(main_program): print("Lowering composite forward ops begin...") - primx._lower_composite( - blocks, prim_config["forward_blacklist"] | exclude - ) + + if len(blacklist) > 0 and len(whitelist) > 0: + filter_ = lambda x: x.type in whitelist and x.type not in blacklist + elif len(blacklist) > 0 and len(whitelist) == 0: + filter_ = lambda x: x.type not in blacklist + elif len(blacklist) == 0 and len(whitelist) > 0: + filter_ = lambda x: x.type in whitelist + else: + filter_ = lambda x: True + primx._lower_composite(blocks, filter_) replace_ops = prim_config["composite_ops_record"] print(f"Lowering composite forward ops finish: {replace_ops}") diff --git a/python/paddle/incubate/autograd/primx.py b/python/paddle/incubate/autograd/primx.py index 5e071e465ec..a204f940e1d 100644 --- a/python/paddle/incubate/autograd/primx.py +++ b/python/paddle/incubate/autograd/primx.py @@ -550,8 +550,11 @@ def _lower(block, reverse, blacklist): block._sync_with_cpp() -def _lower_composite(block, blacklist=frozenset()): - # Some functions which are only used in _lower. +def _lower_composite( + block, filter_: typing.Callable[[framework.Operator], bool] = lambda x: True +): + """The operators in block wich satisfy the filter conditon will be decomposite into primitives.""" + def bind(args, to_bind, value_table): for i in range(len(args)): if isinstance(args[i], list): @@ -603,7 +606,7 @@ def _lower_composite(block, blacklist=frozenset()): for op_idx in range(len(block.ops)): op = block.ops[op_idx] ops_to_remove.append(op_idx) - if lookup_fn(op.type) is not None and op.type not in blacklist: + if lookup_fn(op.type) is not None and filter_(op): change = True op_name = op.type prim_config["composite_ops_record"].add(op_name) @@ -681,12 +684,12 @@ def _lower_composite(block, blacklist=frozenset()): # composite ops may contain other composite ops, thus, call _lower_composite again. if change: - _lower_composite(block, blacklist) + _lower_composite(block, filter_) return elif isinstance(block, typing.Sequence): for item in block: - _lower_composite(item, blacklist) + _lower_composite(item, filter_) return else: raise TypeError diff --git a/python/paddle/jit/dy2static/partial_program.py b/python/paddle/jit/dy2static/partial_program.py index 4afa8c1f90f..55bd8ab7448 100644 --- a/python/paddle/jit/dy2static/partial_program.py +++ b/python/paddle/jit/dy2static/partial_program.py @@ -144,15 +144,13 @@ class ProgramInfo: class PartialProgramLayerHook: - def before_append_backward(self, partial_program_layer, forward_program): + def before_append_backward(self, forward_program): ... - def after_append_backward( - self, partial_program_layer, whole_program, backward_start_idx - ): + def after_append_backward(self, whole_program, backward_start_idx): ... - def after_infer(self, partial_program_layer, infer_program): + def after_infer(self, infer_program): ... @@ -264,7 +262,7 @@ class PartialProgramLayer: for_test=is_infer_mode ) if self._hooker: - infer_program = self._hooker.after_infer(self, infer_program) + infer_program = self._hooker.after_infer(infer_program) return infer_program else: train_program = self._append_backward_desc( @@ -298,11 +296,9 @@ class PartialProgramLayer: pure_fp16_program, self._amp_list, use_fp16_guard=False ) - core.check_and_set_prim_all_enabled() - from paddle.incubate.autograd.primapi import to_prim - - to_prim(pure_fp16_program.blocks) if is_infer_mode: + if self._hooker: + pure_fp16_program = self._hooker.after_infer(pure_fp16_program) return pure_fp16_program else: train_pure_fp16_program = self._append_backward_desc( @@ -314,7 +310,6 @@ class PartialProgramLayer: @switch_to_static_graph def _create_forward_backward_train_program(self): whole_program = self._train_program - # _, forward_end_op_index = self._infer_info('fp32', self._create_program) forward_end_op_index = self.get_forward_end_op_idx(whole_program) assert forward_end_op_index >= 0 @@ -437,7 +432,9 @@ class PartialProgramLayer: return _param_grad_names(self._train_program.desc, self._params) def get_forward_end_op_idx(self, program): - return self._forward_end_index_map[_hash_with_id(program, self)] + return self._forward_end_index_map[ + paddle.utils._hash_with_id(program, self) + ] @LazyInitialized def _out_grad_names(self): @@ -637,7 +634,7 @@ class PartialProgramLayer: # make sure all status of is_test are False in train mode. program = _change_is_test_status(main_program.clone(), is_test=False) if self._hooker: - program = self._hooker.before_append_backward(self, program) + program = self._hooker.before_append_backward(program) targets = [] for out in self._outputs.tolist(): if isinstance(out, framework.Variable): @@ -652,12 +649,14 @@ class PartialProgramLayer: if self._hooker: program, start_idx = self._hooker.after_append_backward( - self, program, start_idx + program, start_idx ) - self.prepare_gradient_aggregation(start_idx, main_program, program) + self.prepare_gradient_aggregation( + start_idx + 1, main_program, program + ) self._forward_end_index_map[ - _hash_with_id(program, self) + paddle.utils._hash_with_id(program, self) ] = start_idx - len(self._outputs.tolist()) return program diff --git a/python/paddle/jit/dy2static/program_translator.py b/python/paddle/jit/dy2static/program_translator.py index e201915310e..5b3eae6f8e3 100644 --- a/python/paddle/jit/dy2static/program_translator.py +++ b/python/paddle/jit/dy2static/program_translator.py @@ -19,7 +19,7 @@ import threading import warnings import weakref -from paddle.amp.auto_cast import _in_amp_guard, _in_pure_fp16_guard +from paddle.amp.auto_cast import _in_amp_guard from paddle.fluid import _non_static_mode, core, framework from paddle.fluid.data_feeder import check_type from paddle.fluid.dygraph import layers @@ -194,7 +194,7 @@ class CacheKey: input_args_with_spec, input_kwargs_with_spec, class_instance, - **kwargs + **kwargs, ): """ Initializes a cache key. @@ -568,7 +568,7 @@ class StaticFunction: self._class_instance, **self._kwargs, with_hook=with_hook, - is_train=is_train + is_train=is_train, ) # 3. check whether hit the cache or build a new program for the input arguments @@ -674,7 +674,7 @@ class StaticFunction: concrete_program, _ = self.get_concrete_program( *desired_input_spec, with_hook=with_hook, - is_train=self._is_train_mode() + is_train=self._is_train_mode(), ) return concrete_program else: @@ -946,7 +946,7 @@ class ConcreteProgram: function, main_program, startup_program=None, - **kwargs + **kwargs, ): self.inputs = inputs self.outputs = outputs @@ -1049,7 +1049,7 @@ class ConcreteProgram: function=dygraph_function, main_program=main_program, startup_program=startup_program, - **kwargs + **kwargs, ) @@ -1152,7 +1152,7 @@ class ProgramCache: input_spec=cache_key.input_args_with_spec, input_kwargs_spec=cache_key.input_kwargs_with_spec, class_instance=cache_key.class_instance, - **cache_key.kwargs + **cache_key.kwargs, ) except Exception as e: if enable_fallback: @@ -1182,48 +1182,11 @@ class ProgramCache: ) ) - class PrimHooker(PartialProgramLayerHook): - def __init__(self): - self.custom_vjps = set() - if core._is_fwd_prim_enabled() and core._is_bwd_prim_enabled(): - self.custom_vjps = { - op.type - for op in concrete_program.main_program.block(0).ops - if core.has_comp_grad_op_maker(op.type) - } - - def before_append_backward( - self, partial_program_layer, forward_program - ): - if core._is_fwd_prim_enabled(): - to_prim(forward_program.block(0), self.custom_vjps) - return forward_program - - def after_append_backward( - self, partial_program_layer, whole_program, backward_start_idx - ): - backward_length = ( - len(whole_program.block(0).ops) - backward_start_idx - ) - if core._is_fwd_prim_enabled() and len(self.custom_vjps) != 0: - to_prim(whole_program.block(0)) - new_start_index = ( - len(whole_program.block(0).ops) - backward_length - ) - return whole_program, new_start_index - - def after_infer(self, partial_program_layer, infer_program): - if core._is_fwd_prim_enabled(): - to_prim(infer_program.block(0)) - return infer_program - partial_program = partial_program_from(concrete_program) - if ( - core._is_fwd_prim_enabled() - and not _in_amp_guard() - and not _in_pure_fp16_guard() - ): - partial_program.set_hooker(PrimHooker()) + if core._is_fwd_prim_enabled() and not _in_amp_guard(): + partial_program.set_hooker( + PrimHooker(concrete_program.main_program) + ) return concrete_program, partial_program def __getitem__(self, item): @@ -1279,6 +1242,38 @@ class ProgramCache: self._caches = collections.OrderedDict() +class PrimHooker(PartialProgramLayerHook): + def __init__(self, original_program): + if len(original_program.blocks) > 1: + raise ValueError( + 'The primitive mode only support one block currently.' + ) + self.custom_vjps = set() + if core._is_all_prim_enabled(): + self.custom_vjps = { + op.type + for op in original_program.block(0).ops + if core.has_comp_grad_op_maker(op.type) + } + + def before_append_backward(self, forward_program): + if core._is_fwd_prim_enabled(): + _to_prim(forward_program.blocks, blacklist=self.custom_vjps) + return forward_program + + def after_append_backward(self, whole_program, backward_start_idx): + backward_length = len(whole_program.block(0).ops) - backward_start_idx + if core._is_fwd_prim_enabled() and len(self.custom_vjps) != 0: + _to_prim(whole_program.blocks, whitelist=self.custom_vjps) + new_start_index = len(whole_program.block(0).ops) - backward_length + return whole_program, new_start_index + + def after_infer(self, infer_program): + if core._is_fwd_prim_enabled(): + _to_prim(infer_program.block(0)) + return infer_program + + class ProgramTranslator: """ Class to translate dygraph function into static graph function. The object @@ -1696,8 +1691,9 @@ def enable_to_static(enable_to_static_bool): @switch_to_static_graph -def to_prim(blocks, exclude=frozenset()): +def _to_prim(blocks, blacklist=frozenset(), whitelist=frozenset()): + """Swith to static graph and call to_prim.""" # TODO(Aurelius84): Fix this cycle import problem from paddle.incubate.autograd import primapi - primapi.to_prim(blocks, exclude) + primapi.to_prim(blocks, blacklist=blacklist, whitelist=whitelist) diff --git a/python/paddle/jit/dy2static/utils.py b/python/paddle/jit/dy2static/utils.py index b37ee05f9f0..34d628c1d35 100644 --- a/python/paddle/jit/dy2static/utils.py +++ b/python/paddle/jit/dy2static/utils.py @@ -1568,7 +1568,7 @@ def _out_grad_names(program_desc, fwd_end_op_index, out_size): min(fwd_end_op_index + out_size, program_desc.block(0).op_size()), ): op = program_desc.block(0).op(i) - if op.type() in ['fill_any_like', "fill_constant"]: + if op.type() == 'fill_any_like': var_name = op.output('Out')[0] names.append(var_name) return names -- GitLab