diff --git a/ACKNOWLEDGMENTS b/ACKNOWLEDGMENTS index ba55abab503fae984248bf0c74979f3bb0ae9cfc..addbf8fce322eed4410c16f88963539bce53c07c 100644 --- a/ACKNOWLEDGMENTS +++ b/ACKNOWLEDGMENTS @@ -209,6 +209,44 @@ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +********************************************************************************************************************************* +multipledispatch +-------------------------------------------------------------------- +Copyright (c) 2014 Matthew Rocklin + +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + a. Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimer. + b. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + c. Neither the name of multipledispatch nor the names of its contributors + may be used to endorse or promote products derived from this software + without specific prior written permission. + + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY +OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH +DAMAGE. +********************************************************************************************************************************* + + + + + + ********************************************************************************************************************************* Software Licensed under the BSD 3-Clause License and Other Licenses of the Third-party Components therein: -------------------------------------------------------------------- diff --git a/imperative/python/megengine/core/autodiff/grad.py b/imperative/python/megengine/core/autodiff/grad.py index 8b1b337685adb179f51c6f3e160507e0cbaa0a97..c30e4113f4fb9fe3f5f4870ff5cf2f79f746274c 100644 --- a/imperative/python/megengine/core/autodiff/grad.py +++ b/imperative/python/megengine/core/autodiff/grad.py @@ -343,7 +343,7 @@ def default_has_grad_fn(opnode, reached): return False -@apply.add +@apply.register() def tracer_apply(op: (OpDef, Function), *args: typing.Optional[Tracer]): args = tuple(i if isinstance(i, Tracer) else None for i in args) input_requires_grad = list(map(bool, args)) @@ -385,6 +385,6 @@ def tracer_apply(op: (OpDef, Function), *args: typing.Optional[Tracer]): return tuple(outputs) -@apply.add +@apply.register() def _(op: Const, *_: typing.Optional[Tracer]): return None diff --git a/imperative/python/megengine/core/ops/builtin/__init__.py b/imperative/python/megengine/core/ops/builtin/__init__.py index 4656cbd2890bed2a80e50b54470f1c5f63357267..997c9d2f122e63781f337abbc5af04197b2deb2f 100644 --- a/imperative/python/megengine/core/ops/builtin/__init__.py +++ b/imperative/python/megengine/core/ops/builtin/__init__.py @@ -19,7 +19,7 @@ from .._internal.helper import PodOpVisitor OpBase.register(OpDef) # forward to apply(OpDef, ...) -@apply.add +@apply.register() def _(op: PodOpVisitor, *args: Union[TensorBase, TensorWrapperBase]): return apply(op.to_c(), *args) diff --git a/imperative/python/megengine/core/tensor/core.py b/imperative/python/megengine/core/tensor/core.py index 3a09f5246fc7acf68670a68b9fb1d06dafe7feb5..0c1bcee79cafd1b8b6ed8d1b033504ee06cefdef 100644 --- a/imperative/python/megengine/core/tensor/core.py +++ b/imperative/python/megengine/core/tensor/core.py @@ -13,7 +13,7 @@ import sys import typing from abc import ABC -import multipledispatch +from .multipledispatch import Dispatcher class OpBase(ABC): @@ -29,84 +29,17 @@ class TensorWrapperBase: pass -class Dispatcher(multipledispatch.Dispatcher): - def add(self, f, g=None): - if g is None: - super().add(get_signature(f), f) - else: - super().add(f, g) - - return f - - def __get__(self, instance, owner=None): - if instance is not None: - return self - return functools.partial(self, instance) - - -if sys.version_info < (3, 6): - - def parse_union(ann): - if type(ann) is not typing.UnionMeta: - return - return ann.__union_params__ - - -elif sys.version_info < (3, 7): - - def parse_union(ann): - if type(ann) is not typing._Union: - return - return ann.__args__ - - -elif sys.version_info < (3, 8): - - def parse_union(ann): - if type(ann) is not typing._GenericAlias: - if type(ann) is not typing.Union: - return - else: - if ann.__origin__ is not typing.Union: - return - return ann.__args__ - - -else: - - def parse_union(ann): - if typing.get_origin(ann) is not typing.Union: - return - return typing.get_args(ann) - - -def get_signature(function, op_type=None): - sig = inspect.signature(function) - types = [] - for p in sig.parameters.values(): - ann = p.annotation - ann = parse_union(ann) or ann - if p.kind in ( - inspect.Parameter.POSITIONAL_ONLY, - inspect.Parameter.POSITIONAL_OR_KEYWORD, - ): - types.append(ann) - if p.kind == inspect.Parameter.VAR_POSITIONAL: - types.append([ann]) - return tuple(types) - - apply = Dispatcher("apply") OpBase.apply = apply -@apply.add +@apply.register() def _(op: OpBase, *args: TensorBase): raise NotImplementedError -@apply.add +@apply.register() def _(op: OpBase, *args: TensorWrapperBase): assert args Wrapper = type(args[0]) diff --git a/imperative/python/megengine/core/tensor/function.py b/imperative/python/megengine/core/tensor/function.py index 9cbb3d56e2ba9fb91a15c64d24ab7873eaffbc05..6b51c6675cc652d646e5b2094e3a3a31445ca151 100644 --- a/imperative/python/megengine/core/tensor/function.py +++ b/imperative/python/megengine/core/tensor/function.py @@ -102,7 +102,7 @@ class Function: Function.apply = Function.__call__ -@apply.add +@apply.register() def _(op: Function, *args: TensorWrapperBase): assert args Wrapper = type(args[0]) @@ -148,11 +148,11 @@ def _(op: Function, *args: TensorWrapperBase): return tuple(map(Wrapper, outputs)) -@apply.add +@apply.register() def _(op: Function, *args: Tensor): raise NotImplementedError -@apply.add +@apply.register() def _(op: Function, *args: RawTensor): raise NotImplementedError diff --git a/imperative/python/megengine/core/tensor/megbrain_graph.py b/imperative/python/megengine/core/tensor/megbrain_graph.py index 86f7bcc11bfff0450cca6b1bd618ae7543908e6d..3ecda5cf0dcf6e5f07199c748cdc07bb34c1a971 100644 --- a/imperative/python/megengine/core/tensor/megbrain_graph.py +++ b/imperative/python/megengine/core/tensor/megbrain_graph.py @@ -111,7 +111,7 @@ def _unwrap(x): return x._node -@apply.add +@apply.register() def _(op: OpDef, *args: VarNode): outputs = _imperative_rt.invoke_op(op, _unwrap(args)) return _wrap(outputs) diff --git a/imperative/python/megengine/core/tensor/multipledispatch/__init__.py b/imperative/python/megengine/core/tensor/multipledispatch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a87b9d9455650335f1ab271a261f586dd0caee86 --- /dev/null +++ b/imperative/python/megengine/core/tensor/multipledispatch/__init__.py @@ -0,0 +1,10 @@ +# This directory is a fork of multipledispatch. +# +# Repo: https://github.com/mrocklin/multipledispatch +# Commit: 9e3c87d0cee57972fd5cc33fe5cacde77c781834 +# Authors: Matthew Rocklin et al. +# +# Refer to ACKNOWLEDGEMENT for copyright and liscense information + +from .core import dispatch +from .dispatcher import Dispatcher diff --git a/imperative/python/megengine/core/tensor/multipledispatch/conflict.py b/imperative/python/megengine/core/tensor/multipledispatch/conflict.py new file mode 100644 index 0000000000000000000000000000000000000000..687b051d44334d2072de2ef13f923ab003c838d6 --- /dev/null +++ b/imperative/python/megengine/core/tensor/multipledispatch/conflict.py @@ -0,0 +1,121 @@ +from .utils import _toposort, groupby +from .variadic import isvariadic + + +class AmbiguityWarning(Warning): + pass + + +def supercedes(a, b): + """ A is consistent and strictly more specific than B """ + if len(a) < len(b): + # only case is if a is empty and b is variadic + return not a and len(b) == 1 and isvariadic(b[-1]) + elif len(a) == len(b): + return all(map(issubclass, a, b)) + else: + # len(a) > len(b) + p1 = 0 + p2 = 0 + while p1 < len(a) and p2 < len(b): + cur_a = a[p1] + cur_b = b[p2] + if not (isvariadic(cur_a) or isvariadic(cur_b)): + if not issubclass(cur_a, cur_b): + return False + p1 += 1 + p2 += 1 + elif isvariadic(cur_a): + assert p1 == len(a) - 1 + return p2 == len(b) - 1 and issubclass(cur_a, cur_b) + elif isvariadic(cur_b): + assert p2 == len(b) - 1 + if not issubclass(cur_a, cur_b): + return False + p1 += 1 + return p2 == len(b) - 1 and p1 == len(a) + + +def consistent(a, b): + """ It is possible for an argument list to satisfy both A and B """ + + # Need to check for empty args + if not a: + return not b or isvariadic(b[0]) + if not b: + return not a or isvariadic(a[0]) + + # Non-empty args check for mutual subclasses + if len(a) == len(b): + return all(issubclass(aa, bb) or issubclass(bb, aa) for aa, bb in zip(a, b)) + else: + p1 = 0 + p2 = 0 + while p1 < len(a) and p2 < len(b): + cur_a = a[p1] + cur_b = b[p2] + if not issubclass(cur_b, cur_a) and not issubclass(cur_a, cur_b): + return False + if not (isvariadic(cur_a) or isvariadic(cur_b)): + p1 += 1 + p2 += 1 + elif isvariadic(cur_a): + p2 += 1 + elif isvariadic(cur_b): + p1 += 1 + # We only need to check for variadic ends + # Variadic types are guaranteed to be the last element + return isvariadic(cur_a) and p2 == len(b) or isvariadic(cur_b) and p1 == len(a) + + +def ambiguous(a, b): + """ A is consistent with B but neither is strictly more specific """ + return consistent(a, b) and not (supercedes(a, b) or supercedes(b, a)) + + +def ambiguities(signatures): + """ All signature pairs such that A is ambiguous with B """ + signatures = list(map(tuple, signatures)) + return set( + (a, b) + for a in signatures + for b in signatures + if hash(a) < hash(b) + and ambiguous(a, b) + and not any(supercedes(c, a) and supercedes(c, b) for c in signatures) + ) + + +def super_signature(signatures): + """ A signature that would break ambiguities """ + n = len(signatures[0]) + assert all(len(s) == n for s in signatures) + + return [max([type.mro(sig[i]) for sig in signatures], key=len)[0] for i in range(n)] + + +def edge(a, b, tie_breaker=hash): + """ A should be checked before B + + Tie broken by tie_breaker, defaults to ``hash`` + """ + # A either supercedes B and B does not supercede A or if B does then call + # tie_breaker + return supercedes(a, b) and ( + not supercedes(b, a) or tie_breaker(a) > tie_breaker(b) + ) + + +def ordering(signatures): + """ A sane ordering of signatures to check, first to last + + Topoological sort of edges as given by ``edge`` and ``supercedes`` + """ + signatures = list(map(tuple, signatures)) + edges = [(a, b) for a in signatures for b in signatures if edge(a, b)] + edges = groupby(lambda x: x[0], edges) + for s in signatures: + if s not in edges: + edges[s] = [] + edges = dict((k, [b for a, b in v]) for k, v in edges.items()) + return _toposort(edges) diff --git a/imperative/python/megengine/core/tensor/multipledispatch/core.py b/imperative/python/megengine/core/tensor/multipledispatch/core.py new file mode 100644 index 0000000000000000000000000000000000000000..06fb2319a1ce2865db891dddb3e9b728fbb51fa4 --- /dev/null +++ b/imperative/python/megengine/core/tensor/multipledispatch/core.py @@ -0,0 +1,88 @@ +import inspect +import sys + +from .dispatcher import Dispatcher, MethodDispatcher, ambiguity_warn + +global_namespace = dict() + + +def dispatch(*types, **kwargs): + """ Dispatch function on the types of the inputs + + Supports dispatch on all non-keyword arguments. + + Collects implementations based on the function name. Ignores namespaces. + + If ambiguous type signatures occur a warning is raised when the function is + defined suggesting the additional method to break the ambiguity. + + Examples + -------- + + >>> @dispatch(int) + ... def f(x): + ... return x + 1 + + >>> @dispatch(float) + ... def f(x): + ... return x - 1 + + >>> f(3) + 4 + >>> f(3.0) + 2.0 + + Specify an isolated namespace with the namespace keyword argument + + >>> my_namespace = dict() + >>> @dispatch(int, namespace=my_namespace) + ... def foo(x): + ... return x + 1 + + Dispatch on instance methods within classes + + >>> class MyClass(object): + ... @dispatch(list) + ... def __init__(self, data): + ... self.data = data + ... @dispatch(int) + ... def __init__(self, datum): + ... self.data = [datum] + """ + namespace = kwargs.get("namespace", global_namespace) + + types = tuple(types) + + def _df(func): + name = func.__name__ + + if ismethod(func): + dispatcher = inspect.currentframe().f_back.f_locals.get( + name, MethodDispatcher(name), + ) + else: + if name not in namespace: + namespace[name] = Dispatcher(name) + dispatcher = namespace[name] + + dispatcher.add(types, func) + return dispatcher + + return _df + + +def ismethod(func): + """ Is func a method? + + Note that this has to work as the method is defined but before the class is + defined. At this stage methods look like functions. + """ + if hasattr(inspect, "signature"): + signature = inspect.signature(func) + return signature.parameters.get("self", None) is not None + else: + if sys.version_info.major < 3: + spec = inspect.getargspec(func) + else: + spec = inspect.getfullargspec(func) + return spec and spec.args and spec.args[0] == "self" diff --git a/imperative/python/megengine/core/tensor/multipledispatch/dispatcher.py b/imperative/python/megengine/core/tensor/multipledispatch/dispatcher.py new file mode 100644 index 0000000000000000000000000000000000000000..de7a0c2eaf1a352350751c19212bbc90c3bb2fc2 --- /dev/null +++ b/imperative/python/megengine/core/tensor/multipledispatch/dispatcher.py @@ -0,0 +1,401 @@ +import copy +import inspect +import itertools as itl +from warnings import warn + +from ..._imperative_rt.dispatcher import Dispatcher as CDispatcher +from .conflict import AmbiguityWarning, ambiguities, ordering, super_signature +from .utils import expand_tuples, parse_union +from .variadic import Variadic, isvariadic + + +def ambiguity_warn(dispatcher, ambiguities): + """ Raise warning when ambiguity is detected + + Parameters + ---------- + dispatcher : Dispatcher + The dispatcher on which the ambiguity was detected + ambiguities : set + Set of type signature pairs that are ambiguous within this dispatcher + + See Also: + Dispatcher.add + warning_text + """ + warn(warning_text(dispatcher.name, ambiguities), AmbiguityWarning) + + +def variadic_signature_matches_iter(types, full_signature): + """Check if a set of input types matches a variadic signature. + + Notes + ----- + The algorithm is as follows: + + Initialize the current signature to the first in the sequence + + For each type in `types`: + If the current signature is variadic + If the type matches the signature + yield True + Else + Try to get the next signature + If no signatures are left we can't possibly have a match + so yield False + Else + yield True if the type matches the current signature + Get the next signature + """ + sigiter = iter(full_signature) + sig = next(sigiter) + for typ in types: + matches = issubclass(typ, sig) + yield matches + if not isvariadic(sig): + # we're not matching a variadic argument, so move to the next + # element in the signature + sig = next(sigiter) + else: + try: + sig = next(sigiter) + except StopIteration: + assert isvariadic(sig) + yield True + else: + # We have signature items left over, so all of our arguments + # haven't matched + yield False + + +def variadic_signature_matches(types, full_signature): + # No arguments always matches a variadic signature + assert full_signature + return all(variadic_signature_matches_iter(types, full_signature)) + + +def get_func_signature(function): + sig = inspect.signature(function) + types = [] + for p in sig.parameters.values(): + ann = p.annotation + ann = parse_union(ann) or ann + if p.kind in ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ): + types.append(ann) + if p.kind == inspect.Parameter.VAR_POSITIONAL: + types.append([ann]) + return tuple(types) + + +class Frame: + __slots__ = "args", "types", "mro", "mro_offset" + + +class Dispatcher(CDispatcher): + """ Dispatch methods based on type signature + + Use ``dispatch`` to add implementations + + Examples + -------- + + >>> from multipledispatch import dispatch + >>> @dispatch(int) + ... def f(x): + ... return x + 1 + + >>> @dispatch(float) + ... def f(x): + ... return x - 1 + + >>> f(3) + 4 + >>> f(3.0) + 2.0 + """ + + __slots__ = "__name__", "name", "funcs", "_ordering", "doc" + + def __init__(self, name, doc=None): + self.name = self.__name__ = name + self.funcs = {} + self.doc = doc + + def register(self, *types, **kwargs): + """ register dispatcher with new implementation + + >>> f = Dispatcher('f') + >>> @f.register(int) + ... def inc(x): + ... return x + 1 + + >>> @f.register(float) + ... def dec(x): + ... return x - 1 + + >>> @f.register(list) + ... @f.register(tuple) + ... def reverse(x): + ... return x[::-1] + + >>> f(1) + 2 + + >>> f(1.0) + 0.0 + + >>> f([1, 2, 3]) + [3, 2, 1] + """ + + def _df(func): + self.add(types, func, **kwargs) + return func + + return _df + + def add(self, signature, func): + """ Add new types/method pair to dispatcher + + >>> D = Dispatcher('add') + >>> D.add((int, int), lambda x, y: x + y) + >>> D.add((float, float), lambda x, y: x + y) + + >>> D(1, 2) + 3 + >>> D(1, 2.0) + Traceback (most recent call last): + ... + NotImplementedError: Could not find signature for add: + + When ``add`` detects a warning it calls the ``on_ambiguity`` callback + with a dispatcher/itself, and a set of ambiguous type signature pairs + as inputs. See ``ambiguity_warn`` for an example. + """ + # Handle annotations + if not signature: + signature = get_func_signature(func) + + # Handle union types + if any(isinstance(typ, tuple) for typ in signature): + for typs in expand_tuples(signature): + self.add(typs, func) + return + + new_signature = [] + + for index, typ in enumerate(signature, start=1): + if not isinstance(typ, (type, list)): + str_sig = ", ".join( + c.__name__ if isinstance(c, type) else str(c) for c in signature + ) + raise TypeError( + "Tried to dispatch on non-type: %s\n" + "In signature: <%s>\n" + "In function: %s" % (typ, str_sig, self.name) + ) + + # handle variadic signatures + if isinstance(typ, list): + if index != len(signature): + raise TypeError("Variadic signature must be the last element") + + if len(typ) != 1: + raise TypeError( + "Variadic signature must contain exactly one element. " + "To use a variadic union type place the desired types " + "inside of a tuple, e.g., [(int, str)]" + ) + new_signature.append(Variadic[typ[0]]) + else: + new_signature.append(typ) + + l = self.funcs.setdefault(tuple(new_signature), []) + for i in l: + if i is func: + raise ValueError("already registered") + l.append(func) + self.enable(func) + self.clear_cache() + + try: + del self._ordering + except AttributeError: + pass + + @property + def ordering(self): + try: + return self._ordering + except AttributeError: + return self.reorder() + + def reorder(self, on_ambiguity=ambiguity_warn): + self._ordering = od = ordering(self.funcs) + amb = ambiguities(self.funcs) + if amb: + on_ambiguity(self, amb) + return od + + def __str__(self): + return "" % self.name + + __repr__ = __str__ + + def dispatch(self, *types): + """Deterimine appropriate implementation for this type signature + + This method is internal. Users should call this object as a function. + Implementation resolution occurs within the ``__call__`` method. + + >>> from multipledispatch import dispatch + >>> @dispatch(int) + ... def inc(x): + ... return x + 1 + + >>> implementation = inc.dispatch(int) + >>> implementation(3) + 4 + + >>> print(inc.dispatch(float)) + None + + See Also: + ``multipledispatch.conflict`` - module to determine resolution order + """ + + if types in self.funcs: + return self.funcs[types][-1] + + for f in self.dispatch_iter(*types): + return f + + def dispatch_iter(self, *types): + + n = len(types) + for signature in self.ordering: + if ( + len(signature) == n + and all(map(issubclass, types, signature)) + or len(signature) + and isvariadic(signature[-1]) + and variadic_signature_matches(types, signature) + ): + yield from self.funcs[signature][::-1] + + def __getstate__(self): + return {"name": self.name, "funcs": self.funcs} + + def __setstate__(self, d): + self.name = d["name"] + self.funcs = d["funcs"] + self._ordering = ordering(self.funcs) + self._cache = dict() + + @property + def __doc__(self): + docs = ["Multiply dispatched method: %s" % self.name] + + if self.doc: + docs.append(self.doc) + + other = [] + for sig in self.ordering[::-1]: + funcs = self.funcs[sig] + s = "Inputs: <%s>\n" % str_signature(sig) + sep = "-" * len(s) + "\n" + for i, func in enumerate(funcs): + s += sep + if len(funcs) > 1: + s += "[Handler %d]\n\n" % (i + 1) + if i: + s += "\n\n" + if func.__doc__: + s += func.__doc__.strip() + else: + s += repr(func) + "\n" + docs.append(s) + + return "\n\n".join(docs) + + def _help(self, *args): + return self.dispatch(*map(type, args)).__doc__ + + def help(self, *args, **kwargs): + """ Print docstring for the function corresponding to inputs """ + print(self._help(*args)) + + def _source(self, *args): + func = self.dispatch(*map(type, args)) + if not func: + raise TypeError("No function found") + return source(func) + + def source(self, *args, **kwargs): + """ Print source code for the function corresponding to inputs """ + print(self._source(*args)) + + +def source(func): + s = "File: %s\n\n" % inspect.getsourcefile(func) + s = s + inspect.getsource(func) + return s + + +class MethodDispatcher(Dispatcher): + """ Dispatch methods based on type signature + + See Also: + Dispatcher + """ + + __slots__ = ("obj", "cls") + + @classmethod + def get_func_params(cls, func): + if hasattr(inspect, "signature"): + sig = inspect.signature(func) + return itl.islice(sig.parameters.values(), 1, None) + + def __get__(self, instance, owner): + self.obj = instance + self.cls = owner + return self + + def __call__(self, *args, **kwargs): + types = tuple([type(arg) for arg in args]) + func = self.dispatch(*types) + if not func: + raise NotImplementedError( + "Could not find signature for %s: <%s>" + % (self.name, str_signature(types)) + ) + return func(self.obj, *args, **kwargs) + + +def str_signature(sig): + """ String representation of type signature + + >>> str_signature((int, float)) + 'int, float' + """ + return ", ".join(cls.__name__ for cls in sig) + + +def warning_text(name, amb): + """ The text for ambiguity warnings """ + text = "\nAmbiguities exist in dispatched function %s\n\n" % (name) + text += "The following signatures may result in ambiguous behavior:\n" + for pair in amb: + text += "\t" + ", ".join("[" + str_signature(s) + "]" for s in pair) + "\n" + text += "\n\nConsider making the following additions:\n\n" + text += "\n\n".join( + [ + "@dispatch(" + str_signature(super_signature(s)) + ")\ndef %s(...)" % name + for s in amb + ] + ) + return text diff --git a/imperative/python/megengine/core/tensor/multipledispatch/utils.py b/imperative/python/megengine/core/tensor/multipledispatch/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..64933572bf43013aee01c38fd126648df393165d --- /dev/null +++ b/imperative/python/megengine/core/tensor/multipledispatch/utils.py @@ -0,0 +1,177 @@ +import sys +import typing +from collections import OrderedDict + + +def raises(err, lamda): + try: + lamda() + return False + except err: + return True + + +def expand_tuples(L): + """ + + >>> expand_tuples([1, (2, 3)]) + [(1, 2), (1, 3)] + + >>> expand_tuples([1, 2]) + [(1, 2)] + """ + if not L: + return [()] + elif not isinstance(L[0], tuple): + rest = expand_tuples(L[1:]) + return [(L[0],) + t for t in rest] + else: + rest = expand_tuples(L[1:]) + return [(item,) + t for t in rest for item in L[0]] + + +# Taken from theano/theano/gof/sched.py +# Avoids licensing issues because this was written by Matthew Rocklin +def _toposort(edges): + """ Topological sort algorithm by Kahn [1] - O(nodes + vertices) + + inputs: + edges - a dict of the form {a: {b, c}} where b and c depend on a + outputs: + L - an ordered list of nodes that satisfy the dependencies of edges + + >>> _toposort({1: (2, 3), 2: (3, )}) + [1, 2, 3] + + Closely follows the wikipedia page [2] + + [1] Kahn, Arthur B. (1962), "Topological sorting of large networks", + Communications of the ACM + [2] http://en.wikipedia.org/wiki/Toposort#Algorithms + """ + incoming_edges = reverse_dict(edges) + incoming_edges = OrderedDict((k, set(val)) for k, val in incoming_edges.items()) + S = OrderedDict.fromkeys(v for v in edges if v not in incoming_edges) + L = [] + + while S: + n, _ = S.popitem() + L.append(n) + for m in edges.get(n, ()): + assert n in incoming_edges[m] + incoming_edges[m].remove(n) + if not incoming_edges[m]: + S[m] = None + if any(incoming_edges.get(v, None) for v in edges): + raise ValueError("Input has cycles") + return L + + +def reverse_dict(d): + """Reverses direction of dependence dict + + >>> d = {'a': (1, 2), 'b': (2, 3), 'c':()} + >>> reverse_dict(d) # doctest: +SKIP + {1: ('a',), 2: ('a', 'b'), 3: ('b',)} + + :note: dict order are not deterministic. As we iterate on the + input dict, it make the output of this function depend on the + dict order. So this function output order should be considered + as undeterministic. + + """ + result = OrderedDict() + for key in d: + for val in d[key]: + result[val] = result.get(val, tuple()) + (key,) + return result + + +# Taken from toolz +# Avoids licensing issues because this version was authored by Matthew Rocklin +def groupby(func, seq): + """ Group a collection by a key function + + >>> names = ['Alice', 'Bob', 'Charlie', 'Dan', 'Edith', 'Frank'] + >>> groupby(len, names) # doctest: +SKIP + {3: ['Bob', 'Dan'], 5: ['Alice', 'Edith', 'Frank'], 7: ['Charlie']} + + >>> iseven = lambda x: x % 2 == 0 + >>> groupby(iseven, [1, 2, 3, 4, 5, 6, 7, 8]) # doctest: +SKIP + {False: [1, 3, 5, 7], True: [2, 4, 6, 8]} + + See Also: + ``countby`` + """ + + d = OrderedDict() + for item in seq: + key = func(item) + if key not in d: + d[key] = list() + d[key].append(item) + return d + + +def typename(type): + """Get the name of `type`. + + Parameters + ---------- + type : Union[Type, Tuple[Type]] + + Returns + ------- + str + The name of `type` or a tuple of the names of the types in `type`. + + Examples + -------- + >>> typename(int) + 'int' + >>> typename((int, float)) + '(int, float)' + """ + try: + return type.__name__ + except AttributeError: + if len(type) == 1: + return typename(*type) + return "(%s)" % ", ".join(map(typename, type)) + + +# parse typing.Union +if sys.version_info < (3, 6): + + def parse_union(ann): + if type(ann) is not typing.UnionMeta: + return + return ann.__union_params__ + + +elif sys.version_info < (3, 7): + + def parse_union(ann): + if type(ann) is not typing._Union: + return + return ann.__args__ + + +elif sys.version_info < (3, 8): + + def parse_union(ann): + if type(ann) is not typing._GenericAlias: + if type(ann) is not typing.Union: + return + else: + if ann.__origin__ is not typing.Union: + return + return ann.__args__ + + +else: + + def parse_union(ann): + if typing.get_origin(ann) is not typing.Union: + return + return typing.get_args(ann) diff --git a/imperative/python/megengine/core/tensor/multipledispatch/variadic.py b/imperative/python/megengine/core/tensor/multipledispatch/variadic.py new file mode 100644 index 0000000000000000000000000000000000000000..e8e26068e9ea334324209e2287ec1f2cdeafacc9 --- /dev/null +++ b/imperative/python/megengine/core/tensor/multipledispatch/variadic.py @@ -0,0 +1,95 @@ +from .utils import typename + + +class VariadicSignatureType(type): + # checking if subclass is a subclass of self + def __subclasscheck__(self, subclass): + other_type = subclass.variadic_type if isvariadic(subclass) else (subclass,) + return subclass is self or all( + issubclass(other, self.variadic_type) for other in other_type + ) + + def __eq__(self, other): + """ + Return True if other has the same variadic type + + Parameters + ---------- + other : object (type) + The object (type) to check + + Returns + ------- + bool + Whether or not `other` is equal to `self` + """ + return isvariadic(other) and set(self.variadic_type) == set(other.variadic_type) + + def __hash__(self): + return hash((type(self), frozenset(self.variadic_type))) + + +def isvariadic(obj): + """Check whether the type `obj` is variadic. + + Parameters + ---------- + obj : type + The type to check + + Returns + ------- + bool + Whether or not `obj` is variadic + + Examples + -------- + >>> isvariadic(int) + False + >>> isvariadic(Variadic[int]) + True + """ + return isinstance(obj, VariadicSignatureType) + + +class VariadicSignatureMeta(type): + """A metaclass that overrides ``__getitem__`` on the class. This is used to + generate a new type for Variadic signatures. See the Variadic class for + examples of how this behaves. + """ + + def __getitem__(self, variadic_type): + if not (isinstance(variadic_type, (type, tuple)) or type(variadic_type)): + raise ValueError( + "Variadic types must be type or tuple of types" + " (Variadic[int] or Variadic[(int, float)]" + ) + + if not isinstance(variadic_type, tuple): + variadic_type = (variadic_type,) + return VariadicSignatureType( + "Variadic[%s]" % typename(variadic_type), + (), + dict(variadic_type=variadic_type, __slots__=()), + ) + + +class Variadic(metaclass=VariadicSignatureMeta): + """A class whose getitem method can be used to generate a new type + representing a specific variadic signature. + + Examples + -------- + >>> Variadic[int] # any number of int arguments + + >>> Variadic[(int, str)] # any number of one of int or str arguments + + >>> issubclass(int, Variadic[int]) + True + >>> issubclass(int, Variadic[(int, str)]) + True + >>> issubclass(str, Variadic[(int, str)]) + True + >>> issubclass(float, Variadic[(int, str)]) + False + """ diff --git a/imperative/python/megengine/core/tensor/raw_tensor/__init__.py b/imperative/python/megengine/core/tensor/raw_tensor/__init__.py index decca86df20b78a6de30341fe20353ebec60373f..da44c689f65fdd5dc9d1fc73845e63d10022ec76 100644 --- a/imperative/python/megengine/core/tensor/raw_tensor/__init__.py +++ b/imperative/python/megengine/core/tensor/raw_tensor/__init__.py @@ -66,13 +66,13 @@ class RawTensor(TensorBase): delete(self._handle) -@apply.add +@apply.register() def _(op: OpDef, *args: RawTensor): outputs = apply_op(op, tuple(i._handle for i in args)) return tuple(map(RawTensor, outputs)) -@apply.add +@apply.register() def _(op: Const, *args: RawTensor): dtype = op.dtype device = as_device(op.device).to_c() diff --git a/imperative/python/megengine/core/tensor/tensor.py b/imperative/python/megengine/core/tensor/tensor.py index 0f2ff9d78121e9f4bb9746ac97494e9949e84c29..7780ea19bc4d565cae09540897e0db3a64a4ceba 100644 --- a/imperative/python/megengine/core/tensor/tensor.py +++ b/imperative/python/megengine/core/tensor/tensor.py @@ -79,7 +79,7 @@ def get_context(): return _context -@apply.add +@apply.register() def tensor_apply(op: OpBase, *args: Tensor): data = tuple(i._data if isinstance(i, Tensor) else i for i in args) # type(Tensor._data) is RawTensor diff --git a/imperative/python/megengine/functional/distributed.py b/imperative/python/megengine/functional/distributed.py index ebb81cf7a2ab8d094a85343d8fa81d27aa6349f2..c5270571635989d32c5481352b607c4e71dec998 100644 --- a/imperative/python/megengine/functional/distributed.py +++ b/imperative/python/megengine/functional/distributed.py @@ -46,7 +46,7 @@ __all__ = [ ] -@apply.add +@apply.register() def _(op: RemoteSend, *args: Tensor): ret = tensor_apply(op, *args) diff --git a/imperative/python/requires.txt b/imperative/python/requires.txt index a2d8a55df078d1b40a2d8b46a75035f485c0a263..3e9b44c98ba3afb32c32b636b0bfca2f56a742b1 100644 --- a/imperative/python/requires.txt +++ b/imperative/python/requires.txt @@ -1,5 +1,4 @@ numpy>=1.18 -multipledispatch==0.6.0 opencv-python pyarrow requests diff --git a/imperative/python/src/dispatcher.cpp b/imperative/python/src/dispatcher.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e54ffc9b0f218b5cec88ebda6c9d1a14b1d71f36 --- /dev/null +++ b/imperative/python/src/dispatcher.cpp @@ -0,0 +1,180 @@ +#include "./dispatcher.h" +#include "./pyext17.h" +#include "megbrain/utils/hash.h" +#include "megbrain/utils/small_vector.h" + +#include +#include + +namespace py = pybind11; +namespace pyx = pyext17; + +namespace { + +struct Handler { + PyObject* func; // borrowed + bool enabled; + + Handler() = default; + Handler(PyObject* func_, bool enable = true) : func(func_), enabled(enable) {} +}; + +using FastSig = mgb::SmallVector; +using MRO = std::vector; + +struct Frame { + MRO* mro; + size_t mro_offset; + + Frame() = default; + Frame(MRO* mro_, size_t mro_offset_ = 0) : mro(mro_), mro_offset(mro_offset_) {} +}; + +struct FastSigHash { + size_t operator()(const FastSig& sig) const { + auto* ptr = &sig.front(); + return mgb::XXHash() + .update(ptr, sig.size() * sizeof(FastSig::value_type)) + .digest(); + } +}; + +struct ObjectIdHash : std::hash { + size_t operator()(const py::handle& h) const { + return std::hash::operator()(h.ptr()); + } +}; + +struct Dispatcher { + std::unordered_map, FastSigHash> cache; + std::vector stack; + std::unordered_map, ObjectIdHash> registry; + + inline py::handle self() { + return pyx::wrap::pycast(this); + } + + bool prepare_call(PyObject*const* args, Py_ssize_t nargs) { + FastSig sig(nargs); + for (Py_ssize_t i = 0; i < nargs; ++i) { + sig[i] = Py_TYPE(args[i]); + } + auto it = cache.find(sig); + if (it == cache.end()) { + if (auto mro = resolve(sig)) { + it = cache.emplace(std::move(sig), std::move(mro)).first; + } else { + return false; + } + } + stack.emplace_back(it->second.get()); + return true; + } + + template + PyObject* do_call(T&& caller) { + auto& frame = stack.back(); + auto& mro = *frame.mro; + auto& i = frame.mro_offset; + for (; i < mro.size(); ++i) { + if (mro[i]->enabled) { + auto ret = caller(mro[i]->func); + if (ret != Py_NotImplemented) { + stack.pop_back(); + return ret; + } + Py_DECREF(ret); + } + } + PyErr_SetString(PyExc_NotImplementedError, "mro exhausted"); + stack.pop_back(); + return nullptr; + } + + std::unique_ptr resolve(const FastSig& sig) { + try { + py::tuple args(sig.size()); + for (size_t i = 0; i < sig.size(); ++i) { + args[i] = (PyObject*)sig[i]; + } + auto mro_iter = self().attr("dispatch_iter")(*args); + auto ret = std::make_unique(); + for (auto i : mro_iter) { + auto it = registry.find(py::reinterpret_borrow(i)); + if (it == registry.end()) { + PyErr_SetString(PyExc_RuntimeError, "resolved to unregistered function"); + return nullptr; + } + ret->push_back(it->second.get()); + } + return ret; + } catch (py::error_already_set& e) { + e.restore(); + } catch (std::runtime_error& e) { + PyErr_SetString(PyExc_RuntimeError, e.what()); + } + return nullptr; + } + +public: + static constexpr auto tp_name = "Dispatcher"; + + PyObject* tp_vectorcall(PyObject*const* args, Py_ssize_t nargs) { + if (!prepare_call(args, nargs)) return nullptr; + return do_call([=](PyObject* func){return _PyObject_FastCall(func, args, nargs);}); + } + + PyObject* tp_call(PyObject* args, PyObject* kwargs) { + if (!prepare_call(&PyTuple_GET_ITEM(args, 0), PyTuple_GET_SIZE(args))) return nullptr; + return do_call([=](PyObject* func){return PyObject_Call(func, args, kwargs);}); + } + + PyObject* super(PyObject*const* args, Py_ssize_t nargs) { + if (stack.empty()) { + PyErr_SetString(PyExc_RuntimeError, "super called at top level"); + return nullptr; + } + stack.emplace_back(stack.back()).mro_offset++; + return do_call([=](PyObject* func){return _PyObject_FastCall(func, args, nargs);}); + } + + void enable(PyObject* func) { + auto obj = py::reinterpret_borrow(func); + auto it = registry.find(obj); + if (it != registry.end()) { + it->second->enabled = true; + } else { + registry.emplace(std::move(obj), std::make_unique(func)); + } + } + + PyObject* disable(PyObject* func) { + auto obj = py::reinterpret_borrow(func); + auto it = registry.find(obj); + if (it == registry.end()) { + PyErr_SetString(PyExc_ValueError, "function not registered"); + return nullptr; + } else { + it->second->enabled = false; + } + Py_RETURN_NONE; + } + + void clear_cache() { + cache.clear(); + } +}; + +} // namespace + +void init_dispatcher(py::module m) { + auto* dispatcher_type = pyx::wrap::type() + .def<&Dispatcher::enable>("enable") + .def<&Dispatcher::disable>("disable") + .def<&Dispatcher::clear_cache>("clear_cache") + .def<&Dispatcher::tp_vectorcall>("call") + .def<&Dispatcher::super>("super") + .finalize(); + if (!dispatcher_type) throw py::error_already_set(); + m.attr("Dispatcher") = dispatcher_type; +} diff --git a/imperative/python/src/dispatcher.h b/imperative/python/src/dispatcher.h new file mode 100644 index 0000000000000000000000000000000000000000..963afa6b78b2706904f99b2fa53c7f3d17051d06 --- /dev/null +++ b/imperative/python/src/dispatcher.h @@ -0,0 +1,5 @@ +#pragma once + +#include + +void init_dispatcher(pybind11::module); diff --git a/imperative/python/src/module.cpp b/imperative/python/src/module.cpp index 447e6ac2ff0fb170bb6c143b28766ff9313bf29c..a0710efc303e6844fe5834fbb05c3deca5fbf5da 100644 --- a/imperative/python/src/module.cpp +++ b/imperative/python/src/module.cpp @@ -21,6 +21,8 @@ #include "./graph_rt.h" #include "./ops.h" +#include "./dispatcher.h" + namespace py = pybind11; #ifndef MODULE_NAME @@ -63,4 +65,6 @@ PYBIND11_MODULE(MODULE_NAME, m) { from .graph import * )", py::getattr(m, "__dict__")); + + init_dispatcher(submodule(m, "dispatcher")); } diff --git a/imperative/python/src/pyext17.h b/imperative/python/src/pyext17.h new file mode 100644 index 0000000000000000000000000000000000000000..c7ad3646be80153bae559107916b20926e53ce82 --- /dev/null +++ b/imperative/python/src/pyext17.h @@ -0,0 +1,270 @@ +#pragma once + +#include +#include +#include +#include + +namespace pyext17 { + +#ifdef METH_FASTCALL +constexpr bool has_fastcall = true; +#else +constexpr bool has_fastcall = false; +#endif + +template +struct invocable_with { + template + constexpr bool operator()(T&& lmb) { + return std::is_invocable_v; + } +}; + +#define HAS_MEMBER_TYPE(T, U) invocable_with{}([](auto&& x) -> typename std::decay_t::U {}) +#define HAS_MEMBER(T, m) invocable_with{}([](auto&& x) -> decltype(&std::decay_t::m) {}) + +inline PyObject* cvt_retval(PyObject* rv) { + return rv; +} + +#define CVT_RET_PYOBJ(...) \ + if constexpr (std::is_same_v) { \ + __VA_ARGS__; \ + Py_RETURN_NONE; \ + } else { \ + return cvt_retval(__VA_ARGS__); \ + } + +template +struct wrap { +private: + typedef wrap wrap_t; + +public: + PyObject_HEAD + std::aligned_storage_t storage; + + inline T* inst() { + return reinterpret_cast(&storage); + } + + inline static PyObject* pycast(T* ptr) { + return (PyObject*)((char*)ptr - offsetof(wrap_t, storage)); + } + +private: + // method wrapper + + enum struct meth_type { + noarg, + varkw, + fastcall, + singarg + }; + + template + struct detect_meth_type { + static constexpr meth_type value = []() { + using F = decltype(f); + static_assert(std::is_member_function_pointer_v); + if constexpr (std::is_invocable_v) { + return meth_type::noarg; + } else if constexpr (std::is_invocable_v) { + return meth_type::varkw; + } else if constexpr (std::is_invocable_v) { + return meth_type::fastcall; + } else if constexpr (std::is_invocable_v) { + return meth_type::singarg; + } else { + static_assert(!std::is_same_v); + } + }(); + }; + + template + struct meth {}; + + template + struct meth { + static constexpr int flags = METH_NOARGS; + + static PyObject* impl(PyObject* self, PyObject*) { + auto* inst = reinterpret_cast(self)->inst(); + CVT_RET_PYOBJ((inst->*f)()); + } + }; + + template + struct meth { + static constexpr int flags = METH_VARARGS | METH_KEYWORDS; + + static PyObject* impl(PyObject* self, PyObject* args, PyObject* kwargs) { + auto* inst = reinterpret_cast(self)->inst(); + CVT_RET_PYOBJ((inst->*f)(args, kwargs)); + } + }; + + template + struct meth { + #ifdef METH_FASTCALL + static constexpr int flags = METH_FASTCALL; + + static PyObject* impl(PyObject* self, PyObject*const* args, Py_ssize_t nargs) { + auto* inst = reinterpret_cast(self)->inst(); + CVT_RET_PYOBJ((inst->*f)(args, nargs)); + } + #else + static constexpr int flags = METH_VARARGS; + + static PyObject* impl(PyObject* self, PyObject* args) { + auto* inst = reinterpret_cast(self)->inst(); + auto* arr = &PyTuple_GET_ITEM(args, 0); + auto size = PyTuple_GET_SIZE(args); + CVT_RET_PYOBJ((inst->*f)(arr, size)); + } + #endif + }; + + template + struct meth { + static constexpr int flags = METH_O; + + static PyObject* impl(PyObject* self, PyObject* obj) { + auto* inst = reinterpret_cast(self)->inst(); + CVT_RET_PYOBJ((inst->*f)(obj)); + } + }; + + template + static constexpr PyMethodDef make_meth_def(const char* name, const char* doc = nullptr) { + using M = meth::value, f>; + return {name, (PyCFunction)M::impl, M::flags, doc}; + } + + // polyfills + + struct tp_new { + static constexpr bool provided = HAS_MEMBER(T, tp_new); + static constexpr bool varkw = std::is_constructible_v; + static constexpr bool noarg = std::is_default_constructible_v; + + template + static PyObject* impl(PyTypeObject* type, PyObject* args, PyObject* kwargs) { + auto* self = type->tp_alloc(type, 0); + auto* ptr = reinterpret_cast(self)->inst(); + if constexpr (varkw) { + new(ptr) T(args, kwargs); + } else { + new(ptr) T(); + } + return self; + } + + static constexpr newfunc value = []() {if constexpr (provided) return T::tp_new; + else if constexpr (varkw || noarg) return impl<>; + else return nullptr;}(); + }; + + struct tp_dealloc { + static constexpr bool provided = HAS_MEMBER(T, tp_dealloc); + + template + static void impl(PyObject* self) { + reinterpret_cast(self)->inst()->~T(); + Py_TYPE(self)->tp_free(self); + } + + static constexpr destructor value = []() {if constexpr (provided) return T::tp_dealloc; + else return impl<>;}(); + }; + + struct tp_call { + static constexpr bool valid = HAS_MEMBER(T, tp_call); + static constexpr bool static_form = invocable_with{}( + [](auto&& t, auto... args) -> decltype(std::decay_t::tp_call(args...)) {}); + + template + static PyObject* impl(PyObject* self, PyObject* args, PyObject* kwargs) { + auto* inst = reinterpret_cast(self)->inst(); + CVT_RET_PYOBJ(inst->tp_call(args, kwargs)); + } + + static constexpr ternaryfunc value = []() {if constexpr (static_form) return T::tp_call; + else if constexpr (valid) return impl<>; + else return nullptr;}(); + }; + +public: + class TypeBuilder { + std::vector m_methods; + PyTypeObject m_type; + bool m_finalized = false; + bool m_ready = false; + + void check_finalized() { + if (m_finalized) { + throw std::runtime_error("type is already finalized"); + } + } + public: + TypeBuilder(const TypeBuilder&) = delete; + TypeBuilder& operator=(const TypeBuilder&) = delete; + + TypeBuilder() : m_type{PyVarObject_HEAD_INIT(nullptr, 0)} { + // static_assert(HAS_MEMBER(T, tp_name)); + if constexpr (HAS_MEMBER(T, tp_name)) { + m_type.tp_name = T::tp_name; + } + m_type.tp_dealloc = tp_dealloc::value; + m_type.tp_call = tp_call::value; + m_type.tp_basicsize = sizeof(wrap_t); + m_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; + m_type.tp_new = tp_new::value; + } + + PyTypeObject* operator->() { + return &m_type; + } + + bool ready() const { + return m_ready; + } + + PyObject* finalize() { + if (!m_finalized) { + if (m_methods.size()) { + m_methods.push_back({0}); + if (m_type.tp_methods) { + PyErr_SetString(PyExc_SystemError, "tp_method is already set"); + return nullptr; + } + m_type.tp_methods = &m_methods[0]; + } + if (PyType_Ready(&m_type)) { + return nullptr; + } + m_ready = true; + } + return (PyObject*)&m_type; + } + + template + TypeBuilder& def(const char* name, const char* doc = nullptr) { + check_finalized(); + m_methods.push_back(make_meth_def(name, doc)); + return *this; + } + }; + + static TypeBuilder& type() { + static TypeBuilder type_helper; + return type_helper; + } +}; + +} // namespace pyext17 + +#undef HAS_MEMBER_TYPE +#undef HAS_MEMBER +#undef CVT_RET_PYOBJ diff --git a/imperative/python/test/unit/test_dispatch.py b/imperative/python/test/unit/test_dispatch.py new file mode 100644 index 0000000000000000000000000000000000000000..ce74341da20bcce063aa2d197d390729b3f31ccf --- /dev/null +++ b/imperative/python/test/unit/test_dispatch.py @@ -0,0 +1,58 @@ +from megengine.core.tensor.multipledispatch import Dispatcher + + +def test_register_many(): + f = Dispatcher("f") + + log = [] + + @f.register() + def _(x: int): + log.append("a") + return log[-1] + + @f.register() + def _(x: int): + log.append("b") + return log[-1] + + assert f(0) == "b" + assert log == ["b"] + + +def test_return_not_implemented(): + f = Dispatcher("f") + + log = [] + + @f.register() + def _(x: int): + log.append("a") + return log[-1] + + @f.register() + def _(x: int): + log.append("b") + return NotImplemented + + assert f(0) == "a" + assert log == ["b", "a"] + + +def test_super(): + f = Dispatcher("f") + + log = [] + + @f.register() + def _(x: int): + log.append("a") + return log[-1] + + @f.register() + def _(x: int): + log.append("b") + return f.super(x) + + assert f(0) == "a" + assert log == ["b", "a"]