提交 5a38ad39 编写于 作者: M Megvii Engine Team

feat(mge/utils): add get/set_expand_structure to deal with complex key

GitOrigin-RevId: 4d1b952068ffda21189f315ad70888dee80bc65f
上级 fad5bc74
......@@ -21,9 +21,9 @@ from ..utils.naming import auto_naming
logger = get_logger(__name__)
def _expand_structure(key, obj):
def _expand_structure(prefix, obj):
if isinstance(obj, (Tensor, Module)):
return [(key, obj)]
return [(prefix, obj)]
elif isinstance(obj, (list, tuple, dict)):
ret = []
if isinstance(obj, dict):
......@@ -37,12 +37,32 @@ def _expand_structure(key, obj):
"keys for Tensor and Module must be str, error key: {}".format(k)
)
for kt, vt in sub_ret:
ret.extend([(key + "." + kt, vt)])
ret.extend([(prefix + "." + kt, vt)])
return ret
else:
return []
def _access_structure(obj, key, callback=None):
key_list = key.split(".")
cur = obj
parent = None
for k in key_list:
parent = cur
if isinstance(cur, (Tensor, Module)):
cur = getattr(cur, k)
elif isinstance(cur, (list, tuple)):
k = int(k)
cur = cur[k]
elif isinstance(cur, dict):
cur = cur[k]
else:
raise ValueError(
"Unsupport value type {} to access attribute".format(type(cur))
)
return callback(parent, k, cur)
def _is_parameter(obj):
return isinstance(obj, Parameter)
......
......@@ -18,9 +18,9 @@ class Sequential(Module):
Alternatively, an ordered dict of modules can also be passed in.
To make it easier to understand, here is a small example:
Examples:
.. testcode::
import numpy as np
......
......@@ -7,7 +7,7 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from copy import copy, deepcopy
from functools import partial
from typing import Callable, Dict, Tuple
from typing import Callable
import numpy as np
......@@ -19,6 +19,7 @@ from ..module import quantized as Quantized
from ..module.qat import QATModule
from ..module.quantized import QuantizedModule
from ..tensor import Tensor
from ..utils.module_utils import set_expand_structure
from .qconfig import QConfig, ema_fakequant_qconfig
......@@ -79,11 +80,7 @@ def quantize(module: Module, inplace: bool = True, mapping: dict = None):
module._flatten(with_key=True, with_parent=True, predicate=is_qat)
):
new_mod = convert_dict[type(submodule)].from_qat_module(submodule)
if isinstance(parent, Float.Sequential):
# cannnot use setattr to be compatible with Sequential's ``__setitem__``
parent[int(key.split(".")[-1])] = new_mod
else:
setattr(parent, key.split(".")[-1], new_mod)
set_expand_structure(parent, key, new_mod)
return module
......@@ -126,11 +123,7 @@ def quantize_qat(
continue
new_mod = convert_dict[type(submodule)].from_float_module(submodule)
if isinstance(parent, Float.Sequential):
# cannnot use setattr to be compatible with Sequential's ``__setitem__``
parent[int(key.split(".")[-1])] = new_mod
else:
setattr(parent, key.split(".")[-1], new_mod)
set_expand_structure(parent, key, new_mod)
propagate_qconfig(module, qconfig)
return module
......
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from collections import Iterable
from ..module import Sequential
from ..module.module import Module, _access_structure
from ..tensor import Tensor
def get_expand_structure(obj: Module, key: str):
"""
Gets Module's attribute compatible with complex key from Module's :meth:`~.named_children`.
Supports handling structure containing list or dict.
"""
def f(_, __, cur):
return cur
return _access_structure(obj, key, callback=f)
def set_expand_structure(obj: Module, key: str, value):
"""
Sets Module's attribute compatible with complex key from Module's :meth:`~.named_children`.
Supports handling structure containing list or dict.
"""
def f(parent, key, cur):
if isinstance(parent, (Tensor, Module)):
# cannnot use setattr to be compatible with Sequential's ``__setitem__``
if isinstance(cur, Sequential):
parent[int(key)] = value
else:
setattr(parent, key, value)
else:
parent[key] = value
_access_structure(obj, key, callback=f)
......@@ -6,8 +6,6 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import os
import tempfile
from collections import OrderedDict
from io import BytesIO
......@@ -29,7 +27,9 @@ from megengine.module import (
Sequential,
Softmax,
)
from megengine.module.module import _access_structure
from megengine.quantization.quantize import quantize, quantize_qat
from megengine.utils.module_utils import get_expand_structure, set_expand_structure
class MLP(Module):
......@@ -45,146 +45,6 @@ class MLP(Module):
return x
def has_gpu(num=1):
try:
mgb.comp_node("gpu{}".format(num - 1))
except mgb.MegBrainError:
return False
return True
def randomNp(*args):
for arg in args:
assert isinstance(arg, int)
return np.random.random(args)
def randomTorch(*args):
import torch # pylint: disable=import-outside-toplevel
for arg in args:
assert isinstance(arg, int)
return torch.tensor(randomNp(*args), dtype=torch.float32)
def graph_mode(*modes):
if not set(modes).issubset({"eager", "static"}):
raise ValueError("graph mode must be in (eager, static)")
def decorator(func):
def wrapper(*args, **kwargs):
if "eager" in set(modes):
func(*args, **kwargs)
if "static" in set(modes):
with Graph() as cg:
cg.set_option("eager_evaluation", False)
func(*args, **kwargs)
return wrapper
return decorator
def _default_compare_fn(x, y):
np.testing.assert_allclose(x.numpy(), y, rtol=1e-6)
def opr_test(
cases,
func,
mode=("eager", "static", "dynamic_shape"),
compare_fn=_default_compare_fn,
ref_fn=None,
**kwargs
):
"""
mode: the list of test mode which are eager, static and dynamic_shape
will test all the cases if None.
func: the function to run opr.
compare_fn: the function to compare the result and expected, use np.testing.assert_allclose if None.
ref_fn: the function to generate expected data, should assign output if None.
cases: the list which have dict element, the list length should be 2 for dynamic shape test.
and the dict should have input,
and should have output if ref_fn is None.
should use list for multiple inputs and outputs for each case.
kwargs: The additional kwargs for opr func.
simple examples:
dtype = np.float32
cases = [{"input": [10, 20]}, {"input": [20, 30]}]
opr_test(cases,
F.eye,
ref_fn=lambda n, m: np.eye(n, m).astype(dtype),
dtype=dtype)
"""
def check_results(results, expected):
if not isinstance(results, Tuple):
results = (results,)
for r, e in zip(results, expected):
compare_fn(r, e)
def get_trace_fn(func, enabled, symbolic):
jit.trace.enabled = enabled
return jit.trace(func, symbolic=symbolic)
def get_param(cases, idx):
case = cases[idx]
inp = case.get("input", None)
outp = case.get("output", None)
if inp is None:
raise ValueError("the test case should have input")
if not isinstance(inp, List):
inp = (inp,)
else:
inp = tuple(inp)
if ref_fn is not None and callable(ref_fn):
outp = ref_fn(*inp)
if outp is None:
raise ValueError("the test case should have output or reference function")
if not isinstance(outp, List):
outp = (outp,)
else:
outp = tuple(outp)
return inp, outp
if not set(mode).issubset({"eager", "static", "dynamic_shape"}):
raise ValueError("opr test mode must be in (eager, static, dynamic_shape)")
if len(cases) == 0:
raise ValueError("should give one case at least")
if "dynamic_shape" in set(mode):
if len(cases) != 2:
raise ValueError("should give 2 cases for dynamic shape test")
if not callable(func):
raise ValueError("the input func should be callable")
inp, outp = get_param(cases, 0)
def run(*args, **kwargs):
return func(*args, **kwargs)
if "eager" in set(mode):
f = get_trace_fn(run, False, False)
results = f(*inp, **kwargs)
check_results(results, outp)
if "static" in set(mode) or "dynamic_shape" in set(mode):
f = get_trace_fn(run, True, True)
results = f(*inp, **kwargs)
check_results(results, outp)
if "dynamic_shape" in set(mode):
inp, outp = get_param(cases, 1)
results = f(*inp, **kwargs)
check_results(results, outp)
class MyModule(Module):
class InnerModule(Module):
def __init__(self):
......@@ -306,13 +166,13 @@ def test_module_api_hooks():
post_hook_num = 0
hooks = []
def pre_hook(module, inputs):
def pre_hook(_, inputs):
nonlocal pre_hook_num
pre_hook_num += 1
modified_inputs = tuple(inp + 1 for inp in inputs)
return modified_inputs
def post_hook(module, inputs, outputs):
def post_hook(_, __, outputs):
nonlocal post_hook_num
post_hook_num += 1
outputs += 1
......@@ -376,7 +236,7 @@ class MyModule2(Module):
def test_expand_structure():
m = MyModule2()
assert list(m.named_modules()) == [
rst = [
("", m),
("a.0", m.a[0]),
("a.1.x", m.a[1]["x"]),
......@@ -387,6 +247,16 @@ def test_expand_structure():
("a.2.0.bn", m.a[2][0].bn),
("bn", m.bn),
]
assert list(m.named_modules()) == rst
for item in rst[1:]:
assert get_expand_structure(m, item[0]) == item[1]
for item in reversed(rst[1:]):
if _access_structure(m, item[0], lambda p, k, o: isinstance(p, tuple)):
continue
set_expand_structure(m, item[0], "TEST_VALUE")
assert get_expand_structure(m, item[0]) == "TEST_VALUE"
def test_flatten_others():
......@@ -603,21 +473,6 @@ def test_pickle_module():
np.testing.assert_allclose(pred0.numpy(), pred2.numpy(), atol=5e-6)
@pytest.mark.skip(reason="under development")
def test_dump_model():
data_shape = (2, 28)
data = Tensor(np.random.random(data_shape))
mlp = MLP()
pred = mlp(data)
f = tempfile.NamedTemporaryFile(delete=False)
f_name = f.name
try:
mge.dump(pred, f_name)
finally:
f.close()
os.unlink(f_name)
def test_load_quantized():
from megengine.core.tensor import dtype
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册