From 91264f3797ce8949a73827d02e99c33759106ddb Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 13 Sep 2021 21:47:32 +0800 Subject: [PATCH] fix(traced_module): fix __getattr__ of TracedModuleBuilder GitOrigin-RevId: 94d91d6938ac3ee50fc0d8a72e3d36031b2ef34b --- .../megengine/traced_module/module_tracer.py | 2 +- .../megengine/traced_module/traced_module.py | 6 ++++++ .../test/unit/traced_module/test_trace_module.py | 16 ++++++++++++++++ 3 files changed, 23 insertions(+), 1 deletion(-) diff --git a/imperative/python/megengine/traced_module/module_tracer.py b/imperative/python/megengine/traced_module/module_tracer.py index 6505fe724..96b62a6e1 100644 --- a/imperative/python/megengine/traced_module/module_tracer.py +++ b/imperative/python/megengine/traced_module/module_tracer.py @@ -185,7 +185,6 @@ class PatchedFn: class Patcher: - patched_fn_ids = set() _builtin_functions = [] _builtin_modules = [ F, @@ -207,6 +206,7 @@ class Patcher: ] def __init__(self, wrap_fn): + self.patched_fn_ids = set() self.patched_fn = [] self.visited_frames_ids = set() self.wrap_fn = wrap_fn diff --git a/imperative/python/megengine/traced_module/traced_module.py b/imperative/python/megengine/traced_module/traced_module.py index 41d90de82..839115935 100644 --- a/imperative/python/megengine/traced_module/traced_module.py +++ b/imperative/python/megengine/traced_module/traced_module.py @@ -17,6 +17,7 @@ import re import weakref from inspect import getcallargs, getmembers, isclass, ismethod from itertools import chain +from types import FunctionType from typing import Callable, Dict, Iterable, List, Optional, Sequence, Type, Union from megengine import tensor @@ -1150,6 +1151,11 @@ class TracedModuleBuilder(NodeMixin): else: attr = getattr(self._mod, name) full_name = None + if ( + isinstance(attr, FunctionType) + and id(attr) in active_module_tracer().patcher.patched_fn_ids + ): + return active_module_tracer().patcher.wrap_fn(attr) if id(attr) in active_module_tracer().id2name: full_name = active_module_tracer().id2name[id(attr)] diff --git a/imperative/python/test/unit/traced_module/test_trace_module.py b/imperative/python/test/unit/traced_module/test_trace_module.py index fd0d8f610..49432d66c 100644 --- a/imperative/python/test/unit/traced_module/test_trace_module.py +++ b/imperative/python/test/unit/traced_module/test_trace_module.py @@ -1,8 +1,10 @@ import numpy as np +import megengine.functional as F import megengine.module as M from megengine import Tensor from megengine.traced_module import TracedModule, trace_module +from megengine.traced_module.expr import CallFunction class MyModule1(M.Module): @@ -38,6 +40,15 @@ class MyModule3(M.Module): return y +class MyModule4(M.Module): + def __init__(self): + super().__init__() + self.add = F.add + + def forward(self, x, y): + return self.add(x, y) + + def test_trace_module(): x = Tensor(1) @@ -67,3 +78,8 @@ def test_trace_module(): assert isinstance(tm3.modules.__dict__["0"], M.Elemwise) assert isinstance(tm3.modules.__dict__["2"], TracedModule) assert isinstance(tm3.modules.__dict__["2"].a, M.Elemwise) + + m4 = MyModule4() + tm4 = trace_module(m4, a, b) + assert len(tm4.graph._exprs) == 1 + assert isinstance(tm4.graph._exprs[0], CallFunction) -- GitLab