提交 a95f6d4f 编写于 作者: M Megvii Engine Team 提交者: huangxinda

perf(trace): add fastpath for const value assert

GitOrigin-RevId: 9a966f257f1129b1785ee1e5aa34509d677975fd
上级 2cd98232
......@@ -279,7 +279,16 @@ class trace:
# Const op is represented by a str
assert isinstance(op_, str) and op_ == "Const"
eq = np.all(np.atleast_1d(value) == self._tinfo[ohandles[0]].bound_data.numpy())
expected = self._tinfo[ohandles[0]].bound_data.numpy()
shape = value.shape
if shape != expected.shape or dtype != expected.dtype:
eq = False
elif shape == ():
eq = expected.item() == value.item()
elif shape == (1,):
eq = expected[0] == value[0]
else:
eq = np.all(value == expected)
if not eq:
raise TraceMismatchError(
"const tensor violated: got a different tensor this time"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册