提交 56edb7fc 编写于 作者: T TensorFlower Gardener

Merge pull request #58400 from SaoirseARM:toupstream/int6x8_32_accum

PiperOrigin-RevId: 565106502
......@@ -1222,7 +1222,7 @@ TEST_F(QuantizeFCTest, VerifyFCFor16x8) {
EXPECT_THAT(model_.operator_codes, SizeIs(1));
EXPECT_THAT(GetBuiltinCode(model_.operator_codes[0].get()),
Eq(BuiltinOperator_FULLY_CONNECTED));
ASSERT_THAT(model_.operator_codes[0]->version, Eq(5));
ASSERT_THAT(model_.operator_codes[0]->version, Eq(11));
// Check the scale value. The scale value will be smaller than the int8 scale
// since the scale is calculated by dividing by 2^bit_num.
......
......@@ -1305,6 +1305,9 @@ TfLiteStatus ParseConv2D(const Operator* op, ErrorReporter* error_reporter,
params->dilation_width_factor = schema_params->dilation_w_factor();
params->dilation_height_factor = schema_params->dilation_h_factor();
TF_LITE_ENSURE_STATUS(
ConvertTensorType(schema_params->quantized_bias_type(),
&params->quantized_bias_type, error_reporter));
} else {
// TODO(b/157480169): We should either return kTfLiteError or fill in some
// reasonable defaults in the params struct. We are not doing so until we
......@@ -1519,7 +1522,9 @@ TfLiteStatus ParseFullyConnected(const Operator* op,
params->keep_num_dims = schema_params->keep_num_dims();
params->asymmetric_quantize_inputs =
schema_params->asymmetric_quantize_inputs();
TF_LITE_ENSURE_STATUS(
ConvertTensorType(schema_params->quantized_bias_type(),
&params->quantized_bias_type, error_reporter));
switch (schema_params->weights_format()) {
case FullyConnectedOptionsWeightsFormat_DEFAULT:
params->weights_format = kTfLiteFullyConnectedWeightsFormatDefault;
......@@ -2450,6 +2455,9 @@ TfLiteStatus ParseTransposeConv(const Operator* op,
params->activation =
ConvertActivation(transpose_conv_params->fused_activation_function());
TF_LITE_ENSURE_STATUS(
ConvertTensorType(transpose_conv_params->quantized_bias_type(),
&params->quantized_bias_type, error_reporter));
} else {
// TODO(b/157480169): We should either return kTfLiteError or fill in some
// reasonable defaults in the params struct. We are not doing so until we
......
......@@ -91,6 +91,10 @@ typedef struct {
// Note: Version 2 supports dilation values not equal to 1.
int dilation_width_factor;
int dilation_height_factor;
// Parameters for CONV_2D version 7 or above.
// Used to determine the default value for the quantized bias.
TfLiteType quantized_bias_type;
} TfLiteConvParams;
typedef struct {
......@@ -194,6 +198,10 @@ typedef struct {
// If set to true and the weights are quantized, then non constant inputs
// are quantized at evaluation time with asymmetric quantization.
bool asymmetric_quantize_inputs;
// Parameters for FullyConnected version 10 or above.
// Used to determine the default value for the quantized bias.
TfLiteType quantized_bias_type;
} TfLiteFullyConnectedParams;
typedef enum {
......@@ -431,6 +439,10 @@ typedef struct {
// Parameters supported by version 4:
TfLiteFusedActivation activation;
// Parameters for TransposeConv version 5 or above.
// Used to determine the default value for the quantized bias.
TfLiteType quantized_bias_type;
} TfLiteTransposeConvParams;
typedef struct {
......
......@@ -57,7 +57,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_L2_POOL_2D, Register_L2_POOL_2D());
AddBuiltin(BuiltinOperator_CONV_2D, Register_CONV_2D(),
/* min_version = */ 1,
/* max_version = */ 7);
/* max_version = */ 8);
AddBuiltin(BuiltinOperator_DEPTHWISE_CONV_2D, Register_DEPTHWISE_CONV_2D(),
/* min_version = */ 1,
/* max_version = */ 7);
......@@ -82,7 +82,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
Register_EMBEDDING_LOOKUP_SPARSE());
AddBuiltin(BuiltinOperator_FULLY_CONNECTED, Register_FULLY_CONNECTED(),
/* min_version = */ 1,
/* max_version = */ 10);
/* max_version = */ 11);
AddBuiltin(BuiltinOperator_LSH_PROJECTION, Register_LSH_PROJECTION());
AddBuiltin(BuiltinOperator_HASHTABLE_LOOKUP, Register_HASHTABLE_LOOKUP());
AddBuiltin(BuiltinOperator_SOFTMAX, Register_SOFTMAX(),
......@@ -221,7 +221,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_COS, Register_COS());
AddBuiltin(BuiltinOperator_TRANSPOSE_CONV, Register_TRANSPOSE_CONV(),
/* min_version = */ 1,
/* max_version = */ 4);
/* max_version = */ 5);
AddBuiltin(BuiltinOperator_TILE, Register_TILE(),
/* min_version = */ 1,
/* max_version = */ 3);
......
......@@ -121,6 +121,8 @@ struct OpData {
// Number of convolution groups.
int32_t groups = 1;
TfLiteType quantized_bias_type = kTfLiteNoType;
};
inline PaddingType RuntimePaddingType(TfLitePadding padding) {
......@@ -359,10 +361,6 @@ TfLiteStatus Prepare(KernelType kernel_type, TfLiteContext* context,
input_type == kTfLiteInt8 || input_type == kTfLiteInt16);
TF_LITE_ENSURE_TYPES_EQ(context, output->type, input_type);
if (input_type == kTfLiteInt16) {
TF_LITE_ENSURE_EQ(context, input->params.zero_point, 0);
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
}
// Filter must have zero zero-points in per-channel quantization.
if (input_type == kTfLiteInt16 || input_type == kTfLiteInt8) {
TF_LITE_ENSURE_EQ(context, filter->quantization.type,
......@@ -396,6 +394,21 @@ TfLiteStatus Prepare(KernelType kernel_type, TfLiteContext* context,
TF_LITE_ENSURE_EQ(context, NumElements(bias), SizeOfDimension(filter, 0));
}
if (input_type == kTfLiteInt16) {
// Quantization should be symmetric.
TF_LITE_ENSURE_EQ(context, input->params.zero_point, 0);
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
// Check quantized_bias_type is either kTfLiteInt64 or kTfLiteInt32.
if (params->quantized_bias_type != kTfLiteFloat32) {
TF_LITE_ENSURE(context, params->quantized_bias_type == kTfLiteInt32 ||
params->quantized_bias_type == kTfLiteInt64);
TF_LITE_ENSURE(context, (bias == nullptr) ||
bias->type == params->quantized_bias_type);
data->quantized_bias_type = params->quantized_bias_type;
}
}
const bool is_hybrid =
(input->type == kTfLiteFloat32 &&
(filter->type == kTfLiteUInt8 || filter->type == kTfLiteInt8));
......@@ -855,34 +868,37 @@ void EvalQuantizedPerChannel16x8(TfLiteContext* context, TfLiteNode* node,
filter->params.zero_point ||
output->params.zero_point;
// Fallback to reference kernel when bias_type is int64 as
// there is no optimized kernel for int64 bias yet.
if (bias && bias->type == kTfLiteInt64) {
reference_integer_ops::ConvPerChannel(
op_params, data->per_channel_output_multiplier.data(),
data->per_channel_output_shift.data(), GetTensorShape(input),
GetTensorData<int16>(input), GetTensorShape(filter),
GetTensorData<int8>(filter), GetTensorShape(bias),
GetTensorData<std::int64_t>(bias), GetTensorShape(output),
GetTensorData<int16>(output));
} else if (effective_kernel_type == kReference || has_non_zero_point) {
if (data->quantized_bias_type == kTfLiteInt32) {
if (effective_kernel_type == kReference || has_non_zero_point) {
reference_integer_ops::ConvPerChannel(
op_params, data->per_channel_output_multiplier.data(),
data->per_channel_output_shift.data(), GetTensorShape(input),
GetTensorData<int16>(input), GetTensorShape(filter),
GetTensorData<int8>(filter), GetTensorShape(bias),
GetTensorData<int32_t>(bias), GetTensorShape(output),
GetTensorData<int16>(output));
} else {
optimized_integer_ops::ConvPerChannel(
op_params, data->per_channel_output_multiplier.data(),
data->per_channel_output_shift.data(), GetTensorShape(input),
GetTensorData<int16_t>(input), GetTensorShape(filter),
GetTensorData<int8_t>(filter), GetTensorShape(bias),
GetTensorData<int32_t>(bias), GetTensorShape(output),
GetTensorData<int16_t>(output), GetTensorShape(im2col),
GetTensorData<int16_t>(im2col),
CpuBackendContext::GetFromContext(context));
}
} else {
TFLITE_DCHECK(!has_non_zero_point);
// Fallback to reference kernel when bias_type is int64 as
// there is no optimized kernel for int64 bias yet.
reference_integer_ops::ConvPerChannel(
op_params, data->per_channel_output_multiplier.data(),
data->per_channel_output_shift.data(), GetTensorShape(input),
GetTensorData<int16>(input), GetTensorShape(filter),
GetTensorData<int8>(filter), GetTensorShape(bias),
GetTensorData<std::int32_t>(bias), GetTensorShape(output),
GetTensorData<int64_t>(bias), GetTensorShape(output),
GetTensorData<int16>(output));
} else {
optimized_integer_ops::ConvPerChannel(
op_params, data->per_channel_output_multiplier.data(),
data->per_channel_output_shift.data(), GetTensorShape(input),
GetTensorData<int16_t>(input), GetTensorShape(filter),
GetTensorData<int8_t>(filter), GetTensorShape(bias),
GetTensorData<std::int32_t>(bias), GetTensorShape(output),
GetTensorData<int16_t>(output), GetTensorShape(im2col),
GetTensorData<int16_t>(im2col),
CpuBackendContext::GetFromContext(context));
}
}
......
......@@ -56,8 +56,8 @@ class BaseConvolutionOpModel : public SingleOpModel {
int stride_height = 2, enum Padding padding = Padding_VALID,
enum ActivationFunctionType activation = ActivationFunctionType_NONE,
int dilation_width_factor = 1, int dilation_height_factor = 1,
int num_threads = -1,
std::initializer_list<FilterType> filter_data = {}) {
int num_threads = -1, std::initializer_list<FilterType> filter_data = {},
const TensorType bias_type = TensorType_INT32) {
input_ = AddInput(input);
if (filter_data.size()) {
......@@ -85,11 +85,6 @@ class BaseConvolutionOpModel : public SingleOpModel {
input.scale * filter.per_channel_quantization_scales[i];
bias_zero_points[i] = 0;
}
tflite::TensorType bias_type = TensorType_INT32;
if (input.type == TensorType_INT16) {
// In case of 16-bit, the bias type is set to be int 64.
bias_type = TensorType_INT64;
}
TensorData bias{bias_type,
{bias_size},
/*min=*/0,
......@@ -104,7 +99,7 @@ class BaseConvolutionOpModel : public SingleOpModel {
} else {
// per tensor quantization.
auto bias_scale = GetScale(input_) * GetScale(filter_);
TensorData bias{TensorType_INT32, {bias_size}, 0, 0, bias_scale};
TensorData bias{bias_type, {bias_size}, 0, 0, bias_scale};
bias_ = AddInput(bias);
}
}
......@@ -114,7 +109,7 @@ class BaseConvolutionOpModel : public SingleOpModel {
SetBuiltinOp(BuiltinOperator_CONV_2D, BuiltinOptions_Conv2DOptions,
CreateConv2DOptions(
builder_, padding, stride_width, stride_height, activation,
dilation_width_factor, dilation_height_factor)
dilation_width_factor, dilation_height_factor, bias_type)
.Union());
resolver_ = std::make_unique<SingleOpResolver>(BuiltinOperator_CONV_2D,
......@@ -1657,8 +1652,9 @@ class PerChannelQuantizedConvolutionOpModel
public:
using BaseConvolutionOpModel::BaseConvolutionOpModel;
void SetInput(std::initializer_list<float> data) {
QuantizeAndPopulate<int8_t>(input_, data);
template <typename T>
void SetInput(const std::vector<float>& data) {
QuantizeAndPopulate<T>(input_, data);
}
void SetFilter(std::initializer_list<float> data) {
......@@ -1669,10 +1665,15 @@ class PerChannelQuantizedConvolutionOpModel
PerChannelQuantizeBias(bias_, data);
}
std::vector<int8_t> GetOutput() { return ExtractVector<int8_t>(output_); }
template <typename T>
std::vector<T> GetOutput() {
return ExtractVector<T>(output_);
}
template <typename T>
std::vector<float> GetDequantizedOutput() {
return Dequantize<int8_t>(ExtractVector<int8_t>(output_), GetScale(output_),
GetZeroPoint(output_));
return Dequantize<T>(ExtractVector<T>(output_), GetScale(output_),
GetZeroPoint(output_));
}
};
......@@ -1714,7 +1715,7 @@ TEST_P(ConvolutionOpTest, SimplePerTensorTest) {
/*channel_index=*/0},
{TensorType_INT8, {}, -63.5, 64, 0.5, -1},
/*stride_width=*/1, /*stride_height=*/1);
m.SetInput({
m.SetInput<int8_t>({
// [1 * 2 * 3 * 2] as [batch, y, x, input_channel]
3, 2, // batch = 0, y = 0, x = 0
1, -1, // batch = 0, y = 0, x = 1
......@@ -1740,9 +1741,9 @@ TEST_P(ConvolutionOpTest, SimplePerTensorTest) {
// Invoke and verify output.
// output has dimension [1 * 1 * 2 * 2] as [batch, y, x, output_channel]
ASSERT_EQ(m.Invoke(), kTfLiteOk);
EXPECT_THAT(m.GetDequantizedOutput(),
EXPECT_THAT(m.GetDequantizedOutput<int8_t>(),
ElementsAreArray(ArrayFloatNear({31, 56, -57, -44})));
EXPECT_THAT(m.GetOutput(), ElementsAreArray({61, 111, -115, -89}));
EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAreArray({61, 111, -115, -89}));
}
TEST_P(ConvolutionOpTest, SimplePerChannelTest) {
......@@ -1761,7 +1762,7 @@ TEST_P(ConvolutionOpTest, SimplePerChannelTest) {
/*channel_index=*/0},
{TensorType_INT8, {}, -63.5, 64, 0.5, -1},
/*stride_width=*/1, /*stride_height=*/1);
m.SetInput({
m.SetInput<int8_t>({
// [1 * 2 * 3 * 2] as [batch, y, x, input_channel]
3, 2, // batch = 0, y = 0, x = 0
1, -1, // batch = 0, y = 0, x = 1
......@@ -1787,9 +1788,123 @@ TEST_P(ConvolutionOpTest, SimplePerChannelTest) {
// Invoke and verify output.
// output has dimension [1 * 1 * 2 * 2] as [batch, y, x, output_channel]
ASSERT_EQ(m.Invoke(), kTfLiteOk);
EXPECT_THAT(m.GetDequantizedOutput(),
EXPECT_THAT(m.GetDequantizedOutput<int8_t>(),
ElementsAreArray(ArrayFloatNear({31, 64, -57, -46})));
EXPECT_THAT(m.GetOutput(), ElementsAreArray({61, 127, -115, -93}));
EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAreArray({61, 127, -115, -93}));
}
TEST_P(ConvolutionOpTest, SimplePerChannel16x8Bias32) {
const float scale = 128.0 / 65536;
PerChannelQuantizedConvolutionOpModel m(
GetRegistration(), {TensorType_INT16, {1, 2, 3, 2}, 0, 0, scale, 0},
{TensorType_INT8,
// [2 * 2 * 2 * 2] as [output_channel, y, x, input_channel]
{2, 2, 2, 2},
0,
0,
0,
0,
/*per_channel_quantization=*/true,
/*per_channel_quantization_scales=*/{1, 2},
/*per_channel_quantization_offsets=*/{0, 0},
/*channel_index=*/0},
{TensorType_INT16, {}, 0, 0, scale, 0},
/*stride_width=*/1, /*stride_height=*/1,
/*padding=*/Padding_VALID,
/*activation=*/ActivationFunctionType_NONE,
/*dilation_width_factor=*/1,
/*dilation_height_factor=*/1,
/*num_threads=*/-1,
/*filter_data=*/{},
/*bias_type=*/TensorType_INT32);
m.SetInput<int16_t>({
// [1 * 2 * 3 * 2] as [batch, y, x, input_channel]
3, 2, // batch = 0, y = 0, x = 0
1, -1, // batch = 0, y = 0, x = 1
-2, -3, // batch = 0, y = 0, x = 2
4, 3, // batch = 0, y = 1, x = 0
2, -2, // batch = 0, y = 1, x = 1
-3, -4, // batch = 0, y = 1, x = 2
});
m.SetFilter(
// [2 * 2 * 2 * 2] as [output_channel, y, x, input_channel]
{
1, 2, // out channel = 0, y = 0, x = 0
3, 4, // out channel = 0, y = 0, x = 1
3, 4, // out channel = 0, y = 1, x = 0
5, 6, // out channel = 0, y = 1, x = 1
7, 8, // out channel = 1, y = 0, x = 0
5, 6, // out channel = 1, y = 0, x = 1
3, 4, // out channel = 1, y = 1, x = 0
1, 2, // out channel = 1, y = 1, x = 1
});
m.SetBias({3, -2});
// Invoke and verify output.
// output has dimension [1 * 1 * 2 * 2] as [batch, y, x, output_channel]
ASSERT_EQ(m.Invoke(), kTfLiteOk);
EXPECT_THAT(m.GetDequantizedOutput<int16_t>(),
ElementsAreArray(ArrayFloatNear({31, 63.99804688, -57, -46})));
EXPECT_THAT(m.GetOutput<int16_t>(),
ElementsAreArray({15872, 32767, -29184, -23552}));
}
TEST_P(ConvolutionOpTest, SimplePerChannel16x8Bias64) {
const float scale = 128.0 / 65536;
PerChannelQuantizedConvolutionOpModel m(
GetRegistration(), {TensorType_INT16, {1, 2, 3, 2}, 0, 0, scale, 0},
{TensorType_INT8,
// [2 * 2 * 2 * 2] as [output_channel, y, x, input_channel]
{2, 2, 2, 2},
0,
0,
0,
0,
/*per_channel_quantization=*/true,
/*per_channel_quantization_scales=*/{1, 2},
/*per_channel_quantization_offsets=*/{0, 0},
/*channel_index=*/0},
{TensorType_INT16, {}, 0, 0, scale, 0},
/*stride_width=*/1, /*stride_height=*/1,
/*padding=*/Padding_VALID,
/*activation=*/ActivationFunctionType_NONE,
/*dilation_width_factor=*/1,
/*dilation_height_factor=*/1,
/*num_threads=*/-1,
/*filter_data=*/{},
/*bias_type=*/TensorType_INT64);
m.SetInput<int16_t>({
// [1 * 2 * 3 * 2] as [batch, y, x, input_channel]
3, 2, // batch = 0, y = 0, x = 0
1, -1, // batch = 0, y = 0, x = 1
-2, -3, // batch = 0, y = 0, x = 2
4, 3, // batch = 0, y = 1, x = 0
2, -2, // batch = 0, y = 1, x = 1
-3, -4, // batch = 0, y = 1, x = 2
});
m.SetFilter(
// [2 * 2 * 2 * 2] as [output_channel, y, x, input_channel]
{
1, 2, // out channel = 0, y = 0, x = 0
3, 4, // out channel = 0, y = 0, x = 1
3, 4, // out channel = 0, y = 1, x = 0
5, 6, // out channel = 0, y = 1, x = 1
7, 8, // out channel = 1, y = 0, x = 0
5, 6, // out channel = 1, y = 0, x = 1
3, 4, // out channel = 1, y = 1, x = 0
1, 2, // out channel = 1, y = 1, x = 1
});
m.SetBias({3, -2});
// Invoke and verify output.
// output has dimension [1 * 1 * 2 * 2] as [batch, y, x, output_channel]
ASSERT_EQ(m.Invoke(), kTfLiteOk);
EXPECT_THAT(m.GetDequantizedOutput<int16_t>(),
ElementsAreArray(ArrayFloatNear({31, 63.99804688, -57, -46})));
EXPECT_THAT(m.GetOutput<int16_t>(),
ElementsAreArray({15872, 32767, -29184, -23552}));
}
TEST_P(ConvolutionOpTest, Simple4bitPerChannelTest) {
......@@ -1808,7 +1923,7 @@ TEST_P(ConvolutionOpTest, Simple4bitPerChannelTest) {
/*channel_index=*/0},
{TensorType_INT8, {}, -63.5, 64, 0.5, -1},
/*stride_width=*/1, /*stride_height=*/1);
m.SetInput({
m.SetInput<int8_t>({
// [1 * 2 * 3 * 2] as [batch, y, x, input_channel]
3, 2, // batch = 0, y = 0, x = 0
1, -1, // batch = 0, y = 0, x = 1
......@@ -1834,9 +1949,9 @@ TEST_P(ConvolutionOpTest, Simple4bitPerChannelTest) {
// Invoke and verify output.
// output has dimension [1 * 1 * 2 * 2] as [batch, y, x, output_channel]
ASSERT_EQ(m.Invoke(), kTfLiteOk);
EXPECT_THAT(m.GetDequantizedOutput(),
EXPECT_THAT(m.GetDequantizedOutput<int8_t>(),
ElementsAreArray(ArrayFloatNear({31, 64, -57, -46})));
EXPECT_THAT(m.GetOutput(), ElementsAreArray({61, 127, -115, -93}));
EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAreArray({61, 127, -115, -93}));
}
class HybridPerChannelConvolutionOpModel
......
......@@ -135,6 +135,7 @@ struct OpData {
bool ledger_initialized;
// Used for 4bit hybrid
std::unique_ptr<optimized_4bit::OpData4Bit> op_data_4bit = nullptr;
TfLiteType quantized_bias_type = kTfLiteNoType;
};
constexpr int kInputTensor = 0;
......@@ -464,6 +465,15 @@ TfLiteStatus PrepareImpl(TfLiteContext* context, TfLiteNode* node,
if (input->type == kTfLiteInt16 && output->type == kTfLiteInt16) {
TF_LITE_ENSURE_EQ(context, input->params.zero_point, 0);
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
// Check quantized_bias_type is either kTfLiteInt64 or kTfLiteInt32.
if (params->quantized_bias_type != kTfLiteFloat32) {
TF_LITE_ENSURE(context, params->quantized_bias_type == kTfLiteInt32 ||
params->quantized_bias_type == kTfLiteInt64);
TF_LITE_ENSURE(context, (bias == nullptr) ||
bias->type == params->quantized_bias_type);
data->quantized_bias_type = params->quantized_bias_type;
}
}
// If we have to perform on-the-fly quantization (with quantized weights and
......@@ -1097,17 +1107,18 @@ void FullyConnectedInt16(const OpData* data, const TfLiteTensor* input,
op_params.output_shift = data->output_shift;
op_params.quantized_activation_min = data->output_activation_min;
op_params.quantized_activation_max = data->output_activation_max;
if (bias && bias->type == kTfLiteInt64) {
if (data->quantized_bias_type == kTfLiteInt32) {
reference_integer_ops::FullyConnected(
op_params, GetTensorShape(input), GetTensorData<int16_t>(input),
GetTensorShape(filter), GetTensorData<int8_t>(filter),
GetTensorShape(bias), GetTensorData<int64_t>(bias),
GetTensorShape(bias), GetTensorData<int32_t>(bias),
GetTensorShape(output), GetTensorData<int16_t>(output));
} else {
reference_integer_ops::FullyConnected(
op_params, GetTensorShape(input), GetTensorData<int16_t>(input),
GetTensorShape(filter), GetTensorData<int8_t>(filter),
GetTensorShape(bias), GetTensorData<int32_t>(bias),
GetTensorShape(bias), GetTensorData<int64_t>(bias),
GetTensorShape(output), GetTensorData<int16_t>(output));
}
}
......@@ -1161,13 +1172,14 @@ void FullyConnectedPerChannelInt16(const OpData* data,
op_params.output_offset = output->params.zero_point;
op_params.quantized_activation_min = data->output_activation_min;
op_params.quantized_activation_max = data->output_activation_max;
if (bias && bias->type == kTfLiteInt64) {
if (data->quantized_bias_type == kTfLiteInt32) {
reference_integer_ops::FullyConnectedPerChannel(
op_params, data->per_channel_output_multiplier.data(),
data->per_channel_output_shift.data(), GetTensorShape(input),
GetTensorData<int16_t>(input), GetTensorShape(filter),
GetTensorData<int8_t>(filter), GetTensorShape(bias),
GetTensorData<int64_t>(bias), GetTensorShape(output),
GetTensorData<int32_t>(bias), GetTensorShape(output),
GetTensorData<int16_t>(output));
} else {
reference_integer_ops::FullyConnectedPerChannel(
......@@ -1175,7 +1187,7 @@ void FullyConnectedPerChannelInt16(const OpData* data,
data->per_channel_output_shift.data(), GetTensorShape(input),
GetTensorData<int16_t>(input), GetTensorShape(filter),
GetTensorData<int8_t>(filter), GetTensorShape(bias),
GetTensorData<int32_t>(bias), GetTensorShape(output),
GetTensorData<int64_t>(bias), GetTensorShape(output),
GetTensorData<int16_t>(output));
}
}
......@@ -1328,13 +1340,9 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
bool has_non_zero_point = input->params.zero_point ||
filter->params.zero_point ||
output->params.zero_point;
if (kernel_type == kReference || has_non_zero_point ||
(bias && bias->type == kTfLiteInt64)) {
is_per_channel ? FullyConnectedPerChannelInt16<kernel_type>(
data, input, filter, bias, output)
: FullyConnectedInt16<kernel_type>(
data, input, filter, bias, output);
} else {
if (kernel_type == kGenericOptimized &&
data->quantized_bias_type == kTfLiteInt32 &&
!has_non_zero_point) {
is_per_channel
? optimized_integer_ops::FullyConnectedPerChannel(
op_params, data->per_channel_output_multiplier.data(),
......@@ -1351,6 +1359,11 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
GetTensorData<int32_t>(bias), GetTensorShape(output),
GetTensorData<int16_t>(output),
CpuBackendContext::GetFromContext(context));
} else {
is_per_channel ? FullyConnectedPerChannelInt16<kernel_type>(
data, input, filter, bias, output)
: FullyConnectedInt16<kernel_type>(
data, input, filter, bias, output);
}
} else if (kernel_type == kReference) {
reference_ops::FullyConnected(
......
......@@ -237,8 +237,9 @@ class BaseFullyConnectedOpModel : public SingleOpModel {
SetBuiltinOp(BuiltinOperator_FULLY_CONNECTED,
BuiltinOptions_FullyConnectedOptions,
CreateFullyConnectedOptions(builder_, activation_func,
weights_format, keep_num_dims)
CreateFullyConnectedOptions(
builder_, activation_func, weights_format, keep_num_dims,
/*asymmetric_quantize_inputs=*/true, bias_type)
.Union());
resolver_ = std::make_unique<SingleOpResolver>(
BuiltinOperator_FULLY_CONNECTED, registration);
......@@ -796,6 +797,38 @@ TEST_P(QuantizedFullyConnectedOpTest, SimpleTestPerChannelQuantizedInt8) {
EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAre(23, 24, 25, 57, 58, 59));
}
TEST_P(QuantizedFullyConnectedOpTest, SimpleTestQuantizedInt16NoBias) {
const float scale = 128.0 / 65536;
QuantizedFullyConnectedOpModel m(
GetRegistration(), /*units=*/3, /*batches*/ 2,
/*input=*/{TensorType_INT16, {2, 10}, 0, 0, scale, 0},
/*output=*/{TensorType_INT16, {}, 0, 0, scale, 0},
/*bias_type=*/TensorType_INT64,
/*keep_num_dims=*/false, /*bool bias_tensor_optional=*/true,
/*ActivationFunctionType activation_func=*/ActivationFunctionType_RELU,
/*FullyConnectedOptionsWeightsFormat weights_format=*/
FullyConnectedOptionsWeightsFormat_DEFAULT);
// input_product_scale < output_scale was not true.
m.SetWeights<int8_t>({
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 2
});
m.SetInput<int16_t>({
1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0
1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1
});
ASSERT_EQ(m.Invoke(), kTfLiteOk);
EXPECT_THAT(m.GetDequantizedOutput<int16_t>(),
ElementsAreArray(ArrayFloatNear({23, 23, 23, 57, 57, 57})));
EXPECT_THAT(m.GetOutput<int16_t>(),
ElementsAre(11776, 11776, 11776, 29184, 29184, 29184));
}
TEST_P(QuantizedFullyConnectedOpTest, SimpleTestQuantizedInt16Bias32) {
const float scale = 128.0 / 65536;
QuantizedFullyConnectedOpModel m(
......
......@@ -242,7 +242,7 @@ BuiltinRefOpResolver::BuiltinRefOpResolver() {
AddBuiltin(BuiltinOperator_L2_POOL_2D, Register_L2_POOL_REF());
AddBuiltin(BuiltinOperator_CONV_2D, Register_CONVOLUTION_REF(),
/* min_version = */ 1,
/* max_version = */ 6);
/* max_version = */ 8);
AddBuiltin(BuiltinOperator_DEPTHWISE_CONV_2D,
Register_DEPTHWISE_CONVOLUTION_REF(),
/* min_version = */ 1,
......@@ -268,7 +268,7 @@ BuiltinRefOpResolver::BuiltinRefOpResolver() {
Register_EMBEDDING_LOOKUP_SPARSE());
AddBuiltin(BuiltinOperator_FULLY_CONNECTED, Register_FULLY_CONNECTED_REF(),
/* min_version */ 1,
/* max_version */ 9);
/* max_version */ 11);
AddBuiltin(BuiltinOperator_LSH_PROJECTION, Register_LSH_PROJECTION());
AddBuiltin(BuiltinOperator_HASHTABLE_LOOKUP, Register_HASHTABLE_LOOKUP());
AddBuiltin(BuiltinOperator_SOFTMAX, Register_SOFTMAX_REF(),
......@@ -407,7 +407,7 @@ BuiltinRefOpResolver::BuiltinRefOpResolver() {
AddBuiltin(BuiltinOperator_COS, Register_COS());
AddBuiltin(BuiltinOperator_TRANSPOSE_CONV, Register_TRANSPOSECONV_REF(),
/* min_version = */ 1,
/* max_version = */ 3);
/* max_version = */ 5);
AddBuiltin(BuiltinOperator_TILE, Register_TILE(),
/* min_version = */ 1,
/* max_version = */ 2);
......
......@@ -90,6 +90,8 @@ struct OpData {
bool has_col2im = false;
bool weights_are_transposed = false;
TfLiteType quantized_bias_type = kTfLiteNoType;
};
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
......@@ -240,6 +242,8 @@ TfLiteStatus ResizeAndTransposeWeights(TfLiteContext* context,
template <KernelType kernel_type>
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
OpData* data = reinterpret_cast<OpData*>(node->user_data);
auto* params =
reinterpret_cast<TfLiteTransposeConvParams*>(node->builtin_data);
bool has_bias = NumInputs(node) == 4;
......@@ -295,6 +299,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, weights->type, kTfLiteInt8);
TF_LITE_ENSURE_EQ(context, input->params.zero_point, 0);
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
// Check quantized_bias_type is either kTfLiteInt64 or kTfLiteInt32.
if (params->quantized_bias_type != kTfLiteFloat32) {
TF_LITE_ENSURE(context, params->quantized_bias_type == kTfLiteInt32 ||
params->quantized_bias_type == kTfLiteInt64);
TF_LITE_ENSURE(context, (bias == nullptr) ||
bias->type == params->quantized_bias_type);
data->quantized_bias_type = params->quantized_bias_type;
}
} else {
TF_LITE_ENSURE_TYPES_EQ(context, weights->type, input->type);
}
......@@ -354,7 +367,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(
context, GetTemporarySafe(context, node, data->scratch_tensor_index,
&scratch_buffer));
if (input->type == kTfLiteInt16 && bias && bias->type == kTfLiteInt64) {
if (data->quantized_bias_type != kTfLiteNoType) {
scratch_buffer->type = data->quantized_bias_type;
} else if (input->type == kTfLiteInt16) {
scratch_buffer->type = kTfLiteInt64;
} else {
scratch_buffer->type = kTfLiteInt32;
......@@ -565,9 +581,31 @@ void EvalQuantizedPerChannel16x8(
weights->params.zero_point ||
output->params.zero_point;
// Fallback to reference kernel when bias_type is int64 as
// there is no optimized kernel for int64 bias yet.
if (bias && bias->type == kTfLiteInt64) {
if (data->quantized_bias_type == kTfLiteInt32) {
if (kernel_type == kReference || has_non_zero_point) {
reference_integer_ops::TransposeConv(
op_params, data->per_channel_output_multiplier.data(),
data->per_channel_output_shift.data(), GetTensorShape(input),
GetTensorData<int16>(input), GetTensorShape(weights),
GetTensorData<int8>(weights), GetTensorShape(bias),
GetTensorData<int32_t>(bias), GetTensorShape(output),
GetTensorData<int16>(output), GetTensorShape(col2im),
GetTensorData<int8>(col2im), GetTensorData<int32_t>(scratch_buffer));
} else {
optimized_integer_ops::TransposeConvV2(
op_params, data->per_channel_output_multiplier.data(),
data->per_channel_output_shift.data(), GetTensorShape(input),
GetTensorData<int16>(input), GetTensorShape(transposed_weights),
GetTensorData<int8>(transposed_weights), GetTensorShape(bias),
GetTensorData<int32_t>(bias), GetTensorShape(output),
GetTensorData<int16>(output), GetTensorShape(col2im),
GetTensorData<int32>(col2im), GetTensorData<int32>(scratch_buffer),
CpuBackendContext::GetFromContext(context));
}
} else {
TFLITE_DCHECK(!has_non_zero_point);
// Fallback to reference kernel when bias_type is int64 as
// there is no optimized kernel for int64 bias yet.
reference_integer_ops::TransposeConv(
op_params, data->per_channel_output_multiplier.data(),
data->per_channel_output_shift.data(), GetTensorShape(input),
......@@ -576,25 +614,6 @@ void EvalQuantizedPerChannel16x8(
GetTensorData<int64_t>(bias), GetTensorShape(output),
GetTensorData<int16>(output), GetTensorShape(col2im),
GetTensorData<int8>(col2im), GetTensorData<int64_t>(scratch_buffer));
} else if (kernel_type == kReference || has_non_zero_point) {
reference_integer_ops::TransposeConv(
op_params, data->per_channel_output_multiplier.data(),
data->per_channel_output_shift.data(), GetTensorShape(input),
GetTensorData<int16>(input), GetTensorShape(weights),
GetTensorData<int8>(weights), GetTensorShape(bias),
GetTensorData<int32_t>(bias), GetTensorShape(output),
GetTensorData<int16>(output), GetTensorShape(col2im),
GetTensorData<int8>(col2im), GetTensorData<int32_t>(scratch_buffer));
} else {
optimized_integer_ops::TransposeConvV2(
op_params, data->per_channel_output_multiplier.data(),
data->per_channel_output_shift.data(), GetTensorShape(input),
GetTensorData<int16>(input), GetTensorShape(transposed_weights),
GetTensorData<int8>(transposed_weights), GetTensorShape(bias),
GetTensorData<int32>(bias), GetTensorShape(output),
GetTensorData<int16>(output), GetTensorShape(col2im),
GetTensorData<int32>(col2im), GetTensorData<int32>(scratch_buffer),
CpuBackendContext::GetFromContext(context));
}
}
......
......@@ -61,7 +61,8 @@ class BaseTransposeConvOpModel : public SingleOpModel {
const TensorData& input, const TensorData& output,
Padding padding, int stride_w, int stride_h,
tflite::ActivationFunctionType fused_activation,
TestType test_type, int version = 1) {
TestType test_type, int version = 1,
const TensorType& bias_type = TensorType_INT32) {
// Just to be confusing, transpose_conv has an _input_ named "output_shape"
// that sets the shape of the output tensor of the op :). It must always be
// an int32 1D four element tensor.
......@@ -76,11 +77,11 @@ class BaseTransposeConvOpModel : public SingleOpModel {
output_ = AddOutput(output);
SetBuiltinOp(BuiltinOperator_TRANSPOSE_CONV,
BuiltinOptions_TransposeConvOptions,
CreateTransposeConvOptions(builder_, padding, stride_w,
stride_h, fused_activation)
.Union());
SetBuiltinOp(
BuiltinOperator_TRANSPOSE_CONV, BuiltinOptions_TransposeConvOptions,
CreateTransposeConvOptions(builder_, padding, stride_w, stride_h,
fused_activation, bias_type)
.Union());
resolver_ = std::make_unique<SingleOpResolver>(
BuiltinOperator_TRANSPOSE_CONV, registration, version);
BuildInterpreter(
......@@ -613,7 +614,7 @@ class PerChannelQuantizedTransposeConvOpModel16x8
}
};
TEST_P(TransposeConvOpTest, SimpleTestQuantizedPerChannel16x8) {
TEST_P(TransposeConvOpTest, SimpleTestQuantizedPerChannel16x8NoBiasInt32) {
const std::initializer_list<float> filter_data = {
// [2 * 2 * 2 * 2] as [output_channel, y, x, input_channel]
1, 2, // out channel = 0, y = 0, x = 0
......@@ -652,7 +653,8 @@ TEST_P(TransposeConvOpTest, SimpleTestQuantizedPerChannel16x8) {
/*zero_point=*/0},
/*padding=*/Padding_SAME,
/*stride_w=*/1, /*stride_h=*/1,
/*fused_activation_function=*/ActivationFunctionType_NONE, GetTestType());
/*fused_activation_function=*/ActivationFunctionType_NONE, GetTestType(),
/*bias_type=*/TensorType_INT32);
model.SetInput({
// [1 * 2 * 3 * 2] as [batch, y, x, input_channel]
3, 2, // batch = 0, y = 0, x = 0
......@@ -739,18 +741,80 @@ TEST_P(TransposeConvOpTest,
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 2, 3, 2}));
}
template <typename InputType>
TEST_P(TransposeConvOpTest, SimpleTestQuantizedPerChannel16x8NoBiasInt64) {
// Float would be {1, 2, 3, 4, 5, 6, 7, 8, 9}
const std::initializer_list<float> filter_data = {
// [2 * 2 * 2 * 2] as [output_channel, y, x, input_channel]
1, 2, // out channel = 0, y = 0, x = 0
3, 4, // out channel = 0, y = 0, x = 1
3, 4, // out channel = 0, y = 1, x = 0
5, 6, // out channel = 0, y = 1, x = 1
7, 8, // out channel = 1, y = 0, x = 0
5, 6, // out channel = 1, y = 0, x = 1
3, 4, // out channel = 1, y = 1, x = 0
1, 2, // out channel = 1, y = 1, x = 1
};
PerChannelQuantizedTransposeConvOpModel16x8 model(
GetRegistration(),
/*output_shape_data=*/{1, 2, 3, 2},
/*filter=*/
{TensorType_INT8,
/*shape=*/{2, 2, 2, 2},
/*min=*/-64, /*max=*/64,
/*scale=*/0, /*zero_point=*/0,
/*per_channel_quantization=*/true,
/*per_channel_quantization_scales=*/{7.0 / 127, 8.0 / 127},
/*per_channel_quantization_offsets=*/{0, 0},
/*channel_index=*/0},
/*filter_data=*/{},
/*input=*/
{TensorType_INT16,
/*shape=*/{1, 2, 3, 2},
/*min=*/0, /*max=*/0,
/*scale=*/4.0 / 127,
/*zero_point=*/0},
/*output=*/
{TensorType_INT16,
/*shape=*/{},
/*min=*/0, /*max=*/0,
/*scale=*/1.0,
/*zero_point=*/0},
/*padding=*/Padding_SAME,
/*stride_w=*/1, /*stride_h=*/1,
/*fused_activation_function=*/ActivationFunctionType_NONE, GetTestType(),
/*version=*/1,
/*bias_type=*/TensorType_INT64);
model.SetInput({
// [1 * 2 * 3 * 2] as [batch, y, x, input_channel]
3, 2, // batch = 0, y = 0, x = 0
1, -1, // batch = 0, y = 0, x = 1
-2, -3, // batch = 0, y = 0, x = 2
4, 3, // batch = 0, y = 1, x = 0
2, -2, // batch = 0, y = 1, x = 1
-3, -4, // batch = 0, y = 1, x = 2
});
model.SetFilter(filter_data);
ASSERT_EQ(model.Invoke(), kTfLiteOk);
EXPECT_THAT(model.GetDequantizedOutput(),
ElementsAreArray(ArrayFloatNear(
{7, 37, 16, 26, -9, -39, 27, 69, 48, 42, -32, -74}, 1e-5)));
// GetOutputShape() should always be same as model.SetOutputShape(...);
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 2, 3, 2}));
}
template <typename InputType, typename FilterType>
class BaseTransposeConvBiasOpModel : public SingleOpModel {
public:
BaseTransposeConvBiasOpModel(TfLiteRegistration* registration,
std::initializer_list<int> output_shape_data,
const TensorData& filter,
std::initializer_list<InputType> filter_data,
const TensorData& input,
const TensorData& output, Padding padding,
int stride_w, int stride_h,
tflite::ActivationFunctionType fused_activation,
TestType test_type, int version = 3) {
BaseTransposeConvBiasOpModel(
TfLiteRegistration* registration,
std::initializer_list<int> output_shape_data, const TensorData& filter,
std::initializer_list<FilterType> filter_data, const TensorData& input,
const TensorData& output, Padding padding, int stride_w, int stride_h,
tflite::ActivationFunctionType fused_activation, TestType test_type,
int version = 3, const TensorType& bias_type = TensorType_INT32) {
bias_type_ = bias_type;
if (test_type == TestType::kDynamic) {
output_shape_ = AddInput({TensorType_INT32, {4}});
filter_ = AddInput(filter);
......@@ -759,54 +823,59 @@ class BaseTransposeConvBiasOpModel : public SingleOpModel {
filter_ = AddConstInput(filter, filter_data);
}
input_ = AddInput(input);
int bias_size = GetShape(filter_)[0];
if (input.type == TensorType_FLOAT32) {
bias_type_ = TensorType_FLOAT32;
bias_ = AddInput({TensorType_FLOAT32, {bias_size}});
} else if (input.type == TensorType_INT8) {
// per channel quantization.
std::vector<float> bias_scale(
filter.per_channel_quantization_scales.size());
std::vector<int64_t> bias_zero_points(
filter.per_channel_quantization_scales.size());
for (size_t i = 0; i < filter.per_channel_quantization_scales.size();
++i) {
bias_scale[i] = input.scale * filter.per_channel_quantization_scales[i];
bias_zero_points[i] = 0;
}
TensorData bias{TensorType_INT32,
{bias_size},
/*min=*/0,
/*max=*/0,
/*scale=*/0,
/*zero_point=*/0,
true,
/*per_channel_quantization_scales=*/bias_scale,
/*per_channel_quantization_offsets=*/bias_zero_points,
/*channel_index==*/0};
bias_ = AddInput(bias);
} else {
// per tensor quantization.
auto bias_scale = GetScale(input_) * GetScale(filter_);
TensorData bias{TensorType_INT32, {bias_size}, 0, 0, bias_scale};
bias_ = AddInput(bias);
if (filter.per_channel_quantization) {
// per channel quantization.
std::vector<float> bias_scale(
filter.per_channel_quantization_scales.size());
std::vector<int64_t> bias_zero_points(
filter.per_channel_quantization_scales.size());
for (size_t i = 0; i < filter.per_channel_quantization_scales.size();
++i) {
bias_scale[i] =
input.scale * filter.per_channel_quantization_scales[i];
bias_zero_points[i] = 0;
}
TensorData bias{bias_type,
{bias_size},
/*min=*/0,
/*max=*/0,
/*scale=*/0,
/*zero_point=*/0,
true,
/*per_channel_quantization_scales=*/bias_scale,
/*per_channel_quantization_offsets=*/bias_zero_points,
/*channel_index==*/0};
bias_ = AddInput(bias);
} else {
// per tensor quantization.
auto bias_scale = GetScale(input_) * GetScale(filter_);
TensorData bias{bias_type, {bias_size}, 0, 0, bias_scale};
bias_ = AddInput(bias);
}
}
output_ = AddOutput(output);
SetBuiltinOp(BuiltinOperator_TRANSPOSE_CONV,
BuiltinOptions_TransposeConvOptions,
CreateTransposeConvOptions(builder_, padding, stride_w,
stride_h, fused_activation)
.Union());
SetBuiltinOp(
BuiltinOperator_TRANSPOSE_CONV, BuiltinOptions_TransposeConvOptions,
CreateTransposeConvOptions(builder_, padding, stride_w, stride_h,
fused_activation, bias_type)
.Union());
resolver_ = std::make_unique<SingleOpResolver>(
BuiltinOperator_TRANSPOSE_CONV, registration, version);
BuildInterpreter({GetShape(output_shape_), GetShape(filter_),
GetShape(input_), GetShape(bias_)});
if (test_type == TestType::kDynamic) {
PopulateTensor<int32_t>(output_shape_, output_shape_data);
PopulateTensor<InputType>(filter_, filter_data);
if (!std::is_same<InputType, int16_t>::value &&
!std::is_same<InputType, int8_t>::value) {
PopulateTensor<FilterType>(filter_, filter_data);
}
}
}
......@@ -815,6 +884,8 @@ class BaseTransposeConvBiasOpModel : public SingleOpModel {
QuantizeAndPopulate<uint8_t>(input_, data);
} else if (std::is_same<InputType, int8_t>::value) {
QuantizeAndPopulate<int8_t>(input_, data);
} else if (std::is_same<InputType, int16_t>::value) {
QuantizeAndPopulate<int16_t>(input_, data);
} else {
PopulateTensor(input_, data);
}
......@@ -823,7 +894,7 @@ class BaseTransposeConvBiasOpModel : public SingleOpModel {
void SetBias(std::initializer_list<float> bias) {
if (std::is_same<InputType, uint8_t>::value) {
QuantizeAndPopulate<int32_t>(bias_, bias);
} else if (std::is_same<InputType, int8_t>::value) {
} else if (std::is_same<FilterType, int8_t>::value) {
PerChannelQuantizeBias(bias_, bias);
} else {
PopulateTensor(bias_, bias);
......@@ -838,9 +909,11 @@ class BaseTransposeConvBiasOpModel : public SingleOpModel {
int input_;
int bias_;
int output_;
TensorType bias_type_;
};
class TransposeConvOpBiasModel : public BaseTransposeConvBiasOpModel<float> {
class TransposeConvOpBiasModel
: public BaseTransposeConvBiasOpModel<float, float> {
public:
using BaseTransposeConvBiasOpModel::BaseTransposeConvBiasOpModel;
......@@ -925,7 +998,7 @@ TEST_P(TransposeConvOpTest, MultiChannelBiasWithFusedActivationTest) {
}
class QuantizedTransposeConvBiasOpModel
: public BaseTransposeConvBiasOpModel<uint8_t> {
: public BaseTransposeConvBiasOpModel<uint8_t, uint8_t> {
public:
using BaseTransposeConvBiasOpModel::BaseTransposeConvBiasOpModel;
......@@ -989,7 +1062,7 @@ TEST_P(TransposeConvOpTest, SimpleBiasWithFusedActivationTestQuantized) {
}
class PerChannelQuantizedTransposeConvBiasOpModel
: public BaseTransposeConvBiasOpModel<int8_t> {
: public BaseTransposeConvBiasOpModel<int8_t, int8_t> {
public:
using BaseTransposeConvBiasOpModel::BaseTransposeConvBiasOpModel;
......@@ -1035,6 +1108,149 @@ TEST_P(TransposeConvOpTest, SimpleBiasTestQuantizedPerChannelSingleChannel) {
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
}
class PerChannel16x8TransposeConvBiasOpModel
: public BaseTransposeConvBiasOpModel<int16_t, int8_t> {
public:
using BaseTransposeConvBiasOpModel::BaseTransposeConvBiasOpModel;
std::vector<float> GetDequantizedOutput() {
return Dequantize<int16_t>(ExtractVector<int16_t>(output_),
GetScale(output_), GetZeroPoint(output_));
}
void SetFilter(const std::initializer_list<float>& data) {
PerChannelSymmetricQuantizeAndPopulate(filter_, data);
}
};
TEST_P(TransposeConvOpTest, SimpleBiasTestQuantizedPerChannel16x8Bias32) {
const float scale = 128.0 / 65536;
const std::initializer_list<float> filter_data = {
// [2 * 2 * 2 * 2] as [output_channel, y, x, input_channel]
1, 2, // out channel = 0, y = 0, x = 0
3, 4, // out channel = 0, y = 0, x = 1
3, 4, // out channel = 0, y = 1, x = 0
5, 6, // out channel = 0, y = 1, x = 1
7, 8, // out channel = 1, y = 0, x = 0
5, 6, // out channel = 1, y = 0, x = 1
3, 4, // out channel = 1, y = 1, x = 0
1, 2, // out channel = 1, y = 1, x = 1
};
PerChannel16x8TransposeConvBiasOpModel model(
GetRegistration(),
/*output_shape_data=*/{1, 2, 3, 2},
/*filter=*/
{TensorType_INT8,
/*shape=*/{2, 2, 2, 2},
/*min=*/-64, /*max=*/64,
/*scale=*/0, /*zero_point=*/0,
/*per_channel_quantization=*/true,
/*per_channel_quantization_scales=*/{7.0 / 127, 8.0 / 127},
/*per_channel_quantization_offsets=*/{0, 0},
/*channel_index=*/0},
/*filter_data=*/{},
/*input=*/
{TensorType_INT16,
/*shape=*/{1, 2, 3, 2},
/*min=*/0, /*max=*/0,
/*scale=*/4.0 / 127,
/*zero_point=*/0},
/*output=*/
{TensorType_INT16,
/*shape=*/{},
/*min=*/0, /*max=*/0,
/*scale=*/scale,
/*zero_point=*/0},
/*padding=*/Padding_SAME,
/*stride_w=*/1, /*stride_h=*/1,
/*fused_activation_function=*/ActivationFunctionType_NONE, GetTestType(),
/*bias_type=*/TensorType_INT32);
model.SetInput({
// [1 * 2 * 3 * 2] as [batch, y, x, input_channel]
3, 2, // batch = 0, y = 0, x = 0
1, -1, // batch = 0, y = 0, x = 1
-2, -3, // batch = 0, y = 0, x = 2
4, 3, // batch = 0, y = 1, x = 0
2, -2, // batch = 0, y = 1, x = 1
-3, -4, // batch = 0, y = 1, x = 2
});
model.SetFilter(filter_data);
model.SetBias({3, -2});
ASSERT_EQ(model.Invoke(), kTfLiteOk);
EXPECT_THAT(model.GetDequantizedOutput(),
ElementsAreArray(ArrayFloatNear(
{10, 35, 19, 24, -6, -41, 30, 64, 51, 40, -29, -64}, 0.19)));
// GetOutputShape() should always be same as model.SetOutputShape(...);
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 2, 3, 2}));
}
TEST_P(TransposeConvOpTest, SimpleBiasTestQuantizedPerChannel16x8Bias64) {
const float scale = 128.0 / 65536;
const std::initializer_list<float> filter_data = {
// [2 * 2 * 2 * 2] as [output_channel, y, x, input_channel]
1, 2, // out channel = 0, y = 0, x = 0
3, 4, // out channel = 0, y = 0, x = 1
3, 4, // out channel = 0, y = 1, x = 0
5, 6, // out channel = 0, y = 1, x = 1
7, 8, // out channel = 1, y = 0, x = 0
5, 6, // out channel = 1, y = 0, x = 1
3, 4, // out channel = 1, y = 1, x = 0
1, 2, // out channel = 1, y = 1, x = 1
};
PerChannel16x8TransposeConvBiasOpModel model(
GetRegistration(),
/*output_shape_data=*/{1, 2, 3, 2},
/*filter=*/
{TensorType_INT8,
/*shape=*/{2, 2, 2, 2},
/*min=*/-64, /*max=*/64,
/*scale=*/0, /*zero_point=*/0,
/*per_channel_quantization=*/true,
/*per_channel_quantization_scales=*/{7.0 / 127, 8.0 / 127},
/*per_channel_quantization_offsets=*/{0, 0},
/*channel_index=*/0},
/*filter_data=*/{},
/*input=*/
{TensorType_INT16,
/*shape=*/{1, 2, 3, 2},
/*min=*/0, /*max=*/0,
/*scale=*/4.0 / 127,
/*zero_point=*/0},
/*output=*/
{TensorType_INT16,
/*shape=*/{},
/*min=*/0, /*max=*/0,
/*scale=*/scale,
/*zero_point=*/0},
/*padding=*/Padding_SAME,
/*stride_w=*/1, /*stride_h=*/1,
/*fused_activation_function=*/ActivationFunctionType_NONE, GetTestType(),
/*bias_type=*/TensorType_INT64);
model.SetInput({
// [1 * 2 * 3 * 2] as [batch, y, x, input_channel]
3, 2, // batch = 0, y = 0, x = 0
1, -1, // batch = 0, y = 0, x = 1
-2, -3, // batch = 0, y = 0, x = 2
4, 3, // batch = 0, y = 1, x = 0
2, -2, // batch = 0, y = 1, x = 1
-3, -4, // batch = 0, y = 1, x = 2
});
model.SetFilter(filter_data);
model.SetBias({3, -2});
ASSERT_EQ(model.Invoke(), kTfLiteOk);
EXPECT_THAT(model.GetDequantizedOutput(),
ElementsAreArray(ArrayFloatNear(
{10, 35, 19, 24, -6, -41, 30, 64, 51, 40, -29, -64}, 0.19)));
// GetOutputShape() should always be same as model.SetOutputShape(...);
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 2, 3, 2}));
}
INSTANTIATE_TEST_SUITE_P(
TransposeConvOpTest, TransposeConvOpTest,
::testing::Combine(
......
......@@ -632,6 +632,8 @@ class TFLiteConverterBase:
self._experimental_tf_quantization_mode = None
# If unset, bias:int32 is by default except 16x8 quant.
# For 16x8 quant, bias:int64 is used to prevent any overflow by default.
# The accumulator type will be the same as bias type set by
# full_integer_quantization_bias_type.
self._experimental_full_integer_quantization_bias_type = None
# Provides specs for quantization, whether preset or custom.
self._experimental_quantization_options = None
......
......@@ -789,6 +789,9 @@ table Conv2DOptions {
fused_activation_function:ActivationFunctionType;
dilation_w_factor:int = 1;
dilation_h_factor:int = 1;
// Parameters for Conv2D version 8 or above.
// When set, quantized_bias_type defines the dtype for both bias and accumulator.
quantized_bias_type: TensorType;
}
// Options for both Conv3D and Conv3DTranspose.
......@@ -896,6 +899,10 @@ table FullyConnectedOptions {
// If set to true, then weights-only op will use asymmetric quantization for
// inputs.
asymmetric_quantize_inputs: bool;
// Parameters for FullyConnected version 11 or above.
// When set, quantized_bias_type defines the dtype for both bias and accumulator.
quantized_bias_type: TensorType;
}
table SoftmaxOptions {
......@@ -1155,6 +1162,11 @@ table TransposeConvOptions {
// Parameters supported by version 4:
fused_activation_function:ActivationFunctionType = NONE;
// Parameters for TransposeConv version 5 or above.
// If set, use this for bias and accumulator.
// When set, quantized_bias_type defines the dtype for both bias and accumulator.
quantized_bias_type: TensorType;
}
table ExpandDimsOptions {
......
......@@ -7521,6 +7521,7 @@ struct Conv2DOptionsT : public ::flatbuffers::NativeTable {
tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE;
int32_t dilation_w_factor = 1;
int32_t dilation_h_factor = 1;
tflite::TensorType quantized_bias_type = tflite::TensorType_FLOAT32;
};
struct Conv2DOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
......@@ -7532,7 +7533,8 @@ struct Conv2DOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
VT_STRIDE_H = 8,
VT_FUSED_ACTIVATION_FUNCTION = 10,
VT_DILATION_W_FACTOR = 12,
VT_DILATION_H_FACTOR = 14
VT_DILATION_H_FACTOR = 14,
VT_quantized_bias_type = 16
};
tflite::Padding padding() const {
return static_cast<tflite::Padding>(GetField<int8_t>(VT_PADDING, 0));
......@@ -7552,6 +7554,9 @@ struct Conv2DOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
int32_t dilation_h_factor() const {
return GetField<int32_t>(VT_DILATION_H_FACTOR, 1);
}
tflite::TensorType quantized_bias_type() const {
return static_cast<tflite::TensorType>(GetField<int8_t>(VT_quantized_bias_type, 0));
}
bool Verify(::flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyField<int8_t>(verifier, VT_PADDING, 1) &&
......@@ -7560,6 +7565,7 @@ struct Conv2DOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION, 1) &&
VerifyField<int32_t>(verifier, VT_DILATION_W_FACTOR, 4) &&
VerifyField<int32_t>(verifier, VT_DILATION_H_FACTOR, 4) &&
VerifyField<int8_t>(verifier, VT_quantized_bias_type, 1) &&
verifier.EndTable();
}
Conv2DOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const;
......@@ -7589,6 +7595,9 @@ struct Conv2DOptionsBuilder {
void add_dilation_h_factor(int32_t dilation_h_factor) {
fbb_.AddElement<int32_t>(Conv2DOptions::VT_DILATION_H_FACTOR, dilation_h_factor, 1);
}
void add_quantized_bias_type(tflite::TensorType quantized_bias_type) {
fbb_.AddElement<int8_t>(Conv2DOptions::VT_quantized_bias_type, static_cast<int8_t>(quantized_bias_type), 0);
}
explicit Conv2DOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
......@@ -7607,12 +7616,14 @@ inline ::flatbuffers::Offset<Conv2DOptions> CreateConv2DOptions(
int32_t stride_h = 0,
tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE,
int32_t dilation_w_factor = 1,
int32_t dilation_h_factor = 1) {
int32_t dilation_h_factor = 1,
tflite::TensorType quantized_bias_type = tflite::TensorType_FLOAT32) {
Conv2DOptionsBuilder builder_(_fbb);
builder_.add_dilation_h_factor(dilation_h_factor);
builder_.add_dilation_w_factor(dilation_w_factor);
builder_.add_stride_h(stride_h);
builder_.add_stride_w(stride_w);
builder_.add_quantized_bias_type(quantized_bias_type);
builder_.add_fused_activation_function(fused_activation_function);
builder_.add_padding(padding);
return builder_.Finish();
......@@ -8418,6 +8429,7 @@ struct FullyConnectedOptionsT : public ::flatbuffers::NativeTable {
tflite::FullyConnectedOptionsWeightsFormat weights_format = tflite::FullyConnectedOptionsWeightsFormat_DEFAULT;
bool keep_num_dims = false;
bool asymmetric_quantize_inputs = false;
tflite::TensorType quantized_bias_type = tflite::TensorType_FLOAT32;
};
struct FullyConnectedOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
......@@ -8427,7 +8439,8 @@ struct FullyConnectedOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Ta
VT_FUSED_ACTIVATION_FUNCTION = 4,
VT_WEIGHTS_FORMAT = 6,
VT_KEEP_NUM_DIMS = 8,
VT_ASYMMETRIC_QUANTIZE_INPUTS = 10
VT_ASYMMETRIC_QUANTIZE_INPUTS = 10,
VT_quantized_bias_type = 12
};
tflite::ActivationFunctionType fused_activation_function() const {
return static_cast<tflite::ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
......@@ -8441,12 +8454,16 @@ struct FullyConnectedOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Ta
bool asymmetric_quantize_inputs() const {
return GetField<uint8_t>(VT_ASYMMETRIC_QUANTIZE_INPUTS, 0) != 0;
}
tflite::TensorType quantized_bias_type() const {
return static_cast<tflite::TensorType>(GetField<int8_t>(VT_quantized_bias_type, 0));
}
bool Verify(::flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION, 1) &&
VerifyField<int8_t>(verifier, VT_WEIGHTS_FORMAT, 1) &&
VerifyField<uint8_t>(verifier, VT_KEEP_NUM_DIMS, 1) &&
VerifyField<uint8_t>(verifier, VT_ASYMMETRIC_QUANTIZE_INPUTS, 1) &&
VerifyField<int8_t>(verifier, VT_quantized_bias_type, 1) &&
verifier.EndTable();
}
FullyConnectedOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const;
......@@ -8470,6 +8487,9 @@ struct FullyConnectedOptionsBuilder {
void add_asymmetric_quantize_inputs(bool asymmetric_quantize_inputs) {
fbb_.AddElement<uint8_t>(FullyConnectedOptions::VT_ASYMMETRIC_QUANTIZE_INPUTS, static_cast<uint8_t>(asymmetric_quantize_inputs), 0);
}
void add_quantized_bias_type(tflite::TensorType quantized_bias_type) {
fbb_.AddElement<int8_t>(FullyConnectedOptions::VT_quantized_bias_type, static_cast<int8_t>(quantized_bias_type), 0);
}
explicit FullyConnectedOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
......@@ -8486,8 +8506,10 @@ inline ::flatbuffers::Offset<FullyConnectedOptions> CreateFullyConnectedOptions(
tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE,
tflite::FullyConnectedOptionsWeightsFormat weights_format = tflite::FullyConnectedOptionsWeightsFormat_DEFAULT,
bool keep_num_dims = false,
bool asymmetric_quantize_inputs = false) {
bool asymmetric_quantize_inputs = false,
tflite::TensorType quantized_bias_type = tflite::TensorType_FLOAT32) {
FullyConnectedOptionsBuilder builder_(_fbb);
builder_.add_quantized_bias_type(quantized_bias_type);
builder_.add_asymmetric_quantize_inputs(asymmetric_quantize_inputs);
builder_.add_keep_num_dims(keep_num_dims);
builder_.add_weights_format(weights_format);
......@@ -11061,6 +11083,7 @@ struct TransposeConvOptionsT : public ::flatbuffers::NativeTable {
int32_t stride_w = 0;
int32_t stride_h = 0;
tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE;
tflite::TensorType quantized_bias_type = tflite::TensorType_FLOAT32;
};
struct TransposeConvOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table {
......@@ -11070,7 +11093,8 @@ struct TransposeConvOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Tab
VT_PADDING = 4,
VT_STRIDE_W = 6,
VT_STRIDE_H = 8,
VT_FUSED_ACTIVATION_FUNCTION = 10
VT_FUSED_ACTIVATION_FUNCTION = 10,
VT_quantized_bias_type = 12
};
tflite::Padding padding() const {
return static_cast<tflite::Padding>(GetField<int8_t>(VT_PADDING, 0));
......@@ -11081,15 +11105,20 @@ struct TransposeConvOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Tab
int32_t stride_h() const {
return GetField<int32_t>(VT_STRIDE_H, 0);
}
tflite::ActivationFunctionType fused_activation_function() const {
return static_cast<tflite::ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
}
tflite::TensorType quantized_bias_type() const {
return static_cast<tflite::TensorType>(GetField<int8_t>(VT_quantized_bias_type, 0));
}
bool Verify(::flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyField<int8_t>(verifier, VT_PADDING, 1) &&
VerifyField<int32_t>(verifier, VT_STRIDE_W, 4) &&
VerifyField<int32_t>(verifier, VT_STRIDE_H, 4) &&
VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION, 1) &&
VerifyField<int8_t>(verifier, VT_quantized_bias_type, 1) &&
verifier.EndTable();
}
TransposeConvOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const;
......@@ -11113,6 +11142,9 @@ struct TransposeConvOptionsBuilder {
void add_fused_activation_function(tflite::ActivationFunctionType fused_activation_function) {
fbb_.AddElement<int8_t>(TransposeConvOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast<int8_t>(fused_activation_function), 0);
}
void add_quantized_bias_type(tflite::TensorType quantized_bias_type) {
fbb_.AddElement<int8_t>(TransposeConvOptions::VT_quantized_bias_type, static_cast<int8_t>(quantized_bias_type), 0);
}
explicit TransposeConvOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
......@@ -11129,11 +11161,13 @@ inline ::flatbuffers::Offset<TransposeConvOptions> CreateTransposeConvOptions(
tflite::Padding padding = tflite::Padding_SAME,
int32_t stride_w = 0,
int32_t stride_h = 0,
tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE) {
tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE,
tflite::TensorType quantized_bias_type = tflite::TensorType_FLOAT32) {
TransposeConvOptionsBuilder builder_(_fbb);
builder_.add_stride_h(stride_h);
builder_.add_stride_w(stride_w);
builder_.add_fused_activation_function(fused_activation_function);
builder_.add_quantized_bias_type(quantized_bias_type);
builder_.add_padding(padding);
return builder_.Finish();
}
......@@ -17188,6 +17222,7 @@ inline void Conv2DOptions::UnPackTo(Conv2DOptionsT *_o, const ::flatbuffers::res
{ auto _e = fused_activation_function(); _o->fused_activation_function = _e; }
{ auto _e = dilation_w_factor(); _o->dilation_w_factor = _e; }
{ auto _e = dilation_h_factor(); _o->dilation_h_factor = _e; }
{ auto _e = quantized_bias_type(); _o->quantized_bias_type = _e; }
}
inline ::flatbuffers::Offset<Conv2DOptions> Conv2DOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const Conv2DOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) {
......@@ -17204,6 +17239,7 @@ inline ::flatbuffers::Offset<Conv2DOptions> CreateConv2DOptions(::flatbuffers::F
auto _fused_activation_function = _o->fused_activation_function;
auto _dilation_w_factor = _o->dilation_w_factor;
auto _dilation_h_factor = _o->dilation_h_factor;
auto _quantized_bias_type = _o->quantized_bias_type;
return tflite::CreateConv2DOptions(
_fbb,
_padding,
......@@ -17211,7 +17247,8 @@ inline ::flatbuffers::Offset<Conv2DOptions> CreateConv2DOptions(::flatbuffers::F
_stride_h,
_fused_activation_function,
_dilation_w_factor,
_dilation_h_factor);
_dilation_h_factor,
_quantized_bias_type);
}
inline Conv3DOptionsT *Conv3DOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const {
......@@ -17545,6 +17582,7 @@ inline void FullyConnectedOptions::UnPackTo(FullyConnectedOptionsT *_o, const ::
{ auto _e = weights_format(); _o->weights_format = _e; }
{ auto _e = keep_num_dims(); _o->keep_num_dims = _e; }
{ auto _e = asymmetric_quantize_inputs(); _o->asymmetric_quantize_inputs = _e; }
{ auto _e = quantized_bias_type(); _o->quantized_bias_type = _e; }
}
inline ::flatbuffers::Offset<FullyConnectedOptions> FullyConnectedOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const FullyConnectedOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) {
......@@ -17559,12 +17597,14 @@ inline ::flatbuffers::Offset<FullyConnectedOptions> CreateFullyConnectedOptions(
auto _weights_format = _o->weights_format;
auto _keep_num_dims = _o->keep_num_dims;
auto _asymmetric_quantize_inputs = _o->asymmetric_quantize_inputs;
auto _quantized_bias_type = _o->quantized_bias_type;
return tflite::CreateFullyConnectedOptions(
_fbb,
_fused_activation_function,
_weights_format,
_keep_num_dims,
_asymmetric_quantize_inputs);
_asymmetric_quantize_inputs,
_quantized_bias_type);
}
inline SoftmaxOptionsT *SoftmaxOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const {
......@@ -18838,6 +18878,7 @@ inline void TransposeConvOptions::UnPackTo(TransposeConvOptionsT *_o, const ::fl
{ auto _e = stride_w(); _o->stride_w = _e; }
{ auto _e = stride_h(); _o->stride_h = _e; }
{ auto _e = fused_activation_function(); _o->fused_activation_function = _e; }
{ auto _e = quantized_bias_type(); _o->quantized_bias_type = _e; }
}
inline ::flatbuffers::Offset<TransposeConvOptions> TransposeConvOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const TransposeConvOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) {
......@@ -18852,12 +18893,14 @@ inline ::flatbuffers::Offset<TransposeConvOptions> CreateTransposeConvOptions(::
auto _stride_w = _o->stride_w;
auto _stride_h = _o->stride_h;
auto _fused_activation_function = _o->fused_activation_function;
auto _quantized_bias_type = _o->quantized_bias_type;
return tflite::CreateTransposeConvOptions(
_fbb,
_padding,
_stride_w,
_stride_h,
_fused_activation_function);
_fused_activation_function,
_quantized_bias_type);
}
inline ExpandDimsOptionsT *ExpandDimsOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const {
......@@ -272,6 +272,38 @@ std::unordered_set<string> PopulateRealValueOpSet(
return real_value_op_set;
}
// Set quantized_bias_type for CONV_2D/FULLY_CONNECTED/TRANSPOSE_CONV so that
// the accumulator is initialized to the appropriate default value when the bias
// is NULL.
void SetOperatorPropertyBiasType(ModelT* model, const TensorType& bias_type) {
for (int subgraph_idx = 0, end = model->subgraphs.size(); subgraph_idx < end;
subgraph_idx++) {
SubGraphT* subgraph = model->subgraphs.at(subgraph_idx).get();
// Iterate backward to avoid messing with index.
for (int op_idx = subgraph->operators.size() - 1; op_idx >= 0; op_idx--) {
OperatorT* op = subgraph->operators[op_idx].get();
OperatorCodeT* op_code = model->operator_codes[op->opcode_index].get();
if (op_code && op_code->builtin_code == BuiltinOperator_FULLY_CONNECTED) {
auto* options = op->builtin_options.AsFullyConnectedOptions();
if (options) {
options->quantized_bias_type = bias_type;
}
} else if (op_code && op_code->builtin_code == BuiltinOperator_CONV_2D) {
auto* options = op->builtin_options.AsConv2DOptions();
if (options) {
options->quantized_bias_type = bias_type;
}
} else if (op_code &&
op_code->builtin_code == BuiltinOperator_TRANSPOSE_CONV) {
auto* options = op->builtin_options.AsTransposeConvOptions();
if (options) {
options->quantized_bias_type = bias_type;
}
}
}
}
}
TfLiteStatus QuantizeBias(ModelT* model, const TensorT* input_tensor,
const TensorT* weight_tensor, TensorT* bias_tensor,
bool is_per_channel, int channel_dim_index,
......@@ -1923,6 +1955,7 @@ TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
TF_LITE_ENSURE_STATUS(ApplyConstraints(model, operator_names,
real_value_op_set, activations_type,
error_reporter));
SetOperatorPropertyBiasType(model, bias_type);
TF_LITE_ENSURE_STATUS(QuantizeBiases(model, operator_names, real_value_op_set,
activations_type, bias_type,
disable_per_channel, error_reporter));
......
......@@ -27,6 +27,7 @@ cc_library(
"//tensorflow/lite:builtin_op_data",
"//tensorflow/lite:minimal_logging",
"//tensorflow/lite/core/c:c_api_types",
"//tensorflow/lite/core/c:common",
"//tensorflow/lite/kernels/internal:compatibility",
"//tensorflow/lite/schema:schema_fbs",
"//tensorflow/lite/schema:schema_fbs_with_mutable",
......@@ -47,6 +48,7 @@ tf_cc_test(
":versioning",
"//tensorflow/lite:builtin_op_data",
"//tensorflow/lite/core/c:c_api_types",
"//tensorflow/lite/core/c:common",
"//tensorflow/lite/core/kernels:builtin_ops",
"//tensorflow/lite/schema:schema_fbs",
"//tensorflow/lite/schema:schema_fbs_with_mutable",
......
......@@ -24,8 +24,10 @@ limitations under the License.
#include "absl/strings/str_split.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/lite/builtin_op_data.h"
#include "tensorflow/lite/core/c/builtin_op_data.h"
#include "tensorflow/lite/core/c/c_api_types.h"
#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/schema/mutable/schema_generated.h"
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/schema/schema_utils.h"
......@@ -53,7 +55,19 @@ int GetInputMaxDims(const OpSignature& op_sig) {
int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
switch (op_sig.op) {
case BuiltinOperator_CONV_2D:
case BuiltinOperator_CONV_2D: {
if (op_sig.inputs.at(0).type == kTfLiteInt16 &&
op_sig.inputs.at(1).type == kTfLiteInt8 &&
op_sig.outputs.at(0).type == kTfLiteInt16) {
// `quantized_bias_type` is supported at version 8.
auto conv_params =
reinterpret_cast<TfLiteConvParams*>(op_sig.builtin_data);
TFLITE_DCHECK(conv_params != nullptr);
if (conv_params->quantized_bias_type) {
return 8;
}
}
if (op_sig.ext_options.conv_2d.is_grouped_convolution) {
return 6;
}
......@@ -90,7 +104,7 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
return 2;
}
return 1;
}
case BuiltinOperator_DEPTHWISE_CONV_2D: {
// If the op accepts int16, we return version 5.
if (op_sig.inputs.at(0).type == kTfLiteInt16 &&
......@@ -155,6 +169,19 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
// | Quantized Int8 | 4 | 4 |
// +-----------------+--------------------+--------------------------+
auto fully_connected_params =
reinterpret_cast<TfLiteFullyConnectedParams*>(op_sig.builtin_data);
TFLITE_DCHECK(fully_connected_params != nullptr);
if (op_sig.inputs.at(0).type == kTfLiteInt16 &&
op_sig.inputs.at(1).type == kTfLiteInt8 &&
op_sig.outputs.at(0).type == kTfLiteInt16) {
// `quantized_bias_type` is supported at version 11.
if (fully_connected_params->quantized_bias_type) {
return 11;
}
}
// FullyConnected with sparse weight is supported at version 8.
if (op_sig.ext_options.fully_connected.sparse_weight) {
return 8;
......@@ -172,9 +199,6 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
if (op_sig.inputs.size() == 2) {
return 6;
}
auto fully_connected_params =
reinterpret_cast<TfLiteFullyConnectedParams*>(op_sig.builtin_data);
TFLITE_DCHECK(fully_connected_params != nullptr);
// `keep_num_dims` is supported at version 5.
if (fully_connected_params->keep_num_dims) {
return 5;
......@@ -335,6 +359,15 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
case BuiltinOperator_TRANSPOSE_CONV: {
auto transpose_conv_params =
reinterpret_cast<TfLiteTransposeConvParams*>(op_sig.builtin_data);
if (op_sig.inputs.at(0).type == kTfLiteInt16 &&
op_sig.inputs.at(1).type == kTfLiteInt8 &&
op_sig.outputs.at(0).type == kTfLiteInt16) {
// `quantized_bias_type` is supported at version 5.
TFLITE_DCHECK(transpose_conv_params != nullptr);
if (transpose_conv_params->quantized_bias_type) {
return 5;
}
}
// TransposeConvOp has fused activation function from version 4.
if (transpose_conv_params != nullptr &&
......
......@@ -19,6 +19,7 @@ limitations under the License.
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "tensorflow/lite/builtin_op_data.h"
#include "tensorflow/lite/core/c/builtin_op_data.h"
#include "tensorflow/lite/core/c/c_api_types.h"
#include "tensorflow/lite/schema/schema_generated.h"
......@@ -698,6 +699,16 @@ TEST(OpVersionTest, VersioningFullyConnectedTest) {
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 3);
fully_connected_params.asymmetric_quantize_inputs = true;
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 9);
fake_op_sig = {
.op = BuiltinOperator_FULLY_CONNECTED,
.inputs = CreateOpSignatureTensorSpecs(
std::vector<TfLiteType>{kTfLiteInt16, kTfLiteInt8}),
.outputs = CreateOpSignatureTensorSpecs(kTfLiteInt16),
.builtin_data = reinterpret_cast<void*>(&fully_connected_params),
};
fully_connected_params.quantized_bias_type = kTfLiteInt32;
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 11);
}
TEST(OpVersionTest, VersioningDequantizeTest) {
......@@ -789,6 +800,17 @@ TEST(OpVersionTest, VersioningConv2DTest) {
fake_op_sig.ext_options.conv_2d.is_grouped_convolution = true;
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 6);
TfLiteConvParams conv_params = {};
fake_op_sig = {
.op = BuiltinOperator_CONV_2D,
.inputs = CreateOpSignatureTensorSpecs(
std::vector<TfLiteType>{kTfLiteInt16, kTfLiteInt8}),
.outputs = CreateOpSignatureTensorSpecs(kTfLiteInt16),
.builtin_data = reinterpret_cast<void*>(&conv_params),
};
conv_params.quantized_bias_type = kTfLiteInt32;
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 8);
}
TEST(OpVersionTest, VersioningFloorDivOperatorTest) {
......@@ -864,6 +886,17 @@ TEST(OpVersionTest, VersioningTransposeConvOperatorTest) {
.builtin_data = reinterpret_cast<void*>(&transpose_conv_params),
};
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 4);
transpose_conv_params = {};
fake_op_sig = {
.op = BuiltinOperator_TRANSPOSE_CONV,
.inputs = CreateOpSignatureTensorSpecs(
std::vector<TfLiteType>{kTfLiteInt16, kTfLiteInt8}),
.outputs = CreateOpSignatureTensorSpecs(kTfLiteInt16),
.builtin_data = reinterpret_cast<void*>(&transpose_conv_params),
};
transpose_conv_params.quantized_bias_type = kTfLiteInt32;
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 5);
}
TEST(OpVersionTest, VersioningSVDFOperatorTest) {
......
......@@ -73,6 +73,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code,
{{BuiltinOperator_CONV_2D, 5}, "2.4.0"},
{{BuiltinOperator_CONV_2D, 6}, "2.9.0"},
{{BuiltinOperator_CONV_2D, 7}, "2.11.0"},
{{BuiltinOperator_CONV_2D, 8}, "2.15.0"},
{{BuiltinOperator_DEPTHWISE_CONV_2D, 1}, "1.5.0"},
{{BuiltinOperator_DEPTHWISE_CONV_2D, 2}, "1.12.0"},
{{BuiltinOperator_DEPTHWISE_CONV_2D, 3}, "1.14.0"},
......@@ -130,6 +131,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code,
{{BuiltinOperator_FULLY_CONNECTED, 8}, "2.3.0"},
{{BuiltinOperator_FULLY_CONNECTED, 9}, "2.3.0"},
{{BuiltinOperator_FULLY_CONNECTED, 10}, "2.11.0"},
{{BuiltinOperator_FULLY_CONNECTED, 11}, "2.15.0"},
{{BuiltinOperator_GATHER, 1}, "1.6.0"},
{{BuiltinOperator_GATHER, 2}, "1.14.0"},
{{BuiltinOperator_GATHER, 3}, "1.15.0"},
......@@ -265,6 +267,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code,
{{BuiltinOperator_TRANSPOSE_CONV, 2}, "2.2.0"},
{{BuiltinOperator_TRANSPOSE_CONV, 3}, "2.3.0"},
{{BuiltinOperator_TRANSPOSE_CONV, 4}, "2.13.0"},
{{BuiltinOperator_TRANSPOSE_CONV, 5}, "2.15.0"},
{{BuiltinOperator_SPARSE_TO_DENSE, 1}, "1.9.0"},
{{BuiltinOperator_SPARSE_TO_DENSE, 2}, "1.14.0"},
{{BuiltinOperator_SPARSE_TO_DENSE, 3}, "1.15.0"},
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册