未验证 提交 fe61e819 编写于 作者: T TensorFlow Jenkins 提交者: GitHub

Fix tf.raw_ops.TensorListResize vulnerability with non-scalar input. (#57867)

Check that the size input is valid.
Add graph/eager unit tests. Graph mode was already ok but eager mode was not.

Note: This fix will have to be cherry picked in r2.10, r2.9, and r2.8.
PiperOrigin-RevId: 477002316
Co-authored-by: NAlan Liu <liualan@google.com>
上级 306e17fe
......@@ -375,6 +375,8 @@ class TensorListResize : public OpKernel {
void Compute(OpKernelContext* c) override {
const TensorList* input_list = nullptr;
OP_REQUIRES_OK(c, GetInputList(c, 0, &input_list));
OP_REQUIRES(c, TensorShapeUtils::IsScalar(c->input(1).shape()),
errors::InvalidArgument("size must be a scalar"));
int32_t size = c->input(1).scalar<int32>()();
OP_REQUIRES(
c, size >= 0,
......
......@@ -1658,6 +1658,15 @@ class ListOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
l = list_ops.tensor_list_resize(l, -1)
self.evaluate(l)
@test_util.run_in_graph_and_eager_modes
def testResizeWithNonScalarFails(self):
l = list_ops.tensor_list_from_tensor([3, 4, 5], element_shape=[])
size = np.zeros([0, 2, 3, 3])
with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
r"Shape must be rank 0 but is rank \d+|"
r"\w+ must be a scalar"):
self.evaluate(gen_list_ops.TensorListResize(input_handle=l, size=size))
@test_util.run_deprecated_v1
@test_util.enable_control_flow_v2
def testSkipEagerResizeGrad(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册