From c992afa960a65b0581ef3e30ac7ebe3cfbf40978 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 14 Apr 2021 18:26:52 +0800 Subject: [PATCH] test(mge/module): add module reset attribute test GitOrigin-RevId: 6c9adc4a7022ed453221d7c58d0b4b0033fd00b4 --- .../python/test/unit/module/test_module.py | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/imperative/python/test/unit/module/test_module.py b/imperative/python/test/unit/module/test_module.py index 05540312e..01e354718 100644 --- a/imperative/python/test/unit/module/test_module.py +++ b/imperative/python/test/unit/module/test_module.py @@ -625,3 +625,34 @@ def test_repr_module_delete(): del net.softmax output = net.__repr__() assert output == ground_truth + + +def test_repr_module_reset_attr(): + class ResetAttrModule(Module): + def __init__(self, flag): + super().__init__() + if flag: + self.a = None + self.a = Linear(3, 5) + else: + self.a = Linear(3, 5) + self.a = None + + def forward(self, x): + if self.a: + x = self.a(x) + return x + + ground_truth = [ + ( + "ResetAttrModule(\n" + " (a): Linear(in_features=3, out_features=5, bias=True)\n" + ")" + ), + ("ResetAttrModule()"), + ] + + m0 = ResetAttrModule(True) + m1 = ResetAttrModule(False) + output = [m0.__repr__(), m1.__repr__()] + assert output == ground_truth -- GitLab