未验证 提交 f9e68f11 编写于 作者: N Nat Jeffries 提交者: GitHub

Implement svdf and fully_connected kernel variants for hexagon. (#380)

上级 3eca5785
......@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/lite/micro/benchmarks/micro_benchmark.h"
#include "tensorflow/lite/micro/kernels/fully_connected.h"
#include "tensorflow/lite/micro/kernels/softmax.h"
#include "tensorflow/lite/micro/kernels/svdf.h"
#include "tensorflow/lite/micro/micro_error_reporter.h"
#include "tensorflow/lite/micro/micro_mutable_op_resolver.h"
#include "tensorflow/lite/micro/micro_profiler.h"
......@@ -55,7 +56,7 @@ KeywordBenchmarkRunner* CreateBenchmarkRunner(MicroProfiler* profiler) {
op_resolver->AddFullyConnected(tflite::Register_FULLY_CONNECTED_INT8());
op_resolver->AddQuantize();
op_resolver->AddSoftmax(tflite::Register_SOFTMAX_INT8_INT16());
op_resolver->AddSvdf();
op_resolver->AddSvdf(tflite::Register_SVDF_INT8());
return new (benchmark_runner_buffer)
KeywordBenchmarkRunner(g_keyword_scrambled_model_data, op_resolver,
......
......@@ -65,14 +65,9 @@ TfLiteStatus CalculateOpDataFullyConnected(
// (reference or optimized) must define this function.
TfLiteRegistration Register_FULLY_CONNECTED();
#if defined(CMSIS_NN)
// The Arduino is a special case where we use the CMSIS kernels, but because of
// the current approach to building for Arduino, we do not support -DCMSIS_NN as
// part of the build. As a result, we use defined(ARDUINO) as proxy for the
// CMSIS kernels for this one special case.
// Returns a TfLiteRegistration struct for cmsis_nn kernel variant that only
// supports int8.
#if defined(CMSIS_NN) || defined(HEXAGON)
// Returns a TfLiteRegistration struct for kernel variant that only supports
// int8.
TfLiteRegistration Register_FULLY_CONNECTED_INT8();
#else
......
......@@ -33,13 +33,13 @@ namespace {
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
return context->AllocatePersistentBuffer(context, sizeof(OpData));
return context->AllocatePersistentBuffer(context, sizeof(OpDataSvdf));
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteSVDFParams*>(node->builtin_data);
TFLITE_DCHECK(node->user_data != nullptr);
const OpData& data = *(static_cast<const OpData*>(node->user_data));
const OpDataSvdf& data = *(static_cast<const OpDataSvdf*>(node->user_data));
const TfLiteEvalTensor* input =
tflite::micro::GetEvalInput(context, node, kSvdfInputTensor);
......
......@@ -20,7 +20,7 @@ limitations under the License.
namespace tflite {
struct OpData {
struct OpDataSvdf {
int32_t effective_scale_1_a;
int32_t effective_scale_2_a;
// b versions of each scale are kept at int since the numbers are just the
......@@ -55,7 +55,7 @@ void EvalIntegerSvdfReference(TfLiteContext* context, TfLiteNode* node,
const TfLiteSVDFParams* params,
TfLiteEvalTensor* activation_state_tensor,
TfLiteEvalTensor* output_tensor,
const OpData& data);
const OpDataSvdf& data);
void EvalFloatSvdfReference(
TfLiteContext* context, TfLiteNode* node, const TfLiteEvalTensor* input,
......@@ -66,6 +66,23 @@ void EvalFloatSvdfReference(
TfLiteStatus PrepareSvdf(TfLiteContext* context, TfLiteNode* node);
// This is the most generic TfLiteRegistration. The actual supported types may
// still be target dependent. The only requirement is that every implementation
// (reference or optimized) must define this function.
TfLiteRegistration Register_SVDF();
#if defined(HEXAGON)
TfLiteRegistration Register_SVDF_INT8();
#else
// Note that while this block gets used for both reference and optimized kernels
// that do not have any specialized implementations, the only goal here is to
// define fallback implementation that allow reference kernels to still be used
// from applications that call a more specific kernel variant.
inline TfLiteRegistration Register_SVDF_INT8() { return Register_SVDF(); }
#endif
} // namespace tflite
#endif // TENSORFLOW_LITE_MICRO_KERNELS_SVDF_H_
......@@ -56,7 +56,7 @@ void EvalIntegerSvdfReference(TfLiteContext* context, TfLiteNode* node,
const TfLiteSVDFParams* params,
TfLiteEvalTensor* activation_state_tensor,
TfLiteEvalTensor* output_tensor,
const OpData& data) {
const OpDataSvdf& data) {
const int n_rank = params->rank;
const int n_batch = input_tensor->dims->data[0];
const int n_input = input_tensor->dims->data[1];
......@@ -401,7 +401,7 @@ TfLiteStatus PrepareSvdf(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, node->inputs->size, 5);
TFLITE_DCHECK(node->user_data != nullptr);
OpData* data = static_cast<OpData*>(node->user_data);
OpDataSvdf* data = static_cast<OpDataSvdf*>(node->user_data);
if (input->type == kTfLiteInt8) {
TF_LITE_ENSURE_EQ(context, weights_feature->type, kTfLiteInt8);
......
......@@ -48,7 +48,7 @@ constexpr int kOutputTensor = 0;
* This version of SVDF is specific to TFLite Micro. It contains only a full
* integer receipe with optimizations for the Xtensa HiFiMini platform.
*
* Note: passing OpData by value might seem like an oversight but it helps
* Note: passing OpDataSvdf by value might seem like an oversight but it helps
* reduce the latency. See b/155656675 for more details.
*/
void EvalIntegerSvdfHifimini(TfLiteContext* context, TfLiteNode* node,
......@@ -58,7 +58,7 @@ void EvalIntegerSvdfHifimini(TfLiteContext* context, TfLiteNode* node,
const TfLiteEvalTensor* bias_tensor,
const TfLiteSVDFParams* params,
TfLiteEvalTensor* activation_state_tensor,
TfLiteEvalTensor* output_tensor, OpData data) {
TfLiteEvalTensor* output_tensor, OpDataSvdf data) {
const int n_rank = params->rank;
const int n_batch = input_tensor->dims->data[0];
const int n_input = input_tensor->dims->data[1];
......@@ -254,7 +254,7 @@ TfLiteStatus EvalIntegerSvdfHifi(TfLiteContext* context, TfLiteNode* node,
const TfLiteSVDFParams* params,
TfLiteEvalTensor* activation_state_tensor,
TfLiteEvalTensor* output_tensor,
const OpData& data) {
const OpDataSvdf& data) {
const int n_rank = params->rank;
const int n_batch = input_tensor->dims->data[0];
const int n_input = input_tensor->dims->data[1];
......@@ -321,7 +321,7 @@ TfLiteStatus EvalIntegerSvdfHifi(TfLiteContext* context, TfLiteNode* node,
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
TFLITE_DCHECK(context != nullptr);
return context->AllocatePersistentBuffer(context, sizeof(OpData));
return context->AllocatePersistentBuffer(context, sizeof(OpDataSvdf));
}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
......@@ -422,7 +422,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
1e-5);
TFLITE_DCHECK(node->user_data != nullptr);
OpData* data = static_cast<OpData*>(node->user_data);
OpDataSvdf* data = static_cast<OpDataSvdf*>(node->user_data);
#if defined(HIFIMINI)
QuantizeMultiplierForInt24(effective_scale_1, &data->effective_scale_1_a,
......@@ -471,7 +471,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
TFLITE_DCHECK(node->user_data != nullptr);
const OpData& data = *(static_cast<const OpData*>(node->user_data));
const OpDataSvdf& data = *(static_cast<const OpDataSvdf*>(node->user_data));
#if defined(HIFIMINI)
EvalIntegerSvdfHifimini(context, node, input, weights_feature, weights_time,
......
......@@ -498,8 +498,9 @@ class MicroMutableOpResolver : public MicroOpResolver {
ParseSub);
}
TfLiteStatus AddSvdf() {
return AddBuiltin(BuiltinOperator_SVDF, Register_SVDF(), ParseSvdf);
TfLiteStatus AddSvdf(
const TfLiteRegistration& registration = Register_SVDF()) {
return AddBuiltin(BuiltinOperator_SVDF, registration, ParseSvdf);
}
TfLiteStatus AddTanh() {
......
MICROLITE_CC_KERNEL_SRCS += \
tensorflow/lite/micro/kernels/hexagon/fully_connected_int8.cc \
tensorflow/lite/micro/kernels/hexagon/svdf_int8.cc
# Full path to the hexagon_tflm static library.
HEXAGON_TFLM_LIB :=
......
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -52,181 +52,14 @@ ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/micro/kernels/hexagon/hexagon_fully_connected.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "third_party/hexagon/hexagon_tflm_translation_fully_connected.h"
namespace tflite {
namespace {
// Input tensors.
constexpr int kInputTensor = 0;
constexpr int kWeightsTensor = 1;
constexpr int kBiasTensor = 2;
// Output tensor.
constexpr int kOutputTensor = 0;
struct OpData {
// The scaling factor from input to output (aka the 'real multiplier') can
// be represented as a fixed point multiplier plus a left shift.
int32_t output_multiplier;
int output_shift;
// The range of the fused activation layer. For example for kNone and
// uint8_t these would be 0 and 255.
int32_t output_activation_min;
int32_t output_activation_max;
// The index of the temporary tensor where the quantized inputs are cached.
int input_quantized_index;
// Cached zero point values of tensors.
int32_t input_zero_point;
int32_t filter_zero_point;
int32_t output_zero_point;
void* hexagon_data;
};
TfLiteStatus CalculateOpData(TfLiteContext* context,
TfLiteFusedActivation activation,
TfLiteType data_type, const TfLiteTensor* input,
const TfLiteTensor* filter,
const TfLiteTensor* bias, TfLiteTensor* output,
OpData* data) {
TfLiteStatus status = kTfLiteOk;
if (data_type != kTfLiteFloat32) {
double real_multiplier = 0.0;
TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler(
context, input, filter, bias, output, &real_multiplier));
int exponent;
QuantizeMultiplier(real_multiplier, &data->output_multiplier, &exponent);
data->output_shift = -exponent;
TF_LITE_ENSURE_STATUS(CalculateActivationRangeQuantized(
context, activation, output, &data->output_activation_min,
&data->output_activation_max));
data->input_zero_point = input->params.zero_point;
data->filter_zero_point = filter->params.zero_point;
data->output_zero_point = output->params.zero_point;
}
return status;
}
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
void* data = nullptr;
data = context->AllocatePersistentBuffer(context, sizeof(OpData));
if (data == nullptr) {
return nullptr;
}
OpData* opdata = static_cast<OpData*>(data);
opdata->hexagon_data =
tflite::hexagon_fully_connected::HexagonInit(context, buffer, length);
return data;
}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TFLITE_DCHECK(node->user_data != nullptr);
TFLITE_DCHECK(node->builtin_data != nullptr);
OpData* data = static_cast<OpData*>(node->user_data);
const auto params =
static_cast<const TfLiteFullyConnectedParams*>(node->builtin_data);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TF_LITE_ENSURE(context, input != nullptr);
const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor);
TF_LITE_ENSURE(context, filter != nullptr);
const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
#include "hexagon_tflm_translation_fully_connected.h"
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
TF_LITE_ENSURE_MSG(context, input->type == filter->type,
"Hybrid models are not supported on TFLite Micro.");
tflite::hexagon_fully_connected::HexagonOptimizationEvaluation(context, node);
if (tflite::hexagon_fully_connected::HexagonOptimizable(context, node)) {
return tflite::hexagon_fully_connected::HexagonPrepare(context, node);
} else {
return CalculateOpData(context, params->activation, input->type, input,
filter, bias, output, data);
}
}
TfLiteStatus EvalQuantizedInt8(TfLiteContext* context, TfLiteNode* node,
const OpData& data,
const TfLiteEvalTensor* input,
const TfLiteEvalTensor* filter,
const TfLiteEvalTensor* bias,
TfLiteEvalTensor* output) {
tflite::FullyConnectedParams op_params;
op_params.input_offset = -data.input_zero_point;
op_params.weights_offset = -data.filter_zero_point;
op_params.output_offset = data.output_zero_point;
op_params.output_multiplier = data.output_multiplier;
// TODO(b/138810107): Figure out whether output shift should be inverted
op_params.output_shift = -data.output_shift;
op_params.quantized_activation_min = data.output_activation_min;
op_params.quantized_activation_max = data.output_activation_max;
reference_integer_ops::FullyConnected(
op_params, tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int8_t>(input),
tflite::micro::GetTensorShape(filter),
tflite::micro::GetTensorData<int8_t>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetTensorData<int32_t>(bias),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int8_t>(output));
return kTfLiteOk;
}
TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
const OpData& data, const TfLiteEvalTensor* input,
const TfLiteEvalTensor* filter,
const TfLiteEvalTensor* bias,
TfLiteEvalTensor* output) {
const int32_t input_offset = -data.input_zero_point;
const int32_t filter_offset = -data.filter_zero_point;
const int32_t output_offset = data.output_zero_point;
tflite::FullyConnectedParams op_params;
op_params.input_offset = input_offset;
op_params.weights_offset = filter_offset;
op_params.output_offset = output_offset;
op_params.output_multiplier = data.output_multiplier;
// Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
op_params.output_shift = -data.output_shift;
op_params.quantized_activation_min = data.output_activation_min;
op_params.quantized_activation_max = data.output_activation_max;
#define TF_LITE_FULLY_CONNECTED(output_data_type) \
reference_ops::FullyConnected( \
op_params, tflite::micro::GetTensorShape(input), \
tflite::micro::GetTensorData<uint8_t>(input), \
tflite::micro::GetTensorShape(filter), \
tflite::micro::GetTensorData<uint8_t>(filter), \
tflite::micro::GetTensorShape(bias), \
tflite::micro::GetTensorData<int32_t>(bias), \
tflite::micro::GetTensorShape(output), \
tflite::micro::GetTensorData<output_data_type>(output))
switch (output->type) {
case kTfLiteUInt8:
TF_LITE_FULLY_CONNECTED(uint8_t);
break;
case kTfLiteInt16:
TF_LITE_FULLY_CONNECTED(int16_t);
break;
default:
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
TfLiteTypeGetName(output->type), output->type);
return kTfLiteError;
}
namespace tflite {
return kTfLiteOk;
}
namespace {
TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node,
TfLiteFusedActivation activation,
......@@ -251,22 +84,23 @@ TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node,
return kTfLiteOk;
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace
TfLiteStatus HexagonFullyConnectedEval(TfLiteContext* context, TfLiteNode* node) {
TFLITE_DCHECK(node->builtin_data != nullptr);
const auto* params =
static_cast<const TfLiteFullyConnectedParams*>(node->builtin_data);
const TfLiteEvalTensor* input =
tflite::micro::GetEvalInput(context, node, kInputTensor);
tflite::micro::GetEvalInput(context, node, kFullyConnectedInputTensor);
const TfLiteEvalTensor* filter =
tflite::micro::GetEvalInput(context, node, kWeightsTensor);
tflite::micro::GetEvalInput(context, node, kFullyConnectedWeightsTensor);
const TfLiteEvalTensor* bias =
tflite::micro::GetEvalInput(context, node, kBiasTensor);
tflite::micro::GetEvalInput(context, node, kFullyConnectedBiasTensor);
TfLiteEvalTensor* output =
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
tflite::micro::GetEvalOutput(context, node, kFullyConnectedOutputTensor);
TFLITE_DCHECK(node->user_data != nullptr);
const OpData& data = *(static_cast<const OpData*>(node->user_data));
// Checks in Prepare ensure input, output and filter types are all the same.
switch (input->type) {
......@@ -275,16 +109,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
output);
case kTfLiteInt8:
if (tflite::hexagon_fully_connected::HexagonOptimizable(context, node)) {
return tflite::hexagon_fully_connected::HexagonEvalQuantizedInt8(
context, node, node->user_data, input, filter, bias, output);
} else {
return EvalQuantizedInt8(context, node, data, input, filter, bias,
output);
}
case kTfLiteUInt8:
return EvalQuantized(context, node, data, input, filter, bias, output);
return HexagonFullyConnectedEvalInt8(context, node);
default:
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
......@@ -294,13 +119,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
} // namespace
TfLiteRegistration Register_FULLY_CONNECTED() {
return {/*init=*/Init,
return {/*init=*/HexagonFullyConnectedInit,
/*free=*/nullptr,
/*prepare=*/Prepare,
/*invoke=*/Eval,
/*prepare=*/HexagonFullyConnectedPrepare,
/*invoke=*/HexagonFullyConnectedEval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
......
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
/* Copyright 2020 The Qualcomm Innovation Center, Inc. All Rights Reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted (subject to the limitations in the disclaimer
below) provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice,
this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
* Neither the name of Qualcomm Innovation Center, Inc. nor the names of its
contributors may be used to endorse or promote products derived from this
software without specific prior written permission.
NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY
THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT
NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
==============================================================================*/
#include "hexagon_tflm_translation_fully_connected.h"
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h"
#include "tensorflow/lite/kernels/internal/reference/fully_connected.h"
#include "tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/micro/kernels/fully_connected.h"
#include "tensorflow/lite/micro/kernels/hexagon/hexagon_fully_connected.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
namespace tflite {
namespace {
TfLiteStatus EvalQuantizedInt8(TfLiteContext* context, TfLiteNode* node,
const HexagonOpDataFullyConnected& data,
const TfLiteEvalTensor* input,
const TfLiteEvalTensor* filter,
const TfLiteEvalTensor* bias,
TfLiteEvalTensor* output) {
tflite::FullyConnectedParams op_params;
op_params.input_offset = -data.reference_op_data.input_zero_point;
op_params.weights_offset = -data.reference_op_data.filter_zero_point;
op_params.output_offset = data.reference_op_data.output_zero_point;
op_params.output_multiplier = data.reference_op_data.output_multiplier;
// TODO(b/138810107): Figure out whether output shift should be inverted
op_params.output_shift = -data.reference_op_data.output_shift;
op_params.quantized_activation_min =
data.reference_op_data.output_activation_min;
op_params.quantized_activation_max =
data.reference_op_data.output_activation_max;
reference_integer_ops::FullyConnected(
op_params, tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int8_t>(input),
tflite::micro::GetTensorShape(filter),
tflite::micro::GetTensorData<int8_t>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetTensorData<int32_t>(bias),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int8_t>(output));
return kTfLiteOk;
}
} // namespace
void* HexagonFullyConnectedInit(TfLiteContext* context, const char* buffer,
size_t length) {
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
void* data = nullptr;
data = context->AllocatePersistentBuffer(context,
sizeof(HexagonOpDataFullyConnected));
if (data == nullptr) {
return nullptr;
}
HexagonOpDataFullyConnected* opdata =
static_cast<HexagonOpDataFullyConnected*>(data);
opdata->hexagon_data =
tflite::hexagon_fully_connected::HexagonInit(context, buffer, length);
return data;
}
TfLiteStatus HexagonFullyConnectedPrepare(TfLiteContext* context, TfLiteNode* node) {
TFLITE_DCHECK(node->user_data != nullptr);
TFLITE_DCHECK(node->builtin_data != nullptr);
HexagonOpDataFullyConnected* data =
static_cast<HexagonOpDataFullyConnected*>(node->user_data);
const auto params =
static_cast<const TfLiteFullyConnectedParams*>(node->builtin_data);
const TfLiteTensor* input =
GetInput(context, node, kFullyConnectedInputTensor);
TF_LITE_ENSURE(context, input != nullptr);
const TfLiteTensor* filter =
GetInput(context, node, kFullyConnectedWeightsTensor);
TF_LITE_ENSURE(context, filter != nullptr);
const TfLiteTensor* bias =
GetOptionalInputTensor(context, node, kFullyConnectedBiasTensor);
TfLiteTensor* output = GetOutput(context, node, kFullyConnectedOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
TF_LITE_ENSURE_MSG(context, input->type == filter->type,
"Hybrid models are not supported on TFLite Micro.");
tflite::hexagon_fully_connected::HexagonOptimizationEvaluation(context, node);
if (tflite::hexagon_fully_connected::HexagonOptimizable(context, node)) {
return tflite::hexagon_fully_connected::HexagonPrepare(context, node);
} else {
return CalculateOpDataFullyConnected(context, params->activation, input->type, input,
filter, bias, output, &data->reference_op_data);
}
}
TfLiteStatus HexagonFullyConnectedEvalInt8(TfLiteContext* context, TfLiteNode* node) {
TFLITE_DCHECK(node->builtin_data != nullptr);
const TfLiteEvalTensor* input =
tflite::micro::GetEvalInput(context, node, kFullyConnectedInputTensor);
const TfLiteEvalTensor* filter =
tflite::micro::GetEvalInput(context, node, kFullyConnectedWeightsTensor);
const TfLiteEvalTensor* bias =
tflite::micro::GetEvalInput(context, node, kFullyConnectedBiasTensor);
TfLiteEvalTensor* output =
tflite::micro::GetEvalOutput(context, node, kFullyConnectedOutputTensor);
TFLITE_DCHECK(node->user_data != nullptr);
const HexagonOpDataFullyConnected& data =
*(static_cast<const HexagonOpDataFullyConnected*>(node->user_data));
// This kernel only implements the int8 version of the fully_connected kernel.
TFLITE_DCHECK(input->type == kTfLiteInt8);
TFLITE_DCHECK(filter->type == kTfLiteInt8);
TFLITE_DCHECK(bias->type == kTfLiteInt32);
TFLITE_DCHECK(output->type == kTfLiteInt8);
if (tflite::hexagon_fully_connected::HexagonOptimizable(context, node)) {
return tflite::hexagon_fully_connected::HexagonEvalQuantizedInt8(
context, node, node->user_data, input, filter, bias, output);
} else {
return EvalQuantizedInt8(context, node, data, input, filter, bias, output);
}
return kTfLiteOk;
}
TfLiteRegistration Register_FULLY_CONNECTED_INT8() {
return {/*init=*/HexagonFullyConnectedInit,
/*free=*/nullptr,
/*prepare=*/HexagonFullyConnectedPrepare,
/*invoke=*/HexagonFullyConnectedEvalInt8,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
}
} // namespace tflite
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_MICRO_KERNELS_HEXAGON_HEXAGON_FULLY_CONNECTED_H_
#define TENSORFLOW_LITE_MICRO_KERNELS_HEXAGON_HEXAGON_FULLY_CONNECTED_H_
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/types.h"
#include "tensorflow/lite/micro/kernels/fully_connected.h"
namespace tflite {
struct HexagonOpDataFullyConnected {
struct OpDataFullyConnected reference_op_data;
void* hexagon_data;
};
void* HexagonFullyConnectedInit(TfLiteContext* context, const char* buffer,
size_t length);
TfLiteStatus HexagonFullyConnectedPrepare(TfLiteContext* context, TfLiteNode* node);
TfLiteStatus HexagonFullyConnectedEvalInt8(TfLiteContext* context, TfLiteNode* node);
} // namespace tflite
#endif // TENSORFLOW_LITE_MICRO_KERNELS_HEXAGON_HEXAGON_FULLY_CONNECTED_H_
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_MICRO_KERNELS_HEXAGON_HEXAGON_SVDF_H_
#define TENSORFLOW_LITE_MICRO_KERNELS_HEXAGON_HEXAGON_SVDF_H_
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/types.h"
#include "tensorflow/lite/micro/kernels/svdf.h"
namespace tflite {
struct HexagonOpDataSvdf {
struct OpDataSvdf reference_op_data;
void* hexagon_data;
};
void* HexagonSvdfInit(TfLiteContext* context, const char* buffer, size_t length);
TfLiteStatus HexagonSvdfPrepare(TfLiteContext* context, TfLiteNode* node);
TfLiteStatus HexagonSvdfEvalInt8(TfLiteContext* context, TfLiteNode* node);
} // namespace tflite
#endif // TENSORFLOW_LITE_MICRO_KERNELS_HEXAGON_HEXAGON_SVDF_H_
此差异已折叠。
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
/* Copyright 2020 The Qualcomm Innovation Center, Inc. All Rights Reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted (subject to the limitations in the disclaimer
below) provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice,
this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
* Neither the name of Qualcomm Innovation Center, Inc. nor the names of its
contributors may be used to endorse or promote products derived from this
software without specific prior written permission.
NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY
THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT
NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
==============================================================================*/
#include <math.h>
#include "hexagon_tflm_translation_svdf.h"
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/op_macros.h"
#include "tensorflow/lite/micro/kernels/activation_utils.h"
#include "tensorflow/lite/micro/kernels/hexagon/hexagon_svdf.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "tensorflow/lite/micro/micro_utils.h"
namespace tflite {
TfLiteStatus HexagonSvdfEvalInt8(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteSVDFParams*>(node->builtin_data);
TFLITE_DCHECK(node->user_data != nullptr);
const HexagonOpDataSvdf& data =
*(static_cast<const HexagonOpDataSvdf*>(node->user_data));
const TfLiteEvalTensor* input =
tflite::micro::GetEvalInput(context, node, kSvdfInputTensor);
const TfLiteEvalTensor* weights_feature =
tflite::micro::GetEvalInput(context, node, kSvdfWeightsFeatureTensor);
const TfLiteEvalTensor* weights_time =
tflite::micro::GetEvalInput(context, node, kSvdfWeightsTimeTensor);
const TfLiteEvalTensor* bias =
(NumInputs(node) == 5)
? tflite::micro::GetEvalInput(context, node, kSvdfBiasTensor)
: nullptr;
TfLiteEvalTensor* activation_state = tflite::micro::GetMutableEvalInput(
context, node, kSvdfInputActivationStateTensor);
TfLiteEvalTensor* output =
tflite::micro::GetEvalOutput(context, node, kSvdfOutputTensor);
if (tflite::hexagon_svdf::HexagonOptimizable(context, node)) {
tflite::hexagon_svdf::HexagonEvalIntegerSVDF(
context, node, input, weights_feature, weights_time, bias, params,
activation_state, output, node->user_data);
} else {
EvalIntegerSvdfReference(context, node, input, weights_feature,
weights_time, bias, params, activation_state,
output, data.reference_op_data);
}
return kTfLiteOk;
}
void* HexagonSvdfInit(TfLiteContext* context, const char* buffer, size_t length) {
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
void* data = context->AllocatePersistentBuffer(context, sizeof(OpDataSvdf));
if (data == nullptr) {
return nullptr;
}
HexagonOpDataSvdf* opdata = static_cast<HexagonOpDataSvdf*>(data);
opdata->hexagon_data =
tflite::hexagon_svdf::HexagonInit(context, buffer, length);
return data;
}
TfLiteStatus HexagonSvdfPrepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteStatus prepare_status = PrepareSvdf(context, node);
if (prepare_status != kTfLiteOk) {
return prepare_status;
}
tflite::hexagon_svdf::HexagonOptimizationEvaluation(context, node);
if (tflite::hexagon_svdf::HexagonOptimizable(context, node)) {
TF_LITE_ENSURE_OK(context,
tflite::hexagon_svdf::HexagonPrepare(context, node));
}
return kTfLiteOk;
}
TfLiteRegistration Register_SVDF_INT8() {
return {/*init=*/HexagonSvdfInit,
/*free=*/nullptr,
/*prepare=*/HexagonSvdfPrepare,
/*invoke=*/HexagonSvdfEvalInt8,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
}
} // namespace tflite
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册