提交 786c36ff 编写于 作者: M Megvii Engine Team

fix(mge/tools): rename `net_stats` in function and examples to match file name

GitOrigin-RevId: 82a1377d6688f915d4f4a32354a3c3f8db712f9f
上级 9c90ce8c
...@@ -31,7 +31,7 @@ def visualize( ...@@ -31,7 +31,7 @@ def visualize(
): ):
r""" r"""
Load megengine dumped model and visualize graph structure with tensorboard log files. Load megengine dumped model and visualize graph structure with tensorboard log files.
Can also record and print model's statistics like :func:`~.net_stats` Can also record and print model's statistics like :func:`~.module_stats`
:param model_path: dir path for megengine dumped model. :param model_path: dir path for megengine dumped model.
:param log_path: dir path for tensorboard graph log. :param log_path: dir path for tensorboard graph log.
......
...@@ -187,7 +187,7 @@ def print_params_stats(params, bar_length_max=20): ...@@ -187,7 +187,7 @@ def print_params_stats(params, bar_length_max=20):
return total_param_size return total_param_size
def net_stats( def module_stats(
model: m.Module, model: m.Module,
input_size: int, input_size: int,
bar_length_max: int = 20, bar_length_max: int = 20,
...@@ -212,7 +212,7 @@ def net_stats( ...@@ -212,7 +212,7 @@ def net_stats(
else: else:
return 4 return 4
def net_stats_hook(module, input, output, name=""): def module_stats_hook(module, input, output, name=""):
class_name = str(module.__class__).split(".")[-1].split("'")[0] class_name = str(module.__class__).split(".")[-1].split("'")[0]
flops_fun = CALC_FLOPS.get(type(module)) flops_fun = CALC_FLOPS.get(type(module))
...@@ -280,7 +280,7 @@ def net_stats( ...@@ -280,7 +280,7 @@ def net_stats(
for (name, module) in model.named_modules(): for (name, module) in model.named_modules():
if isinstance(module, hook_modules): if isinstance(module, hook_modules):
hooks.append( hooks.append(
module.register_forward_hook(partial(net_stats_hook, name=name)) module.register_forward_hook(partial(module_stats_hook, name=name))
) )
inputs = [zeros(in_size, dtype=np.float32) for in_size in input_size] inputs = [zeros(in_size, dtype=np.float32) for in_size in input_size]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册