提交 caf5da8c 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Enable RaggedTensor dispatch for tf.where with x=y=None.

PiperOrigin-RevId: 225408570
上级 27cffd79
......@@ -76,6 +76,8 @@ def _get_arg_infos(func, arg_names):
def _is_convertible_to_tensor(value):
"""Returns true if `value` is convertible to a `Tensor`."""
if value is None:
return True
if isinstance(value,
(ops.Tensor, variables.Variable, np.ndarray, int, float, str)):
return True
......
......@@ -555,6 +555,10 @@ class RaggedElementwiseOpsTest(ragged_test_util.RaggedTensorTestCase,
ragged_factory_ops.constant_value([[b'A', b'B'], [b'C']]),
ragged_factory_ops.constant_value([[b'a', b'b'], [b'c']])),
expected=ragged_factory_ops.constant_value([[b'A', b'b'], [b'C']])),
dict(
op=array_ops.where,
args=(ragged_factory_ops.constant_value([[True, False], [True]]),),
expected=[[0, 0], [1, 0]]),
dict(
op=math_ops.unsorted_segment_sum,
kwargs={
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册