提交 73ee76cc 编写于 作者: T TensorFlower Gardener

Merge pull request #61491 from benbarsdell:upstream-SparseSegmentReduceGradV2-part3

PiperOrigin-RevId: 564529533
......@@ -74,11 +74,22 @@
### Bug Fixes and Other Changes
* <SIMILAR TO ABOVE SECTION, BUT FOR OTHER IMPORTANT CHANGES / BUG FIXES>
* <IF A CHANGE CLOSES A GITHUB ISSUE, IT SHOULD BE DOCUMENTED HERE>
* <NOTES SHOULD BE GROUPED PER AREA>
* <SIMILAR TO ABOVE SECTION, BUT FOR OTHER IMPORTANT CHANGES / BUG FIXES>
* <IF A CHANGE CLOSES A GITHUB ISSUE, IT SHOULD BE DOCUMENTED HERE>
* <NOTES SHOULD BE GROUPED PER AREA>
* Add TensorFlow Quantizer to TensorFlow pip package.
* `tf.sparse.segment_sum` `tf.sparse.segment_mean` `tf.sparse.segment_sqrt_n`
`SparseSegmentSum/Mean/SqrtN[WithNumSegments]`
* Added `sparse_gradient` option (default=false) that makes the gradient
of these functions/ops sparse (`IndexedSlices`) instead of dense
(`Tensor`), using new `SparseSegmentSum/Mean/SqrtNGradV2` ops.
* `tf.nn.embedding_lookup_sparse`
* Add TensorFlow Quantizer to TensorFlow pip package.
* Optimized this function for some cases by fusing internal operations.
## Keras
......
......@@ -16993,7 +16993,9 @@ dimension, selecting a subset of dimension 0, specified by `indices`.
let arguments = (ins
TF_FloatTensor:$data,
Arg<TF_I32OrI64Tensor, [{A 1-D tensor. Has same rank as `segment_ids`.}]>:$indices,
Arg<TF_I32OrI64Tensor, [{A 1-D tensor. Values should be sorted and can be repeated.}]>:$segment_ids
Arg<TF_I32OrI64Tensor, [{A 1-D tensor. Values should be sorted and can be repeated.}]>:$segment_ids,
DefaultValuedOptionalAttr<BoolAttr, "false">:$sparse_gradient
);
let results = (outs
......@@ -17046,7 +17048,9 @@ for an explanation of segments.
TF_FloatTensor:$data,
Arg<TF_I32OrI64Tensor, [{A 1-D tensor. Has same rank as `segment_ids`.}]>:$indices,
Arg<TF_I32OrI64Tensor, [{A 1-D tensor. Values should be sorted and can be repeated.}]>:$segment_ids,
Arg<TF_I32OrI64Tensor, [{Should equal the number of distinct segment IDs.}]>:$num_segments
Arg<TF_I32OrI64Tensor, [{Should equal the number of distinct segment IDs.}]>:$num_segments,
DefaultValuedOptionalAttr<BoolAttr, "false">:$sparse_gradient
);
let results = (outs
......@@ -17074,7 +17078,9 @@ See `tf.sparse.segment_sum` for usage examples.
let arguments = (ins
TF_FloatTensor:$data,
Arg<TF_I32OrI64Tensor, [{A 1-D tensor. Has same rank as `segment_ids`.}]>:$indices,
Arg<TF_I32OrI64Tensor, [{A 1-D tensor. Values should be sorted and can be repeated.}]>:$segment_ids
Arg<TF_I32OrI64Tensor, [{A 1-D tensor. Values should be sorted and can be repeated.}]>:$segment_ids,
DefaultValuedOptionalAttr<BoolAttr, "false">:$sparse_gradient
);
let results = (outs
......@@ -17131,7 +17137,9 @@ for an explanation of segments.
TF_FloatTensor:$data,
Arg<TF_I32OrI64Tensor, [{A 1-D tensor. Has same rank as `segment_ids`.}]>:$indices,
Arg<TF_I32OrI64Tensor, [{A 1-D tensor. Values should be sorted and can be repeated.}]>:$segment_ids,
Arg<TF_I32OrI64Tensor, [{Should equal the number of distinct segment IDs.}]>:$num_segments
Arg<TF_I32OrI64Tensor, [{Should equal the number of distinct segment IDs.}]>:$num_segments,
DefaultValuedOptionalAttr<BoolAttr, "false">:$sparse_gradient
);
let results = (outs
......@@ -17183,7 +17191,9 @@ tf.segment_sum(c, tf.constant([0, 0, 1]))
let arguments = (ins
TF_IntOrFpTensor:$data,
Arg<TF_I32OrI64Tensor, [{A 1-D tensor. Has same rank as `segment_ids`.}]>:$indices,
Arg<TF_I32OrI64Tensor, [{A 1-D tensor. Values should be sorted and can be repeated.}]>:$segment_ids
Arg<TF_I32OrI64Tensor, [{A 1-D tensor. Values should be sorted and can be repeated.}]>:$segment_ids,
DefaultValuedOptionalAttr<BoolAttr, "false">:$sparse_gradient
);
let results = (outs
......
......@@ -1417,6 +1417,7 @@ REGISTER_OP("SparseSegmentSum")
.Attr("T: realnumbertype")
.Attr("Tidx: {int32, int64} = DT_INT32")
.Attr("Tsegmentids: {int32, int64} = DT_INT32")
.Attr("sparse_gradient: bool = false")
.SetShapeFn(SparseSegmentReductionShapeFn);
REGISTER_OP("SparseSegmentSumWithNumSegments")
......@@ -1429,6 +1430,7 @@ REGISTER_OP("SparseSegmentSumWithNumSegments")
.Attr("Tidx: {int32, int64} = DT_INT32")
.Attr("Tnumsegments: {int32,int64} = DT_INT32")
.Attr("Tsegmentids: {int32, int64} = DT_INT32")
.Attr("sparse_gradient: bool = false")
.SetShapeFn(SparseSegmentReductionWithNumSegmentsShapeFn);
REGISTER_OP("SparseSegmentSumGrad")
......@@ -1464,6 +1466,7 @@ REGISTER_OP("SparseSegmentMean")
.Attr("T: {bfloat16, half, float, double}")
.Attr("Tidx: {int32, int64} = DT_INT32")
.Attr("Tsegmentids: {int32, int64} = DT_INT32")
.Attr("sparse_gradient: bool = false")
.SetShapeFn(SparseSegmentReductionShapeFn);
REGISTER_OP("SparseSegmentMeanWithNumSegments")
......@@ -1476,6 +1479,7 @@ REGISTER_OP("SparseSegmentMeanWithNumSegments")
.Attr("Tidx: {int32, int64} = DT_INT32")
.Attr("Tnumsegments: {int32,int64} = DT_INT32")
.Attr("Tsegmentids: {int32, int64} = DT_INT32")
.Attr("sparse_gradient: bool = false")
.SetShapeFn(SparseSegmentReductionWithNumSegmentsShapeFn);
REGISTER_OP("SparseSegmentMeanGrad")
......@@ -1509,6 +1513,7 @@ REGISTER_OP("SparseSegmentSqrtN")
.Attr("T: {bfloat16, half, float, double}")
.Attr("Tidx: {int32, int64} = DT_INT32")
.Attr("Tsegmentids: {int32, int64} = DT_INT32")
.Attr("sparse_gradient: bool = false")
.SetShapeFn(SparseSegmentReductionShapeFn);
REGISTER_OP("SparseSegmentSqrtNWithNumSegments")
......@@ -1521,6 +1526,7 @@ REGISTER_OP("SparseSegmentSqrtNWithNumSegments")
.Attr("Tidx: {int32, int64} = DT_INT32")
.Attr("Tnumsegments: {int32,int64} = DT_INT32")
.Attr("Tsegmentids: {int32, int64} = DT_INT32")
.Attr("sparse_gradient: bool = false")
.SetShapeFn(SparseSegmentReductionWithNumSegmentsShapeFn);
REGISTER_OP("SparseSegmentSqrtNGrad")
......
......@@ -797,9 +797,11 @@ cuda_py_strict_test(
":forwardprop",
":test",
"//tensorflow/python:pywrap_tfe",
"//tensorflow/python/compat",
"//tensorflow/python/framework:constant_op",
"//tensorflow/python/framework:dtypes",
"//tensorflow/python/framework:ops",
"//tensorflow/python/framework:sparse_tensor",
"//tensorflow/python/framework:tensor_shape",
"//tensorflow/python/framework:tensor_spec",
"//tensorflow/python/framework:test_lib",
......@@ -808,6 +810,7 @@ cuda_py_strict_test(
"//tensorflow/python/ops:cond",
"//tensorflow/python/ops:embedding_ops",
"//tensorflow/python/ops:functional_ops",
"//tensorflow/python/ops:gradients",
"//tensorflow/python/ops:math_ops",
"//tensorflow/python/ops:math_ops_gen",
"//tensorflow/python/ops:nn_ops",
......
......@@ -33,6 +33,7 @@ import time
import numpy as np
from tensorflow.python import pywrap_tfe
from tensorflow.python.compat import compat as forward_compat
from tensorflow.python.eager import backprop # pylint: disable=unused-import
from tensorflow.python.eager import benchmarks_test_base
from tensorflow.python.eager import context
......@@ -43,6 +44,7 @@ from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_util
......@@ -52,6 +54,7 @@ from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import gradients
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import random_ops
......@@ -101,7 +104,6 @@ def run_benchmark(func, num_iters, execution_mode=None):
class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase):
def __init__(self):
# used for multiply benchmarks
self._m_2 = random_ops.random_uniform([2])
......@@ -1685,6 +1687,60 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase):
allow_fast_lookup=True
)
def _RandomIdsAndWeights(self, batch_size, vocab_size, max_val_per_entry):
vals_per_batch_entry = np.random.randint(
1, max_val_per_entry, size=batch_size
)
num_vals = np.sum(vals_per_batch_entry)
ids = np.random.randint(vocab_size, size=num_vals)
weights = 1 + np.random.rand(num_vals)
indices = []
for batch_entry, num_val in enumerate(vals_per_batch_entry):
for val_index in range(num_val):
indices.append([batch_entry, val_index])
shape = [batch_size, max_val_per_entry]
sp_ids = sparse_tensor.SparseTensor(
constant_op.constant(indices, dtypes.int64),
constant_op.constant(ids, dtypes.int32),
constant_op.constant(shape, dtypes.int64),
)
sp_weights = sparse_tensor.SparseTensor(
constant_op.constant(indices, dtypes.int64),
constant_op.constant(weights, dtypes.float32),
constant_op.constant(shape, dtypes.int64),
)
return sp_ids, sp_weights
def _benchmark_embedding_lookup_sparse_with_gradient(
self, params, batch_size, max_val_per_entry, device
):
def func(sp_ids):
with forward_compat.forward_compatibility_horizon(2023, 9, 26):
with gradients.GradientTape() as g:
y = embedding_ops.embedding_lookup_sparse(params, sp_ids, None)
params_grad = g.gradient(y, params)
return params_grad
vocab_size = params.get_shape()[0]
with context.device(device):
sp_ids, _ = self._RandomIdsAndWeights(
batch_size, vocab_size, max_val_per_entry
)
func(sp_ids)
self._run(lambda: func(sp_ids), num_iters=2000)
def benchmark_embedding_lookup_sparse_with_gradient(self):
params = random_ops.random_uniform((1024 * 1024, 16))
params = params.gpu()
params = resource_variable_ops.ResourceVariable(params)
self._benchmark_embedding_lookup_sparse_with_gradient(
params, batch_size=32768, max_val_per_entry=64, device=GPU
)
if __name__ == "__main__":
test.main()
......@@ -463,9 +463,11 @@ cuda_py_strict_test(
"//tensorflow/python/framework:constant_op",
"//tensorflow/python/framework:errors",
"//tensorflow/python/framework:for_generated_wrappers",
"//tensorflow/python/framework:indexed_slices",
"//tensorflow/python/framework:test_lib",
"//tensorflow/python/ops:gradient_checker",
"//tensorflow/python/ops:gradient_checker_v2",
"//tensorflow/python/ops:gradients",
"//tensorflow/python/ops:math_ops",
"//tensorflow/python/ops:variables",
"//tensorflow/python/platform:client_testlib",
......
......@@ -23,10 +23,12 @@ from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes as dtypes_lib
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import indexed_slices
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import gradient_checker_v2
from tensorflow.python.ops import gradients
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
......@@ -1110,6 +1112,32 @@ class SparseSegmentReductionOpTest(SparseSegmentReductionHelper):
self.assertAllClose(tf_xgrad, np_xgrad)
self.assertAllClose(tf_unique_indices, np_unique_indices)
@test_util.run_deprecated_v1
def testSparseGradient(self):
shape = [10, 4]
segment_indices = [0, 1, 2, 2]
num_indices = len(segment_indices)
for tf_op in [math_ops.sparse_segment_sum, math_ops.sparse_segment_mean]:
with self.cached_session():
tf_indices, _, tf_x, np_x = self._sparse_input(
shape, num_indices, dtype=dtypes_lib.float64
)
with gradients.GradientTape() as g:
g.watch(tf_x)
s = tf_op(
data=tf_x,
indices=tf_indices,
segment_ids=segment_indices,
sparse_gradient=True,
)
tf_dx = g.gradient(s, tf_x)
self.assertIsInstance(tf_dx, indexed_slices.IndexedSlices)
jacob_t, jacob_n = gradient_checker.compute_gradient(
tf_x, shape, s, [3, 4], x_init_value=np_x.astype(np.double), delta=1
)
self.assertAllClose(jacob_t, jacob_n)
def testGradientValid(self):
# Baseline for the testGradient*Invalid* methods below.
tf_x, _ = self._input([3, 4], dtype=dtypes_lib.float32)
......
......@@ -476,8 +476,11 @@ cuda_py_strict_test(
"no_cuda_asan", # Size limit: b/192505612
],
deps = [
"//tensorflow/python/compat",
"//tensorflow/python/framework:constant_op",
"//tensorflow/python/framework:dtypes",
"//tensorflow/python/framework:for_generated_wrappers",
"//tensorflow/python/framework:ops",
"//tensorflow/python/framework:sparse_tensor",
"//tensorflow/python/framework:test_lib",
"//tensorflow/python/ops:array_ops",
......@@ -485,10 +488,13 @@ cuda_py_strict_test(
"//tensorflow/python/ops:data_flow_ops",
"//tensorflow/python/ops:embedding_ops",
"//tensorflow/python/ops:gradient_checker",
"//tensorflow/python/ops:gradients",
"//tensorflow/python/ops:init_ops",
"//tensorflow/python/ops:linalg_ops",
"//tensorflow/python/ops:math_ops",
"//tensorflow/python/ops:partitioned_variables",
"//tensorflow/python/ops:resource_variable_ops",
"//tensorflow/python/ops:sort_ops",
"//tensorflow/python/ops:state_ops",
"//tensorflow/python/ops:variable_scope",
"//tensorflow/python/ops:variables",
......
......@@ -16,10 +16,11 @@
import itertools
import math
from absl.testing import parameterized
from absl.testing import parameterized
import numpy as np
from tensorflow.python.compat import compat as forward_compat
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
......@@ -30,10 +31,13 @@ from tensorflow.python.ops import array_ops_stack
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import gradients
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import sort_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
......@@ -652,7 +656,7 @@ class EmbeddingLookupTest(test.TestCase):
# tensorflow/python/ops/embedding_ops_test.py
class EmbeddingLookupSparseTest(test.TestCase, parameterized.TestCase):
def _RandomIdsAndWeights(self, batch_size, vocab_size, ragged):
def _RandomIdsAndWeights(self, batch_size, vocab_size, ragged=False):
max_val_per_entry = 6
vals_per_batch_entry = np.random.randint(
1, max_val_per_entry, size=batch_size)
......@@ -909,6 +913,104 @@ class EmbeddingLookupSparseTest(test.TestCase, parameterized.TestCase):
sp_weights,
)
def _SortByKey(self, keys, vals):
perm = sort_ops.argsort(keys)
return array_ops.gather(keys, perm), array_ops.gather(vals, perm)
def _ExpectedSparseGradient(
self, nnz, param_shape, np_type, sp_ids, sp_weights, combiner
):
"""Returns the expected indices and values corresponding to the (sparse)
gradient of a sparse embedding lookup.
"""
expected_values = np.ones([nnz] + param_shape, dtype=np_type)
segment_ids = sp_ids.indices[:, 0]
ignore_weights = sp_weights is None
weights = (
array_ops.ones(nnz, dtype=dtypes.float32)
if ignore_weights
else sp_weights.values
)
if combiner == "sqrtn":
weights = weights**2
segment_weights = math_ops.segment_sum(weights, segment_ids)
if combiner != "sum":
grad_scale = 1.0 / array_ops.gather(segment_weights, segment_ids)
if combiner == "sqrtn":
grad_scale = math_ops.sqrt(grad_scale)
expected_values *= grad_scale[:, None]
if not ignore_weights:
expected_values *= sp_weights.values[:, None]
expected_indices = sp_ids.values
# Sort and deduplicate the indices in the expected sparse tensor.
expected_indices, expected_values = self._SortByKey(
expected_indices, expected_values
)
expected_indices, unique_mapping = array_ops.unique(expected_indices)
expected_values = math_ops.segment_sum(expected_values, unique_mapping)
return expected_indices, expected_values
def testResourceVariableGradientEmbeddingLookupSparse(self):
"""Explicitly checks the gradient of a sparse embedding lookup with
ResourceVariable input.
"""
vocab_size = 128
batch_size = 32
param_shape = [16]
sp_ids, sp_weights, _, _, _ = self._RandomIdsAndWeights(
batch_size, vocab_size
)
for combiner, dtype, ignore_weights in itertools.product(
["sum", "mean", "sqrtn"],
[dtypes.float32, dtypes.float64],
[True, False],
):
with self.test_session(), forward_compat.forward_compatibility_horizon(
2023, 9, 26
):
x_shape = [vocab_size] + param_shape
np_type = "f" if dtype == dtypes.float32 else "d"
x = np.random.uniform(size=x_shape).astype(np_type) + 1
x = resource_variable_ops.ResourceVariable(x)
self.evaluate(variables.global_variables_initializer())
def forward(x_):
y_ = embedding_ops.embedding_lookup_sparse(
x_,
sp_ids,
None if ignore_weights else sp_weights,
combiner=combiner,
)
return y_
with gradients.GradientTape() as g:
y = forward(x)
dx = g.gradient(y, x)
self.assertAllEqual(dx.dense_shape, x_shape)
actual_indices, actual_values = dx.indices, dx.values
# The sort order of the output is not guaranteed, so we must sort it
# into a consistent order before comparing.
actual_indices, actual_values = self._SortByKey(
actual_indices, actual_values
)
nnz = sp_ids.values.get_shape()[0]
expected_indices, expected_values = self._ExpectedSparseGradient(
nnz,
param_shape,
np_type,
sp_ids,
None if ignore_weights else sp_weights,
combiner,
)
self.assertAllEqual(actual_indices, expected_indices)
self.assertAllClose(actual_values, expected_values)
class SafeEmbeddingLookupSparseTest(test.TestCase, parameterized.TestCase):
......
......@@ -1566,12 +1566,15 @@ py_strict_library(
":resource_variable_ops",
":sparse_ops",
":variables",
"//tensorflow/python/compat",
"//tensorflow/python/framework:composite_tensor",
"//tensorflow/python/framework:constant_op",
"//tensorflow/python/framework:dtypes",
"//tensorflow/python/framework:indexed_slices",
"//tensorflow/python/framework:ops",
"//tensorflow/python/framework:sparse_tensor",
"//tensorflow/python/framework:tensor_shape",
"//tensorflow/python/types:core",
"//tensorflow/python/util:dispatch",
"//tensorflow/python/util:tf_export",
],
......@@ -2010,6 +2013,7 @@ py_strict_library(
"//tensorflow/python/eager:context",
"//tensorflow/python/framework:constant_op",
"//tensorflow/python/framework:dtypes",
"//tensorflow/python/framework:indexed_slices",
"//tensorflow/python/framework:ops",
"//tensorflow/python/framework:tensor",
"//tensorflow/python/framework:tensor_util",
......
......@@ -14,6 +14,8 @@
# ==============================================================================
"""Operations for embeddings."""
from tensorflow.python.compat import compat
from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import indexed_slices
......@@ -30,6 +32,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import variables
from tensorflow.python.types import core
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export
......@@ -1007,9 +1010,33 @@ def embedding_lookup_sparse_impl(
name,
):
"""Implementation of sparse embedding aggregation."""
if len(params) == 1 and max_norm is None and allow_fast_lookup:
need_sparse_segment_gradient = False
# Ensure we can query the devices below.
segment_ids = ops.convert_to_tensor(segment_ids, name="segment_ids")
if len(params) == 1 and not isinstance(
params[0], (core.Tensor, composite_tensor.CompositeTensor)
):
params = [ops.convert_to_tensor(params[0], name="params")]
# Note that if the params are on a different device (e.g., CPU), we must use
# embedding_lookup() so that the gather operation is colocated with them.
if (
len(params) == 1
and not isinstance(params[0], composite_tensor.CompositeTensor)
and params[0].device == segment_ids.device
and max_norm is None
and (
allow_fast_lookup
or (ignore_weights and compat.forward_compatible(2023, 9, 26))
)
):
idx = ids
embeddings = params[0]
if isinstance(embeddings, resource_variable_ops.BaseResourceVariable):
# Avoid a redundant copy due to copy-on-read semantics for
# sparsely-updated variables.
embeddings = embeddings.read_value_no_copy()
if not allow_fast_lookup:
need_sparse_segment_gradient = True
else:
ids, idx = array_ops.unique(ids)
embeddings = embedding_lookup(
......@@ -1072,15 +1099,27 @@ def embedding_lookup_sparse_impl(
assert idx is not None
if combiner == "sum":
embeddings = math_ops.sparse_segment_sum(
embeddings, idx, segment_ids, name=name
embeddings,
idx,
segment_ids,
name=name,
sparse_gradient=need_sparse_segment_gradient,
)
elif combiner == "mean":
embeddings = math_ops.sparse_segment_mean(
embeddings, idx, segment_ids, name=name
embeddings,
idx,
segment_ids,
name=name,
sparse_gradient=need_sparse_segment_gradient,
)
elif combiner == "sqrtn":
embeddings = math_ops.sparse_segment_sqrt_n(
embeddings, idx, segment_ids, name=name
embeddings,
idx,
segment_ids,
name=name,
sparse_gradient=need_sparse_segment_gradient,
)
else:
assert False, "Unrecognized combiner"
......
......@@ -19,6 +19,7 @@ from tensorflow.python.compat import compat
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import indexed_slices as indexed_slices_lib
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor
from tensorflow.python.framework import tensor_util
......@@ -332,9 +333,45 @@ def _SegmentMeanGrad(op: ops.Operation, grad):
return array_ops.gather(scaled_grad, op.inputs[1]), None
def _SparseSegmentReduceGradV2(op, grad, norm=None):
"""Sparse gradient for SparseSegment(Sum|Mean|SqrtN)[WithNumSegments]."""
assert norm is None or norm == "mean" or norm == "sqrtn"
data = op.inputs[0]
indices = op.inputs[1]
segment_ids = op.inputs[2]
data_shape = array_ops.shape(op.inputs[0])
dense_output_dim0 = data_shape[0]
grad_fn = (
math_ops.sparse_segment_mean_grad_v2
if norm == "mean"
else math_ops.sparse_segment_sqrt_n_grad_v2
if norm == "sqrtn"
else math_ops.sparse_segment_sum_grad_v2
)
grad_values, sorted_unique_indices = grad_fn(
grad, indices, segment_ids, dense_output_dim0
)
return indexed_slices_lib.IndexedSlices(
grad_values, sorted_unique_indices, data_shape
)
def _GetOpAttrOrNone(op, name):
"""Returns the value of the attr of `op` with the given `name`, or None if no
such attr exists.
"""
try:
return op.get_attr(name)
except ValueError:
return None
@ops.RegisterGradient("SparseSegmentSum")
def _SparseSegmentSumGrad(op: ops.Operation, grad):
"""Gradient for SparseSegmentSum."""
if _GetOpAttrOrNone(op, "sparse_gradient"):
return _SparseSegmentReduceGradV2(op, grad), None, None
dim0 = array_ops.shape(op.inputs[0])[0]
if compat.forward_compatible(2021, 6, 10):
return (math_ops.sparse_segment_sum_grad(grad, op.inputs[1], op.inputs[2],
......@@ -347,6 +384,8 @@ def _SparseSegmentSumGrad(op: ops.Operation, grad):
@ops.RegisterGradient("SparseSegmentSumWithNumSegments")
def _SparseSegmentSumWithNumSegmentsGrad(op: ops.Operation, grad):
"""Gradient for SparseSegmentSumWithNumSegments."""
if _GetOpAttrOrNone(op, "sparse_gradient"):
return _SparseSegmentReduceGradV2(op, grad), None, None, None
dim0 = array_ops.shape(op.inputs[0])[0]
if compat.forward_compatible(2021, 6, 10):
return (math_ops.sparse_segment_sum_grad(grad, op.inputs[1], op.inputs[2],
......@@ -360,6 +399,8 @@ def _SparseSegmentSumWithNumSegmentsGrad(op: ops.Operation, grad):
@ops.RegisterGradient("SparseSegmentMean")
def _SparseSegmentMeanGrad(op: ops.Operation, grad):
"""Gradient for SparseSegmentMean."""
if _GetOpAttrOrNone(op, "sparse_gradient"):
return _SparseSegmentReduceGradV2(op, grad, "mean"), None, None
dim0 = array_ops.shape(op.inputs[0])[0]
return (math_ops.sparse_segment_mean_grad(grad, op.inputs[1], op.inputs[2],
dim0), None, None)
......@@ -368,6 +409,8 @@ def _SparseSegmentMeanGrad(op: ops.Operation, grad):
@ops.RegisterGradient("SparseSegmentMeanWithNumSegments")
def _SparseSegmentMeanWithNumSegmentsGrad(op: ops.Operation, grad):
"""Gradient for SparseSegmentMeanWithNumSegments."""
if _GetOpAttrOrNone(op, "sparse_gradient"):
return _SparseSegmentReduceGradV2(op, grad, "mean"), None, None, None
dim0 = array_ops.shape(op.inputs[0])[0]
return (math_ops.sparse_segment_mean_grad(grad, op.inputs[1], op.inputs[2],
dim0), None, None, None)
......@@ -376,6 +419,8 @@ def _SparseSegmentMeanWithNumSegmentsGrad(op: ops.Operation, grad):
@ops.RegisterGradient("SparseSegmentSqrtN")
def _SparseSegmentSqrtNGrad(op: ops.Operation, grad):
"""Gradient for SparseSegmentSqrtN."""
if _GetOpAttrOrNone(op, "sparse_gradient"):
return _SparseSegmentReduceGradV2(op, grad, "sqrtn"), None, None
dim0 = array_ops.shape(op.inputs[0])[0]
return (math_ops.sparse_segment_sqrt_n_grad(grad, op.inputs[1], op.inputs[2],
dim0), None, None)
......@@ -384,6 +429,8 @@ def _SparseSegmentSqrtNGrad(op: ops.Operation, grad):
@ops.RegisterGradient("SparseSegmentSqrtNWithNumSegments")
def _SparseSegmentSqrtNWithNumSegmentsGrad(op: ops.Operation, grad):
"""Gradient for SparseSegmentSqrtNWithNumSegments."""
if _GetOpAttrOrNone(op, "sparse_gradient"):
return _SparseSegmentReduceGradV2(op, grad, "sqrtn"), None, None, None
dim0 = array_ops.shape(op.inputs[0])[0]
return (math_ops.sparse_segment_sqrt_n_grad(grad, op.inputs[1], op.inputs[2],
dim0), None, None, None)
......
......@@ -4769,11 +4769,14 @@ def unsorted_segment_sqrt_n(data, segment_ids, num_segments, name=None):
@tf_export(v1=["sparse.segment_sum", "sparse_segment_sum"])
@deprecation.deprecated_endpoints("sparse_segment_sum")
def sparse_segment_sum(data,
indices,
segment_ids,
name=None,
num_segments=None):
def sparse_segment_sum(
data,
indices,
segment_ids,
name=None,
num_segments=None,
sparse_gradient=False,
):
r"""Computes the sum along sparse segments of a tensor.
Read [the section on
......@@ -4826,6 +4829,10 @@ def sparse_segment_sum(data,
name: A name for the operation (optional).
num_segments: An optional int32 scalar. Indicates the size of the output
`Tensor`.
sparse_gradient: An optional `bool`. Defaults to `False`. If `True`, the
gradient of this function will be sparse (`IndexedSlices`) instead of
dense (`Tensor`). The sparse gradient will contain one non-zero row for
each unique index in `indices`.
Returns:
A `tensor` of the shape as data, except for dimension 0 which
......@@ -4838,18 +4845,28 @@ def sparse_segment_sum(data,
indices=indices,
segment_ids=segment_ids,
num_segments=num_segments,
name=name)
sparse_gradient=sparse_gradient,
name=name,
)
else:
return gen_math_ops.sparse_segment_sum(
data=data, indices=indices, segment_ids=segment_ids, name=name)
data=data,
indices=indices,
segment_ids=segment_ids,
sparse_gradient=sparse_gradient,
name=name,
)
@tf_export("sparse.segment_sum", v1=[])
def sparse_segment_sum_v2(data,
indices,
segment_ids,
num_segments=None,
name=None):
def sparse_segment_sum_v2(
data,
indices,
segment_ids,
num_segments=None,
name=None,
sparse_gradient=False,
):
r"""Computes the sum along sparse segments of a tensor.
Read [the section on
......@@ -4902,6 +4919,10 @@ def sparse_segment_sum_v2(data,
num_segments: An optional int32 scalar. Indicates the size of the output
`Tensor`.
name: A name for the operation (optional).
sparse_gradient: An optional `bool`. Defaults to `False`. If `True`, the
gradient of this function will be sparse (`IndexedSlices`) instead of
dense (`Tensor`). The sparse gradient will contain one non-zero row for
each unique index in `indices`.
Returns:
A `tensor` of the shape as data, except for dimension 0 which
......@@ -4909,16 +4930,25 @@ def sparse_segment_sum_v2(data,
inferred for the last element in `segments_ids`.
"""
return sparse_segment_sum(
data, indices, segment_ids, name=name, num_segments=num_segments)
data,
indices,
segment_ids,
name=name,
num_segments=num_segments,
sparse_gradient=sparse_gradient,
)
@tf_export(v1=["sparse.segment_mean", "sparse_segment_mean"])
@deprecation.deprecated_endpoints("sparse_segment_mean")
def sparse_segment_mean(data,
indices,
segment_ids,
name=None,
num_segments=None):
def sparse_segment_mean(
data,
indices,
segment_ids,
name=None,
num_segments=None,
sparse_gradient=False,
):
r"""Computes the mean along sparse segments of a tensor.
Read [the section on
......@@ -4941,6 +4971,10 @@ def sparse_segment_mean(data,
name: A name for the operation (optional).
num_segments: An optional int32 scalar. Indicates the size of the output
`Tensor`.
sparse_gradient: An optional `bool`. Defaults to `False`. If `True`, the
gradient of this function will be sparse (`IndexedSlices`) instead of
dense (`Tensor`). The sparse gradient will contain one non-zero row for
each unique index in `indices`.
Returns:
A `tensor` of the shape as data, except for dimension 0 which
......@@ -4953,18 +4987,28 @@ def sparse_segment_mean(data,
indices=indices,
segment_ids=segment_ids,
num_segments=num_segments,
name=name)
name=name,
sparse_gradient=sparse_gradient,
)
else:
return gen_math_ops.sparse_segment_mean(
data=data, indices=indices, segment_ids=segment_ids, name=name)
data=data,
indices=indices,
segment_ids=segment_ids,
name=name,
sparse_gradient=sparse_gradient,
)
@tf_export("sparse.segment_mean", v1=[])
def sparse_segment_mean_v2(data,
indices,
segment_ids,
num_segments=None,
name=None):
def sparse_segment_mean_v2(
data,
indices,
segment_ids,
num_segments=None,
name=None,
sparse_gradient=False,
):
r"""Computes the mean along sparse segments of a tensor.
Read [the section on
......@@ -4987,6 +5031,10 @@ def sparse_segment_mean_v2(data,
num_segments: An optional int32 scalar. Indicates the size of the output
`Tensor`.
name: A name for the operation (optional).
sparse_gradient: An optional `bool`. Defaults to `False`. If `True`, the
gradient of this function will be sparse (`IndexedSlices`) instead of
dense (`Tensor`). The sparse gradient will contain one non-zero row for
each unique index in `indices`.
Returns:
A `tensor` of the shape as data, except for dimension 0 which
......@@ -4994,16 +5042,25 @@ def sparse_segment_mean_v2(data,
inferred for the last element in `segments_ids`.
"""
return sparse_segment_mean(
data, indices, segment_ids, name=name, num_segments=num_segments)
data,
indices,
segment_ids,
name=name,
num_segments=num_segments,
sparse_gradient=sparse_gradient,
)
@tf_export(v1=["sparse.segment_sqrt_n", "sparse_segment_sqrt_n"])
@deprecation.deprecated_endpoints("sparse_segment_sqrt_n")
def sparse_segment_sqrt_n(data,
indices,
segment_ids,
name=None,
num_segments=None):
def sparse_segment_sqrt_n(
data,
indices,
segment_ids,
name=None,
num_segments=None,
sparse_gradient=False,
):
r"""Computes the sum along sparse segments of a tensor divided by the sqrt(N).
`N` is the size of the segment being reduced.
......@@ -5017,6 +5074,9 @@ def sparse_segment_sqrt_n(data,
name: A name for the operation (optional).
num_segments: An optional int32 scalar. Indicates the size of the output
`Tensor`.
sparse_gradient: An optional `bool`. Defaults to `False`. If `True`, the
gradient of this function will be sparse (IndexedSlices) instead of dense
(Tensor).
Returns:
A `tensor` of the shape as data, except for dimension 0 which
......@@ -5029,18 +5089,28 @@ def sparse_segment_sqrt_n(data,
indices=indices,
segment_ids=segment_ids,
num_segments=num_segments,
name=name)
name=name,
sparse_gradient=sparse_gradient,
)
else:
return gen_math_ops.sparse_segment_sqrt_n(
data=data, indices=indices, segment_ids=segment_ids, name=name)
data=data,
indices=indices,
segment_ids=segment_ids,
name=name,
sparse_gradient=sparse_gradient,
)
@tf_export("sparse.segment_sqrt_n", v1=[])
def sparse_segment_sqrt_n_v2(data,
indices,
segment_ids,
num_segments=None,
name=None):
def sparse_segment_sqrt_n_v2(
data,
indices,
segment_ids,
num_segments=None,
name=None,
sparse_gradient=False,
):
r"""Computes the sum along sparse segments of a tensor divided by the sqrt(N).
Read [the section on
......@@ -5059,6 +5129,10 @@ def sparse_segment_sqrt_n_v2(data,
num_segments: An optional int32 scalar. Indicates the size of the output
`Tensor`.
name: A name for the operation (optional).
sparse_gradient: An optional `bool`. Defaults to `False`. If `True`, the
gradient of this function will be sparse (`IndexedSlices`) instead of
dense (`Tensor`). The sparse gradient will contain one non-zero row for
each unique index in `indices`.
Returns:
A `tensor` of the shape as data, except for dimension 0 which
......@@ -5066,7 +5140,13 @@ def sparse_segment_sqrt_n_v2(data,
inferred for the last element in `segments_ids`.
"""
return sparse_segment_sqrt_n(
data, indices, segment_ids, name=name, num_segments=num_segments)
data,
indices,
segment_ids,
name=name,
num_segments=num_segments,
sparse_gradient=sparse_gradient,
)
@tf_export("tensordot", "linalg.tensordot")
......
......@@ -2262,15 +2262,15 @@ tf_module {
}
member_method {
name: "sparse_segment_mean"
argspec: "args=[\'data\', \'indices\', \'segment_ids\', \'name\', \'num_segments\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
argspec: "args=[\'data\', \'indices\', \'segment_ids\', \'name\', \'num_segments\', \'sparse_gradient\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\'], "
}
member_method {
name: "sparse_segment_sqrt_n"
argspec: "args=[\'data\', \'indices\', \'segment_ids\', \'name\', \'num_segments\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
argspec: "args=[\'data\', \'indices\', \'segment_ids\', \'name\', \'num_segments\', \'sparse_gradient\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\'], "
}
member_method {
name: "sparse_segment_sum"
argspec: "args=[\'data\', \'indices\', \'segment_ids\', \'name\', \'num_segments\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
argspec: "args=[\'data\', \'indices\', \'segment_ids\', \'name\', \'num_segments\', \'sparse_gradient\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\'], "
}
member_method {
name: "sparse_slice"
......
......@@ -4542,7 +4542,7 @@ tf_module {
}
member_method {
name: "SparseSegmentMean"
argspec: "args=[\'data\', \'indices\', \'segment_ids\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'data\', \'indices\', \'segment_ids\', \'sparse_gradient\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
}
member_method {
name: "SparseSegmentMeanGrad"
......@@ -4554,11 +4554,11 @@ tf_module {
}
member_method {
name: "SparseSegmentMeanWithNumSegments"
argspec: "args=[\'data\', \'indices\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'data\', \'indices\', \'segment_ids\', \'num_segments\', \'sparse_gradient\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
}
member_method {
name: "SparseSegmentSqrtN"
argspec: "args=[\'data\', \'indices\', \'segment_ids\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'data\', \'indices\', \'segment_ids\', \'sparse_gradient\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
}
member_method {
name: "SparseSegmentSqrtNGrad"
......@@ -4570,11 +4570,11 @@ tf_module {
}
member_method {
name: "SparseSegmentSqrtNWithNumSegments"
argspec: "args=[\'data\', \'indices\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'data\', \'indices\', \'segment_ids\', \'num_segments\', \'sparse_gradient\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
}
member_method {
name: "SparseSegmentSum"
argspec: "args=[\'data\', \'indices\', \'segment_ids\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'data\', \'indices\', \'segment_ids\', \'sparse_gradient\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
}
member_method {
name: "SparseSegmentSumGrad"
......@@ -4586,7 +4586,7 @@ tf_module {
}
member_method {
name: "SparseSegmentSumWithNumSegments"
argspec: "args=[\'data\', \'indices\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'data\', \'indices\', \'segment_ids\', \'num_segments\', \'sparse_gradient\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
}
member_method {
name: "SparseSlice"
......
......@@ -102,15 +102,15 @@ tf_module {
}
member_method {
name: "segment_mean"
argspec: "args=[\'data\', \'indices\', \'segment_ids\', \'name\', \'num_segments\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
argspec: "args=[\'data\', \'indices\', \'segment_ids\', \'name\', \'num_segments\', \'sparse_gradient\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\'], "
}
member_method {
name: "segment_sqrt_n"
argspec: "args=[\'data\', \'indices\', \'segment_ids\', \'name\', \'num_segments\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
argspec: "args=[\'data\', \'indices\', \'segment_ids\', \'name\', \'num_segments\', \'sparse_gradient\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\'], "
}
member_method {
name: "segment_sum"
argspec: "args=[\'data\', \'indices\', \'segment_ids\', \'name\', \'num_segments\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
argspec: "args=[\'data\', \'indices\', \'segment_ids\', \'name\', \'num_segments\', \'sparse_gradient\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\'], "
}
member_method {
name: "slice"
......
......@@ -4542,7 +4542,7 @@ tf_module {
}
member_method {
name: "SparseSegmentMean"
argspec: "args=[\'data\', \'indices\', \'segment_ids\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'data\', \'indices\', \'segment_ids\', \'sparse_gradient\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
}
member_method {
name: "SparseSegmentMeanGrad"
......@@ -4554,11 +4554,11 @@ tf_module {
}
member_method {
name: "SparseSegmentMeanWithNumSegments"
argspec: "args=[\'data\', \'indices\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'data\', \'indices\', \'segment_ids\', \'num_segments\', \'sparse_gradient\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
}
member_method {
name: "SparseSegmentSqrtN"
argspec: "args=[\'data\', \'indices\', \'segment_ids\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'data\', \'indices\', \'segment_ids\', \'sparse_gradient\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
}
member_method {
name: "SparseSegmentSqrtNGrad"
......@@ -4570,11 +4570,11 @@ tf_module {
}
member_method {
name: "SparseSegmentSqrtNWithNumSegments"
argspec: "args=[\'data\', \'indices\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'data\', \'indices\', \'segment_ids\', \'num_segments\', \'sparse_gradient\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
}
member_method {
name: "SparseSegmentSum"
argspec: "args=[\'data\', \'indices\', \'segment_ids\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'data\', \'indices\', \'segment_ids\', \'sparse_gradient\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
}
member_method {
name: "SparseSegmentSumGrad"
......@@ -4586,7 +4586,7 @@ tf_module {
}
member_method {
name: "SparseSegmentSumWithNumSegments"
argspec: "args=[\'data\', \'indices\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'data\', \'indices\', \'segment_ids\', \'num_segments\', \'sparse_gradient\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
}
member_method {
name: "SparseSlice"
......
......@@ -82,15 +82,15 @@ tf_module {
}
member_method {
name: "segment_mean"
argspec: "args=[\'data\', \'indices\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
argspec: "args=[\'data\', \'indices\', \'segment_ids\', \'num_segments\', \'name\', \'sparse_gradient\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\'], "
}
member_method {
name: "segment_sqrt_n"
argspec: "args=[\'data\', \'indices\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
argspec: "args=[\'data\', \'indices\', \'segment_ids\', \'num_segments\', \'name\', \'sparse_gradient\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\'], "
}
member_method {
name: "segment_sum"
argspec: "args=[\'data\', \'indices\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
argspec: "args=[\'data\', \'indices\', \'segment_ids\', \'num_segments\', \'name\', \'sparse_gradient\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\'], "
}
member_method {
name: "slice"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册