diff --git a/imperative/python/megengine/traced_module/compat.py b/imperative/python/megengine/traced_module/compat.py index 878d2fc6bb97523069b8e818704861f292c1346a..acd4afa0dfad1133708b5bb934bfbdc0acb1d2bb 100644 --- a/imperative/python/megengine/traced_module/compat.py +++ b/imperative/python/megengine/traced_module/compat.py @@ -130,7 +130,8 @@ def convbn2d_module_loader(expr): @register_opdef_loader(BatchNorm) def bn_opdef_loader(expr): # mge 1.6 - if not hasattr(expr, "version"): + if not hasattr(expr, "version") and len(expr.outputs) != 6: + assert len(expr.outputs) == 5 output = expr.outputs[-1] oup = TensorNode(expr, shape=(0,), dtype=None, qparams=output._qparams,) expr.outputs.insert(4, oup)