From eac8f84142c4211f79fd25a09410473754714a2a Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 17 May 2021 12:44:50 +0800 Subject: [PATCH] fix(mge/jit): error out if dump bn in training mode GitOrigin-RevId: edc7ea2962da24c8a680c0c6fb2effcfaf3508c2 --- imperative/python/megengine/jit/tracing.py | 6 +++++- .../python/test/integration/test_trace_dump.py | 13 +++++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/imperative/python/megengine/jit/tracing.py b/imperative/python/megengine/jit/tracing.py index c4b70d6cd..c95f5402f 100644 --- a/imperative/python/megengine/jit/tracing.py +++ b/imperative/python/megengine/jit/tracing.py @@ -36,7 +36,7 @@ from ..core._imperative_rt.ops import ( ) from ..core._trace_option import set_symbolic_shape from ..core._wrap import device as as_device -from ..core.ops.builtin import BackwardGraph, OpDef +from ..core.ops.builtin import BackwardGraph, BatchNorm, OpDef from ..core.ops.special import Const from ..core.tensor import megbrain_graph as G from ..core.tensor.utils import setscalar @@ -833,6 +833,10 @@ class trace: if isinstance(op, BackwardGraph): ovars = G.apply_backward_varnode(op, *ivars) else: + if isinstance(op, BatchNorm): + assert ( + op.fwd_mode == BatchNorm.FwdMode.INFERENCE + ), "can not dump BatchNorm in training mode, maybe you forget to do model.eval()?" ovars = G.apply_normal_varnode(op, *ivars) AutoNaming.record_opnode(ovars[0].op) diff --git a/imperative/python/test/integration/test_trace_dump.py b/imperative/python/test/integration/test_trace_dump.py index e0b876c51..c719ee94b 100644 --- a/imperative/python/test/integration/test_trace_dump.py +++ b/imperative/python/test/integration/test_trace_dump.py @@ -11,6 +11,7 @@ import os import tempfile import numpy as np +import pytest import megengine as mge import megengine.functional as F @@ -140,3 +141,15 @@ def test_xornet_trace_dump(): with mkstemp() as out: pred_fun.dump(out, arg_names=["data"], output_names=["label"]) + + +def test_dump_bn_train_mode(): + @trace(symbolic=True, capture_as_const=True) + def bn_train(data): + pred = M.BatchNorm2d(10)(data).sum() + return pred + + data = mge.tensor(np.random.random((10, 10, 10, 10))) + bn_train(data) + with pytest.raises(AssertionError): + bn_train.dump("test.mge") -- GitLab