提交 fd8e10fb 编写于 作者: M Mihai Maruseac

Fix heap buffer overflow in `tf.raw_ops.SparseFillEmptyRowsGrad`.

Also add tests as they were lacking

PiperOrigin-RevId: 332566071
Change-Id: I44277578e26ff5fb3fdb0dcbba6e91b2ec3e7859
上级 ce7a81fd
......@@ -213,6 +213,9 @@ class SparseFillEmptyRowsGradOp : public OpKernel {
context, TensorShapeUtils::IsVector(reverse_index_map_t->shape()),
errors::InvalidArgument("reverse_index_map must be a vector, saw: ",
reverse_index_map_t->shape().DebugString()));
OP_REQUIRES(context, TensorShapeUtils::IsVector(grad_values_t->shape()),
errors::InvalidArgument("grad_values must be a vector, saw: ",
grad_values_t->shape().DebugString()));
const auto reverse_index_map = reverse_index_map_t->vec<int64>();
const auto grad_values = grad_values_t->vec<T>();
......@@ -241,8 +244,13 @@ class SparseFillEmptyRowsGradOp : public OpKernel {
// Locate the index of the output of the forward prop associated
// with this location in the input of the forward prop. Copy
// the gradient into it. Mark it as visited.
d_values(i) = grad_values(reverse_index_map(i));
visited(reverse_index_map(i)) = true;
int64 reverse_index = reverse_index_map(i);
OP_REQUIRES(
context, 0 <= reverse_index && reverse_index < N_full,
errors::InvalidArgument("Elements in reverse index must be in [0, ",
N_full, ") but got ", reverse_index));
d_values(i) = grad_values(reverse_index);
visited(reverse_index) = true;
}
for (int j = 0; j < N_full; ++j) {
// The default value gradient gets the accumulated remainder of
......
......@@ -21,12 +21,14 @@ from __future__ import print_function
from absl.testing import parameterized
import numpy as np
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
# Need array_grad to register gradient for Identity.
from tensorflow.python.ops import array_grad # pylint: disable=unused-import
from tensorflow.python.ops import gen_sparse_ops
from tensorflow.python.ops import gradient_checker_v2 as gradient_checker
from tensorflow.python.ops import math_ops
# Need sparse_grad to register gradient for SparseToDense.
......@@ -134,5 +136,57 @@ class SparseOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
self.assertAllEqual(expected_dense, result_dense)
@test_util.run_all_in_graph_and_eager_modes
class RawOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
def testSparseFillEmptyRowsGrad(self):
reverse_index_map = [2, 1]
grad_values = [0, 1, 2, 3]
d_values, d_default_value = self.evaluate(
gen_sparse_ops.SparseFillEmptyRowsGrad(
reverse_index_map=reverse_index_map, grad_values=grad_values))
self.assertAllEqual([2, 1], d_values)
self.assertEqual(3, d_default_value)
def testSparseFillEmptyRowsGradNegativeIndexMapValue(self):
reverse_index_map = [2, -1]
grad_values = [0, 1, 2, 3]
with self.assertRaisesRegex(
errors.InvalidArgumentError,
r'Elements in reverse index must be in \[0, 4\)'):
self.evaluate(
gen_sparse_ops.SparseFillEmptyRowsGrad(
reverse_index_map=reverse_index_map, grad_values=grad_values))
def testSparseFillEmptyRowsGradLargeIndexMapValue(self):
reverse_index_map = [2, 10]
grad_values = [0, 1, 2, 3]
with self.assertRaisesRegex(
errors.InvalidArgumentError,
r'Elements in reverse index must be in \[0, 4\)'):
self.evaluate(
gen_sparse_ops.SparseFillEmptyRowsGrad(
reverse_index_map=reverse_index_map, grad_values=grad_values))
def testSparseFillEmptyRowsGradMatrix(self):
reverse_index_map = [0, 1]
grad_values = [[0, 1], [2, 3]]
# Note: Eager mode and graph mode throw different errors here. Graph mode
# will fail with a ValueError from the shape checking logic, while Eager
# will fail with an InvalidArgumentError from the kernel itself.
if context.executing_eagerly():
with self.assertRaisesRegex(errors.InvalidArgumentError,
r'grad_values must be a vector'):
self.evaluate(
gen_sparse_ops.SparseFillEmptyRowsGrad(
reverse_index_map=reverse_index_map, grad_values=grad_values))
else:
with self.assertRaisesRegex(ValueError,
r'Shape must be rank 1 but is rank 2'):
self.evaluate(
gen_sparse_ops.SparseFillEmptyRowsGrad(
reverse_index_map=reverse_index_map, grad_values=grad_values))
if __name__ == '__main__':
googletest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册