提交 539d05c6 编写于 作者: A Aurelius84 提交者: Xiaoxu Chen

[CINN]Enhance CacheKey hash logic by considering input dtypes (#50557)

* [CINN]Enhance CacheKey hash logic by considering input dtypes

* add unittest

* fix typo

* fix typo

* fix map.at

* fix find

* fix test

* fix cinn cache key structure realize

* using ordered map for attributes

* add test by review advice

---------
Co-authored-by: Njiangcheng <thisjiang@qq.com>
上级 3a3ff942
......@@ -20,10 +20,12 @@ import autograd.numpy as anp
import autograd.scipy as ascipy
import config
import numpy as np
import parameterized as param
import utils
import paddle
from paddle.incubate.autograd import primx
from paddle.fluid import core
from paddle.incubate.autograd import primapi, primx
@utils.place(config.DEVICES)
......@@ -1034,5 +1036,25 @@ class TestGradWithHigherOrder(unittest.TestCase):
np.testing.assert_allclose(i, j, rtol=self._rtol, atol=self._atol)
class TestToPrim(unittest.TestCase):
def setUp(self):
paddle.enable_static()
core._set_prim_forward_enabled(True)
def tearDown(self):
core._set_prim_forward_enabled(False)
paddle.disable_static()
@param.parameterized((('dropout',),))
def test_exclude(self, exclude):
program = paddle.static.Program()
with paddle.static.program_guard(program):
x = paddle.rand((1,))
y = paddle.nn.functional.dropout(x)
primapi.to_prim(program, exclude)
ops = tuple(op.type for op in program.block(0).ops)
self.assertTrue(all(tuple(op in ops for op in exclude)))
if __name__ == '__main__':
unittest.main()
......@@ -217,12 +217,16 @@ def grad(outputs, inputs, grad_outputs=None):
@framework.static_only
def to_prim(blocks):
"""Search nonbasic ops which have be registered composite rules and replace them with primitive ops."""
def to_prim(blocks, exclude=frozenset()):
"""Search nonbasic ops which have be registered composite rules and replace them with primitive ops.
Args:
exclude(frozenset): The Operators that will be exclude in lowering.
"""
if not core._is_fwd_prim_enabled():
return
if isinstance(blocks, paddle.fluid.framework.Block):
logging.info("Atomize composite op to primitive ops begin.")
logging.debug("Atomize composite op to primitive ops begin.")
main_program = blocks.program
elif isinstance(blocks, typing.Sequence):
for item in blocks:
......@@ -236,8 +240,9 @@ def to_prim(blocks):
f"Expect block or sequence of blocks, but got {type(blocks)}."
)
with framework.program_guard(main_program):
print("Lowering composite forward ops begin...")
primx._lower_composite(blocks, prim_config["forward_blacklist"])
logging.debug("Lowering composite forward ops begin...")
primx._lower_composite(
blocks, prim_config["forward_blacklist"] | exclude
)
replace_ops = prim_config["composite_ops_record"]
print(f"Lowering composite forward ops finish: {replace_ops}")
return
logging.debug(f"Lowering composite forward ops finish: {replace_ops}")
......@@ -550,7 +550,7 @@ def _lower(block, reverse, blacklist):
block._sync_with_cpp()
def _lower_composite(block, blacklist=[]):
def _lower_composite(block, blacklist=frozenset()):
# Some functions which are only used in _lower.
def bind(args, to_bind, value_table):
for i in range(len(args)):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册