提交 b5e46ae9 编写于 作者: M Megvii Engine Team

feat(mge): restore Function

GitOrigin-RevId: dd455238bae9e937b70eaaa164ad215ef3126d5d
上级 dc250745
......@@ -6,11 +6,11 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import atexit
import ctypes
import os
import sys
import platform
import ctypes
import atexit
import sys
if sys.platform == "win32":
lib_path = os.path.join(os.path.dirname(__file__), "core/lib")
......@@ -71,14 +71,15 @@ if sys.platform == "win32":
kernel32.SetErrorMode(old_error_mode)
from .core._imperative_rt.utils import _set_fork_exec_path_for_timed_func
from .core._imperative_rt.core2 import sync
from .core._imperative_rt.utils import _set_fork_exec_path_for_timed_func
from .device import *
from .logger import enable_debug_log, get_logger, set_log_file, set_log_level
from .serialization import load, save
from .tensor import Parameter, Tensor, tensor
from .utils import comp_graph_tools as cgtools
from .utils import persistent_cache
from .version import __version__
from .utils import persistent_cache, comp_graph_tools as cgtools
_set_fork_exec_path_for_timed_func(
sys.executable,
......
......@@ -16,7 +16,7 @@ import numpy as np
import megengine as mge
from .._imperative_rt import core2
from .._imperative_rt import core2, ops
from ..ops.builtin import Elemwise, OpDef, RemoteSend
from ..ops.special import Const
from ..tensor.core import TensorBase, TensorWrapperBase, apply
......@@ -211,3 +211,25 @@ class Grad:
def __exit__(self, _1, _2, _3):
del self._impl
class Function(ops.PyOpBase):
def _default_rule(self, *args):
ret = self.forward(*args)
self.__single_output = isinstance(ret, core2.Tensor)
return ret
def _grad_rule(self, *args):
return self._default_rule(*args), self.backward
def __call__(self, *args):
ret = core2.apply(self, *args)
if self.__single_output:
(ret,) = ret
return ret
def __getstate__(self):
return self.__dict__
def __setstate__(self, state):
self.__dict__.update(state)
......@@ -11,8 +11,8 @@ from typing import Iterable
import numpy as np
from .. import functional as F
from ..core.autodiff.grad import Function
from ..core.tensor.dtype import _metadata_dict, get_quantized_dtype
from ..core.tensor.function import Function
from ..module import Module
from ..tensor import Parameter, Tensor
from .utils import QuantMode, fake_quant_tensor, get_qparam_dict
......
......@@ -12,7 +12,7 @@ from functools import partial
import numpy as np
from .. import functional as F
from ..core.tensor.function import Function
from ..core.autodiff.grad import Function
from .fake_quant import _FakeQuantize
from .observer import MinMaxObserver
from .qconfig import QConfig
......
......@@ -12,11 +12,11 @@ from typing import Dict
import numpy as np
from .. import functional as F
from ..core.autodiff.grad import Function
from ..core.ops import builtin
from ..core.tensor import megbrain_graph
from ..core.tensor.core import apply
from ..core.tensor.dtype import _metadata_dict
from ..core.tensor.function import Function
from ..tensor import Tensor
......
......@@ -15,7 +15,7 @@ import megengine.optimizer as optimizer
from megengine import Parameter
from megengine import Tensor as tensor
from megengine import tensor
from megengine.core.tensor.function import Function
from megengine.core.autodiff.grad import Function
from megengine.module import Module
......@@ -239,7 +239,7 @@ def test_none_in_out_grad():
def backward(self, grad_a, grad_b):
assert grad_b is None
return (grad_a, 0.0)
return (grad_a, None)
class Simple(Module):
def __init__(self, a, b):
......
......@@ -11,8 +11,7 @@ import pytest
import megengine as mge
from megengine import tensor
from megengine.core.autodiff.grad import Grad
from megengine.core.tensor.function import Function
from megengine.core.autodiff.grad import Function, Grad
from megengine.core.tensor.utils import make_shape_tuple
from megengine.quantization.fake_quant import TQT_Function
from megengine.quantization.internal_fake_quant import *
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册