diff --git a/imperative/python/megengine/traced_module/module_tracer.py b/imperative/python/megengine/traced_module/module_tracer.py index 59eb5531719494a9052b17f9dd2cf856c2d800f4..a0d5dceb945996b1f95698d253299213e343bf65 100644 --- a/imperative/python/megengine/traced_module/module_tracer.py +++ b/imperative/python/megengine/traced_module/module_tracer.py @@ -211,6 +211,8 @@ class Patcher: self.wrap_fn = wrap_fn for module in self._builtin_modules: self.patch_module(module) + # some functions in F.nn are import from other module, and not in __all__ + self.auto_patch(F.nn.__dict__, False) for meth in BUILTIN_ARRAY_METHOD: self.patch_method(ArrayMethodMixin, meth, self.wrap_fn) self.patch_method(Tensor, "detach", self.wrap_fn) @@ -256,8 +258,8 @@ class Patcher: self.patch_function(module.__dict__, k, self.wrap_fn) self.visited_frames_ids.add(id(module.__dict__)) - def auto_patch(self, frame_dict): - if id(frame_dict) not in self.visited_frames_ids: + def auto_patch(self, frame_dict, check_frame_id=True): + if id(frame_dict) not in self.visited_frames_ids or not check_frame_id: for k, v in frame_dict.items(): if id(v) in self.patched_fn_ids: self.patch_function(frame_dict, k, self.wrap_fn)