diff --git a/imperative/python/megengine/utils/module_stats.py b/imperative/python/megengine/utils/module_stats.py index fe46c0ea64d5401b75339eb3fa5f51cae2160cec..bc86fb42b2f7dd84fc17a147dc744c1a00e60e21 100644 --- a/imperative/python/megengine/utils/module_stats.py +++ b/imperative/python/megengine/utils/module_stats.py @@ -437,7 +437,21 @@ def module_stats( has_inputs = True if not isinstance(inputs, (tuple, list)): inputs = [inputs] - inputs = [Tensor(input, dtype=np.float32) for input in inputs] + + def load_tensor(x): + if isinstance(x, np.ndarray): + return Tensor(x) + elif isinstance(x, collections.abc.Mapping): + return {k: load_tensor(x) for k, v in x.items()} + elif isinstance(x, tuple) and hasattr(x, "_fields"): # nametuple + return type(x)(*(load_tensor(value) for value in x)) + elif isinstance(x, collections.abc.Sequence): + return [load_tensor(v) for v in x] + else: + return Tensor(x, dtype=np.float32) + + inputs = load_tensor(inputs) + else: if input_shapes: if not isinstance(input_shapes[0], tuple):