diff --git a/imperative/python/megengine/traced_module/module_tracer.py b/imperative/python/megengine/traced_module/module_tracer.py index 6505fe724fa00c60139280803f2d9c37adebe578..96b62a6e1ceb70befd3ade85f3b92925192caa97 100644 --- a/imperative/python/megengine/traced_module/module_tracer.py +++ b/imperative/python/megengine/traced_module/module_tracer.py @@ -185,7 +185,6 @@ class PatchedFn: class Patcher: - patched_fn_ids = set() _builtin_functions = [] _builtin_modules = [ F, @@ -207,6 +206,7 @@ class Patcher: ] def __init__(self, wrap_fn): + self.patched_fn_ids = set() self.patched_fn = [] self.visited_frames_ids = set() self.wrap_fn = wrap_fn diff --git a/imperative/python/megengine/traced_module/traced_module.py b/imperative/python/megengine/traced_module/traced_module.py index 41d90de827f87f97a66d8b2b303c7f678762a006..8391159351f9396cf2e6788e3f4559c33c944ccc 100644 --- a/imperative/python/megengine/traced_module/traced_module.py +++ b/imperative/python/megengine/traced_module/traced_module.py @@ -17,6 +17,7 @@ import re import weakref from inspect import getcallargs, getmembers, isclass, ismethod from itertools import chain +from types import FunctionType from typing import Callable, Dict, Iterable, List, Optional, Sequence, Type, Union from megengine import tensor @@ -1150,6 +1151,11 @@ class TracedModuleBuilder(NodeMixin): else: attr = getattr(self._mod, name) full_name = None + if ( + isinstance(attr, FunctionType) + and id(attr) in active_module_tracer().patcher.patched_fn_ids + ): + return active_module_tracer().patcher.wrap_fn(attr) if id(attr) in active_module_tracer().id2name: full_name = active_module_tracer().id2name[id(attr)] diff --git a/imperative/python/test/unit/traced_module/test_trace_module.py b/imperative/python/test/unit/traced_module/test_trace_module.py index fd0d8f610d6cd1053365714d704a0cd4b05416c2..49432d66c7c3a9963205a0b44e68a5d467b92d7d 100644 --- a/imperative/python/test/unit/traced_module/test_trace_module.py +++ b/imperative/python/test/unit/traced_module/test_trace_module.py @@ -1,8 +1,10 @@ import numpy as np +import megengine.functional as F import megengine.module as M from megengine import Tensor from megengine.traced_module import TracedModule, trace_module +from megengine.traced_module.expr import CallFunction class MyModule1(M.Module): @@ -38,6 +40,15 @@ class MyModule3(M.Module): return y +class MyModule4(M.Module): + def __init__(self): + super().__init__() + self.add = F.add + + def forward(self, x, y): + return self.add(x, y) + + def test_trace_module(): x = Tensor(1) @@ -67,3 +78,8 @@ def test_trace_module(): assert isinstance(tm3.modules.__dict__["0"], M.Elemwise) assert isinstance(tm3.modules.__dict__["2"], TracedModule) assert isinstance(tm3.modules.__dict__["2"].a, M.Elemwise) + + m4 = MyModule4() + tm4 = trace_module(m4, a, b) + assert len(tm4.graph._exprs) == 1 + assert isinstance(tm4.graph._exprs[0], CallFunction)