diff --git a/imperative/python/megengine/core/autodiff/builtin_op_utils.py b/imperative/python/megengine/core/autodiff/builtin_op_utils.py index e21f48575081308dd67bc7df01615746100f39ea..9ae1e9f6998ce38f5944bbc45978b40819050291 100644 --- a/imperative/python/megengine/core/autodiff/builtin_op_utils.py +++ b/imperative/python/megengine/core/autodiff/builtin_op_utils.py @@ -12,6 +12,7 @@ import itertools import numpy as np from .._imperative_rt import TensorAttr, imperative +from .._imperative_rt.core2 import apply from ..ops.builtin import ( Broadcast, Elemwise, @@ -25,37 +26,6 @@ from ..ops.builtin import ( Subtensor, ) from ..ops.special import Const -from ..tensor.core import apply -from ..tensor.function import Function - - -@functools.singledispatch -def builtin_op_get_backward_fn(op: OpDef, inputs, outputs, input_requires_grad): - assert 0 - - -@builtin_op_get_backward_fn.register(OpDef) -def _(op: OpDef, inputs, outputs, input_requires_grad): - if isinstance(op, Reshape): - grad_fn = reshape_grad_fn - elif isinstance(op, Subtensor): - grad_fn = subtensor_grad_fn - elif isinstance(op, IndexingMultiAxisVec): - grad_fn = indexingMultiAxisVec_grad_fn - elif isinstance(op, Broadcast) or ( - isinstance(op, Elemwise) and op.mode == Elemwise.Mode.ADD - ): - grad_fn = elemwise_add_grad_fn - elif isinstance(op, Reduce) and op.mode == Reduce.Mode.SUM: - grad_fn = reduce_sum_grad_fn - else: - grad_fn = default_grad_fn - return grad_fn(op, inputs, outputs, input_requires_grad) - - -@builtin_op_get_backward_fn.register(Function) -def _(op: Function, inputs, outputs, input_requires_grad): - return op.get_backward_fn(), [True,] * len(outputs) def default_grad_fn(op, inputs, outputs, input_requires_grad): diff --git a/imperative/python/megengine/core/autodiff/grad.py b/imperative/python/megengine/core/autodiff/grad.py index d783a67d2d9d4d4b63b81ef01c67b573b0fddade..791fb927ec6be5a7f8891e81ded415e411e65022 100644 --- a/imperative/python/megengine/core/autodiff/grad.py +++ b/imperative/python/megengine/core/autodiff/grad.py @@ -19,8 +19,6 @@ import megengine as mge from .._imperative_rt import core2, ops from ..ops.builtin import Elemwise, OpDef, RemoteSend from ..ops.special import Const -from ..tensor.core import TensorBase, TensorWrapperBase, apply -from ..tensor.function import Function from . import builtin_op_utils """ Some notes: @@ -48,146 +46,6 @@ def get_grad_managers(): return [_grad_manager_dict[key] for key in _grad_manager_dict] -def add(a, b): - (c,) = apply(Elemwise(Elemwise.Mode.ADD), a, b) - return c - - -def get_tensor(x): - # use recursion to avoid infinite loop - if isinstance(x, Tensor): - return x - try: - x = x.__wrapped__ - except AttributeError: - raise TypeError(type(x)) - return get_tensor(x) - - -class clearable: - __cleared = False - - def __bool__(self): - return not self.__cleared - - def clear(self): - self.__dict__.clear() - self.__cleared = True - - -class OpNode(clearable): - """ OpNode saves all the information to form the computational graph. - """ - - def __init__(self): - self.id = None - self.inputs = None # Could be VariableNode - self.outputs = None # Could be VariableNode - self.backward = None - self.has_grad_fn = None - self.backward_allow_noinput = False - - -class VariableNode(clearable): - """ VariableNode saves OpNode and callback. - FIXME!!! Explain manager and owner - """ - - def __init__(self, manager, owner, opnode=None, callback=None): - # manager is Grad type - self.manager = weakref.ref(manager) - # owner is Tensor type - self.owner = weakref.ref(owner) - self.opnode = opnode - self.callback = callback - - -class Tracer(clearable, TensorBase): - def __init__(self, node=None): - """ type(node) is VariableNode - """ - self.node = node - - -@functools.singledispatch -def check_backward_allow_noinput(op: OpDef): - return False - - -@functools.singledispatch -def get_op_has_grad_fn(op: OpDef): - assert 0 - - -@get_op_has_grad_fn.register(OpDef) -def _(op: OpDef): - return default_has_grad_fn - - -@get_op_has_grad_fn.register(Function) -def _(op: Function): - return default_has_grad_fn - - -def default_has_grad_fn(opnode, reached): - for v in opnode.outputs: - if v() in reached: - return True - return False - - -@apply.register() -def tracer_apply(op: (OpDef, Function), *args: typing.Optional[Tracer]): - args = tuple(i if isinstance(i, Tracer) else None for i in args) - input_requires_grad = list(map(bool, args)) - if not any(input_requires_grad): - return - - ctx = get_context() - manager = None - assert len(ctx.inputs) == len(args) - for i, j in zip(ctx.inputs, args): - if j: - j = j.node - assert i is j.owner() - if manager is None: - manager = j.manager() - assert manager - else: - assert manager is j.manager() - - if not manager._enabled: - return - - # register backward method - # tuple of backward functions corresponding to dy / dx_i - # None means y is not a function of x_i - backward, output_need_grad = builtin_op_utils.builtin_op_get_backward_fn( - op, ctx.inputs, ctx.outputs, input_requires_grad - ) - assert len(ctx.outputs) == len(output_need_grad) - if not any(output_need_grad): - return - - opnode, outputs = manager._new_opnode([i and i.node for i in args], ctx.outputs) - if isinstance(op, RemoteSend): - manager.remote_send_cache.append(opnode) - opnode.backward = backward - - outputs = [x if y else None for (x, y) in zip(outputs, output_need_grad)] - - opnode.backward_allow_noinput = check_backward_allow_noinput(op) - - opnode.has_grad_fn = get_op_has_grad_fn(op) - - return tuple(outputs) - - -@apply.register() -def _(op: Const, *_: typing.Optional[Tracer]): - return None - - class Grad: def __init__(self): self._impl = core2.GradKey() diff --git a/imperative/python/megengine/core/ops/special.py b/imperative/python/megengine/core/ops/special.py index 4b2de494bdfaaa3c3871388534f550f024b45f04..54013280fb469e5d1010979c81f27a3cfb041f31 100644 --- a/imperative/python/megengine/core/ops/special.py +++ b/imperative/python/megengine/core/ops/special.py @@ -8,9 +8,6 @@ # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import numpy as np -# from .._imperative_rt.core2 import Tensor -from ..tensor.core import OpBase, TensorBase, apply - class Const: def __init__(self, value=None, *, dtype=None, device=None): diff --git a/imperative/python/megengine/core/tensor/core.py b/imperative/python/megengine/core/tensor/core.py index 0c1bcee79cafd1b8b6ed8d1b033504ee06cefdef..a74616cf40b66c5c0019bbef54b44218b34f7508 100644 --- a/imperative/python/megengine/core/tensor/core.py +++ b/imperative/python/megengine/core/tensor/core.py @@ -13,12 +13,9 @@ import sys import typing from abc import ABC -from .multipledispatch import Dispatcher - -class OpBase(ABC): - def __call__(self, *args): - return apply(self, *args) +class OpBase: + pass class TensorBase: @@ -27,22 +24,3 @@ class TensorBase: class TensorWrapperBase: pass - - -apply = Dispatcher("apply") - -OpBase.apply = apply - - -@apply.register() -def _(op: OpBase, *args: TensorBase): - raise NotImplementedError - - -@apply.register() -def _(op: OpBase, *args: TensorWrapperBase): - assert args - Wrapper = type(args[0]) - outputs = apply(op, *(i.__wrapped__ for i in args)) - assert isinstance(outputs, tuple) - return tuple(map(Wrapper, outputs)) diff --git a/imperative/python/megengine/core/tensor/function.py b/imperative/python/megengine/core/tensor/function.py deleted file mode 100644 index d7b6b8cf78c348f3c63b50bf0d89b8e444a589d6..0000000000000000000000000000000000000000 --- a/imperative/python/megengine/core/tensor/function.py +++ /dev/null @@ -1,154 +0,0 @@ -# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") -# -# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -from ..ops.builtin import OpDef -from .core import TensorBase, TensorWrapperBase, apply - - -class Function: - """ - Defines a block of operations with customizable differentiation. - - The computation should be defined in ``forward`` method, with gradient - computation defined in ``backward`` method. - - Each instance of ``Function`` should be used only once during forwardding. - - Examples: - - .. testcode:: - - class Sigmoid(Function): - def forward(self, x): - y = 1 / (1 + F.exp(-x)) - self.y = y - return y - - def backward(self, output_grads): - y = self.y - return output_grads * y * (1-y) - - """ - - def __init__(self, *args, **kwargs): - pass - - def __call__(self, *args): - ret = apply(self, *args) - if type(ret) == tuple and len(ret) == 1: - return ret[0] - return ret - - def forward(self, *args, **kwargs): - """ - Applies operations to ``inputs`` and returns results. It must be overriden by all subclasses. - - :param input: input tensors. - :return: a tuple of Tensor or a single Tensor. - - .. note:: - - This method should return a tuple of Tensor or a single Tensor representing the output - of the function. - """ - raise NotImplementedError - - def backward(self, *output_grads): - """ - Compute the gradient of the forward function. It must be overriden by all subclasses. - - :param output_grads: gradients of outputs that are returned by :meth:`~.function.Function.forward`. - - .. note:: - - In case when some tensors of outputs are not related to loss function, the corresponding - values in ``output_grads`` would be ``None``. - - .. note:: - - This method should return a tuple which containing the gradients of all inputs, in the same order - as the ``inputs`` argument of :meth:`~.function.Function.forward` . A ``Tensor`` could be returned - instead if there is only one input. If users want to stop the propagation of some gradients, - the corresponding returned values should be set ``None`` . - - """ - raise NotImplementedError - - def get_backward_fn(self): - if self.backward is None: - return None - - def _backward(*output_grads): - if type(output_grads) is tuple: - _output_grads = [ - TensorWrapper(i) if i is not None else i for i in output_grads - ] - else: - _output_grads = ( - TensorWrapper(output_grads) - if output_grads is not None - else output_grads, - ) - ret = self.backward(*_output_grads) - if type(ret) is not tuple: - ret = (ret,) - ret = tuple( - i.__wrapped__ if isinstance(i, TensorWrapper) else i for i in ret - ) - return ret - - return _backward - - -Function.apply = Function.__call__ - - -@apply.register() -def _(op: Function, *args: TensorWrapperBase): - assert args - Wrapper = type(args[0]) - - # compute the value for self define function - extra_data_dic = {} - for arg in args: - extra_data_dic[arg.__wrapped__] = arg.__wrapped__._extra_data - arg.__wrapped__._extra_data = {} - - rets = op.forward(*args) - - for arg in args: - arg.__wrapped__._extra_data = extra_data_dic[arg.__wrapped__] - - # update the gradient information for self define function - inputs = tuple(map(lambda i: i.__wrapped__, args)) - outputs = ( - tuple(map(lambda i: i.__wrapped__, rets)) - if type(rets) is tuple - else (rets.__wrapped__,) - ) - - for output in outputs: - if output not in inputs: - output._extra_data = {} - - with push_context() as ctx: - ctx.inputs = inputs - ctx.outputs = outputs - for k in set().union(*(i._extra_data for i in inputs if isinstance(i, Tensor))): - ctx.key = k - data = tuple( - i._extra_data.get(k) if isinstance(i, Tensor) else i for i in inputs - ) - # data are instances of Tracer - # dispatched to apply.add@grad.py - rets = apply(op, *data) - if rets is not None: - assert len(outputs) == len(rets) - for t, i in zip(outputs, rets): - t._extra_data[k] = i - - return tuple(map(Wrapper, outputs)) diff --git a/imperative/python/megengine/core/tensor/multipledispatch/__init__.py b/imperative/python/megengine/core/tensor/multipledispatch/__init__.py deleted file mode 100644 index 84d1bdd628d3c257fe323d29f71b8a9e99f2113f..0000000000000000000000000000000000000000 --- a/imperative/python/megengine/core/tensor/multipledispatch/__init__.py +++ /dev/null @@ -1,53 +0,0 @@ -# Copyright (c) 2014 Matthew Rocklin -# -# All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# a. Redistributions of source code must retain the above copyright notice, -# this list of conditions and the following disclaimer. -# b. Redistributions in binary form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in the -# documentation and/or other materials provided with the distribution. -# c. Neither the name of multipledispatch nor the names of its contributors -# may be used to endorse or promote products derived from this software -# without specific prior written permission. -# -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -# ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR -# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT -# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY -# OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH -# DAMAGE. -# -# -------------------------------------------------------------------------------------- -# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") -# -# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# -# This file has been modified by Megvii ("Megvii Modifications"). -# All Megvii Modifications are Copyright (C) 2014-2020 Megvii Inc. All rights reserved. -# -------------------------------------------------------------------------------------- - -# This directory is a fork of multipledispatch. -# -# Repo: https://github.com/mrocklin/multipledispatch -# Commit: 9e3c87d0cee57972fd5cc33fe5cacde77c781834 -# Authors: Matthew Rocklin et al. -# -# The original LICENSE file is included in the ACKNOWLEDGEMENT file under -# MegEngine root directory. - -from .core import dispatch -from .dispatcher import Dispatcher diff --git a/imperative/python/megengine/core/tensor/multipledispatch/conflict.py b/imperative/python/megengine/core/tensor/multipledispatch/conflict.py deleted file mode 100644 index ec852aa767fdcacecdabf6434bd80bcf4867f243..0000000000000000000000000000000000000000 --- a/imperative/python/megengine/core/tensor/multipledispatch/conflict.py +++ /dev/null @@ -1,165 +0,0 @@ -# Copyright (c) 2014 Matthew Rocklin -# -# All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# a. Redistributions of source code must retain the above copyright notice, -# this list of conditions and the following disclaimer. -# b. Redistributions in binary form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in the -# documentation and/or other materials provided with the distribution. -# c. Neither the name of multipledispatch nor the names of its contributors -# may be used to endorse or promote products derived from this software -# without specific prior written permission. -# -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -# ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR -# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT -# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY -# OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH -# DAMAGE. -# -# -------------------------------------------------------------------------------------- -# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") -# -# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# -# This file has been modified by Megvii ("Megvii Modifications"). -# All Megvii Modifications are Copyright (C) 2014-2020 Megvii Inc. All rights reserved. -# -------------------------------------------------------------------------------------- - -from collections import OrderedDict - -from .utils import _toposort, groupby -from .variadic import isvariadic - - -class AmbiguityWarning(Warning): - pass - - -def supercedes(a, b): - """ A is consistent and strictly more specific than B """ - if len(a) < len(b): - # only case is if a is empty and b is variadic - return not a and len(b) == 1 and isvariadic(b[-1]) - elif len(a) == len(b): - return all(map(issubclass, a, b)) - else: - # len(a) > len(b) - p1 = 0 - p2 = 0 - while p1 < len(a) and p2 < len(b): - cur_a = a[p1] - cur_b = b[p2] - if not (isvariadic(cur_a) or isvariadic(cur_b)): - if not issubclass(cur_a, cur_b): - return False - p1 += 1 - p2 += 1 - elif isvariadic(cur_a): - assert p1 == len(a) - 1 - return p2 == len(b) - 1 and issubclass(cur_a, cur_b) - elif isvariadic(cur_b): - assert p2 == len(b) - 1 - if not issubclass(cur_a, cur_b): - return False - p1 += 1 - return p2 == len(b) - 1 and p1 == len(a) - - -def consistent(a, b): - """ It is possible for an argument list to satisfy both A and B """ - - # Need to check for empty args - if not a: - return not b or isvariadic(b[0]) - if not b: - return not a or isvariadic(a[0]) - - # Non-empty args check for mutual subclasses - if len(a) == len(b): - return all(issubclass(aa, bb) or issubclass(bb, aa) for aa, bb in zip(a, b)) - else: - p1 = 0 - p2 = 0 - while p1 < len(a) and p2 < len(b): - cur_a = a[p1] - cur_b = b[p2] - if not issubclass(cur_b, cur_a) and not issubclass(cur_a, cur_b): - return False - if not (isvariadic(cur_a) or isvariadic(cur_b)): - p1 += 1 - p2 += 1 - elif isvariadic(cur_a): - p2 += 1 - elif isvariadic(cur_b): - p1 += 1 - # We only need to check for variadic ends - # Variadic types are guaranteed to be the last element - return isvariadic(cur_a) and p2 == len(b) or isvariadic(cur_b) and p1 == len(a) - - -def ambiguous(a, b): - """ A is consistent with B but neither is strictly more specific """ - return consistent(a, b) and not (supercedes(a, b) or supercedes(b, a)) - - -def ambiguities(signatures): - """ All signature pairs such that A is ambiguous with B """ - signatures = list(map(tuple, signatures)) - return set( - (a, b) - for a in signatures - for b in signatures - if hash(a) < hash(b) - and ambiguous(a, b) - and not any(supercedes(c, a) and supercedes(c, b) for c in signatures) - ) - - -def super_signature(signatures): - """ A signature that would break ambiguities """ - n = len(signatures[0]) - assert all(len(s) == n for s in signatures) - - return [max([type.mro(sig[i]) for sig in signatures], key=len)[0] for i in range(n)] - - -def edge(a, b, tie_breaker=hash): - """ A should be checked before B - - Tie broken by tie_breaker, defaults to ``hash`` - """ - # A either supercedes B and B does not supercede A or if B does then call - # tie_breaker - return supercedes(a, b) and ( - not supercedes(b, a) or tie_breaker(a) > tie_breaker(b) - ) - - -def ordering(signatures): - """ A sane ordering of signatures to check, first to last - - Topoological sort of edges as given by ``edge`` and ``supercedes`` - """ - signatures = list(map(tuple, signatures)) - edges = [(a, b) for a in signatures for b in signatures if edge(a, b)] - edges = groupby(lambda x: x[0], edges) - for s in signatures: - if s not in edges: - edges[s] = [] - edges = OrderedDict((k, [b for a, b in v]) for k, v in edges.items()) - return _toposort(edges) diff --git a/imperative/python/megengine/core/tensor/multipledispatch/core.py b/imperative/python/megengine/core/tensor/multipledispatch/core.py deleted file mode 100644 index 13e59b87b59349d495e30126d9bb0f7c11071c72..0000000000000000000000000000000000000000 --- a/imperative/python/megengine/core/tensor/multipledispatch/core.py +++ /dev/null @@ -1,130 +0,0 @@ -# Copyright (c) 2014 Matthew Rocklin -# -# All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# a. Redistributions of source code must retain the above copyright notice, -# this list of conditions and the following disclaimer. -# b. Redistributions in binary form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in the -# documentation and/or other materials provided with the distribution. -# c. Neither the name of multipledispatch nor the names of its contributors -# may be used to endorse or promote products derived from this software -# without specific prior written permission. -# -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -# ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR -# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT -# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY -# OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH -# DAMAGE. -# -# -------------------------------------------------------------------------------------- -# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") -# -# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# -# This file has been modified by Megvii ("Megvii Modifications"). -# All Megvii Modifications are Copyright (C) 2014-2020 Megvii Inc. All rights reserved. -# -------------------------------------------------------------------------------------- - -import inspect -import sys - -from .dispatcher import Dispatcher, MethodDispatcher, ambiguity_warn - -global_namespace = dict() - - -def dispatch(*types, **kwargs): - """ Dispatch function on the types of the inputs - - Supports dispatch on all non-keyword arguments. - - Collects implementations based on the function name. Ignores namespaces. - - If ambiguous type signatures occur a warning is raised when the function is - defined suggesting the additional method to break the ambiguity. - - Examples - -------- - - >>> @dispatch(int) - ... def f(x): - ... return x + 1 - - >>> @dispatch(float) - ... def f(x): - ... return x - 1 - - >>> f(3) - 4 - >>> f(3.0) - 2.0 - - Specify an isolated namespace with the namespace keyword argument - - >>> my_namespace = dict() - >>> @dispatch(int, namespace=my_namespace) - ... def foo(x): - ... return x + 1 - - Dispatch on instance methods within classes - - >>> class MyClass(object): - ... @dispatch(list) - ... def __init__(self, data): - ... self.data = data - ... @dispatch(int) - ... def __init__(self, datum): - ... self.data = [datum] - """ - namespace = kwargs.get("namespace", global_namespace) - - types = tuple(types) - - def _df(func): - name = func.__name__ - - if ismethod(func): - dispatcher = inspect.currentframe().f_back.f_locals.get( - name, MethodDispatcher(name), - ) - else: - if name not in namespace: - namespace[name] = Dispatcher(name) - dispatcher = namespace[name] - - dispatcher.add(types, func) - return dispatcher - - return _df - - -def ismethod(func): - """ Is func a method? - - Note that this has to work as the method is defined but before the class is - defined. At this stage methods look like functions. - """ - if hasattr(inspect, "signature"): - signature = inspect.signature(func) - return signature.parameters.get("self", None) is not None - else: - if sys.version_info.major < 3: - spec = inspect.getargspec(func) - else: - spec = inspect.getfullargspec(func) - return spec and spec.args and spec.args[0] == "self" diff --git a/imperative/python/megengine/core/tensor/multipledispatch/dispatcher.py b/imperative/python/megengine/core/tensor/multipledispatch/dispatcher.py deleted file mode 100644 index 3660449f9c548358c5c45cc42991e1d68e1d68a4..0000000000000000000000000000000000000000 --- a/imperative/python/megengine/core/tensor/multipledispatch/dispatcher.py +++ /dev/null @@ -1,445 +0,0 @@ -# Copyright (c) 2014 Matthew Rocklin -# -# All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# a. Redistributions of source code must retain the above copyright notice, -# this list of conditions and the following disclaimer. -# b. Redistributions in binary form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in the -# documentation and/or other materials provided with the distribution. -# c. Neither the name of multipledispatch nor the names of its contributors -# may be used to endorse or promote products derived from this software -# without specific prior written permission. -# -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -# ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR -# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT -# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY -# OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH -# DAMAGE. -# -# -------------------------------------------------------------------------------------- -# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") -# -# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# -# This file has been modified by Megvii ("Megvii Modifications"). -# All Megvii Modifications are Copyright (C) 2014-2020 Megvii Inc. All rights reserved. -# -------------------------------------------------------------------------------------- - -import copy -import inspect -import itertools as itl -from warnings import warn - -from ..._imperative_rt.dispatcher import Dispatcher as CDispatcher -from .conflict import AmbiguityWarning, ambiguities, ordering, super_signature -from .utils import expand_tuples, parse_union -from .variadic import Variadic, isvariadic - - -def ambiguity_warn(dispatcher, ambiguities): - """ Raise warning when ambiguity is detected - - Parameters - ---------- - dispatcher : Dispatcher - The dispatcher on which the ambiguity was detected - ambiguities : set - Set of type signature pairs that are ambiguous within this dispatcher - - See Also: - Dispatcher.add - warning_text - """ - warn(warning_text(dispatcher.name, ambiguities), AmbiguityWarning) - - -def variadic_signature_matches_iter(types, full_signature): - """ - Check if a set of input types matches a variadic signature. - - Notes - ----- - The algorithm is as follows: - - Initialize the current signature to the first in the sequence - - For each type in `types`: - If the current signature is variadic - If the type matches the signature - yield True - Else - Try to get the next signature - If no signatures are left we can't possibly have a match - so yield False - Else - yield True if the type matches the current signature - Get the next signature - """ - sigiter = iter(full_signature) - sig = next(sigiter) - for typ in types: - matches = issubclass(typ, sig) - yield matches - if not isvariadic(sig): - # we're not matching a variadic argument, so move to the next - # element in the signature - sig = next(sigiter) - else: - try: - sig = next(sigiter) - except StopIteration: - assert isvariadic(sig) - yield True - else: - # We have signature items left over, so all of our arguments - # haven't matched - yield False - - -def variadic_signature_matches(types, full_signature): - # No arguments always matches a variadic signature - assert full_signature - return all(variadic_signature_matches_iter(types, full_signature)) - - -def get_func_signature(function): - sig = inspect.signature(function) - types = [] - for p in sig.parameters.values(): - ann = p.annotation - ann = parse_union(ann) or ann - if p.kind in ( - inspect.Parameter.POSITIONAL_ONLY, - inspect.Parameter.POSITIONAL_OR_KEYWORD, - ): - types.append(ann) - if p.kind == inspect.Parameter.VAR_POSITIONAL: - types.append([ann]) - return tuple(types) - - -class Frame: - __slots__ = "args", "types", "mro", "mro_offset" - - -class Dispatcher(CDispatcher): - """ Dispatch methods based on type signature - - Use ``dispatch`` to add implementations - - Examples - -------- - - >>> from multipledispatch import dispatch - >>> @dispatch(int) - ... def f(x): - ... return x + 1 - - >>> @dispatch(float) - ... def f(x): - ... return x - 1 - - >>> f(3) - 4 - >>> f(3.0) - 2.0 - """ - - __slots__ = "__name__", "name", "funcs", "_ordering", "doc" - - def __init__(self, name, doc=None): - self.name = self.__name__ = name - self.funcs = {} - self.doc = doc - - def register(self, *types, **kwargs): - """ register dispatcher with new implementation - - >>> f = Dispatcher('f') - >>> @f.register(int) - ... def inc(x): - ... return x + 1 - - >>> @f.register(float) - ... def dec(x): - ... return x - 1 - - >>> @f.register(list) - ... @f.register(tuple) - ... def reverse(x): - ... return x[::-1] - - >>> f(1) - 2 - - >>> f(1.0) - 0.0 - - >>> f([1, 2, 3]) - [3, 2, 1] - """ - - def _df(func): - self.add(types, func, **kwargs) - return func - - return _df - - def add(self, signature, func): - """ Add new types/method pair to dispatcher - - >>> D = Dispatcher('add') - >>> D.add((int, int), lambda x, y: x + y) - >>> D.add((float, float), lambda x, y: x + y) - - >>> D(1, 2) - 3 - >>> D(1, 2.0) - Traceback (most recent call last): - ... - NotImplementedError: Could not find signature for add: - - When ``add`` detects a warning it calls the ``on_ambiguity`` callback - with a dispatcher/itself, and a set of ambiguous type signature pairs - as inputs. See ``ambiguity_warn`` for an example. - """ - # Handle annotations - if not signature: - signature = get_func_signature(func) - - # Handle union types - if any(isinstance(typ, tuple) for typ in signature): - for typs in expand_tuples(signature): - self.add(typs, func) - return - - new_signature = [] - - for index, typ in enumerate(signature, start=1): - if not isinstance(typ, (type, list)): - str_sig = ", ".join( - c.__name__ if isinstance(c, type) else str(c) for c in signature - ) - raise TypeError( - "Tried to dispatch on non-type: %s\n" - "In signature: <%s>\n" - "In function: %s" % (typ, str_sig, self.name) - ) - - # handle variadic signatures - if isinstance(typ, list): - if index != len(signature): - raise TypeError("Variadic signature must be the last element") - - if len(typ) != 1: - raise TypeError( - "Variadic signature must contain exactly one element. " - "To use a variadic union type place the desired types " - "inside of a tuple, e.g., [(int, str)]" - ) - new_signature.append(Variadic[typ[0]]) - else: - new_signature.append(typ) - - l = self.funcs.setdefault(tuple(new_signature), []) - for i in l: - if i is func: - raise ValueError("already registered") - l.append(func) - self.enable(func) - self.clear_cache() - - try: - del self._ordering - except AttributeError: - pass - - @property - def ordering(self): - try: - return self._ordering - except AttributeError: - return self.reorder() - - def reorder(self, on_ambiguity=ambiguity_warn): - self._ordering = od = ordering(self.funcs) - amb = ambiguities(self.funcs) - if amb: - on_ambiguity(self, amb) - return od - - def __str__(self): - return "" % self.name - - __repr__ = __str__ - - def dispatch(self, *types): - """ - Deterimine appropriate implementation for this type signature - - This method is internal. Users should call this object as a function. - Implementation resolution occurs within the ``__call__`` method. - - >>> from multipledispatch import dispatch - >>> @dispatch(int) - ... def inc(x): - ... return x + 1 - - >>> implementation = inc.dispatch(int) - >>> implementation(3) - 4 - - >>> print(inc.dispatch(float)) - None - - See Also: - ``multipledispatch.conflict`` - module to determine resolution order - """ - - if types in self.funcs: - return self.funcs[types][-1] - - for f in self.dispatch_iter(*types): - return f - - def dispatch_iter(self, *types): - - n = len(types) - for signature in self.ordering: - if ( - len(signature) == n - and all(map(issubclass, types, signature)) - or len(signature) - and isvariadic(signature[-1]) - and variadic_signature_matches(types, signature) - ): - yield from self.funcs[signature][::-1] - - def __getstate__(self): - return {"name": self.name, "funcs": self.funcs} - - def __setstate__(self, d): - self.name = d["name"] - self.funcs = d["funcs"] - self._ordering = ordering(self.funcs) - self._cache = dict() - - @property - def __doc__(self): - docs = ["Multiply dispatched method: %s" % self.name] - - if self.doc: - docs.append(self.doc) - - other = [] - for sig in self.ordering[::-1]: - funcs = self.funcs[sig] - s = "Inputs: <%s>\n" % str_signature(sig) - sep = "-" * len(s) + "\n" - for i, func in enumerate(funcs): - s += sep - if len(funcs) > 1: - s += "[Handler %d]\n\n" % (i + 1) - if i: - s += "\n\n" - if func.__doc__: - s += func.__doc__.strip() - else: - s += repr(func) + "\n" - docs.append(s) - - return "\n\n".join(docs) - - def _help(self, *args): - return self.dispatch(*map(type, args)).__doc__ - - def help(self, *args, **kwargs): - """ Print docstring for the function corresponding to inputs """ - print(self._help(*args)) - - def _source(self, *args): - func = self.dispatch(*map(type, args)) - if not func: - raise TypeError("No function found") - return source(func) - - def source(self, *args, **kwargs): - """ Print source code for the function corresponding to inputs """ - print(self._source(*args)) - - -def source(func): - s = "File: %s\n\n" % inspect.getsourcefile(func) - s = s + inspect.getsource(func) - return s - - -class MethodDispatcher(Dispatcher): - """ Dispatch methods based on type signature - - See Also: - Dispatcher - """ - - __slots__ = ("obj", "cls") - - @classmethod - def get_func_params(cls, func): - if hasattr(inspect, "signature"): - sig = inspect.signature(func) - return itl.islice(sig.parameters.values(), 1, None) - - def __get__(self, instance, owner): - self.obj = instance - self.cls = owner - return self - - def __call__(self, *args, **kwargs): - types = tuple([type(arg) for arg in args]) - func = self.dispatch(*types) - if not func: - raise NotImplementedError( - "Could not find signature for %s: <%s>" - % (self.name, str_signature(types)) - ) - return func(self.obj, *args, **kwargs) - - -def str_signature(sig): - """ String representation of type signature - - >>> str_signature((int, float)) - 'int, float' - """ - return ", ".join(cls.__name__ for cls in sig) - - -def warning_text(name, amb): - """ The text for ambiguity warnings """ - text = "\nAmbiguities exist in dispatched function %s\n\n" % (name) - text += "The following signatures may result in ambiguous behavior:\n" - for pair in amb: - text += "\t" + ", ".join("[" + str_signature(s) + "]" for s in pair) + "\n" - text += "\n\nConsider making the following additions:\n\n" - text += "\n\n".join( - [ - "@dispatch(" + str_signature(super_signature(s)) + ")\ndef %s(...)" % name - for s in amb - ] - ) - return text diff --git a/imperative/python/megengine/core/tensor/multipledispatch/utils.py b/imperative/python/megengine/core/tensor/multipledispatch/utils.py deleted file mode 100644 index 469ef2774de09cf9af8eb00f4f9db4d22fef1875..0000000000000000000000000000000000000000 --- a/imperative/python/megengine/core/tensor/multipledispatch/utils.py +++ /dev/null @@ -1,210 +0,0 @@ -# Copyright (c) 2014 Matthew Rocklin -# -# All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# a. Redistributions of source code must retain the above copyright notice, -# this list of conditions and the following disclaimer. -# b. Redistributions in binary form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in the -# documentation and/or other materials provided with the distribution. -# c. Neither the name of multipledispatch nor the names of its contributors -# may be used to endorse or promote products derived from this software -# without specific prior written permission. -# -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -# ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR -# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT -# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY -# OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH -# DAMAGE. -# -# -------------------------------------------------------------------------------------- -# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") -# -# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# -# This file has been modified by Megvii ("Megvii Modifications"). -# All Megvii Modifications are Copyright (C) 2014-2020 Megvii Inc. All rights reserved. -# -------------------------------------------------------------------------------------- - -import sys -import typing -from collections import OrderedDict - - -def raises(err, lamda): - try: - lamda() - return False - except err: - return True - - -def expand_tuples(L): - """ - - >>> expand_tuples([1, (2, 3)]) - [(1, 2), (1, 3)] - - >>> expand_tuples([1, 2]) - [(1, 2)] - """ - if not L: - return [()] - elif not isinstance(L[0], tuple): - rest = expand_tuples(L[1:]) - return [(L[0],) + t for t in rest] - else: - rest = expand_tuples(L[1:]) - return [(item,) + t for t in rest for item in L[0]] - - -# Taken from theano/theano/gof/sched.py -# Avoids licensing issues because this was written by Matthew Rocklin -def _toposort(edges): - """ Topological sort algorithm by Kahn [1] - O(nodes + vertices) - - inputs: - edges - a dict of the form {a: {b, c}} where b and c depend on a - outputs: - L - an ordered list of nodes that satisfy the dependencies of edges - - >>> _toposort({1: (2, 3), 2: (3, )}) - [1, 2, 3] - - Closely follows the wikipedia page [2] - - [1] Kahn, Arthur B. (1962), "Topological sorting of large networks", - Communications of the ACM - [2] http://en.wikipedia.org/wiki/Toposort#Algorithms - """ - incoming_edges = reverse_dict(edges) - incoming_edges = OrderedDict((k, set(val)) for k, val in incoming_edges.items()) - S = OrderedDict.fromkeys(v for v in edges if v not in incoming_edges) - L = [] - - while S: - n, _ = S.popitem() - L.append(n) - for m in edges.get(n, ()): - assert n in incoming_edges[m] - incoming_edges[m].remove(n) - if not incoming_edges[m]: - S[m] = None - if any(incoming_edges.get(v, None) for v in edges): - raise ValueError("Input has cycles") - return L - - -def reverse_dict(d): - """ - Reverses direction of dependence dict - - >>> d = {'a': (1, 2), 'b': (2, 3), 'c':()} - >>> reverse_dict(d) # doctest: +SKIP - {1: ('a',), 2: ('a', 'b'), 3: ('b',)} - - :note: dict order are not deterministic. As we iterate on the - input dict, it make the output of this function depend on the - dict order. So this function output order should be considered - as undeterministic. - - """ - result = OrderedDict() - for key in d: - for val in d[key]: - result[val] = result.get(val, tuple()) + (key,) - return result - - -# Taken from toolz -# Avoids licensing issues because this version was authored by Matthew Rocklin -def groupby(func, seq): - """ Group a collection by a key function - - >>> names = ['Alice', 'Bob', 'Charlie', 'Dan', 'Edith', 'Frank'] - >>> groupby(len, names) # doctest: +SKIP - {3: ['Bob', 'Dan'], 5: ['Alice', 'Edith', 'Frank'], 7: ['Charlie']} - - >>> iseven = lambda x: x % 2 == 0 - >>> groupby(iseven, [1, 2, 3, 4, 5, 6, 7, 8]) # doctest: +SKIP - {False: [1, 3, 5, 7], True: [2, 4, 6, 8]} - - See Also: - ``countby`` - """ - - d = OrderedDict() - for item in seq: - key = func(item) - if key not in d: - d[key] = list() - d[key].append(item) - return d - - -def typename(type): - """ - Get the name of `type`. - - Parameters - ---------- - type : Union[Type, Tuple[Type]] - - Returns - ------- - str - The name of `type` or a tuple of the names of the types in `type`. - - Examples - -------- - >>> typename(int) - 'int' - >>> typename((int, float)) - '(int, float)' - """ - try: - return type.__name__ - except AttributeError: - if len(type) == 1: - return typename(*type) - return "(%s)" % ", ".join(map(typename, type)) - - -# parse typing.Union -def parse_union(ann): - if hasattr(typing, "UnionMeta"): - if type(ann) is not typing.UnionMeta: - return - return ann.__union_params__ - elif hasattr(typing, "_Union"): - if type(ann) is not typing._Union: - return - return ann.__args__ - elif hasattr(typing, "_GenericAlias"): - if type(ann) is not typing._GenericAlias: - if type(ann) is not typing.Union: - return - else: - if ann.__origin__ is not typing.Union: - return - return ann.__args__ - elif hasattr(typing, "Union"): - if typing.get_origin(ann) is not typing.Union: - return - return typing.get_args(ann) - else: - raise NotImplementedError("unsupported Python version") diff --git a/imperative/python/megengine/core/tensor/multipledispatch/variadic.py b/imperative/python/megengine/core/tensor/multipledispatch/variadic.py deleted file mode 100644 index 25a5700e417856be75dbf142af231e682e54687e..0000000000000000000000000000000000000000 --- a/imperative/python/megengine/core/tensor/multipledispatch/variadic.py +++ /dev/null @@ -1,140 +0,0 @@ -# Copyright (c) 2014 Matthew Rocklin -# -# All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# a. Redistributions of source code must retain the above copyright notice, -# this list of conditions and the following disclaimer. -# b. Redistributions in binary form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in the -# documentation and/or other materials provided with the distribution. -# c. Neither the name of multipledispatch nor the names of its contributors -# may be used to endorse or promote products derived from this software -# without specific prior written permission. -# -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -# ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR -# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT -# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY -# OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH -# DAMAGE. -# -# -------------------------------------------------------------------------------------- -# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") -# -# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# -# This file has been modified by Megvii ("Megvii Modifications"). -# All Megvii Modifications are Copyright (C) 2014-2020 Megvii Inc. All rights reserved. -# -------------------------------------------------------------------------------------- - -from .utils import typename - - -class VariadicSignatureType(type): - # checking if subclass is a subclass of self - def __subclasscheck__(self, subclass): - other_type = subclass.variadic_type if isvariadic(subclass) else (subclass,) - return subclass is self or all( - issubclass(other, self.variadic_type) for other in other_type - ) - - def __eq__(self, other): - """ - Return True if other has the same variadic type - - Parameters - ---------- - other : object (type) - The object (type) to check - - Returns - ------- - bool - Whether or not `other` is equal to `self` - """ - return isvariadic(other) and set(self.variadic_type) == set(other.variadic_type) - - def __hash__(self): - return hash((type(self), frozenset(self.variadic_type))) - - -def isvariadic(obj): - """ - Check whether the type `obj` is variadic. - - Parameters - ---------- - obj : type - The type to check - - Returns - ------- - bool - Whether or not `obj` is variadic - - Examples - -------- - >>> isvariadic(int) - False - >>> isvariadic(Variadic[int]) - True - """ - return isinstance(obj, VariadicSignatureType) - - -class VariadicSignatureMeta(type): - """ - A metaclass that overrides ``__getitem__`` on the class. This is used to - generate a new type for Variadic signatures. See the Variadic class for - examples of how this behaves. - """ - - def __getitem__(self, variadic_type): - if not (isinstance(variadic_type, (type, tuple)) or type(variadic_type)): - raise ValueError( - "Variadic types must be type or tuple of types" - " (Variadic[int] or Variadic[(int, float)]" - ) - - if not isinstance(variadic_type, tuple): - variadic_type = (variadic_type,) - return VariadicSignatureType( - "Variadic[%s]" % typename(variadic_type), - (), - dict(variadic_type=variadic_type, __slots__=()), - ) - - -class Variadic(metaclass=VariadicSignatureMeta): - """ - A class whose getitem method can be used to generate a new type - representing a specific variadic signature. - - Examples - -------- - >>> Variadic[int] # any number of int arguments - - >>> Variadic[(int, str)] # any number of one of int or str arguments - - >>> issubclass(int, Variadic[int]) - True - >>> issubclass(int, Variadic[(int, str)]) - True - >>> issubclass(str, Variadic[(int, str)]) - True - >>> issubclass(float, Variadic[(int, str)]) - False - """ diff --git a/imperative/python/megengine/core/tensor/raw_tensor/__init__.py b/imperative/python/megengine/core/tensor/raw_tensor/__init__.py deleted file mode 100644 index 7e1b5834e770bf20bc1dee2b3aeb7b43ecc773ca..0000000000000000000000000000000000000000 --- a/imperative/python/megengine/core/tensor/raw_tensor/__init__.py +++ /dev/null @@ -1,136 +0,0 @@ -# -*- coding: utf-8 -*- -# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") -# -# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -import functools - -import numpy as np - -from ..._imperative_rt import CompNode, DeviceTensorND -from ..._imperative_rt.imperative import ( - _drop, - _get_dev_tensor, - _swap_in, - _swap_out, - apply_op, - delete, - get_device, - get_dtype, - get_shape, - get_value, - put, -) -from ..._wrap import device as as_device -from ...ops.builtin import Copy, OpDef, TypeCvt -from ...ops.special import Const -from ..core import OpBase, TensorBase, apply - - -class RawTensor(TensorBase): - - _init_cb = None - _del_cb = None - _handle = None - - def __init__(self, handle=None, isscalar=False): - self._handle = handle - self._isscalar = isscalar - if handle is not None: - if self._init_cb: - self._init_cb() - - @property - def dtype(self): - return get_dtype(self._handle) - - @property - def device(self): - return as_device(get_device(self._handle)) - - @property - def shape(self): - if self._isscalar: - return () - return get_shape(self._handle) - - def numpy(self): - ret = get_value(self._handle) - if self._isscalar: - ret = ret.squeeze() - return ret - - def _dev_tensor(self): - return _get_dev_tensor(self._handle) - - def _drop(self): - _drop(self._handle) - - def _swap_in(self): - _swap_in(self._handle) - - def _swap_out(self): - _swap_out(self._handle) - - def __repr__(self): - return "{}({}, device='{}')".format( - type(self).__qualname__, repr(self.numpy()), self.device - ) - - def __del__(self): - if self._handle is not None: - if self._del_cb: - self._del_cb() - delete(self._handle) - - -@apply.register() -def _(op: OpDef, *args: RawTensor): - outputs = apply_op(op, tuple(i._handle for i in args)) - return tuple(map(RawTensor, outputs)) - - -@apply.register() -def _(op: Const, *args: RawTensor): - dtype = op.dtype - device = as_device(op.device).to_c() - return (as_raw_tensor(op.value, dtype=dtype, device=device),) - - -@functools.singledispatch -def as_raw_tensor(obj, dtype=None, device=None): - obj = np.asarray(obj, dtype=dtype) - if obj.dtype == np.float64: - obj = obj.astype(np.float32) - if obj.dtype == np.int64: - obj = obj.astype(np.int32) - return as_raw_tensor(obj, device=device) - - -@as_raw_tensor.register(DeviceTensorND) -def _(data: DeviceTensorND): - return RawTensor(put(data)) - - -@as_raw_tensor.register(np.ndarray) -def _(array: np.ndarray, dtype=None, device=None): - device = None if device is None else as_device(device).to_c() - if 0 in array.strides: - array = array.squeeze().reshape(array.shape) - return RawTensor(put(array, dtype=dtype, device=device), isscalar=(array.ndim == 0)) - - -@as_raw_tensor.register(RawTensor) -def _(tensor: RawTensor, dtype=None, device=None): - if dtype is not None: - dtype = np.dtype(dtype) - if dtype != tensor.dtype: - (tensor,) = apply(TypeCvt(dtype=dtype), tensor) - if device is not None: - device = as_device(device) - if device != tensor.device: - (tensor,) = apply(Copy(comp_node=device.to_c()), tensor) - return tensor diff --git a/imperative/python/megengine/distributed/functional.py b/imperative/python/megengine/distributed/functional.py index 1eb353966983f48ffc3cc7fb3717f2eda7a0443c..b758217b803254bcfd7c252d354569873f645a08 100644 --- a/imperative/python/megengine/distributed/functional.py +++ b/imperative/python/megengine/distributed/functional.py @@ -9,14 +9,7 @@ from typing import Optional, Tuple from ..core._imperative_rt.core2 import apply -from ..core.autodiff.builtin_op_utils import builtin_op_get_backward_fn -from ..core.autodiff.grad import ( - Tracer, - check_backward_allow_noinput, - get_grad_managers, - get_op_has_grad_fn, - tracer_apply, -) +from ..core.autodiff.grad import get_grad_managers from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend from ..device import get_default_device from ..tensor import Tensor @@ -236,7 +229,7 @@ def remote_recv( device = get_default_device() # dummy input if inp == None: - inp = tensor([0], device=device) + inp = Tensor([0], device=device) tracer_set = get_client().check_remote_tracer(key) for grad_manager in get_grad_managers(): if grad_manager.name in tracer_set: diff --git a/imperative/python/megengine/distributed/helper.py b/imperative/python/megengine/distributed/helper.py index 883470e34afe93c96c28cad61f286035528db57c..f94724a6e52026a9059d229032c149161d5f7995 100644 --- a/imperative/python/megengine/distributed/helper.py +++ b/imperative/python/megengine/distributed/helper.py @@ -67,7 +67,7 @@ def param_pack_split(inp: Tensor, offsets: list, shapes: list): outputs = apply(op, inp) for s, x in zip(shapes, outputs): if not s: - x._isscalar = True + x.setscalar() return outputs diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index b9a5b4f75538bc758f167e191fc44bc89a80cf10..bfa3d9b867f69076574485ad67a8bb1b2f905c9f 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -10,7 +10,7 @@ from typing import Optional, Sequence, Tuple, Union from ..core._imperative_rt import CompNode -from ..core._imperative_rt.core2 import Tensor, apply +from ..core._imperative_rt.core2 import apply from ..core._trace_option import use_symbolic_shape from ..core.ops import builtin from ..core.ops.builtin import BatchNorm diff --git a/imperative/python/megengine/quantization/utils.py b/imperative/python/megengine/quantization/utils.py index 95d4db1bcd093a573b12d4ed12913888cd3a816b..c3e6c9849356567741bcc86718e693add6700737 100644 --- a/imperative/python/megengine/quantization/utils.py +++ b/imperative/python/megengine/quantization/utils.py @@ -12,10 +12,10 @@ from typing import Dict import numpy as np from .. import functional as F +from ..core._imperative_rt.core2 import apply from ..core.autodiff.grad import Function from ..core.ops import builtin from ..core.tensor import megbrain_graph -from ..core.tensor.core import apply from ..core.tensor.dtype import _metadata_dict from ..tensor import Tensor diff --git a/imperative/python/test/conftest.py b/imperative/python/test/conftest.py index 180908ff9b36276f102b25ae3583db9fdd0f493e..ae47cb83144c79fa816a452c4d8b3d1c2496ee2c 100644 --- a/imperative/python/test/conftest.py +++ b/imperative/python/test/conftest.py @@ -3,7 +3,7 @@ import sys import pytest -from megengine.core._imperative_rt.imperative import sync +from megengine.core._imperative_rt.core2 import sync sys.path.append(os.path.join(os.path.dirname(__file__), "helpers")) diff --git a/imperative/python/test/integration/test_save_load.py b/imperative/python/test/integration/test_save_load.py index f18cc8c429f842f41a72b3a040f6f5f291eccae7..e8075028d7b030df35363e317f2baebe882fe4f2 100644 --- a/imperative/python/test/integration/test_save_load.py +++ b/imperative/python/test/integration/test_save_load.py @@ -4,7 +4,6 @@ import megengine as mge import megengine.autodiff as ad import megengine.optimizer as optimizer from megengine import Parameter, tensor -from megengine.core.tensor.raw_tensor import RawTensor from megengine.module import Module diff --git a/imperative/python/test/unit/core/test_dtype_quant.py b/imperative/python/test/unit/core/test_dtype_quant.py index b9a17972fcc8bc25ae53407d68cd2b59c10c8a10..902ef6f0f37a509909c2db7177d95db76ecc62c5 100644 --- a/imperative/python/test/unit/core/test_dtype_quant.py +++ b/imperative/python/test/unit/core/test_dtype_quant.py @@ -13,7 +13,6 @@ import pytest import megengine.core.tensor.megbrain_graph as G from megengine.core.ops import builtin as ops -from megengine.core.tensor.core import apply from megengine.core.tensor.dtype import ( _metadata_dict, convert_from_qint4, diff --git a/imperative/python/test/unit/test_dispatch.py b/imperative/python/test/unit/test_dispatch.py deleted file mode 100644 index ce74341da20bcce063aa2d197d390729b3f31ccf..0000000000000000000000000000000000000000 --- a/imperative/python/test/unit/test_dispatch.py +++ /dev/null @@ -1,58 +0,0 @@ -from megengine.core.tensor.multipledispatch import Dispatcher - - -def test_register_many(): - f = Dispatcher("f") - - log = [] - - @f.register() - def _(x: int): - log.append("a") - return log[-1] - - @f.register() - def _(x: int): - log.append("b") - return log[-1] - - assert f(0) == "b" - assert log == ["b"] - - -def test_return_not_implemented(): - f = Dispatcher("f") - - log = [] - - @f.register() - def _(x: int): - log.append("a") - return log[-1] - - @f.register() - def _(x: int): - log.append("b") - return NotImplemented - - assert f(0) == "a" - assert log == ["b", "a"] - - -def test_super(): - f = Dispatcher("f") - - log = [] - - @f.register() - def _(x: int): - log.append("a") - return log[-1] - - @f.register() - def _(x: int): - log.append("b") - return f.super(x) - - assert f(0) == "a" - assert log == ["b", "a"]