未验证 提交 b942395d 编写于 作者: C cad-audio 提交者: GitHub

REF_CODE_REFACTOR: hard_swish (#307)

* REF_CODE_REFACTOR: hard_swish
Refactoring the reference code for hard_swish operator.

BUG=refactoring existing code.

* Remove uint8 support.

* fix formatting.
Co-authored-by: NAdvait Jain <advaitjain@google.com>
上级 4cef7ee3
......@@ -144,6 +144,7 @@ cc_library(
"gather.cc",
"gather_nd.cc",
"hard_swish.cc",
"hard_swish_common.cc",
"if.cc",
"l2_pool_2d.cc",
"l2norm.cc",
......@@ -192,6 +193,7 @@ cc_library(
"depthwise_conv.h",
"ethosu.h",
"fully_connected.h",
"hard_swish.h",
"logical.h",
"logistic.h",
"micro_ops.h",
......
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -23,72 +23,23 @@ limitations under the License.
#include "tensorflow/lite/kernels/internal/types.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/op_macros.h"
#include "tensorflow/lite/micro/kernels/hard_swish.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "tensorflow/lite/micro/micro_error_reporter.h"
#include "tensorflow/lite/micro/micro_utils.h"
namespace tflite {
namespace ops {
namespace micro {
namespace hard_swish {
constexpr int kInputTensor = 0;
constexpr int kOutputTensor = 0;
namespace {
void* HardSwishInit(TfLiteContext* context, const char* buffer, size_t length) {
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
return context->AllocatePersistentBuffer(context, sizeof(HardSwishParams));
}
TfLiteStatus HardSwishPrepare(TfLiteContext* context, TfLiteNode* node) {
TFLITE_DCHECK(node->user_data != nullptr);
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TF_LITE_ENSURE(context, input != nullptr);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) {
HardSwishParams* params = static_cast<HardSwishParams*>(node->user_data);
params->input_zero_point = input->params.zero_point;
params->output_zero_point = output->params.zero_point;
const float input_scale = input->params.scale;
const float hires_input_scale = (1.0f / 128.0f) * input_scale;
const float reluish_scale = 3.0f / 32768.0f;
const float output_scale = output->params.scale;
const double output_multiplier =
static_cast<double>(hires_input_scale / output_scale);
int32_t output_multiplier_fixedpoint_int32;
QuantizeMultiplier(output_multiplier, &output_multiplier_fixedpoint_int32,
&params->output_multiplier_exponent);
DownScaleInt32ToInt16Multiplier(
output_multiplier_fixedpoint_int32,
&params->output_multiplier_fixedpoint_int16);
TF_LITE_ENSURE(context, params->output_multiplier_exponent <= 0);
const double reluish_multiplier =
static_cast<double>(hires_input_scale / reluish_scale);
int32_t reluish_multiplier_fixedpoint_int32;
QuantizeMultiplier(reluish_multiplier, &reluish_multiplier_fixedpoint_int32,
&params->reluish_multiplier_exponent);
DownScaleInt32ToInt16Multiplier(
reluish_multiplier_fixedpoint_int32,
&params->reluish_multiplier_fixedpoint_int16);
}
return kTfLiteOk;
}
TfLiteStatus HardSwishEval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteEvalTensor* input =
tflite::micro::GetEvalInput(context, node, kInputTensor);
tflite::micro::GetEvalInput(context, node, kHardSwishInputTensor);
TfLiteEvalTensor* output =
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
tflite::micro::GetEvalOutput(context, node, kHardSwishOutputTensor);
HardSwishParams* params = static_cast<HardSwishParams*>(node->user_data);
switch (input->type) {
......@@ -99,13 +50,6 @@ TfLiteStatus HardSwishEval(TfLiteContext* context, TfLiteNode* node) {
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<float>(output));
} break;
case kTfLiteUInt8: {
tflite::reference_ops::HardSwish<uint8_t>(
*params, tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<uint8_t>(input),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<uint8_t>(output));
} break;
case kTfLiteInt8: {
tflite::reference_ops::HardSwish<int8_t>(
*params, tflite::micro::GetTensorShape(input),
......@@ -114,29 +58,24 @@ TfLiteStatus HardSwishEval(TfLiteContext* context, TfLiteNode* node) {
tflite::micro::GetTensorData<int8_t>(output));
} break;
default: {
TF_LITE_KERNEL_LOG(
context,
"Only float32/int8_t/uint8_t are supported currently, got %s",
TfLiteTypeGetName(input->type));
MicroPrintf("Unsupported type %s", TfLiteTypeGetName(input->type));
return kTfLiteError;
}
}
return kTfLiteOk;
}
} // namespace hard_swish
} // namespace
TfLiteRegistration Register_HARD_SWISH() {
return {/*init=*/hard_swish::HardSwishInit,
return {/*init=*/HardSwishInit,
/*free=*/nullptr,
/*prepare=*/hard_swish::HardSwishPrepare,
/*invoke=*/hard_swish::HardSwishEval,
/*prepare=*/tflite::HardSwishPrepare,
/*invoke=*/HardSwishEval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
}
} // namespace micro
} // namespace ops
} // namespace tflite
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_MICRO_KERNELS_HARD_SWISH_H_
#define TENSORFLOW_LITE_MICRO_KERNELS_HARD_SWISH_H_
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
namespace tflite {
extern const int kHardSwishInputTensor;
extern const int kHardSwishOutputTensor;
TfLiteStatus HardSwishPrepare(TfLiteContext* context, TfLiteNode* node);
} // namespace tflite
#endif // TENSORFLOW_LITE_MICRO_KERNELS_HARD_SWISH_H_
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h"
#include "tensorflow/lite/kernels/internal/reference/hard_swish.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/internal/types.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/op_macros.h"
#include "tensorflow/lite/micro/kernels/hard_swish.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "tensorflow/lite/micro/micro_utils.h"
namespace tflite {
const int kHardSwishInputTensor = 0;
const int kHardSwishOutputTensor = 0;
TfLiteStatus HardSwishPrepare(TfLiteContext* context, TfLiteNode* node) {
TFLITE_DCHECK(node->user_data != nullptr);
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, kHardSwishInputTensor);
TF_LITE_ENSURE(context, input != nullptr);
TfLiteTensor* output = GetOutput(context, node, kHardSwishOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
if (input->type == kTfLiteInt8) {
HardSwishParams* params = static_cast<HardSwishParams*>(node->user_data);
params->input_zero_point = input->params.zero_point;
params->output_zero_point = output->params.zero_point;
const float input_scale = input->params.scale;
const float hires_input_scale = (1.0f / 128.0f) * input_scale;
const float reluish_scale = 3.0f / 32768.0f;
const float output_scale = output->params.scale;
const double output_multiplier =
static_cast<double>(hires_input_scale / output_scale);
int32_t output_multiplier_fixedpoint_int32;
QuantizeMultiplier(output_multiplier, &output_multiplier_fixedpoint_int32,
&params->output_multiplier_exponent);
DownScaleInt32ToInt16Multiplier(
output_multiplier_fixedpoint_int32,
&params->output_multiplier_fixedpoint_int16);
TF_LITE_ENSURE(context, params->output_multiplier_exponent <= 0);
const double reluish_multiplier =
static_cast<double>(hires_input_scale / reluish_scale);
int32_t reluish_multiplier_fixedpoint_int32;
QuantizeMultiplier(reluish_multiplier, &reluish_multiplier_fixedpoint_int32,
&params->reluish_multiplier_exponent);
DownScaleInt32ToInt16Multiplier(
reluish_multiplier_fixedpoint_int32,
&params->reluish_multiplier_fixedpoint_int16);
}
return kTfLiteOk;
}
} // namespace tflite
......@@ -105,8 +105,7 @@ void TestHardSwishQuantized(int size, const T* output_data,
int outputs_array_data[] = {1, 1};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
const TfLiteRegistration registration =
tflite::ops::micro::Register_HARD_SWISH();
const TfLiteRegistration registration = tflite::Register_HARD_SWISH();
micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array,
outputs_array, /*builtin_data=*/nullptr);
......@@ -184,8 +183,7 @@ void TestHardSwishQuantizedBias(const int size, const T* output_data,
int outputs_array_data[] = {1, 1};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
const TfLiteRegistration registration =
tflite::ops::micro::Register_HARD_SWISH();
const TfLiteRegistration registration = tflite::Register_HARD_SWISH();
micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array,
outputs_array, /*builtin_data=*/nullptr);
......@@ -236,8 +234,7 @@ void TestHardSwishFloat(const int size, float* output_data,
int outputs_array_data[] = {1, 1};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
const TfLiteRegistration registration =
tflite::ops::micro::Register_HARD_SWISH();
const TfLiteRegistration registration = tflite::Register_HARD_SWISH();
micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array,
outputs_array, /*builtin_data=*/nullptr);
......@@ -294,49 +291,4 @@ TF_LITE_MICRO_TEST(SimpleHardSwishTestInt8) {
}
}
TF_LITE_MICRO_TEST(SimpleHardSwishTestUint8) {
std::minstd_rand random_engine;
constexpr int size = 99;
constexpr int pairs = 4, one_pair = 2;
constexpr float minmax_pairs[pairs][one_pair] = {
{0.f, 1.f}, {-2.f, 1.f}, {-5.f, 10.f}, {-40.f, 60.f}};
uint8_t output_data[size] = {0};
uint8_t input_data_quantized[size] = {0};
float dequantized_output[size] = {0.f};
float input_values[size] = {0.f};
float output_values[size] = {0.f};
for (int x = 0; x < pairs; x++) {
for (int y = 0; y < pairs; y++) {
float input_min = minmax_pairs[x][0];
float input_max = minmax_pairs[x][1];
float output_min = minmax_pairs[y][0];
float output_max = minmax_pairs[y][1];
tflite::testing::TestHardSwishQuantized<uint8_t>(
size, output_data, input_data_quantized, dequantized_output,
input_min, input_max, output_min, output_max, &random_engine,
input_values, output_values);
}
}
}
// See the comment in the reference implementation of quantized HardSwish:
// A numerical issue significantly affecting ImageNet classification accuracy
// with MobileNet v3 is only observable at the scale of HardSwish unit tests
// if we monitor specifically bias. This testcase is extracted from one of the
// HardSwish nodes in that MobileNet v3 that exhibited this issue.
TF_LITE_MICRO_TEST(SimpleHardSwishTestQuantizedBias) {
constexpr int size = 43;
uint8_t output_data[size] = {0};
uint8_t input_data_quantized[size] = {0};
float dequantized_output[size] = {0.f};
float input_values[size] = {0.f};
float output_values[size] = {0.f};
tflite::testing::TestHardSwishQuantizedBias<uint8_t>(
size, output_data, input_data_quantized, dequantized_output, -11.654928f,
25.036512f, -0.3905796f, 24.50887f, 0.035, input_values, output_values);
}
TF_LITE_MICRO_TESTS_END
......@@ -47,6 +47,7 @@ TfLiteRegistration Register_FLOOR_DIV();
TfLiteRegistration Register_FLOOR_MOD();
TfLiteRegistration Register_GATHER();
TfLiteRegistration Register_GATHER_ND();
TfLiteRegistration Register_HARD_SWISH();
TfLiteRegistration Register_IF();
TfLiteRegistration Register_L2_POOL_2D();
TfLiteRegistration Register_LEAKY_RELU();
......@@ -85,7 +86,6 @@ TfLiteRegistration Register_EQUAL();
TfLiteRegistration Register_FLOOR();
TfLiteRegistration Register_GREATER();
TfLiteRegistration Register_GREATER_EQUAL();
TfLiteRegistration Register_HARD_SWISH();
TfLiteRegistration Register_LESS();
TfLiteRegistration Register_LESS_EQUAL();
TfLiteRegistration Register_LOG();
......
......@@ -278,8 +278,7 @@ class MicroMutableOpResolver : public MicroOpResolver {
}
TfLiteStatus AddHardSwish() {
return AddBuiltin(BuiltinOperator_HARD_SWISH,
tflite::ops::micro::Register_HARD_SWISH(),
return AddBuiltin(BuiltinOperator_HARD_SWISH, tflite::Register_HARD_SWISH(),
ParseHardSwish);
}
......
......@@ -380,6 +380,7 @@ tensorflow/lite/micro/kernels/fully_connected_common.cc \
tensorflow/lite/micro/kernels/gather.cc \
tensorflow/lite/micro/kernels/gather_nd.cc \
tensorflow/lite/micro/kernels/hard_swish.cc \
tensorflow/lite/micro/kernels/hard_swish_common.cc \
tensorflow/lite/micro/kernels/if.cc \
tensorflow/lite/micro/kernels/kernel_runner.cc \
tensorflow/lite/micro/kernels/kernel_util.cc \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册