diff --git a/imperative/python/test/unit/module/test_module.py b/imperative/python/test/unit/module/test_module.py index 05540312ef0a016612a06405c47122e4c30e28bf..01e354718a1fa0d10dcb8a45d0497782e9559e24 100644 --- a/imperative/python/test/unit/module/test_module.py +++ b/imperative/python/test/unit/module/test_module.py @@ -625,3 +625,34 @@ def test_repr_module_delete(): del net.softmax output = net.__repr__() assert output == ground_truth + + +def test_repr_module_reset_attr(): + class ResetAttrModule(Module): + def __init__(self, flag): + super().__init__() + if flag: + self.a = None + self.a = Linear(3, 5) + else: + self.a = Linear(3, 5) + self.a = None + + def forward(self, x): + if self.a: + x = self.a(x) + return x + + ground_truth = [ + ( + "ResetAttrModule(\n" + " (a): Linear(in_features=3, out_features=5, bias=True)\n" + ")" + ), + ("ResetAttrModule()"), + ] + + m0 = ResetAttrModule(True) + m1 = ResetAttrModule(False) + output = [m0.__repr__(), m1.__repr__()] + assert output == ground_truth