@@ -604,13 +604,17 @@ class TracedModuleBuilder(NodeMixin):
"_NodeMixin__node",
"_is_builtin",
"build",
"_record_wrapped_nodes",
"_argdef_graph_map",
"_argdef_outdef_map",
"nodes",
"__class__",
"__dict__",
]
def__init__(self,mod,is_top_module=False):
super(TracedModuleBuilder,self).__init__()
assertisinstance(mod,Module)
self._mod=mod
self._body=None
self._is_top=is_top_module
...
...
@@ -618,6 +622,13 @@ class TracedModuleBuilder(NodeMixin):
self._argdef_graph_map={}
self._argdef_outdef_map={}
self.nodes=set()
# The builder will be passed to self._mod.forward as 'self' argument. If the 'forward' uses super().xxx to call method of its base classes, the trace procedure will throw exceprion, because the builder doesn't inherit from self._mod.__bases__.
# modify self.__class__ and let the builder inherit from TracedModuleBuilder and mod.__class__.
self.__class__=type(
"TracedModuleBuilder",
(TracedModuleBuilder,mod.__class__),
dict(TracedModuleBuilder.__dict__),
)
defbuild(self):
ifself._is_builtin:
...
...
@@ -631,8 +642,6 @@ class TracedModuleBuilder(NodeMixin):