提交 539d05c6 编写于 作者: A Aurelius84 提交者: Xiaoxu Chen

[CINN]Enhance CacheKey hash logic by considering input dtypes (#50557)

* [CINN]Enhance CacheKey hash logic by considering input dtypes

* add unittest

* fix typo

* fix typo

* fix map.at

* fix find

* fix test

* fix cinn cache key structure realize

* using ordered map for attributes

* add test by review advice

---------
Co-authored-by: Njiangcheng <thisjiang@qq.com>
上级 3a3ff942
...@@ -20,10 +20,12 @@ import autograd.numpy as anp ...@@ -20,10 +20,12 @@ import autograd.numpy as anp
import autograd.scipy as ascipy import autograd.scipy as ascipy
import config import config
import numpy as np import numpy as np
import parameterized as param
import utils import utils
import paddle import paddle
from paddle.incubate.autograd import primx from paddle.fluid import core
from paddle.incubate.autograd import primapi, primx
@utils.place(config.DEVICES) @utils.place(config.DEVICES)
...@@ -1034,5 +1036,25 @@ class TestGradWithHigherOrder(unittest.TestCase): ...@@ -1034,5 +1036,25 @@ class TestGradWithHigherOrder(unittest.TestCase):
np.testing.assert_allclose(i, j, rtol=self._rtol, atol=self._atol) np.testing.assert_allclose(i, j, rtol=self._rtol, atol=self._atol)
class TestToPrim(unittest.TestCase):
def setUp(self):
paddle.enable_static()
core._set_prim_forward_enabled(True)
def tearDown(self):
core._set_prim_forward_enabled(False)
paddle.disable_static()
@param.parameterized((('dropout',),))
def test_exclude(self, exclude):
program = paddle.static.Program()
with paddle.static.program_guard(program):
x = paddle.rand((1,))
y = paddle.nn.functional.dropout(x)
primapi.to_prim(program, exclude)
ops = tuple(op.type for op in program.block(0).ops)
self.assertTrue(all(tuple(op in ops for op in exclude)))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -217,12 +217,16 @@ def grad(outputs, inputs, grad_outputs=None): ...@@ -217,12 +217,16 @@ def grad(outputs, inputs, grad_outputs=None):
@framework.static_only @framework.static_only
def to_prim(blocks): def to_prim(blocks, exclude=frozenset()):
"""Search nonbasic ops which have be registered composite rules and replace them with primitive ops.""" """Search nonbasic ops which have be registered composite rules and replace them with primitive ops.
Args:
exclude(frozenset): The Operators that will be exclude in lowering.
"""
if not core._is_fwd_prim_enabled(): if not core._is_fwd_prim_enabled():
return return
if isinstance(blocks, paddle.fluid.framework.Block): if isinstance(blocks, paddle.fluid.framework.Block):
logging.info("Atomize composite op to primitive ops begin.") logging.debug("Atomize composite op to primitive ops begin.")
main_program = blocks.program main_program = blocks.program
elif isinstance(blocks, typing.Sequence): elif isinstance(blocks, typing.Sequence):
for item in blocks: for item in blocks:
...@@ -236,8 +240,9 @@ def to_prim(blocks): ...@@ -236,8 +240,9 @@ def to_prim(blocks):
f"Expect block or sequence of blocks, but got {type(blocks)}." f"Expect block or sequence of blocks, but got {type(blocks)}."
) )
with framework.program_guard(main_program): with framework.program_guard(main_program):
print("Lowering composite forward ops begin...") logging.debug("Lowering composite forward ops begin...")
primx._lower_composite(blocks, prim_config["forward_blacklist"]) primx._lower_composite(
blocks, prim_config["forward_blacklist"] | exclude
)
replace_ops = prim_config["composite_ops_record"] replace_ops = prim_config["composite_ops_record"]
print(f"Lowering composite forward ops finish: {replace_ops}") logging.debug(f"Lowering composite forward ops finish: {replace_ops}")
return
...@@ -550,7 +550,7 @@ def _lower(block, reverse, blacklist): ...@@ -550,7 +550,7 @@ def _lower(block, reverse, blacklist):
block._sync_with_cpp() block._sync_with_cpp()
def _lower_composite(block, blacklist=[]): def _lower_composite(block, blacklist=frozenset()):
# Some functions which are only used in _lower. # Some functions which are only used in _lower.
def bind(args, to_bind, value_table): def bind(args, to_bind, value_table):
for i in range(len(args)): for i in range(len(args)):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册