提交 7cd846a5 编写于 作者: M Megvii Engine Team

fix(mge/batch_norm): fix batch_norm check when trace(symbolic=True)

GitOrigin-RevId: 2032eb5f7daa76b90a13acf3fe2830be06a85c7c
上级 120e719e
......@@ -394,6 +394,7 @@ class trace:
def _apply_graph_options(self, graph):
graph.options.no_force_inplace = True
graph.options.seq_opt.enable_seq_comp_node_opt = False
# graph opt level
if self._graph_opt_level is not None:
......@@ -417,7 +418,6 @@ class trace:
def _compile(self):
graph = self._graph = G.Graph()
graph.options.no_force_inplace = True
graph.options.async_exec_level = 0b100
self._apply_graph_options(graph)
# graph.options.graph_opt_level = 0
......
......@@ -13,7 +13,8 @@ import megengine
import megengine.autodiff as ad
import megengine.optimizer as optimizer
from megengine import Parameter, tensor
from megengine.module import BatchNorm2d
from megengine.jit import trace
from megengine.module import BatchNorm2d, Module
def test_frozen_bn():
......@@ -89,3 +90,25 @@ def test_bn_no_track_stat3():
data = np.random.random((6, nchannel, 2, 2)).astype("float32")
with pytest.raises(Exception):
m(data)
def test_trace_bn_forward_twice():
class Simple(Module):
def __init__(self):
super().__init__()
self.bn = BatchNorm2d(1)
def forward(self, inp):
x = self.bn(inp)
x = self.bn(x)
return x
@trace(symbolic=True)
def train_bn(inp, net=None):
net.train()
pred = net(inp)
return pred
x = np.ones((1, 1, 32, 32), dtype=np.float32)
y = train_bn(x, net=Simple())
np.testing.assert_equal(y.numpy(), 0)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册