From 850d5c0f4df804362a778432fcbb05cd84e2032f Mon Sep 17 00:00:00 2001 From: Mihai Maruseac Date: Thu, 1 Sep 2022 11:33:04 -0700 Subject: [PATCH] Fix conflict resolution error Signed-off-by: Mihai Maruseac --- .../quantization_ops/quantization_ops_test.py | 25 +++++++++++++------ 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/tensorflow/python/kernel_tests/quantization_ops/quantization_ops_test.py b/tensorflow/python/kernel_tests/quantization_ops/quantization_ops_test.py index 724f6fdc509..7f23b69c2a1 100644 --- a/tensorflow/python/kernel_tests/quantization_ops/quantization_ops_test.py +++ b/tensorflow/python/kernel_tests/quantization_ops/quantization_ops_test.py @@ -428,6 +428,14 @@ class QuantizeAndDequantizeV3OpTest(test_util.TensorFlowTestCase): input=inputs, input_min=[], input_max=4.0, out_type=dtypes.quint8)) + input_value = constant_op.constant([-0.8, -0.5, 0, 0.3, 0.8, -2.0], + shape=(6,), + dtype=dtypes.float32), + input_min = constant_op.constant(-127, shape=(), dtype=dtypes.float32) + input_max = constant_op.constant(127, shape=(), dtype=dtypes.float32) + # Tensor with invalid shape and invalid number of elements. + num_bits = constant_op.constant([], shape=(0,), dtype=dtypes.int32) + # Test that running the op raises error. It raises different errors # depending on whether the shape inference is run first or the op's # Compute() is run first. @@ -454,13 +462,16 @@ class QuantizeDownAndShrinkRangeOpTest(test_util.TensorFlowTestCase): @test_util.run_in_graph_and_eager_modes def test_invalid_inputs(self): - input_value = constant_op.constant([-0.8, -0.5, 0, 0.3, 0.8, -2.0], - shape=(6,), - dtype=dtypes.float32), - input_min = constant_op.constant(-127, shape=(), dtype=dtypes.float32) - input_max = constant_op.constant(127, shape=(), dtype=dtypes.float32) - # Tensor with invalid shape and invalid number of elements. - num_bits = constant_op.constant([], shape=(0,), dtype=dtypes.int32) + inputs = constant_op.constant( + np.int32(0), shape=[3, 3, 3, 3], dtype=dtypes.qint32) + + with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError), + "must be rank 0"): + self.evaluate( + math_ops.quantize_down_and_shrink_range(input=inputs, + input_min=[], + input_max=4.0, + out_type=dtypes.quint8)) if __name__ == "__main__": -- GitLab