提交 60c7d62a 编写于 作者: M Megvii Engine Team

refactor(imperative): remove multidispatch, raw_tensor, register

GitOrigin-RevId: ca5a6ed8eb6c8089b758eb84bf26d6e928ea4d41
上级 b5e46ae9
......@@ -12,6 +12,7 @@ import itertools
import numpy as np
from .._imperative_rt import TensorAttr, imperative
from .._imperative_rt.core2 import apply
from ..ops.builtin import (
Broadcast,
Elemwise,
......@@ -25,37 +26,6 @@ from ..ops.builtin import (
Subtensor,
)
from ..ops.special import Const
from ..tensor.core import apply
from ..tensor.function import Function
@functools.singledispatch
def builtin_op_get_backward_fn(op: OpDef, inputs, outputs, input_requires_grad):
assert 0
@builtin_op_get_backward_fn.register(OpDef)
def _(op: OpDef, inputs, outputs, input_requires_grad):
if isinstance(op, Reshape):
grad_fn = reshape_grad_fn
elif isinstance(op, Subtensor):
grad_fn = subtensor_grad_fn
elif isinstance(op, IndexingMultiAxisVec):
grad_fn = indexingMultiAxisVec_grad_fn
elif isinstance(op, Broadcast) or (
isinstance(op, Elemwise) and op.mode == Elemwise.Mode.ADD
):
grad_fn = elemwise_add_grad_fn
elif isinstance(op, Reduce) and op.mode == Reduce.Mode.SUM:
grad_fn = reduce_sum_grad_fn
else:
grad_fn = default_grad_fn
return grad_fn(op, inputs, outputs, input_requires_grad)
@builtin_op_get_backward_fn.register(Function)
def _(op: Function, inputs, outputs, input_requires_grad):
return op.get_backward_fn(), [True,] * len(outputs)
def default_grad_fn(op, inputs, outputs, input_requires_grad):
......
......@@ -19,8 +19,6 @@ import megengine as mge
from .._imperative_rt import core2, ops
from ..ops.builtin import Elemwise, OpDef, RemoteSend
from ..ops.special import Const
from ..tensor.core import TensorBase, TensorWrapperBase, apply
from ..tensor.function import Function
from . import builtin_op_utils
""" Some notes:
......@@ -48,146 +46,6 @@ def get_grad_managers():
return [_grad_manager_dict[key] for key in _grad_manager_dict]
def add(a, b):
(c,) = apply(Elemwise(Elemwise.Mode.ADD), a, b)
return c
def get_tensor(x):
# use recursion to avoid infinite loop
if isinstance(x, Tensor):
return x
try:
x = x.__wrapped__
except AttributeError:
raise TypeError(type(x))
return get_tensor(x)
class clearable:
__cleared = False
def __bool__(self):
return not self.__cleared
def clear(self):
self.__dict__.clear()
self.__cleared = True
class OpNode(clearable):
""" OpNode saves all the information to form the computational graph.
"""
def __init__(self):
self.id = None
self.inputs = None # Could be VariableNode
self.outputs = None # Could be VariableNode
self.backward = None
self.has_grad_fn = None
self.backward_allow_noinput = False
class VariableNode(clearable):
""" VariableNode saves OpNode and callback.
FIXME!!! Explain manager and owner
"""
def __init__(self, manager, owner, opnode=None, callback=None):
# manager is Grad type
self.manager = weakref.ref(manager)
# owner is Tensor type
self.owner = weakref.ref(owner)
self.opnode = opnode
self.callback = callback
class Tracer(clearable, TensorBase):
def __init__(self, node=None):
""" type(node) is VariableNode
"""
self.node = node
@functools.singledispatch
def check_backward_allow_noinput(op: OpDef):
return False
@functools.singledispatch
def get_op_has_grad_fn(op: OpDef):
assert 0
@get_op_has_grad_fn.register(OpDef)
def _(op: OpDef):
return default_has_grad_fn
@get_op_has_grad_fn.register(Function)
def _(op: Function):
return default_has_grad_fn
def default_has_grad_fn(opnode, reached):
for v in opnode.outputs:
if v() in reached:
return True
return False
@apply.register()
def tracer_apply(op: (OpDef, Function), *args: typing.Optional[Tracer]):
args = tuple(i if isinstance(i, Tracer) else None for i in args)
input_requires_grad = list(map(bool, args))
if not any(input_requires_grad):
return
ctx = get_context()
manager = None
assert len(ctx.inputs) == len(args)
for i, j in zip(ctx.inputs, args):
if j:
j = j.node
assert i is j.owner()
if manager is None:
manager = j.manager()
assert manager
else:
assert manager is j.manager()
if not manager._enabled:
return
# register backward method
# tuple of backward functions corresponding to dy / dx_i
# None means y is not a function of x_i
backward, output_need_grad = builtin_op_utils.builtin_op_get_backward_fn(
op, ctx.inputs, ctx.outputs, input_requires_grad
)
assert len(ctx.outputs) == len(output_need_grad)
if not any(output_need_grad):
return
opnode, outputs = manager._new_opnode([i and i.node for i in args], ctx.outputs)
if isinstance(op, RemoteSend):
manager.remote_send_cache.append(opnode)
opnode.backward = backward
outputs = [x if y else None for (x, y) in zip(outputs, output_need_grad)]
opnode.backward_allow_noinput = check_backward_allow_noinput(op)
opnode.has_grad_fn = get_op_has_grad_fn(op)
return tuple(outputs)
@apply.register()
def _(op: Const, *_: typing.Optional[Tracer]):
return None
class Grad:
def __init__(self):
self._impl = core2.GradKey()
......
......@@ -8,9 +8,6 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np
# from .._imperative_rt.core2 import Tensor
from ..tensor.core import OpBase, TensorBase, apply
class Const:
def __init__(self, value=None, *, dtype=None, device=None):
......
......@@ -13,12 +13,9 @@ import sys
import typing
from abc import ABC
from .multipledispatch import Dispatcher
class OpBase(ABC):
def __call__(self, *args):
return apply(self, *args)
class OpBase:
pass
class TensorBase:
......@@ -27,22 +24,3 @@ class TensorBase:
class TensorWrapperBase:
pass
apply = Dispatcher("apply")
OpBase.apply = apply
@apply.register()
def _(op: OpBase, *args: TensorBase):
raise NotImplementedError
@apply.register()
def _(op: OpBase, *args: TensorWrapperBase):
assert args
Wrapper = type(args[0])
outputs = apply(op, *(i.__wrapped__ for i in args))
assert isinstance(outputs, tuple)
return tuple(map(Wrapper, outputs))
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ..ops.builtin import OpDef
from .core import TensorBase, TensorWrapperBase, apply
class Function:
"""
Defines a block of operations with customizable differentiation.
The computation should be defined in ``forward`` method, with gradient
computation defined in ``backward`` method.
Each instance of ``Function`` should be used only once during forwardding.
Examples:
.. testcode::
class Sigmoid(Function):
def forward(self, x):
y = 1 / (1 + F.exp(-x))
self.y = y
return y
def backward(self, output_grads):
y = self.y
return output_grads * y * (1-y)
"""
def __init__(self, *args, **kwargs):
pass
def __call__(self, *args):
ret = apply(self, *args)
if type(ret) == tuple and len(ret) == 1:
return ret[0]
return ret
def forward(self, *args, **kwargs):
"""
Applies operations to ``inputs`` and returns results. It must be overriden by all subclasses.
:param input: input tensors.
:return: a tuple of Tensor or a single Tensor.
.. note::
This method should return a tuple of Tensor or a single Tensor representing the output
of the function.
"""
raise NotImplementedError
def backward(self, *output_grads):
"""
Compute the gradient of the forward function. It must be overriden by all subclasses.
:param output_grads: gradients of outputs that are returned by :meth:`~.function.Function.forward`.
.. note::
In case when some tensors of outputs are not related to loss function, the corresponding
values in ``output_grads`` would be ``None``.
.. note::
This method should return a tuple which containing the gradients of all inputs, in the same order
as the ``inputs`` argument of :meth:`~.function.Function.forward` . A ``Tensor`` could be returned
instead if there is only one input. If users want to stop the propagation of some gradients,
the corresponding returned values should be set ``None`` .
"""
raise NotImplementedError
def get_backward_fn(self):
if self.backward is None:
return None
def _backward(*output_grads):
if type(output_grads) is tuple:
_output_grads = [
TensorWrapper(i) if i is not None else i for i in output_grads
]
else:
_output_grads = (
TensorWrapper(output_grads)
if output_grads is not None
else output_grads,
)
ret = self.backward(*_output_grads)
if type(ret) is not tuple:
ret = (ret,)
ret = tuple(
i.__wrapped__ if isinstance(i, TensorWrapper) else i for i in ret
)
return ret
return _backward
Function.apply = Function.__call__
@apply.register()
def _(op: Function, *args: TensorWrapperBase):
assert args
Wrapper = type(args[0])
# compute the value for self define function
extra_data_dic = {}
for arg in args:
extra_data_dic[arg.__wrapped__] = arg.__wrapped__._extra_data
arg.__wrapped__._extra_data = {}
rets = op.forward(*args)
for arg in args:
arg.__wrapped__._extra_data = extra_data_dic[arg.__wrapped__]
# update the gradient information for self define function
inputs = tuple(map(lambda i: i.__wrapped__, args))
outputs = (
tuple(map(lambda i: i.__wrapped__, rets))
if type(rets) is tuple
else (rets.__wrapped__,)
)
for output in outputs:
if output not in inputs:
output._extra_data = {}
with push_context() as ctx:
ctx.inputs = inputs
ctx.outputs = outputs
for k in set().union(*(i._extra_data for i in inputs if isinstance(i, Tensor))):
ctx.key = k
data = tuple(
i._extra_data.get(k) if isinstance(i, Tensor) else i for i in inputs
)
# data are instances of Tracer
# dispatched to apply.add@grad.py
rets = apply(op, *data)
if rets is not None:
assert len(outputs) == len(rets)
for t, i in zip(outputs, rets):
t._extra_data[k] = i
return tuple(map(Wrapper, outputs))
# Copyright (c) 2014 Matthew Rocklin
#
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# a. Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# b. Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# c. Neither the name of multipledispatch nor the names of its contributors
# may be used to endorse or promote products derived from this software
# without specific prior written permission.
#
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR
# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
# OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
# DAMAGE.
#
# --------------------------------------------------------------------------------------
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#
# This file has been modified by Megvii ("Megvii Modifications").
# All Megvii Modifications are Copyright (C) 2014-2020 Megvii Inc. All rights reserved.
# --------------------------------------------------------------------------------------
# This directory is a fork of multipledispatch.
#
# Repo: https://github.com/mrocklin/multipledispatch
# Commit: 9e3c87d0cee57972fd5cc33fe5cacde77c781834
# Authors: Matthew Rocklin et al.
#
# The original LICENSE file is included in the ACKNOWLEDGEMENT file under
# MegEngine root directory.
from .core import dispatch
from .dispatcher import Dispatcher
# Copyright (c) 2014 Matthew Rocklin
#
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# a. Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# b. Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# c. Neither the name of multipledispatch nor the names of its contributors
# may be used to endorse or promote products derived from this software
# without specific prior written permission.
#
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR
# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
# OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
# DAMAGE.
#
# --------------------------------------------------------------------------------------
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#
# This file has been modified by Megvii ("Megvii Modifications").
# All Megvii Modifications are Copyright (C) 2014-2020 Megvii Inc. All rights reserved.
# --------------------------------------------------------------------------------------
from collections import OrderedDict
from .utils import _toposort, groupby
from .variadic import isvariadic
class AmbiguityWarning(Warning):
pass
def supercedes(a, b):
""" A is consistent and strictly more specific than B """
if len(a) < len(b):
# only case is if a is empty and b is variadic
return not a and len(b) == 1 and isvariadic(b[-1])
elif len(a) == len(b):
return all(map(issubclass, a, b))
else:
# len(a) > len(b)
p1 = 0
p2 = 0
while p1 < len(a) and p2 < len(b):
cur_a = a[p1]
cur_b = b[p2]
if not (isvariadic(cur_a) or isvariadic(cur_b)):
if not issubclass(cur_a, cur_b):
return False
p1 += 1
p2 += 1
elif isvariadic(cur_a):
assert p1 == len(a) - 1
return p2 == len(b) - 1 and issubclass(cur_a, cur_b)
elif isvariadic(cur_b):
assert p2 == len(b) - 1
if not issubclass(cur_a, cur_b):
return False
p1 += 1
return p2 == len(b) - 1 and p1 == len(a)
def consistent(a, b):
""" It is possible for an argument list to satisfy both A and B """
# Need to check for empty args
if not a:
return not b or isvariadic(b[0])
if not b:
return not a or isvariadic(a[0])
# Non-empty args check for mutual subclasses
if len(a) == len(b):
return all(issubclass(aa, bb) or issubclass(bb, aa) for aa, bb in zip(a, b))
else:
p1 = 0
p2 = 0
while p1 < len(a) and p2 < len(b):
cur_a = a[p1]
cur_b = b[p2]
if not issubclass(cur_b, cur_a) and not issubclass(cur_a, cur_b):
return False
if not (isvariadic(cur_a) or isvariadic(cur_b)):
p1 += 1
p2 += 1
elif isvariadic(cur_a):
p2 += 1
elif isvariadic(cur_b):
p1 += 1
# We only need to check for variadic ends
# Variadic types are guaranteed to be the last element
return isvariadic(cur_a) and p2 == len(b) or isvariadic(cur_b) and p1 == len(a)
def ambiguous(a, b):
""" A is consistent with B but neither is strictly more specific """
return consistent(a, b) and not (supercedes(a, b) or supercedes(b, a))
def ambiguities(signatures):
""" All signature pairs such that A is ambiguous with B """
signatures = list(map(tuple, signatures))
return set(
(a, b)
for a in signatures
for b in signatures
if hash(a) < hash(b)
and ambiguous(a, b)
and not any(supercedes(c, a) and supercedes(c, b) for c in signatures)
)
def super_signature(signatures):
""" A signature that would break ambiguities """
n = len(signatures[0])
assert all(len(s) == n for s in signatures)
return [max([type.mro(sig[i]) for sig in signatures], key=len)[0] for i in range(n)]
def edge(a, b, tie_breaker=hash):
""" A should be checked before B
Tie broken by tie_breaker, defaults to ``hash``
"""
# A either supercedes B and B does not supercede A or if B does then call
# tie_breaker
return supercedes(a, b) and (
not supercedes(b, a) or tie_breaker(a) > tie_breaker(b)
)
def ordering(signatures):
""" A sane ordering of signatures to check, first to last
Topoological sort of edges as given by ``edge`` and ``supercedes``
"""
signatures = list(map(tuple, signatures))
edges = [(a, b) for a in signatures for b in signatures if edge(a, b)]
edges = groupby(lambda x: x[0], edges)
for s in signatures:
if s not in edges:
edges[s] = []
edges = OrderedDict((k, [b for a, b in v]) for k, v in edges.items())
return _toposort(edges)
# Copyright (c) 2014 Matthew Rocklin
#
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# a. Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# b. Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# c. Neither the name of multipledispatch nor the names of its contributors
# may be used to endorse or promote products derived from this software
# without specific prior written permission.
#
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR
# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
# OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
# DAMAGE.
#
# --------------------------------------------------------------------------------------
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#
# This file has been modified by Megvii ("Megvii Modifications").
# All Megvii Modifications are Copyright (C) 2014-2020 Megvii Inc. All rights reserved.
# --------------------------------------------------------------------------------------
import inspect
import sys
from .dispatcher import Dispatcher, MethodDispatcher, ambiguity_warn
global_namespace = dict()
def dispatch(*types, **kwargs):
""" Dispatch function on the types of the inputs
Supports dispatch on all non-keyword arguments.
Collects implementations based on the function name. Ignores namespaces.
If ambiguous type signatures occur a warning is raised when the function is
defined suggesting the additional method to break the ambiguity.
Examples
--------
>>> @dispatch(int)
... def f(x):
... return x + 1
>>> @dispatch(float)
... def f(x):
... return x - 1
>>> f(3)
4
>>> f(3.0)
2.0
Specify an isolated namespace with the namespace keyword argument
>>> my_namespace = dict()
>>> @dispatch(int, namespace=my_namespace)
... def foo(x):
... return x + 1
Dispatch on instance methods within classes
>>> class MyClass(object):
... @dispatch(list)
... def __init__(self, data):
... self.data = data
... @dispatch(int)
... def __init__(self, datum):
... self.data = [datum]
"""
namespace = kwargs.get("namespace", global_namespace)
types = tuple(types)
def _df(func):
name = func.__name__
if ismethod(func):
dispatcher = inspect.currentframe().f_back.f_locals.get(
name, MethodDispatcher(name),
)
else:
if name not in namespace:
namespace[name] = Dispatcher(name)
dispatcher = namespace[name]
dispatcher.add(types, func)
return dispatcher
return _df
def ismethod(func):
""" Is func a method?
Note that this has to work as the method is defined but before the class is
defined. At this stage methods look like functions.
"""
if hasattr(inspect, "signature"):
signature = inspect.signature(func)
return signature.parameters.get("self", None) is not None
else:
if sys.version_info.major < 3:
spec = inspect.getargspec(func)
else:
spec = inspect.getfullargspec(func)
return spec and spec.args and spec.args[0] == "self"
# Copyright (c) 2014 Matthew Rocklin
#
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# a. Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# b. Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# c. Neither the name of multipledispatch nor the names of its contributors
# may be used to endorse or promote products derived from this software
# without specific prior written permission.
#
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR
# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
# OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
# DAMAGE.
#
# --------------------------------------------------------------------------------------
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#
# This file has been modified by Megvii ("Megvii Modifications").
# All Megvii Modifications are Copyright (C) 2014-2020 Megvii Inc. All rights reserved.
# --------------------------------------------------------------------------------------
import copy
import inspect
import itertools as itl
from warnings import warn
from ..._imperative_rt.dispatcher import Dispatcher as CDispatcher
from .conflict import AmbiguityWarning, ambiguities, ordering, super_signature
from .utils import expand_tuples, parse_union
from .variadic import Variadic, isvariadic
def ambiguity_warn(dispatcher, ambiguities):
""" Raise warning when ambiguity is detected
Parameters
----------
dispatcher : Dispatcher
The dispatcher on which the ambiguity was detected
ambiguities : set
Set of type signature pairs that are ambiguous within this dispatcher
See Also:
Dispatcher.add
warning_text
"""
warn(warning_text(dispatcher.name, ambiguities), AmbiguityWarning)
def variadic_signature_matches_iter(types, full_signature):
"""
Check if a set of input types matches a variadic signature.
Notes
-----
The algorithm is as follows:
Initialize the current signature to the first in the sequence
For each type in `types`:
If the current signature is variadic
If the type matches the signature
yield True
Else
Try to get the next signature
If no signatures are left we can't possibly have a match
so yield False
Else
yield True if the type matches the current signature
Get the next signature
"""
sigiter = iter(full_signature)
sig = next(sigiter)
for typ in types:
matches = issubclass(typ, sig)
yield matches
if not isvariadic(sig):
# we're not matching a variadic argument, so move to the next
# element in the signature
sig = next(sigiter)
else:
try:
sig = next(sigiter)
except StopIteration:
assert isvariadic(sig)
yield True
else:
# We have signature items left over, so all of our arguments
# haven't matched
yield False
def variadic_signature_matches(types, full_signature):
# No arguments always matches a variadic signature
assert full_signature
return all(variadic_signature_matches_iter(types, full_signature))
def get_func_signature(function):
sig = inspect.signature(function)
types = []
for p in sig.parameters.values():
ann = p.annotation
ann = parse_union(ann) or ann
if p.kind in (
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
):
types.append(ann)
if p.kind == inspect.Parameter.VAR_POSITIONAL:
types.append([ann])
return tuple(types)
class Frame:
__slots__ = "args", "types", "mro", "mro_offset"
class Dispatcher(CDispatcher):
""" Dispatch methods based on type signature
Use ``dispatch`` to add implementations
Examples
--------
>>> from multipledispatch import dispatch
>>> @dispatch(int)
... def f(x):
... return x + 1
>>> @dispatch(float)
... def f(x):
... return x - 1
>>> f(3)
4
>>> f(3.0)
2.0
"""
__slots__ = "__name__", "name", "funcs", "_ordering", "doc"
def __init__(self, name, doc=None):
self.name = self.__name__ = name
self.funcs = {}
self.doc = doc
def register(self, *types, **kwargs):
""" register dispatcher with new implementation
>>> f = Dispatcher('f')
>>> @f.register(int)
... def inc(x):
... return x + 1
>>> @f.register(float)
... def dec(x):
... return x - 1
>>> @f.register(list)
... @f.register(tuple)
... def reverse(x):
... return x[::-1]
>>> f(1)
2
>>> f(1.0)
0.0
>>> f([1, 2, 3])
[3, 2, 1]
"""
def _df(func):
self.add(types, func, **kwargs)
return func
return _df
def add(self, signature, func):
""" Add new types/method pair to dispatcher
>>> D = Dispatcher('add')
>>> D.add((int, int), lambda x, y: x + y)
>>> D.add((float, float), lambda x, y: x + y)
>>> D(1, 2)
3
>>> D(1, 2.0)
Traceback (most recent call last):
...
NotImplementedError: Could not find signature for add: <int, float>
When ``add`` detects a warning it calls the ``on_ambiguity`` callback
with a dispatcher/itself, and a set of ambiguous type signature pairs
as inputs. See ``ambiguity_warn`` for an example.
"""
# Handle annotations
if not signature:
signature = get_func_signature(func)
# Handle union types
if any(isinstance(typ, tuple) for typ in signature):
for typs in expand_tuples(signature):
self.add(typs, func)
return
new_signature = []
for index, typ in enumerate(signature, start=1):
if not isinstance(typ, (type, list)):
str_sig = ", ".join(
c.__name__ if isinstance(c, type) else str(c) for c in signature
)
raise TypeError(
"Tried to dispatch on non-type: %s\n"
"In signature: <%s>\n"
"In function: %s" % (typ, str_sig, self.name)
)
# handle variadic signatures
if isinstance(typ, list):
if index != len(signature):
raise TypeError("Variadic signature must be the last element")
if len(typ) != 1:
raise TypeError(
"Variadic signature must contain exactly one element. "
"To use a variadic union type place the desired types "
"inside of a tuple, e.g., [(int, str)]"
)
new_signature.append(Variadic[typ[0]])
else:
new_signature.append(typ)
l = self.funcs.setdefault(tuple(new_signature), [])
for i in l:
if i is func:
raise ValueError("already registered")
l.append(func)
self.enable(func)
self.clear_cache()
try:
del self._ordering
except AttributeError:
pass
@property
def ordering(self):
try:
return self._ordering
except AttributeError:
return self.reorder()
def reorder(self, on_ambiguity=ambiguity_warn):
self._ordering = od = ordering(self.funcs)
amb = ambiguities(self.funcs)
if amb:
on_ambiguity(self, amb)
return od
def __str__(self):
return "<dispatched %s>" % self.name
__repr__ = __str__
def dispatch(self, *types):
"""
Deterimine appropriate implementation for this type signature
This method is internal. Users should call this object as a function.
Implementation resolution occurs within the ``__call__`` method.
>>> from multipledispatch import dispatch
>>> @dispatch(int)
... def inc(x):
... return x + 1
>>> implementation = inc.dispatch(int)
>>> implementation(3)
4
>>> print(inc.dispatch(float))
None
See Also:
``multipledispatch.conflict`` - module to determine resolution order
"""
if types in self.funcs:
return self.funcs[types][-1]
for f in self.dispatch_iter(*types):
return f
def dispatch_iter(self, *types):
n = len(types)
for signature in self.ordering:
if (
len(signature) == n
and all(map(issubclass, types, signature))
or len(signature)
and isvariadic(signature[-1])
and variadic_signature_matches(types, signature)
):
yield from self.funcs[signature][::-1]
def __getstate__(self):
return {"name": self.name, "funcs": self.funcs}
def __setstate__(self, d):
self.name = d["name"]
self.funcs = d["funcs"]
self._ordering = ordering(self.funcs)
self._cache = dict()
@property
def __doc__(self):
docs = ["Multiply dispatched method: %s" % self.name]
if self.doc:
docs.append(self.doc)
other = []
for sig in self.ordering[::-1]:
funcs = self.funcs[sig]
s = "Inputs: <%s>\n" % str_signature(sig)
sep = "-" * len(s) + "\n"
for i, func in enumerate(funcs):
s += sep
if len(funcs) > 1:
s += "[Handler %d]\n\n" % (i + 1)
if i:
s += "\n\n"
if func.__doc__:
s += func.__doc__.strip()
else:
s += repr(func) + "\n"
docs.append(s)
return "\n\n".join(docs)
def _help(self, *args):
return self.dispatch(*map(type, args)).__doc__
def help(self, *args, **kwargs):
""" Print docstring for the function corresponding to inputs """
print(self._help(*args))
def _source(self, *args):
func = self.dispatch(*map(type, args))
if not func:
raise TypeError("No function found")
return source(func)
def source(self, *args, **kwargs):
""" Print source code for the function corresponding to inputs """
print(self._source(*args))
def source(func):
s = "File: %s\n\n" % inspect.getsourcefile(func)
s = s + inspect.getsource(func)
return s
class MethodDispatcher(Dispatcher):
""" Dispatch methods based on type signature
See Also:
Dispatcher
"""
__slots__ = ("obj", "cls")
@classmethod
def get_func_params(cls, func):
if hasattr(inspect, "signature"):
sig = inspect.signature(func)
return itl.islice(sig.parameters.values(), 1, None)
def __get__(self, instance, owner):
self.obj = instance
self.cls = owner
return self
def __call__(self, *args, **kwargs):
types = tuple([type(arg) for arg in args])
func = self.dispatch(*types)
if not func:
raise NotImplementedError(
"Could not find signature for %s: <%s>"
% (self.name, str_signature(types))
)
return func(self.obj, *args, **kwargs)
def str_signature(sig):
""" String representation of type signature
>>> str_signature((int, float))
'int, float'
"""
return ", ".join(cls.__name__ for cls in sig)
def warning_text(name, amb):
""" The text for ambiguity warnings """
text = "\nAmbiguities exist in dispatched function %s\n\n" % (name)
text += "The following signatures may result in ambiguous behavior:\n"
for pair in amb:
text += "\t" + ", ".join("[" + str_signature(s) + "]" for s in pair) + "\n"
text += "\n\nConsider making the following additions:\n\n"
text += "\n\n".join(
[
"@dispatch(" + str_signature(super_signature(s)) + ")\ndef %s(...)" % name
for s in amb
]
)
return text
# Copyright (c) 2014 Matthew Rocklin
#
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# a. Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# b. Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# c. Neither the name of multipledispatch nor the names of its contributors
# may be used to endorse or promote products derived from this software
# without specific prior written permission.
#
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR
# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
# OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
# DAMAGE.
#
# --------------------------------------------------------------------------------------
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#
# This file has been modified by Megvii ("Megvii Modifications").
# All Megvii Modifications are Copyright (C) 2014-2020 Megvii Inc. All rights reserved.
# --------------------------------------------------------------------------------------
import sys
import typing
from collections import OrderedDict
def raises(err, lamda):
try:
lamda()
return False
except err:
return True
def expand_tuples(L):
"""
>>> expand_tuples([1, (2, 3)])
[(1, 2), (1, 3)]
>>> expand_tuples([1, 2])
[(1, 2)]
"""
if not L:
return [()]
elif not isinstance(L[0], tuple):
rest = expand_tuples(L[1:])
return [(L[0],) + t for t in rest]
else:
rest = expand_tuples(L[1:])
return [(item,) + t for t in rest for item in L[0]]
# Taken from theano/theano/gof/sched.py
# Avoids licensing issues because this was written by Matthew Rocklin
def _toposort(edges):
""" Topological sort algorithm by Kahn [1] - O(nodes + vertices)
inputs:
edges - a dict of the form {a: {b, c}} where b and c depend on a
outputs:
L - an ordered list of nodes that satisfy the dependencies of edges
>>> _toposort({1: (2, 3), 2: (3, )})
[1, 2, 3]
Closely follows the wikipedia page [2]
[1] Kahn, Arthur B. (1962), "Topological sorting of large networks",
Communications of the ACM
[2] http://en.wikipedia.org/wiki/Toposort#Algorithms
"""
incoming_edges = reverse_dict(edges)
incoming_edges = OrderedDict((k, set(val)) for k, val in incoming_edges.items())
S = OrderedDict.fromkeys(v for v in edges if v not in incoming_edges)
L = []
while S:
n, _ = S.popitem()
L.append(n)
for m in edges.get(n, ()):
assert n in incoming_edges[m]
incoming_edges[m].remove(n)
if not incoming_edges[m]:
S[m] = None
if any(incoming_edges.get(v, None) for v in edges):
raise ValueError("Input has cycles")
return L
def reverse_dict(d):
"""
Reverses direction of dependence dict
>>> d = {'a': (1, 2), 'b': (2, 3), 'c':()}
>>> reverse_dict(d) # doctest: +SKIP
{1: ('a',), 2: ('a', 'b'), 3: ('b',)}
:note: dict order are not deterministic. As we iterate on the
input dict, it make the output of this function depend on the
dict order. So this function output order should be considered
as undeterministic.
"""
result = OrderedDict()
for key in d:
for val in d[key]:
result[val] = result.get(val, tuple()) + (key,)
return result
# Taken from toolz
# Avoids licensing issues because this version was authored by Matthew Rocklin
def groupby(func, seq):
""" Group a collection by a key function
>>> names = ['Alice', 'Bob', 'Charlie', 'Dan', 'Edith', 'Frank']
>>> groupby(len, names) # doctest: +SKIP
{3: ['Bob', 'Dan'], 5: ['Alice', 'Edith', 'Frank'], 7: ['Charlie']}
>>> iseven = lambda x: x % 2 == 0
>>> groupby(iseven, [1, 2, 3, 4, 5, 6, 7, 8]) # doctest: +SKIP
{False: [1, 3, 5, 7], True: [2, 4, 6, 8]}
See Also:
``countby``
"""
d = OrderedDict()
for item in seq:
key = func(item)
if key not in d:
d[key] = list()
d[key].append(item)
return d
def typename(type):
"""
Get the name of `type`.
Parameters
----------
type : Union[Type, Tuple[Type]]
Returns
-------
str
The name of `type` or a tuple of the names of the types in `type`.
Examples
--------
>>> typename(int)
'int'
>>> typename((int, float))
'(int, float)'
"""
try:
return type.__name__
except AttributeError:
if len(type) == 1:
return typename(*type)
return "(%s)" % ", ".join(map(typename, type))
# parse typing.Union
def parse_union(ann):
if hasattr(typing, "UnionMeta"):
if type(ann) is not typing.UnionMeta:
return
return ann.__union_params__
elif hasattr(typing, "_Union"):
if type(ann) is not typing._Union:
return
return ann.__args__
elif hasattr(typing, "_GenericAlias"):
if type(ann) is not typing._GenericAlias:
if type(ann) is not typing.Union:
return
else:
if ann.__origin__ is not typing.Union:
return
return ann.__args__
elif hasattr(typing, "Union"):
if typing.get_origin(ann) is not typing.Union:
return
return typing.get_args(ann)
else:
raise NotImplementedError("unsupported Python version")
# Copyright (c) 2014 Matthew Rocklin
#
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# a. Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# b. Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# c. Neither the name of multipledispatch nor the names of its contributors
# may be used to endorse or promote products derived from this software
# without specific prior written permission.
#
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR
# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
# OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
# DAMAGE.
#
# --------------------------------------------------------------------------------------
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#
# This file has been modified by Megvii ("Megvii Modifications").
# All Megvii Modifications are Copyright (C) 2014-2020 Megvii Inc. All rights reserved.
# --------------------------------------------------------------------------------------
from .utils import typename
class VariadicSignatureType(type):
# checking if subclass is a subclass of self
def __subclasscheck__(self, subclass):
other_type = subclass.variadic_type if isvariadic(subclass) else (subclass,)
return subclass is self or all(
issubclass(other, self.variadic_type) for other in other_type
)
def __eq__(self, other):
"""
Return True if other has the same variadic type
Parameters
----------
other : object (type)
The object (type) to check
Returns
-------
bool
Whether or not `other` is equal to `self`
"""
return isvariadic(other) and set(self.variadic_type) == set(other.variadic_type)
def __hash__(self):
return hash((type(self), frozenset(self.variadic_type)))
def isvariadic(obj):
"""
Check whether the type `obj` is variadic.
Parameters
----------
obj : type
The type to check
Returns
-------
bool
Whether or not `obj` is variadic
Examples
--------
>>> isvariadic(int)
False
>>> isvariadic(Variadic[int])
True
"""
return isinstance(obj, VariadicSignatureType)
class VariadicSignatureMeta(type):
"""
A metaclass that overrides ``__getitem__`` on the class. This is used to
generate a new type for Variadic signatures. See the Variadic class for
examples of how this behaves.
"""
def __getitem__(self, variadic_type):
if not (isinstance(variadic_type, (type, tuple)) or type(variadic_type)):
raise ValueError(
"Variadic types must be type or tuple of types"
" (Variadic[int] or Variadic[(int, float)]"
)
if not isinstance(variadic_type, tuple):
variadic_type = (variadic_type,)
return VariadicSignatureType(
"Variadic[%s]" % typename(variadic_type),
(),
dict(variadic_type=variadic_type, __slots__=()),
)
class Variadic(metaclass=VariadicSignatureMeta):
"""
A class whose getitem method can be used to generate a new type
representing a specific variadic signature.
Examples
--------
>>> Variadic[int] # any number of int arguments
<class 'multipledispatch.variadic.Variadic[int]'>
>>> Variadic[(int, str)] # any number of one of int or str arguments
<class 'multipledispatch.variadic.Variadic[(int, str)]'>
>>> issubclass(int, Variadic[int])
True
>>> issubclass(int, Variadic[(int, str)])
True
>>> issubclass(str, Variadic[(int, str)])
True
>>> issubclass(float, Variadic[(int, str)])
False
"""
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import functools
import numpy as np
from ..._imperative_rt import CompNode, DeviceTensorND
from ..._imperative_rt.imperative import (
_drop,
_get_dev_tensor,
_swap_in,
_swap_out,
apply_op,
delete,
get_device,
get_dtype,
get_shape,
get_value,
put,
)
from ..._wrap import device as as_device
from ...ops.builtin import Copy, OpDef, TypeCvt
from ...ops.special import Const
from ..core import OpBase, TensorBase, apply
class RawTensor(TensorBase):
_init_cb = None
_del_cb = None
_handle = None
def __init__(self, handle=None, isscalar=False):
self._handle = handle
self._isscalar = isscalar
if handle is not None:
if self._init_cb:
self._init_cb()
@property
def dtype(self):
return get_dtype(self._handle)
@property
def device(self):
return as_device(get_device(self._handle))
@property
def shape(self):
if self._isscalar:
return ()
return get_shape(self._handle)
def numpy(self):
ret = get_value(self._handle)
if self._isscalar:
ret = ret.squeeze()
return ret
def _dev_tensor(self):
return _get_dev_tensor(self._handle)
def _drop(self):
_drop(self._handle)
def _swap_in(self):
_swap_in(self._handle)
def _swap_out(self):
_swap_out(self._handle)
def __repr__(self):
return "{}({}, device='{}')".format(
type(self).__qualname__, repr(self.numpy()), self.device
)
def __del__(self):
if self._handle is not None:
if self._del_cb:
self._del_cb()
delete(self._handle)
@apply.register()
def _(op: OpDef, *args: RawTensor):
outputs = apply_op(op, tuple(i._handle for i in args))
return tuple(map(RawTensor, outputs))
@apply.register()
def _(op: Const, *args: RawTensor):
dtype = op.dtype
device = as_device(op.device).to_c()
return (as_raw_tensor(op.value, dtype=dtype, device=device),)
@functools.singledispatch
def as_raw_tensor(obj, dtype=None, device=None):
obj = np.asarray(obj, dtype=dtype)
if obj.dtype == np.float64:
obj = obj.astype(np.float32)
if obj.dtype == np.int64:
obj = obj.astype(np.int32)
return as_raw_tensor(obj, device=device)
@as_raw_tensor.register(DeviceTensorND)
def _(data: DeviceTensorND):
return RawTensor(put(data))
@as_raw_tensor.register(np.ndarray)
def _(array: np.ndarray, dtype=None, device=None):
device = None if device is None else as_device(device).to_c()
if 0 in array.strides:
array = array.squeeze().reshape(array.shape)
return RawTensor(put(array, dtype=dtype, device=device), isscalar=(array.ndim == 0))
@as_raw_tensor.register(RawTensor)
def _(tensor: RawTensor, dtype=None, device=None):
if dtype is not None:
dtype = np.dtype(dtype)
if dtype != tensor.dtype:
(tensor,) = apply(TypeCvt(dtype=dtype), tensor)
if device is not None:
device = as_device(device)
if device != tensor.device:
(tensor,) = apply(Copy(comp_node=device.to_c()), tensor)
return tensor
......@@ -9,14 +9,7 @@
from typing import Optional, Tuple
from ..core._imperative_rt.core2 import apply
from ..core.autodiff.builtin_op_utils import builtin_op_get_backward_fn
from ..core.autodiff.grad import (
Tracer,
check_backward_allow_noinput,
get_grad_managers,
get_op_has_grad_fn,
tracer_apply,
)
from ..core.autodiff.grad import get_grad_managers
from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend
from ..device import get_default_device
from ..tensor import Tensor
......@@ -236,7 +229,7 @@ def remote_recv(
device = get_default_device()
# dummy input
if inp == None:
inp = tensor([0], device=device)
inp = Tensor([0], device=device)
tracer_set = get_client().check_remote_tracer(key)
for grad_manager in get_grad_managers():
if grad_manager.name in tracer_set:
......
......@@ -67,7 +67,7 @@ def param_pack_split(inp: Tensor, offsets: list, shapes: list):
outputs = apply(op, inp)
for s, x in zip(shapes, outputs):
if not s:
x._isscalar = True
x.setscalar()
return outputs
......
......@@ -10,7 +10,7 @@
from typing import Optional, Sequence, Tuple, Union
from ..core._imperative_rt import CompNode
from ..core._imperative_rt.core2 import Tensor, apply
from ..core._imperative_rt.core2 import apply
from ..core._trace_option import use_symbolic_shape
from ..core.ops import builtin
from ..core.ops.builtin import BatchNorm
......
......@@ -12,10 +12,10 @@ from typing import Dict
import numpy as np
from .. import functional as F
from ..core._imperative_rt.core2 import apply
from ..core.autodiff.grad import Function
from ..core.ops import builtin
from ..core.tensor import megbrain_graph
from ..core.tensor.core import apply
from ..core.tensor.dtype import _metadata_dict
from ..tensor import Tensor
......
......@@ -3,7 +3,7 @@ import sys
import pytest
from megengine.core._imperative_rt.imperative import sync
from megengine.core._imperative_rt.core2 import sync
sys.path.append(os.path.join(os.path.dirname(__file__), "helpers"))
......
......@@ -4,7 +4,6 @@ import megengine as mge
import megengine.autodiff as ad
import megengine.optimizer as optimizer
from megengine import Parameter, tensor
from megengine.core.tensor.raw_tensor import RawTensor
from megengine.module import Module
......
......@@ -13,7 +13,6 @@ import pytest
import megengine.core.tensor.megbrain_graph as G
from megengine.core.ops import builtin as ops
from megengine.core.tensor.core import apply
from megengine.core.tensor.dtype import (
_metadata_dict,
convert_from_qint4,
......
from megengine.core.tensor.multipledispatch import Dispatcher
def test_register_many():
f = Dispatcher("f")
log = []
@f.register()
def _(x: int):
log.append("a")
return log[-1]
@f.register()
def _(x: int):
log.append("b")
return log[-1]
assert f(0) == "b"
assert log == ["b"]
def test_return_not_implemented():
f = Dispatcher("f")
log = []
@f.register()
def _(x: int):
log.append("a")
return log[-1]
@f.register()
def _(x: int):
log.append("b")
return NotImplemented
assert f(0) == "a"
assert log == ["b", "a"]
def test_super():
f = Dispatcher("f")
log = []
@f.register()
def _(x: int):
log.append("a")
return log[-1]
@f.register()
def _(x: int):
log.append("b")
return f.super(x)
assert f(0) == "a"
assert log == ["b", "a"]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册