未验证 提交 6389e83f 编写于 作者: C cad-audio 提交者: GitHub

REF_CODE_REFACTOR: logistic (#308)

* REF_CODE_REFACTOR: logistic
Refactoring the reference code for logistic operator.

BUG=refactoring existing code.

* Fix formatting
Co-authored-by: NAdvait Jain <advaitjain@users.noreply.github.com>
Co-authored-by: NAdvait Jain <advaitjain@google.com>
上级 4eae8093
...@@ -151,6 +151,7 @@ cc_library( ...@@ -151,6 +151,7 @@ cc_library(
"log_softmax.cc", "log_softmax.cc",
"logical.cc", "logical.cc",
"logistic.cc", "logistic.cc",
"logistic_common.cc",
"maximum_minimum.cc", "maximum_minimum.cc",
"mul.cc", "mul.cc",
"neg.cc", "neg.cc",
...@@ -190,6 +191,7 @@ cc_library( ...@@ -190,6 +191,7 @@ cc_library(
"depthwise_conv.h", "depthwise_conv.h",
"ethosu.h", "ethosu.h",
"fully_connected.h", "fully_connected.h",
"logistic.h",
"micro_ops.h", "micro_ops.h",
"pooling.h", "pooling.h",
"quantize.h", "quantize.h",
......
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
...@@ -24,71 +24,24 @@ limitations under the License. ...@@ -24,71 +24,24 @@ limitations under the License.
#include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/op_macros.h" #include "tensorflow/lite/kernels/op_macros.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h" #include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "tensorflow/lite/micro/kernels/logistic.h"
namespace tflite { namespace tflite {
namespace ops {
namespace micro {
namespace activations {
namespace { namespace {
constexpr int kInputTensor = 0;
constexpr int kOutputTensor = 0;
struct OpData {
int32_t input_zero_point;
int32_t input_range_radius;
int32_t input_multiplier;
int input_left_shift;
};
TfLiteStatus CalculateArithmeticOpData(TfLiteContext* context, TfLiteNode* node,
OpData* data) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TF_LITE_ENSURE(context, input != nullptr);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
if (input->type == kTfLiteInt8) {
TF_LITE_ENSURE_EQ(context, output->params.zero_point,
std::numeric_limits<int8_t>::min());
static constexpr int kInputIntegerBits = 4;
const double input_real_multiplier =
static_cast<double>(input->params.scale) *
static_cast<double>(1 << (31 - kInputIntegerBits));
data->input_zero_point = input->params.zero_point;
const double q = std::frexp(input_real_multiplier, &data->input_left_shift);
data->input_multiplier = static_cast<int32_t>(TfLiteRound(q * (1ll << 31)));
data->input_range_radius =
CalculateInputRadius(kInputIntegerBits, data->input_left_shift, 31);
}
return kTfLiteOk;
}
} // namespace
void* LogisticInit(TfLiteContext* context, const char* buffer, size_t length) { void* LogisticInit(TfLiteContext* context, const char* buffer, size_t length) {
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr); TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
return context->AllocatePersistentBuffer(context, sizeof(OpData)); return context->AllocatePersistentBuffer(context, sizeof(OpDataLogistic));
}
TfLiteStatus LogisticPrepare(TfLiteContext* context, TfLiteNode* node) {
TFLITE_DCHECK(node->user_data != nullptr);
OpData* data = static_cast<OpData*>(node->user_data);
return CalculateArithmeticOpData(context, node, data);
} }
TfLiteStatus LogisticEval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus LogisticEval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteEvalTensor* input = const TfLiteEvalTensor* input =
tflite::micro::GetEvalInput(context, node, kInputTensor); tflite::micro::GetEvalInput(context, node, kLogisticInputTensor);
TfLiteEvalTensor* output = TfLiteEvalTensor* output =
tflite::micro::GetEvalOutput(context, node, kOutputTensor); tflite::micro::GetEvalOutput(context, node, kLogisticOutputTensor);
TFLITE_DCHECK(node->user_data != nullptr); TFLITE_DCHECK(node->user_data != nullptr);
OpData* data = static_cast<OpData*>(node->user_data); OpDataLogistic* data = static_cast<OpDataLogistic*>(node->user_data);
if (input->type == kTfLiteFloat32) { if (input->type == kTfLiteFloat32) {
switch (output->type) { switch (output->type) {
...@@ -133,18 +86,16 @@ TfLiteStatus LogisticEval(TfLiteContext* context, TfLiteNode* node) { ...@@ -133,18 +86,16 @@ TfLiteStatus LogisticEval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk; return kTfLiteOk;
} }
} // namespace activations } // namespace
TfLiteRegistration Register_LOGISTIC() { TfLiteRegistration Register_LOGISTIC() {
return {/*init=*/activations::LogisticInit, return {/*init=*/LogisticInit,
/*free=*/nullptr, /*free=*/nullptr,
/*prepare=*/activations::LogisticPrepare, /*prepare=*/LogisticPrepare,
/*invoke=*/activations::LogisticEval, /*invoke=*/LogisticEval,
/*profiling_string=*/nullptr, /*profiling_string=*/nullptr,
/*builtin_code=*/0, /*builtin_code=*/0,
/*custom_name=*/nullptr, /*custom_name=*/nullptr,
/*version=*/0}; /*version=*/0};
} }
} // namespace micro
} // namespace ops
} // namespace tflite } // namespace tflite
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_MICRO_KERNELS_LOGISTIC_H_
#define TENSORFLOW_LITE_MICRO_KERNELS_LOGISTIC_H_
#include <cstdint>
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
namespace tflite {
extern const int kLogisticInputTensor;
extern const int kLogisticOutputTensor;
struct OpDataLogistic {
int32_t input_zero_point;
int32_t input_range_radius;
int32_t input_multiplier;
int input_left_shift;
};
TfLiteStatus CalculateArithmeticOpDataLogistic(TfLiteContext* context,
TfLiteNode* node,
OpDataLogistic* data);
TfLiteStatus LogisticPrepare(TfLiteContext* context, TfLiteNode* node);
} // namespace tflite
#endif // TENSORFLOW_LITE_MICRO_KERNELS_LOGISTIC_H_
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h"
#include "tensorflow/lite/kernels/internal/reference/integer_ops/logistic.h"
#include "tensorflow/lite/kernels/internal/reference/logistic.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/op_macros.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "tensorflow/lite/micro/kernels/logistic.h"
namespace tflite {
const int kLogisticInputTensor = 0;
const int kLogisticOutputTensor = 0;
TfLiteStatus CalculateArithmeticOpDataLogistic(TfLiteContext* context,
TfLiteNode* node,
OpDataLogistic* data) {
const TfLiteTensor* input = GetInput(context, node, kLogisticInputTensor);
TF_LITE_ENSURE(context, input != nullptr);
TfLiteTensor* output = GetOutput(context, node, kLogisticOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
if (input->type == kTfLiteInt8) {
TF_LITE_ENSURE_EQ(context, output->params.zero_point,
std::numeric_limits<int8_t>::min());
static constexpr int kInputIntegerBits = 4;
const double input_real_multiplier =
static_cast<double>(input->params.scale) *
static_cast<double>(1 << (31 - kInputIntegerBits));
data->input_zero_point = input->params.zero_point;
const double q = std::frexp(input_real_multiplier, &data->input_left_shift);
data->input_multiplier = static_cast<int32_t>(TfLiteRound(q * (1ll << 31)));
data->input_range_radius =
CalculateInputRadius(kInputIntegerBits, data->input_left_shift, 31);
}
return kTfLiteOk;
}
TfLiteStatus LogisticPrepare(TfLiteContext* context, TfLiteNode* node) {
TFLITE_DCHECK(node->user_data != nullptr);
OpDataLogistic* data = static_cast<OpDataLogistic*>(node->user_data);
return CalculateArithmeticOpDataLogistic(context, node, data);
}
} // namespace tflite
...@@ -55,8 +55,7 @@ void ValidateLogisticGoldens(TfLiteTensor* tensors, const int tensor_count, ...@@ -55,8 +55,7 @@ void ValidateLogisticGoldens(TfLiteTensor* tensors, const int tensor_count,
int outputs_array_data[] = {1, 1}; int outputs_array_data[] = {1, 1};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
const TfLiteRegistration registration = const TfLiteRegistration registration = tflite::Register_LOGISTIC();
tflite::ops::micro::Register_LOGISTIC();
micro::KernelRunner runner(registration, tensors, tensor_count, inputs_array, micro::KernelRunner runner(registration, tensors, tensor_count, inputs_array,
outputs_array, nullptr); outputs_array, nullptr);
......
...@@ -51,6 +51,7 @@ TfLiteRegistration Register_IF(); ...@@ -51,6 +51,7 @@ TfLiteRegistration Register_IF();
TfLiteRegistration Register_L2_POOL_2D(); TfLiteRegistration Register_L2_POOL_2D();
TfLiteRegistration Register_LEAKY_RELU(); TfLiteRegistration Register_LEAKY_RELU();
TfLiteRegistration Register_LOG_SOFTMAX(); TfLiteRegistration Register_LOG_SOFTMAX();
TfLiteRegistration Register_LOGISTIC();
TfLiteRegistration Register_MAX_POOL_2D(); TfLiteRegistration Register_MAX_POOL_2D();
TfLiteRegistration Register_QUANTIZE(); TfLiteRegistration Register_QUANTIZE();
TfLiteRegistration Register_RELU(); TfLiteRegistration Register_RELU();
...@@ -89,7 +90,6 @@ TfLiteRegistration Register_LOG(); ...@@ -89,7 +90,6 @@ TfLiteRegistration Register_LOG();
TfLiteRegistration Register_LOGICAL_AND(); TfLiteRegistration Register_LOGICAL_AND();
TfLiteRegistration Register_LOGICAL_NOT(); TfLiteRegistration Register_LOGICAL_NOT();
TfLiteRegistration Register_LOGICAL_OR(); TfLiteRegistration Register_LOGICAL_OR();
TfLiteRegistration Register_LOGISTIC();
TfLiteRegistration Register_MAXIMUM(); TfLiteRegistration Register_MAXIMUM();
TfLiteRegistration Register_MEAN(); TfLiteRegistration Register_MEAN();
TfLiteRegistration Register_MINIMUM(); TfLiteRegistration Register_MINIMUM();
......
...@@ -338,8 +338,8 @@ class MicroMutableOpResolver : public MicroOpResolver { ...@@ -338,8 +338,8 @@ class MicroMutableOpResolver : public MicroOpResolver {
} }
TfLiteStatus AddLogistic() { TfLiteStatus AddLogistic() {
return AddBuiltin(BuiltinOperator_LOGISTIC, return AddBuiltin(BuiltinOperator_LOGISTIC, tflite::Register_LOGISTIC(),
tflite::ops::micro::Register_LOGISTIC(), ParseLogistic); ParseLogistic);
} }
TfLiteStatus AddMaximum() { TfLiteStatus AddMaximum() {
......
...@@ -388,6 +388,7 @@ tensorflow/lite/micro/kernels/l2_pool_2d.cc \ ...@@ -388,6 +388,7 @@ tensorflow/lite/micro/kernels/l2_pool_2d.cc \
tensorflow/lite/micro/kernels/leaky_relu.cc \ tensorflow/lite/micro/kernels/leaky_relu.cc \
tensorflow/lite/micro/kernels/logical.cc \ tensorflow/lite/micro/kernels/logical.cc \
tensorflow/lite/micro/kernels/logistic.cc \ tensorflow/lite/micro/kernels/logistic.cc \
tensorflow/lite/micro/kernels/logistic_common.cc \
tensorflow/lite/micro/kernels/log_softmax.cc \ tensorflow/lite/micro/kernels/log_softmax.cc \
tensorflow/lite/micro/kernels/maximum_minimum.cc \ tensorflow/lite/micro/kernels/maximum_minimum.cc \
tensorflow/lite/micro/kernels/mul.cc \ tensorflow/lite/micro/kernels/mul.cc \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册