提交 814a4278 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

feat(mge/module): add `with_parent` argument in `_flatten`

GitOrigin-RevId: 1c88559ece139f7cdff3e155083af48353e7030b
上级 2a1e0624
...@@ -68,6 +68,7 @@ class Module(metaclass=ABCMeta): ...@@ -68,6 +68,7 @@ class Module(metaclass=ABCMeta):
*, *,
recursive: bool = True, recursive: bool = True,
with_key: bool = False, with_key: bool = False,
with_parent: bool = False,
prefix: Optional[str] = None, prefix: Optional[str] = None,
predicate: Callable[[Any], bool] = lambda _: True, predicate: Callable[[Any], bool] = lambda _: True,
seen: Optional[Set[int]] = None seen: Optional[Set[int]] = None
...@@ -80,6 +81,7 @@ class Module(metaclass=ABCMeta): ...@@ -80,6 +81,7 @@ class Module(metaclass=ABCMeta):
:param recursive: Whether to recursively scan all the submodules. :param recursive: Whether to recursively scan all the submodules.
:param with_key: Whether to yield keys along with yielded objects. :param with_key: Whether to yield keys along with yielded objects.
:param with_parent: Whether to yield ``self`` along with yielded objects.
:param prefix: The prefix appended to the yielded keys. :param prefix: The prefix appended to the yielded keys.
:param predicate: The predicate function applied to scanned objects. :param predicate: The predicate function applied to scanned objects.
:param seen: A dict that records whether a module has been traversed yet. :param seen: A dict that records whether a module has been traversed yet.
...@@ -88,7 +90,7 @@ class Module(metaclass=ABCMeta): ...@@ -88,7 +90,7 @@ class Module(metaclass=ABCMeta):
seen = set([id(self)]) seen = set([id(self)])
module_dict = vars(self) module_dict = vars(self)
_prefix = "" if not prefix else prefix + "." _prefix = "" if prefix is None else prefix + "."
for key in sorted(module_dict): for key in sorted(module_dict):
for expanded_key, leaf in _expand_structure(key, module_dict[key]): for expanded_key, leaf in _expand_structure(key, module_dict[key]):
...@@ -98,8 +100,12 @@ class Module(metaclass=ABCMeta): ...@@ -98,8 +100,12 @@ class Module(metaclass=ABCMeta):
seen.add(leaf_id) seen.add(leaf_id)
if predicate(leaf): if predicate(leaf):
if with_key: if with_key and with_parent:
yield _prefix + expanded_key, leaf, self
elif with_key:
yield _prefix + expanded_key, leaf yield _prefix + expanded_key, leaf
elif with_parent:
yield leaf, self
else: else:
yield leaf yield leaf
...@@ -107,22 +113,22 @@ class Module(metaclass=ABCMeta): ...@@ -107,22 +113,22 @@ class Module(metaclass=ABCMeta):
yield from leaf._flatten( yield from leaf._flatten(
recursive=recursive, recursive=recursive,
with_key=with_key, with_key=with_key,
prefix=None if prefix is None else _prefix + expanded_key, with_parent=with_parent,
prefix=_prefix + expanded_key if with_key else None,
predicate=predicate, predicate=predicate,
seen=seen, seen=seen,
) )
def parameters( def parameters(
self, requires_grad: Optional[bool] = None, recursive: bool = True self, requires_grad: Optional[bool] = None, recursive: bool = True, **kwargs
) -> Iterable[Parameter]: ) -> Iterable[Parameter]:
r"""Returns an iterable for the :class:`~.Parameter` of the module. r"""Returns an iterable for the :class:`~.Parameter` of the module.
:param requires_grad: Limitation over the :attr:`~.Parameter.requires_grad` :param requires_grad: Limitation over the :attr:`~.Parameter.requires_grad`
attribute of returned :class:`.Parameter`. ``None`` for attribute of returned :class:`.Parameter`. ``None`` for no limitation.
no limitation.
:param recursive: If ``True``, returns all :class:`~.Parameter` within this :param recursive: If ``True``, returns all :class:`~.Parameter` within this
module, else only returns :class:`~.Parameter` that are direct module, else only returns :class:`~.Parameter` that are direct attributes
attributes of this module. of this module.
""" """
def predicate(obj) -> bool: def predicate(obj) -> bool:
...@@ -130,24 +136,26 @@ class Module(metaclass=ABCMeta): ...@@ -130,24 +136,26 @@ class Module(metaclass=ABCMeta):
requires_grad is None or obj.requires_grad == requires_grad requires_grad is None or obj.requires_grad == requires_grad
) )
yield from self._flatten(predicate=predicate, recursive=recursive) yield from self._flatten(
with_key=False, predicate=predicate, recursive=recursive, **kwargs
)
def named_parameters( def named_parameters(
self, self,
requires_grad: Optional[bool] = None, requires_grad: Optional[bool] = None,
prefix: str = "", prefix: Optional[str] = None,
recursive: bool = True, recursive: bool = True,
**kwargs
) -> Iterable[Tuple[str, Parameter]]: ) -> Iterable[Tuple[str, Parameter]]:
"""Returns an iterable for key :class:`~.Parameter` pairs of the module, where """Returns an iterable for key :class:`~.Parameter` pairs of the module, where
``key`` is the dotted path from this module to the :class:`~.Parameter` . ``key`` is the dotted path from this module to the :class:`~.Parameter` .
:param requires_grad: Limitation over the :attr:`~.Parameter.requires_grad` :param requires_grad: Limitation over the :attr:`~.Parameter.requires_grad`
attribute of returned :class:`~.Parameter` . ``None`` for attribute of returned :class:`~.Parameter` . ``None`` for no limitation.
no limitation.
:param prefix: The prefix prepended to the keys. :param prefix: The prefix prepended to the keys.
:param recursive: If ``True``, returns all :class:`~.Parameter` within this :param recursive: If ``True``, returns all :class:`~.Parameter` within this
module, else only returns :class:`~.Parameter` that are direct module, else only returns :class:`~.Parameter` that are direct attributes
attributes of this module. of this module.
""" """
def predicate(obj) -> bool: def predicate(obj) -> bool:
...@@ -156,17 +164,23 @@ class Module(metaclass=ABCMeta): ...@@ -156,17 +164,23 @@ class Module(metaclass=ABCMeta):
) )
yield from self._flatten( yield from self._flatten(
with_key=True, prefix=prefix, predicate=predicate, recursive=recursive with_key=True,
prefix=prefix,
predicate=predicate,
recursive=recursive,
**kwargs,
) )
def buffers(self, recursive: bool = True) -> Iterable[Buffer]: def buffers(self, recursive: bool = True, **kwargs) -> Iterable[Buffer]:
"""Returns an iterable for the :class:`~.Buffer` of the module. """Returns an iterable for the :class:`~.Buffer` of the module.
:param recursive: If ``True``, returns all :class:`~.Buffer` within this :param recursive: If ``True``, returns all :class:`~.Buffer` within this
module, else only returns :class:`~.Buffer` that are direct module, else only returns :class:`~.Buffer` that are direct attributes
attributes of this module. of this module.
""" """
yield from self._flatten(predicate=_is_buffer, recursive=recursive) yield from self._flatten(
with_key=False, predicate=_is_buffer, recursive=recursive, **kwargs
)
def replace_param( def replace_param(
self, params: dict, start_pos: int, seen: Optional[Set[int]] = None self, params: dict, start_pos: int, seen: Optional[Set[int]] = None
...@@ -192,48 +206,66 @@ class Module(metaclass=ABCMeta): ...@@ -192,48 +206,66 @@ class Module(metaclass=ABCMeta):
return offset return offset
def named_buffers( def named_buffers(
self, prefix: str = "", recursive: bool = True self, prefix: Optional[str] = None, recursive: bool = True, **kwargs
) -> Iterable[Tuple[str, Buffer]]: ) -> Iterable[Tuple[str, Buffer]]:
"""Returns an iterable for key :class:`~.Buffer` pairs of the module, where """Returns an iterable for key :class:`~.Buffer` pairs of the module, where
``key`` is the dotted path from this module to the :class:`~.Buffer` . ``key`` is the dotted path from this module to the :class:`~.Buffer` .
:param prefix: The prefix prepended to the keys. :param prefix: The prefix prepended to the keys.
:param recursive: If ``True``, returns all :class:`~.Buffer` within this :param recursive: If ``True``, returns all :class:`~.Buffer` within this
module, else only returns :class:`~.Buffer` that are direct module, else only returns :class:`~.Buffer` that are direct attributes
attributes of this module. of this module.
""" """
yield from self._flatten( yield from self._flatten(
with_key=True, prefix=prefix, predicate=_is_buffer, recursive=recursive with_key=True,
prefix=prefix,
predicate=_is_buffer,
recursive=recursive,
**kwargs,
) )
def children(self) -> "Iterable[Module]": def children(self, **kwargs) -> "Iterable[Module]":
"""Returns an iterable for all the submodules that are direct attributes of this """Returns an iterable for all the submodules that are direct attributes of this
module. module.
""" """
yield from self._flatten(predicate=_is_module, recursive=False) yield from self._flatten(
with_key=False, predicate=_is_module, recursive=False, **kwargs
)
def named_children(self) -> "Iterable[Tuple[str, Module]]": def named_children(self, **kwargs) -> "Iterable[Tuple[str, Module]]":
"""Returns an iterable of key-submodule pairs for all the submodules that are """Returns an iterable of key-submodule pairs for all the submodules that are
direct attributes of this module, where 'key' is the attribute name of direct attributes of this module, where 'key' is the attribute name of
submodules. submodules.
""" """
yield from self._flatten(with_key=True, predicate=_is_module, recursive=False) yield from self._flatten(
with_key=True, predicate=_is_module, recursive=False, **kwargs
)
def modules(self) -> "Iterable[Module]": def modules(self, **kwargs) -> "Iterable[Module]":
"""Returns an iterable for all the modules within this module, including itself. """Returns an iterable for all the modules within this module, including itself.
""" """
yield self if "with_parent" in kwargs and kwargs["with_parent"]:
yield from self._flatten(predicate=_is_module) yield self, None
else:
yield self
yield from self._flatten(with_key=False, predicate=_is_module, **kwargs)
def named_modules(self, prefix: str = "") -> "Iterable[Tuple[str, Module]]": def named_modules(
self, prefix: Optional[str] = None, **kwargs
) -> "Iterable[Tuple[str, Module]]":
"""Returns an iterable of key-module pairs for all the modules within this """Returns an iterable of key-module pairs for all the modules within this
module, including itself, where 'key' is the dotted path from this module to the module, including itself, where 'key' is the dotted path from this module to the
submodules. submodules.
:param prefix: The prefix prepended to the path. :param prefix: The prefix prepended to the path.
""" """
yield prefix, self if "with_parent" in kwargs and kwargs["with_parent"]:
yield from self._flatten(with_key=True, prefix=prefix, predicate=_is_module) yield ("" if prefix is None else prefix), self, None
else:
yield ("" if prefix is None else prefix), self
yield from self._flatten(
with_key=True, prefix=prefix, predicate=_is_module, **kwargs
)
def apply(self, fn: "Callable[[Module], Any]") -> None: def apply(self, fn: "Callable[[Module], Any]") -> None:
"""Apply function ``fn`` to all the modules within this module, including """Apply function ``fn`` to all the modules within this module, including
......
...@@ -53,9 +53,6 @@ class Optimizer(metaclass=ABCMeta): ...@@ -53,9 +53,6 @@ class Optimizer(metaclass=ABCMeta):
if isinstance(params, (Parameter, dict)): if isinstance(params, (Parameter, dict)):
params = [params] params = [params]
else: else:
assert isinstance(
params, Iterable
), "params argument given to the optimizer should be Parameter or dict"
if not isinstance(params, Iterable): if not isinstance(params, Iterable):
raise TypeError( raise TypeError(
"params argument given to the optimizer should be " "params argument given to the optimizer should be "
...@@ -65,13 +62,15 @@ class Optimizer(metaclass=ABCMeta): ...@@ -65,13 +62,15 @@ class Optimizer(metaclass=ABCMeta):
self.param_groups = [] # type: list self.param_groups = [] # type: list
param_groups = list(params) param_groups = list(params)
assert len(param_groups) != 0, "optimizer got an empty parameter list" if len(param_groups) == 0:
raise ValueError("optimizer got an empty parameter list")
param_type = type(param_groups[0]) param_type = type(param_groups[0])
for param in param_groups: for param in param_groups:
assert isinstance( if not isinstance(param, param_type):
param, param_type raise TypeError(
), "types of params argument given to the optimizer shoud be same" "types of params argument given to the optimizer shoud be same"
)
if not isinstance(param_groups[0], dict): if not isinstance(param_groups[0], dict):
param_groups = [{"params": param_groups}] param_groups = [{"params": param_groups}]
...@@ -150,7 +149,7 @@ class Optimizer(metaclass=ABCMeta): ...@@ -150,7 +149,7 @@ class Optimizer(metaclass=ABCMeta):
def backward(self, loss: Tensor): def backward(self, loss: Tensor):
"""Computes the back-propagation of the network given loss. """Computes the back-propagation of the network given loss.
:param loss: The obtained loss tensor :param loss: The obtained loss tensor
""" """
rst = [] rst = []
key = 0 key = 0
......
...@@ -15,7 +15,7 @@ from helpers import MLP ...@@ -15,7 +15,7 @@ from helpers import MLP
import megengine as mge import megengine as mge
from megengine.core import Buffer, Parameter, tensor from megengine.core import Buffer, Parameter, tensor
from megengine.module import BatchNorm1d, BatchNorm2d, Conv2d, Module from megengine.module import BatchNorm1d, BatchNorm2d, Conv2d, Module, Sequential
from megengine.test import assertTensorClose from megengine.test import assertTensorClose
...@@ -156,7 +156,7 @@ class MyModule2(Module): ...@@ -156,7 +156,7 @@ class MyModule2(Module):
return x return x
def test_mode_api_expand_structure(): def test_expand_structure():
m = MyModule2() m = MyModule2()
assert list(m.named_modules()) == [ assert list(m.named_modules()) == [
("", m), ("", m),
...@@ -171,6 +171,62 @@ def test_mode_api_expand_structure(): ...@@ -171,6 +171,62 @@ def test_mode_api_expand_structure():
] ]
def test_flatten_with_parent():
m = MyModule2()
assert list(m.named_modules(with_parent=True)) == [
("", m, None),
("a.0", m.a[0], m),
("a.1.x", m.a[1]["x"], m),
("a.1.y.0", m.a[1]["y"][0], m),
("a.1.y.1", m.a[1]["y"][1], m),
("a.1.y.1.bn", m.a[1]["y"][1].bn, m.a[1]["y"][1]),
("a.2.0", m.a[2][0], m),
("a.2.0.bn", m.a[2][0].bn, m.a[2][0]),
("bn", m.bn, m),
]
assert list(m.modules(with_parent=True)) == [
(m, None),
(m.a[0], m),
(m.a[1]["x"], m),
(m.a[1]["y"][0], m),
(m.a[1]["y"][1], m),
(m.a[1]["y"][1].bn, m.a[1]["y"][1]),
(m.a[2][0], m),
(m.a[2][0].bn, m.a[2][0]),
(m.bn, m),
]
class MyModule3(Module):
class InnerModule(Module):
def __init__(self):
super().__init__()
self.bn = BatchNorm2d(4)
def forward(self, x):
x = self.bn(x)
def __init__(self):
super().__init__()
self.bn = BatchNorm2d(4)
self.seq = Sequential(BatchNorm2d(4), self.InnerModule(),)
def forward(self, x):
return x
def test_module_api_with_sequential():
m = MyModule3()
assert list(m.named_modules()) == [
("", m),
("bn", m.bn),
("seq", m.seq),
("seq.0", m.seq[0]),
("seq.1", m.seq[1]),
("seq.1.bn", m.seq[1].bn),
]
def test_state_dict(): def test_state_dict():
data_shape = (2, 28) data_shape = (2, 28)
data = tensor() data = tensor()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册