提交 786c36ff 编写于 作者: M Megvii Engine Team

fix(mge/tools): rename `net_stats` in function and examples to match file name

GitOrigin-RevId: 82a1377d6688f915d4f4a32354a3c3f8db712f9f
上级 9c90ce8c
......@@ -31,7 +31,7 @@ def visualize(
):
r"""
Load megengine dumped model and visualize graph structure with tensorboard log files.
Can also record and print model's statistics like :func:`~.net_stats`
Can also record and print model's statistics like :func:`~.module_stats`
:param model_path: dir path for megengine dumped model.
:param log_path: dir path for tensorboard graph log.
......
......@@ -187,7 +187,7 @@ def print_params_stats(params, bar_length_max=20):
return total_param_size
def net_stats(
def module_stats(
model: m.Module,
input_size: int,
bar_length_max: int = 20,
......@@ -212,7 +212,7 @@ def net_stats(
else:
return 4
def net_stats_hook(module, input, output, name=""):
def module_stats_hook(module, input, output, name=""):
class_name = str(module.__class__).split(".")[-1].split("'")[0]
flops_fun = CALC_FLOPS.get(type(module))
......@@ -280,7 +280,7 @@ def net_stats(
for (name, module) in model.named_modules():
if isinstance(module, hook_modules):
hooks.append(
module.register_forward_hook(partial(net_stats_hook, name=name))
module.register_forward_hook(partial(module_stats_hook, name=name))
)
inputs = [zeros(in_size, dtype=np.float32) for in_size in input_size]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册