提交 f0d6424d 编写于 作者: J Jian Li 提交者: TensorFlower Gardener

Add int16 support to Quant.

PiperOrigin-RevId: 258563058
上级 4fd66235
......@@ -55,7 +55,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
OpContext op_context(context, node);
TF_LITE_ENSURE(context, op_context.output->type == kTfLiteUInt8 ||
op_context.output->type == kTfLiteInt8);
op_context.output->type == kTfLiteInt8 ||
op_context.output->type == kTfLiteInt16);
// TODO(b/128934713): Add support for fixed-point per-channel quantization.
// Currently this only support affine per-layer quantization.
......@@ -69,9 +70,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// For requantize use case.
const bool is_requantize = (op_context.input->type == kTfLiteUInt8 ||
op_context.input->type == kTfLiteInt8) &&
op_context.input->type == kTfLiteInt8 ||
op_context.input->type == kTfLiteInt16) &&
(op_context.output->type == kTfLiteUInt8 ||
op_context.output->type == kTfLiteInt8);
op_context.output->type == kTfLiteInt8 ||
op_context.output->type == kTfLiteInt16);
if (is_requantize) {
const double effective_output_scale =
static_cast<double>(op_context.input->params.scale) /
......@@ -104,6 +107,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
optimized_ops::AffineQuantize(
op_params, GetTensorShape(input), GetTensorData<float>(input),
GetTensorShape(output), GetTensorData<uint8_t>(output));
} else if (output->type == kTfLiteInt16) {
optimized_ops::AffineQuantize(
op_params, GetTensorShape(input), GetTensorData<float>(input),
GetTensorShape(output), GetTensorData<int16_t>(output));
} else {
context->ReportError(
context,
......
......@@ -79,6 +79,17 @@ TEST(QuantizeOpTest, INT8) {
{-128, -127, -126, -125, -124, 123, 124, 125, 126, 127}));
}
TEST(QuantizeOpTest, INT16) {
QuantizeOpModel m({TensorType_FLOAT32, {2, 5}},
{TensorType_INT16, {2, 5}, 0, 0, 0.005, 0});
m.SetInput({-63.5, -63, -3, -2, -1, 1, 2, 3, 63.5, 64});
m.Invoke();
EXPECT_THAT(m.GetOutput<int16_t>(),
ElementsAreArray({-12700, -12600, -600, -400, -200, 200, 400, 600,
12700, 12800}));
}
// Input scale 0.500000, output scale 0.500000, input zeropoint -1, output
// zeropoint -1
TEST(QuantizeOpTest, Int8Int8SameScale) {
......
......@@ -376,7 +376,9 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_ELU, Register_ELU());
AddBuiltin(BuiltinOperator_REVERSE_SEQUENCE, Register_REVERSE_SEQUENCE());
AddBuiltin(BuiltinOperator_MATRIX_DIAG, Register_MATRIX_DIAG());
AddBuiltin(BuiltinOperator_QUANTIZE, Register_QUANTIZE());
AddBuiltin(BuiltinOperator_QUANTIZE, Register_QUANTIZE(),
/* min_version */ 1,
/* max_version */ 2);
AddBuiltin(BuiltinOperator_MATRIX_SET_DIAG, Register_MATRIX_SET_DIAG());
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
......
......@@ -169,7 +169,7 @@ OperatorProperty GetOperatorProperty(const BuiltinOperator& op) {
case BuiltinOperator_QUANTIZE:
property.inputs = {{0, {}}};
property.outputs = {{0, {}}};
property.version = 1;
property.version = 2;
break;
case BuiltinOperator_RESHAPE:
property.inputs = {{0, {}}};
......
......@@ -370,7 +370,7 @@ TEST_F(QuantizeConcatModelTest, AddRequantBeforeConcat) {
BuiltinOperator_CONCATENATION);
EXPECT_EQ(model_.operator_codes[0]->version, 2);
EXPECT_EQ(model_.operator_codes[1]->builtin_code, BuiltinOperator_QUANTIZE);
EXPECT_EQ(model_.operator_codes[1]->version, 1);
EXPECT_EQ(model_.operator_codes[1]->version, 2);
}
class QuantizeConvModel1Test : public QuantizeModelTest {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册