提交 f7d05db7 编写于 作者: M Megvii Engine Team

fix(mge/module): fix named_tensors

GitOrigin-RevId: bb5aa1f41d3577c6b346c913337839d12f8c1559
上级 a29f1c8c
......@@ -67,6 +67,10 @@ def _is_parameter(obj):
return isinstance(obj, Parameter)
def _is_tensor(obj):
return isinstance(obj, Tensor)
def _is_buffer(obj):
return isinstance(obj, Tensor) and not isinstance(obj, Parameter)
......@@ -309,8 +313,9 @@ class Module(metaclass=ABCMeta):
module, else only returns :class:`~.Tensor` that are direct attributes
of this module.
"""
yield from self._flatten(with_key=False, recursive=recursive, **kwargs)
yield from self._flatten(
with_key=False, predicate=_is_tensor, recursive=recursive, **kwargs
)
def named_tensors(
self, prefix: Optional[str] = None, recursive: bool = True, **kwargs
......@@ -325,7 +330,11 @@ class Module(metaclass=ABCMeta):
of this module.
"""
yield from self._flatten(
with_key=True, prefix=prefix, recursive=recursive, **kwargs,
with_key=True,
prefix=prefix,
predicate=_is_tensor,
recursive=recursive,
**kwargs,
)
def children(self, **kwargs) -> "Iterable[Module]":
......
......@@ -124,6 +124,30 @@ def test_module_api(test_traced_module):
("i.bn.weight", m.i.bn.weight),
("param", m.param),
]
assert list(m.tensors()) == [
m.bn.bias,
m.bn.running_mean,
m.bn.running_var,
m.bn.weight,
m.buff,
m.i.bn.bias,
m.i.bn.running_mean,
m.i.bn.running_var,
m.i.bn.weight,
m.param,
]
assert list(m.named_tensors()) == [
("bn.bias", m.bn.bias),
("bn.running_mean", m.bn.running_mean),
("bn.running_var", m.bn.running_var),
("bn.weight", m.bn.weight),
("buff", m.buff),
("i.bn.bias", m.i.bn.bias),
("i.bn.running_mean", m.i.bn.running_mean),
("i.bn.running_var", m.i.bn.running_var),
("i.bn.weight", m.i.bn.weight),
("param", m.param),
]
m.eval()
assert (
m.training == False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册