提交 5474538f 编写于 作者: M Megvii Engine Team

refactor(mge/imperative): fork multipledispatch

GitOrigin-RevId: a7c25a4302badc71d3f885d3b114e256ee779e49
上级 ad9ac521
......@@ -209,6 +209,44 @@ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
*********************************************************************************************************************************
multipledispatch
--------------------------------------------------------------------
Copyright (c) 2014 Matthew Rocklin
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
a. Redistributions of source code must retain the above copyright notice,
this list of conditions and the following disclaimer.
b. Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
c. Neither the name of multipledispatch nor the names of its contributors
may be used to endorse or promote products derived from this software
without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
DAMAGE.
*********************************************************************************************************************************
*********************************************************************************************************************************
Software Licensed under the BSD 3-Clause License and Other Licenses of the Third-party Components therein:
--------------------------------------------------------------------
......
......@@ -343,7 +343,7 @@ def default_has_grad_fn(opnode, reached):
return False
@apply.add
@apply.register()
def tracer_apply(op: (OpDef, Function), *args: typing.Optional[Tracer]):
args = tuple(i if isinstance(i, Tracer) else None for i in args)
input_requires_grad = list(map(bool, args))
......@@ -385,6 +385,6 @@ def tracer_apply(op: (OpDef, Function), *args: typing.Optional[Tracer]):
return tuple(outputs)
@apply.add
@apply.register()
def _(op: Const, *_: typing.Optional[Tracer]):
return None
......@@ -19,7 +19,7 @@ from .._internal.helper import PodOpVisitor
OpBase.register(OpDef)
# forward to apply(OpDef, ...)
@apply.add
@apply.register()
def _(op: PodOpVisitor, *args: Union[TensorBase, TensorWrapperBase]):
return apply(op.to_c(), *args)
......
......@@ -13,7 +13,7 @@ import sys
import typing
from abc import ABC
import multipledispatch
from .multipledispatch import Dispatcher
class OpBase(ABC):
......@@ -29,84 +29,17 @@ class TensorWrapperBase:
pass
class Dispatcher(multipledispatch.Dispatcher):
def add(self, f, g=None):
if g is None:
super().add(get_signature(f), f)
else:
super().add(f, g)
return f
def __get__(self, instance, owner=None):
if instance is not None:
return self
return functools.partial(self, instance)
if sys.version_info < (3, 6):
def parse_union(ann):
if type(ann) is not typing.UnionMeta:
return
return ann.__union_params__
elif sys.version_info < (3, 7):
def parse_union(ann):
if type(ann) is not typing._Union:
return
return ann.__args__
elif sys.version_info < (3, 8):
def parse_union(ann):
if type(ann) is not typing._GenericAlias:
if type(ann) is not typing.Union:
return
else:
if ann.__origin__ is not typing.Union:
return
return ann.__args__
else:
def parse_union(ann):
if typing.get_origin(ann) is not typing.Union:
return
return typing.get_args(ann)
def get_signature(function, op_type=None):
sig = inspect.signature(function)
types = []
for p in sig.parameters.values():
ann = p.annotation
ann = parse_union(ann) or ann
if p.kind in (
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
):
types.append(ann)
if p.kind == inspect.Parameter.VAR_POSITIONAL:
types.append([ann])
return tuple(types)
apply = Dispatcher("apply")
OpBase.apply = apply
@apply.add
@apply.register()
def _(op: OpBase, *args: TensorBase):
raise NotImplementedError
@apply.add
@apply.register()
def _(op: OpBase, *args: TensorWrapperBase):
assert args
Wrapper = type(args[0])
......
......@@ -102,7 +102,7 @@ class Function:
Function.apply = Function.__call__
@apply.add
@apply.register()
def _(op: Function, *args: TensorWrapperBase):
assert args
Wrapper = type(args[0])
......@@ -148,11 +148,11 @@ def _(op: Function, *args: TensorWrapperBase):
return tuple(map(Wrapper, outputs))
@apply.add
@apply.register()
def _(op: Function, *args: Tensor):
raise NotImplementedError
@apply.add
@apply.register()
def _(op: Function, *args: RawTensor):
raise NotImplementedError
......@@ -111,7 +111,7 @@ def _unwrap(x):
return x._node
@apply.add
@apply.register()
def _(op: OpDef, *args: VarNode):
outputs = _imperative_rt.invoke_op(op, _unwrap(args))
return _wrap(outputs)
......
# This directory is a fork of multipledispatch.
#
# Repo: https://github.com/mrocklin/multipledispatch
# Commit: 9e3c87d0cee57972fd5cc33fe5cacde77c781834
# Authors: Matthew Rocklin et al.
#
# Refer to ACKNOWLEDGEMENT for copyright and liscense information
from .core import dispatch
from .dispatcher import Dispatcher
from .utils import _toposort, groupby
from .variadic import isvariadic
class AmbiguityWarning(Warning):
pass
def supercedes(a, b):
""" A is consistent and strictly more specific than B """
if len(a) < len(b):
# only case is if a is empty and b is variadic
return not a and len(b) == 1 and isvariadic(b[-1])
elif len(a) == len(b):
return all(map(issubclass, a, b))
else:
# len(a) > len(b)
p1 = 0
p2 = 0
while p1 < len(a) and p2 < len(b):
cur_a = a[p1]
cur_b = b[p2]
if not (isvariadic(cur_a) or isvariadic(cur_b)):
if not issubclass(cur_a, cur_b):
return False
p1 += 1
p2 += 1
elif isvariadic(cur_a):
assert p1 == len(a) - 1
return p2 == len(b) - 1 and issubclass(cur_a, cur_b)
elif isvariadic(cur_b):
assert p2 == len(b) - 1
if not issubclass(cur_a, cur_b):
return False
p1 += 1
return p2 == len(b) - 1 and p1 == len(a)
def consistent(a, b):
""" It is possible for an argument list to satisfy both A and B """
# Need to check for empty args
if not a:
return not b or isvariadic(b[0])
if not b:
return not a or isvariadic(a[0])
# Non-empty args check for mutual subclasses
if len(a) == len(b):
return all(issubclass(aa, bb) or issubclass(bb, aa) for aa, bb in zip(a, b))
else:
p1 = 0
p2 = 0
while p1 < len(a) and p2 < len(b):
cur_a = a[p1]
cur_b = b[p2]
if not issubclass(cur_b, cur_a) and not issubclass(cur_a, cur_b):
return False
if not (isvariadic(cur_a) or isvariadic(cur_b)):
p1 += 1
p2 += 1
elif isvariadic(cur_a):
p2 += 1
elif isvariadic(cur_b):
p1 += 1
# We only need to check for variadic ends
# Variadic types are guaranteed to be the last element
return isvariadic(cur_a) and p2 == len(b) or isvariadic(cur_b) and p1 == len(a)
def ambiguous(a, b):
""" A is consistent with B but neither is strictly more specific """
return consistent(a, b) and not (supercedes(a, b) or supercedes(b, a))
def ambiguities(signatures):
""" All signature pairs such that A is ambiguous with B """
signatures = list(map(tuple, signatures))
return set(
(a, b)
for a in signatures
for b in signatures
if hash(a) < hash(b)
and ambiguous(a, b)
and not any(supercedes(c, a) and supercedes(c, b) for c in signatures)
)
def super_signature(signatures):
""" A signature that would break ambiguities """
n = len(signatures[0])
assert all(len(s) == n for s in signatures)
return [max([type.mro(sig[i]) for sig in signatures], key=len)[0] for i in range(n)]
def edge(a, b, tie_breaker=hash):
""" A should be checked before B
Tie broken by tie_breaker, defaults to ``hash``
"""
# A either supercedes B and B does not supercede A or if B does then call
# tie_breaker
return supercedes(a, b) and (
not supercedes(b, a) or tie_breaker(a) > tie_breaker(b)
)
def ordering(signatures):
""" A sane ordering of signatures to check, first to last
Topoological sort of edges as given by ``edge`` and ``supercedes``
"""
signatures = list(map(tuple, signatures))
edges = [(a, b) for a in signatures for b in signatures if edge(a, b)]
edges = groupby(lambda x: x[0], edges)
for s in signatures:
if s not in edges:
edges[s] = []
edges = dict((k, [b for a, b in v]) for k, v in edges.items())
return _toposort(edges)
import inspect
import sys
from .dispatcher import Dispatcher, MethodDispatcher, ambiguity_warn
global_namespace = dict()
def dispatch(*types, **kwargs):
""" Dispatch function on the types of the inputs
Supports dispatch on all non-keyword arguments.
Collects implementations based on the function name. Ignores namespaces.
If ambiguous type signatures occur a warning is raised when the function is
defined suggesting the additional method to break the ambiguity.
Examples
--------
>>> @dispatch(int)
... def f(x):
... return x + 1
>>> @dispatch(float)
... def f(x):
... return x - 1
>>> f(3)
4
>>> f(3.0)
2.0
Specify an isolated namespace with the namespace keyword argument
>>> my_namespace = dict()
>>> @dispatch(int, namespace=my_namespace)
... def foo(x):
... return x + 1
Dispatch on instance methods within classes
>>> class MyClass(object):
... @dispatch(list)
... def __init__(self, data):
... self.data = data
... @dispatch(int)
... def __init__(self, datum):
... self.data = [datum]
"""
namespace = kwargs.get("namespace", global_namespace)
types = tuple(types)
def _df(func):
name = func.__name__
if ismethod(func):
dispatcher = inspect.currentframe().f_back.f_locals.get(
name, MethodDispatcher(name),
)
else:
if name not in namespace:
namespace[name] = Dispatcher(name)
dispatcher = namespace[name]
dispatcher.add(types, func)
return dispatcher
return _df
def ismethod(func):
""" Is func a method?
Note that this has to work as the method is defined but before the class is
defined. At this stage methods look like functions.
"""
if hasattr(inspect, "signature"):
signature = inspect.signature(func)
return signature.parameters.get("self", None) is not None
else:
if sys.version_info.major < 3:
spec = inspect.getargspec(func)
else:
spec = inspect.getfullargspec(func)
return spec and spec.args and spec.args[0] == "self"
import copy
import inspect
import itertools as itl
from warnings import warn
from ..._imperative_rt.dispatcher import Dispatcher as CDispatcher
from .conflict import AmbiguityWarning, ambiguities, ordering, super_signature
from .utils import expand_tuples, parse_union
from .variadic import Variadic, isvariadic
def ambiguity_warn(dispatcher, ambiguities):
""" Raise warning when ambiguity is detected
Parameters
----------
dispatcher : Dispatcher
The dispatcher on which the ambiguity was detected
ambiguities : set
Set of type signature pairs that are ambiguous within this dispatcher
See Also:
Dispatcher.add
warning_text
"""
warn(warning_text(dispatcher.name, ambiguities), AmbiguityWarning)
def variadic_signature_matches_iter(types, full_signature):
"""Check if a set of input types matches a variadic signature.
Notes
-----
The algorithm is as follows:
Initialize the current signature to the first in the sequence
For each type in `types`:
If the current signature is variadic
If the type matches the signature
yield True
Else
Try to get the next signature
If no signatures are left we can't possibly have a match
so yield False
Else
yield True if the type matches the current signature
Get the next signature
"""
sigiter = iter(full_signature)
sig = next(sigiter)
for typ in types:
matches = issubclass(typ, sig)
yield matches
if not isvariadic(sig):
# we're not matching a variadic argument, so move to the next
# element in the signature
sig = next(sigiter)
else:
try:
sig = next(sigiter)
except StopIteration:
assert isvariadic(sig)
yield True
else:
# We have signature items left over, so all of our arguments
# haven't matched
yield False
def variadic_signature_matches(types, full_signature):
# No arguments always matches a variadic signature
assert full_signature
return all(variadic_signature_matches_iter(types, full_signature))
def get_func_signature(function):
sig = inspect.signature(function)
types = []
for p in sig.parameters.values():
ann = p.annotation
ann = parse_union(ann) or ann
if p.kind in (
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
):
types.append(ann)
if p.kind == inspect.Parameter.VAR_POSITIONAL:
types.append([ann])
return tuple(types)
class Frame:
__slots__ = "args", "types", "mro", "mro_offset"
class Dispatcher(CDispatcher):
""" Dispatch methods based on type signature
Use ``dispatch`` to add implementations
Examples
--------
>>> from multipledispatch import dispatch
>>> @dispatch(int)
... def f(x):
... return x + 1
>>> @dispatch(float)
... def f(x):
... return x - 1
>>> f(3)
4
>>> f(3.0)
2.0
"""
__slots__ = "__name__", "name", "funcs", "_ordering", "doc"
def __init__(self, name, doc=None):
self.name = self.__name__ = name
self.funcs = {}
self.doc = doc
def register(self, *types, **kwargs):
""" register dispatcher with new implementation
>>> f = Dispatcher('f')
>>> @f.register(int)
... def inc(x):
... return x + 1
>>> @f.register(float)
... def dec(x):
... return x - 1
>>> @f.register(list)
... @f.register(tuple)
... def reverse(x):
... return x[::-1]
>>> f(1)
2
>>> f(1.0)
0.0
>>> f([1, 2, 3])
[3, 2, 1]
"""
def _df(func):
self.add(types, func, **kwargs)
return func
return _df
def add(self, signature, func):
""" Add new types/method pair to dispatcher
>>> D = Dispatcher('add')
>>> D.add((int, int), lambda x, y: x + y)
>>> D.add((float, float), lambda x, y: x + y)
>>> D(1, 2)
3
>>> D(1, 2.0)
Traceback (most recent call last):
...
NotImplementedError: Could not find signature for add: <int, float>
When ``add`` detects a warning it calls the ``on_ambiguity`` callback
with a dispatcher/itself, and a set of ambiguous type signature pairs
as inputs. See ``ambiguity_warn`` for an example.
"""
# Handle annotations
if not signature:
signature = get_func_signature(func)
# Handle union types
if any(isinstance(typ, tuple) for typ in signature):
for typs in expand_tuples(signature):
self.add(typs, func)
return
new_signature = []
for index, typ in enumerate(signature, start=1):
if not isinstance(typ, (type, list)):
str_sig = ", ".join(
c.__name__ if isinstance(c, type) else str(c) for c in signature
)
raise TypeError(
"Tried to dispatch on non-type: %s\n"
"In signature: <%s>\n"
"In function: %s" % (typ, str_sig, self.name)
)
# handle variadic signatures
if isinstance(typ, list):
if index != len(signature):
raise TypeError("Variadic signature must be the last element")
if len(typ) != 1:
raise TypeError(
"Variadic signature must contain exactly one element. "
"To use a variadic union type place the desired types "
"inside of a tuple, e.g., [(int, str)]"
)
new_signature.append(Variadic[typ[0]])
else:
new_signature.append(typ)
l = self.funcs.setdefault(tuple(new_signature), [])
for i in l:
if i is func:
raise ValueError("already registered")
l.append(func)
self.enable(func)
self.clear_cache()
try:
del self._ordering
except AttributeError:
pass
@property
def ordering(self):
try:
return self._ordering
except AttributeError:
return self.reorder()
def reorder(self, on_ambiguity=ambiguity_warn):
self._ordering = od = ordering(self.funcs)
amb = ambiguities(self.funcs)
if amb:
on_ambiguity(self, amb)
return od
def __str__(self):
return "<dispatched %s>" % self.name
__repr__ = __str__
def dispatch(self, *types):
"""Deterimine appropriate implementation for this type signature
This method is internal. Users should call this object as a function.
Implementation resolution occurs within the ``__call__`` method.
>>> from multipledispatch import dispatch
>>> @dispatch(int)
... def inc(x):
... return x + 1
>>> implementation = inc.dispatch(int)
>>> implementation(3)
4
>>> print(inc.dispatch(float))
None
See Also:
``multipledispatch.conflict`` - module to determine resolution order
"""
if types in self.funcs:
return self.funcs[types][-1]
for f in self.dispatch_iter(*types):
return f
def dispatch_iter(self, *types):
n = len(types)
for signature in self.ordering:
if (
len(signature) == n
and all(map(issubclass, types, signature))
or len(signature)
and isvariadic(signature[-1])
and variadic_signature_matches(types, signature)
):
yield from self.funcs[signature][::-1]
def __getstate__(self):
return {"name": self.name, "funcs": self.funcs}
def __setstate__(self, d):
self.name = d["name"]
self.funcs = d["funcs"]
self._ordering = ordering(self.funcs)
self._cache = dict()
@property
def __doc__(self):
docs = ["Multiply dispatched method: %s" % self.name]
if self.doc:
docs.append(self.doc)
other = []
for sig in self.ordering[::-1]:
funcs = self.funcs[sig]
s = "Inputs: <%s>\n" % str_signature(sig)
sep = "-" * len(s) + "\n"
for i, func in enumerate(funcs):
s += sep
if len(funcs) > 1:
s += "[Handler %d]\n\n" % (i + 1)
if i:
s += "\n\n"
if func.__doc__:
s += func.__doc__.strip()
else:
s += repr(func) + "\n"
docs.append(s)
return "\n\n".join(docs)
def _help(self, *args):
return self.dispatch(*map(type, args)).__doc__
def help(self, *args, **kwargs):
""" Print docstring for the function corresponding to inputs """
print(self._help(*args))
def _source(self, *args):
func = self.dispatch(*map(type, args))
if not func:
raise TypeError("No function found")
return source(func)
def source(self, *args, **kwargs):
""" Print source code for the function corresponding to inputs """
print(self._source(*args))
def source(func):
s = "File: %s\n\n" % inspect.getsourcefile(func)
s = s + inspect.getsource(func)
return s
class MethodDispatcher(Dispatcher):
""" Dispatch methods based on type signature
See Also:
Dispatcher
"""
__slots__ = ("obj", "cls")
@classmethod
def get_func_params(cls, func):
if hasattr(inspect, "signature"):
sig = inspect.signature(func)
return itl.islice(sig.parameters.values(), 1, None)
def __get__(self, instance, owner):
self.obj = instance
self.cls = owner
return self
def __call__(self, *args, **kwargs):
types = tuple([type(arg) for arg in args])
func = self.dispatch(*types)
if not func:
raise NotImplementedError(
"Could not find signature for %s: <%s>"
% (self.name, str_signature(types))
)
return func(self.obj, *args, **kwargs)
def str_signature(sig):
""" String representation of type signature
>>> str_signature((int, float))
'int, float'
"""
return ", ".join(cls.__name__ for cls in sig)
def warning_text(name, amb):
""" The text for ambiguity warnings """
text = "\nAmbiguities exist in dispatched function %s\n\n" % (name)
text += "The following signatures may result in ambiguous behavior:\n"
for pair in amb:
text += "\t" + ", ".join("[" + str_signature(s) + "]" for s in pair) + "\n"
text += "\n\nConsider making the following additions:\n\n"
text += "\n\n".join(
[
"@dispatch(" + str_signature(super_signature(s)) + ")\ndef %s(...)" % name
for s in amb
]
)
return text
import sys
import typing
from collections import OrderedDict
def raises(err, lamda):
try:
lamda()
return False
except err:
return True
def expand_tuples(L):
"""
>>> expand_tuples([1, (2, 3)])
[(1, 2), (1, 3)]
>>> expand_tuples([1, 2])
[(1, 2)]
"""
if not L:
return [()]
elif not isinstance(L[0], tuple):
rest = expand_tuples(L[1:])
return [(L[0],) + t for t in rest]
else:
rest = expand_tuples(L[1:])
return [(item,) + t for t in rest for item in L[0]]
# Taken from theano/theano/gof/sched.py
# Avoids licensing issues because this was written by Matthew Rocklin
def _toposort(edges):
""" Topological sort algorithm by Kahn [1] - O(nodes + vertices)
inputs:
edges - a dict of the form {a: {b, c}} where b and c depend on a
outputs:
L - an ordered list of nodes that satisfy the dependencies of edges
>>> _toposort({1: (2, 3), 2: (3, )})
[1, 2, 3]
Closely follows the wikipedia page [2]
[1] Kahn, Arthur B. (1962), "Topological sorting of large networks",
Communications of the ACM
[2] http://en.wikipedia.org/wiki/Toposort#Algorithms
"""
incoming_edges = reverse_dict(edges)
incoming_edges = OrderedDict((k, set(val)) for k, val in incoming_edges.items())
S = OrderedDict.fromkeys(v for v in edges if v not in incoming_edges)
L = []
while S:
n, _ = S.popitem()
L.append(n)
for m in edges.get(n, ()):
assert n in incoming_edges[m]
incoming_edges[m].remove(n)
if not incoming_edges[m]:
S[m] = None
if any(incoming_edges.get(v, None) for v in edges):
raise ValueError("Input has cycles")
return L
def reverse_dict(d):
"""Reverses direction of dependence dict
>>> d = {'a': (1, 2), 'b': (2, 3), 'c':()}
>>> reverse_dict(d) # doctest: +SKIP
{1: ('a',), 2: ('a', 'b'), 3: ('b',)}
:note: dict order are not deterministic. As we iterate on the
input dict, it make the output of this function depend on the
dict order. So this function output order should be considered
as undeterministic.
"""
result = OrderedDict()
for key in d:
for val in d[key]:
result[val] = result.get(val, tuple()) + (key,)
return result
# Taken from toolz
# Avoids licensing issues because this version was authored by Matthew Rocklin
def groupby(func, seq):
""" Group a collection by a key function
>>> names = ['Alice', 'Bob', 'Charlie', 'Dan', 'Edith', 'Frank']
>>> groupby(len, names) # doctest: +SKIP
{3: ['Bob', 'Dan'], 5: ['Alice', 'Edith', 'Frank'], 7: ['Charlie']}
>>> iseven = lambda x: x % 2 == 0
>>> groupby(iseven, [1, 2, 3, 4, 5, 6, 7, 8]) # doctest: +SKIP
{False: [1, 3, 5, 7], True: [2, 4, 6, 8]}
See Also:
``countby``
"""
d = OrderedDict()
for item in seq:
key = func(item)
if key not in d:
d[key] = list()
d[key].append(item)
return d
def typename(type):
"""Get the name of `type`.
Parameters
----------
type : Union[Type, Tuple[Type]]
Returns
-------
str
The name of `type` or a tuple of the names of the types in `type`.
Examples
--------
>>> typename(int)
'int'
>>> typename((int, float))
'(int, float)'
"""
try:
return type.__name__
except AttributeError:
if len(type) == 1:
return typename(*type)
return "(%s)" % ", ".join(map(typename, type))
# parse typing.Union
if sys.version_info < (3, 6):
def parse_union(ann):
if type(ann) is not typing.UnionMeta:
return
return ann.__union_params__
elif sys.version_info < (3, 7):
def parse_union(ann):
if type(ann) is not typing._Union:
return
return ann.__args__
elif sys.version_info < (3, 8):
def parse_union(ann):
if type(ann) is not typing._GenericAlias:
if type(ann) is not typing.Union:
return
else:
if ann.__origin__ is not typing.Union:
return
return ann.__args__
else:
def parse_union(ann):
if typing.get_origin(ann) is not typing.Union:
return
return typing.get_args(ann)
from .utils import typename
class VariadicSignatureType(type):
# checking if subclass is a subclass of self
def __subclasscheck__(self, subclass):
other_type = subclass.variadic_type if isvariadic(subclass) else (subclass,)
return subclass is self or all(
issubclass(other, self.variadic_type) for other in other_type
)
def __eq__(self, other):
"""
Return True if other has the same variadic type
Parameters
----------
other : object (type)
The object (type) to check
Returns
-------
bool
Whether or not `other` is equal to `self`
"""
return isvariadic(other) and set(self.variadic_type) == set(other.variadic_type)
def __hash__(self):
return hash((type(self), frozenset(self.variadic_type)))
def isvariadic(obj):
"""Check whether the type `obj` is variadic.
Parameters
----------
obj : type
The type to check
Returns
-------
bool
Whether or not `obj` is variadic
Examples
--------
>>> isvariadic(int)
False
>>> isvariadic(Variadic[int])
True
"""
return isinstance(obj, VariadicSignatureType)
class VariadicSignatureMeta(type):
"""A metaclass that overrides ``__getitem__`` on the class. This is used to
generate a new type for Variadic signatures. See the Variadic class for
examples of how this behaves.
"""
def __getitem__(self, variadic_type):
if not (isinstance(variadic_type, (type, tuple)) or type(variadic_type)):
raise ValueError(
"Variadic types must be type or tuple of types"
" (Variadic[int] or Variadic[(int, float)]"
)
if not isinstance(variadic_type, tuple):
variadic_type = (variadic_type,)
return VariadicSignatureType(
"Variadic[%s]" % typename(variadic_type),
(),
dict(variadic_type=variadic_type, __slots__=()),
)
class Variadic(metaclass=VariadicSignatureMeta):
"""A class whose getitem method can be used to generate a new type
representing a specific variadic signature.
Examples
--------
>>> Variadic[int] # any number of int arguments
<class 'multipledispatch.variadic.Variadic[int]'>
>>> Variadic[(int, str)] # any number of one of int or str arguments
<class 'multipledispatch.variadic.Variadic[(int, str)]'>
>>> issubclass(int, Variadic[int])
True
>>> issubclass(int, Variadic[(int, str)])
True
>>> issubclass(str, Variadic[(int, str)])
True
>>> issubclass(float, Variadic[(int, str)])
False
"""
......@@ -66,13 +66,13 @@ class RawTensor(TensorBase):
delete(self._handle)
@apply.add
@apply.register()
def _(op: OpDef, *args: RawTensor):
outputs = apply_op(op, tuple(i._handle for i in args))
return tuple(map(RawTensor, outputs))
@apply.add
@apply.register()
def _(op: Const, *args: RawTensor):
dtype = op.dtype
device = as_device(op.device).to_c()
......
......@@ -79,7 +79,7 @@ def get_context():
return _context
@apply.add
@apply.register()
def tensor_apply(op: OpBase, *args: Tensor):
data = tuple(i._data if isinstance(i, Tensor) else i for i in args)
# type(Tensor._data) is RawTensor
......
......@@ -46,7 +46,7 @@ __all__ = [
]
@apply.add
@apply.register()
def _(op: RemoteSend, *args: Tensor):
ret = tensor_apply(op, *args)
......
numpy>=1.18
multipledispatch==0.6.0
opencv-python
pyarrow
requests
......
#include "./dispatcher.h"
#include "./pyext17.h"
#include "megbrain/utils/hash.h"
#include "megbrain/utils/small_vector.h"
#include <unordered_map>
#include <structmember.h>
namespace py = pybind11;
namespace pyx = pyext17;
namespace {
struct Handler {
PyObject* func; // borrowed
bool enabled;
Handler() = default;
Handler(PyObject* func_, bool enable = true) : func(func_), enabled(enable) {}
};
using FastSig = mgb::SmallVector<void*, 8>;
using MRO = std::vector<Handler*>;
struct Frame {
MRO* mro;
size_t mro_offset;
Frame() = default;
Frame(MRO* mro_, size_t mro_offset_ = 0) : mro(mro_), mro_offset(mro_offset_) {}
};
struct FastSigHash {
size_t operator()(const FastSig& sig) const {
auto* ptr = &sig.front();
return mgb::XXHash()
.update(ptr, sig.size() * sizeof(FastSig::value_type))
.digest();
}
};
struct ObjectIdHash : std::hash<void*> {
size_t operator()(const py::handle& h) const {
return std::hash<void*>::operator()(h.ptr());
}
};
struct Dispatcher {
std::unordered_map<FastSig, std::unique_ptr<MRO>, FastSigHash> cache;
std::vector<Frame> stack;
std::unordered_map<py::object, std::unique_ptr<Handler>, ObjectIdHash> registry;
inline py::handle self() {
return pyx::wrap<Dispatcher>::pycast(this);
}
bool prepare_call(PyObject*const* args, Py_ssize_t nargs) {
FastSig sig(nargs);
for (Py_ssize_t i = 0; i < nargs; ++i) {
sig[i] = Py_TYPE(args[i]);
}
auto it = cache.find(sig);
if (it == cache.end()) {
if (auto mro = resolve(sig)) {
it = cache.emplace(std::move(sig), std::move(mro)).first;
} else {
return false;
}
}
stack.emplace_back(it->second.get());
return true;
}
template<typename T>
PyObject* do_call(T&& caller) {
auto& frame = stack.back();
auto& mro = *frame.mro;
auto& i = frame.mro_offset;
for (; i < mro.size(); ++i) {
if (mro[i]->enabled) {
auto ret = caller(mro[i]->func);
if (ret != Py_NotImplemented) {
stack.pop_back();
return ret;
}
Py_DECREF(ret);
}
}
PyErr_SetString(PyExc_NotImplementedError, "mro exhausted");
stack.pop_back();
return nullptr;
}
std::unique_ptr<MRO> resolve(const FastSig& sig) {
try {
py::tuple args(sig.size());
for (size_t i = 0; i < sig.size(); ++i) {
args[i] = (PyObject*)sig[i];
}
auto mro_iter = self().attr("dispatch_iter")(*args);
auto ret = std::make_unique<MRO>();
for (auto i : mro_iter) {
auto it = registry.find(py::reinterpret_borrow<py::object>(i));
if (it == registry.end()) {
PyErr_SetString(PyExc_RuntimeError, "resolved to unregistered function");
return nullptr;
}
ret->push_back(it->second.get());
}
return ret;
} catch (py::error_already_set& e) {
e.restore();
} catch (std::runtime_error& e) {
PyErr_SetString(PyExc_RuntimeError, e.what());
}
return nullptr;
}
public:
static constexpr auto tp_name = "Dispatcher";
PyObject* tp_vectorcall(PyObject*const* args, Py_ssize_t nargs) {
if (!prepare_call(args, nargs)) return nullptr;
return do_call([=](PyObject* func){return _PyObject_FastCall(func, args, nargs);});
}
PyObject* tp_call(PyObject* args, PyObject* kwargs) {
if (!prepare_call(&PyTuple_GET_ITEM(args, 0), PyTuple_GET_SIZE(args))) return nullptr;
return do_call([=](PyObject* func){return PyObject_Call(func, args, kwargs);});
}
PyObject* super(PyObject*const* args, Py_ssize_t nargs) {
if (stack.empty()) {
PyErr_SetString(PyExc_RuntimeError, "super called at top level");
return nullptr;
}
stack.emplace_back(stack.back()).mro_offset++;
return do_call([=](PyObject* func){return _PyObject_FastCall(func, args, nargs);});
}
void enable(PyObject* func) {
auto obj = py::reinterpret_borrow<py::object>(func);
auto it = registry.find(obj);
if (it != registry.end()) {
it->second->enabled = true;
} else {
registry.emplace(std::move(obj), std::make_unique<Handler>(func));
}
}
PyObject* disable(PyObject* func) {
auto obj = py::reinterpret_borrow<py::object>(func);
auto it = registry.find(obj);
if (it == registry.end()) {
PyErr_SetString(PyExc_ValueError, "function not registered");
return nullptr;
} else {
it->second->enabled = false;
}
Py_RETURN_NONE;
}
void clear_cache() {
cache.clear();
}
};
} // namespace
void init_dispatcher(py::module m) {
auto* dispatcher_type = pyx::wrap<Dispatcher>::type()
.def<&Dispatcher::enable>("enable")
.def<&Dispatcher::disable>("disable")
.def<&Dispatcher::clear_cache>("clear_cache")
.def<&Dispatcher::tp_vectorcall>("call")
.def<&Dispatcher::super>("super")
.finalize();
if (!dispatcher_type) throw py::error_already_set();
m.attr("Dispatcher") = dispatcher_type;
}
#pragma once
#include <pybind11/pybind11.h>
void init_dispatcher(pybind11::module);
......@@ -21,6 +21,8 @@
#include "./graph_rt.h"
#include "./ops.h"
#include "./dispatcher.h"
namespace py = pybind11;
#ifndef MODULE_NAME
......@@ -63,4 +65,6 @@ PYBIND11_MODULE(MODULE_NAME, m) {
from .graph import *
)",
py::getattr(m, "__dict__"));
init_dispatcher(submodule(m, "dispatcher"));
}
#pragma once
#include <stdexcept>
#include <vector>
#include <utility>
#include <Python.h>
namespace pyext17 {
#ifdef METH_FASTCALL
constexpr bool has_fastcall = true;
#else
constexpr bool has_fastcall = false;
#endif
template<typename... Args>
struct invocable_with {
template<typename T>
constexpr bool operator()(T&& lmb) {
return std::is_invocable_v<T, Args...>;
}
};
#define HAS_MEMBER_TYPE(T, U) invocable_with<T>{}([](auto&& x) -> typename std::decay_t<decltype(x)>::U {})
#define HAS_MEMBER(T, m) invocable_with<T>{}([](auto&& x) -> decltype(&std::decay_t<decltype(x)>::m) {})
inline PyObject* cvt_retval(PyObject* rv) {
return rv;
}
#define CVT_RET_PYOBJ(...) \
if constexpr (std::is_same_v<decltype(__VA_ARGS__), void>) { \
__VA_ARGS__; \
Py_RETURN_NONE; \
} else { \
return cvt_retval(__VA_ARGS__); \
}
template <typename T>
struct wrap {
private:
typedef wrap<T> wrap_t;
public:
PyObject_HEAD
std::aligned_storage_t<sizeof(T), alignof(T)> storage;
inline T* inst() {
return reinterpret_cast<T*>(&storage);
}
inline static PyObject* pycast(T* ptr) {
return (PyObject*)((char*)ptr - offsetof(wrap_t, storage));
}
private:
// method wrapper
enum struct meth_type {
noarg,
varkw,
fastcall,
singarg
};
template<auto f>
struct detect_meth_type {
static constexpr meth_type value = []() {
using F = decltype(f);
static_assert(std::is_member_function_pointer_v<F>);
if constexpr (std::is_invocable_v<F, T>) {
return meth_type::noarg;
} else if constexpr (std::is_invocable_v<F, T, PyObject*, PyObject*>) {
return meth_type::varkw;
} else if constexpr (std::is_invocable_v<F, T, PyObject*const*, Py_ssize_t>) {
return meth_type::fastcall;
} else if constexpr (std::is_invocable_v<F, T, PyObject*>) {
return meth_type::singarg;
} else {
static_assert(!std::is_same_v<F, F>);
}
}();
};
template<meth_type, auto f>
struct meth {};
template<auto f>
struct meth<meth_type::noarg, f> {
static constexpr int flags = METH_NOARGS;
static PyObject* impl(PyObject* self, PyObject*) {
auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
CVT_RET_PYOBJ((inst->*f)());
}
};
template<auto f>
struct meth<meth_type::varkw, f> {
static constexpr int flags = METH_VARARGS | METH_KEYWORDS;
static PyObject* impl(PyObject* self, PyObject* args, PyObject* kwargs) {
auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
CVT_RET_PYOBJ((inst->*f)(args, kwargs));
}
};
template<auto f>
struct meth<meth_type::fastcall, f> {
#ifdef METH_FASTCALL
static constexpr int flags = METH_FASTCALL;
static PyObject* impl(PyObject* self, PyObject*const* args, Py_ssize_t nargs) {
auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
CVT_RET_PYOBJ((inst->*f)(args, nargs));
}
#else
static constexpr int flags = METH_VARARGS;
static PyObject* impl(PyObject* self, PyObject* args) {
auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
auto* arr = &PyTuple_GET_ITEM(args, 0);
auto size = PyTuple_GET_SIZE(args);
CVT_RET_PYOBJ((inst->*f)(arr, size));
}
#endif
};
template<auto f>
struct meth<meth_type::singarg, f> {
static constexpr int flags = METH_O;
static PyObject* impl(PyObject* self, PyObject* obj) {
auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
CVT_RET_PYOBJ((inst->*f)(obj));
}
};
template<auto f>
static constexpr PyMethodDef make_meth_def(const char* name, const char* doc = nullptr) {
using M = meth<detect_meth_type<f>::value, f>;
return {name, (PyCFunction)M::impl, M::flags, doc};
}
// polyfills
struct tp_new {
static constexpr bool provided = HAS_MEMBER(T, tp_new);
static constexpr bool varkw = std::is_constructible_v<T, PyObject*, PyObject*>;
static constexpr bool noarg = std::is_default_constructible_v<T>;
template<typename = void>
static PyObject* impl(PyTypeObject* type, PyObject* args, PyObject* kwargs) {
auto* self = type->tp_alloc(type, 0);
auto* ptr = reinterpret_cast<wrap_t*>(self)->inst();
if constexpr (varkw) {
new(ptr) T(args, kwargs);
} else {
new(ptr) T();
}
return self;
}
static constexpr newfunc value = []() {if constexpr (provided) return T::tp_new;
else if constexpr (varkw || noarg) return impl<>;
else return nullptr;}();
};
struct tp_dealloc {
static constexpr bool provided = HAS_MEMBER(T, tp_dealloc);
template<typename = void>
static void impl(PyObject* self) {
reinterpret_cast<wrap_t*>(self)->inst()->~T();
Py_TYPE(self)->tp_free(self);
}
static constexpr destructor value = []() {if constexpr (provided) return T::tp_dealloc;
else return impl<>;}();
};
struct tp_call {
static constexpr bool valid = HAS_MEMBER(T, tp_call);
static constexpr bool static_form = invocable_with<T, PyObject*, PyObject*, PyObject*>{}(
[](auto&& t, auto... args) -> decltype(std::decay_t<decltype(t)>::tp_call(args...)) {});
template<typename = void>
static PyObject* impl(PyObject* self, PyObject* args, PyObject* kwargs) {
auto* inst = reinterpret_cast<wrap_t*>(self)->inst();
CVT_RET_PYOBJ(inst->tp_call(args, kwargs));
}
static constexpr ternaryfunc value = []() {if constexpr (static_form) return T::tp_call;
else if constexpr (valid) return impl<>;
else return nullptr;}();
};
public:
class TypeBuilder {
std::vector<PyMethodDef> m_methods;
PyTypeObject m_type;
bool m_finalized = false;
bool m_ready = false;
void check_finalized() {
if (m_finalized) {
throw std::runtime_error("type is already finalized");
}
}
public:
TypeBuilder(const TypeBuilder&) = delete;
TypeBuilder& operator=(const TypeBuilder&) = delete;
TypeBuilder() : m_type{PyVarObject_HEAD_INIT(nullptr, 0)} {
// static_assert(HAS_MEMBER(T, tp_name));
if constexpr (HAS_MEMBER(T, tp_name)) {
m_type.tp_name = T::tp_name;
}
m_type.tp_dealloc = tp_dealloc::value;
m_type.tp_call = tp_call::value;
m_type.tp_basicsize = sizeof(wrap_t);
m_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
m_type.tp_new = tp_new::value;
}
PyTypeObject* operator->() {
return &m_type;
}
bool ready() const {
return m_ready;
}
PyObject* finalize() {
if (!m_finalized) {
if (m_methods.size()) {
m_methods.push_back({0});
if (m_type.tp_methods) {
PyErr_SetString(PyExc_SystemError, "tp_method is already set");
return nullptr;
}
m_type.tp_methods = &m_methods[0];
}
if (PyType_Ready(&m_type)) {
return nullptr;
}
m_ready = true;
}
return (PyObject*)&m_type;
}
template<auto f>
TypeBuilder& def(const char* name, const char* doc = nullptr) {
check_finalized();
m_methods.push_back(make_meth_def<f>(name, doc));
return *this;
}
};
static TypeBuilder& type() {
static TypeBuilder type_helper;
return type_helper;
}
};
} // namespace pyext17
#undef HAS_MEMBER_TYPE
#undef HAS_MEMBER
#undef CVT_RET_PYOBJ
from megengine.core.tensor.multipledispatch import Dispatcher
def test_register_many():
f = Dispatcher("f")
log = []
@f.register()
def _(x: int):
log.append("a")
return log[-1]
@f.register()
def _(x: int):
log.append("b")
return log[-1]
assert f(0) == "b"
assert log == ["b"]
def test_return_not_implemented():
f = Dispatcher("f")
log = []
@f.register()
def _(x: int):
log.append("a")
return log[-1]
@f.register()
def _(x: int):
log.append("b")
return NotImplemented
assert f(0) == "a"
assert log == ["b", "a"]
def test_super():
f = Dispatcher("f")
log = []
@f.register()
def _(x: int):
log.append("a")
return log[-1]
@f.register()
def _(x: int):
log.append("b")
return f.super(x)
assert f(0) == "a"
assert log == ["b", "a"]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册