提交 9a42c636 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

fix(mge/module): add the more warnings in load_state_dict

In the before, the information of miss matching and unused operators won't be printed in the non-trict mode, this commit add the information.

GitOrigin-RevId: b2543eb832ec4f2e562f136fc99cc7c283f8bea3
上级 9b74e6fc
......@@ -360,14 +360,25 @@ class Module(metaclass=ABCMeta):
loaded, skipped = self._load_state_dict_with_closure(closure)
unused = set(unused) - loaded
if strict and len(unused) != 0:
raise KeyError(
"Unused params violate `strict=True`, unused={}".format(unused)
)
if strict and len(skipped) != 0:
raise KeyError(
"Missing params violate `strict=True`, missing={}".format(skipped)
)
if len(unused) != 0:
if strict:
raise KeyError(
"Unused params violate `strict=True`, unused={}".format(unused)
)
else:
logger.warning(
"Unused params in `strict=False` mode, unused={}".format(unused)
)
if len(skipped) != 0:
if strict:
raise KeyError(
"Missing params violate `strict=True`, missing={}".format(skipped)
)
else:
logger.warning(
"Missing params in `strict=False` mode, missing={}".format(skipped)
)
def _load_state_dict_with_closure(self, closure):
"""Advance state_dict load through callable `closure` whose signature is
......@@ -383,7 +394,6 @@ class Module(metaclass=ABCMeta):
for k, var in local_state_dict.items():
to_be_load = closure(k, var)
if to_be_load is None:
logger.warning("skip loading param `%s`", k)
skipped.append(k)
continue
assert isinstance(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册