提交 b8316de5 编写于 作者: M Megvii Engine Team

fix(mge/traced_module): patch missed functional in functional.nn

GitOrigin-RevId: 5aa1316fb11dd897a7788aacb36bade57e747d44
上级 dbca3270
...@@ -211,6 +211,8 @@ class Patcher: ...@@ -211,6 +211,8 @@ class Patcher:
self.wrap_fn = wrap_fn self.wrap_fn = wrap_fn
for module in self._builtin_modules: for module in self._builtin_modules:
self.patch_module(module) self.patch_module(module)
# some functions in F.nn are import from other module, and not in __all__
self.auto_patch(F.nn.__dict__, False)
for meth in BUILTIN_ARRAY_METHOD: for meth in BUILTIN_ARRAY_METHOD:
self.patch_method(ArrayMethodMixin, meth, self.wrap_fn) self.patch_method(ArrayMethodMixin, meth, self.wrap_fn)
self.patch_method(Tensor, "detach", self.wrap_fn) self.patch_method(Tensor, "detach", self.wrap_fn)
...@@ -256,8 +258,8 @@ class Patcher: ...@@ -256,8 +258,8 @@ class Patcher:
self.patch_function(module.__dict__, k, self.wrap_fn) self.patch_function(module.__dict__, k, self.wrap_fn)
self.visited_frames_ids.add(id(module.__dict__)) self.visited_frames_ids.add(id(module.__dict__))
def auto_patch(self, frame_dict): def auto_patch(self, frame_dict, check_frame_id=True):
if id(frame_dict) not in self.visited_frames_ids: if id(frame_dict) not in self.visited_frames_ids or not check_frame_id:
for k, v in frame_dict.items(): for k, v in frame_dict.items():
if id(v) in self.patched_fn_ids: if id(v) in self.patched_fn_ids:
self.patch_function(frame_dict, k, self.wrap_fn) self.patch_function(frame_dict, k, self.wrap_fn)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册