diff --git a/imperative/python/megengine/logger.py b/imperative/python/megengine/logger.py index f60fefe85967c5bbefbe47105216a0cc0ea064a3..9c3182af950e2e662888500647a675599f98a575 100644 --- a/imperative/python/megengine/logger.py +++ b/imperative/python/megengine/logger.py @@ -158,10 +158,12 @@ def set_log_level(level, update_existing=True): update_existing: whether to update existing loggers """ global _default_level # pylint: disable=global-statement + origin_level = _default_level _default_level = level if update_existing: for i in _all_loggers: i.setLevel(level) + return origin_level _logger = get_logger(__name__) diff --git a/imperative/python/megengine/module/conv_bn.py b/imperative/python/megengine/module/conv_bn.py index fdaabaa592f222a9ffc8bc29898d7a626045bf74..5d87a688f8ef5a0016001edce6e07529a3244ab3 100644 --- a/imperative/python/megengine/module/conv_bn.py +++ b/imperative/python/megengine/module/conv_bn.py @@ -32,7 +32,7 @@ class _ConvBnActivation2d(Module): track_running_stats=True, **kwargs ): - super().__init__() + super().__init__(**kwargs) self.conv = Conv2d( in_channels, out_channels, diff --git a/imperative/python/megengine/module/module.py b/imperative/python/megengine/module/module.py index 593aab602dae56f35ea76ff8e4ba8d50c354aa53..e47359382505268f7047e2edc039eb4bf1be4e87 100644 --- a/imperative/python/megengine/module/module.py +++ b/imperative/python/megengine/module/module.py @@ -49,17 +49,13 @@ def _access_structure(obj, key, callback=None): parent = None for k in key_list: parent = cur - if isinstance(cur, (Tensor, Module)): - cur = getattr(cur, k) - elif isinstance(cur, (list, tuple)): + if isinstance(cur, (list, tuple)): k = int(k) cur = cur[k] elif isinstance(cur, dict): cur = cur[k] else: - raise ValueError( - "Unsupport value type {} to access attribute".format(type(cur)) - ) + cur = getattr(cur, k) return callback(parent, k, cur) @@ -650,8 +646,8 @@ class Module(metaclass=ABCMeta): v._name = k elif v._name != k: logger.warning( - "try setting the submodule `{}` to a new attribute `{}`, its name `{}` will remain unchanged".format( - v._name, k, v._name + "try setting the submodule `{}` to `{}`'s new attribute `{}`, its name `{}` will remain unchanged".format( + type(v), type(self), k, v._name ) ) super().__setattr__(name, value) diff --git a/imperative/python/megengine/quantization/utils.py b/imperative/python/megengine/quantization/utils.py index e48edc992f7e657028083c924af9896947e7a93a..14c5c9252a1cc47495f8989737654c1f7e7bb136 100644 --- a/imperative/python/megengine/quantization/utils.py +++ b/imperative/python/megengine/quantization/utils.py @@ -111,10 +111,8 @@ class QParams: return "QParams({})".format(content) -class LSQParams: - r"""To standardize LSQ's qparams format. If custom - qparams is needed, inherit this class and add custom ``__slots__``. - """ +class LSQParams(QParams): + r"""LSQ qparams with extra grad_scale slot.""" __slots__ = "mode", "dtype_meta", "scale", "zero_point", "grad_scale" @@ -126,30 +124,9 @@ class LSQParams: zero_point: Tensor, grad_scale: Tensor, ): - self.mode = mode - self.dtype_meta = dtype_meta - self.scale = scale - self.zero_point = zero_point + super().__init__(mode, dtype_meta, scale, zero_point) self.grad_scale = grad_scale - def update(self, lsqparams: "LSQParams"): - for key in self.__slots__: - setattr(self, key, getattr(lsqparams, key)) - - def __eq__(self, other): - if len(self.__slots__) != len(other.__slots__): - return False - for key in self.__slots__: - if not hasattr(other, key) or getattr(self, key) != getattr(other, key): - return False - return True - - def __repr__(self): - content = ", ".join( - ["{}={}".format(key, getattr(self, key)) for key in self.__slots__] - ) - return "LSQParams({})".format(content) - class QParamsModuleMixin(abc.ABC): def get_quantized_dtype(self): diff --git a/imperative/python/megengine/traced_module/traced_module.py b/imperative/python/megengine/traced_module/traced_module.py index 4f06de2726cb26912a6f5b5b43d951d6d6b17fa6..1000be7be59d0b70b54d37ce296de60bfaf90758 100644 --- a/imperative/python/megengine/traced_module/traced_module.py +++ b/imperative/python/megengine/traced_module/traced_module.py @@ -642,7 +642,6 @@ class InternalGraph: Returns: A :class:`~.TracedModule.NodeFilterType`. """ - assert issubclass(module_cls, Module) return self.nodes(recursive).type(module_cls) def get_node_by_id(self, node_id: List[int] = None, recursive=True): diff --git a/imperative/python/megengine/traced_module/utils.py b/imperative/python/megengine/traced_module/utils.py index f9c3c46e59a0144bd5678be96e10ce0c4e0ce2e7..67bb06b658fab01b342dc49fd712b9dda8ade286 100644 --- a/imperative/python/megengine/traced_module/utils.py +++ b/imperative/python/megengine/traced_module/utils.py @@ -96,6 +96,12 @@ class _ModuleList(Module, MutableSequence): raise IndexError("list index out of range") return rst if len(rst) > 1 else rst[0] + def __setattr__(self, key, value): + # clear mod name to avoid warning in Module's setattr + if isinstance(value, Module): + value._name = None + super().__setattr__(key, value) + def __setitem__(self, idx: int, mod: Module): if not isinstance(mod, Module): raise ValueError("invalid sub-module") @@ -159,6 +165,12 @@ class _ModuleDict(Module, MutableMapping): def __getitem__(self, key): return getattr(self, key) + def __setattr__(self, key, value): + # clear mod name to avoid warning in Module's setattr + if isinstance(value, Module): + value._name = None + super().__setattr__(key, value) + def __setitem__(self, key, value): if not isinstance(value, Module): raise ValueError("invalid sub-module")