提交 0185c1a9 编写于 作者: M Megvii Engine Team

feat(mge/traced_module): add argspec for top TracedModule

GitOrigin-RevId: 8e31a00c7e69b7efa15cfef6b7eee6861535eaea
上级 1daeba76
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
import collections import collections
from collections import OrderedDict, defaultdict from collections import OrderedDict, defaultdict
from functools import partial from functools import partial
from inspect import FullArgSpec
from typing import Callable, NamedTuple from typing import Callable, NamedTuple
import numpy as np import numpy as np
...@@ -53,6 +54,7 @@ SUPPORTED_LEAF_TYPE = { ...@@ -53,6 +54,7 @@ SUPPORTED_LEAF_TYPE = {
QuantMode, QuantMode,
ArgsIndex, ArgsIndex,
Group, Group,
FullArgSpec,
} }
USER_REGISTERED_LEAF_TYPE = [] USER_REGISTERED_LEAF_TYPE = []
......
...@@ -1928,8 +1928,11 @@ class TracedModule(Module): ...@@ -1928,8 +1928,11 @@ class TracedModule(Module):
self.watch_node_value = {} self.watch_node_value = {}
self.end_points = [] self.end_points = []
self.is_qat = is_qat self.is_qat = is_qat
self.argspec = None
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
if hasattr(self, "argspec") and self.argspec is not None:
args, kwargs = _convert_kwargs_to_args(self.argspec, args, kwargs, True)
inputs, treedef = tree_flatten(((self, *args), kwargs)) inputs, treedef = tree_flatten(((self, *args), kwargs))
assert treedef in self.argdef_graph_map assert treedef in self.argdef_graph_map
inputs = filter( inputs = filter(
...@@ -2422,8 +2425,12 @@ def trace_module( ...@@ -2422,8 +2425,12 @@ def trace_module(
NodeMixin.wrap_safe( NodeMixin.wrap_safe(
builder, Input.make(name="top", type=ModuleNode, qualname=net_name) builder, Input.make(name="top", type=ModuleNode, qualname=net_name)
) )
args, kwargs = _convert_kwargs_to_args(mod.forward, args, kwargs, True) forward_argspec = (
mod.argspec
if hasattr(mod, "argspec")
else inspect.getfullargspec(mod.forward)
)
args, kwargs = _convert_kwargs_to_args(forward_argspec, args, kwargs, True)
inputs, _ = tree_flatten((args, kwargs)) inputs, _ = tree_flatten((args, kwargs))
for _, i in enumerate(inputs): for _, i in enumerate(inputs):
# assert isinstance(i, Tensor), "not support " # assert isinstance(i, Tensor), "not support "
...@@ -2439,6 +2446,7 @@ def trace_module( ...@@ -2439,6 +2446,7 @@ def trace_module(
builder(*args, **kwargs) builder(*args, **kwargs)
active_module_tracer().pop_scope() active_module_tracer().pop_scope()
traced_mod = builder.build() traced_mod = builder.build()
traced_mod.argspec = forward_argspec
traced_mod.graph._reset_ids() traced_mod.graph._reset_ids()
return traced_mod return traced_mod
finally: finally:
......
...@@ -9,7 +9,8 @@ import collections ...@@ -9,7 +9,8 @@ import collections
import copy import copy
import inspect import inspect
from collections.abc import MutableMapping, MutableSequence from collections.abc import MutableMapping, MutableSequence
from typing import Dict, Iterable, List, Optional, Sequence, Type from inspect import FullArgSpec
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Type, Union
from .. import get_logger from .. import get_logger
from ..module import Module from ..module import Module
...@@ -57,9 +58,14 @@ def replace_container_with_module_container(container): ...@@ -57,9 +58,14 @@ def replace_container_with_module_container(container):
return has_module, module_container return has_module, module_container
def _convert_kwargs_to_args(func, args, kwargs, is_bounded=False): def _convert_kwargs_to_args(
argspecs: Union[Callable, FullArgSpec], args, kwargs, is_bounded=False
):
# is_bounded = True when func is a method and provided args don't include 'self' # is_bounded = True when func is a method and provided args don't include 'self'
arg_specs = inspect.getfullargspec(func) arg_specs = (
inspect.getfullargspec(argspecs) if isinstance(argspecs, Callable) else argspecs
)
assert isinstance(arg_specs, FullArgSpec)
arg_specs_args = arg_specs.args arg_specs_args = arg_specs.args
if is_bounded: if is_bounded:
arg_specs_args = arg_specs.args[1:] arg_specs_args = arg_specs.args[1:]
......
...@@ -5,6 +5,7 @@ import numpy as np ...@@ -5,6 +5,7 @@ import numpy as np
import megengine.functional as F import megengine.functional as F
import megengine.module as M import megengine.module as M
from megengine import Tensor from megengine import Tensor
from megengine.module.module import Module
from megengine.traced_module import TracedModule, trace_module from megengine.traced_module import TracedModule, trace_module
from megengine.traced_module.expr import CallFunction from megengine.traced_module.expr import CallFunction
...@@ -89,5 +90,46 @@ def test_trace_module(): ...@@ -89,5 +90,46 @@ def test_trace_module():
m4 = MyModule4() m4 = MyModule4()
tm4 = trace_module(m4, a, b) tm4 = trace_module(m4, a, b)
np.testing.assert_equal(tm4(a, b).numpy(), 3)
np.testing.assert_equal(tm4(a, y=b).numpy(), 3)
np.testing.assert_equal(tm4(x=a, y=b).numpy(), 3)
tm4 = trace_module(m4, a, y=b)
np.testing.assert_equal(tm4(a, b).numpy(), 3)
np.testing.assert_equal(tm4(a, y=b).numpy(), 3)
np.testing.assert_equal(tm4(x=a, y=b).numpy(), 3)
tm4 = trace_module(m4, x=a, y=b)
np.testing.assert_equal(tm4(a, b).numpy(), 3)
np.testing.assert_equal(tm4(a, y=b).numpy(), 3)
np.testing.assert_equal(tm4(x=a, y=b).numpy(), 3)
tm5 = trace_module(tm4, a, b)
np.testing.assert_equal(tm5(a, b).numpy(), 3)
np.testing.assert_equal(tm5(a, y=b).numpy(), 3)
np.testing.assert_equal(tm5(x=a, y=b).numpy(), 3)
tm5 = trace_module(tm4, a, y=b)
np.testing.assert_equal(tm5(a, b).numpy(), 3)
np.testing.assert_equal(tm5(a, y=b).numpy(), 3)
np.testing.assert_equal(tm5(x=a, y=b).numpy(), 3)
tm5 = trace_module(tm4, x=a, y=b)
np.testing.assert_equal(tm5(a, b).numpy(), 3)
np.testing.assert_equal(tm5(a, y=b).numpy(), 3)
np.testing.assert_equal(tm5(x=a, y=b).numpy(), 3)
assert len(tm4.graph._exprs) == 1 assert len(tm4.graph._exprs) == 1
assert isinstance(tm4.graph._exprs[0], CallFunction) assert isinstance(tm4.graph._exprs[0], CallFunction)
class MyModule5(Module):
def __init__(self):
super().__init__()
self.m1 = tm4
def forward(self, x, y):
return self.m1(x, y)
tm6 = trace_module(MyModule5(), a, b)
assert tm6.m1.argspec is None
assert tm6.m1._is_top is False
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册