diff --git a/imperative/python/megengine/traced_module/pytree.py b/imperative/python/megengine/traced_module/pytree.py index 0c62dc2830abab5371c76872f28b5dacadcaf88a..98d19f1e4a9be17bf9a2789b599ac1b3cb3ad7cf 100644 --- a/imperative/python/megengine/traced_module/pytree.py +++ b/imperative/python/megengine/traced_module/pytree.py @@ -9,6 +9,7 @@ import collections from collections import OrderedDict, defaultdict from functools import partial +from inspect import FullArgSpec from typing import Callable, NamedTuple import numpy as np @@ -53,6 +54,7 @@ SUPPORTED_LEAF_TYPE = { QuantMode, ArgsIndex, Group, + FullArgSpec, } USER_REGISTERED_LEAF_TYPE = [] diff --git a/imperative/python/megengine/traced_module/traced_module.py b/imperative/python/megengine/traced_module/traced_module.py index b4a7f9b1d802accb7fcb573929691f6b3fcc6c1e..edb9f2fbec8102ff435c9998dfdd4ec0ba7f64f2 100644 --- a/imperative/python/megengine/traced_module/traced_module.py +++ b/imperative/python/megengine/traced_module/traced_module.py @@ -1928,8 +1928,11 @@ class TracedModule(Module): self.watch_node_value = {} self.end_points = [] self.is_qat = is_qat + self.argspec = None def forward(self, *args, **kwargs): + if hasattr(self, "argspec") and self.argspec is not None: + args, kwargs = _convert_kwargs_to_args(self.argspec, args, kwargs, True) inputs, treedef = tree_flatten(((self, *args), kwargs)) assert treedef in self.argdef_graph_map inputs = filter( @@ -2422,8 +2425,12 @@ def trace_module( NodeMixin.wrap_safe( builder, Input.make(name="top", type=ModuleNode, qualname=net_name) ) - args, kwargs = _convert_kwargs_to_args(mod.forward, args, kwargs, True) - + forward_argspec = ( + mod.argspec + if hasattr(mod, "argspec") + else inspect.getfullargspec(mod.forward) + ) + args, kwargs = _convert_kwargs_to_args(forward_argspec, args, kwargs, True) inputs, _ = tree_flatten((args, kwargs)) for _, i in enumerate(inputs): # assert isinstance(i, Tensor), "not support " @@ -2439,6 +2446,7 @@ def trace_module( builder(*args, **kwargs) active_module_tracer().pop_scope() traced_mod = builder.build() + traced_mod.argspec = forward_argspec traced_mod.graph._reset_ids() return traced_mod finally: diff --git a/imperative/python/megengine/traced_module/utils.py b/imperative/python/megengine/traced_module/utils.py index 9038a51183bfa7f32f067f4857448500fe0dab18..48094d5a4ecdc7c3fd28ebdb92265d5cfa5eb193 100644 --- a/imperative/python/megengine/traced_module/utils.py +++ b/imperative/python/megengine/traced_module/utils.py @@ -9,7 +9,8 @@ import collections import copy import inspect from collections.abc import MutableMapping, MutableSequence -from typing import Dict, Iterable, List, Optional, Sequence, Type +from inspect import FullArgSpec +from typing import Callable, Dict, Iterable, List, Optional, Sequence, Type, Union from .. import get_logger from ..module import Module @@ -57,9 +58,14 @@ def replace_container_with_module_container(container): return has_module, module_container -def _convert_kwargs_to_args(func, args, kwargs, is_bounded=False): +def _convert_kwargs_to_args( + argspecs: Union[Callable, FullArgSpec], args, kwargs, is_bounded=False +): # is_bounded = True when func is a method and provided args don't include 'self' - arg_specs = inspect.getfullargspec(func) + arg_specs = ( + inspect.getfullargspec(argspecs) if isinstance(argspecs, Callable) else argspecs + ) + assert isinstance(arg_specs, FullArgSpec) arg_specs_args = arg_specs.args if is_bounded: arg_specs_args = arg_specs.args[1:] diff --git a/imperative/python/test/unit/traced_module/test_trace_module.py b/imperative/python/test/unit/traced_module/test_trace_module.py index d18d08351240d7995e7b0dc512dee9f563b8b227..e4441c49f372cf4e8c3e4073874322f3d7349026 100644 --- a/imperative/python/test/unit/traced_module/test_trace_module.py +++ b/imperative/python/test/unit/traced_module/test_trace_module.py @@ -5,6 +5,7 @@ import numpy as np import megengine.functional as F import megengine.module as M from megengine import Tensor +from megengine.module.module import Module from megengine.traced_module import TracedModule, trace_module from megengine.traced_module.expr import CallFunction @@ -89,5 +90,46 @@ def test_trace_module(): m4 = MyModule4() tm4 = trace_module(m4, a, b) + np.testing.assert_equal(tm4(a, b).numpy(), 3) + np.testing.assert_equal(tm4(a, y=b).numpy(), 3) + np.testing.assert_equal(tm4(x=a, y=b).numpy(), 3) + + tm4 = trace_module(m4, a, y=b) + np.testing.assert_equal(tm4(a, b).numpy(), 3) + np.testing.assert_equal(tm4(a, y=b).numpy(), 3) + np.testing.assert_equal(tm4(x=a, y=b).numpy(), 3) + + tm4 = trace_module(m4, x=a, y=b) + np.testing.assert_equal(tm4(a, b).numpy(), 3) + np.testing.assert_equal(tm4(a, y=b).numpy(), 3) + np.testing.assert_equal(tm4(x=a, y=b).numpy(), 3) + + tm5 = trace_module(tm4, a, b) + np.testing.assert_equal(tm5(a, b).numpy(), 3) + np.testing.assert_equal(tm5(a, y=b).numpy(), 3) + np.testing.assert_equal(tm5(x=a, y=b).numpy(), 3) + + tm5 = trace_module(tm4, a, y=b) + np.testing.assert_equal(tm5(a, b).numpy(), 3) + np.testing.assert_equal(tm5(a, y=b).numpy(), 3) + np.testing.assert_equal(tm5(x=a, y=b).numpy(), 3) + + tm5 = trace_module(tm4, x=a, y=b) + np.testing.assert_equal(tm5(a, b).numpy(), 3) + np.testing.assert_equal(tm5(a, y=b).numpy(), 3) + np.testing.assert_equal(tm5(x=a, y=b).numpy(), 3) + assert len(tm4.graph._exprs) == 1 assert isinstance(tm4.graph._exprs[0], CallFunction) + + class MyModule5(Module): + def __init__(self): + super().__init__() + self.m1 = tm4 + + def forward(self, x, y): + return self.m1(x, y) + + tm6 = trace_module(MyModule5(), a, b) + assert tm6.m1.argspec is None + assert tm6.m1._is_top is False