未验证 提交 b92f248a 编写于 作者: D deqiangc 提交者: GitHub

Import third_party/hexagon (#1028)

Importing third_party/hexagon from codelinario at commit:
5bee22654219b89efdd35e447cc89d50369ad168

This is generated by running ci/import_third_party_hexagon.sh

BUG=https://b/227665919
上级 b0fd265b
......@@ -72,13 +72,17 @@ TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node,
tflite::FullyConnectedParams op_params;
op_params.float_activation_min = output_activation_min;
op_params.float_activation_max = output_activation_max;
const float* bias_data =
nullptr != bias ? tflite::micro::GetTensorData<float>(bias) : nullptr;
tflite::reference_ops::FullyConnected(
op_params, tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<float>(input),
tflite::micro::GetTensorShape(filter),
tflite::micro::GetTensorData<float>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetTensorData<float>(bias),
bias_data,
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<float>(output));
return kTfLiteOk;
......
......@@ -53,7 +53,9 @@ ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/micro/kernels/fully_connected.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "third_party/hexagon/hexagon_fully_connected.h"
#include "third_party/hexagon/hexagon_tflm_translation_fully_connected.h"
namespace tflite {
namespace {
......@@ -70,19 +72,22 @@ TfLiteStatus EvalQuantizedInt8(TfLiteContext* context, TfLiteNode* node,
op_params.output_offset = data.reference_op_data.output_zero_point;
op_params.output_multiplier = data.reference_op_data.output_multiplier;
// TODO(b/138810107): Figure out whether output shift should be inverted
op_params.output_shift = -data.reference_op_data.output_shift;
op_params.output_shift = data.reference_op_data.output_shift;
op_params.quantized_activation_min =
data.reference_op_data.output_activation_min;
op_params.quantized_activation_max =
data.reference_op_data.output_activation_max;
const int32_t* bias_data =
nullptr != bias ? tflite::micro::GetTensorData<int32_t>(bias) : nullptr;
reference_integer_ops::FullyConnected(
op_params, tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int8_t>(input),
tflite::micro::GetTensorShape(filter),
tflite::micro::GetTensorData<int8_t>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetTensorData<int32_t>(bias),
bias_data,
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int8_t>(output));
......@@ -132,9 +137,19 @@ TfLiteStatus HexagonFullyConnectedPrepare(TfLiteContext* context, TfLiteNode* no
AllocateTempOutputTensor(node, kFullyConnectedOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
TF_LITE_ENSURE_OK(context, CalculateOpDataFullyConnected(
context, params->activation, input->type,
input, filter, bias, output, &data->reference_op_data));
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
TF_LITE_ENSURE_MSG(context, input->type == filter->type,
"Hybrid models are not supported on TFLite Micro.");
micro_context->DeallocateTempTfLiteTensor(input);
micro_context->DeallocateTempTfLiteTensor(filter);
micro_context->DeallocateTempTfLiteTensor(bias);
if (bias != nullptr) {
micro_context->DeallocateTempTfLiteTensor(bias);
}
micro_context->DeallocateTempTfLiteTensor(output);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
......@@ -145,10 +160,8 @@ TfLiteStatus HexagonFullyConnectedPrepare(TfLiteContext* context, TfLiteNode* no
if (tflite::hexagon_fully_connected::HexagonOptimizable(context, node)) {
return tflite::hexagon_fully_connected::HexagonPrepare(context, node);
} else {
return CalculateOpDataFullyConnected(context, params->activation, input->type, input,
filter, bias, output, &data->reference_op_data);
}
return kTfLiteOk;
}
TfLiteStatus HexagonFullyConnectedEvalInt8(TfLiteContext* context, TfLiteNode* node) {
......@@ -170,7 +183,9 @@ TfLiteStatus HexagonFullyConnectedEvalInt8(TfLiteContext* context, TfLiteNode* n
// This kernel only implements the int8 version of the fully_connected kernel.
TFLITE_DCHECK(input->type == kTfLiteInt8);
TFLITE_DCHECK(filter->type == kTfLiteInt8);
TFLITE_DCHECK(bias->type == kTfLiteInt32);
if (bias != nullptr) {
TFLITE_DCHECK(bias->type == kTfLiteInt32);
}
TFLITE_DCHECK(output->type == kTfLiteInt8);
if (tflite::hexagon_fully_connected::HexagonOptimizable(context, node)) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册