From caf5da8cc5a14266c0f42993687b40ac923ba02c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 13 Dec 2018 11:45:07 -0800 Subject: [PATCH] Enable RaggedTensor dispatch for tf.where with x=y=None. PiperOrigin-RevId: 225408570 --- tensorflow/python/ops/ragged/ragged_dispatch.py | 2 ++ tensorflow/python/ops/ragged/ragged_dispatch_test.py | 4 ++++ 2 files changed, 6 insertions(+) diff --git a/tensorflow/python/ops/ragged/ragged_dispatch.py b/tensorflow/python/ops/ragged/ragged_dispatch.py index 77990a8b188..ecc7f5d611f 100644 --- a/tensorflow/python/ops/ragged/ragged_dispatch.py +++ b/tensorflow/python/ops/ragged/ragged_dispatch.py @@ -76,6 +76,8 @@ def _get_arg_infos(func, arg_names): def _is_convertible_to_tensor(value): """Returns true if `value` is convertible to a `Tensor`.""" + if value is None: + return True if isinstance(value, (ops.Tensor, variables.Variable, np.ndarray, int, float, str)): return True diff --git a/tensorflow/python/ops/ragged/ragged_dispatch_test.py b/tensorflow/python/ops/ragged/ragged_dispatch_test.py index fb3dabc3eb8..9d70470f05a 100644 --- a/tensorflow/python/ops/ragged/ragged_dispatch_test.py +++ b/tensorflow/python/ops/ragged/ragged_dispatch_test.py @@ -555,6 +555,10 @@ class RaggedElementwiseOpsTest(ragged_test_util.RaggedTensorTestCase, ragged_factory_ops.constant_value([[b'A', b'B'], [b'C']]), ragged_factory_ops.constant_value([[b'a', b'b'], [b'c']])), expected=ragged_factory_ops.constant_value([[b'A', b'b'], [b'C']])), + dict( + op=array_ops.where, + args=(ragged_factory_ops.constant_value([[True, False], [True]]),), + expected=[[0, 0], [1, 0]]), dict( op=math_ops.unsorted_segment_sum, kwargs={ -- GitLab