提交 f55a491b 编写于 作者: S Stephan Hoyer 提交者: TensorFlower Gardener

[tf-numpy] add experimental __array_module__ method

__array_module__ is an experimental protocol for "duck array" compatibility that
for indicating how to find a "numpy compatible" module. The hope is to make it
easier to write generic code that works across a range of array libraries.

A full example, which should work equally work for TF-NumPy as for JAX, can be
found at https://github.com/google/jax/pull/4076. More examples, and motivation for
this protocol can be found at https://numpy.org/neps/nep-0037-array-module.html.

This design has not yet been finalized in NumPy, so at present it requires using
the experimental numpy_dispatch module: https://github.com/seberg/numpy-dispatch.

Unlike NumPy's __array_ufunc__ and __array_function__ protocols, __array_module__
by design has no backwards compatibility consequences. The protocol only controls
the behavior of numpy_dispatch.get_array_module() or numpy.get_array_module() --
it does not change any existing NumPy functions.

PiperOrigin-RevId: 328181308
Change-Id: I98b648a59709ed3bc295aaf97f471a1ff7ae1f05
上级 5c06ed81
......@@ -294,6 +294,19 @@ class ndarray(composite_tensor.CompositeTensor):
# NOTE: we currently prefer interop with TF to allow TF to take precedence.
__array_priority__ = 90
def __array_module__(self, types):
# Experimental support for NumPy's module dispatch with NEP-37:
# https://numpy.org/neps/nep-0037-array-module.html
# Currently requires https://github.com/seberg/numpy-dispatch
# pylint: disable=g-import-not-at-top
import tensorflow.compat.v2 as tf
if all(issubclass(t, (ndarray, np.ndarray)) for t in types):
return tf.experimental.numpy
else:
return NotImplemented
def __index__(self):
"""Returns a python scalar.
......
......@@ -174,6 +174,17 @@ class InteropTest(tf.test.TestCase):
self.assertIsInstance(sq, onp.ndarray)
self.assertEqual(100., sq[0])
def testArrayModule(self):
arr = np.asarray([10])
module = arr.__array_module__((np.ndarray,))
self.assertIs(module, tf.experimental.numpy)
class Dummy:
pass
module = arr.__array_module__((np.ndarray, Dummy))
self.assertIs(module, NotImplemented)
# TODO(nareshmodi): Fails since the autopacking code doesn't use
# nest.flatten.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册