diff --git a/python_module/megengine/module/module.py b/python_module/megengine/module/module.py index 6baadb3bdeaa8ad17166278e38852029226b0756..f770fa1646dd3c651b019e46135cb16a723c0037 100644 --- a/python_module/megengine/module/module.py +++ b/python_module/megengine/module/module.py @@ -360,14 +360,25 @@ class Module(metaclass=ABCMeta): loaded, skipped = self._load_state_dict_with_closure(closure) unused = set(unused) - loaded - if strict and len(unused) != 0: - raise KeyError( - "Unused params violate `strict=True`, unused={}".format(unused) - ) - if strict and len(skipped) != 0: - raise KeyError( - "Missing params violate `strict=True`, missing={}".format(skipped) - ) + if len(unused) != 0: + if strict: + raise KeyError( + "Unused params violate `strict=True`, unused={}".format(unused) + ) + else: + logger.warning( + "Unused params in `strict=False` mode, unused={}".format(unused) + ) + + if len(skipped) != 0: + if strict: + raise KeyError( + "Missing params violate `strict=True`, missing={}".format(skipped) + ) + else: + logger.warning( + "Missing params in `strict=False` mode, missing={}".format(skipped) + ) def _load_state_dict_with_closure(self, closure): """Advance state_dict load through callable `closure` whose signature is @@ -383,7 +394,6 @@ class Module(metaclass=ABCMeta): for k, var in local_state_dict.items(): to_be_load = closure(k, var) if to_be_load is None: - logger.warning("skip loading param `%s`", k) skipped.append(k) continue assert isinstance(