提交 68d7320e 编写于 作者: M Megvii Engine Team

fix(mge/utils): fix using wrong function in register_hook_module

GitOrigin-RevId: 75097de1c1755ead246879cc3f276bbced4b5189
上级 3018ca51
...@@ -163,10 +163,10 @@ hook_modules = [ ...@@ -163,10 +163,10 @@ hook_modules = [
def register_hook_module(module): def register_hook_module(module):
if isinstance(module, (tuple, list)): if isinstance(module, (tuple, list)):
modules = list(module) modules = module
for module in modules: for module in modules:
register_hook_module(module) register_hook_module(module)
elif isinstance(module, M.Module): elif issubclass(module, M.Module):
hook_modules.append(module) hook_modules.append(module)
else: else:
raise TypeError("the param type should in [list,tuple,M.Module]") raise TypeError("the param type should in [list,tuple,M.Module]")
......
...@@ -10,7 +10,11 @@ import megengine.functional as F ...@@ -10,7 +10,11 @@ import megengine.functional as F
import megengine.hub as hub import megengine.hub as hub
import megengine.module as M import megengine.module as M
from megengine.core._trace_option import use_symbolic_shape from megengine.core._trace_option import use_symbolic_shape
from megengine.utils.module_stats import module_stats from megengine.utils.module_stats import (
hook_modules,
module_stats,
register_hook_module,
)
@pytest.mark.skipif( @pytest.mark.skipif(
...@@ -75,6 +79,7 @@ def test_getattribute_param(): ...@@ -75,6 +79,7 @@ def test_getattribute_param():
self.conv1 = M.Conv2d( self.conv1 = M.Conv2d(
3, self.in_channels, kernel_size=7, stride=2, padding=3, bias=True 3, self.in_channels, kernel_size=7, stride=2, padding=3, bias=True
) )
self.conv1.reset_parameters()
self.bn1 = M.BatchNorm2d(self.in_channels) self.bn1 = M.BatchNorm2d(self.in_channels)
def forward(self, input): def forward(self, input):
...@@ -90,8 +95,10 @@ def test_getattribute_param(): ...@@ -90,8 +95,10 @@ def test_getattribute_param():
def get_name(obj): def get_name(obj):
return obj["name"] return obj["name"]
param_name = list(map(get_name, params)) param_names = list(map(get_name, params))
assert "conv1-w" in param_name and "conv1-b" in param_name assert "conv1-w" in param_names and "conv1-b" in param_names
conv1_b_param = params[param_names.index("conv1-b")]
assert int(conv1_b_param["mean"]) == 0 and int(conv1_b_param["std"]) == 0
class TestNet0(M.Module): class TestNet0(M.Module):
...@@ -493,3 +500,10 @@ def cal_pool_stats(module, inputs, outputs): ...@@ -493,3 +500,10 @@ def cal_pool_stats(module, inputs, outputs):
np.prod(outputs[0].shape) * (module.kernel_size ** 2), np.prod(outputs[0].shape) * (module.kernel_size ** 2),
np.prod(outputs[0].shape), np.prod(outputs[0].shape),
) )
def test_register_hook_module():
modules = [TestNet0, TestNet1, TestNet2, FakeNet, BasicBlock, ResNet]
register_hook_module(modules)
for module in modules:
assert module in hook_modules
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册