From 9a42c636414829bf0ecfc142148e00e942382942 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 16 Apr 2020 11:23:04 +0800 Subject: [PATCH] fix(mge/module): add the more warnings in load_state_dict In the before, the information of miss matching and unused operators won't be printed in the non-trict mode, this commit add the information. GitOrigin-RevId: b2543eb832ec4f2e562f136fc99cc7c283f8bea3 --- python_module/megengine/module/module.py | 28 ++++++++++++++++-------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/python_module/megengine/module/module.py b/python_module/megengine/module/module.py index 6baadb3bd..f770fa164 100644 --- a/python_module/megengine/module/module.py +++ b/python_module/megengine/module/module.py @@ -360,14 +360,25 @@ class Module(metaclass=ABCMeta): loaded, skipped = self._load_state_dict_with_closure(closure) unused = set(unused) - loaded - if strict and len(unused) != 0: - raise KeyError( - "Unused params violate `strict=True`, unused={}".format(unused) - ) - if strict and len(skipped) != 0: - raise KeyError( - "Missing params violate `strict=True`, missing={}".format(skipped) - ) + if len(unused) != 0: + if strict: + raise KeyError( + "Unused params violate `strict=True`, unused={}".format(unused) + ) + else: + logger.warning( + "Unused params in `strict=False` mode, unused={}".format(unused) + ) + + if len(skipped) != 0: + if strict: + raise KeyError( + "Missing params violate `strict=True`, missing={}".format(skipped) + ) + else: + logger.warning( + "Missing params in `strict=False` mode, missing={}".format(skipped) + ) def _load_state_dict_with_closure(self, closure): """Advance state_dict load through callable `closure` whose signature is @@ -383,7 +394,6 @@ class Module(metaclass=ABCMeta): for k, var in local_state_dict.items(): to_be_load = closure(k, var) if to_be_load is None: - logger.warning("skip loading param `%s`", k) skipped.append(k) continue assert isinstance( -- GitLab