提交 0185c1a9 编写于 作者: M Megvii Engine Team

feat(mge/traced_module): add argspec for top TracedModule

GitOrigin-RevId: 8e31a00c7e69b7efa15cfef6b7eee6861535eaea
上级 1daeba76
......@@ -9,6 +9,7 @@
import collections
from collections import OrderedDict, defaultdict
from functools import partial
from inspect import FullArgSpec
from typing import Callable, NamedTuple
import numpy as np
......@@ -53,6 +54,7 @@ SUPPORTED_LEAF_TYPE = {
QuantMode,
ArgsIndex,
Group,
FullArgSpec,
}
USER_REGISTERED_LEAF_TYPE = []
......
......@@ -1928,8 +1928,11 @@ class TracedModule(Module):
self.watch_node_value = {}
self.end_points = []
self.is_qat = is_qat
self.argspec = None
def forward(self, *args, **kwargs):
if hasattr(self, "argspec") and self.argspec is not None:
args, kwargs = _convert_kwargs_to_args(self.argspec, args, kwargs, True)
inputs, treedef = tree_flatten(((self, *args), kwargs))
assert treedef in self.argdef_graph_map
inputs = filter(
......@@ -2422,8 +2425,12 @@ def trace_module(
NodeMixin.wrap_safe(
builder, Input.make(name="top", type=ModuleNode, qualname=net_name)
)
args, kwargs = _convert_kwargs_to_args(mod.forward, args, kwargs, True)
forward_argspec = (
mod.argspec
if hasattr(mod, "argspec")
else inspect.getfullargspec(mod.forward)
)
args, kwargs = _convert_kwargs_to_args(forward_argspec, args, kwargs, True)
inputs, _ = tree_flatten((args, kwargs))
for _, i in enumerate(inputs):
# assert isinstance(i, Tensor), "not support "
......@@ -2439,6 +2446,7 @@ def trace_module(
builder(*args, **kwargs)
active_module_tracer().pop_scope()
traced_mod = builder.build()
traced_mod.argspec = forward_argspec
traced_mod.graph._reset_ids()
return traced_mod
finally:
......
......@@ -9,7 +9,8 @@ import collections
import copy
import inspect
from collections.abc import MutableMapping, MutableSequence
from typing import Dict, Iterable, List, Optional, Sequence, Type
from inspect import FullArgSpec
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Type, Union
from .. import get_logger
from ..module import Module
......@@ -57,9 +58,14 @@ def replace_container_with_module_container(container):
return has_module, module_container
def _convert_kwargs_to_args(func, args, kwargs, is_bounded=False):
def _convert_kwargs_to_args(
argspecs: Union[Callable, FullArgSpec], args, kwargs, is_bounded=False
):
# is_bounded = True when func is a method and provided args don't include 'self'
arg_specs = inspect.getfullargspec(func)
arg_specs = (
inspect.getfullargspec(argspecs) if isinstance(argspecs, Callable) else argspecs
)
assert isinstance(arg_specs, FullArgSpec)
arg_specs_args = arg_specs.args
if is_bounded:
arg_specs_args = arg_specs.args[1:]
......
......@@ -5,6 +5,7 @@ import numpy as np
import megengine.functional as F
import megengine.module as M
from megengine import Tensor
from megengine.module.module import Module
from megengine.traced_module import TracedModule, trace_module
from megengine.traced_module.expr import CallFunction
......@@ -89,5 +90,46 @@ def test_trace_module():
m4 = MyModule4()
tm4 = trace_module(m4, a, b)
np.testing.assert_equal(tm4(a, b).numpy(), 3)
np.testing.assert_equal(tm4(a, y=b).numpy(), 3)
np.testing.assert_equal(tm4(x=a, y=b).numpy(), 3)
tm4 = trace_module(m4, a, y=b)
np.testing.assert_equal(tm4(a, b).numpy(), 3)
np.testing.assert_equal(tm4(a, y=b).numpy(), 3)
np.testing.assert_equal(tm4(x=a, y=b).numpy(), 3)
tm4 = trace_module(m4, x=a, y=b)
np.testing.assert_equal(tm4(a, b).numpy(), 3)
np.testing.assert_equal(tm4(a, y=b).numpy(), 3)
np.testing.assert_equal(tm4(x=a, y=b).numpy(), 3)
tm5 = trace_module(tm4, a, b)
np.testing.assert_equal(tm5(a, b).numpy(), 3)
np.testing.assert_equal(tm5(a, y=b).numpy(), 3)
np.testing.assert_equal(tm5(x=a, y=b).numpy(), 3)
tm5 = trace_module(tm4, a, y=b)
np.testing.assert_equal(tm5(a, b).numpy(), 3)
np.testing.assert_equal(tm5(a, y=b).numpy(), 3)
np.testing.assert_equal(tm5(x=a, y=b).numpy(), 3)
tm5 = trace_module(tm4, x=a, y=b)
np.testing.assert_equal(tm5(a, b).numpy(), 3)
np.testing.assert_equal(tm5(a, y=b).numpy(), 3)
np.testing.assert_equal(tm5(x=a, y=b).numpy(), 3)
assert len(tm4.graph._exprs) == 1
assert isinstance(tm4.graph._exprs[0], CallFunction)
class MyModule5(Module):
def __init__(self):
super().__init__()
self.m1 = tm4
def forward(self, x, y):
return self.m1(x, y)
tm6 = trace_module(MyModule5(), a, b)
assert tm6.m1.argspec is None
assert tm6.m1._is_top is False
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册