提交 97c90d91 编写于 作者: M Megvii Engine Team

feat(traced_module): add _exclude_from_trace

GitOrigin-RevId: 615b769a02779547cf778aebf431d84d7d625179
上级 30e565e5
......@@ -145,9 +145,8 @@ def _node_to_tensor(*args, **kwargs):
value = n.value
if value is None:
flag = _set_graph_surgery_mode(False)
unset_module_tracing()
value = F.zeros(shape=n._shape, dtype=n._dtype)
set_module_tracing()
with _exclude_from_trace():
value = F.zeros(shape=n._shape, dtype=n._dtype)
_set_graph_surgery_mode(flag)
orig_n = NodeMixin.get(value, None)
if orig_n is None or "setitem" not in orig_n._name:
......@@ -1274,8 +1273,10 @@ def _wrapped_function(orig_func):
@functools.wraps(orig_func)
def wrapped_fn(*args, **kwargs):
method_func = kwargs.pop("method_func", wrapped_fn)
if is_tracing_module():
unset_module_tracing()
if not is_tracing_module():
return orig_func(*args, **kwargs)
with _exclude_from_trace():
inputs, tree_def = tree_flatten((args, kwargs))
for i in inputs:
if not NodeMixin.get(i, None):
......@@ -1290,7 +1291,6 @@ def _wrapped_function(orig_func):
if meth_name == "__new__":
if all([not isinstance(i, RawTensor) for i in inputs]):
# only trace Tensor.__new__() when there are tensors in args
set_module_tracing()
return orig_func(*args, **kwargs)
if isinstance(args[1], RawTensor):
node = NodeMixin.get(inputs[1])
......@@ -1327,9 +1327,7 @@ def _wrapped_function(orig_func):
call_node, outputs
)
set_module_tracing()
return rst
return orig_func(*args, **kwargs)
return wrapped_fn
......@@ -1339,8 +1337,8 @@ class TracedModuleBuilder(NodeMixin):
_mod = None # type: Module
_body = None # type: InternalGraph
_is_builtin = None # type: bool
_argdef_graph_map = None # type: Dict[Treedef, "InternalGraph"]
_argdef_outdef_map = None # type: Dict[Treedef, Treedef]
_argdef_graph_map = None # type: Dict[TreeDef, "InternalGraph"]
_argdef_outdef_map = None # type: Dict[TreeDef, TreeDef]
nodes = None
__builder_attributes__ = [
......@@ -1371,9 +1369,8 @@ class TracedModuleBuilder(NodeMixin):
else module_tracer.is_builtin(mod)
)
if isinstance(self._mod, QATModule):
unset_module_tracing()
self._check_qat_module(self._mod)
set_module_tracing()
with _exclude_from_trace():
self._check_qat_module(self._mod)
self._argdef_graph_map = {}
self._argdef_outdef_map = {}
......@@ -1458,18 +1455,17 @@ class TracedModuleBuilder(NodeMixin):
setattr(traced_module, k, v)
if isinstance(self._mod, QATModule):
unset_module_tracing()
traced_module.with_act = self._mod.with_act
traced_module.with_weight = self._mod.with_weight
if not hasattr(traced_module, "act_fake_quant"):
traced_module.act_fake_quant = None
if not hasattr(traced_module, "act_observer"):
traced_module.act_observer = None
if not hasattr(traced_module, "weight_fake_quant"):
traced_module.weight_fake_quant = None
if not hasattr(traced_module, "weight_observer"):
traced_module.weight_observer = None
set_module_tracing()
with _exclude_from_trace():
traced_module.with_act = self._mod.with_act
traced_module.with_weight = self._mod.with_weight
if not hasattr(traced_module, "act_fake_quant"):
traced_module.act_fake_quant = None
if not hasattr(traced_module, "act_observer"):
traced_module.act_observer = None
if not hasattr(traced_module, "weight_fake_quant"):
traced_module.weight_fake_quant = None
if not hasattr(traced_module, "weight_observer"):
traced_module.weight_observer = None
if self._is_top:
traced_module._update_ref()
......@@ -1505,16 +1501,14 @@ class TracedModuleBuilder(NodeMixin):
callnode.arg_def = tree_def
if self._is_builtin or tree_def in self._argdef_graph_map:
unset_module_tracing()
rst = self._mod(*args, **kwargs)
outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf)
if _get_expr_checker():
with _exclude_from_trace():
with _exclude_from_trace():
rst = self._mod(*args, **kwargs)
outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf)
if _get_expr_checker():
tmp = self.build()
active_module_tracer().checker.check_builtin_module(
tmp, callnode, outputs
)
set_module_tracing()
if self._is_builtin:
self._body = None
elif tree_def in self._argdef_graph_map:
......@@ -1640,16 +1634,17 @@ class TracedModuleBuilder(NodeMixin):
if isinstance(attr, (List, Dict)):
flag = _set_graph_surgery_mode(False)
unset_module_tracing()
has_module, m_container = replace_container_with_module_container(attr)
if m_container:
attr = m_container
if has_module and not m_container:
raise ValueError(
"Can not trace the module that uses the same container to store"
" Module and Non-Module objects."
with _exclude_from_trace():
has_module, m_container = replace_container_with_module_container(
attr
)
set_module_tracing()
if m_container:
attr = m_container
if has_module and not m_container:
raise ValueError(
"Can not trace the module that uses the same container to store"
" Module and Non-Module objects."
)
_set_graph_surgery_mode(flag)
if isinstance(attr, Module):
......
......@@ -5,8 +5,8 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import collections
import copy
import contextlib
import inspect
from collections.abc import MutableMapping, MutableSequence
from inspect import FullArgSpec
......@@ -71,7 +71,9 @@ def _convert_kwargs_to_args(
arg_specs_args = arg_specs.args
arg_specs_defaults = arg_specs.defaults if arg_specs.defaults else []
arg_specs_kwonlyargs = arg_specs.kwonlyargs
arg_specs_kwonlydefaults = arg_specs.kwonlydefaults if arg_specs.kwonlydefaults else dict()
arg_specs_kwonlydefaults = (
arg_specs.kwonlydefaults if arg_specs.kwonlydefaults else dict()
)
if is_bounded:
arg_specs_args = arg_specs.args[1:]
new_args = []
......@@ -104,17 +106,17 @@ def _convert_kwargs_to_args(
new_kwargs[kwarg_name] = kwargs[kwarg_name]
else:
if kwarg_name not in arg_specs_kwonlydefaults:
raise TypeError("{} missing required keyword-only argument: {}".format(
func_name, kwarg_name
))
raise TypeError(
"{} missing required keyword-only argument: {}".format(
func_name, kwarg_name
)
)
new_kwargs[kwarg_name] = arg_specs_kwonlydefaults[kwarg_name]
for k, v in kwargs.items():
if k not in arg_specs_args and k not in arg_specs_kwonlyargs:
if arg_specs.varkw is None:
raise TypeError(
"{} got an unexpected keyword argument {}".format(
func_name, k
)
"{} got an unexpected keyword argument {}".format(func_name, k)
)
new_kwargs[k] = v
return tuple(new_args), new_kwargs
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册