提交 a0b3a3c0 编写于 作者: M Megvii Engine Team

feat(imperative): add TracedModule checker

GitOrigin-RevId: 12de7b278e28b7a3e37eb129c7f73c6660e8f300
上级 19993070
......@@ -10,6 +10,7 @@ from ..core._imperative_rt.core2 import set_cpp_apply_module_trace
from . import compat
from ._passes import optimize
from .pytree import register_supported_type
from .tm_config import disable_default_checker, enable_expr_checker
from .traced_module import (
TracedModule,
_register_all_builtin_module,
......@@ -29,4 +30,6 @@ __all__ = [
"wrap",
"TracedModule",
"optimize",
"enable_expr_checker",
"disable_default_checker",
]
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import traceback
from typing import Sequence
import numpy as np
from ..core._imperative_rt.core2 import apply
from ..core._imperative_rt.ops import ROIAlign, ROIPooling
from ..core.ops.builtin import Copy
from ..core.tensor.utils import isscalar, setscalar
from ..tensor import Tensor
from .tm_config import _exclude_from_trace
class TracedModuleChecker:
def __init__(self, tracer):
self._active_node2values = []
self.tracer = tracer
self.node_without_tensor_info = {}
def push_scope(self):
self._active_node2values.append({})
def pop_scope(self):
self._active_node2values.pop()
def current_node2values(self):
return self._active_node2values[-1]
def reset_checker(self):
self._active_node2values = []
def check_node_not_in_scope(self):
if self.node_without_tensor_info:
for node, info in self.node_without_tensor_info.items():
for expr in info[0]._exprs:
if node in expr.inputs or node in expr.outputs:
traceback.print_list(info[1])
raise ValueError(
"node({}) not in the graph:\n{}".format(node, info[0])
)
return True
else:
return False
def check_net_outputs(self, tm_res, gt_res):
if isinstance(tm_res, Tensor):
np.testing.assert_allclose(tm_res.numpy(), gt_res.numpy())
elif isinstance(tm_res, Sequence):
for i, j in zip(tm_res, gt_res):
np.testing.assert_allclose(i.numpy(), j.numpy())
else:
for k in tm_res.__dict__.keys():
np.testing.assert_allclose(
getattr(tm_res, k).numpy(), getattr(gt_res, k).numpy()
)
def record_nodemixin(self, node, value):
self.current_node2values()[node] = value
def record_node2value(self, node, value):
with _exclude_from_trace():
self.current_node2values()[node] = apply(
Copy(comp_node=value.device), value
)[0]
if isscalar(value):
setscalar(self.current_node2values()[node])
def check_apply_special_cases(self, opdef, num_outputs):
indexs = list(range(num_outputs))
if isinstance(opdef, ROIAlign) and opdef.mode == ROIAlign.Mode.AVERAGE:
indexs.pop(-1)
if isinstance(opdef, ROIPooling) and opdef.mode == ROIPooling.Mode.AVERAGE:
indexs.pop(-1)
return indexs
def check_expr_results(self, expr_outputs, gt_outputs, indexs=None):
expr_outputs = (
(expr_outputs,) if not isinstance(expr_outputs, Sequence) else expr_outputs
)
gt_outputs = (
(gt_outputs,) if not isinstance(gt_outputs, Sequence) else gt_outputs
)
if indexs is not None:
for i in indexs:
np.testing.assert_allclose(
expr_outputs[i].numpy(), gt_outputs[i].numpy()
)
else:
np.testing.assert_allclose(expr_outputs, gt_outputs)
def get_node2value(self, inputs, start_idx=0):
inp_values = []
has_node_not_in_scope = False
for i in range(start_idx, len(inputs)):
try:
inp_values.append(self.current_node2values()[inputs[i]])
except:
has_node_not_in_scope = True
self.node_without_tensor_info[inputs[i]] = [
self.tracer.current_scope(),
traceback.extract_stack(),
]
return inp_values, has_node_not_in_scope
def check_expr_interpret(self, expr, gt_outputs):
ori_in, has_node_not_in_scope = self.get_node2value(expr.inputs)
if not has_node_not_in_scope:
expr_res = expr.interpret(*ori_in)
try:
self.check_expr_results(expr_res, gt_outputs)
except:
raise ValueError("Error occurred when checking expr: {}".format(expr))
def check_apply(self, expr, gt_outputs, opdef):
ori_in, has_node_not_in_scope = self.get_node2value(expr.inputs)
if not has_node_not_in_scope:
expr_res = expr.interpret(*ori_in)
indexs = self.check_apply_special_cases(opdef, len(gt_outputs))
try:
self.check_expr_results(expr_res, gt_outputs, indexs=indexs)
except:
raise ValueError("Error occurred when checking expr: {}".format(expr))
def check_builtin_module(self, module, expr, gt_outputs):
ori_in, has_node_not_in_scope = self.get_node2value(expr.inputs, start_idx=1)
if not has_node_not_in_scope:
ori_in.insert(0, module)
expr_res = expr.interpret(*ori_in)
try:
self.check_expr_results(expr_res, gt_outputs)
except:
raise ValueError(
"{}, Error occurred when checking expr: {}".format(expr)
)
......@@ -32,6 +32,7 @@ from .module_tracer import active_module_tracer, module_tracer
from .node import ModuleNode, Node, NodeMixin, TensorNode
from .pytree import ArgsIndex, TreeDef, _is_const_leaf, _is_leaf, tree_flatten
from .serialization import _ModuleState
from .tm_config import _exclude_from_trace, _get_expr_checker
from .utils import _check_builtin_module_attr, _check_obj_attr, _convert_kwargs_to_args
......@@ -611,6 +612,8 @@ class Apply(Expr):
inp_nodes = [NodeMixin.get(inputs[0])]
for i in inputs[1:]:
node = Constant.make(i)
if _get_expr_checker():
active_module_tracer().checker.record_node2value(node, Tensor(i))
inp_nodes.append(node)
apply_node = cls.make(opdef)
for n in inp_nodes:
......@@ -624,11 +627,17 @@ class Apply(Expr):
unset_module_tracing()
outputs = apply(opdef, *inputs)
outputs = list(map(Tensor, outputs))
set_module_tracing()
apply_node.add_outputs(outputs)
for n, v in zip(apply_node.outputs, outputs):
NodeMixin.wrap_safe(v, n)
if _get_expr_checker():
with _exclude_from_trace():
active_module_tracer().checker.check_apply(apply_node, outputs, opdef)
return list(outputs)
......
......@@ -12,6 +12,7 @@ from .. import functional as F
from ..core.tensor.array_method import ArrayMethodMixin
from ..module import Module
from ..module.qat import QATModule
from .checker import TracedModuleChecker
_active_module_tracer = None
......@@ -128,6 +129,7 @@ class module_tracer:
def __init__(self, wrap_fn):
self._active_scopes = []
self.checker = TracedModuleChecker(self)
self.patcher = Patcher(wrap_fn)
@classmethod
......@@ -142,9 +144,11 @@ class module_tracer:
def push_scope(self, scope):
self._active_scopes.append(scope)
self.checker.push_scope()
def pop_scope(self):
self._active_scopes.pop()
self.checker.pop_scope()
def current_scope(self):
if self._active_scopes:
......
......@@ -18,6 +18,8 @@ from ..core._imperative_rt.core2 import Tensor as RawTensor
from ..module import Module
from ..quantization.utils import QParams
from ..tensor import Tensor
from .module_tracer import active_module_tracer
from .tm_config import _get_expr_checker
from .utils import _check_obj_attr
logger = get_logger(__name__)
......@@ -343,6 +345,11 @@ class NodeMixin(abc.ABC):
if isinstance(value, NodeMixin):
value._record_wrapped_nodes(node)
setattr(value, "_NodeMixin__node", node)
if _get_expr_checker():
if isinstance(value, RawTensor):
active_module_tracer().checker.record_node2value(node, value)
if isinstance(value, NodeMixin):
active_module_tracer().checker.record_nodemixin(node, value)
else:
assert callable(node)
n = node()
......@@ -352,6 +359,11 @@ class NodeMixin(abc.ABC):
if isinstance(value, NodeMixin):
value._record_wrapped_nodes(n)
setattr(value, "_NodeMixin__node", n)
if _get_expr_checker():
if isinstance(value, RawTensor):
active_module_tracer().checker.record_node2value(n, value)
if isinstance(value, NodeMixin):
active_module_tracer().checker.record_nodemixin(n, value)
@classmethod
def wrap_safe(cls, value, node):
......@@ -359,6 +371,11 @@ class NodeMixin(abc.ABC):
if isinstance(value, RawTensor):
cls._record_tensornode_property(node, value)
setattr(value, "_NodeMixin__node", node)
if _get_expr_checker():
if isinstance(value, RawTensor):
active_module_tracer().checker.record_node2value(node, value)
if isinstance(value, NodeMixin):
active_module_tracer().checker.record_nodemixin(node, value)
if isinstance(value, NodeMixin):
value._record_wrapped_nodes(node)
......
......@@ -212,7 +212,11 @@ def tree_flatten(
to reconstruct the pytree.
"""
if type(values) not in SUPPORTED_TYPE:
assert is_leaf(values), values
assert is_leaf(
values
), 'doesn\'t support {} type, MUST use "register_supported_type" method to register self-defined type'.format(
values
)
node = LeafDef(leaf_type(values))
if is_const_leaf(values):
node.const_val = values
......
import contextlib
from ..core._imperative_rt.core2 import (
is_tracing_module,
set_module_tracing,
unset_module_tracing,
)
_enable_expr_checker = False
_enable_default_checker = True
def _get_expr_checker():
return _enable_expr_checker
def _get_default_checker():
return _enable_default_checker
def enable_expr_checker():
r"""Call this function to check the result of each expr during tracing."""
global _enable_expr_checker
_enable_expr_checker = True
_enable_default_checker = False
def disable_default_checker():
r"""Call this function to disable checking the final output of the model after tracing."""
global _enable_default_checker
_enable_default_checker = False
_enable_graph_surgery_mode = False
def _graph_surgery_mode():
return _enable_graph_surgery_mode
def _set_graph_surgery_mode(mode: bool):
global _enable_graph_surgery_mode
pre_mode = _enable_graph_surgery_mode
_enable_graph_surgery_mode = mode
return pre_mode
@contextlib.contextmanager
def _exclude_from_trace():
is_tracing = is_tracing_module()
if is_tracing:
unset_module_tracing()
yield
if is_tracing:
set_module_tracing()
......@@ -36,11 +36,14 @@ from .. import get_logger
from .. import module as M
from ..core._imperative_rt.core2 import Tensor as RawTensor
from ..core._imperative_rt.core2 import (
apply,
is_tracing_module,
set_module_tracing,
unset_module_tracing,
)
from ..core._trace_option import set_symbolic_shape
from ..core.ops.builtin import Copy
from ..core.tensor.utils import isscalar, setscalar
from ..module import Module
from ..module import external as MExternal
from ..module.qat import QATModule
......@@ -98,6 +101,13 @@ from .serialization import (
load_call_tensor_method_expr,
load_functional,
)
from .tm_config import (
_exclude_from_trace,
_get_default_checker,
_get_expr_checker,
_graph_surgery_mode,
_set_graph_surgery_mode,
)
from .utils import (
_check_builtin_module_attr,
_check_obj_attr,
......@@ -117,26 +127,14 @@ def _is_builtin_name(name: str) -> bool:
def _is_leaf(node):
assert isinstance(node, RawTensor), "doesn't support {} in return values".format(
assert isinstance(
node, RawTensor
), 'doesn\'t support {} in return values, MUST use Tensor or use "register_supported_type" method to register self-defined type'.format(
type(node)
)
return isinstance(node, RawTensor)
_enable_graph_surgery_mode = False
def _graph_surgery_mode():
return _enable_graph_surgery_mode
def _set_graph_surgery_mode(mode: bool):
global _enable_graph_surgery_mode
pre_mode = _enable_graph_surgery_mode
_enable_graph_surgery_mode = mode
return pre_mode
def _node_to_tensor(*args, **kwargs):
tensors = []
nodes, tree_def = tree_flatten((args, kwargs))
......@@ -1295,7 +1293,12 @@ def _wrapped_function(orig_func):
return orig_func(*args, **kwargs)
if isinstance(args[1], RawTensor):
node = NodeMixin.get(inputs[1])
inputs[1] = copy.copy(inputs[1])
is_scalar = isscalar(inputs[1])
inputs[1] = apply(
Copy(comp_node=inputs[1].device), Tensor(inputs[1])
)[0]
if is_scalar:
setscalar(inputs[1])
# copy inputs[1] to avoid tensor and Tensor(tensor) share same m_tensor,
# which will cause they have same _NodeMixin__node in tracing.
NodeMixin.wrap_safe(inputs[1], node)
......@@ -1319,6 +1322,13 @@ def _wrapped_function(orig_func):
else:
outputs = None
call_node.add_outputs(outputs)
if _get_expr_checker():
with _exclude_from_trace():
active_module_tracer().checker.check_expr_interpret(
call_node, outputs
)
set_module_tracing()
return rst
return orig_func(*args, **kwargs)
......@@ -1500,6 +1510,12 @@ class TracedModuleBuilder(NodeMixin):
unset_module_tracing()
rst = self._mod(*args, **kwargs)
outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf)
if _get_expr_checker():
with _exclude_from_trace():
tmp = self.build()
active_module_tracer().checker.check_builtin_module(
tmp, callnode, outputs
)
set_module_tracing()
if self._is_builtin:
self._body = None
......@@ -1674,7 +1690,9 @@ class TracedModuleBuilder(NodeMixin):
if not isinstance(mod_attr, (List, Dict, QATModule)):
assert mod_attr is wrapped._mod
else:
assert mod_attr is wrapped
assert (
mod_attr is wrapped
), "TracedModule do not support modify attributes, please check your code."
if isinstance(wrapped, (NodeMixin, RawTensor)):
NodeMixin.wrap(
......@@ -2469,11 +2487,23 @@ def trace_module(
qualname="{}.[{}]".format(net_name, "arg_{}".format(_)),
),
)
builder(*args, **kwargs)
rst = builder(*copy.deepcopy(args), **copy.deepcopy(kwargs))
active_module_tracer().pop_scope()
traced_mod = builder.build()
traced_mod.argspec = forward_argspec
traced_mod.graph._reset_ids()
has_expr_not_check = False
if _get_expr_checker():
has_expr_not_check = (
active_module_tracer().checker.check_node_not_in_scope()
)
if _get_default_checker() or has_expr_not_check:
with _exclude_from_trace():
tm_res = traced_mod(*args, **kwargs)
tm_res, _ = tree_flatten(tm_res, is_leaf=_is_leaf)
rst, _ = tree_flatten(rst, is_leaf=_is_leaf)
active_module_tracer().checker.check_net_outputs(tm_res, rst)
return traced_mod
finally:
set_symbolic_shape(use_sym_shape)
......
......@@ -5,16 +5,15 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import collections
import copy
import inspect
from collections.abc import MutableMapping, MutableSequence
from inspect import FullArgSpec
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Type, Union
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Union
from .. import get_logger
from ..module import Module
from ..tensor import Parameter, Tensor
from ..tensor import Tensor
logger = get_logger(__name__)
......
......@@ -109,6 +109,7 @@ def build_observered_net(net: M.Module, observer_cls):
)
Q.enable_observer(qat_net)
inp = Tensor(np.random.random(size=(5, 3, 32, 32)))
qat_net.eval()
qat_net(inp)
Q.disable_observer(qat_net)
return qat_net
......@@ -116,6 +117,7 @@ def build_observered_net(net: M.Module, observer_cls):
def build_fakequanted_net(net: QATModule, fakequant_cls):
qat_net = Q.reset_qconfig(net, get_lsq_config(fakequant_cls))
qat_net.eval()
return qat_net
......@@ -162,6 +164,7 @@ def test_load_param():
def _check_module(build_func: Callable):
net = build_func()
net.eval()
buffer = io.BytesIO()
mge.save(net.state_dict(), buffer)
buffer.seek(0)
......@@ -185,6 +188,7 @@ def test_load_param():
def test_qualname():
def _check_qualname(net):
inp = Tensor(np.random.random(size=(5, 3, 32, 32)))
net.eval()
traced_net = trace_module(net, inp)
base_qualname = traced_net.graph.qualname
for node in traced_net.graph.nodes():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册