提交 f75261d7 编写于 作者: M Megvii Engine Team

fix(mge/utils): fix module_status error

GitOrigin-RevId: 9e004d98a17e408fd63b162f4f8fe868aaad57bd
上级 9fd2e663
...@@ -443,7 +443,7 @@ def module_stats( ...@@ -443,7 +443,7 @@ def module_stats(
if isinstance(x, np.ndarray): if isinstance(x, np.ndarray):
return Tensor(x) return Tensor(x)
elif isinstance(x, collections.abc.Mapping): elif isinstance(x, collections.abc.Mapping):
return {k: load_tensor(x) for k, v in x.items()} return {k: load_tensor(v) for k, v in x.items()}
elif isinstance(x, tuple) and hasattr(x, "_fields"): # nametuple elif isinstance(x, tuple) and hasattr(x, "_fields"): # nametuple
return type(x)(*(load_tensor(value) for value in x)) return type(x)(*(load_tensor(value) for value in x))
elif isinstance(x, collections.abc.Sequence): elif isinstance(x, collections.abc.Sequence):
......
import collections
import math import math
from copy import deepcopy from copy import deepcopy
...@@ -27,6 +28,31 @@ def test_module_stats(): ...@@ -27,6 +28,31 @@ def test_module_stats():
assert (total_stats.flops, total_stats.act_dims) == (gt_flops, gt_acts,) assert (total_stats.flops, total_stats.act_dims) == (gt_flops, gt_acts,)
@pytest.mark.skipif(
use_symbolic_shape(), reason="This test do not support symbolic shape.",
)
def test_other_input_module_state():
a = [1, 2]
b = {"1": 1, "2": 2}
nt = collections.namedtuple("nt", ["n", "t"])
_nt = nt(n=1, t=2)
net = FakeNet()
net(a)
net(b)
net(_nt)
class FakeNet(M.Module):
def __init__(self):
super().__init__()
def forward(self, x):
assert isinstance(
x,
(np.ndarray, collections.abc.Mapping, collections.abc.Sequence, mge.Tensor),
) or (isinstance(x, tuple) and hasattr(x, "_fields"))
class BasicBlock(M.Module): class BasicBlock(M.Module):
expansion = 1 expansion = 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册