提交 b2944559 编写于 作者: M Megvii Engine Team 提交者: huangxinda

fix(imperative/module): remove ``__getattribute__`` method in module

GitOrigin-RevId: 5ac525f010eb0d23d3f83788709fe1360f03e9f3
上级 77ead937
......@@ -609,14 +609,6 @@ class Module(metaclass=ABCMeta):
return set(loaded), set(skipped)
def __getattribute__(self, name: str):
value = super().__getattribute__(name)
if name == "__dict__":
return value
for prefix, variable in _expand_structure(name, value):
variable._name = prefix
return value
def __setattr__(self, name: str, value):
is_module_like = _is_module(value) or isinstance(value, (list, tuple, dict))
if name != "_modules":
......@@ -631,6 +623,15 @@ class Module(metaclass=ABCMeta):
else:
if modules is not None and name in modules:
modules.remove(name)
for k, v in _expand_structure(name, value):
if not v._name:
v._name = k
else:
logger.warning(
"try setting the submodule `{}` to a new attribute `{}`, its name `{}` will remain unchanged".format(
v._name, k, v._name
)
)
super().__setattr__(name, value)
def __delattr__(self, name: str):
......
......@@ -368,10 +368,10 @@ class AssertModule(Module):
def test_assert_message():
m = AssertModule()
with pytest.raises(
AssertionError, match="keys for Tensor and Module must be str, error key: True"
):
m = AssertModule()
list(m._flatten())
......
......@@ -155,13 +155,13 @@ def test_with_submodule_in_container(symbolic):
m = Simple("simple")
ops = _dump_and_load(m, symbolic)
assert ops[-1].outputs[0].name == "simple.l2.l2-1.ADD"
assert ops[-1].name == "simple.l2.l2-1.ADD"
assert ops[-2].name == "simple.l2.l2-1.MatrixMul"
assert ops[-3].name == "simple.l1.1.ADD"
assert ops[-4].name == "simple.l1.1.MatrixMul"
assert ops[-5].name == "simple.l0.1.ADD"
assert ops[-6].name == "simple.l0.1.MatrixMul"
assert ops[-1].outputs[0].name == "simple.l0.1.ADD[2]"
assert ops[-1].name == "simple.l0.1.ADD[2]"
assert ops[-2].name == "simple.l0.1.MatrixMul[2]"
assert ops[-3].name == "simple.l0.1.ADD[1]"
assert ops[-4].name == "simple.l0.1.MatrixMul[1]"
assert ops[-5].name == "simple.l0.1.ADD[0]"
assert ops[-6].name == "simple.l0.1.MatrixMul[0]"
@pytest.mark.parametrize("symbolic", [False, True])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册