diff --git a/imperative/python/megengine/__init__.py b/imperative/python/megengine/__init__.py index 60311783a3e3e10a0fd4bdd47317805f57944720..204dac0f8445defc85ab583b181c20a92d79e7af 100644 --- a/imperative/python/megengine/__init__.py +++ b/imperative/python/megengine/__init__.py @@ -6,11 +6,11 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import atexit +import ctypes import os -import sys import platform -import ctypes -import atexit +import sys if sys.platform == "win32": lib_path = os.path.join(os.path.dirname(__file__), "core/lib") @@ -71,14 +71,15 @@ if sys.platform == "win32": kernel32.SetErrorMode(old_error_mode) -from .core._imperative_rt.utils import _set_fork_exec_path_for_timed_func from .core._imperative_rt.core2 import sync +from .core._imperative_rt.utils import _set_fork_exec_path_for_timed_func from .device import * from .logger import enable_debug_log, get_logger, set_log_file, set_log_level from .serialization import load, save from .tensor import Parameter, Tensor, tensor +from .utils import comp_graph_tools as cgtools +from .utils import persistent_cache from .version import __version__ -from .utils import persistent_cache, comp_graph_tools as cgtools _set_fork_exec_path_for_timed_func( sys.executable, diff --git a/imperative/python/megengine/core/autodiff/grad.py b/imperative/python/megengine/core/autodiff/grad.py index 2ccfbb4fefd078dc46bc5d20d43e9adb1181bd76..d783a67d2d9d4d4b63b81ef01c67b573b0fddade 100644 --- a/imperative/python/megengine/core/autodiff/grad.py +++ b/imperative/python/megengine/core/autodiff/grad.py @@ -16,7 +16,7 @@ import numpy as np import megengine as mge -from .._imperative_rt import core2 +from .._imperative_rt import core2, ops from ..ops.builtin import Elemwise, OpDef, RemoteSend from ..ops.special import Const from ..tensor.core import TensorBase, TensorWrapperBase, apply @@ -211,3 +211,25 @@ class Grad: def __exit__(self, _1, _2, _3): del self._impl + + +class Function(ops.PyOpBase): + def _default_rule(self, *args): + ret = self.forward(*args) + self.__single_output = isinstance(ret, core2.Tensor) + return ret + + def _grad_rule(self, *args): + return self._default_rule(*args), self.backward + + def __call__(self, *args): + ret = core2.apply(self, *args) + if self.__single_output: + (ret,) = ret + return ret + + def __getstate__(self): + return self.__dict__ + + def __setstate__(self, state): + self.__dict__.update(state) diff --git a/imperative/python/megengine/quantization/fake_quant.py b/imperative/python/megengine/quantization/fake_quant.py index a5accd1dd914544d5f0a11e4f9eb4aa0dda020e4..e20813a72b98775b02f68970e1b59ffde8bf14a1 100644 --- a/imperative/python/megengine/quantization/fake_quant.py +++ b/imperative/python/megengine/quantization/fake_quant.py @@ -11,8 +11,8 @@ from typing import Iterable import numpy as np from .. import functional as F +from ..core.autodiff.grad import Function from ..core.tensor.dtype import _metadata_dict, get_quantized_dtype -from ..core.tensor.function import Function from ..module import Module from ..tensor import Parameter, Tensor from .utils import QuantMode, fake_quant_tensor, get_qparam_dict diff --git a/imperative/python/megengine/quantization/internal_fake_quant.py b/imperative/python/megengine/quantization/internal_fake_quant.py index 02d1d89767eab1ba801c075a7cea3e53edcaed39..ab980406eed865239213d696b8e216c9bda8b1ce 100644 --- a/imperative/python/megengine/quantization/internal_fake_quant.py +++ b/imperative/python/megengine/quantization/internal_fake_quant.py @@ -12,7 +12,7 @@ from functools import partial import numpy as np from .. import functional as F -from ..core.tensor.function import Function +from ..core.autodiff.grad import Function from .fake_quant import _FakeQuantize from .observer import MinMaxObserver from .qconfig import QConfig diff --git a/imperative/python/megengine/quantization/utils.py b/imperative/python/megengine/quantization/utils.py index 31c342e1115f4ff68f4e0ece0c2d405020249eb4..95d4db1bcd093a573b12d4ed12913888cd3a816b 100644 --- a/imperative/python/megengine/quantization/utils.py +++ b/imperative/python/megengine/quantization/utils.py @@ -12,11 +12,11 @@ from typing import Dict import numpy as np from .. import functional as F +from ..core.autodiff.grad import Function from ..core.ops import builtin from ..core.tensor import megbrain_graph from ..core.tensor.core import apply from ..core.tensor.dtype import _metadata_dict -from ..core.tensor.function import Function from ..tensor import Tensor diff --git a/imperative/python/test/unit/core/test_function.py b/imperative/python/test/unit/core/test_function.py index fc01faf912d434360179067a6b5bed645e54e274..accc751067e25eb62fbf40e39eda431d8a66deea 100644 --- a/imperative/python/test/unit/core/test_function.py +++ b/imperative/python/test/unit/core/test_function.py @@ -15,7 +15,7 @@ import megengine.optimizer as optimizer from megengine import Parameter from megengine import Tensor as tensor from megengine import tensor -from megengine.core.tensor.function import Function +from megengine.core.autodiff.grad import Function from megengine.module import Module @@ -239,7 +239,7 @@ def test_none_in_out_grad(): def backward(self, grad_a, grad_b): assert grad_b is None - return (grad_a, 0.0) + return (grad_a, None) class Simple(Module): def __init__(self, a, b): diff --git a/imperative/python/test/unit/quantization/test_fake_quant.py b/imperative/python/test/unit/quantization/test_fake_quant.py index 60cda8c46ad048a8e672f1f5213a68155d9682a4..bf84c93df7ad8a368fb093e18043f288034af8f5 100644 --- a/imperative/python/test/unit/quantization/test_fake_quant.py +++ b/imperative/python/test/unit/quantization/test_fake_quant.py @@ -11,8 +11,7 @@ import pytest import megengine as mge from megengine import tensor -from megengine.core.autodiff.grad import Grad -from megengine.core.tensor.function import Function +from megengine.core.autodiff.grad import Function, Grad from megengine.core.tensor.utils import make_shape_tuple from megengine.quantization.fake_quant import TQT_Function from megengine.quantization.internal_fake_quant import *