提交 f2691566 编写于 作者: M Megvii Engine Team

feat(traced_module): add pytree

GitOrigin-RevId: 6c6e53521c71474c67590e0a94723a1d6be89218
上级 bee305be
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import builtins
import collections import collections
from typing import Callable, List from typing import Callable, List
...@@ -19,6 +19,7 @@ from ...module import Module ...@@ -19,6 +19,7 @@ from ...module import Module
from ...tensor import Tensor from ...tensor import Tensor
from .module_tracer import active_module_tracer from .module_tracer import active_module_tracer
from .node import ModuleNode, Node, NodeMixin, TensorNode from .node import ModuleNode, Node, NodeMixin, TensorNode
from .pytree import TreeDef
class Expr: class Expr:
...@@ -28,9 +29,22 @@ class Expr: ...@@ -28,9 +29,22 @@ class Expr:
inputs = None # type: List[Node] inputs = None # type: List[Node]
outputs = None # type: List[Node] outputs = None # type: List[Node]
const_val = None # type: List[Any]
def add_input(self, node): arg_def = None # type: TreeDef
self.inputs.append(node)
def add_inputs(self, vals):
if not isinstance(vals, collections.abc.Sequence):
vals = (vals,)
for val in vals:
node = NodeMixin.get(val, None)
if isinstance(node, (TensorNode, ModuleNode)):
if node not in self.inputs:
self.inputs.append(node)
else:
assert node is None
assert type(val) in builtins.__dict__.values()
idx = len(self.inputs) + len(self.const_val)
self.const_val.append((idx, val))
def add_outputs(self, outputs): def add_outputs(self, outputs):
self.outputs = [] self.outputs = []
...@@ -38,50 +52,31 @@ class Expr: ...@@ -38,50 +52,31 @@ class Expr:
outputs = (outputs,) outputs = (outputs,)
for i in outputs: for i in outputs:
assert isinstance(i, RawTensor)
self.outputs.append(NodeMixin.get_wrapped_type(i)(self)) self.outputs.append(NodeMixin.get_wrapped_type(i)(self))
for i, node in zip(outputs, self.outputs,): for i, node in zip(outputs, self.outputs,):
NodeMixin.wrap_safe(i, node) NodeMixin.wrap_safe(i, node)
@classmethod def unflatten_args(self, inputs):
def get_args_node(cls, arg): if self.arg_def is not None:
""" inputs = list(inputs)
Create nodes by ``arg``, which may be a container. for idx, val in self.const_val:
Return the same structure with arg. inputs.insert(idx, val)
args, kwargs = self.arg_def.unflatten(inputs)
If ``arg`` was not Tensor or Module, it will be stored as const. return args, kwargs
:param arg: tensor, module or const.
"""
if isinstance(arg, (RawTensor, Module)):
if not NodeMixin.get(arg, None):
NodeMixin.wrap_safe(arg, Constant.make(arg))
return NodeMixin.get(arg)
elif isinstance(arg, collections.abc.Sequence):
seq_cls = type(arg)
return seq_cls([Expr.get_args_node(a) for a in arg])
else: else:
# TODO: assert arg type return inputs, {}
return arg # as const
@classmethod @property
def get_arg_value(cls, inp_node, node2value): def kwargs(self):
""" _, kwargs = self.unflatten_args(self.inputs)
Get values from node2value by inp_node, which may be a container. return kwargs
Return the same structure with inp_node.
@property
If ``inp_node`` was not in node2value, it is a const. def args(self):
args, _ = self.unflatten_args(self.inputs)
:param inp_node: nodes. return args
:param node2value: dict from node to tensor and module.
"""
if inp_node in node2value:
return node2value[inp_node]
elif isinstance(inp_node, collections.abc.Sequence):
seq_cls = type(inp_node)
return seq_cls([Expr.get_arg_value(i, node2value) for i in inp_node])
else:
return inp_node
# expr: None (i.e. fake expression which is used to mark input) # expr: None (i.e. fake expression which is used to mark input)
...@@ -144,16 +139,8 @@ class CallMethod(Expr): ...@@ -144,16 +139,8 @@ class CallMethod(Expr):
self.inputs = [ self.inputs = [
module, module,
] ]
self.const_val = []
self.method = method self.method = method
self.arg_names = []
self.kwargs = {} # const kwargs
def add_input(self, node, arg_name=None):
if arg_name == "self": # FIXME: <XP>
return
self.inputs.append(node)
if arg_name is not None:
self.arg_names.append(arg_name)
@classmethod @classmethod
def make(cls, *args, **kwargs): def make(cls, *args, **kwargs):
...@@ -162,19 +149,22 @@ class CallMethod(Expr): ...@@ -162,19 +149,22 @@ class CallMethod(Expr):
return expr return expr
def interpret(self, *inputs): def interpret(self, *inputs):
mod = inputs[0] args, kwargs = self.unflatten_args(inputs)
args = inputs[1:] obj = args[0]
outputs = getattr(mod, self.method)(*args, **self.kwargs) args = args[1:]
outputs = getattr(obj, self.method)(*args, **kwargs)
if isinstance(outputs, RawTensor): if isinstance(outputs, RawTensor):
outputs = (outputs,) outputs = (outputs,)
return outputs return outputs
def __repr__(self): def __repr__(self):
return "{} = CallMethod({}, {})({})".format( args = ", ".join(str(i) for i in self.args[1:])
kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items())
return "{} = {}.{}({})".format(
", ".join(str(i) for i in self.outputs), ", ".join(str(i) for i in self.outputs),
self.inputs[0], self.inputs[0],
self.method, self.method,
", ".join(str(i) for i in self.inputs[1:]), ", ".join([args, kwargs]),
) )
...@@ -227,13 +217,8 @@ class CallFunction(Expr): ...@@ -227,13 +217,8 @@ class CallFunction(Expr):
def __init__(self, func): def __init__(self, func):
assert isinstance(func, Callable) assert isinstance(func, Callable)
self.func = func self.func = func
self.const_val = []
self.inputs = [] self.inputs = []
self.arg_names = []
self.kwargs = {} # const kwargs
def add_input(self, node, arg_name):
self.inputs.append(node)
self.arg_names.append(arg_name)
@classmethod @classmethod
def make(cls, *args, **kwargs): def make(cls, *args, **kwargs):
...@@ -242,18 +227,20 @@ class CallFunction(Expr): ...@@ -242,18 +227,20 @@ class CallFunction(Expr):
return expr return expr
def interpret(self, *inputs): def interpret(self, *inputs):
inp_dict = dict([(name, node) for node, name in zip(inputs, self.arg_names)]) args, kwargs = self.unflatten_args(inputs)
outputs = self.func(**inp_dict, **self.kwargs) outputs = self.func(*args, **kwargs)
outputs = ( outputs = (
outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,) outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,)
) )
return outputs return outputs
def __repr__(self): def __repr__(self):
args = ", ".join(str(i) for i in self.args)
kwargs = ", ".join("{}={}".format(k, v) for k, v in self.kwargs.items())
return "{} = {}({})".format( return "{} = {}({})".format(
", ".join(str(i) for i in self.outputs), ", ".join(str(i) for i in self.outputs),
self.func.__module__ + "." + self.func.__name__, self.func.__module__ + "." + self.func.__name__,
", ".join(str(i) for i in self.inputs), ", ".join([args, kwargs]),
) )
......
...@@ -15,6 +15,72 @@ from ...module import Module ...@@ -15,6 +15,72 @@ from ...module import Module
_active_module_tracer = None _active_module_tracer = None
BUILTIN_ARRAY_METHOD = [
"__lt__",
"__le__",
"__gt__",
"__ge__",
"__eq__",
"__ne__",
"__neg__",
"__pos__",
"__abs__",
"__invert__",
"__round__",
"__floor__",
"__ceil__",
"__add__",
"__sub__",
"__mul__",
"__matmul__",
"__truediv__",
"__floordiv__",
"__mod__",
"__pow__",
"__lshift__",
"__rshift__",
"__and__",
"__or__",
"__xor__",
"__radd__",
"__rsub__",
"__rmul__",
"__rmatmul__",
"__rtruediv__",
"__rfloordiv__",
"__rmod__",
"__rpow__",
"__rlshift__",
"__rrshift__",
"__rand__",
"__ror__",
"__rxor__",
"__iadd__",
"__isub__",
"__imul__",
"__imatmul__",
"__itruediv__",
"__ifloordiv__",
"__imod__",
"__ipow__",
"__ilshift__",
"__irshift__",
"__iand__",
"__ior__",
"__ixor__",
"T",
"astype",
"reshape",
"_broadcast",
"transpose",
"flatten",
"sum",
"prod",
"min",
"max",
"mean",
]
def active_module_tracer(): def active_module_tracer():
return _active_module_tracer return _active_module_tracer
...@@ -108,9 +174,8 @@ class Patcher: ...@@ -108,9 +174,8 @@ class Patcher:
self.wrap_fn = wrap_fn self.wrap_fn = wrap_fn
for module in self._builtin_modules: for module in self._builtin_modules:
self.patch_module(module) self.patch_module(module)
for meth in BUILTIN_ARRAY_METHOD:
for cls in self._builtin_methods: self.patch_method(ArrayMethodMixin, meth, self.wrap_fn)
self.patch_cls(cls)
for i, j in self._builtin_functions: for i, j in self._builtin_functions:
if id(i) not in self.visited_frames_ids: if id(i) not in self.visited_frames_ids:
......
...@@ -13,6 +13,7 @@ import numpy ...@@ -13,6 +13,7 @@ import numpy
from ...core._imperative_rt.core2 import Tensor as RawTensor from ...core._imperative_rt.core2 import Tensor as RawTensor
from ...module import Module from ...module import Module
from ...tensor import Tensor from ...tensor import Tensor
from .pytree import TreeDef
class Node: class Node:
...@@ -58,6 +59,7 @@ class ModuleNode(Node): ...@@ -58,6 +59,7 @@ class ModuleNode(Node):
module_type = Module # type: Type[Module] module_type = Module # type: Type[Module]
graph = None graph = None
attr_type_map = None # type: Dict[str, Type[Any]] attr_type_map = None # type: Dict[str, Type[Any]]
arg_def = None # type: TreeDef
def __repr__(self): def __repr__(self):
if self._name is None: if self._name is None:
......
from typing import Callable, NamedTuple
SUPPORTED_TYPE = {}
NodeType = NamedTuple("NodeType", [("flatten", Callable), ("unflatten", Callable)])
def register_supported_type(type, flatten, unflatten):
SUPPORTED_TYPE[type] = NodeType(flatten, unflatten)
register_supported_type(list, lambda x: (x, None), lambda x, aux_data: list(x))
register_supported_type(tuple, lambda x: (x, None), lambda x, aux_data: list(x))
register_supported_type(
dict, lambda x: (list(x.values()), list(x.keys())), lambda x, y: dict(zip(y, x))
)
register_supported_type(
slice,
lambda x: ([x.start, x.stop, x.step], None),
lambda x, aux_data: slice(x[0], x[1], x[2]),
)
def tree_flatten(
values, leaf_type: Callable = lambda x: type(x), is_leaf: Callable = lambda x: True
):
if type(values) not in SUPPORTED_TYPE:
assert is_leaf(values)
return [values,], LeafDef(leaf_type(values))
rst = []
children_defs = []
children_values, aux_data = SUPPORTED_TYPE[type(values)].flatten(values)
for v in children_values:
v_list, treedef = tree_flatten(v, leaf_type)
rst.extend(v_list)
children_defs.append(treedef)
return rst, TreeDef(type(values), aux_data, children_defs)
class TreeDef:
def __init__(self, type, aux_data, children_defs):
self.type = type
self.aux_data = aux_data
self.children_defs = children_defs
self.num_leaves = sum(ch.num_leaves for ch in children_defs)
def unflatten(self, leaves):
assert len(leaves) == self.num_leaves
start = 0
children = []
for ch in self.children_defs:
children.append(ch.unflatten(leaves[start : start + ch.num_leaves]))
start += ch.num_leaves
return SUPPORTED_TYPE[self.type].unflatten(children, self.aux_data)
def __eq__(self, other):
return (
self.type == other.type
and self.aux_data == other.aux_data
and self.num_leaves == other.num_leaves
and self.children_defs == other.children_defs
)
def __repr__(self):
return "{}[{}]".format(self.type.__name__, self.children_defs)
class LeafDef(TreeDef):
def __init__(self, type):
super().__init__(type, None, [])
self.num_leaves = 1
def unflatten(self, leaves):
assert len(leaves) == 1
assert isinstance(leaves[0], self.type), self.type
return leaves[0]
def __repr__(self):
return "Leaf({})".format(self.type.__name__)
...@@ -9,9 +9,11 @@ ...@@ -9,9 +9,11 @@
import collections import collections
import copy import copy
import functools import functools
from inspect import getmembers, isclass, ismethod
from typing import List, Type from typing import List, Type
from ... import module as M from ... import module as M
from ...core._imperative_rt.core2 import Tensor as RawTensor
from ...core._imperative_rt.core2 import ( from ...core._imperative_rt.core2 import (
is_tracing_module, is_tracing_module,
set_module_tracing, set_module_tracing,
...@@ -28,6 +30,16 @@ from .module_tracer import ( ...@@ -28,6 +30,16 @@ from .module_tracer import (
set_active_module_tracer, set_active_module_tracer,
) )
from .node import ModuleNode, Node, NodeMixin, TensorNode from .node import ModuleNode, Node, NodeMixin, TensorNode
from .pytree import tree_flatten
def _leaf_type(node):
if isinstance(node, RawTensor):
return (Tensor, TensorNode)
elif isinstance(node, (NodeMixin, Module)):
return (Module, ModuleNode, NodeMixin)
else:
return type(node)
class InternalGraph: class InternalGraph:
...@@ -65,9 +77,7 @@ class InternalGraph: ...@@ -65,9 +77,7 @@ class InternalGraph:
for n, v in zip(self._inputs, inputs): for n, v in zip(self._inputs, inputs):
node2value[n] = v node2value[n] = v
for expr in self._exprs: for expr in self._exprs:
values = expr.interpret( values = expr.interpret(*list(node2value[i] for i in expr.inputs))
*list(Expr.get_arg_value(i, node2value) for i in expr.inputs)
)
for n, v in zip(expr.outputs, values): for n, v in zip(expr.outputs, values):
node2value[n] = v node2value[n] = v
return list(node2value[i] for i in self._outputs) return list(node2value[i] for i in self._outputs)
...@@ -80,37 +90,39 @@ class InternalGraph: ...@@ -80,37 +90,39 @@ class InternalGraph:
) )
def _get_meth_name(obj, func):
for cls in type(obj).mro():
for k, v in cls.__dict__.items():
if v == func:
return k
return None
def _wrapped_function(orig_func): def _wrapped_function(orig_func):
@functools.wraps(orig_func) @functools.wraps(orig_func)
def wrapped_fn(*inputs, **kwargs): def wrapped_fn(*args, **kwargs):
if is_tracing_module(): if is_tracing_module():
unset_module_tracing() unset_module_tracing()
const_kwargs = {} inputs, tree_def = tree_flatten((args, kwargs), leaf_type=_leaf_type)
arg_names = orig_func.__code__.co_varnames for i in inputs:
if orig_func.__qualname__.split(".").__len__() > 1: if not NodeMixin.get(i, None):
# FIXME: a robust way to distinguish method and function. <XP> if isinstance(i, (RawTensor, NodeMixin)):
NodeMixin.wrap_safe(i, Constant.make(i))
meth_name = _get_meth_name(args[0], wrapped_fn)
if meth_name:
self = inputs[0] self = inputs[0]
call_node = CallMethod.make(NodeMixin.get(self), orig_func.__name__) call_node = CallMethod.make(NodeMixin.get(self), meth_name)
else: else:
call_node = CallFunction.make(orig_func) call_node = CallFunction.make(orig_func)
def add_input(inp, varname=None): call_node.add_inputs(inputs)
node = Expr.get_args_node(inp)
if node is not None: call_node.arg_def = tree_def
call_node.add_input(node, varname) outputs = orig_func(*args, **kwargs)
else:
const_kwargs[varname] = inp
for ind, inp in enumerate(inputs):
add_input(inp, arg_names[ind])
for k, v in kwargs.items():
add_input(v, k)
call_node.kwargs = const_kwargs
outputs = orig_func(*inputs, **kwargs)
call_node.add_outputs(outputs) call_node.add_outputs(outputs)
set_module_tracing() set_module_tracing()
return outputs return outputs
return orig_func(*inputs, **kwargs) return orig_func(*args, **kwargs)
return wrapped_fn return wrapped_fn
...@@ -120,14 +132,14 @@ class TracedModuleBuilder(NodeMixin): ...@@ -120,14 +132,14 @@ class TracedModuleBuilder(NodeMixin):
_mod = None # type: Module _mod = None # type: Module
_body = None # type: InternalGraph _body = None # type: InternalGraph
_is_builtin = None # type: bool _is_builtin = None # type: bool
_arg_def = None # type: TreeDef
__builder_attributes__ = [ __builder_attributes__ = [
"_mod", "_mod",
"_body", "_body",
"_NodeMixin__node", "_NodeMixin__node",
"_is_builtin", "_is_builtin",
"_is_traced", "_is_traced",
"build", "_arg_def" "build",
] ]
def __init__(self, mod): def __init__(self, mod):
...@@ -146,6 +158,7 @@ class TracedModuleBuilder(NodeMixin): ...@@ -146,6 +158,7 @@ class TracedModuleBuilder(NodeMixin):
node = NodeMixin.get(self) node = NodeMixin.get(self)
node.graph = self._body node.graph = self._body
node.attr_type_map = {} node.attr_type_map = {}
node.arg_def = self._arg_def
traced_module = TracedModule(node) traced_module = TracedModule(node)
for k, v in self.__dict__.items(): for k, v in self.__dict__.items():
if k not in TracedModuleBuilder.__builder_attributes__: if k not in TracedModuleBuilder.__builder_attributes__:
...@@ -155,32 +168,34 @@ class TracedModuleBuilder(NodeMixin): ...@@ -155,32 +168,34 @@ class TracedModuleBuilder(NodeMixin):
traced_module.m_node.attr_type_map[k] = type(v) traced_module.m_node.attr_type_map[k] = type(v)
return traced_module return traced_module
def __call__(self, *inputs, **kwargs): def __call__(self, *args, **kwargs):
assert isinstance(self._mod, Module) assert isinstance(self._mod, Module)
for arg in args:
assert isinstance(arg, RawTensor)
for k, v in kwargs.items():
assert isinstance(v, RawTensor)
# prepare args and kwargs for inner graph # prepare args and kwargs for inner graph
def mark_constant(x): def mark_constant(x):
node = NodeMixin.get(x, None) node = NodeMixin.get(x, None)
if node is None: # capture as constant if node is None: # capture as constant
NodeMixin.wrap(x, lambda: Constant.make(x)) NodeMixin.wrap(x, lambda: Constant.make(x))
inputs, tree_def = tree_flatten(((self, *args), kwargs), leaf_type=_leaf_type)
if self._arg_def is None:
self._arg_def = tree_def
assert self._arg_def == tree_def
for i in inputs: for i in inputs:
mark_constant(i) mark_constant(i)
for k, v in kwargs.items():
mark_constant(v)
callnode = CallMethod.make(NodeMixin.get(self)) callnode = CallMethod.make(NodeMixin.get(self))
def add_input(x): callnode.add_inputs(inputs)
callnode.add_input(NodeMixin.get(x))
for i in inputs: callnode.arg_def = tree_def
add_input(i)
for k, v in kwargs.items():
add_input(v)
if self._is_builtin or self._is_traced: if self._is_builtin or self._is_traced:
unset_module_tracing() unset_module_tracing()
outputs = self._mod(*inputs, **kwargs) outputs = self._mod(*args, **kwargs)
set_module_tracing() set_module_tracing()
if self._is_builtin: if self._is_builtin:
self._body = None self._body = None
...@@ -193,23 +208,21 @@ class TracedModuleBuilder(NodeMixin): ...@@ -193,23 +208,21 @@ class TracedModuleBuilder(NodeMixin):
) )
# prepare args and kwargs for inner graph # prepare args and kwargs for inner graph
def wrap(x): def wrap(x):
# wrapped = copy.copy(x) # FIXME wrapped = copy.copy(x) # FIXME
wrapped = x # FIXME: <XP>
NodeMixin.wrap( NodeMixin.wrap(
wrapped, wrapped,
lambda: Input.make(type=NodeMixin.get_wrapped_type(wrapped)), lambda: Input.make(type=NodeMixin.get_wrapped_type(wrapped)),
) )
return wrapped return wrapped
args = [] args = [self]
for i in inputs: for i in inputs[1:]:
args.append(wrap(i)) args.append(wrap(i))
for k, v in kwargs.items(): args, kwargs = tree_def.unflatten(args)
kwargs[k] = wrap(v)
active_module_tracer().patcher.auto_patch( active_module_tracer().patcher.auto_patch(
getattr(getattr(self._mod, "forward", self._mod), "__globals__", {}) getattr(getattr(self._mod, "forward", self._mod), "__globals__", {})
) )
outputs = type(self._mod).forward(self, *args, **kwargs) outputs = type(self._mod).forward(*args, **kwargs)
for i in ( for i in (
outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,) outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,)
...@@ -269,8 +282,10 @@ class TracedModule(Module): ...@@ -269,8 +282,10 @@ class TracedModule(Module):
super(TracedModule, self).__init__() super(TracedModule, self).__init__()
self.m_node = node self.m_node = node
def forward(self, *inputs): def forward(self, *args, **kwargs):
rst = self.m_node.graph.interpret(self, *inputs) inputs, treedef = tree_flatten(((self, *args), kwargs), leaf_type=_leaf_type)
assert treedef == self.m_node.arg_def
rst = self.m_node.graph.interpret(*inputs)
if len(rst) == 1: if len(rst) == 1:
rst = rst[0] rst = rst[0]
return rst return rst
...@@ -345,7 +360,6 @@ def register_as_builtin(mod_cls: Type[Module]) -> None: ...@@ -345,7 +360,6 @@ def register_as_builtin(mod_cls: Type[Module]) -> None:
def _register_all_builtin_module(): def _register_all_builtin_module():
from inspect import getmembers, isclass
for sub_mod in [M, M.qat, M.quantized]: for sub_mod in [M, M.qat, M.quantized]:
for m in getmembers(sub_mod): for m in getmembers(sub_mod):
...@@ -357,7 +371,7 @@ def _register_all_builtin_module(): ...@@ -357,7 +371,7 @@ def _register_all_builtin_module():
module_tracer.register_as_builtin(m[1]) module_tracer.register_as_builtin(m[1])
def trace_module(mod: Module, *inputs: Tensor, **kwargs: Tensor) -> TracedModule: def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule:
""" """
Traces module ``mod`` and returns corresponding TracedModule. Traces module ``mod`` and returns corresponding TracedModule.
...@@ -375,15 +389,13 @@ def trace_module(mod: Module, *inputs: Tensor, **kwargs: Tensor) -> TracedModule ...@@ -375,15 +389,13 @@ def trace_module(mod: Module, *inputs: Tensor, **kwargs: Tensor) -> TracedModule
builder = TracedModuleBuilder(mod) builder = TracedModuleBuilder(mod)
NodeMixin.wrap_safe(builder, Input.make("TopModule", ModuleNode)) NodeMixin.wrap_safe(builder, Input.make("TopModule", ModuleNode))
inputs, _ = tree_flatten((args, kwargs))
for _, i in enumerate(inputs): for _, i in enumerate(inputs):
NodeMixin.wrap_safe(i, Input.make("arg_{}".format(_))) NodeMixin.wrap_safe(
for k, v in kwargs.items(): i, Input.make("arg_{}".format(_), NodeMixin.get_wrapped_type(i))
NodeMixin.wrap_safe(v, Input.make("kwarg_{}".format(k))) )
builder(*args, **kwargs)
builder(*inputs, **kwargs)
active_module_tracer().pop_scope() active_module_tracer().pop_scope()
return builder.build() return builder.build()
finally: finally:
set_active_module_tracer(None) set_active_module_tracer(None)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册