提交 6d1a4f20 编写于 作者: M Megvii Engine Team

feat(traced_module): support tracing submodules in list/dict

GitOrigin-RevId: 4076b47a89ff5fdbe7c94778a649f8a01d6cc0b6
上级 a3f9073c
......@@ -58,6 +58,7 @@ from .module_tracer import (
)
from .node import ModuleNode, Node, NodeMixin, TensorNode
from .pytree import ArgsIndex, tree_flatten
from .utils import replace_container_with_module_container
logger = get_logger(__name__)
......@@ -988,7 +989,9 @@ class TracedModuleBuilder(NodeMixin):
if k not in TracedModuleBuilder.__builder_attributes__:
if isinstance(v, TracedModuleBuilder):
v = v.build()
setattr(traced_module, k, v)
setattr(traced_module, k, v)
elif isinstance(v, RawTensor):
setattr(traced_module, k, v)
if isinstance(self._mod, QATModule):
unset_module_tracing()
......@@ -1146,7 +1149,16 @@ class TracedModuleBuilder(NodeMixin):
if id(attr) in active_module_tracer().id2name:
full_name = active_module_tracer().id2name[id(attr)]
if isinstance(attr, (List, Dict)):
unset_module_tracing()
has_module, m_container = replace_container_with_module_container(attr)
if m_container:
attr = m_container
if has_module and not m_container:
raise ValueError(
"Can not trace the module that uses the same container to store Module and Non-Module objects "
)
set_module_tracing()
if isinstance(attr, Module):
attr = TracedModuleBuilder(attr)
......@@ -1178,17 +1190,22 @@ class TracedModuleBuilder(NodeMixin):
return object.__getattribute__(self, name)
else:
wrapped = object.__getattribute__(self, name)
class_members = dict(inspect.getmembers(self.__class__))
if name in self._mod.__dict__:
mod_attr = getattr(self._mod, name)
if not isinstance(mod_attr, Module) and wrapped is not mod_attr:
wrapped = mod_attr
setattr(self, name, wrapped)
if isinstance(mod_attr, Module):
assert mod_attr is wrapped._mod
if name in class_members:
if (
not isinstance(wrapped, TracedModuleBuilder)
and wrapped is not mod_attr
):
wrapped = self.__getattr__(name)
if isinstance(wrapped, TracedModuleBuilder):
if not isinstance(mod_attr, (List, Dict)):
assert mod_attr is wrapped._mod
else:
assert mod_attr is wrapped
full_name = None
if id(mod_attr) in active_module_tracer().id2name:
full_name = active_module_tracer().id2name[id(mod_attr)]
......@@ -1679,7 +1696,6 @@ def _register_all_builtin_module():
isclass(m[1])
and issubclass(m[1], M.Module)
and m[1] is not M.Sequential
and m[1] is not M.ModuleList
):
module_tracer.register_as_builtin(m[1])
......
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import copy
from collections.abc import MutableMapping, MutableSequence
from typing import Dict, Iterable, List, Optional, Sequence
from ...module import Module
def replace_container_with_module_container(container):
has_module = False
module_container = None
if isinstance(container, Dict):
m_dic = copy.copy(container)
for key, value in container.items():
if isinstance(value, Module):
has_module = True
elif isinstance(value, (List, Dict)):
(
_has_module,
_module_container,
) = replace_container_with_module_container(value)
m_dic[key] = _module_container
if _has_module:
has_module = True
if not all(isinstance(v, Module) for v in m_dic.values()):
return has_module, None
else:
return has_module, _ModuleDict(m_dic)
elif isinstance(container, List):
m_list = copy.copy(container)
for ind, value in enumerate(container):
if isinstance(value, Module):
has_module = True
elif isinstance(value, (List, Dict)):
(
_has_module,
_module_container,
) = replace_container_with_module_container(value)
m_list[ind] = _module_container
if _has_module:
has_module = True
if not all(isinstance(v, Module) for v in m_list):
return has_module, None
else:
return has_module, _ModuleList(m_list)
return has_module, module_container
class _ModuleList(Module, MutableSequence):
r"""
A List-like container.
Using a ``ModuleList``, one can visit, add, delete and modify submodules
just like an ordinary python list.
"""
def __init__(self, modules: Optional[Iterable[Module]] = None):
super().__init__()
self._size = 0
if modules is None:
return
for mod in modules:
self.append(mod)
@classmethod
def _ikey(cls, idx):
return "{}".format(idx)
def _check_idx(self, idx):
L = len(self)
if idx < 0:
idx = L + idx
if idx < 0 or idx >= L:
raise IndexError("list index out of range")
return idx
def __getitem__(self, idx: int):
if isinstance(idx, slice):
idx = range(self._size)[idx]
if not isinstance(idx, Sequence):
idx = [
idx,
]
rst = []
for i in idx:
i = self._check_idx(i)
key = self._ikey(i)
try:
rst.append(getattr(self, key))
except AttributeError:
raise IndexError("list index out of range")
return rst if len(rst) > 1 else rst[0]
def __setitem__(self, idx: int, mod: Module):
if not isinstance(mod, Module):
raise ValueError("invalid sub-module")
idx = self._check_idx(idx)
setattr(self, self._ikey(idx), mod)
def __delitem__(self, idx):
idx = self._check_idx(idx)
L = len(self)
for orig_idx in range(idx + 1, L):
new_idx = orig_idx - 1
self[new_idx] = self[orig_idx]
delattr(self, self._ikey(L - 1))
self._size -= 1
def __len__(self):
return self._size
def insert(self, idx, mod: Module):
assert isinstance(mod, Module)
L = len(self)
if idx < 0:
idx = L - idx
# clip idx to (0, L)
if idx > L:
idx = L
elif idx < 0:
idx = 0
for new_idx in range(L, idx, -1):
orig_idx = new_idx - 1
key = self._ikey(new_idx)
setattr(self, key, self[orig_idx])
key = self._ikey(idx)
setattr(self, key, mod)
self._size += 1
def forward(self):
raise RuntimeError("ModuleList is not callable")
class _ModuleDict(Module, MutableMapping):
r"""
A Dict-like container.
Using a ``ModuleDict``, one can visit, add, delete and modify submodules
just like an ordinary python dict.
"""
def __init__(self, modules: Optional[Dict[str, Module]] = None):
super().__init__()
self._size = 0
if modules is not None:
self.update(modules)
def __delitem__(self, key):
delattr(self, key)
self._size -= 1
def __getitem__(self, key):
return getattr(self, key)
def __setitem__(self, key, value):
if not isinstance(value, Module):
raise ValueError("invalid sub-module")
setattr(self, key, value)
self._size += 1
def __iter__(self):
return iter(self.keys())
def __len__(self):
return self._size
def items(self):
return dict(self.named_children()).items()
def values(self):
return dict(self.named_children()).values()
def keys(self):
return dict(self.named_children()).keys()
def forward(self):
raise RuntimeError("ModuleList is not callable")
import numpy as np
import megengine.module as M
from megengine import Tensor
from megengine.experimental.traced_module import trace_module
from megengine.module import Module as M
from megengine.experimental.traced_module import TracedModule, trace_module
class MyModule1(M):
class MyModule1(M.Module):
def forward(self, x):
y = Tensor(x)
y += 1
......@@ -13,7 +13,7 @@ class MyModule1(M):
return x, y
class MyModule2(M):
class MyModule2(M.Module):
def forward(self, x):
y = Tensor([1, x, 1])
y += 1
......@@ -21,6 +21,23 @@ class MyModule2(M):
return x, y
class MyModule3(M.Module):
def __init__(self):
super().__init__()
self.modules = [
M.Elemwise("ADD"),
M.Elemwise("ADD"),
{"a": M.Elemwise("ADD"), "b": M.Elemwise("ADD")},
]
def forward(self, a, b):
x = self.modules[0](a, b)
y = self.modules[1](a, b)
y = self.modules[2]["a"](x, y)
y = self.modules[2]["b"](x, y)
return y
def test_trace_module():
x = Tensor(1)
......@@ -40,3 +57,13 @@ def test_trace_module():
for a, b in zip(output1, gt1):
np.testing.assert_equal(a.numpy(), b.numpy())
a, b = Tensor(1), Tensor(2)
m3 = MyModule3()
gt = m3(a, b)
tm3 = trace_module(m3, a, b)
out = tm3(a, b)
np.testing.assert_equal(out.numpy(), gt.numpy())
assert isinstance(tm3.modules.__dict__["0"], M.Elemwise)
assert isinstance(tm3.modules.__dict__["2"], TracedModule)
assert isinstance(tm3.modules.__dict__["2"].a, M.Elemwise)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册