提交 f2f33565 编写于 作者: M Megvii Engine Team

fix(mge): fix some minor problems

GitOrigin-RevId: 43abda2ab9cad1b39830e00af774b1368436cb1d
上级 7591718d
...@@ -158,10 +158,12 @@ def set_log_level(level, update_existing=True): ...@@ -158,10 +158,12 @@ def set_log_level(level, update_existing=True):
update_existing: whether to update existing loggers update_existing: whether to update existing loggers
""" """
global _default_level # pylint: disable=global-statement global _default_level # pylint: disable=global-statement
origin_level = _default_level
_default_level = level _default_level = level
if update_existing: if update_existing:
for i in _all_loggers: for i in _all_loggers:
i.setLevel(level) i.setLevel(level)
return origin_level
_logger = get_logger(__name__) _logger = get_logger(__name__)
......
...@@ -32,7 +32,7 @@ class _ConvBnActivation2d(Module): ...@@ -32,7 +32,7 @@ class _ConvBnActivation2d(Module):
track_running_stats=True, track_running_stats=True,
**kwargs **kwargs
): ):
super().__init__() super().__init__(**kwargs)
self.conv = Conv2d( self.conv = Conv2d(
in_channels, in_channels,
out_channels, out_channels,
......
...@@ -49,17 +49,13 @@ def _access_structure(obj, key, callback=None): ...@@ -49,17 +49,13 @@ def _access_structure(obj, key, callback=None):
parent = None parent = None
for k in key_list: for k in key_list:
parent = cur parent = cur
if isinstance(cur, (Tensor, Module)): if isinstance(cur, (list, tuple)):
cur = getattr(cur, k)
elif isinstance(cur, (list, tuple)):
k = int(k) k = int(k)
cur = cur[k] cur = cur[k]
elif isinstance(cur, dict): elif isinstance(cur, dict):
cur = cur[k] cur = cur[k]
else: else:
raise ValueError( cur = getattr(cur, k)
"Unsupport value type {} to access attribute".format(type(cur))
)
return callback(parent, k, cur) return callback(parent, k, cur)
...@@ -650,8 +646,8 @@ class Module(metaclass=ABCMeta): ...@@ -650,8 +646,8 @@ class Module(metaclass=ABCMeta):
v._name = k v._name = k
elif v._name != k: elif v._name != k:
logger.warning( logger.warning(
"try setting the submodule `{}` to a new attribute `{}`, its name `{}` will remain unchanged".format( "try setting the submodule `{}` to `{}`'s new attribute `{}`, its name `{}` will remain unchanged".format(
v._name, k, v._name type(v), type(self), k, v._name
) )
) )
super().__setattr__(name, value) super().__setattr__(name, value)
......
...@@ -111,10 +111,8 @@ class QParams: ...@@ -111,10 +111,8 @@ class QParams:
return "QParams({})".format(content) return "QParams({})".format(content)
class LSQParams: class LSQParams(QParams):
r"""To standardize LSQ's qparams format. If custom r"""LSQ qparams with extra grad_scale slot."""
qparams is needed, inherit this class and add custom ``__slots__``.
"""
__slots__ = "mode", "dtype_meta", "scale", "zero_point", "grad_scale" __slots__ = "mode", "dtype_meta", "scale", "zero_point", "grad_scale"
...@@ -126,30 +124,9 @@ class LSQParams: ...@@ -126,30 +124,9 @@ class LSQParams:
zero_point: Tensor, zero_point: Tensor,
grad_scale: Tensor, grad_scale: Tensor,
): ):
self.mode = mode super().__init__(mode, dtype_meta, scale, zero_point)
self.dtype_meta = dtype_meta
self.scale = scale
self.zero_point = zero_point
self.grad_scale = grad_scale self.grad_scale = grad_scale
def update(self, lsqparams: "LSQParams"):
for key in self.__slots__:
setattr(self, key, getattr(lsqparams, key))
def __eq__(self, other):
if len(self.__slots__) != len(other.__slots__):
return False
for key in self.__slots__:
if not hasattr(other, key) or getattr(self, key) != getattr(other, key):
return False
return True
def __repr__(self):
content = ", ".join(
["{}={}".format(key, getattr(self, key)) for key in self.__slots__]
)
return "LSQParams({})".format(content)
class QParamsModuleMixin(abc.ABC): class QParamsModuleMixin(abc.ABC):
def get_quantized_dtype(self): def get_quantized_dtype(self):
......
...@@ -642,7 +642,6 @@ class InternalGraph: ...@@ -642,7 +642,6 @@ class InternalGraph:
Returns: Returns:
A :class:`~.TracedModule.NodeFilterType`. A :class:`~.TracedModule.NodeFilterType`.
""" """
assert issubclass(module_cls, Module)
return self.nodes(recursive).type(module_cls) return self.nodes(recursive).type(module_cls)
def get_node_by_id(self, node_id: List[int] = None, recursive=True): def get_node_by_id(self, node_id: List[int] = None, recursive=True):
......
...@@ -96,6 +96,12 @@ class _ModuleList(Module, MutableSequence): ...@@ -96,6 +96,12 @@ class _ModuleList(Module, MutableSequence):
raise IndexError("list index out of range") raise IndexError("list index out of range")
return rst if len(rst) > 1 else rst[0] return rst if len(rst) > 1 else rst[0]
def __setattr__(self, key, value):
# clear mod name to avoid warning in Module's setattr
if isinstance(value, Module):
value._name = None
super().__setattr__(key, value)
def __setitem__(self, idx: int, mod: Module): def __setitem__(self, idx: int, mod: Module):
if not isinstance(mod, Module): if not isinstance(mod, Module):
raise ValueError("invalid sub-module") raise ValueError("invalid sub-module")
...@@ -159,6 +165,12 @@ class _ModuleDict(Module, MutableMapping): ...@@ -159,6 +165,12 @@ class _ModuleDict(Module, MutableMapping):
def __getitem__(self, key): def __getitem__(self, key):
return getattr(self, key) return getattr(self, key)
def __setattr__(self, key, value):
# clear mod name to avoid warning in Module's setattr
if isinstance(value, Module):
value._name = None
super().__setattr__(key, value)
def __setitem__(self, key, value): def __setitem__(self, key, value):
if not isinstance(value, Module): if not isinstance(value, Module):
raise ValueError("invalid sub-module") raise ValueError("invalid sub-module")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册