From 57eb92b7781d46f22f57f89f75010b898e236c42 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 12 Dec 2018 10:38:08 -0800 Subject: [PATCH] Internal Change PiperOrigin-RevId: 225212001 --- tensorflow/python/BUILD | 1 + tensorflow/python/ops/ragged/__init__.py | 15 ++++++++++++++- tensorflow/python/ops/ragged/ragged_dispatch.py | 11 +++++++++-- tensorflow/python/ops/standard_ops.py | 6 ++++++ 4 files changed, 30 insertions(+), 3 deletions(-) diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 8a7c001321f..c11df5534de 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -3033,6 +3033,7 @@ py_library( "//tensorflow/python/eager:wrap_function", "//tensorflow/python/ops/distributions", "//tensorflow/python/ops/linalg", + "//tensorflow/python/ops/ragged", ], ) diff --git a/tensorflow/python/ops/ragged/__init__.py b/tensorflow/python/ops/ragged/__init__.py index 3d915ee269b..f23f506e06d 100644 --- a/tensorflow/python/ops/ragged/__init__.py +++ b/tensorflow/python/ops/ragged/__init__.py @@ -66,6 +66,15 @@ class documentation. @@RaggedTensorDynamicShape @@broadcast_to @@broadcast_dynamic_shape + + +@@ragged_dispatch +@@ragged_factory_ops +@@ragged_operators +@@ragged_string_ops +@@ragged_tensor +@@ragged_tensor_value +@@ragged_util """ from __future__ import absolute_import @@ -73,8 +82,12 @@ from __future__ import division from __future__ import print_function from tensorflow.python.ops.ragged import ragged_dispatch +from tensorflow.python.ops.ragged import ragged_factory_ops from tensorflow.python.ops.ragged import ragged_operators from tensorflow.python.ops.ragged import ragged_string_ops +from tensorflow.python.ops.ragged import ragged_tensor +from tensorflow.python.ops.ragged import ragged_tensor_value +from tensorflow.python.ops.ragged import ragged_util from tensorflow.python.ops.ragged.ragged_array_ops import batch_gather from tensorflow.python.ops.ragged.ragged_array_ops import boolean_mask @@ -133,7 +146,7 @@ from tensorflow.python.util import all_util as _all_util # Register OpDispatchers that override standard TF ops to work w/ RaggedTensors. -__doc__ += ragged_dispatch.register_dispatchers() # pylint: disable=redefined-builtin +__doc__ += ragged_dispatch.ragged_op_list() # pylint: disable=redefined-builtin # Any symbol that is not referenced (with "@@name") in the module docstring # above will be removed. diff --git a/tensorflow/python/ops/ragged/ragged_dispatch.py b/tensorflow/python/ops/ragged/ragged_dispatch.py index f334f1fc8e5..77990a8b188 100644 --- a/tensorflow/python/ops/ragged/ragged_dispatch.py +++ b/tensorflow/python/ops/ragged/ragged_dispatch.py @@ -447,10 +447,17 @@ def register_dispatchers(): for (original_op, ragged_op, args) in _RAGGED_DISPATCH_OPS: RaggedDispatcher(original_op, ragged_op, args).register(original_op) - docstring = ( + +def ragged_op_list(): + """Returns a string listing operators that have dispathers registered.""" + op_list = ( + _UNARY_ELEMENTWISE_OPS + _UNARY_LIST_ELEMENTWISE_OPS + + _BINARY_ELEMENTWISE_OPS + [x[0] for x in _RAGGED_DISPATCH_OPS]) + return ( '\n\n### Additional ops that support `RaggedTensor`\n\n' + '\n'.join([ '* `tf.%s`' % tf_export.get_canonical_name_for_symbol(op) for op in op_list ])) - return docstring + +register_dispatchers() diff --git a/tensorflow/python/ops/standard_ops.py b/tensorflow/python/ops/standard_ops.py index 8ef0fe80706..ba3bd094923 100644 --- a/tensorflow/python/ops/standard_ops.py +++ b/tensorflow/python/ops/standard_ops.py @@ -71,6 +71,8 @@ from tensorflow.python.ops.math_ops import * from tensorflow.python.ops.numerics import * from tensorflow.python.ops.parsing_ops import * from tensorflow.python.ops.partitioned_variables import * +from tensorflow.python.ops.ragged import ragged_dispatch as _ragged_dispatch +from tensorflow.python.ops.ragged import ragged_operators as _ragged_operators from tensorflow.python.ops.random_ops import * from tensorflow.python.ops.script_ops import py_func from tensorflow.python.ops.session_ops import * @@ -102,3 +104,7 @@ from tensorflow.python.ops.variable_scope import * from tensorflow.python.ops.variables import * # pylint: enable=wildcard-import # pylint: enable=g-bad-import-order + + +# These modules were imported to set up RaggedTensor operators and dispatchers: +del _ragged_dispatch, _ragged_operators -- GitLab