From 2df1ab963361998e7871ccd31e9dd1e1d7f65fe4 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 9 Sep 2020 11:34:32 +0800 Subject: [PATCH] refactor(mge/jit): skip seed when checking equal rng op GitOrigin-RevId: dae2086b362531183cd0a28aaaba324a4ddc58f1 --- imperative/python/megengine/jit/tracing.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/imperative/python/megengine/jit/tracing.py b/imperative/python/megengine/jit/tracing.py index b3a1af89..e93b7087 100644 --- a/imperative/python/megengine/jit/tracing.py +++ b/imperative/python/megengine/jit/tracing.py @@ -10,6 +10,7 @@ import weakref import numpy as np from ..core._imperative_rt import GraphProfiler +from ..core._imperative_rt.ops import OprAttr from ..core.ops.special import Const from ..core.tensor import megbrain_graph as G from ..core.tensor.core import OpBase, TensorBase, TensorWrapperBase, apply @@ -127,8 +128,12 @@ class trace: record = self._seq[self._pc] op_, ihandles, ohandles = record if op != op_: - if op.type == "UniformRNG": - pass + # FIXME: will be removed once better rng implementation is done + if isinstance(op, OprAttr) and ( + op.type in ("UniformRNG", "GaussianRNG") and op.type == op_.type + ): + if op.param[8:] != op_.param[8:]: + raise TraceMismatchError("op different from last time") else: raise TraceMismatchError("op different from last time") if len(ihandles) != len(args): -- GitLab