未验证 提交 70aed11c 编写于 作者: S suleshahid 提交者: GitHub

Adds FFT Auto Scale Op (#2134)

This PR adds additional FFT op functionality in the Signal library, namely adding the FFT Auto Scale operation.
Testing added in the original `fft_test.cc` and `fft_ops_test.py`.
BUG=[287346710](http://b/287346710)
上级 52007f6a
......@@ -51,6 +51,7 @@ PythonOpsResolver::PythonOpsResolver() {
AddEthosU();
AddExp();
AddExpandDims();
AddFftAutoScale();
AddFill();
AddFloor();
AddFloorDiv();
......
......@@ -46,6 +46,10 @@ py_tflm_signal_library(
py_test(
name = "fft_ops_test",
srcs = ["ops/fft_ops_test.py"],
data = [
"//python/tflite_micro/signal/ops/testdata:fft_auto_scale_test1.txt",
"//python/tflite_micro/signal/ops/testdata:rfft_test1.txt",
],
python_version = "PY3",
srcs_version = "PY3",
deps = [
......
......@@ -65,5 +65,22 @@ def _fft_wrapper(fft_fn, default_name):
return _fft
def _fft_auto_scale_wrapper(fft_auto_scale_fn, default_name):
"""Wrapper around gen_fft_ops.fft_auto_scale*."""
def _fft_auto_scale(input_tensor, name=default_name):
with tf.name_scope(name) as name:
input_tensor = tf.convert_to_tensor(input_tensor, dtype=tf.int16)
dim_list = input_tensor.shape.as_list()
if len(dim_list) != 1:
raise ValueError("Input tensor must have a rank of 1")
return fft_auto_scale_fn(input_tensor, name=name)
return _fft_auto_scale
rfft = _fft_wrapper(gen_fft_ops.signal_rfft, "signal_rfft")
fft_auto_scale = _fft_auto_scale_wrapper(gen_fft_ops.signal_fft_auto_scale,
"signal_fft_auto_scale")
tf.no_gradient("signal_rfft")
tf.no_gradient("signal_fft_auto_scale")
......@@ -33,6 +33,31 @@ class RfftOpTest(tf.test.TestCase):
file_text = f.read()
return file_text
def SingleFftAutoScaleTest(self, filename):
lines = self.GetResource(filename).splitlines()
func = tf.function(fft_ops.fft_auto_scale)
input_size = len(lines[0].split())
concrete_function = func.get_concrete_function(
tf.TensorSpec(input_size, dtype=tf.int16))
interpreter = util.get_tflm_interpreter(concrete_function, func)
i = 0
while i < len(lines):
in_frame = np.array([int(j) for j in lines[i].split()], dtype=np.int16)
out_frame_exp = [int(j) for j in lines[i + 1].split()]
scale_exp = [int(j) for j in lines[i + 2].split()]
# TFLM
interpreter.set_input(in_frame, 0)
interpreter.invoke()
out_frame = interpreter.get_output(0)
scale = interpreter.get_output(1)
self.assertAllEqual(out_frame_exp, out_frame)
self.assertEqual(scale_exp, scale)
# TF
out_frame, scale = self.evaluate(fft_ops.fft_auto_scale(in_frame))
self.assertAllEqual(out_frame_exp, out_frame)
self.assertEqual(scale_exp, scale)
i += 3
def SingleRfftTest(self, filename):
lines = self.GetResource(filename).splitlines()
args = lines[0].split()
......@@ -43,8 +68,6 @@ class RfftOpTest(tf.test.TestCase):
tf.TensorSpec(input_size, dtype=tf.int16), fft_length)
# TODO(b/286252893): make test more robust (vs scipy)
interpreter = util.get_tflm_interpreter(concrete_function, func)
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
# Skip line 0, which contains the configuration params.
# Read lines in pairs <input, expected>
i = 1
......@@ -53,9 +76,9 @@ class RfftOpTest(tf.test.TestCase):
out_frame_exp = [int(j) for j in lines[i + 1].split()]
# Compare TFLM inference against the expected golden values
# TODO(b/286252893): validate usage of testing vs interpreter here
interpreter.set_tensor(input_details[0]['index'], in_frame)
interpreter.set_input(in_frame, 0)
interpreter.invoke()
out_frame = interpreter.get_tensor(output_details[0]['index'])
out_frame = interpreter.get_output(0)
self.assertAllEqual(out_frame_exp, out_frame)
# TF
out_frame = self.evaluate(fft_ops.rfft(in_frame, fft_length))
......@@ -83,11 +106,9 @@ class RfftOpTest(tf.test.TestCase):
concrete_function = func.get_concrete_function(
tf.TensorSpec(np.shape(in_frames), dtype=tf.int16), fft_length)
interpreter = util.get_tflm_interpreter(concrete_function, func)
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.set_tensor(input_details[0]['index'], in_frames)
interpreter.set_input(in_frames, 0)
interpreter.invoke()
out_frame = interpreter.get_tensor(output_details[0]['index'])
out_frame = interpreter.get_output(0)
self.assertAllEqual(out_frames_exp, out_frame)
# TF
out_frames = self.evaluate(fft_ops.rfft(in_frames, fft_length))
......@@ -204,6 +225,12 @@ class RfftOpTest(tf.test.TestCase):
delta=1)
fft_length = 2 * fft_length
def testRfft(self):
self.SingleRfftTest('testdata/rfft_test1.txt')
def testRfftLargeOuterDimension(self):
self.MultiDimRfftTest('testdata/rfft_test1.txt')
def testFftTooLarge(self):
for dtype in [np.int16, np.int32, np.float32]:
fft_input = np.zeros(round(fft_ops._MAX_FFT_LENGTH * 2), dtype=dtype)
......@@ -224,6 +251,9 @@ class RfftOpTest(tf.test.TestCase):
with self.assertRaises((tf.errors.InvalidArgumentError, ValueError)):
self.evaluate(fft_ops.rfft(fft_input, 127))
def testAutoScale(self):
self.SingleFftAutoScaleTest('testdata/fft_auto_scale_test1.txt')
def testPow2FftLengthTest(self):
fft_length, fft_bits = fft_ops.get_pow2_fft_length(131)
self.assertEqual(fft_length, 256)
......
......@@ -7,6 +7,7 @@ package(
)
exports_files([
"fft_auto_scale_test1.txt",
"rfft_test1.txt",
"window_test1.txt",
])
......@@ -10,6 +10,7 @@ cc_library(
srcs = [
"delay.cc",
"energy.cc",
"fft_auto_scale.cc",
"filter_bank.cc",
"filter_bank_log.cc",
"filter_bank_spectral_subtraction.cc",
......@@ -30,6 +31,7 @@ cc_library(
deps = [
"//signal/src:circular_buffer",
"//signal/src:energy",
"//signal/src:fft_auto_scale",
"//signal/src:filter_bank",
"//signal/src:filter_bank_log",
"//signal/src:filter_bank_spectral_subtraction",
......
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "signal/src/fft_auto_scale.h"
#include <math.h>
#include <stddef.h>
#include <stdint.h>
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "tensorflow/lite/micro/micro_context.h"
namespace tflite {
namespace {
constexpr int kInputTensor = 0;
constexpr int kOutputTensor = 0;
constexpr int kScaleBitTensor = 1;
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 2);
MicroContext* micro_context = GetMicroContext(context);
TfLiteTensor* input =
micro_context->AllocateTempInputTensor(node, kInputTensor);
TF_LITE_ENSURE(context, input != nullptr);
TfLiteTensor* output =
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
TfLiteTensor* scale_bit =
micro_context->AllocateTempOutputTensor(node, kScaleBitTensor);
TF_LITE_ENSURE(context, scale_bit != nullptr);
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 1);
TF_LITE_ENSURE_EQ(context, NumDimensions(output), 1);
TF_LITE_ENSURE_EQ(context, NumDimensions(scale_bit), 0);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteInt16);
TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteInt16);
TF_LITE_ENSURE_TYPES_EQ(context, scale_bit->type, kTfLiteInt32);
micro_context->DeallocateTempTfLiteTensor(scale_bit);
micro_context->DeallocateTempTfLiteTensor(input);
micro_context->DeallocateTempTfLiteTensor(output);
return kTfLiteOk;
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteEvalTensor* input =
tflite::micro::GetEvalInput(context, node, kInputTensor);
TfLiteEvalTensor* output =
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
TfLiteEvalTensor* scale_bit =
tflite::micro::GetEvalOutput(context, node, kScaleBitTensor);
const int16_t* input_data = tflite::micro::GetTensorData<int16_t>(input);
int16_t* output_data = tflite::micro::GetTensorData<int16_t>(output);
int32_t* scale_bit_data = tflite::micro::GetTensorData<int32_t>(scale_bit);
*scale_bit_data =
tflm_signal::FftAutoScale(input_data, output->dims->data[0], output_data);
return kTfLiteOk;
}
} // namespace
// TODO(b/286250473): remove namespace once de-duped libraries
namespace tflm_signal {
TFLMRegistration* Register_FFT_AUTO_SCALE() {
static TFLMRegistration r = tflite::micro::RegisterOp(nullptr, Prepare, Eval);
return &r;
}
} // namespace tflm_signal
} // namespace tflite
......@@ -85,6 +85,43 @@ TfLiteStatus TestFFT(int* input_dims_data, const T* input_data,
return kTfLiteOk;
}
TfLiteStatus TestFFTAutoScale(int* input_dims_data, const int16_t* input_data,
int* output_dims_data, const int16_t* golden,
int* scale_bit_dims_data,
const int32_t scale_bit_golden,
const TFLMRegistration registration,
const uint8_t* flexbuffers_data,
const int flexbuffers_data_len,
int16_t* output_data, int32_t* scale_bit) {
TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
TfLiteIntArray* scale_bit_dims = IntArrayFromInts(scale_bit_dims_data);
constexpr int kInputsSize = 1;
constexpr int kOutputsSize = 2;
constexpr int kTensorsSize = kInputsSize + kOutputsSize;
TfLiteTensor tensors[kTensorsSize] = {
CreateTensor(input_data, input_dims),
CreateTensor(output_data, output_dims),
CreateTensor(scale_bit, scale_bit_dims),
};
int inputs_array_data[] = {1, 0};
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {2, 1, 2};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
const int output_len = ElementCount(*output_dims);
TF_LITE_ENSURE_STATUS(ValidateFFTGoldens<int16_t>(
tensors, kTensorsSize, inputs_array, outputs_array, output_len, golden,
registration, flexbuffers_data, flexbuffers_data_len, output_data, 0));
TF_LITE_MICRO_EXPECT_EQ(scale_bit_golden, *scale_bit);
return kTfLiteOk;
}
} // namespace
} // namespace testing
......@@ -266,4 +303,61 @@ TF_LITE_MICRO_TEST(RfftTestSize512Int32) {
g_gen_data_size_fft_length_512_int32, output, 0));
}
TF_LITE_MICRO_TEST(FftAutoScaleTestSmall) {
constexpr int kTensorsSize = 8;
int shape[] = {1, 8};
const int16_t input[] = {0x0000, 0x1111, 0x2222, 0x3333,
0x3333, 0x2222, 0x1111, 0x0000};
int16_t output[kTensorsSize];
int scale_bit_shape[] = {0};
int32_t scale_bit;
const int16_t golden[] = {0x0000, 0x2222, 0x4444, 0x6666,
0x6666, 0x4444, 0x2222, 0x0000};
const int32_t scale_bit_golden = 1;
const TFLMRegistration* registration =
tflite::tflm_signal::Register_FFT_AUTO_SCALE();
TF_LITE_MICRO_EXPECT_EQ(
kTfLiteOk,
tflite::testing::TestFFTAutoScale(
shape, input, shape, golden, scale_bit_shape, scale_bit_golden,
*registration, nullptr, 0, output, &scale_bit));
}
TF_LITE_MICRO_TEST(FftAutoScaleTestScaleBit) {
constexpr int kTensorsSize = 8;
int shape[] = {1, 8};
const int16_t input[] = {238, 113, -88, -243, -5, -130, 159, -70};
int16_t output[kTensorsSize];
int scale_bit_shape[] = {0};
int32_t scale_bit;
const int16_t golden[] = {30464, 14464, -11264, -31104,
-640, -16640, 20352, -8960};
const int32_t scale_bit_golden = 7;
const TFLMRegistration* registration =
tflite::tflm_signal::Register_FFT_AUTO_SCALE();
TF_LITE_MICRO_EXPECT_EQ(
kTfLiteOk,
tflite::testing::TestFFTAutoScale(
shape, input, shape, golden, scale_bit_shape, scale_bit_golden,
*registration, nullptr, 0, output, &scale_bit));
}
TF_LITE_MICRO_TEST(FftAutoScaleTestLarge) {
constexpr int kTensorsSize = 400;
int shape[] = {1, kTensorsSize};
int16_t output[kTensorsSize];
int scale_bit_shape[] = {0};
int32_t scale_bit;
const int32_t scale_bit_golden = 0;
const TFLMRegistration* registration =
tflite::tflm_signal::Register_FFT_AUTO_SCALE();
TF_LITE_MICRO_EXPECT_EQ(
kTfLiteOk,
tflite::testing::TestFFTAutoScale(
shape, tflite::kFftAutoScaleLength512Input, shape,
tflite::kFftAutoScaleLength512Golden, scale_bit_shape,
scale_bit_golden, *registration, nullptr, 0, output, &scale_bit));
}
TF_LITE_MICRO_TESTS_END
......@@ -8,6 +8,22 @@ cc_library(
hdrs = ["complex.h"],
)
cc_library(
name = "fft_auto_scale",
srcs = ["fft_auto_scale.cc"],
hdrs = ["fft_auto_scale.h"],
deps = [
":max_abs",
":msb_32",
],
)
cc_library(
name = "max_abs",
srcs = ["max_abs.cc"],
hdrs = ["max_abs.h"],
)
cc_library(
name = "square_root_32",
srcs = ["square_root_32.cc"],
......
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "signal/src/fft_auto_scale.h"
#include <stddef.h>
#include <stdint.h>
#include "signal/src/max_abs.h"
#include "signal/src/msb.h"
// TODO(b/286250473): remove namespace once de-duped libraries
namespace tflite {
namespace tflm_signal {
int FftAutoScale(const int16_t* input, int size, int16_t* output) {
const int16_t max = MaxAbs16(input, size);
int scale_bits = (sizeof(int16_t) * 8) - MostSignificantBit32(max) - 1;
if (scale_bits <= 0) {
scale_bits = 0;
}
for (int i = 0; i < size; i++) {
// (input[i] << scale_bits) is undefined if input[i] is negative.
// Multiply explicitly to make the code portable.
output[i] = input[i] * (1 << scale_bits);
}
return scale_bits;
}
} // namespace tflm_signal
} // namespace tflite
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef SIGNAL_SRC_FFT_AUTO_SCALE_H_
#define SIGNAL_SRC_FFT_AUTO_SCALE_H_
#include <stddef.h>
#include <stdint.h>
// TODO(b/286250473): remove namespace once de-duped libraries
namespace tflite {
namespace tflm_signal {
// Auto scales `input` and write the result to `output`
// Elements in `input` are left shifted to maximize the amplitude without
// clipping,
// * both `input` and `output` must be of size `size`
int FftAutoScale(const int16_t* input, int size, int16_t* output);
} // namespace tflm_signal
} // namespace tflite
#endif // SIGNAL_SRC_FFT_AUTO_SCALE_H_
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "signal/src/max_abs.h"
// TODO(b/286250473): remove namespace once de-duped libraries
namespace tflite {
namespace tflm_signal {
int16_t MaxAbs16(const int16_t* input, int size) {
int16_t max = 0;
for (int i = 0; i < size; i++) {
const int16_t value = input[i];
if (value > max) {
max = value;
} else if (-value > max) {
max = -value;
}
}
return max;
}
} // namespace tflm_signal
} // namespace tflite
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef SIGNAL_SRC_MAX_ABS_H_
#define SIGNAL_SRC_MAX_ABS_H_
#include <stdint.h>
// TODO(b/286250473): remove namespace once de-duped libraries
namespace tflite {
namespace tflm_signal {
// Returns the maximum absolute value of the `size` elements in `input`
int16_t MaxAbs16(const int16_t* input, int size);
} // namespace tflm_signal
} // namespace tflite
#endif // SIGNAL_SRC_MAX_ABS_H_
......@@ -9,6 +9,7 @@ tflm_signal_kernel_library(
name = "fft_kernel",
srcs = ["fft_kernels.cc"],
deps = [
"//signal/src:fft_auto_scale",
"//signal/src:rfft",
"@tensorflow_cc_deps//:cc_library",
],
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "signal/src/fft_auto_scale.h"
#include "signal/src/rfft.h"
#include "tensorflow/core/framework/op_kernel.h"
......@@ -81,7 +82,30 @@ class RfftOp : public tensorflow::OpKernel {
Tensor state_tensor_;
};
class FftAutoScaleOp : public tensorflow::OpKernel {
public:
explicit FftAutoScaleOp(tensorflow::OpKernelConstruction* context)
: tensorflow::OpKernel(context) {}
void Compute(tensorflow::OpKernelContext* context) override {
const tensorflow::Tensor& input_tensor = context->input(0);
const int16_t* input = input_tensor.flat<int16_t>().data();
// Create an output tensor
tensorflow::Tensor* output_tensor = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
&output_tensor));
int16_t* output = output_tensor->flat<int16_t>().data();
tensorflow::Tensor* scale_bit_tensor = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(1, {}, &scale_bit_tensor));
scale_bit_tensor->scalar<int32_t>()() = tflite::tflm_signal::FftAutoScale(
input, output_tensor->NumElements(), output);
}
};
// TODO(b/286250473): change back name after name clash resolved
REGISTER_KERNEL_BUILDER(
Name("SignalFftAutoScale").Device(tensorflow::DEVICE_CPU), FftAutoScaleOp);
REGISTER_KERNEL_BUILDER(
Name("SignalRfft")
.Device(tensorflow::DEVICE_CPU)
......
......@@ -62,5 +62,27 @@ fft_length: The length of the FFT operation. An input signal that's shorter
will be zero padded to fft_length.
)doc");
// TODO(b/286250473): change back name after name clash resolved
REGISTER_OP("SignalFftAutoScale")
.Input("input: int16")
.Output("output: int16")
.Output("scale_bits: int32")
.SetShapeFn([](shape_inference::InferenceContext* c) {
ShapeHandle out;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &out));
c->set_output(0, out);
c->set_output(1, c->Scalar());
return OkStatus();
})
.Doc(R"doc(
Shifts the input left until the amplitude is maximized without clipping. Returns
the amount of left shift for compensation later. This op can be used to maximize
precision of integer FFT implementations, especially 16-bit.
input: A 1-D time domain signal.
output: A 1-D time domain signal after auto scaling.
scale_bits: Scalar. The number of left shifts applied to the input signal.
)doc");
} // namespace signal
} // namespace tensorflow
......@@ -571,4 +571,98 @@ const float kRfftFloatLength512Golden[] = {
-5.16014671e+00, 6.53038979e+00, 5.11271954e-02, -3.40430737e-01,
-2.98720551e+00, 0.00000000e+00};
const int16_t kFftAutoScaleLength512Input[] = {
27728, 28180, -23037, -999, 7627, 19097, 4809, -28251, 25421,
21584, 5775, -514, 31389, -13221, 28700, 3928, -32678, 9413,
21553, -11903, 19367, 18168, 3923, 13968, -19808, -8946, 9707,
4996, -2033, 26133, 8465, -29982, 145, 4190, -27992, 18817,
-7428, 24828, 23734, -1823, -24872, -14335, 12368, 24986, 26849,
14131, -21011, 26386, 22377, -11729, 18306, -9073, 14716, -14157,
17728, -6062, 21593, -5285, -20960, -5598, -10791, 2156, -1026,
12427, 8910, 12726, 26340, 20350, -28668, 5619, 29298, 10801,
11427, -5965, 6491, -21680, 1107, 9373, 14507, -16623, -17628,
-3940, 3115, -28103, 8256, 23231, 13811, -27905, 20126, 7647,
-31981, 4270, -2447, -24855, -15303, 29773, 22757, 13823, -2623,
-18288, -29900, -32254, 2751, 21161, 13991, -4786, 2118, -5267,
-13969, -7855, 17443, -4976, 26066, 29508, -22460, -27974, 28652,
-24058, 25214, 18340, 15171, 3101, 28870, 11533, -3481, 3047,
6802, 23452, 15039, -32103, -29734, -9678, 11514, 29903, -30983,
-30655, -19125, 18490, 14868, 18846, -3849, 18431, 31245, 12547,
24499, -15543, 8004, 30226, -17948, -27011, -10040, -21747, -4077,
29493, 25322, -10680, -23061, 16320, -262, 19493, -29407, -30065,
-29158, -14538, 17318, 5900, 22016, -8523, -3482, 8226, 6269,
-9888, -18543, 11548, 32126, 716, 12042, -2344, 28403, 16777,
18536, 19821, -13224, -32000, 18846, 29756, 15152, -25000, -2854,
2894, 5817, -20644, -14967, -22120, 19426, 24206, -10895, 9359,
22642, 15413, -544, 11456, 29289, -29467, -7007, 1946, -12714,
29003, -9754, -11247, -8743, -26529, 12775, 17819, -4068, 10532,
7749, -15982, -3309, -15015, -14728, -5061, 4456, -23045, 4875,
18613, -26598, 20189, -2459, -1949, -27655, 10714, -25641, -31826,
-14901, 26440, -29853, -21380, -9872, 7330, -24977, 28143, 22335,
-24296, 1775, 19950, -31505, -23314, -7708, 8747, -14274, 30659,
-31703, -16215, -7103, -7876, 25772, 773, 28262, 16517, 26455,
-15645, 18958, -1342, 30649, 6825, 8075, -13666, 16635, 31946,
-22845, -27888, 11845, 7597, -20615, -27995, 11419, -2343, -6894,
7419, 30308, 15120, 24538, -25659, -26220, 25970, -11688, 26728,
-27865, -8426, 24771, 30570, 27041, -20003, 13894, 16227, -32113,
-4925, -7249, -27491, 743, 11549, -18304, 6082, -27239, 22277,
-914, 5237, -30772, -6916, 15278, -28297, 9274, 14611, 23071,
-9831, 1675, -31961, -17243, 16597, -21968, 12045, -16939, 9563,
7989, 1251, 22767, 28480, 31961, 31297, 30398, -2645, 1837,
-15697, -19268, -15887, 29292, -10900, 812, -10870, -2759, 20450,
-20981, 28539, -30402, -17263, -19693, -32710, 6172, -30003, 27373,
24939, -28543, -8928, 1198, -326, -3504, -23640, 24945, -24141,
17787, 20449, -7981, -10926, -26171, -20678, 4107, -32513, 8184,
-12479, 16854, -18552, -21534, -8804, -30278, 18573, -16409, 14746,
-17123, 24656, 25243, 4516, 19254, -2165, 24230, -7639, -19385,
-15505, -15386, 21841, -12507, 9168, 4469, -2649, 21013, 23788,
21282, 27991, -31716, 22753};
const int16_t kFftAutoScaleLength512Golden[] = {
27728, 28180, -23037, -999, 7627, 19097, 4809, -28251, 25421,
21584, 5775, -514, 31389, -13221, 28700, 3928, -32678, 9413,
21553, -11903, 19367, 18168, 3923, 13968, -19808, -8946, 9707,
4996, -2033, 26133, 8465, -29982, 145, 4190, -27992, 18817,
-7428, 24828, 23734, -1823, -24872, -14335, 12368, 24986, 26849,
14131, -21011, 26386, 22377, -11729, 18306, -9073, 14716, -14157,
17728, -6062, 21593, -5285, -20960, -5598, -10791, 2156, -1026,
12427, 8910, 12726, 26340, 20350, -28668, 5619, 29298, 10801,
11427, -5965, 6491, -21680, 1107, 9373, 14507, -16623, -17628,
-3940, 3115, -28103, 8256, 23231, 13811, -27905, 20126, 7647,
-31981, 4270, -2447, -24855, -15303, 29773, 22757, 13823, -2623,
-18288, -29900, -32254, 2751, 21161, 13991, -4786, 2118, -5267,
-13969, -7855, 17443, -4976, 26066, 29508, -22460, -27974, 28652,
-24058, 25214, 18340, 15171, 3101, 28870, 11533, -3481, 3047,
6802, 23452, 15039, -32103, -29734, -9678, 11514, 29903, -30983,
-30655, -19125, 18490, 14868, 18846, -3849, 18431, 31245, 12547,
24499, -15543, 8004, 30226, -17948, -27011, -10040, -21747, -4077,
29493, 25322, -10680, -23061, 16320, -262, 19493, -29407, -30065,
-29158, -14538, 17318, 5900, 22016, -8523, -3482, 8226, 6269,
-9888, -18543, 11548, 32126, 716, 12042, -2344, 28403, 16777,
18536, 19821, -13224, -32000, 18846, 29756, 15152, -25000, -2854,
2894, 5817, -20644, -14967, -22120, 19426, 24206, -10895, 9359,
22642, 15413, -544, 11456, 29289, -29467, -7007, 1946, -12714,
29003, -9754, -11247, -8743, -26529, 12775, 17819, -4068, 10532,
7749, -15982, -3309, -15015, -14728, -5061, 4456, -23045, 4875,
18613, -26598, 20189, -2459, -1949, -27655, 10714, -25641, -31826,
-14901, 26440, -29853, -21380, -9872, 7330, -24977, 28143, 22335,
-24296, 1775, 19950, -31505, -23314, -7708, 8747, -14274, 30659,
-31703, -16215, -7103, -7876, 25772, 773, 28262, 16517, 26455,
-15645, 18958, -1342, 30649, 6825, 8075, -13666, 16635, 31946,
-22845, -27888, 11845, 7597, -20615, -27995, 11419, -2343, -6894,
7419, 30308, 15120, 24538, -25659, -26220, 25970, -11688, 26728,
-27865, -8426, 24771, 30570, 27041, -20003, 13894, 16227, -32113,
-4925, -7249, -27491, 743, 11549, -18304, 6082, -27239, 22277,
-914, 5237, -30772, -6916, 15278, -28297, 9274, 14611, 23071,
-9831, 1675, -31961, -17243, 16597, -21968, 12045, -16939, 9563,
7989, 1251, 22767, 28480, 31961, 31297, 30398, -2645, 1837,
-15697, -19268, -15887, 29292, -10900, 812, -10870, -2759, 20450,
-20981, 28539, -30402, -17263, -19693, -32710, 6172, -30003, 27373,
24939, -28543, -8928, 1198, -326, -3504, -23640, 24945, -24141,
17787, 20449, -7981, -10926, -26171, -20678, 4107, -32513, 8184,
-12479, 16854, -18552, -21534, -8804, -30278, 18573, -16409, 14746,
-17123, 24656, 25243, 4516, 19254, -2165, 24230, -7639, -19385,
-15505, -15386, 21841, -12507, 9168, 4469, -2649, 21013, 23788,
21282, 27991, -31716, 22753};
} // namespace tflite
......@@ -31,6 +31,9 @@ extern const int32_t kRfftInt32Length512Golden[];
extern const float kRfftFloatLength512Input[];
extern const float kRfftFloatLength512Golden[];
extern const int16_t kFftAutoScaleLength512Input[];
extern const int16_t kFftAutoScaleLength512Golden[];
} // namespace tflite
#endif // SIGNAL_TESTDATA_FFT_TEST_DATA_H_
......@@ -135,6 +135,7 @@ TFLMRegistration Register_ZEROS_LIKE();
// TODO(b/160234179): Change custom OPs to also return by value.
namespace tflm_signal {
TFLMRegistration* Register_DELAY();
TFLMRegistration* Register_FFT_AUTO_SCALE();
TFLMRegistration* Register_FILTER_BANK();
TFLMRegistration* Register_FILTER_BANK_LOG();
TFLMRegistration* Register_FILTER_BANK_SPECTRAL_SUBTRACTION();
......
......@@ -254,6 +254,12 @@ class MicroMutableOpResolver : public MicroOpResolver {
ParseExpandDims);
}
TfLiteStatus AddFftAutoScale() {
// TODO(b/286250473): change back name and remove namespace
return AddCustom("SignalFftAutoScale",
tflite::tflm_signal::Register_FFT_AUTO_SCALE());
}
TfLiteStatus AddFill() {
return AddBuiltin(BuiltinOperator_FILL, tflite::Register_FILL(), ParseFill);
}
......
......@@ -314,6 +314,7 @@ $(TENSORFLOW_ROOT)tensorflow/lite/micro/memory_planner/non_persistent_buffer_pla
MICROLITE_CC_KERNEL_SRCS := \
$(TENSORFLOW_ROOT)signal/micro/kernels/delay.cc \
$(TENSORFLOW_ROOT)signal/micro/kernels/energy.cc \
$(TENSORFLOW_ROOT)signal/micro/kernels/fft_auto_scale.cc \
$(TENSORFLOW_ROOT)signal/micro/kernels/filter_bank.cc \
$(TENSORFLOW_ROOT)signal/micro/kernels/filter_bank_log.cc \
$(TENSORFLOW_ROOT)signal/micro/kernels/filter_bank_square_root.cc \
......@@ -325,11 +326,13 @@ $(TENSORFLOW_ROOT)signal/micro/kernels/overlap_add.cc \
$(TENSORFLOW_ROOT)signal/micro/kernels/window.cc \
$(TENSORFLOW_ROOT)signal/src/circular_buffer.cc \
$(TENSORFLOW_ROOT)signal/src/energy.cc \
$(TENSORFLOW_ROOT)signal/src/fft_auto_scale.cc \
$(TENSORFLOW_ROOT)signal/src/filter_bank.cc \
$(TENSORFLOW_ROOT)signal/src/filter_bank_log.cc \
$(TENSORFLOW_ROOT)signal/src/filter_bank_square_root.cc \
$(TENSORFLOW_ROOT)signal/src/filter_bank_spectral_subtraction.cc \
$(TENSORFLOW_ROOT)signal/src/log.cc \
$(TENSORFLOW_ROOT)signal/src/max_abs.cc \
$(TENSORFLOW_ROOT)signal/src/msb_32.cc \
$(TENSORFLOW_ROOT)signal/src/msb_64.cc \
$(TENSORFLOW_ROOT)signal/src/overlap_add.cc \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册