提交 814a4278 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

feat(mge/module): add `with_parent` argument in `_flatten`

GitOrigin-RevId: 1c88559ece139f7cdff3e155083af48353e7030b
上级 2a1e0624
......@@ -68,6 +68,7 @@ class Module(metaclass=ABCMeta):
*,
recursive: bool = True,
with_key: bool = False,
with_parent: bool = False,
prefix: Optional[str] = None,
predicate: Callable[[Any], bool] = lambda _: True,
seen: Optional[Set[int]] = None
......@@ -80,6 +81,7 @@ class Module(metaclass=ABCMeta):
:param recursive: Whether to recursively scan all the submodules.
:param with_key: Whether to yield keys along with yielded objects.
:param with_parent: Whether to yield ``self`` along with yielded objects.
:param prefix: The prefix appended to the yielded keys.
:param predicate: The predicate function applied to scanned objects.
:param seen: A dict that records whether a module has been traversed yet.
......@@ -88,7 +90,7 @@ class Module(metaclass=ABCMeta):
seen = set([id(self)])
module_dict = vars(self)
_prefix = "" if not prefix else prefix + "."
_prefix = "" if prefix is None else prefix + "."
for key in sorted(module_dict):
for expanded_key, leaf in _expand_structure(key, module_dict[key]):
......@@ -98,8 +100,12 @@ class Module(metaclass=ABCMeta):
seen.add(leaf_id)
if predicate(leaf):
if with_key:
if with_key and with_parent:
yield _prefix + expanded_key, leaf, self
elif with_key:
yield _prefix + expanded_key, leaf
elif with_parent:
yield leaf, self
else:
yield leaf
......@@ -107,22 +113,22 @@ class Module(metaclass=ABCMeta):
yield from leaf._flatten(
recursive=recursive,
with_key=with_key,
prefix=None if prefix is None else _prefix + expanded_key,
with_parent=with_parent,
prefix=_prefix + expanded_key if with_key else None,
predicate=predicate,
seen=seen,
)
def parameters(
self, requires_grad: Optional[bool] = None, recursive: bool = True
self, requires_grad: Optional[bool] = None, recursive: bool = True, **kwargs
) -> Iterable[Parameter]:
r"""Returns an iterable for the :class:`~.Parameter` of the module.
:param requires_grad: Limitation over the :attr:`~.Parameter.requires_grad`
attribute of returned :class:`.Parameter`. ``None`` for
no limitation.
attribute of returned :class:`.Parameter`. ``None`` for no limitation.
:param recursive: If ``True``, returns all :class:`~.Parameter` within this
module, else only returns :class:`~.Parameter` that are direct
attributes of this module.
module, else only returns :class:`~.Parameter` that are direct attributes
of this module.
"""
def predicate(obj) -> bool:
......@@ -130,24 +136,26 @@ class Module(metaclass=ABCMeta):
requires_grad is None or obj.requires_grad == requires_grad
)
yield from self._flatten(predicate=predicate, recursive=recursive)
yield from self._flatten(
with_key=False, predicate=predicate, recursive=recursive, **kwargs
)
def named_parameters(
self,
requires_grad: Optional[bool] = None,
prefix: str = "",
prefix: Optional[str] = None,
recursive: bool = True,
**kwargs
) -> Iterable[Tuple[str, Parameter]]:
"""Returns an iterable for key :class:`~.Parameter` pairs of the module, where
``key`` is the dotted path from this module to the :class:`~.Parameter` .
:param requires_grad: Limitation over the :attr:`~.Parameter.requires_grad`
attribute of returned :class:`~.Parameter` . ``None`` for
no limitation.
attribute of returned :class:`~.Parameter` . ``None`` for no limitation.
:param prefix: The prefix prepended to the keys.
:param recursive: If ``True``, returns all :class:`~.Parameter` within this
module, else only returns :class:`~.Parameter` that are direct
attributes of this module.
module, else only returns :class:`~.Parameter` that are direct attributes
of this module.
"""
def predicate(obj) -> bool:
......@@ -156,17 +164,23 @@ class Module(metaclass=ABCMeta):
)
yield from self._flatten(
with_key=True, prefix=prefix, predicate=predicate, recursive=recursive
with_key=True,
prefix=prefix,
predicate=predicate,
recursive=recursive,
**kwargs,
)
def buffers(self, recursive: bool = True) -> Iterable[Buffer]:
def buffers(self, recursive: bool = True, **kwargs) -> Iterable[Buffer]:
"""Returns an iterable for the :class:`~.Buffer` of the module.
:param recursive: If ``True``, returns all :class:`~.Buffer` within this
module, else only returns :class:`~.Buffer` that are direct
attributes of this module.
module, else only returns :class:`~.Buffer` that are direct attributes
of this module.
"""
yield from self._flatten(predicate=_is_buffer, recursive=recursive)
yield from self._flatten(
with_key=False, predicate=_is_buffer, recursive=recursive, **kwargs
)
def replace_param(
self, params: dict, start_pos: int, seen: Optional[Set[int]] = None
......@@ -192,48 +206,66 @@ class Module(metaclass=ABCMeta):
return offset
def named_buffers(
self, prefix: str = "", recursive: bool = True
self, prefix: Optional[str] = None, recursive: bool = True, **kwargs
) -> Iterable[Tuple[str, Buffer]]:
"""Returns an iterable for key :class:`~.Buffer` pairs of the module, where
``key`` is the dotted path from this module to the :class:`~.Buffer` .
:param prefix: The prefix prepended to the keys.
:param recursive: If ``True``, returns all :class:`~.Buffer` within this
module, else only returns :class:`~.Buffer` that are direct
attributes of this module.
module, else only returns :class:`~.Buffer` that are direct attributes
of this module.
"""
yield from self._flatten(
with_key=True, prefix=prefix, predicate=_is_buffer, recursive=recursive
with_key=True,
prefix=prefix,
predicate=_is_buffer,
recursive=recursive,
**kwargs,
)
def children(self) -> "Iterable[Module]":
def children(self, **kwargs) -> "Iterable[Module]":
"""Returns an iterable for all the submodules that are direct attributes of this
module.
"""
yield from self._flatten(predicate=_is_module, recursive=False)
yield from self._flatten(
with_key=False, predicate=_is_module, recursive=False, **kwargs
)
def named_children(self) -> "Iterable[Tuple[str, Module]]":
def named_children(self, **kwargs) -> "Iterable[Tuple[str, Module]]":
"""Returns an iterable of key-submodule pairs for all the submodules that are
direct attributes of this module, where 'key' is the attribute name of
submodules.
"""
yield from self._flatten(with_key=True, predicate=_is_module, recursive=False)
yield from self._flatten(
with_key=True, predicate=_is_module, recursive=False, **kwargs
)
def modules(self) -> "Iterable[Module]":
def modules(self, **kwargs) -> "Iterable[Module]":
"""Returns an iterable for all the modules within this module, including itself.
"""
yield self
yield from self._flatten(predicate=_is_module)
if "with_parent" in kwargs and kwargs["with_parent"]:
yield self, None
else:
yield self
yield from self._flatten(with_key=False, predicate=_is_module, **kwargs)
def named_modules(self, prefix: str = "") -> "Iterable[Tuple[str, Module]]":
def named_modules(
self, prefix: Optional[str] = None, **kwargs
) -> "Iterable[Tuple[str, Module]]":
"""Returns an iterable of key-module pairs for all the modules within this
module, including itself, where 'key' is the dotted path from this module to the
submodules.
:param prefix: The prefix prepended to the path.
"""
yield prefix, self
yield from self._flatten(with_key=True, prefix=prefix, predicate=_is_module)
if "with_parent" in kwargs and kwargs["with_parent"]:
yield ("" if prefix is None else prefix), self, None
else:
yield ("" if prefix is None else prefix), self
yield from self._flatten(
with_key=True, prefix=prefix, predicate=_is_module, **kwargs
)
def apply(self, fn: "Callable[[Module], Any]") -> None:
"""Apply function ``fn`` to all the modules within this module, including
......
......@@ -53,9 +53,6 @@ class Optimizer(metaclass=ABCMeta):
if isinstance(params, (Parameter, dict)):
params = [params]
else:
assert isinstance(
params, Iterable
), "params argument given to the optimizer should be Parameter or dict"
if not isinstance(params, Iterable):
raise TypeError(
"params argument given to the optimizer should be "
......@@ -65,13 +62,15 @@ class Optimizer(metaclass=ABCMeta):
self.param_groups = [] # type: list
param_groups = list(params)
assert len(param_groups) != 0, "optimizer got an empty parameter list"
if len(param_groups) == 0:
raise ValueError("optimizer got an empty parameter list")
param_type = type(param_groups[0])
for param in param_groups:
assert isinstance(
param, param_type
), "types of params argument given to the optimizer shoud be same"
if not isinstance(param, param_type):
raise TypeError(
"types of params argument given to the optimizer shoud be same"
)
if not isinstance(param_groups[0], dict):
param_groups = [{"params": param_groups}]
......@@ -150,7 +149,7 @@ class Optimizer(metaclass=ABCMeta):
def backward(self, loss: Tensor):
"""Computes the back-propagation of the network given loss.
:param loss: The obtained loss tensor
:param loss: The obtained loss tensor
"""
rst = []
key = 0
......
......@@ -15,7 +15,7 @@ from helpers import MLP
import megengine as mge
from megengine.core import Buffer, Parameter, tensor
from megengine.module import BatchNorm1d, BatchNorm2d, Conv2d, Module
from megengine.module import BatchNorm1d, BatchNorm2d, Conv2d, Module, Sequential
from megengine.test import assertTensorClose
......@@ -156,7 +156,7 @@ class MyModule2(Module):
return x
def test_mode_api_expand_structure():
def test_expand_structure():
m = MyModule2()
assert list(m.named_modules()) == [
("", m),
......@@ -171,6 +171,62 @@ def test_mode_api_expand_structure():
]
def test_flatten_with_parent():
m = MyModule2()
assert list(m.named_modules(with_parent=True)) == [
("", m, None),
("a.0", m.a[0], m),
("a.1.x", m.a[1]["x"], m),
("a.1.y.0", m.a[1]["y"][0], m),
("a.1.y.1", m.a[1]["y"][1], m),
("a.1.y.1.bn", m.a[1]["y"][1].bn, m.a[1]["y"][1]),
("a.2.0", m.a[2][0], m),
("a.2.0.bn", m.a[2][0].bn, m.a[2][0]),
("bn", m.bn, m),
]
assert list(m.modules(with_parent=True)) == [
(m, None),
(m.a[0], m),
(m.a[1]["x"], m),
(m.a[1]["y"][0], m),
(m.a[1]["y"][1], m),
(m.a[1]["y"][1].bn, m.a[1]["y"][1]),
(m.a[2][0], m),
(m.a[2][0].bn, m.a[2][0]),
(m.bn, m),
]
class MyModule3(Module):
class InnerModule(Module):
def __init__(self):
super().__init__()
self.bn = BatchNorm2d(4)
def forward(self, x):
x = self.bn(x)
def __init__(self):
super().__init__()
self.bn = BatchNorm2d(4)
self.seq = Sequential(BatchNorm2d(4), self.InnerModule(),)
def forward(self, x):
return x
def test_module_api_with_sequential():
m = MyModule3()
assert list(m.named_modules()) == [
("", m),
("bn", m.bn),
("seq", m.seq),
("seq.0", m.seq[0]),
("seq.1", m.seq[1]),
("seq.1.bn", m.seq[1].bn),
]
def test_state_dict():
data_shape = (2, 28)
data = tensor()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册