提交 01d2473c 编写于 作者: M Megvii Engine Team

fix(mge/traced_module): fix TracedModule flatten

GitOrigin-RevId: 7b15fe492b4486d009d603227bec05485457a7da
上级 23c1fda7
......@@ -2078,9 +2078,7 @@ class TracedModule(Module):
for node, repl_node in repl_dict.items():
assert node in graph._inputs or node in graph._outputs
for i in node.users:
if i not in repl_node.users:
repl_node.users.append(i)
repl_node.users.extend(node.users)
rename_blacklist = list(chain(call.inputs, call.outputs))
......
......@@ -6,6 +6,7 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import pickle
from collections import defaultdict
from itertools import chain
import numpy as np
......@@ -52,6 +53,25 @@ class MyModule(M.Module):
return x
class MyBlock1(M.Module):
def forward(self, a):
y = F.concat([a, a])
return a, y
class MyModule1(M.Module):
def __init__(self):
super().__init__()
self.block0 = MyBlock1()
self.block1 = MyBlock1()
def forward(self, a):
a, y1 = self.block0(a)
a = a + 1
a, y2 = self.block1(a)
return a, y1 + y2
class NewModule(M.Module):
def __init__(self, traced_module):
super(NewModule, self).__init__()
......@@ -64,6 +84,17 @@ class NewModule(M.Module):
return x
def _check_expr_users(traced_module):
node_user = defaultdict(list)
for expr in traced_module.graph._exprs:
for node in expr.inputs:
node_user[node].append(expr)
for node in traced_module.graph.nodes():
node.users.sort(key=lambda m: m._id)
node_user[node].sort(key=lambda m: m._id)
assert node.users == node_user[node]
def _init_cls(cls):
module = cls()
x = F.ones((1, 3, 3, 3))
......@@ -201,6 +232,10 @@ def test_flatten():
assert len(traced_module.graph._exprs) == 12
np.testing.assert_equal(expect.numpy(), traced_module(x).numpy())
traced_module, x, expect = _init_cls(MyModule1)
traced_module = traced_module.flatten()
_check_expr_users(traced_module)
def test_id_and_name():
def _check_id(traced_module):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册