提交 f2f33565 编写于 作者: M Megvii Engine Team

fix(mge): fix some minor problems

GitOrigin-RevId: 43abda2ab9cad1b39830e00af774b1368436cb1d
上级 7591718d
......@@ -158,10 +158,12 @@ def set_log_level(level, update_existing=True):
update_existing: whether to update existing loggers
"""
global _default_level # pylint: disable=global-statement
origin_level = _default_level
_default_level = level
if update_existing:
for i in _all_loggers:
i.setLevel(level)
return origin_level
_logger = get_logger(__name__)
......
......@@ -32,7 +32,7 @@ class _ConvBnActivation2d(Module):
track_running_stats=True,
**kwargs
):
super().__init__()
super().__init__(**kwargs)
self.conv = Conv2d(
in_channels,
out_channels,
......
......@@ -49,17 +49,13 @@ def _access_structure(obj, key, callback=None):
parent = None
for k in key_list:
parent = cur
if isinstance(cur, (Tensor, Module)):
cur = getattr(cur, k)
elif isinstance(cur, (list, tuple)):
if isinstance(cur, (list, tuple)):
k = int(k)
cur = cur[k]
elif isinstance(cur, dict):
cur = cur[k]
else:
raise ValueError(
"Unsupport value type {} to access attribute".format(type(cur))
)
cur = getattr(cur, k)
return callback(parent, k, cur)
......@@ -650,8 +646,8 @@ class Module(metaclass=ABCMeta):
v._name = k
elif v._name != k:
logger.warning(
"try setting the submodule `{}` to a new attribute `{}`, its name `{}` will remain unchanged".format(
v._name, k, v._name
"try setting the submodule `{}` to `{}`'s new attribute `{}`, its name `{}` will remain unchanged".format(
type(v), type(self), k, v._name
)
)
super().__setattr__(name, value)
......
......@@ -111,10 +111,8 @@ class QParams:
return "QParams({})".format(content)
class LSQParams:
r"""To standardize LSQ's qparams format. If custom
qparams is needed, inherit this class and add custom ``__slots__``.
"""
class LSQParams(QParams):
r"""LSQ qparams with extra grad_scale slot."""
__slots__ = "mode", "dtype_meta", "scale", "zero_point", "grad_scale"
......@@ -126,30 +124,9 @@ class LSQParams:
zero_point: Tensor,
grad_scale: Tensor,
):
self.mode = mode
self.dtype_meta = dtype_meta
self.scale = scale
self.zero_point = zero_point
super().__init__(mode, dtype_meta, scale, zero_point)
self.grad_scale = grad_scale
def update(self, lsqparams: "LSQParams"):
for key in self.__slots__:
setattr(self, key, getattr(lsqparams, key))
def __eq__(self, other):
if len(self.__slots__) != len(other.__slots__):
return False
for key in self.__slots__:
if not hasattr(other, key) or getattr(self, key) != getattr(other, key):
return False
return True
def __repr__(self):
content = ", ".join(
["{}={}".format(key, getattr(self, key)) for key in self.__slots__]
)
return "LSQParams({})".format(content)
class QParamsModuleMixin(abc.ABC):
def get_quantized_dtype(self):
......
......@@ -642,7 +642,6 @@ class InternalGraph:
Returns:
A :class:`~.TracedModule.NodeFilterType`.
"""
assert issubclass(module_cls, Module)
return self.nodes(recursive).type(module_cls)
def get_node_by_id(self, node_id: List[int] = None, recursive=True):
......
......@@ -96,6 +96,12 @@ class _ModuleList(Module, MutableSequence):
raise IndexError("list index out of range")
return rst if len(rst) > 1 else rst[0]
def __setattr__(self, key, value):
# clear mod name to avoid warning in Module's setattr
if isinstance(value, Module):
value._name = None
super().__setattr__(key, value)
def __setitem__(self, idx: int, mod: Module):
if not isinstance(mod, Module):
raise ValueError("invalid sub-module")
......@@ -159,6 +165,12 @@ class _ModuleDict(Module, MutableMapping):
def __getitem__(self, key):
return getattr(self, key)
def __setattr__(self, key, value):
# clear mod name to avoid warning in Module's setattr
if isinstance(value, Module):
value._name = None
super().__setattr__(key, value)
def __setitem__(self, key, value):
if not isinstance(value, Module):
raise ValueError("invalid sub-module")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册