提交 e5217db9 编写于 作者: M Megvii Engine Team

feat(traced_module): move traced_module out of the experimental folder

GitOrigin-RevId: 36c76b5277c64b15ef4a9aacdf7fa4e1eb936c10
上级 6d1a4f20
......@@ -130,4 +130,3 @@ import megengine.optimizer
import megengine.quantization
import megengine.random
import megengine.utils
import megengine.experimental
......@@ -6,5 +6,4 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from . import traced_module
from .weight_scaler import get_scaled_model
......@@ -6,7 +6,7 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ...core._imperative_rt.core2 import set_cpp_apply_module_trace
from ..core._imperative_rt.core2 import set_cpp_apply_module_trace
from .traced_module import (
TracedModule,
_register_all_builtin_module,
......
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
......@@ -14,13 +13,13 @@ import inspect
import re
from typing import Callable, Dict, List
from ...core._imperative_rt import OpDef
from ...core._imperative_rt.core2 import Tensor as RawTensor
from ...core._imperative_rt.core2 import apply, set_module_tracing, unset_module_tracing
from ...core.ops.builtin import FakeQuant
from ...core.ops.special import Const
from ...module import Module
from ...tensor import Parameter, Tensor
from ..core._imperative_rt import OpDef
from ..core._imperative_rt.core2 import Tensor as RawTensor
from ..core._imperative_rt.core2 import apply, set_module_tracing, unset_module_tracing
from ..core.ops.builtin import FakeQuant
from ..core.ops.special import Const
from ..module import Module
from ..tensor import Parameter, Tensor
from .module_tracer import active_module_tracer, module_tracer
from .node import ModuleNode, Node, NodeMixin, TensorNode
from .pytree import ArgsIndex, TreeDef, _is_const_leaf, _is_leaf, tree_flatten
......
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from copy import deepcopy
from typing import Union
from ...core.tensor.dtype import QuantDtypeMeta
from ...quantization.fake_quant import QParamsModuleMixin, _FakeQuantize
from ...quantization.utils import QParams, QuantMode, fake_quant_tensor
from ..core.tensor.dtype import QuantDtypeMeta
from ..quantization.fake_quant import QParamsModuleMixin, _FakeQuantize
from ..quantization.utils import QParams, QuantMode, fake_quant_tensor
class FakeQuantize(_FakeQuantize, QParamsModuleMixin):
......
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
......@@ -8,11 +7,11 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import collections
from ... import Tensor
from ... import functional as F
from ...core.tensor.array_method import ArrayMethodMixin
from ...module import Module
from ...module.qat import QATModule
from .. import Tensor
from .. import functional as F
from ..core.tensor.array_method import ArrayMethodMixin
from ..module import Module
from ..module.qat import QATModule
_active_module_tracer = None
......
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
......@@ -12,9 +11,9 @@ from typing import Any, Dict, List, Tuple, Type
import numpy
from ...core._imperative_rt.core2 import Tensor as RawTensor
from ...module import Module
from ...tensor import Tensor
from ..core._imperative_rt.core2 import Tensor as RawTensor
from ..module import Module
from ..tensor import Tensor
class Node:
......
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
......@@ -13,13 +12,13 @@ from typing import Callable, NamedTuple
import numpy as np
from ...core._imperative_rt.common import CompNode
from ...core._imperative_rt.core2 import Tensor as RawTensor
from ...core._wrap import Device
from ...core.tensor.dtype import QuantDtypeMeta
from ...module import Module
from ...quantization.utils import LSQParams, QParams, QuantMode
from ...tensor import Parameter, Tensor
from ..core._imperative_rt.common import CompNode
from ..core._imperative_rt.core2 import Tensor as RawTensor
from ..core._wrap import Device
from ..core.tensor.dtype import QuantDtypeMeta
from ..module import Module
from ..quantization.utils import LSQParams, QParams, QuantMode
from ..tensor import Parameter, Tensor
from .node import ModuleNode, Node, NodeMixin, TensorNode
......
......@@ -7,9 +7,9 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from typing import Dict
from ...core._imperative_rt import OpDef
from ...core.ops import builtin
from ...version import __version__
from ..core._imperative_rt import OpDef
from ..core.ops import builtin
from ..version import __version__
OPDEF_PARAM_LOADER = {}
......
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
......@@ -22,21 +21,21 @@ from typing import Callable, Dict, Iterable, List, Optional, Sequence, Type, Uni
from megengine import tensor
from ... import functional as F
from ... import get_logger
from ... import module as M
from ...core._imperative_rt.core2 import Tensor as RawTensor
from ...core._imperative_rt.core2 import (
from .. import functional as F
from .. import get_logger
from .. import module as M
from ..core._imperative_rt.core2 import Tensor as RawTensor
from ..core._imperative_rt.core2 import (
is_tracing_module,
set_module_tracing,
unset_module_tracing,
)
from ...core._trace_option import set_symbolic_shape
from ...core.tensor.array_method import ArrayMethodMixin
from ...module import Module
from ...module.qat import QATModule
from ...quantization.fake_quant import LSQ, TQT, FakeQuantize, _FakeQuantize
from ...quantization.observer import (
from ..core._trace_option import set_symbolic_shape
from ..core.tensor.array_method import ArrayMethodMixin
from ..module import Module
from ..module.qat import QATModule
from ..quantization.fake_quant import LSQ, TQT, FakeQuantize, _FakeQuantize
from ..quantization.observer import (
ExponentialMovingAverageObserver,
HistogramObserver,
MinMaxObserver,
......@@ -45,7 +44,7 @@ from ...quantization.observer import (
SyncExponentialMovingAverageObserver,
SyncMinMaxObserver,
)
from ...tensor import Tensor
from ..tensor import Tensor
from .expr import Apply, CallFunction, CallMethod, Constant, Expr, GetAttr, Input
from .fake_quant import FakeQuantize as TM_FakeQuant
from .module_tracer import (
......
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
......@@ -10,7 +9,7 @@ import copy
from collections.abc import MutableMapping, MutableSequence
from typing import Dict, Iterable, List, Optional, Sequence
from ...module import Module
from ..module import Module
def replace_container_with_module_container(container):
......
......@@ -15,9 +15,9 @@ import megengine as mge
import megengine.autodiff as ad
import megengine.functional as F
from megengine import Tensor
from megengine.experimental.traced_module import trace_module
from megengine.module import Linear, Module
from megengine.optimizer import SGD
from megengine.traced_module import trace_module
batch_size = 64
data_shape = (batch_size, 2)
......
......@@ -16,10 +16,10 @@ import megengine.autodiff as ad
import megengine.functional as F
import megengine.optimizer as optim
from megengine import Tensor
from megengine.experimental.traced_module import trace_module
from megengine.jit import trace
from megengine.module import Linear, Module
from megengine.optimizer import SGD
from megengine.traced_module import trace_module
batch_size = 64
data_shape = (batch_size, 2)
......
......@@ -19,8 +19,8 @@ import megengine.module as M
import megengine.optimizer as optim
from megengine import tensor
from megengine.autodiff import GradManager
from megengine.experimental.traced_module import trace_module
from megengine.jit import trace
from megengine.traced_module import trace_module
@contextlib.contextmanager
......
......@@ -15,10 +15,7 @@ import numpy as np
import megengine as mge
from megengine import Parameter, Tensor
from megengine.core.ops import builtin
from megengine.experimental.traced_module.serialization import (
get_opdef_state,
load_opdef_from_state,
)
from megengine.traced_module.serialization import get_opdef_state, load_opdef_from_state
def test_tensor_serialization():
......
......@@ -15,7 +15,6 @@ import pytest
import megengine as mge
import megengine.functional as F
from megengine import Parameter, Tensor, tensor
from megengine.experimental.traced_module import TracedModule, trace_module
from megengine.module import (
BatchNorm1d,
BatchNorm2d,
......@@ -30,6 +29,7 @@ from megengine.module import (
)
from megengine.module.module import _access_structure
from megengine.quantization.quantize import quantize, quantize_qat
from megengine.traced_module import TracedModule, trace_module
from megengine.utils.module_utils import get_expand_structure, set_expand_structure
......
......@@ -7,8 +7,8 @@ import megengine.functional as F
import megengine.module as M
import megengine.utils.comp_graph_tools as cgtools
from megengine.core._trace_option import set_symbolic_shape
from megengine.experimental.traced_module import trace_module
from megengine.jit import trace
from megengine.traced_module import trace_module
set_symbolic_shape(True)
......
......@@ -12,9 +12,9 @@ import numpy as np
import megengine.functional as F
import megengine.module as M
import megengine.utils.comp_graph_tools as cgtools
from megengine.experimental.traced_module import trace_module
from megengine.jit import trace
from megengine.module import Module
from megengine.traced_module import trace_module
class MyBlock(Module):
......
......@@ -9,8 +9,8 @@ import numpy as np
import megengine.functional as F
import megengine.module as M
from megengine.experimental.traced_module import trace_module
from megengine.experimental.traced_module.expr import CallFunction, GetAttr
from megengine.traced_module import trace_module
from megengine.traced_module.expr import CallFunction, GetAttr
class MyBlock(M.Module):
......
......@@ -12,8 +12,8 @@ import numpy as np
import megengine.functional as F
import megengine.module as M
from megengine import Tensor
from megengine.experimental.traced_module import trace_module
from megengine.module import Module
from megengine.traced_module import trace_module
class MyBlock(Module):
......
......@@ -2,7 +2,7 @@ import numpy as np
import megengine.module as M
from megengine import Tensor
from megengine.experimental.traced_module import TracedModule, trace_module
from megengine.traced_module import TracedModule, trace_module
class MyModule1(M.Module):
......
......@@ -8,8 +8,8 @@ import megengine.functional as F
import megengine.module as M
import megengine.utils.comp_graph_tools as cgtools
from megengine.core._trace_option import set_symbolic_shape
from megengine.experimental.traced_module import trace_module
from megengine.jit import trace
from megengine.traced_module import trace_module
set_symbolic_shape(True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册