提交 b5e46ae9 编写于 作者: M Megvii Engine Team

feat(mge): restore Function

GitOrigin-RevId: dd455238bae9e937b70eaaa164ad215ef3126d5d
上级 dc250745
...@@ -6,11 +6,11 @@ ...@@ -6,11 +6,11 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import atexit
import ctypes
import os import os
import sys
import platform import platform
import ctypes import sys
import atexit
if sys.platform == "win32": if sys.platform == "win32":
lib_path = os.path.join(os.path.dirname(__file__), "core/lib") lib_path = os.path.join(os.path.dirname(__file__), "core/lib")
...@@ -71,14 +71,15 @@ if sys.platform == "win32": ...@@ -71,14 +71,15 @@ if sys.platform == "win32":
kernel32.SetErrorMode(old_error_mode) kernel32.SetErrorMode(old_error_mode)
from .core._imperative_rt.utils import _set_fork_exec_path_for_timed_func
from .core._imperative_rt.core2 import sync from .core._imperative_rt.core2 import sync
from .core._imperative_rt.utils import _set_fork_exec_path_for_timed_func
from .device import * from .device import *
from .logger import enable_debug_log, get_logger, set_log_file, set_log_level from .logger import enable_debug_log, get_logger, set_log_file, set_log_level
from .serialization import load, save from .serialization import load, save
from .tensor import Parameter, Tensor, tensor from .tensor import Parameter, Tensor, tensor
from .utils import comp_graph_tools as cgtools
from .utils import persistent_cache
from .version import __version__ from .version import __version__
from .utils import persistent_cache, comp_graph_tools as cgtools
_set_fork_exec_path_for_timed_func( _set_fork_exec_path_for_timed_func(
sys.executable, sys.executable,
......
...@@ -16,7 +16,7 @@ import numpy as np ...@@ -16,7 +16,7 @@ import numpy as np
import megengine as mge import megengine as mge
from .._imperative_rt import core2 from .._imperative_rt import core2, ops
from ..ops.builtin import Elemwise, OpDef, RemoteSend from ..ops.builtin import Elemwise, OpDef, RemoteSend
from ..ops.special import Const from ..ops.special import Const
from ..tensor.core import TensorBase, TensorWrapperBase, apply from ..tensor.core import TensorBase, TensorWrapperBase, apply
...@@ -211,3 +211,25 @@ class Grad: ...@@ -211,3 +211,25 @@ class Grad:
def __exit__(self, _1, _2, _3): def __exit__(self, _1, _2, _3):
del self._impl del self._impl
class Function(ops.PyOpBase):
def _default_rule(self, *args):
ret = self.forward(*args)
self.__single_output = isinstance(ret, core2.Tensor)
return ret
def _grad_rule(self, *args):
return self._default_rule(*args), self.backward
def __call__(self, *args):
ret = core2.apply(self, *args)
if self.__single_output:
(ret,) = ret
return ret
def __getstate__(self):
return self.__dict__
def __setstate__(self, state):
self.__dict__.update(state)
...@@ -11,8 +11,8 @@ from typing import Iterable ...@@ -11,8 +11,8 @@ from typing import Iterable
import numpy as np import numpy as np
from .. import functional as F from .. import functional as F
from ..core.autodiff.grad import Function
from ..core.tensor.dtype import _metadata_dict, get_quantized_dtype from ..core.tensor.dtype import _metadata_dict, get_quantized_dtype
from ..core.tensor.function import Function
from ..module import Module from ..module import Module
from ..tensor import Parameter, Tensor from ..tensor import Parameter, Tensor
from .utils import QuantMode, fake_quant_tensor, get_qparam_dict from .utils import QuantMode, fake_quant_tensor, get_qparam_dict
......
...@@ -12,7 +12,7 @@ from functools import partial ...@@ -12,7 +12,7 @@ from functools import partial
import numpy as np import numpy as np
from .. import functional as F from .. import functional as F
from ..core.tensor.function import Function from ..core.autodiff.grad import Function
from .fake_quant import _FakeQuantize from .fake_quant import _FakeQuantize
from .observer import MinMaxObserver from .observer import MinMaxObserver
from .qconfig import QConfig from .qconfig import QConfig
......
...@@ -12,11 +12,11 @@ from typing import Dict ...@@ -12,11 +12,11 @@ from typing import Dict
import numpy as np import numpy as np
from .. import functional as F from .. import functional as F
from ..core.autodiff.grad import Function
from ..core.ops import builtin from ..core.ops import builtin
from ..core.tensor import megbrain_graph from ..core.tensor import megbrain_graph
from ..core.tensor.core import apply from ..core.tensor.core import apply
from ..core.tensor.dtype import _metadata_dict from ..core.tensor.dtype import _metadata_dict
from ..core.tensor.function import Function
from ..tensor import Tensor from ..tensor import Tensor
......
...@@ -15,7 +15,7 @@ import megengine.optimizer as optimizer ...@@ -15,7 +15,7 @@ import megengine.optimizer as optimizer
from megengine import Parameter from megengine import Parameter
from megengine import Tensor as tensor from megengine import Tensor as tensor
from megengine import tensor from megengine import tensor
from megengine.core.tensor.function import Function from megengine.core.autodiff.grad import Function
from megengine.module import Module from megengine.module import Module
...@@ -239,7 +239,7 @@ def test_none_in_out_grad(): ...@@ -239,7 +239,7 @@ def test_none_in_out_grad():
def backward(self, grad_a, grad_b): def backward(self, grad_a, grad_b):
assert grad_b is None assert grad_b is None
return (grad_a, 0.0) return (grad_a, None)
class Simple(Module): class Simple(Module):
def __init__(self, a, b): def __init__(self, a, b):
......
...@@ -11,8 +11,7 @@ import pytest ...@@ -11,8 +11,7 @@ import pytest
import megengine as mge import megengine as mge
from megengine import tensor from megengine import tensor
from megengine.core.autodiff.grad import Grad from megengine.core.autodiff.grad import Function, Grad
from megengine.core.tensor.function import Function
from megengine.core.tensor.utils import make_shape_tuple from megengine.core.tensor.utils import make_shape_tuple
from megengine.quantization.fake_quant import TQT_Function from megengine.quantization.fake_quant import TQT_Function
from megengine.quantization.internal_fake_quant import * from megengine.quantization.internal_fake_quant import *
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册