diff --git a/imperative/python/megengine/jit/tracing.py b/imperative/python/megengine/jit/tracing.py index 412ad2707ebb0efd14a54bd2953e1fda7437cb2a..a2d141b4a3041280add51736305c2d2a1611ca2b 100644 --- a/imperative/python/megengine/jit/tracing.py +++ b/imperative/python/megengine/jit/tracing.py @@ -394,6 +394,7 @@ class trace: def _apply_graph_options(self, graph): + graph.options.no_force_inplace = True graph.options.seq_opt.enable_seq_comp_node_opt = False # graph opt level if self._graph_opt_level is not None: @@ -417,7 +418,6 @@ class trace: def _compile(self): graph = self._graph = G.Graph() - graph.options.no_force_inplace = True graph.options.async_exec_level = 0b100 self._apply_graph_options(graph) # graph.options.graph_opt_level = 0 diff --git a/imperative/python/test/integration/test_bn.py b/imperative/python/test/integration/test_bn.py index 8767af309a48fb33f8f0015e229ec4a86190a797..89e61dd6556a3cab8851d203aa2e6f9e3d2a8f0e 100644 --- a/imperative/python/test/integration/test_bn.py +++ b/imperative/python/test/integration/test_bn.py @@ -13,7 +13,8 @@ import megengine import megengine.autodiff as ad import megengine.optimizer as optimizer from megengine import Parameter, tensor -from megengine.module import BatchNorm2d +from megengine.jit import trace +from megengine.module import BatchNorm2d, Module def test_frozen_bn(): @@ -89,3 +90,25 @@ def test_bn_no_track_stat3(): data = np.random.random((6, nchannel, 2, 2)).astype("float32") with pytest.raises(Exception): m(data) + + +def test_trace_bn_forward_twice(): + class Simple(Module): + def __init__(self): + super().__init__() + self.bn = BatchNorm2d(1) + + def forward(self, inp): + x = self.bn(inp) + x = self.bn(x) + return x + + @trace(symbolic=True) + def train_bn(inp, net=None): + net.train() + pred = net(inp) + return pred + + x = np.ones((1, 1, 32, 32), dtype=np.float32) + y = train_bn(x, net=Simple()) + np.testing.assert_equal(y.numpy(), 0)