提交 38b7cfde 编写于 作者: M Megvii Engine Team

fix(mge/utils): fix module states input is dict or others

GitOrigin-RevId: f9701b6134bf663345260e03f7f8a213a8fcb050
上级 16131359
......@@ -437,7 +437,21 @@ def module_stats(
has_inputs = True
if not isinstance(inputs, (tuple, list)):
inputs = [inputs]
inputs = [Tensor(input, dtype=np.float32) for input in inputs]
def load_tensor(x):
if isinstance(x, np.ndarray):
return Tensor(x)
elif isinstance(x, collections.abc.Mapping):
return {k: load_tensor(x) for k, v in x.items()}
elif isinstance(x, tuple) and hasattr(x, "_fields"): # nametuple
return type(x)(*(load_tensor(value) for value in x))
elif isinstance(x, collections.abc.Sequence):
return [load_tensor(v) for v in x]
else:
return Tensor(x, dtype=np.float32)
inputs = load_tensor(inputs)
else:
if input_shapes:
if not isinstance(input_shapes[0], tuple):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册