提交 eac8f841 编写于 作者: M Megvii Engine Team

fix(mge/jit): error out if dump bn in training mode

GitOrigin-RevId: edc7ea2962da24c8a680c0c6fb2effcfaf3508c2
上级 e509e7a0
......@@ -36,7 +36,7 @@ from ..core._imperative_rt.ops import (
)
from ..core._trace_option import set_symbolic_shape
from ..core._wrap import device as as_device
from ..core.ops.builtin import BackwardGraph, OpDef
from ..core.ops.builtin import BackwardGraph, BatchNorm, OpDef
from ..core.ops.special import Const
from ..core.tensor import megbrain_graph as G
from ..core.tensor.utils import setscalar
......@@ -833,6 +833,10 @@ class trace:
if isinstance(op, BackwardGraph):
ovars = G.apply_backward_varnode(op, *ivars)
else:
if isinstance(op, BatchNorm):
assert (
op.fwd_mode == BatchNorm.FwdMode.INFERENCE
), "can not dump BatchNorm in training mode, maybe you forget to do model.eval()?"
ovars = G.apply_normal_varnode(op, *ivars)
AutoNaming.record_opnode(ovars[0].op)
......
......@@ -11,6 +11,7 @@ import os
import tempfile
import numpy as np
import pytest
import megengine as mge
import megengine.functional as F
......@@ -140,3 +141,15 @@ def test_xornet_trace_dump():
with mkstemp() as out:
pred_fun.dump(out, arg_names=["data"], output_names=["label"])
def test_dump_bn_train_mode():
@trace(symbolic=True, capture_as_const=True)
def bn_train(data):
pred = M.BatchNorm2d(10)(data).sum()
return pred
data = mge.tensor(np.random.random((10, 10, 10, 10)))
bn_train(data)
with pytest.raises(AssertionError):
bn_train.dump("test.mge")
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册