From a95f6d4f75978a5e8cbb98fcffc7ebf70c6ff7db Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Sun, 11 Jul 2021 01:45:34 +0800 Subject: [PATCH] perf(trace): add fastpath for const value assert GitOrigin-RevId: 9a966f257f1129b1785ee1e5aa34509d677975fd --- imperative/python/megengine/jit/tracing.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/imperative/python/megengine/jit/tracing.py b/imperative/python/megengine/jit/tracing.py index 8f617364..54b913d9 100644 --- a/imperative/python/megengine/jit/tracing.py +++ b/imperative/python/megengine/jit/tracing.py @@ -279,7 +279,16 @@ class trace: # Const op is represented by a str assert isinstance(op_, str) and op_ == "Const" - eq = np.all(np.atleast_1d(value) == self._tinfo[ohandles[0]].bound_data.numpy()) + expected = self._tinfo[ohandles[0]].bound_data.numpy() + shape = value.shape + if shape != expected.shape or dtype != expected.dtype: + eq = False + elif shape == (): + eq = expected.item() == value.item() + elif shape == (1,): + eq = expected[0] == value[0] + else: + eq = np.all(value == expected) if not eq: raise TraceMismatchError( "const tensor violated: got a different tensor this time" -- GitLab