From 539d05c6c0b28d4cd2f7f0a80b83747a8573256a Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Fri, 24 Feb 2023 10:26:45 +0800 Subject: [PATCH] [CINN]Enhance CacheKey hash logic by considering input dtypes (#50557) * [CINN]Enhance CacheKey hash logic by considering input dtypes * add unittest * fix typo * fix typo * fix map.at * fix find * fix test * fix cinn cache key structure realize * using ordered map for attributes * add test by review advice --------- Co-authored-by: jiangcheng --- .../tests/unittests/autograd/test_primapi.py | 24 ++++++++++++++++++- python/paddle/incubate/autograd/primapi.py | 19 +++++++++------ python/paddle/incubate/autograd/primx.py | 2 +- 3 files changed, 36 insertions(+), 9 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/autograd/test_primapi.py b/python/paddle/fluid/tests/unittests/autograd/test_primapi.py index 3d1a1563860..50c1acbd85c 100644 --- a/python/paddle/fluid/tests/unittests/autograd/test_primapi.py +++ b/python/paddle/fluid/tests/unittests/autograd/test_primapi.py @@ -20,10 +20,12 @@ import autograd.numpy as anp import autograd.scipy as ascipy import config import numpy as np +import parameterized as param import utils import paddle -from paddle.incubate.autograd import primx +from paddle.fluid import core +from paddle.incubate.autograd import primapi, primx @utils.place(config.DEVICES) @@ -1034,5 +1036,25 @@ class TestGradWithHigherOrder(unittest.TestCase): np.testing.assert_allclose(i, j, rtol=self._rtol, atol=self._atol) +class TestToPrim(unittest.TestCase): + def setUp(self): + paddle.enable_static() + core._set_prim_forward_enabled(True) + + def tearDown(self): + core._set_prim_forward_enabled(False) + paddle.disable_static() + + @param.parameterized((('dropout',),)) + def test_exclude(self, exclude): + program = paddle.static.Program() + with paddle.static.program_guard(program): + x = paddle.rand((1,)) + y = paddle.nn.functional.dropout(x) + primapi.to_prim(program, exclude) + ops = tuple(op.type for op in program.block(0).ops) + self.assertTrue(all(tuple(op in ops for op in exclude))) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/incubate/autograd/primapi.py b/python/paddle/incubate/autograd/primapi.py index df4fc1c513a..68d912b8589 100644 --- a/python/paddle/incubate/autograd/primapi.py +++ b/python/paddle/incubate/autograd/primapi.py @@ -217,12 +217,16 @@ def grad(outputs, inputs, grad_outputs=None): @framework.static_only -def to_prim(blocks): - """Search nonbasic ops which have be registered composite rules and replace them with primitive ops.""" +def to_prim(blocks, exclude=frozenset()): + """Search nonbasic ops which have be registered composite rules and replace them with primitive ops. + + Args: + exclude(frozenset): The Operators that will be exclude in lowering. + """ if not core._is_fwd_prim_enabled(): return if isinstance(blocks, paddle.fluid.framework.Block): - logging.info("Atomize composite op to primitive ops begin.") + logging.debug("Atomize composite op to primitive ops begin.") main_program = blocks.program elif isinstance(blocks, typing.Sequence): for item in blocks: @@ -236,8 +240,9 @@ def to_prim(blocks): f"Expect block or sequence of blocks, but got {type(blocks)}." ) with framework.program_guard(main_program): - print("Lowering composite forward ops begin...") - primx._lower_composite(blocks, prim_config["forward_blacklist"]) + logging.debug("Lowering composite forward ops begin...") + primx._lower_composite( + blocks, prim_config["forward_blacklist"] | exclude + ) replace_ops = prim_config["composite_ops_record"] - print(f"Lowering composite forward ops finish: {replace_ops}") - return + logging.debug(f"Lowering composite forward ops finish: {replace_ops}") diff --git a/python/paddle/incubate/autograd/primx.py b/python/paddle/incubate/autograd/primx.py index ec4b75f13e6..13262d30e71 100644 --- a/python/paddle/incubate/autograd/primx.py +++ b/python/paddle/incubate/autograd/primx.py @@ -550,7 +550,7 @@ def _lower(block, reverse, blacklist): block._sync_with_cpp() -def _lower_composite(block, blacklist=[]): +def _lower_composite(block, blacklist=frozenset()): # Some functions which are only used in _lower. def bind(args, to_bind, value_table): for i in range(len(args)): -- GitLab