提交 e5217db9 编写于 作者: M Megvii Engine Team

feat(traced_module): move traced_module out of the experimental folder

GitOrigin-RevId: 36c76b5277c64b15ef4a9aacdf7fa4e1eb936c10
上级 6d1a4f20
...@@ -130,4 +130,3 @@ import megengine.optimizer ...@@ -130,4 +130,3 @@ import megengine.optimizer
import megengine.quantization import megengine.quantization
import megengine.random import megengine.random
import megengine.utils import megengine.utils
import megengine.experimental
...@@ -6,5 +6,4 @@ ...@@ -6,5 +6,4 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from . import traced_module
from .weight_scaler import get_scaled_model from .weight_scaler import get_scaled_model
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ...core._imperative_rt.core2 import set_cpp_apply_module_trace from ..core._imperative_rt.core2 import set_cpp_apply_module_trace
from .traced_module import ( from .traced_module import (
TracedModule, TracedModule,
_register_all_builtin_module, _register_all_builtin_module,
......
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
# #
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
...@@ -14,13 +13,13 @@ import inspect ...@@ -14,13 +13,13 @@ import inspect
import re import re
from typing import Callable, Dict, List from typing import Callable, Dict, List
from ...core._imperative_rt import OpDef from ..core._imperative_rt import OpDef
from ...core._imperative_rt.core2 import Tensor as RawTensor from ..core._imperative_rt.core2 import Tensor as RawTensor
from ...core._imperative_rt.core2 import apply, set_module_tracing, unset_module_tracing from ..core._imperative_rt.core2 import apply, set_module_tracing, unset_module_tracing
from ...core.ops.builtin import FakeQuant from ..core.ops.builtin import FakeQuant
from ...core.ops.special import Const from ..core.ops.special import Const
from ...module import Module from ..module import Module
from ...tensor import Parameter, Tensor from ..tensor import Parameter, Tensor
from .module_tracer import active_module_tracer, module_tracer from .module_tracer import active_module_tracer, module_tracer
from .node import ModuleNode, Node, NodeMixin, TensorNode from .node import ModuleNode, Node, NodeMixin, TensorNode
from .pytree import ArgsIndex, TreeDef, _is_const_leaf, _is_leaf, tree_flatten from .pytree import ArgsIndex, TreeDef, _is_const_leaf, _is_leaf, tree_flatten
......
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from copy import deepcopy from copy import deepcopy
from typing import Union from typing import Union
from ...core.tensor.dtype import QuantDtypeMeta from ..core.tensor.dtype import QuantDtypeMeta
from ...quantization.fake_quant import QParamsModuleMixin, _FakeQuantize from ..quantization.fake_quant import QParamsModuleMixin, _FakeQuantize
from ...quantization.utils import QParams, QuantMode, fake_quant_tensor from ..quantization.utils import QParams, QuantMode, fake_quant_tensor
class FakeQuantize(_FakeQuantize, QParamsModuleMixin): class FakeQuantize(_FakeQuantize, QParamsModuleMixin):
......
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
# #
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
...@@ -8,11 +7,11 @@ ...@@ -8,11 +7,11 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import collections import collections
from ... import Tensor from .. import Tensor
from ... import functional as F from .. import functional as F
from ...core.tensor.array_method import ArrayMethodMixin from ..core.tensor.array_method import ArrayMethodMixin
from ...module import Module from ..module import Module
from ...module.qat import QATModule from ..module.qat import QATModule
_active_module_tracer = None _active_module_tracer = None
......
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
# #
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
...@@ -12,9 +11,9 @@ from typing import Any, Dict, List, Tuple, Type ...@@ -12,9 +11,9 @@ from typing import Any, Dict, List, Tuple, Type
import numpy import numpy
from ...core._imperative_rt.core2 import Tensor as RawTensor from ..core._imperative_rt.core2 import Tensor as RawTensor
from ...module import Module from ..module import Module
from ...tensor import Tensor from ..tensor import Tensor
class Node: class Node:
......
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
# #
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
...@@ -13,13 +12,13 @@ from typing import Callable, NamedTuple ...@@ -13,13 +12,13 @@ from typing import Callable, NamedTuple
import numpy as np import numpy as np
from ...core._imperative_rt.common import CompNode from ..core._imperative_rt.common import CompNode
from ...core._imperative_rt.core2 import Tensor as RawTensor from ..core._imperative_rt.core2 import Tensor as RawTensor
from ...core._wrap import Device from ..core._wrap import Device
from ...core.tensor.dtype import QuantDtypeMeta from ..core.tensor.dtype import QuantDtypeMeta
from ...module import Module from ..module import Module
from ...quantization.utils import LSQParams, QParams, QuantMode from ..quantization.utils import LSQParams, QParams, QuantMode
from ...tensor import Parameter, Tensor from ..tensor import Parameter, Tensor
from .node import ModuleNode, Node, NodeMixin, TensorNode from .node import ModuleNode, Node, NodeMixin, TensorNode
......
...@@ -7,9 +7,9 @@ ...@@ -7,9 +7,9 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from typing import Dict from typing import Dict
from ...core._imperative_rt import OpDef from ..core._imperative_rt import OpDef
from ...core.ops import builtin from ..core.ops import builtin
from ...version import __version__ from ..version import __version__
OPDEF_PARAM_LOADER = {} OPDEF_PARAM_LOADER = {}
......
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
# #
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
...@@ -22,21 +21,21 @@ from typing import Callable, Dict, Iterable, List, Optional, Sequence, Type, Uni ...@@ -22,21 +21,21 @@ from typing import Callable, Dict, Iterable, List, Optional, Sequence, Type, Uni
from megengine import tensor from megengine import tensor
from ... import functional as F from .. import functional as F
from ... import get_logger from .. import get_logger
from ... import module as M from .. import module as M
from ...core._imperative_rt.core2 import Tensor as RawTensor from ..core._imperative_rt.core2 import Tensor as RawTensor
from ...core._imperative_rt.core2 import ( from ..core._imperative_rt.core2 import (
is_tracing_module, is_tracing_module,
set_module_tracing, set_module_tracing,
unset_module_tracing, unset_module_tracing,
) )
from ...core._trace_option import set_symbolic_shape from ..core._trace_option import set_symbolic_shape
from ...core.tensor.array_method import ArrayMethodMixin from ..core.tensor.array_method import ArrayMethodMixin
from ...module import Module from ..module import Module
from ...module.qat import QATModule from ..module.qat import QATModule
from ...quantization.fake_quant import LSQ, TQT, FakeQuantize, _FakeQuantize from ..quantization.fake_quant import LSQ, TQT, FakeQuantize, _FakeQuantize
from ...quantization.observer import ( from ..quantization.observer import (
ExponentialMovingAverageObserver, ExponentialMovingAverageObserver,
HistogramObserver, HistogramObserver,
MinMaxObserver, MinMaxObserver,
...@@ -45,7 +44,7 @@ from ...quantization.observer import ( ...@@ -45,7 +44,7 @@ from ...quantization.observer import (
SyncExponentialMovingAverageObserver, SyncExponentialMovingAverageObserver,
SyncMinMaxObserver, SyncMinMaxObserver,
) )
from ...tensor import Tensor from ..tensor import Tensor
from .expr import Apply, CallFunction, CallMethod, Constant, Expr, GetAttr, Input from .expr import Apply, CallFunction, CallMethod, Constant, Expr, GetAttr, Input
from .fake_quant import FakeQuantize as TM_FakeQuant from .fake_quant import FakeQuantize as TM_FakeQuant
from .module_tracer import ( from .module_tracer import (
......
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
# #
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
...@@ -10,7 +9,7 @@ import copy ...@@ -10,7 +9,7 @@ import copy
from collections.abc import MutableMapping, MutableSequence from collections.abc import MutableMapping, MutableSequence
from typing import Dict, Iterable, List, Optional, Sequence from typing import Dict, Iterable, List, Optional, Sequence
from ...module import Module from ..module import Module
def replace_container_with_module_container(container): def replace_container_with_module_container(container):
......
...@@ -15,9 +15,9 @@ import megengine as mge ...@@ -15,9 +15,9 @@ import megengine as mge
import megengine.autodiff as ad import megengine.autodiff as ad
import megengine.functional as F import megengine.functional as F
from megengine import Tensor from megengine import Tensor
from megengine.experimental.traced_module import trace_module
from megengine.module import Linear, Module from megengine.module import Linear, Module
from megengine.optimizer import SGD from megengine.optimizer import SGD
from megengine.traced_module import trace_module
batch_size = 64 batch_size = 64
data_shape = (batch_size, 2) data_shape = (batch_size, 2)
......
...@@ -16,10 +16,10 @@ import megengine.autodiff as ad ...@@ -16,10 +16,10 @@ import megengine.autodiff as ad
import megengine.functional as F import megengine.functional as F
import megengine.optimizer as optim import megengine.optimizer as optim
from megengine import Tensor from megengine import Tensor
from megengine.experimental.traced_module import trace_module
from megengine.jit import trace from megengine.jit import trace
from megengine.module import Linear, Module from megengine.module import Linear, Module
from megengine.optimizer import SGD from megengine.optimizer import SGD
from megengine.traced_module import trace_module
batch_size = 64 batch_size = 64
data_shape = (batch_size, 2) data_shape = (batch_size, 2)
......
...@@ -19,8 +19,8 @@ import megengine.module as M ...@@ -19,8 +19,8 @@ import megengine.module as M
import megengine.optimizer as optim import megengine.optimizer as optim
from megengine import tensor from megengine import tensor
from megengine.autodiff import GradManager from megengine.autodiff import GradManager
from megengine.experimental.traced_module import trace_module
from megengine.jit import trace from megengine.jit import trace
from megengine.traced_module import trace_module
@contextlib.contextmanager @contextlib.contextmanager
......
...@@ -15,10 +15,7 @@ import numpy as np ...@@ -15,10 +15,7 @@ import numpy as np
import megengine as mge import megengine as mge
from megengine import Parameter, Tensor from megengine import Parameter, Tensor
from megengine.core.ops import builtin from megengine.core.ops import builtin
from megengine.experimental.traced_module.serialization import ( from megengine.traced_module.serialization import get_opdef_state, load_opdef_from_state
get_opdef_state,
load_opdef_from_state,
)
def test_tensor_serialization(): def test_tensor_serialization():
......
...@@ -15,7 +15,6 @@ import pytest ...@@ -15,7 +15,6 @@ import pytest
import megengine as mge import megengine as mge
import megengine.functional as F import megengine.functional as F
from megengine import Parameter, Tensor, tensor from megengine import Parameter, Tensor, tensor
from megengine.experimental.traced_module import TracedModule, trace_module
from megengine.module import ( from megengine.module import (
BatchNorm1d, BatchNorm1d,
BatchNorm2d, BatchNorm2d,
...@@ -30,6 +29,7 @@ from megengine.module import ( ...@@ -30,6 +29,7 @@ from megengine.module import (
) )
from megengine.module.module import _access_structure from megengine.module.module import _access_structure
from megengine.quantization.quantize import quantize, quantize_qat from megengine.quantization.quantize import quantize, quantize_qat
from megengine.traced_module import TracedModule, trace_module
from megengine.utils.module_utils import get_expand_structure, set_expand_structure from megengine.utils.module_utils import get_expand_structure, set_expand_structure
......
...@@ -7,8 +7,8 @@ import megengine.functional as F ...@@ -7,8 +7,8 @@ import megengine.functional as F
import megengine.module as M import megengine.module as M
import megengine.utils.comp_graph_tools as cgtools import megengine.utils.comp_graph_tools as cgtools
from megengine.core._trace_option import set_symbolic_shape from megengine.core._trace_option import set_symbolic_shape
from megengine.experimental.traced_module import trace_module
from megengine.jit import trace from megengine.jit import trace
from megengine.traced_module import trace_module
set_symbolic_shape(True) set_symbolic_shape(True)
......
...@@ -12,9 +12,9 @@ import numpy as np ...@@ -12,9 +12,9 @@ import numpy as np
import megengine.functional as F import megengine.functional as F
import megengine.module as M import megengine.module as M
import megengine.utils.comp_graph_tools as cgtools import megengine.utils.comp_graph_tools as cgtools
from megengine.experimental.traced_module import trace_module
from megengine.jit import trace from megengine.jit import trace
from megengine.module import Module from megengine.module import Module
from megengine.traced_module import trace_module
class MyBlock(Module): class MyBlock(Module):
......
...@@ -9,8 +9,8 @@ import numpy as np ...@@ -9,8 +9,8 @@ import numpy as np
import megengine.functional as F import megengine.functional as F
import megengine.module as M import megengine.module as M
from megengine.experimental.traced_module import trace_module from megengine.traced_module import trace_module
from megengine.experimental.traced_module.expr import CallFunction, GetAttr from megengine.traced_module.expr import CallFunction, GetAttr
class MyBlock(M.Module): class MyBlock(M.Module):
......
...@@ -12,8 +12,8 @@ import numpy as np ...@@ -12,8 +12,8 @@ import numpy as np
import megengine.functional as F import megengine.functional as F
import megengine.module as M import megengine.module as M
from megengine import Tensor from megengine import Tensor
from megengine.experimental.traced_module import trace_module
from megengine.module import Module from megengine.module import Module
from megengine.traced_module import trace_module
class MyBlock(Module): class MyBlock(Module):
......
...@@ -2,7 +2,7 @@ import numpy as np ...@@ -2,7 +2,7 @@ import numpy as np
import megengine.module as M import megengine.module as M
from megengine import Tensor from megengine import Tensor
from megengine.experimental.traced_module import TracedModule, trace_module from megengine.traced_module import TracedModule, trace_module
class MyModule1(M.Module): class MyModule1(M.Module):
......
...@@ -8,8 +8,8 @@ import megengine.functional as F ...@@ -8,8 +8,8 @@ import megengine.functional as F
import megengine.module as M import megengine.module as M
import megengine.utils.comp_graph_tools as cgtools import megengine.utils.comp_graph_tools as cgtools
from megengine.core._trace_option import set_symbolic_shape from megengine.core._trace_option import set_symbolic_shape
from megengine.experimental.traced_module import trace_module
from megengine.jit import trace from megengine.jit import trace
from megengine.traced_module import trace_module
set_symbolic_shape(True) set_symbolic_shape(True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册