diff --git a/python/paddle/fluid/tests/unittests/autograd/test_primapi.py b/python/paddle/fluid/tests/unittests/autograd/test_primapi.py index 3d1a1563860833604e47ef3d4dd4835f7e35b7f7..50c1acbd85ce5948166c543507af2df4efec9ad7 100644 --- a/python/paddle/fluid/tests/unittests/autograd/test_primapi.py +++ b/python/paddle/fluid/tests/unittests/autograd/test_primapi.py @@ -20,10 +20,12 @@ import autograd.numpy as anp import autograd.scipy as ascipy import config import numpy as np +import parameterized as param import utils import paddle -from paddle.incubate.autograd import primx +from paddle.fluid import core +from paddle.incubate.autograd import primapi, primx @utils.place(config.DEVICES) @@ -1034,5 +1036,25 @@ class TestGradWithHigherOrder(unittest.TestCase): np.testing.assert_allclose(i, j, rtol=self._rtol, atol=self._atol) +class TestToPrim(unittest.TestCase): + def setUp(self): + paddle.enable_static() + core._set_prim_forward_enabled(True) + + def tearDown(self): + core._set_prim_forward_enabled(False) + paddle.disable_static() + + @param.parameterized((('dropout',),)) + def test_exclude(self, exclude): + program = paddle.static.Program() + with paddle.static.program_guard(program): + x = paddle.rand((1,)) + y = paddle.nn.functional.dropout(x) + primapi.to_prim(program, exclude) + ops = tuple(op.type for op in program.block(0).ops) + self.assertTrue(all(tuple(op in ops for op in exclude))) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/incubate/autograd/primapi.py b/python/paddle/incubate/autograd/primapi.py index df4fc1c513ae56c109b56ad10b440c88ba320e9d..68d912b8589b863c55e281cad85204cab604a943 100644 --- a/python/paddle/incubate/autograd/primapi.py +++ b/python/paddle/incubate/autograd/primapi.py @@ -217,12 +217,16 @@ def grad(outputs, inputs, grad_outputs=None): @framework.static_only -def to_prim(blocks): - """Search nonbasic ops which have be registered composite rules and replace them with primitive ops.""" +def to_prim(blocks, exclude=frozenset()): + """Search nonbasic ops which have be registered composite rules and replace them with primitive ops. + + Args: + exclude(frozenset): The Operators that will be exclude in lowering. + """ if not core._is_fwd_prim_enabled(): return if isinstance(blocks, paddle.fluid.framework.Block): - logging.info("Atomize composite op to primitive ops begin.") + logging.debug("Atomize composite op to primitive ops begin.") main_program = blocks.program elif isinstance(blocks, typing.Sequence): for item in blocks: @@ -236,8 +240,9 @@ def to_prim(blocks): f"Expect block or sequence of blocks, but got {type(blocks)}." ) with framework.program_guard(main_program): - print("Lowering composite forward ops begin...") - primx._lower_composite(blocks, prim_config["forward_blacklist"]) + logging.debug("Lowering composite forward ops begin...") + primx._lower_composite( + blocks, prim_config["forward_blacklist"] | exclude + ) replace_ops = prim_config["composite_ops_record"] - print(f"Lowering composite forward ops finish: {replace_ops}") - return + logging.debug(f"Lowering composite forward ops finish: {replace_ops}") diff --git a/python/paddle/incubate/autograd/primx.py b/python/paddle/incubate/autograd/primx.py index ec4b75f13e69ba01af594b52158ac854df10f00c..13262d30e7113da13c9931f732caba08873636ae 100644 --- a/python/paddle/incubate/autograd/primx.py +++ b/python/paddle/incubate/autograd/primx.py @@ -550,7 +550,7 @@ def _lower(block, reverse, blacklist): block._sync_with_cpp() -def _lower_composite(block, blacklist=[]): +def _lower_composite(block, blacklist=frozenset()): # Some functions which are only used in _lower. def bind(args, to_bind, value_table): for i in range(len(args)):