...
 
Commits (3)
    https://gitcode.net/xusiwei1236/tflite-micro/-/commit/24010d0ef8595433f10afdb2b7b25e1fedb2f2de Input multiplier range check for tanh/logistic kernel. (#2023) 2023-06-21T06:21:17+00:00 lukmanr-cadence 77075253+lukmanr-cadence@users.noreply.github.com Code snippet which calculate Input multiplier and shift values below, ```cc double multiplier = static_cast&lt;double&gt;(input-&gt;params.scale) * 4096.0 * 3.0; data-&gt;input_left_shift = 0; while (multiplier &lt;= 32767.0 / 2.0 &amp;&amp; data-&gt;input_left_shift &lt;= 30) { data-&gt;input_left_shift++; multiplier = multiplier * 2.0; } ``` Usually this multiplier value will be in 16 bit even though this variable is a 32 bit integer. Above while loop make sure that this value is normalized 16 bit, except for the cases where input-&gt;params.scale is more than 2.67. Scale factor for 16 bit activation will be more than 2.67 for the cases where float dynamic range is more than 174978.45(2.67*65535).This is extremely rare. Code snippet below where input multiplier is used in the kernel. ```cc int32_t input_data = ((*ptr_input_data) * input_multiplier + round) &gt;&gt;input_left_shift; ``` It’s actually a 16x32 bit multiplication ( *ptr_input_data is 16 bit variable and input_multiplier is 32 bit variable) and result is stored in 32 bit. Non-saturated lower 32 bit result of 16x32 multiplication are stored in the input_data. So anyway the result will be wrong for the cases, where input multiplier is more than 16 bits. So a condition check in the prepare function for these kernel will be good. So the programmer can make use of 16x16 operation instead of 16x32 bit multiplication. BUG=see description https://gitcode.net/xusiwei1236/tflite-micro/-/commit/f22888af6502c1a58794a3de9cc8d22eeaf262f4 Adds Signal Library RFFT OP (#2056) 2023-06-21T20:36:55+00:00 suleshahid 110432064+suleshahid@users.noreply.github.com Second OP for the TFLM Signal library, Real-Valued Fast Fourier Transform. The RFFT OP provides three resolutions: `FLOAT, INT16, INT32` Similar usage as to previous Window OP: * `op_resolver.AddRfft()` (which will add all resolutions, and determine the type at runtime) * `op_resolver.AddRfftFloat()`, `op_resolver.AddRfftInt16()`, `op_resolver.AddRfftInt32()` for a specific resolution type. * or via python as can be seen in `fft_ops_test.py` 3 testing options are provided: * Micro(C++): bazel run signal/micro/kernels:fft_test * Tensorflow/Micro(Python): bazel run python/tflite_micro/signal:fft_ops_test * Makefile(C++): make -f tensorflow/lite/micro/tools/make/Makefile test_kernel_fft_test BUG=[287346710](<a href="http://b/287346710" rel="nofollow noreferrer noopener" target="_blank">http://b/287346710</a>) https://gitcode.net/xusiwei1236/tflite-micro/-/commit/37d32b3345b918c9d2f7bb34af332469bfc2ea75 Increase Rfft float test tolerance (#2077) 2023-06-23T03:46:53+00:00 suleshahid 110432064+suleshahid@users.noreply.github.com Our Cortex-M55 test is failing due to strict tolerance in the Rfft float test. Increasing tolerance for this target. BUG=[287518815](<a href="http://b/287518815" rel="nofollow noreferrer noopener" target="_blank">http://b/287518815</a>)
......@@ -91,6 +91,7 @@ PythonOpsResolver::PythonOpsResolver() {
AddReshape();
AddResizeBilinear();
AddResizeNearestNeighbor();
AddRfft();
AddRound();
AddRsqrt();
AddSelectV2();
......
......@@ -15,6 +15,7 @@ cc_library(
name = "ops_lib",
visibility = [":signal_friends"],
deps = [
":fft_ops_cc",
":window_op_cc",
],
)
......@@ -27,10 +28,33 @@ py_library(
],
srcs_version = "PY3",
deps = [
":fft_ops",
":window_op",
],
)
py_tflm_signal_library(
name = "fft_ops",
srcs = ["ops/fft_ops.py"],
cc_op_defs = ["//signal/tensorflow_core/ops:fft_ops"],
cc_op_kernels = [
"//signal/tensorflow_core/kernels:fft_kernel",
],
)
py_test(
name = "fft_ops_test",
srcs = ["ops/fft_ops_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":fft_ops",
"//python/tflite_micro/signal/utils:util",
requirement("numpy"),
requirement("tensorflow-cpu"),
],
)
py_tflm_signal_library(
name = "window_op",
srcs = ["ops/window_op.py"],
......
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Use FFT ops in python."""
import math
import tensorflow as tf
from tflite_micro.python.tflite_micro.signal.utils import util
gen_fft_ops = util.load_custom_op('fft_ops.so')
_MIN_FFT_LENGTH = 64
_MAX_FFT_LENGTH = 2048
def get_pow2_fft_length(input_length):
"""Returns the smallest suuported power of 2 FFT length larger than or equal
to the input_length.
Only returns FFT lengths that are powers of 2 within the range
[_MIN_FFT_LENGTH, _MAX_FFT_LENGTH].
Args:
input_length: Length of input time domain signal.
Returns:
A pair: the smallest length and its log2 (number of bits)
Raises:
ValueError: The FFT length needed is not supported
"""
fft_bits = math.ceil(math.log2(input_length))
fft_length = pow(2, fft_bits)
if not _MIN_FFT_LENGTH <= fft_length <= _MAX_FFT_LENGTH:
raise ValueError("Invalid fft_length. Must be between %d and %d." %
(_MIN_FFT_LENGTH, _MAX_FFT_LENGTH))
return fft_length, fft_bits
def _fft_wrapper(fft_fn, default_name):
"""Wrapper around gen_fft_ops.*rfft*."""
def _fft(input_tensor, fft_length, name=default_name):
if not ((_MIN_FFT_LENGTH <= fft_length <= _MAX_FFT_LENGTH) and
(fft_length % 2 == 0)):
raise ValueError(
"Invalid fft_length. Must be an even number between %d and %d." %
(_MIN_FFT_LENGTH, _MAX_FFT_LENGTH))
with tf.name_scope(name) as name:
input_tensor = tf.convert_to_tensor(input_tensor)
return fft_fn(input_tensor, fft_length=fft_length, name=name)
return _fft
rfft = _fft_wrapper(gen_fft_ops.signal_rfft, "signal_rfft")
tf.no_gradient("signal_rfft")
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for FFT ops."""
import os
import numpy as np
import tensorflow as tf
from tensorflow.python.platform import resource_loader
from tflite_micro.python.tflite_micro.signal.ops import fft_ops
from tflite_micro.python.tflite_micro.signal.utils import util
class RfftOpTest(tf.test.TestCase):
_PREFIX_PATH = resource_loader.get_path_to_datafile('')
def GetResource(self, filepath):
full_path = os.path.join(self._PREFIX_PATH, filepath)
with open(full_path, 'rt') as f:
file_text = f.read()
return file_text
def SingleRfftTest(self, filename):
lines = self.GetResource(filename).splitlines()
args = lines[0].split()
fft_length = int(args[0])
func = tf.function(fft_ops.rfft)
input_size = len(lines[1].split())
concrete_function = func.get_concrete_function(
tf.TensorSpec(input_size, dtype=tf.int16), fft_length)
# TODO(b/286252893): make test more robust (vs scipy)
interpreter = util.get_tflm_interpreter(concrete_function, func)
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
# Skip line 0, which contains the configuration params.
# Read lines in pairs <input, expected>
i = 1
while i < len(lines):
in_frame = np.array([int(j) for j in lines[i].split()], dtype=np.int16)
out_frame_exp = [int(j) for j in lines[i + 1].split()]
# Compare TFLM inference against the expected golden values
# TODO(b/286252893): validate usage of testing vs interpreter here
interpreter.set_tensor(input_details[0]['index'], in_frame)
interpreter.invoke()
out_frame = interpreter.get_tensor(output_details[0]['index'])
self.assertAllEqual(out_frame_exp, out_frame)
# TF
out_frame = self.evaluate(fft_ops.rfft(in_frame, fft_length))
self.assertAllEqual(out_frame_exp, out_frame)
i += 2
def MultiDimRfftTest(self, filename):
lines = self.GetResource(filename).splitlines()
args = lines[0].split()
fft_length = int(args[0])
func = tf.function(fft_ops.rfft)
input_size = len(lines[1].split())
# Since the input starts at line 1, we must add 1. To avoid overflowing,
# instead subtract 7.
len_lines_multiple_of_eight = int(len(lines) - len(lines) % 8) - 7
# Skip line 0, which contains the configuration params.
# Read lines in pairs <input, expected>
in_frames = np.array([[int(j) for j in lines[i].split()]
for i in range(1, len_lines_multiple_of_eight, 2)],
dtype=np.int16)
out_frames_exp = [[int(j) for j in lines[i + 1].split()]
for i in range(1, len_lines_multiple_of_eight, 2)]
# Compare TFLM inference against the expected golden values
# TODO(b/286252893): validate usage of testing vs interpreter here
concrete_function = func.get_concrete_function(
tf.TensorSpec(np.shape(in_frames), dtype=tf.int16), fft_length)
interpreter = util.get_tflm_interpreter(concrete_function, func)
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.set_tensor(input_details[0]['index'], in_frames)
interpreter.invoke()
out_frame = interpreter.get_tensor(output_details[0]['index'])
self.assertAllEqual(out_frames_exp, out_frame)
# TF
out_frames = self.evaluate(fft_ops.rfft(in_frames, fft_length))
self.assertAllEqual(out_frames_exp, out_frames)
# Expand outer dims to [4, x, input_size] to test >1 outer dim.
in_frames_multiple_outer_dims = np.reshape(in_frames, [4, -1, input_size])
out_frames_exp_multiple_outer_dims = np.reshape(
out_frames_exp, [4, -1, len(out_frames_exp[0])])
out_frames_multiple_outer_dims = self.evaluate(
fft_ops.rfft(in_frames_multiple_outer_dims, fft_length))
self.assertAllEqual(out_frames_exp_multiple_outer_dims,
out_frames_multiple_outer_dims)
def testRfftOpImpulseTest(self):
for dtype in [np.int16, np.int32]:
fft_length = fft_ops._MIN_FFT_LENGTH
while fft_length <= fft_ops._MAX_FFT_LENGTH:
max_value = np.iinfo(dtype).max
# Integer RFFTs are scaled by 1 / fft_length
expected_real = round(max_value / fft_length)
expected_imag = 0
fft_input = np.zeros(fft_length, dtype=dtype)
fft_input[0] = max_value
fft_output = self.evaluate(fft_ops.rfft(fft_input, fft_length))
for i in range(0, int(fft_length / 2 + 1)):
self.assertEqual(fft_output[2 * i], expected_real)
self.assertEqual(fft_output[2 * i + 1], expected_imag)
fft_length = 2 * fft_length
def testRfftMaxMinAmplitudeTest(self):
for dtype in [np.int16, np.int32]:
# Make sure that the FFT doesn't overflow with max/min inputs
fft_length = fft_ops._MIN_FFT_LENGTH
while fft_length <= fft_ops._MAX_FFT_LENGTH:
# Test max
expected_real = np.iinfo(dtype).max
expected_imag = 0
fft_input = expected_real * np.ones(fft_length, dtype=dtype)
fft_output = self.evaluate(fft_ops.rfft(fft_input, fft_length))
if dtype == np.int16:
self.assertAlmostEqual(fft_output[0], expected_real, delta=21)
elif dtype == np.int32:
self.assertAlmostEqual(fft_output[0], expected_real, delta=47)
self.assertAlmostEqual(fft_output[1], expected_imag)
for i in range(1, int(fft_length / 2 + 1)):
self.assertEqual(fft_output[2 * i], 0)
self.assertEqual(fft_output[2 * i + 1], 0)
# Test min
expected_real = np.iinfo(dtype).min
expected_imag = 0
fft_input = expected_real * np.ones(fft_length, dtype=dtype)
fft_output = self.evaluate(fft_ops.rfft(fft_input, fft_length))
self.assertAlmostEqual(fft_output[0], expected_real, delta=22)
self.assertAlmostEqual(fft_output[1], expected_imag)
for i in range(1, int(fft_length / 2 + 1)):
self.assertEqual(fft_output[2 * i], 0)
self.assertEqual(fft_output[2 * i + 1], 0)
fft_length = 2 * fft_length
def testRfftSineTest(self):
sine_wave_amplitude = 10000
# how many sine periods per fft_length samples
sine_wave_angle = (1 / fft_ops._MIN_FFT_LENGTH)
fft_length = fft_ops._MIN_FFT_LENGTH
while fft_length <= fft_ops._MAX_FFT_LENGTH:
fft_input = sine_wave_amplitude * np.sin(
sine_wave_angle * np.pi * 2 * np.array(range(0, fft_length)))
fft_input_float = np.float32(fft_input)
fft_input_int16 = np.int16(np.round(fft_input_float))
fft_input_int32 = np.int32(np.round(fft_input_float))
fft_output_float = self.evaluate(
fft_ops.rfft(fft_input_float, fft_length))
fft_output_int16 = self.evaluate(
fft_ops.rfft(fft_input_int16, fft_length))
fft_output_int32 = np.round(
self.evaluate(fft_ops.rfft(fft_input_int32, fft_length)))
sine_bin = round(fft_length / fft_ops._MIN_FFT_LENGTH)
expected_real = 0
# The output of floating point RFFT is not scaled
# This is the expected output of the theorerical DFT
expected_imag_sine_bin_float = np.float32(-sine_wave_amplitude / 2 *
fft_length)
# The output of the integer RFFT is scaled by 1 / fft_length
expected_imag_sine_bin_int16 = np.int16(round(-sine_wave_amplitude / 2))
expected_imag_sine_bin_int32 = np.int32(round(-sine_wave_amplitude / 2))
expected_imag_other_bins = 0
for i in range(0, int(fft_length / 2 + 1)):
self.assertAlmostEqual(fft_output_float[2 * i],
expected_real,
delta=0.1)
self.assertAlmostEqual(fft_output_int16[2 * i], expected_real, delta=2)
self.assertAlmostEqual(fft_output_int32[2 * i], expected_real, delta=0)
if i == sine_bin:
self.assertAlmostEqual(fft_output_float[2 * i + 1],
expected_imag_sine_bin_float,
delta=1.3e-12)
self.assertAlmostEqual(fft_output_int16[2 * i + 1],
expected_imag_sine_bin_int16,
delta=2)
self.assertAlmostEqual(fft_output_int32[2 * i + 1],
expected_imag_sine_bin_int32,
delta=2)
else:
self.assertAlmostEqual(fft_output_float[2 * i + 1],
expected_imag_other_bins,
delta=0.35)
self.assertAlmostEqual(fft_output_int16[2 * i + 1],
expected_imag_other_bins,
delta=2)
self.assertAlmostEqual(fft_output_int32[2 * i + 1],
expected_imag_other_bins,
delta=1)
fft_length = 2 * fft_length
def testFftTooLarge(self):
for dtype in [np.int16, np.int32, np.float32]:
fft_input = np.zeros(round(fft_ops._MAX_FFT_LENGTH * 2), dtype=dtype)
with self.assertRaises((tf.errors.InvalidArgumentError, ValueError)):
self.evaluate(
fft_ops.rfft(fft_input, round(fft_ops._MAX_FFT_LENGTH * 2)))
def testFftTooSmall(self):
for dtype in [np.int16, np.int32, np.float32]:
fft_input = np.zeros(round(fft_ops._MIN_FFT_LENGTH / 2), dtype=dtype)
with self.assertRaises((tf.errors.InvalidArgumentError, ValueError)):
self.evaluate(
fft_ops.rfft(fft_input, round(fft_ops._MIN_FFT_LENGTH / 2)))
def testFftLengthNoEven(self):
for dtype in [np.int16, np.int32, np.float32]:
fft_input = np.zeros(127, dtype=dtype)
with self.assertRaises((tf.errors.InvalidArgumentError, ValueError)):
self.evaluate(fft_ops.rfft(fft_input, 127))
def testPow2FftLengthTest(self):
fft_length, fft_bits = fft_ops.get_pow2_fft_length(131)
self.assertEqual(fft_length, 256)
self.assertEqual(fft_bits, 8)
fft_length, fft_bits = fft_ops.get_pow2_fft_length(73)
self.assertEqual(fft_length, 128)
self.assertEqual(fft_bits, 7)
with self.assertRaises((tf.errors.InvalidArgumentError, ValueError)):
fft_ops.get_pow2_fft_length(fft_ops._MIN_FFT_LENGTH / 2 - 1)
with self.assertRaises((tf.errors.InvalidArgumentError, ValueError)):
fft_ops.get_pow2_fft_length(fft_ops._MAX_FFT_LENGTH + 1)
if __name__ == '__main__':
tf.test.main()
......@@ -7,5 +7,6 @@ package(
)
exports_files([
"rfft_test1.txt",
"window_test1.txt",
])
此差异已折叠。
......@@ -18,7 +18,7 @@ import numpy as np
import tensorflow as tf
from tflite_micro.python.tflite_micro.signal.utils import util
gen_window_op = util.load_custom_op('window')
gen_window_op = util.load_custom_op('window_op.so')
def hann_window_weights(window_length, shift, dtype=np.int16):
......
......@@ -39,4 +39,4 @@ def get_tflm_interpreter(concrete_function, trackable_obj):
def load_custom_op(name):
return load_library.load_op_library(
resource_loader.get_path_to_datafile('../ops/_' + name + '_op.so'))
resource_loader.get_path_to_datafile('../ops/_' + name))
......@@ -10,14 +10,20 @@ package(
cc_library(
name = "register_signal_ops",
srcs = [
"rfft.cc",
"window.cc",
],
hdrs = [
"rfft.h",
],
copts = micro_copts(),
visibility = [
"//tensorflow/lite/micro",
],
deps = [
"//signal/src:rfft",
"//signal/src:window",
"//tensorflow/lite:type_to_tflitetype",
"//tensorflow/lite/kernels:kernel_util",
"//tensorflow/lite/kernels/internal:tensor",
"//tensorflow/lite/micro:flatbuffer_utils",
......@@ -29,6 +35,33 @@ cc_library(
],
)
cc_library(
name = "fft_flexbuffers_generated_data",
srcs = [
"fft_flexbuffers_generated_data.cc",
],
hdrs = [
"fft_flexbuffers_generated_data.h",
],
)
cc_test(
name = "fft_test",
srcs = [
"fft_test.cc",
],
deps = [
":fft_flexbuffers_generated_data",
":register_signal_ops",
"//signal/testdata:fft_test_data",
"//tensorflow/lite/c:common",
"//tensorflow/lite/micro:op_resolvers",
"//tensorflow/lite/micro:test_helpers",
"//tensorflow/lite/micro/kernels:kernel_runner",
"//tensorflow/lite/micro/testing:micro_test",
],
)
cc_library(
name = "window_flexbuffers_generated_data",
srcs = [
......
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// TODO(b/187459424): Find a better way to generate this data.
// This file is generated. See:
// tensorflow/lite/micro/kernels/test_data_generation/README.md
#include "signal/micro/kernels/fft_flexbuffers_generated_data.h"
const int g_gen_data_size_fft_length_64_float = 26;
const unsigned char g_gen_data_fft_length_64_float[] = {
0x66, 0x66, 0x74, 0x5f, 0x6c, 0x65, 0x6e, 0x67, 0x74,
0x68, 0x00, 0x54, 0x00, 0x02, 0x03, 0x0f, 0x02, 0x01,
0x02, 0x00, 0x40, 0x04, 0x04, 0x04, 0x24, 0x01};
const int g_gen_data_size_fft_length_64_int16 = 26;
const unsigned char g_gen_data_fft_length_64_int16[] = {
0x66, 0x66, 0x74, 0x5f, 0x6c, 0x65, 0x6e, 0x67, 0x74,
0x68, 0x00, 0x54, 0x00, 0x02, 0x03, 0x0f, 0x02, 0x01,
0x02, 0x07, 0x40, 0x04, 0x04, 0x04, 0x24, 0x01};
const int g_gen_data_size_fft_length_64_int32 = 26;
const unsigned char g_gen_data_fft_length_64_int32[] = {
0x66, 0x66, 0x74, 0x5f, 0x6c, 0x65, 0x6e, 0x67, 0x74,
0x68, 0x00, 0x54, 0x00, 0x02, 0x03, 0x0f, 0x02, 0x01,
0x02, 0x02, 0x40, 0x04, 0x04, 0x04, 0x24, 0x01};
const int g_gen_data_size_fft_length_512_float = 31;
const unsigned char g_gen_data_fft_length_512_float[] = {
0x66, 0x66, 0x74, 0x5f, 0x6c, 0x65, 0x6e, 0x67, 0x74, 0x68, 0x00,
0x54, 0x00, 0x02, 0x03, 0x0f, 0x02, 0x00, 0x01, 0x00, 0x02, 0x00,
0x00, 0x00, 0x00, 0x02, 0x05, 0x05, 0x06, 0x25, 0x01};
const int g_gen_data_size_fft_length_512_int16 = 31;
const unsigned char g_gen_data_fft_length_512_int16[] = {
0x66, 0x66, 0x74, 0x5f, 0x6c, 0x65, 0x6e, 0x67, 0x74, 0x68, 0x00,
0x54, 0x00, 0x02, 0x03, 0x0f, 0x02, 0x00, 0x01, 0x00, 0x02, 0x00,
0x07, 0x00, 0x00, 0x02, 0x05, 0x05, 0x06, 0x25, 0x01};
const int g_gen_data_size_fft_length_512_int32 = 31;
const unsigned char g_gen_data_fft_length_512_int32[] = {
0x66, 0x66, 0x74, 0x5f, 0x6c, 0x65, 0x6e, 0x67, 0x74, 0x68, 0x00,
0x54, 0x00, 0x02, 0x03, 0x0f, 0x02, 0x00, 0x01, 0x00, 0x02, 0x00,
0x02, 0x00, 0x00, 0x02, 0x05, 0x05, 0x06, 0x25, 0x01};
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef SIGNAL_MICRO_KERNELS_TEST_DATA_GENERATION_GENERATE_FFT_FLEXBUFFERS_DATA_H_
#define SIGNAL_MICRO_KERNELS_TEST_DATA_GENERATION_GENERATE_FFT_FLEXBUFFERS_DATA_H_
extern const int g_gen_data_size_fft_length_64_float;
extern const unsigned char g_gen_data_fft_length_64_float[];
extern const int g_gen_data_size_fft_length_64_int16;
extern const unsigned char g_gen_data_fft_length_64_int16[];
extern const int g_gen_data_size_fft_length_64_int32;
extern const unsigned char g_gen_data_fft_length_64_int32[];
extern const int g_gen_data_size_fft_length_512_float;
extern const unsigned char g_gen_data_fft_length_512_float[];
extern const int g_gen_data_size_fft_length_512_int16;
extern const unsigned char g_gen_data_fft_length_512_int16[];
extern const int g_gen_data_size_fft_length_512_int32;
extern const unsigned char g_gen_data_fft_length_512_int32[];
#endif // SIGNAL_MICRO_KERNELS_TEST_DATA_GENERATION_GENERATE_FFT_FLEXBUFFERS_DATA_H_
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstdio>
#include "signal/micro/kernels/fft_flexbuffers_generated_data.h"
#include "signal/testdata/fft_test_data.h"
#include "tensorflow/lite/micro/kernels/kernel_runner.h"
#include "tensorflow/lite/micro/test_helpers.h"
#include "tensorflow/lite/micro/testing/micro_test.h"
namespace tflite {
namespace testing {
namespace {
template <typename T>
TfLiteStatus ValidateFFTGoldens(
TfLiteTensor* tensors, const int tensors_size, TfLiteIntArray* inputs_array,
TfLiteIntArray* outputs_array, int output_len, const T* golden,
const TFLMRegistration registration, const uint8_t* flexbuffers_data,
const int flexbuffers_data_len, T* output_data, T tolerance) {
micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array,
outputs_array,
/*builtin_data=*/nullptr);
// TfLite uses a char* for the raw bytes whereas flexbuffers use an unsigned
// char*. This small discrepancy results in compiler warnings unless we
// reinterpret_cast right before passing in the flexbuffer bytes to the
// KernelRunner.
TfLiteStatus status = runner.InitAndPrepare(
reinterpret_cast<const char*>(flexbuffers_data), flexbuffers_data_len);
if (status != kTfLiteOk) {
return status;
}
status = runner.Invoke();
if (status != kTfLiteOk) {
return status;
}
for (int i = 0; i < output_len; ++i) {
TF_LITE_MICRO_EXPECT_NEAR(golden[i], output_data[i], tolerance);
}
return kTfLiteOk;
}
template <typename T>
TfLiteStatus TestFFT(int* input_dims_data, const T* input_data,
int* output_dims_data, const T* golden,
const TFLMRegistration registration,
const uint8_t* flexbuffers_data,
const int flexbuffers_data_len, T* output_data,
T tolerance) {
TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
constexpr int kInputsSize = 1;
constexpr int kOutputsSize = 1;
constexpr int kTensorsSize = kInputsSize + kOutputsSize;
TfLiteTensor tensors[kTensorsSize] = {
CreateTensor(input_data, input_dims),
CreateTensor(output_data, output_dims),
};
int inputs_array_data[] = {1, 0};
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 1};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
const int output_len = ElementCount(*output_dims);
TF_LITE_ENSURE_STATUS(
ValidateFFTGoldens<T>(tensors, kTensorsSize, inputs_array, outputs_array,
output_len, golden, registration, flexbuffers_data,
flexbuffers_data_len, output_data, tolerance));
return kTfLiteOk;
}
} // namespace
} // namespace testing
} // namespace tflite
TF_LITE_MICRO_TESTS_BEGIN
TF_LITE_MICRO_TEST(RfftTestSize64Float) {
constexpr int kOutputLen = 66;
int input_shape[] = {1, 64};
const float input[] = {16384., 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
int output_shape[] = {1, kOutputLen};
const float golden[] = {16384., 0, 16384., 0, 16384., 0, 16384., 0, 16384., 0,
16384., 0, 16384., 0, 16384., 0, 16384., 0, 16384., 0,
16384., 0, 16384., 0, 16384., 0, 16384., 0, 16384., 0,
16384., 0, 16384., 0, 16384., 0, 16384., 0, 16384., 0,
16384., 0, 16384., 0, 16384., 0, 16384., 0, 16384., 0,
16384., 0, 16384., 0, 16384., 0, 16384., 0, 16384., 0,
16384., 0, 16384., 0, 16384., 0};
float output[kOutputLen];
const TFLMRegistration* registration =
tflite::tflm_signal::Register_RFFT_FLOAT();
TF_LITE_MICRO_EXPECT_EQ(
kTfLiteOk, tflite::testing::TestFFT<float>(
input_shape, input, output_shape, golden, *registration,
g_gen_data_fft_length_64_float,
g_gen_data_size_fft_length_64_float, output, 1e-7));
}
#if !defined __XTENSA__
// Currently, only 16-bit RFFT of size 512 is supported.
TF_LITE_MICRO_TEST(RfftTestSize64Int16) {
constexpr int kOutputLen = 66;
int input_shape[] = {1, 64};
const int16_t input[] = {16384, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
int output_shape[] = {1, kOutputLen};
const int16_t golden[] = {
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0};
int16_t output[kOutputLen];
const TFLMRegistration* registration =
tflite::tflm_signal::Register_RFFT_INT16();
TF_LITE_MICRO_EXPECT_EQ(
kTfLiteOk, tflite::testing::TestFFT<int16_t>(
input_shape, input, output_shape, golden, *registration,
g_gen_data_fft_length_64_int16,
g_gen_data_size_fft_length_64_int16, output, 0));
}
#endif
TF_LITE_MICRO_TEST(RfftTestSize64Int32) {
constexpr int kOutputLen = 66;
int input_shape[] = {1, 64};
const int32_t input[] = {16384, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
int output_shape[] = {1, kOutputLen};
const int32_t golden[] = {
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0};
int32_t output[kOutputLen];
const TFLMRegistration* registration =
tflite::tflm_signal::Register_RFFT_INT32();
TF_LITE_MICRO_EXPECT_EQ(
kTfLiteOk, tflite::testing::TestFFT<int32_t>(
input_shape, input, output_shape, golden, *registration,
g_gen_data_fft_length_64_int32,
g_gen_data_size_fft_length_64_int32, output, 0));
}
TF_LITE_MICRO_TEST(RfftTestSize64Int32OuterDims4) {
constexpr int kOutputLen = 66;
constexpr int kOuterDim = 2;
int input_shape[] = {3, kOuterDim, kOuterDim, 64};
const int32_t input[] = {
16384, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 16384, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16384, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 16384, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0};
int output_shape[] = {3, kOuterDim, kOuterDim, kOutputLen};
const int32_t golden[] = {
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0};
int32_t output[kOuterDim * kOuterDim * kOutputLen];
const TFLMRegistration* registration =
tflite::tflm_signal::Register_RFFT_INT32();
TF_LITE_MICRO_EXPECT_EQ(
kTfLiteOk, tflite::testing::TestFFT<int32_t>(
input_shape, input, output_shape, golden, *registration,
g_gen_data_fft_length_64_int32,
g_gen_data_size_fft_length_64_int32, output, 0));
}
TF_LITE_MICRO_TEST(RfftTestSize512Float) {
constexpr int kOutputLen = 514;
int input_shape[] = {1, 512};
int output_shape[] = {1, kOutputLen};
// Outputs are ComplexInt16 which takes twice the space as regular int16.
float output[kOutputLen * 2];
const TFLMRegistration* registration =
tflite::tflm_signal::Register_RFFT_FLOAT();
TF_LITE_MICRO_EXPECT_EQ(
kTfLiteOk, tflite::testing::TestFFT<float>(
input_shape, tflite::kRfftFloatLength512Input,
output_shape, tflite::kRfftFloatLength512Golden,
*registration, g_gen_data_fft_length_512_float,
g_gen_data_size_fft_length_512_float, output, 1e-5));
}
TF_LITE_MICRO_TEST(RfftTestSize512Int16) {
constexpr int kOutputLen = 514;
int input_shape[] = {1, 512};
int output_shape[] = {1, kOutputLen};
// Outputs are ComplexInt16 which takes twice the space as regular int16.
int16_t output[kOutputLen * 2];
const TFLMRegistration* registration =
tflite::tflm_signal::Register_RFFT_INT16();
// See (b/287518815) for why this is needed.
#if defined(HIFI4) || defined(HIFI5)
int tolerance = 9;
#else // defined(HIFI4) || defined(HIFI5)
int tolerance = 3;
#endif // defined(HIFI4) || defined(HIFI5)
TF_LITE_MICRO_EXPECT_EQ(
kTfLiteOk, tflite::testing::TestFFT<int16_t>(
input_shape, tflite::kRfftInt16Length512Input,
output_shape, tflite::kRfftInt16Length512Golden,
*registration, g_gen_data_fft_length_512_int16,
g_gen_data_size_fft_length_512_int16, output, tolerance));
}
TF_LITE_MICRO_TEST(RfftTestSize512Int32) {
constexpr int kOutputLen = 514;
int input_shape[] = {1, 512};
int output_shape[] = {1, kOutputLen};
// Outputs are ComplexInt32 which takes twice the space as regular int32.
int32_t output[kOutputLen * 2];
const TFLMRegistration* registration =
tflite::tflm_signal::Register_RFFT_INT32();
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
tflite::testing::TestFFT<int32_t>(
input_shape, tflite::kRfftInt32Length512Input,
output_shape, tflite::kRfftInt32Length512Golden,
*registration, g_gen_data_fft_length_512_int32,
g_gen_data_size_fft_length_512_int32, output, 0));
}
TF_LITE_MICRO_TESTS_END
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "signal/src/rfft.h"
#include <math.h>
#include <stddef.h>
#include <stdint.h>
#include "signal/micro/kernels/rfft.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/micro/flatbuffer_utils.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "tensorflow/lite/portable_type_to_tflitetype.h"
namespace tflite {
namespace {
constexpr int kInputTensor = 0;
constexpr int kOutputTensor = 0;
// Indices into the init flexbuffer's vector.
// The parameter's name is in the comment that follows.
// Elements in the vectors are ordered alphabetically by parameter name.
// 'T' is added implicitly by the TensorFlow framework when the type is resolved
// during graph construction.
// constexpr int kTypeIndex = 0; // 'T' (unused)
constexpr int kFftLengthIndex = 1; // 'fft_length'
template <typename T>
struct TfLiteAudioFrontendRfftParams {
int32_t fft_length;
int32_t input_size;
int32_t input_length;
int32_t output_length;
TfLiteType fft_type;
T* work_area;
int8_t* state;
};
template <typename T, size_t (*get_needed_memory_func)(int32_t),
void* (*init_func)(int32_t, void*, size_t)>
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
auto* params = static_cast<TfLiteAudioFrontendRfftParams<T>*>(
context->AllocatePersistentBuffer(
context, sizeof(TfLiteAudioFrontendRfftParams<T>)));
tflite::FlexbufferWrapper fbw(buffer_t, length);
params->fft_length = fbw.ElementAsInt32(kFftLengthIndex);
params->fft_type = typeToTfLiteType<T>();
params->work_area = static_cast<T*>(context->AllocatePersistentBuffer(
context, params->fft_length * sizeof(T)));
size_t state_size = (*get_needed_memory_func)(params->fft_length);
params->state = static_cast<int8_t*>(
context->AllocatePersistentBuffer(context, state_size * sizeof(int8_t)));
(*init_func)(params->fft_length, params->state, state_size);
return params;
}
template <typename T, TfLiteType TfLiteTypeEnum>
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
MicroContext* micro_context = GetMicroContext(context);
TfLiteTensor* input =
micro_context->AllocateTempInputTensor(node, kInputTensor);
TF_LITE_ENSURE(context, input != nullptr);
TfLiteTensor* output =
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
TF_LITE_ENSURE_EQ(context, NumDimensions(input), NumDimensions(output));
TF_LITE_ENSURE_TYPES_EQ(context, input->type, TfLiteTypeEnum);
TF_LITE_ENSURE_TYPES_EQ(context, output->type, TfLiteTypeEnum);
auto* params =
reinterpret_cast<TfLiteAudioFrontendRfftParams<T>*>(node->user_data);
RuntimeShape input_shape = GetTensorShape(input);
RuntimeShape output_shape = GetTensorShape(output);
params->input_length = input_shape.Dims(input_shape.DimensionsCount() - 1);
params->input_size = input_shape.FlatSize();
// Divide by 2 because output is complex.
params->output_length =
output_shape.Dims(output_shape.DimensionsCount() - 1) / 2;
micro_context->DeallocateTempTfLiteTensor(input);
micro_context->DeallocateTempTfLiteTensor(output);
return kTfLiteOk;
}
template <typename T, void (*apply_func)(void*, const T* input, Complex<T>*)>
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params =
reinterpret_cast<TfLiteAudioFrontendRfftParams<T>*>(node->user_data);
const TfLiteEvalTensor* input =
tflite::micro::GetEvalInput(context, node, kInputTensor);
const T* input_data = tflite::micro::GetTensorData<T>(input);
TfLiteEvalTensor* output =
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
Complex<T>* output_data = tflite::micro::GetTensorData<Complex<T>>(output);
for (int input_idx = 0, output_idx = 0; input_idx < params->input_size;
input_idx += params->input_length, output_idx += params->output_length) {
memcpy(params->work_area, &input_data[input_idx],
sizeof(T) * params->input_length);
// Zero pad input to FFT length
memset(&params->work_area[params->input_length], 0,
sizeof(T) * (params->fft_length - params->input_length));
(*apply_func)(params->state, params->work_area, &output_data[output_idx]);
}
return kTfLiteOk;
}
void* InitAll(TfLiteContext* context, const char* buffer, size_t length) {
const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap();
auto tensor_type = static_cast<tflite::TensorType>(m["T"].AsInt32());
switch (tensor_type) {
case TensorType_INT16: {
return Init<int16_t, RfftInt16GetNeededMemory, RfftInt16Init>(
context, buffer, length);
}
case TensorType_INT32: {
return Init<int32_t, RfftInt32GetNeededMemory, RfftInt32Init>(
context, buffer, length);
}
case TensorType_FLOAT32: {
return Init<float, RfftFloatGetNeededMemory, RfftFloatInit>(
context, buffer, length);
}
default:
return nullptr;
}
}
TfLiteStatus PrepareAll(TfLiteContext* context, TfLiteNode* node) {
auto* params =
reinterpret_cast<TfLiteAudioFrontendRfftParams<void>*>(node->user_data);
switch (params->fft_type) {
case kTfLiteInt16: {
return Prepare<int16_t, kTfLiteInt16>(context, node);
}
case kTfLiteInt32: {
return Prepare<int32_t, kTfLiteInt32>(context, node);
}
case kTfLiteFloat32: {
return Prepare<float, kTfLiteFloat32>(context, node);
}
default:
return kTfLiteError;
}
}
TfLiteStatus EvalAll(TfLiteContext* context, TfLiteNode* node) {
auto* params =
reinterpret_cast<TfLiteAudioFrontendRfftParams<void>*>(node->user_data);
switch (params->fft_type) {
case kTfLiteInt16: {
return Eval<int16_t, RfftInt16Apply>(context, node);
}
case kTfLiteInt32: {
return Eval<int32_t, RfftInt32Apply>(context, node);
}
case kTfLiteFloat32: {
return Eval<float, RfftFloatApply>(context, node);
}
default:
return kTfLiteError;
}
}
} // namespace
// TODO(b/286250473): remove namespace once de-duped libraries
namespace tflm_signal {
TFLMRegistration* Register_RFFT() {
static TFLMRegistration r =
tflite::micro::RegisterOp(InitAll, PrepareAll, EvalAll);
return &r;
}
TFLMRegistration* Register_RFFT_FLOAT() {
static TFLMRegistration r = tflite::micro::RegisterOp(
Init<float, RfftFloatGetNeededMemory, RfftFloatInit>,
Prepare<float, kTfLiteFloat32>, Eval<float, RfftFloatApply>);
return &r;
}
TFLMRegistration* Register_RFFT_INT16() {
static TFLMRegistration r = tflite::micro::RegisterOp(
Init<int16_t, RfftInt16GetNeededMemory, RfftInt16Init>,
Prepare<int16_t, kTfLiteInt16>, Eval<int16_t, RfftInt16Apply>);
return &r;
}
TFLMRegistration* Register_RFFT_INT32() {
static TFLMRegistration r = tflite::micro::RegisterOp(
Init<int32_t, RfftInt32GetNeededMemory, RfftInt32Init>,
Prepare<int32_t, kTfLiteInt32>, Eval<int32_t, RfftInt32Apply>);
return &r;
}
} // namespace tflm_signal
} // namespace tflite
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef SIGNAL_MICRO_KERNELS_RFFT_H_
#define SIGNAL_MICRO_KERNELS_RFFT_H_
#include "tensorflow/lite/micro/micro_common.h"
namespace tflite {
namespace tflm_signal {
TFLMRegistration* Register_RFFT();
TFLMRegistration* Register_RFFT_FLOAT();
TFLMRegistration* Register_RFFT_INT16();
TFLMRegistration* Register_RFFT_INT32();
} // namespace tflm_signal
} // namespace tflite
#endif // SIGNAL_MICRO_KERNELS_RFFT_H_
......@@ -3,6 +3,25 @@ package(
licenses = ["notice"],
)
cc_library(
name = "complex",
hdrs = ["complex.h"],
)
cc_library(
name = "rfft",
srcs = [
"rfft_float.cc",
"rfft_int16.cc",
"rfft_int32.cc",
],
hdrs = ["rfft.h"],
deps = [
":complex",
"//signal/src/kiss_fft_wrappers",
],
)
cc_library(
name = "window",
srcs = ["window.cc"],
......
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef SIGNAL_SRC_COMPLEX_H_
#define SIGNAL_SRC_COMPLEX_H_
#include <stdint.h>
// We would use the standard complex type in complex.h, but there's
// no guarantee that all architectures will support it.
template <typename T>
struct Complex {
T real;
T imag;
};
#endif // SIGNAL_SRC_COMPLEX_H_
package(
default_visibility = ["//signal/src:__subpackages__"],
licenses = ["notice"],
)
cc_library(
name = "kiss_fft_wrappers",
srcs = [
"kiss_fft_float.cc",
"kiss_fft_int16.cc",
"kiss_fft_int32.cc",
],
hdrs = [
"kiss_fft_common.h",
"kiss_fft_float.h",
"kiss_fft_int16.h",
"kiss_fft_int32.h",
],
deps = [
"@kissfft//:kiss_fftr",
],
)
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef SIGNAL_SRC_KISS_FFT_WRAPPERS_KISS_FFT_COMMON_H_
#define SIGNAL_SRC_KISS_FFT_WRAPPERS_KISS_FFT_COMMON_H_
// This header file should be included in all variants of kiss_fft_$type.{h,cc}
// so that their sub-included source files do not mistakenly wrap libc header
// files within their kissfft_$type namespaces.
// E.g., This header avoids kissfft_int16.h containing:
// namespace kiss_fft_int16 {
// #include "kiss_fft.h"
// }
// where kiss_fft_.h contains:
// #include <math.h>
//
// TRICK: By including the following header files here, their preprocessor
// header guards prevent them being re-defined inside of the kiss_fft_$type
// namespaces declared within the kiss_fft_$type.{h,cc} sources.
// Note that the original kiss_fft*.h files are untouched since they
// may be used in libraries that include them directly.
#include <limits.h>
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#ifdef FIXED_POINT
#include <sys/types.h>
#endif
#ifdef USE_SIMD
#include <xmmintrin.h>
#endif
#endif // SIGNAL_SRC_KISS_FFT_WRAPPERS_KISS_FFT_COMMON_H_
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "signal/src/kiss_fft_wrappers/kiss_fft_common.h"
#undef FIXED_POINT
namespace kiss_fft_float {
#include "kiss_fft.c"
#include "tools/kiss_fftr.c"
} // namespace kiss_fft_float
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef SIGNAL_SRC_KISS_FFT_WRAPPERS_KISS_FFT_FLOAT_H_
#define SIGNAL_SRC_KISS_FFT_WRAPPERS_KISS_FFT_FLOAT_H_
#include "signal/src/kiss_fft_wrappers/kiss_fft_common.h"
// Wrap floating point kiss fft in its own namespace. Enables us to link an
// application with different kiss fft resolutions
// (16/32 bit integer, float, double) without getting a linker error.
#undef FIXED_POINT
namespace kiss_fft_float {
#include "kiss_fft.h"
#include "tools/kiss_fftr.h"
} // namespace kiss_fft_float
#undef FIXED_POINT
#endif // SIGNAL_SRC_KISS_FFT_WRAPPERS_KISS_FFT_FLOAT_H_
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "signal/src/kiss_fft_wrappers/kiss_fft_common.h"
#define FIXED_POINT 16
namespace kiss_fft_fixed16 {
#include "kiss_fft.c"
#include "tools/kiss_fftr.c"
} // namespace kiss_fft_fixed16
#undef FIXED_POINT
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef SIGNAL_SRC_KISS_FFT_WRAPPERS_KISS_FFT_INT16_H_
#define SIGNAL_SRC_KISS_FFT_WRAPPERS_KISS_FFT_INT16_H_
#include "signal/src/kiss_fft_wrappers/kiss_fft_common.h"
// Wrap floating point kiss fft in its own namespace. Enables us to link an
// application with different kiss fft resolutions
// (16/32 bit integer, float, double) without getting a linker error.
#define FIXED_POINT 16
namespace kiss_fft_fixed16 {
#include "kiss_fft.h"
#include "tools/kiss_fftr.h"
} // namespace kiss_fft_fixed16
#undef FIXED_POINT
#endif // SIGNAL_SRC_KISS_FFT_WRAPPERS_KISS_FFT_INT16_H_
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "signal/src/kiss_fft_wrappers/kiss_fft_common.h"
#define FIXED_POINT 32
namespace kiss_fft_fixed32 {
#include "kiss_fft.c"
#include "tools/kiss_fftr.c"
} // namespace kiss_fft_fixed32
#undef FIXED_POINT
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef SIGNAL_SRC_KISS_FFT_WRAPPERS_KISS_FFT_INT32_H_
#define SIGNAL_SRC_KISS_FFT_WRAPPERS_KISS_FFT_INT32_H_
#include "signal/src/kiss_fft_wrappers/kiss_fft_common.h"
// Wrap floating point kiss fft in its own namespace. Enables us to link an
// application with different kiss fft resolutions
// (16/32 bit integer, float, double) without getting a linker error.
#define FIXED_POINT 32
namespace kiss_fft_fixed32 {
#include "kiss_fft.h"
#include "tools/kiss_fftr.h"
} // namespace kiss_fft_fixed32
#undef FIXED_POINT
#endif // SIGNAL_SRC_KISS_FFT_WRAPPERS_KISS_FFT_INT32_H_
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef SIGNAL_SRC_RFFT_H_
#define SIGNAL_SRC_RFFT_H_
#include <stddef.h>
#include <stdint.h>
#include "signal/src/complex.h"
// RFFT (Real Fast Fourier Transform)
// FFT for real valued time domain inputs.
// 16-bit Integer input/output
// Returns the size of the memory that an RFFT of `fft_length` needs
size_t RfftInt16GetNeededMemory(int32_t fft_length);
// Initialize the state of an RFFT of `fft_length`
// `state` points to an opaque state of size `state_size`, which
// must be greater or equal to the value returned by
// RfftGetNeededMemory(fft_length).
// Return the value of `state` on success or nullptr on failure
void* RfftInt16Init(int32_t fft_length, void* state, size_t state_size);
// Applies RFFT to `input` and writes the result to `output`
// * `input` must be of size `fft_length` elements (see RfftInit)
// * `output` must be of size (`fft_length` * 2) + 1 elements
void RfftInt16Apply(void* state, const int16_t* input,
Complex<int16_t>* output);
// 32-bit Integer input/output
// Returns the size of the memory that an RFFT of `fft_length` needs
size_t RfftInt32GetNeededMemory(int32_t fft_length);
// Initialize the state of an RFFT of `fft_length`
// `state` points to an opaque state of size `state_size`, which
// must be greater or equal to the value returned by
// RfftGetNeededMemory(fft_length).
// Return the value of `state` on success or nullptr on failure
void* RfftInt32Init(int32_t fft_length, void* state, size_t state_size);
// Applies RFFT to `input` and writes the result to `output`
// * `input` must be of size `fft_length` elements (see RfftInit)
// * `output` must be of size (`fft_length` * 2) + 1 elements
void RfftInt32Apply(void* state, const int32_t* input,
Complex<int32_t>* output);
// Floating point input/output
// Returns the size of the memory that an RFFT of `fft_length` needs
size_t RfftFloatGetNeededMemory(int32_t fft_length);
// Initialize the state of an RFFT of `fft_length`
// `state` points to an opaque state of size `state_size`, which
// must be greater or equal to the value returned by
// RfftGetNeededMemory(fft_length).
// Return the value of `state` on success or nullptr on failure
void* RfftFloatInit(int32_t fft_length, void* state, size_t state_size);
// Applies RFFT to `input` and writes the result to `output`
// * `input` must be of size `fft_length` elements (see RfftInit)
// * `output` must be of size (`fft_length` * 2) + 1 elements
void RfftFloatApply(void* state, const float* input, Complex<float>* output);
#endif // SIGNAL_SRC_RFFT_H_
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <stddef.h>
#include <stdint.h>
#include "signal/src/complex.h"
#include "signal/src/kiss_fft_wrappers/kiss_fft_float.h"
#include "signal/src/rfft.h"
size_t RfftFloatGetNeededMemory(int32_t fft_length) {
size_t state_size = 0;
kiss_fft_float::kiss_fftr_alloc(fft_length, 0, nullptr, &state_size);
return state_size;
}
void* RfftFloatInit(int32_t fft_length, void* state, size_t state_size) {
return kiss_fft_float::kiss_fftr_alloc(fft_length, 0, state, &state_size);
}
void RfftFloatApply(void* state, const float* input, Complex<float>* output) {
kiss_fft_float::kiss_fftr(
static_cast<kiss_fft_float::kiss_fftr_cfg>(state),
reinterpret_cast<const kiss_fft_scalar*>(input),
reinterpret_cast<kiss_fft_float::kiss_fft_cpx*>(output));
}
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <stddef.h>
#include <stdint.h>
#include "signal/src/complex.h"
#include "signal/src/kiss_fft_wrappers/kiss_fft_int16.h"
#include "signal/src/rfft.h"
size_t RfftInt16GetNeededMemory(int32_t fft_length) {
size_t state_size = 0;
kiss_fft_fixed16::kiss_fftr_alloc(fft_length, 0, nullptr, &state_size);
return state_size;
}
void* RfftInt16Init(int32_t fft_length, void* state, size_t state_size) {
return kiss_fft_fixed16::kiss_fftr_alloc(fft_length, 0, state, &state_size);
}
void RfftInt16Apply(void* state, const int16_t* input,
Complex<int16_t>* output) {
kiss_fft_fixed16::kiss_fftr(
static_cast<kiss_fft_fixed16::kiss_fftr_cfg>(state),
reinterpret_cast<const kiss_fft_scalar*>(input),
reinterpret_cast<kiss_fft_fixed16::kiss_fft_cpx*>(output));
}
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <stddef.h>
#include <stdint.h>
#include "signal/src/complex.h"
#include "signal/src/kiss_fft_wrappers/kiss_fft_int32.h"
#include "signal/src/rfft.h"
size_t RfftInt32GetNeededMemory(int32_t fft_length) {
size_t state_size = 0;
kiss_fft_fixed32::kiss_fftr_alloc(fft_length, 0, nullptr, &state_size);
return state_size;
}
void* RfftInt32Init(int32_t fft_length, void* state, size_t state_size) {
return kiss_fft_fixed32::kiss_fftr_alloc(fft_length, 0, state, &state_size);
}
void RfftInt32Apply(void* state, const int32_t* input,
Complex<int32_t>* output) {
kiss_fft_fixed32::kiss_fftr(
static_cast<kiss_fft_fixed32::kiss_fftr_cfg>(state),
reinterpret_cast<const kiss_fft_scalar*>(input),
reinterpret_cast<kiss_fft_fixed32::kiss_fft_cpx*>(output));
}
\ No newline at end of file
......@@ -5,6 +5,15 @@ package(
licenses = ["notice"],
)
tflm_signal_kernel_library(
name = "fft_kernel",
srcs = ["fft_kernels.cc"],
deps = [
"//signal/src:rfft",
"@tensorflow_cc_deps//:cc_library",
],
)
tflm_signal_kernel_library(
name = "window_kernel",
srcs = ["window_kernel.cc"],
......
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "signal/src/rfft.h"
#include "tensorflow/core/framework/op_kernel.h"
namespace tensorflow {
namespace signal {
// get_needed_memory_func(), init_func(), apply_func()
// are type specific implementations of the RFFT functions.
// See rfft.h included above for documentation
template <typename T, DataType E, size_t (*get_needed_memory_func)(int32_t),
void* (*init_func)(int32_t, void*, size_t),
void (*apply_func)(void*, const T* input, Complex<T>*)>
class RfftOp : public tensorflow::OpKernel {
public:
explicit RfftOp(tensorflow::OpKernelConstruction* context)
: tensorflow::OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("fft_length", &fft_length_));
OP_REQUIRES_OK(context,
context->allocate_temp(E, TensorShape({fft_length_}),
&work_area_tensor_));
work_area_ = work_area_tensor_.flat<T>().data();
// Subband array size is the number of subbands * 2 because each coefficient
// is complex.
subband_array_size_ = ((fft_length_ / 2) + 1) * 2;
size_t state_size = (*get_needed_memory_func)(fft_length_);
OP_REQUIRES_OK(context,
context->allocate_temp(
DT_INT8, TensorShape({static_cast<int32_t>(state_size)}),
&state_tensor_));
state_ = state_tensor_.flat<int8_t>().data();
(*init_func)(fft_length_, state_, state_size);
}
void Compute(tensorflow::OpKernelContext* context) override {
const tensorflow::Tensor& input_tensor = context->input(0);
const T* input = input_tensor.flat<T>().data();
TensorShape output_shape = input_tensor.shape();
output_shape.set_dim(output_shape.dims() - 1, subband_array_size_);
// Create an output tensor
tensorflow::Tensor* output_tensor = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output(0, output_shape, &output_tensor));
T* output = output_tensor->flat<T>().data();
int outer_dims = input_tensor.flat_inner_dims<T, 2>().dimensions().at(0);
int frame_size = input_tensor.flat_inner_dims<T, 2>().dimensions().at(1);
for (int i = 0; i < outer_dims; i++) {
auto input_in_work_end =
std::copy_n(&input[i * frame_size], frame_size, work_area_);
// Zero pad input to FFT length
std::fill(input_in_work_end, &work_area_[fft_length_], 0);
(*apply_func)(
state_, work_area_,
reinterpret_cast<Complex<T>*>(&output[i * subband_array_size_]));
}
}
private:
int fft_length_;
int subband_array_size_;
int8_t* state_;
T* work_area_;
Tensor work_area_tensor_;
Tensor state_tensor_;
};
// TODO(b/286250473): change back name after name clash resolved
REGISTER_KERNEL_BUILDER(Name("SignalRfft")
.Device(tensorflow::DEVICE_CPU)
.TypeConstraint<float>("T"),
RfftOp<float, DT_FLOAT, RfftFloatGetNeededMemory,
RfftFloatInit, RfftFloatApply>);
REGISTER_KERNEL_BUILDER(Name("SignalRfft")
.Device(tensorflow::DEVICE_CPU)
.TypeConstraint<int16>("T"),
RfftOp<int16_t, DT_INT16, RfftInt16GetNeededMemory,
RfftInt16Init, RfftInt16Apply>);
REGISTER_KERNEL_BUILDER(Name("SignalRfft")
.Device(tensorflow::DEVICE_CPU)
.TypeConstraint<int32>("T"),
RfftOp<int32_t, DT_INT32, RfftInt32GetNeededMemory,
RfftInt32Init, RfftInt32Apply>);
} // namespace signal
} // namespace tensorflow
\ No newline at end of file
......@@ -5,6 +5,14 @@ package(
licenses = ["notice"],
)
tflm_signal_kernel_library(
name = "fft_ops",
srcs = ["fft_ops.cc"],
deps = [
"@tensorflow_cc_deps//:cc_library",
],
)
tflm_signal_kernel_library(
name = "window_op",
srcs = ["window_op.cc"],
......
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
using tensorflow::shape_inference::InferenceContext;
using tensorflow::shape_inference::ShapeHandle;
namespace tensorflow {
namespace signal {
Status RfftShape(InferenceContext* c) {
ShapeHandle out;
int fft_length;
TF_RETURN_IF_ERROR(c->GetAttr<int>("fft_length", &fft_length));
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &out));
auto dim = ((fft_length / 2) + 1) * 2; // * 2 for complex
TF_RETURN_IF_ERROR(c->ReplaceDim(out, -1, c->MakeDim(dim), &out));
c->set_output(0, out);
return OkStatus();
}
// TODO(b/286250473): change back name after name clash resolved
REGISTER_OP("SignalRfft")
.Attr("T: {float, int16, int32}")
.Attr("fft_length: int >= 2")
.Input("input: T")
.Output("output: T")
.SetShapeFn(RfftShape)
.Doc(R"doc(
Computes the 1-dimensional discrete Fourier transform of a real-valued signal
over the inner-most dimension of input. Since the DFT of a real signal is
Hermitian-symmetric, RFFT only returns the fft_length / 2 + 1 unique complex
components of the FFT: the zero-frequency term, followed by the fft_length / 2
positive-frequency terms. Along the axis RFFT is computed on, if fft_length is
larger than the corresponding dimension of input, the dimension is padded with
zeros.
input: A Tensor. Must be one of the following types: float32, int16, int32
output: A tensor containing ((fft_length / 2) + 1) complex spectral
components along its innermost dimension.
Since there's no TF integer complex type, the array is represented using
((fft_length / 2) + 1) * 2 real elements.
For integer input (int16, int32), the output is scaled by 1 / fft_length
relative to the theoretical DFT, to avoid overflowing.
For floating point (float32) input, the output isn't scaled.
fft_length: The length of the FFT operation. An input signal that's shorter
will be zero padded to fft_length.
)doc");
} // namespace signal
} // namespace tensorflow
# Test data for signal tests. This is used to store large arrays which would make tests less readable.
package(
default_visibility = ["//signal/micro/kernels:__subpackages__"],
licenses = ["notice"],
)
cc_library(
name = "fft_test_data",
srcs = [
"fft_test_data.cc",
],
hdrs = [
"fft_test_data.h",
],
)
此差异已折叠。
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef SIGNAL_TESTDATA_FFT_TEST_DATA_H_
#define SIGNAL_TESTDATA_FFT_TEST_DATA_H_
#include <cstdint>
namespace tflite {
/* These arrays are generated using random data. They serve to detect changes
* in the kernels. They do not test correctness.
*/
extern const int16_t kRfftInt16Length512Input[];
extern const int16_t kRfftInt16Length512Golden[];
extern const int32_t kRfftInt32Length512Input[];
extern const int32_t kRfftInt32Length512Golden[];
extern const float kRfftFloatLength512Input[];
extern const float kRfftFloatLength512Golden[];
} // namespace tflite
#endif // SIGNAL_TESTDATA_FFT_TEST_DATA_H_
......@@ -333,6 +333,7 @@ tflm_kernel_cc_library(
":activation_utils",
":kernel_util",
":micro_tensor_utils",
"//signal/micro/kernels:register_signal_ops",
"//tensorflow/lite/c:common",
"//tensorflow/lite/kernels:kernel_util",
"//tensorflow/lite/kernels:op_macros",
......@@ -351,7 +352,6 @@ tflm_kernel_cc_library(
"//tensorflow/lite/micro:micro_log",
"//tensorflow/lite/micro:micro_utils",
"//tensorflow/lite/schema:schema_fbs",
"//signal/micro/kernels:register_signal_ops",
"@flatbuffers//:runtime_cc",
] + select({
xtensa_fusion_f1_config(): ["//third_party/xtensa/nnlib_hifi4:nnlib_hifi4_lib"],
......
......@@ -48,6 +48,12 @@ $(eval $(call microlite_test,unidirectional_sequence_lstm_test,\
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/testdata/lstm_test_data.cc,\
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/testdata/lstm_test_data.h))
$(eval $(call microlite_test,kernel_fft_test,\
$(TENSORFLOW_ROOT)signal/micro/kernels/fft_test.cc \
$(TENSORFLOW_ROOT)signal/micro/kernels/fft_flexbuffers_generated_data.cc \
$(TENSORFLOW_ROOT)signal/testdata/fft_test_data.cc, \
$(TENSORFLOW_ROOT)signal/micro/kernels/fft_flexbuffers_generated_data.h))
$(eval $(call microlite_test,kernel_window_test,\
$(TENSORFLOW_ROOT)signal/micro/kernels/window_test.cc \
$(TENSORFLOW_ROOT)signal/micro/kernels/window_flexbuffers_generated_data.cc, \
......
......@@ -96,7 +96,7 @@ TfLiteStatus CalculateArithmeticOpDataLogistic(TfLiteContext* context,
data->input_multiplier = static_cast<int32_t>(multiplier);
}
TFLITE_DCHECK_LE(data->input_multiplier, 32767);
int output_scale_log2_rounded;
TF_LITE_ENSURE(
context, CheckedLog2(output->params.scale, &output_scale_log2_rounded));
......
......@@ -15,6 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_MICRO_KERNELS_MICRO_OPS_H_
#define TENSORFLOW_LITE_MICRO_KERNELS_MICRO_OPS_H_
#include "signal/micro/kernels/rfft.h"
#include "tensorflow/lite/c/common.h"
// Forward declaration of all micro op kernel registration methods. These
......@@ -127,11 +128,12 @@ TFLMRegistration Register_UNIDIRECTIONAL_SEQUENCE_LSTM();
TFLMRegistration Register_UNPACK();
TFLMRegistration Register_VAR_HANDLE();
TFLMRegistration Register_WHILE();
TFLMRegistration Register_ZEROS_LIKE();
// TODO(b/160234179): Change custom OPs to also return by value.
namespace tflm_signal {
TFLMRegistration* Register_WINDOW();
}
TFLMRegistration Register_ZEROS_LIKE();
} // namespace tflm_signal
namespace ops {
namespace micro {
......
......@@ -118,7 +118,7 @@ TfLiteStatus CalculateArithmeticOpData(TfLiteContext* context, TfLiteNode* node,
data->input_multiplier = static_cast<int32_t>(multiplier);
}
TFLITE_DCHECK_LE(data->input_multiplier, 32767);
int output_scale_log2_rounded;
TF_LITE_ENSURE(
context, CheckedLog2(output->params.scale, &output_scale_log2_rounded));
......
......@@ -84,7 +84,8 @@ class MicroMutableOpResolver : public MicroOpResolver {
// function is called again for a previously added Custom Operator, the
// MicroOpResolver will be unchanged and this function will return
// kTfLiteError.
TfLiteStatus AddCustom(const char* name, TFLMRegistration* registration) {
TfLiteStatus AddCustom(const char* name,
const TFLMRegistration* registration) {
if (registrations_len_ >= tOpCount) {
MicroPrintf(
"Couldn't register custom op '%s', resolver size is too"
......@@ -443,6 +444,12 @@ class MicroMutableOpResolver : public MicroOpResolver {
ParseResizeNearestNeighbor);
}
TfLiteStatus AddRfft(const TFLMRegistration* registration =
tflite::tflm_signal::Register_RFFT()) {
// TODO(b/286250473): change back name and remove namespace
return AddCustom("SignalRfft", registration);
}
TfLiteStatus AddRound() {
return AddBuiltin(BuiltinOperator_ROUND,
tflite::ops::micro::Register_ROUND(), ParseRound);
......
......@@ -82,6 +82,7 @@ INCLUDES := \
-I. \
-I$(DOWNLOADS_DIR)/gemmlowp \
-I$(DOWNLOADS_DIR)/flatbuffers/include \
-I$(DOWNLOADS_DIR)/kissfft \
-I$(DOWNLOADS_DIR)/ruy
ifneq ($(TENSORFLOW_ROOT),)
......@@ -311,6 +312,15 @@ $(TENSORFLOW_ROOT)tensorflow/lite/micro/memory_planner/linear_memory_planner_tes
$(TENSORFLOW_ROOT)tensorflow/lite/micro/memory_planner/non_persistent_buffer_planner_shim_test.cc
MICROLITE_CC_KERNEL_SRCS := \
$(TENSORFLOW_ROOT)signal/micro/kernels/rfft.cc \
$(TENSORFLOW_ROOT)signal/micro/kernels/window.cc \
$(TENSORFLOW_ROOT)signal/src/kiss_fft_wrappers/kiss_fft_float.cc \
$(TENSORFLOW_ROOT)signal/src/kiss_fft_wrappers/kiss_fft_int16.cc \
$(TENSORFLOW_ROOT)signal/src/kiss_fft_wrappers/kiss_fft_int32.cc \
$(TENSORFLOW_ROOT)signal/src/rfft_float.cc \
$(TENSORFLOW_ROOT)signal/src/rfft_int16.cc \
$(TENSORFLOW_ROOT)signal/src/rfft_int32.cc \
$(TENSORFLOW_ROOT)signal/src/window.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/activations.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/activations_common.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/add.cc \
......@@ -413,8 +423,6 @@ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/unidirectional_sequence_lstm.cc
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/unpack.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/var_handle.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/while.cc \
$(TENSORFLOW_ROOT)signal/micro/kernels/window.cc \
$(TENSORFLOW_ROOT)signal/src/window.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/zeros_like.cc
MICROLITE_TEST_HDRS := \
......@@ -446,14 +454,15 @@ $(wildcard $(TENSORFLOW_ROOT)tensorflow/lite/micro/tflite_bridge/*.cc) \
$(TFL_CC_SRCS)
MICROLITE_CC_HDRS := \
$(wildcard $(TENSORFLOW_ROOT)signal/micro/kernels/*.h) \
$(wildcard $(TENSORFLOW_ROOT)signal/src/*.h) \
$(wildcard $(TENSORFLOW_ROOT)signal/src/kiss_fft_wrappers/*.h) \
$(wildcard $(TENSORFLOW_ROOT)tensorflow/lite/micro/*.h) \
$(wildcard $(TENSORFLOW_ROOT)tensorflow/lite/micro/benchmarks/*model_data.h) \
$(wildcard $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/*.h) \
$(wildcard $(TENSORFLOW_ROOT)tensorflow/lite/micro/arena_allocator/*.h) \
$(wildcard $(TENSORFLOW_ROOT)tensorflow/lite/micro/memory_planner/*.h) \
$(wildcard $(TENSORFLOW_ROOT)tensorflow/lite/micro/tflite_bridge/*.h) \
$(wildcard $(TENSORFLOW_ROOT)signal/micro/kernels/*.h) \
$(wildcard $(TENSORFLOW_ROOT)signal/src/*.h) \
$(TENSORFLOW_ROOT)LICENSE \
$(TFL_CC_HDRS)
......