From b5e46ae92f91f8cf6de44296e358111543cebacc Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 25 Dec 2020 14:38:47 +0800 Subject: [PATCH] feat(mge): restore Function GitOrigin-RevId: dd455238bae9e937b70eaaa164ad215ef3126d5d --- imperative/python/megengine/__init__.py | 11 +++++---- .../python/megengine/core/autodiff/grad.py | 24 ++++++++++++++++++- .../megengine/quantization/fake_quant.py | 2 +- .../quantization/internal_fake_quant.py | 2 +- .../python/megengine/quantization/utils.py | 2 +- .../python/test/unit/core/test_function.py | 4 ++-- .../test/unit/quantization/test_fake_quant.py | 3 +-- 7 files changed, 35 insertions(+), 13 deletions(-) diff --git a/imperative/python/megengine/__init__.py b/imperative/python/megengine/__init__.py index 60311783a..204dac0f8 100644 --- a/imperative/python/megengine/__init__.py +++ b/imperative/python/megengine/__init__.py @@ -6,11 +6,11 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import atexit +import ctypes import os -import sys import platform -import ctypes -import atexit +import sys if sys.platform == "win32": lib_path = os.path.join(os.path.dirname(__file__), "core/lib") @@ -71,14 +71,15 @@ if sys.platform == "win32": kernel32.SetErrorMode(old_error_mode) -from .core._imperative_rt.utils import _set_fork_exec_path_for_timed_func from .core._imperative_rt.core2 import sync +from .core._imperative_rt.utils import _set_fork_exec_path_for_timed_func from .device import * from .logger import enable_debug_log, get_logger, set_log_file, set_log_level from .serialization import load, save from .tensor import Parameter, Tensor, tensor +from .utils import comp_graph_tools as cgtools +from .utils import persistent_cache from .version import __version__ -from .utils import persistent_cache, comp_graph_tools as cgtools _set_fork_exec_path_for_timed_func( sys.executable, diff --git a/imperative/python/megengine/core/autodiff/grad.py b/imperative/python/megengine/core/autodiff/grad.py index 2ccfbb4fe..d783a67d2 100644 --- a/imperative/python/megengine/core/autodiff/grad.py +++ b/imperative/python/megengine/core/autodiff/grad.py @@ -16,7 +16,7 @@ import numpy as np import megengine as mge -from .._imperative_rt import core2 +from .._imperative_rt import core2, ops from ..ops.builtin import Elemwise, OpDef, RemoteSend from ..ops.special import Const from ..tensor.core import TensorBase, TensorWrapperBase, apply @@ -211,3 +211,25 @@ class Grad: def __exit__(self, _1, _2, _3): del self._impl + + +class Function(ops.PyOpBase): + def _default_rule(self, *args): + ret = self.forward(*args) + self.__single_output = isinstance(ret, core2.Tensor) + return ret + + def _grad_rule(self, *args): + return self._default_rule(*args), self.backward + + def __call__(self, *args): + ret = core2.apply(self, *args) + if self.__single_output: + (ret,) = ret + return ret + + def __getstate__(self): + return self.__dict__ + + def __setstate__(self, state): + self.__dict__.update(state) diff --git a/imperative/python/megengine/quantization/fake_quant.py b/imperative/python/megengine/quantization/fake_quant.py index a5accd1dd..e20813a72 100644 --- a/imperative/python/megengine/quantization/fake_quant.py +++ b/imperative/python/megengine/quantization/fake_quant.py @@ -11,8 +11,8 @@ from typing import Iterable import numpy as np from .. import functional as F +from ..core.autodiff.grad import Function from ..core.tensor.dtype import _metadata_dict, get_quantized_dtype -from ..core.tensor.function import Function from ..module import Module from ..tensor import Parameter, Tensor from .utils import QuantMode, fake_quant_tensor, get_qparam_dict diff --git a/imperative/python/megengine/quantization/internal_fake_quant.py b/imperative/python/megengine/quantization/internal_fake_quant.py index 02d1d8976..ab980406e 100644 --- a/imperative/python/megengine/quantization/internal_fake_quant.py +++ b/imperative/python/megengine/quantization/internal_fake_quant.py @@ -12,7 +12,7 @@ from functools import partial import numpy as np from .. import functional as F -from ..core.tensor.function import Function +from ..core.autodiff.grad import Function from .fake_quant import _FakeQuantize from .observer import MinMaxObserver from .qconfig import QConfig diff --git a/imperative/python/megengine/quantization/utils.py b/imperative/python/megengine/quantization/utils.py index 31c342e11..95d4db1bc 100644 --- a/imperative/python/megengine/quantization/utils.py +++ b/imperative/python/megengine/quantization/utils.py @@ -12,11 +12,11 @@ from typing import Dict import numpy as np from .. import functional as F +from ..core.autodiff.grad import Function from ..core.ops import builtin from ..core.tensor import megbrain_graph from ..core.tensor.core import apply from ..core.tensor.dtype import _metadata_dict -from ..core.tensor.function import Function from ..tensor import Tensor diff --git a/imperative/python/test/unit/core/test_function.py b/imperative/python/test/unit/core/test_function.py index fc01faf91..accc75106 100644 --- a/imperative/python/test/unit/core/test_function.py +++ b/imperative/python/test/unit/core/test_function.py @@ -15,7 +15,7 @@ import megengine.optimizer as optimizer from megengine import Parameter from megengine import Tensor as tensor from megengine import tensor -from megengine.core.tensor.function import Function +from megengine.core.autodiff.grad import Function from megengine.module import Module @@ -239,7 +239,7 @@ def test_none_in_out_grad(): def backward(self, grad_a, grad_b): assert grad_b is None - return (grad_a, 0.0) + return (grad_a, None) class Simple(Module): def __init__(self, a, b): diff --git a/imperative/python/test/unit/quantization/test_fake_quant.py b/imperative/python/test/unit/quantization/test_fake_quant.py index 60cda8c46..bf84c93df 100644 --- a/imperative/python/test/unit/quantization/test_fake_quant.py +++ b/imperative/python/test/unit/quantization/test_fake_quant.py @@ -11,8 +11,7 @@ import pytest import megengine as mge from megengine import tensor -from megengine.core.autodiff.grad import Grad -from megengine.core.tensor.function import Function +from megengine.core.autodiff.grad import Function, Grad from megengine.core.tensor.utils import make_shape_tuple from megengine.quantization.fake_quant import TQT_Function from megengine.quantization.internal_fake_quant import * -- GitLab