From f0d6424da5f77d3e23491593027330b53517d591 Mon Sep 17 00:00:00 2001 From: Jian Li Date: Wed, 17 Jul 2019 07:18:23 -0700 Subject: [PATCH] Add int16 support to Quant. PiperOrigin-RevId: 258563058 --- tensorflow/lite/kernels/quantize.cc | 13 ++++++++++--- tensorflow/lite/kernels/quantize_test.cc | 11 +++++++++++ tensorflow/lite/kernels/register.cc | 4 +++- tensorflow/lite/tools/optimize/operator_property.cc | 2 +- .../lite/tools/optimize/quantize_model_test.cc | 2 +- 5 files changed, 26 insertions(+), 6 deletions(-) diff --git a/tensorflow/lite/kernels/quantize.cc b/tensorflow/lite/kernels/quantize.cc index 066bc72618b..35a0bb54f35 100644 --- a/tensorflow/lite/kernels/quantize.cc +++ b/tensorflow/lite/kernels/quantize.cc @@ -55,7 +55,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { OpContext op_context(context, node); TF_LITE_ENSURE(context, op_context.output->type == kTfLiteUInt8 || - op_context.output->type == kTfLiteInt8); + op_context.output->type == kTfLiteInt8 || + op_context.output->type == kTfLiteInt16); // TODO(b/128934713): Add support for fixed-point per-channel quantization. // Currently this only support affine per-layer quantization. @@ -69,9 +70,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // For requantize use case. const bool is_requantize = (op_context.input->type == kTfLiteUInt8 || - op_context.input->type == kTfLiteInt8) && + op_context.input->type == kTfLiteInt8 || + op_context.input->type == kTfLiteInt16) && (op_context.output->type == kTfLiteUInt8 || - op_context.output->type == kTfLiteInt8); + op_context.output->type == kTfLiteInt8 || + op_context.output->type == kTfLiteInt16); if (is_requantize) { const double effective_output_scale = static_cast(op_context.input->params.scale) / @@ -104,6 +107,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { optimized_ops::AffineQuantize( op_params, GetTensorShape(input), GetTensorData(input), GetTensorShape(output), GetTensorData(output)); + } else if (output->type == kTfLiteInt16) { + optimized_ops::AffineQuantize( + op_params, GetTensorShape(input), GetTensorData(input), + GetTensorShape(output), GetTensorData(output)); } else { context->ReportError( context, diff --git a/tensorflow/lite/kernels/quantize_test.cc b/tensorflow/lite/kernels/quantize_test.cc index b381c041859..e720f74728e 100644 --- a/tensorflow/lite/kernels/quantize_test.cc +++ b/tensorflow/lite/kernels/quantize_test.cc @@ -79,6 +79,17 @@ TEST(QuantizeOpTest, INT8) { {-128, -127, -126, -125, -124, 123, 124, 125, 126, 127})); } +TEST(QuantizeOpTest, INT16) { + QuantizeOpModel m({TensorType_FLOAT32, {2, 5}}, + {TensorType_INT16, {2, 5}, 0, 0, 0.005, 0}); + + m.SetInput({-63.5, -63, -3, -2, -1, 1, 2, 3, 63.5, 64}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({-12700, -12600, -600, -400, -200, 200, 400, 600, + 12700, 12800})); +} + // Input scale 0.500000, output scale 0.500000, input zeropoint -1, output // zeropoint -1 TEST(QuantizeOpTest, Int8Int8SameScale) { diff --git a/tensorflow/lite/kernels/register.cc b/tensorflow/lite/kernels/register.cc index 700b321c7a9..f17c445a063 100644 --- a/tensorflow/lite/kernels/register.cc +++ b/tensorflow/lite/kernels/register.cc @@ -376,7 +376,9 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_ELU, Register_ELU()); AddBuiltin(BuiltinOperator_REVERSE_SEQUENCE, Register_REVERSE_SEQUENCE()); AddBuiltin(BuiltinOperator_MATRIX_DIAG, Register_MATRIX_DIAG()); - AddBuiltin(BuiltinOperator_QUANTIZE, Register_QUANTIZE()); + AddBuiltin(BuiltinOperator_QUANTIZE, Register_QUANTIZE(), + /* min_version */ 1, + /* max_version */ 2); AddBuiltin(BuiltinOperator_MATRIX_SET_DIAG, Register_MATRIX_SET_DIAG()); // TODO(andrewharp, ahentz): Move these somewhere more appropriate so that diff --git a/tensorflow/lite/tools/optimize/operator_property.cc b/tensorflow/lite/tools/optimize/operator_property.cc index f5b2736415e..ad974e7aca3 100644 --- a/tensorflow/lite/tools/optimize/operator_property.cc +++ b/tensorflow/lite/tools/optimize/operator_property.cc @@ -169,7 +169,7 @@ OperatorProperty GetOperatorProperty(const BuiltinOperator& op) { case BuiltinOperator_QUANTIZE: property.inputs = {{0, {}}}; property.outputs = {{0, {}}}; - property.version = 1; + property.version = 2; break; case BuiltinOperator_RESHAPE: property.inputs = {{0, {}}}; diff --git a/tensorflow/lite/tools/optimize/quantize_model_test.cc b/tensorflow/lite/tools/optimize/quantize_model_test.cc index f41bf077cd3..7de8398c287 100644 --- a/tensorflow/lite/tools/optimize/quantize_model_test.cc +++ b/tensorflow/lite/tools/optimize/quantize_model_test.cc @@ -370,7 +370,7 @@ TEST_F(QuantizeConcatModelTest, AddRequantBeforeConcat) { BuiltinOperator_CONCATENATION); EXPECT_EQ(model_.operator_codes[0]->version, 2); EXPECT_EQ(model_.operator_codes[1]->builtin_code, BuiltinOperator_QUANTIZE); - EXPECT_EQ(model_.operator_codes[1]->version, 1); + EXPECT_EQ(model_.operator_codes[1]->version, 2); } class QuantizeConvModel1Test : public QuantizeModelTest { -- GitLab