...
 
Commits (13)
    https://gitcode.net/xusiwei1236/tflite-micro/-/commit/e687c60812bc9998bd92295f08de24ce4b157f7f Initialize test output variables (#2124) 2023-07-13T22:33:57+00:00 RJ Ascani rjascani@google.com BUG=b/290989355 https://gitcode.net/xusiwei1236/tflite-micro/-/commit/c3fddae75235b7added31cde97f2fd7183de1a09 Move kTfLiteAbort to micro_context.h (#2072) 2023-07-14T17:10:16+00:00 TANMAY DAS 16020637+tanmaydas82@users.noreply.github.com BUG=<a href="http://b/149795762" rel="nofollow noreferrer noopener" target="_blank">http://b/149795762</a> https://gitcode.net/xusiwei1236/tflite-micro/-/commit/6da07de9c8f2fec2e40e724e392fe8a11e848bb4 Generic benchmarking. (#2125) 2023-07-14T18:37:18+00:00 TANMAY DAS 16020637+tanmaydas82@users.noreply.github.com BUG=<a href="https://b.corp.google.com/277097397" rel="nofollow noreferrer noopener" target="_blank">https://b.corp.google.com/277097397</a> https://gitcode.net/xusiwei1236/tflite-micro/-/commit/2548022e06de03558d50893c025a295959c7a729 Filter_bank ops (#2123) 2023-07-14T20:17:33+00:00 Steven Toribio 34755817+turbotoribio@users.noreply.github.com `port c++ filter_bank ops to open source in tflm_signal` -port filter_bank ops and corresponding to new open source location for C++ BUG=[b/289422411](<a href="https://b.corp.google.com/issues/289422411" rel="nofollow noreferrer noopener" target="_blank">https://b.corp.google.com/issues/289422411</a>) https://gitcode.net/xusiwei1236/tflite-micro/-/commit/699c5178b9bb4724a7d76731da67be7bb710bdf1 Explicit load for py_library. (#2127) 2023-07-14T21:59:36+00:00 Advait Jain advaitjain@users.noreply.github.com BUG=<a href="http://b/291306662" rel="nofollow noreferrer noopener" target="_blank">http://b/291306662</a> https://gitcode.net/xusiwei1236/tflite-micro/-/commit/15154aa6f8382afacc3c303d34b7f3096256ae50 Add RISC-V as a requirement for the issue on error. (#2128) 2023-07-14T15:36:26-07:00 Advait Jain advaitjain@users.noreply.github.com This will (hopefully) ensure that RISC-V failures result in a github issue being created. BUG=<a href="http://b/291335234" rel="nofollow noreferrer noopener" target="_blank">http://b/291335234</a> https://gitcode.net/xusiwei1236/tflite-micro/-/commit/52007f6a479a37c6513fa5b6523791572902c25c Fix benchmarking utils on RISC-V (#2130) 2023-07-15T16:31:36+00:00 RJ Ascani rjascani@google.com The generic benchmarking utilities added in PR 2125 did not compile on RISC-V due to two separate issues: * An incorrect format specifier for int32_t on RISC-V. This is resolved by using PRId32 instead of %ld or %d. * The fileno() is undeclared with the RISC-V toolchain, similar to how it is with the Embedded ARM toolchain. This PR resolves this by forgoing the fstat() check on file size (as this is what used fileno), and instead attempting to read the maximum number of bytes we can fit in the model buffer. We can then use feof() to verify that we read the entirety of the model, and error out if we did not. This solution also eliminates needing to worry about different file types, such as how the Xtensa simulator treats files from the host system as character devices. BUG=2129 https://gitcode.net/xusiwei1236/tflite-micro/-/commit/70aed11c9545897976b046655ba45c8bbae07075 Adds FFT Auto Scale Op (#2134) 2023-07-19T00:08:47+00:00 suleshahid 110432064+suleshahid@users.noreply.github.com This PR adds additional FFT op functionality in the Signal library, namely adding the FFT Auto Scale operation. Testing added in the original `fft_test.cc` and `fft_ops_test.py`. BUG=[287346710](<a href="http://b/287346710" rel="nofollow noreferrer noopener" target="_blank">http://b/287346710</a>) https://gitcode.net/xusiwei1236/tflite-micro/-/commit/61bc56d574aac7a08f81b09653b29443a9223306 Fix numpy array compare for Numpy 1.25 (#2136) 2023-07-19T18:08:19+00:00 suleshahid 110432064+suleshahid@users.noreply.github.com BUG=[549343727](<a href="http://cl/549343727" rel="nofollow noreferrer noopener" target="_blank">http://cl/549343727</a>) https://gitcode.net/xusiwei1236/tflite-micro/-/commit/ed11500ab9761c101a918a4aaf9f44c4ac7056b3 Delay OP python extension (#2138) 2023-07-19T22:52:06+00:00 suleshahid 110432064+suleshahid@users.noreply.github.com Extends the Signal Library Delay OP to be usable from python. Can test via `bazel run python/tflite_micro/signal:delay_op_test` BUG=[287346710](<a href="http://b/287346710" rel="nofollow noreferrer noopener" target="_blank">http://b/287346710</a>) https://gitcode.net/xusiwei1236/tflite-micro/-/commit/55037d2d5e82fa0786b7b51a4f436e57b45cfcbe Adds IRFFT Op to Signal Library (#2137) 2023-07-20T01:17:26+00:00 suleshahid 110432064+suleshahid@users.noreply.github.com Inverse-RFFT as part of Signal library ops. Testing via current FFT Op tests. BUG=[287346710](<a href="http://b/287346710" rel="nofollow noreferrer noopener" target="_blank">http://b/287346710</a>) https://gitcode.net/xusiwei1236/tflite-micro/-/commit/40a6936e9461fa8db157b5b3faf89ae9a0384554 Increase tolerance for fft_test (#2139) 2023-07-20T19:03:09+00:00 suleshahid 110432064+suleshahid@users.noreply.github.com Our Cortex-m55 build is failing for fft_test.cc due to strict tolerance in irfft tests. Increasing tolerance slightly similar to what we did for rfft. Can test locally with ```make -f tensorflow/lite/micro/tools/make/Makefile -j24 test_kernel_signal_fft_test TARGET=cortex_m_corstone_300 TARGET_ARCH=cortex-m55``` BUG=[287518815](<a href="http://b/287518815" rel="nofollow noreferrer noopener" target="_blank">http://b/287518815</a>) https://gitcode.net/xusiwei1236/tflite-micro/-/commit/bb3fda3a9c9aa8d048c48a63fa7cd9da41e25d91 Python extensions for energy, framer, overlap_add Signal OPs (#2142) 2023-07-24T17:21:48+00:00 suleshahid 110432064+suleshahid@users.noreply.github.com We this PR, you can use these ops directly from python, including in TF graphs. Test with `bazel run python/tflite_micro/signal:framer_op_test`, etc. BUG=[287346710](<a href="http://b/287346710" rel="nofollow noreferrer noopener" target="_blank">http://b/287346710</a>)
......@@ -30,7 +30,7 @@ jobs:
tflm-bot-token: ${{ secrets.TFLM_BOT_PACKAGE_READ_TOKEN }}
issue_on_error:
needs: [xtensa_postmerge]
needs: [riscv_postmerge,xtensa_postmerge]
if: ${{ always() && contains(needs.*.result, 'failure') &&
!contains(github.event.pull_request.labels.*.name, 'ci:run_full') }}
uses: ./.github/workflows/issue_on_error.yml
......
"""BUILD rules for generating flatbuffer files."""
load("@build_bazel_rules_android//android:rules.bzl", "android_library")
load("@rules_python//python:defs.bzl", "py_library")
flatc_path = "@flatbuffers//:flatc"
zip_files = "//tensorflow/lite/tools:zip_files"
......@@ -449,7 +450,7 @@ def flatbuffer_py_library(
":{}".format(all_srcs_no_include),
],
)
native.py_library(
py_library(
name = name,
srcs = [
":{}".format(concat_py_srcs),
......
......@@ -51,6 +51,7 @@ PythonOpsResolver::PythonOpsResolver() {
AddEthosU();
AddExp();
AddExpandDims();
AddFftAutoScale();
AddFill();
AddFloor();
AddFloorDiv();
......@@ -63,6 +64,7 @@ PythonOpsResolver::PythonOpsResolver() {
AddGreaterEqual();
AddHardSwish();
AddIf();
AddIrfft();
AddL2Normalization();
AddL2Pool2D();
AddLeakyRelu();
......
......@@ -15,7 +15,11 @@ cc_library(
name = "ops_lib",
visibility = [":signal_friends"],
deps = [
":delay_op_cc",
":energy_op_cc",
":fft_ops_cc",
":framer_op_cc",
":overlap_add_op_cc",
":window_op_cc",
],
)
......@@ -29,11 +33,64 @@ py_library(
srcs_version = "PY3",
visibility = ["//python/tflite_micro/signal/utils:__subpackages__"],
deps = [
":delay_op",
":energy_op",
":fft_ops",
":framer_op",
":overlap_add_op",
":window_op",
],
)
py_tflm_signal_library(
name = "delay_op",
srcs = ["ops/delay_op.py"],
cc_op_defs = ["//signal/tensorflow_core/ops:delay_op"],
cc_op_kernels = [
"//signal/tensorflow_core/kernels:delay_kernel",
],
)
py_test(
name = "delay_op_test",
size = "small",
srcs = ["ops/delay_op_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":delay_op",
"//python/tflite_micro/signal/utils:util",
requirement("numpy"),
requirement("tensorflow-cpu"),
],
)
py_tflm_signal_library(
name = "energy_op",
srcs = ["ops/energy_op.py"],
cc_op_defs = ["//signal/tensorflow_core/ops:energy_op"],
cc_op_kernels = [
"//signal/tensorflow_core/kernels:energy_kernel",
],
)
py_test(
name = "energy_op_test",
size = "small",
srcs = ["ops/energy_op_test.py"],
data = [
"//python/tflite_micro/signal/ops/testdata:energy_test1.txt",
],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":energy_op",
"//python/tflite_micro/signal/utils:util",
requirement("numpy"),
requirement("tensorflow-cpu"),
],
)
py_tflm_signal_library(
name = "fft_ops",
srcs = ["ops/fft_ops.py"],
......@@ -46,6 +103,10 @@ py_tflm_signal_library(
py_test(
name = "fft_ops_test",
srcs = ["ops/fft_ops_test.py"],
data = [
"//python/tflite_micro/signal/ops/testdata:fft_auto_scale_test1.txt",
"//python/tflite_micro/signal/ops/testdata:rfft_test1.txt",
],
python_version = "PY3",
srcs_version = "PY3",
deps = [
......@@ -56,6 +117,56 @@ py_test(
],
)
py_tflm_signal_library(
name = "framer_op",
srcs = ["ops/framer_op.py"],
cc_op_defs = ["//signal/tensorflow_core/ops:framer_op"],
cc_op_kernels = [
"//signal/tensorflow_core/kernels:framer_kernel",
],
)
py_test(
name = "framer_op_test",
size = "small",
srcs = ["ops/framer_op_test.py"],
data = [
"//python/tflite_micro/signal/ops/testdata:framer_test1.txt",
],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":framer_op",
"//python/tflite_micro/signal/utils:util",
requirement("numpy"),
requirement("tensorflow-cpu"),
],
)
py_tflm_signal_library(
name = "overlap_add_op",
srcs = ["ops/overlap_add_op.py"],
cc_op_defs = ["//signal/tensorflow_core/ops:overlap_add_op"],
cc_op_kernels = [
"//signal/tensorflow_core/kernels:overlap_add_kernel",
],
)
py_test(
name = "overlap_add_op_test",
size = "small",
srcs = ["ops/overlap_add_op_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":overlap_add_op",
"//python/tflite_micro/signal/utils:util",
"@absl_py//absl/testing:parameterized",
requirement("numpy"),
requirement("tensorflow-cpu"),
],
)
py_tflm_signal_library(
name = "window_op",
srcs = ["ops/window_op.py"],
......
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Use overlap add op in python."""
import tensorflow as tf
from tflite_micro.python.tflite_micro.signal.utils import util
gen_delay_op = util.load_custom_op('delay_op.so')
def _delay_wrapper(delay_fn, default_name):
"""Wrapper around gen_delay_op.delay*."""
def _delay(input_tensor, delay_length, name=default_name):
with tf.name_scope(name) as name:
input_tensor = tf.convert_to_tensor(input_tensor, dtype=tf.int16)
return delay_fn(input_tensor, delay_length=delay_length, name=name)
return _delay
# TODO(b/286250473): change back name after name clash resolved
delay = _delay_wrapper(gen_delay_op.signal_delay, "signal_delay")
tf.no_gradient("signal_delay")
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for delay op."""
import numpy as np
import tensorflow as tf
from tflite_micro.python.tflite_micro.signal.ops import delay_op
from tflite_micro.python.tflite_micro.signal.utils import util
class DelayOpTest(tf.test.TestCase):
def TestHelper(self, input_signal, delay_length, frame_size):
inner_dim_size = input_signal.shape[-1]
input_signal_rank = len(input_signal.shape)
frame_num = int(np.ceil((inner_dim_size + delay_length) / frame_size))
# We need to continue feeding the op with zeros until the delay line is
# flushed. Pad the input signal to a multiple of frame_size.
padded_size = frame_num * frame_size
pad_size = int(padded_size - inner_dim_size)
# Axes to pass to np.pad. All axes have no padding except the innermost one.
pad_outer_axes = np.zeros([input_signal_rank - 1, 2], dtype=int)
pad_input_signal = np.vstack([pad_outer_axes, [0, pad_size]])
input_signal_padded = np.pad(input_signal, pad_input_signal)
delay_exp_signal = np.vstack(
[pad_outer_axes, [delay_length, pad_size - delay_length]])
delay_exp = np.pad(input_signal, delay_exp_signal)
delay_out = np.zeros(input_signal_padded.shape)
in_frame_shape = input_signal.shape[:-1] + (frame_size, )
func = tf.function(delay_op.delay)
concrete_function = func.get_concrete_function(tf.TensorSpec(
in_frame_shape, dtype=tf.int16),
delay_length=delay_length)
interpreter = util.get_tflm_interpreter(concrete_function, func)
for i in range(frame_num):
in_frame = input_signal_padded[..., i * frame_size:(i + 1) * frame_size]
# TFLM
interpreter.set_input(in_frame, 0)
interpreter.invoke()
out_frame_tflm = interpreter.get_output(0)
# TF
out_frame = self.evaluate(
delay_op.delay(in_frame, delay_length=delay_length))
delay_out[..., i * frame_size:(i + 1) * frame_size] = out_frame
self.assertAllEqual(out_frame, out_frame_tflm)
self.assertAllEqual(delay_out, delay_exp)
def testFrameLargerThanDelay(self):
self.TestHelper(np.arange(0, 30, dtype=np.int16), 7, 10)
def testFrameSmallerThanDelay(self):
self.TestHelper(np.arange(0, 70, dtype=np.int16), 21, 3)
def testZeroDelay(self):
self.TestHelper(np.arange(0, 20, dtype=np.int16), 0, 3)
def testNegativeDelay(self):
with self.assertRaises((tf.errors.InvalidArgumentError, ValueError)):
self.TestHelper(np.arange(1, 20, dtype=np.int16), -21, 3)
def testMultiDimensionalDelay(self):
input_signal = np.reshape(np.arange(0, 120, dtype=np.int16), [2, 3, 20])
self.TestHelper(input_signal, 4, 6)
input_signal = np.reshape(np.arange(0, 72, dtype=np.int16),
[2, 2, 3, 3, 2])
self.TestHelper(input_signal, 7, 3)
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Use energy op in python."""
import tensorflow as tf
from tflite_micro.python.tflite_micro.signal.utils import util
gen_energy_op = util.load_custom_op('energy_op.so')
def _energy_wrapper(energy_fn, default_name):
"""Wrapper around gen_energy_op.energy*."""
def _energy(input_tensor, start_index=0, end_index=-1, name=default_name):
with tf.name_scope(name) as name:
input_tensor = tf.convert_to_tensor(input_tensor, dtype=tf.int16)
dim_list = input_tensor.shape.as_list()
if len(dim_list) != 1:
raise ValueError("Input tensor must have a rank of 1")
if end_index == -1:
end_index = dim_list[0] - 1
return energy_fn(input_tensor,
start_index=start_index,
end_index=end_index,
name=name)
return _energy
# TODO(b/286250473): change back name after name clash resolved
energy = _energy_wrapper(gen_energy_op.signal_energy, "signal_energy")
tf.no_gradient("signal_energy")
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for energy op."""
import os
import numpy as np
import tensorflow as tf
from tensorflow.python.platform import resource_loader
from tflite_micro.python.tflite_micro.signal.ops import energy_op
from tflite_micro.python.tflite_micro.signal.utils import util
class EnergyOpTest(tf.test.TestCase):
_PREFIX_PATH = resource_loader.get_path_to_datafile('')
def GetResource(self, filepath):
full_path = os.path.join(self._PREFIX_PATH, filepath)
with open(full_path, 'rt') as f:
file_text = f.read()
return file_text
def SingleEnergyTest(self, filename):
lines = self.GetResource(filename).splitlines()
args = lines[0].split()
start_index = int(args[0])
end_index = int(args[1])
func = tf.function(energy_op.energy)
input_size = len(lines[1].split())
concrete_function = func.get_concrete_function(tf.TensorSpec(
input_size, dtype=tf.int16),
start_index=start_index,
end_index=end_index)
interpreter = util.get_tflm_interpreter(concrete_function, func)
# Skip line 0, which contains the configuration params.
# Read lines in pairs <input, expected>
i = 1
while i < len(lines):
in_frame = np.array([int(j) for j in lines[i].split()], dtype='int16')
out_frame_exp = [int(j) for j in lines[i + 1].split()]
# TFLM
interpreter.set_input(in_frame, 0)
interpreter.invoke()
out_frame = interpreter.get_output(0)
for j in range(start_index, end_index):
self.assertEqual(out_frame_exp[j], out_frame[j])
# TF
out_frame = self.evaluate(
energy_op.energy(in_frame,
start_index=start_index,
end_index=end_index))
for j in range(start_index, end_index):
self.assertEqual(out_frame_exp[j], out_frame[j])
i += 2
def testSingleFrame(self):
start_index = 5
end_index = 250
energy_in = [
-56, 0, 26, 49, 144, -183, -621, 16, 544, 605, 11, -581, -26, 245,
-210, -273, 200, 541, 268, -319, -43, -544, -747, 356, 415, 356, 174,
-133, 4, -278, -487, 104, 449, 560, 223, -691, -451, 130, 132, 202, 86,
-91, 170, -85, -454, -123, 330, 125, -434, 104, 422, 89, -14, -113,
-123, -63, 125, 142, 40, -218, -183, -10, 3, 154, 95, -64, -108, -55,
55, 216, 47, -358, -297, 391, 437, 5, -59, -252, -102, -25, -60, 76,
-46, 6, 128, 113, -4, -101, 20, -75, -154, 88, 144, -50, -163, 58, 112,
38, 31, 2, -38, -80, 77, 63, -136, -83, 83, 89, 32, 27, 6, -237, -247,
250, 292, -13, -55, 4, 58, -182, -120, 63, -33, -40, -88, 152, 246, 41,
-99, -178, -11, 68, -10, 3, 14, 39, 30, -94, -29, 79, -6, -84, -65, 55,
138, 71, -141, -151, 150, 149, -159, -106, 203, 55, -207, -153, -37,
231, 187, -6, 54, -66, -85, -258, -244, 271, 157, 24, 117, 144, 144,
-202, -66, -320, -478, 340, 510, 46, -152, -185, -199, -19, 139, 282,
-15, -140, 129, 45, -124, -26, 145, -36, -79, -17, -85, -29, 104, 82,
-84, -7, 127, -96, -210, 60, 114, 67, 40, -3, -1, -101, -76, 77, 55,
-5, 19, 13, 13, -36, -40, -34, 20, 63, 7, -66, -44, -6, -22, 66, 40,
-20, 13, 21, -15, -45, 6, 38, 19, -40, -46, -3, 2, 41, 41, -17, -37,
-11, 15, 13, -4, -5, 0, 1, 2, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 1,
1, -1, -2, -1, 0, -1, 0, 0, -1, 0, -1, 0, 0, 1, 0, -1, 0, 1, -1, -1, 0,
0, 0, -1, 0, -1, 0, 1, 0, 0, -1, -1, 1, -1, -1, 0, 0, 0, 0, -2, -1, -1,
0, 0, -1, -1, 0, 0, -1, -1, -1, 1, 0, -1, 0, 0, 0, -1, 0, 0, 0, 1, -1,
-1, 0, 1, 0, -1, -1, 0, -2, 0, 0, 0, 0, -1, -3, 1, 2, 0, 0, 1, 2, -1,
-1, -1, -1, -1, -1, 0, 0, 1, 0, -1, -1, 1, 0, 0, 1, -1, 0, 0, 0, 0, 0,
-1, 0, -1, 0, 0, -1, -1, 0, 0, 0, 0, -1, -1, 0, 1, -1, -1, 0, -2, 0, 1,
0, 0, 0, 0, 1, -1, -1, 1, 0, -1, 0, 0, -1, 0, 2, 1, -2, -1, 1, 0, 0,
-2, 0, 0, -1, -1, 0, 0, 0, 0, -1, -2, -1, 1, 1, 0, 0, 0, 0, -1, -1, 0,
1, 1, 0, 0, 0, 0, -1, 0, 0, 0, 0, -1, 0, -1, 0, -1, -1, 0, 0, 0, 1, -1,
0, 1, 0, -1, 0, -1, 0, 0, 0, 0, 0, -1, -2, 0, 1, 0, 1, 1, -1, -1, -1,
0, 0, 0, 0, -1, -1, -1, 0, 1, 0, -2, -1, 1, 1, 1, -1, -3, -1, 1, -1,
-2, 0, -1, -2, -1, 0, 0, 0, -3, -1, 0, 0, -1, 0, 0, -2, 0
]
energy_exp = [
0, 0, 0, 0, 0, 337682, 60701, 118629, 332681, 173585, 297785, 684745,
298961, 47965, 77300, 247985, 515201, 527210, 220301, 58228, 15677,
36125, 221245, 124525, 199172, 186005, 12965, 19098, 35789, 49124,
33589, 23725, 13121, 14689, 49681, 130373, 241090, 190994, 66985,
11029, 9376, 2152, 29153, 10217, 6025, 31460, 23236, 29933, 13988, 965,
7844, 9898, 25385, 14810, 1753, 56205, 123509, 85433, 3041, 36488,
18369, 2689, 30848, 62197, 41485, 4745, 109, 1717, 9736, 7082, 7092,
7250, 24085, 42682, 44701, 36517, 44234, 66258, 54730, 35005, 7272,
73789, 132977, 25225, 34425, 61540, 106756, 344084, 262216, 57329,
39962, 98845, 19825, 18666, 16052, 22321, 6530, 8066, 17540, 7105,
25345, 47700, 17485, 1609, 10202, 11705, 3050, 530, 1465, 2756, 4369,
4405, 1972, 4840, 2000, 610, 2250, 1480, 1961, 2125, 1685, 1970, 1490,
394, 41, 1, 4, 0, 0, 1, 0, 0, 1, 2, 5, 1, 0, 1, 1, 1, 1, 1, 2, 0, 1, 1,
1, 0, 2, 2, 1, 0, 4, 2, 0, 2, 0, 2, 2, 1, 0, 1, 0, 1, 2, 1, 1, 1, 4, 0,
1, 10, 4, 1, 5, 2, 2, 1, 1, 1, 2, 0, 2, 0, 0, 1, 1, 0, 2, 0, 0, 2, 1,
2, 4, 1, 0, 0, 2, 2, 1, 0, 1, 5, 5, 1, 4, 0, 2, 0, 0, 5, 2, 1, 0, 1, 1,
2, 0, 0, 1, 0, 1, 1, 1, 1, 0, 2, 1, 1, 1, 0, 0, 1, 4, 1, 2, 2, 1, 0, 1,
2, 1, 4, 2, 2, 10, 2, 5, 1, 0, 0, 0, 0, 0, 0, 0
]
energy_out = energy_op.energy(energy_in,
start_index=start_index,
end_index=end_index)
for j in range(start_index, end_index):
self.assertEqual(energy_exp[j], energy_out[j])
def testEnergy(self):
self.SingleEnergyTest('testdata/energy_test1.txt')
if __name__ == '__main__':
tf.test.main()
......@@ -65,5 +65,24 @@ def _fft_wrapper(fft_fn, default_name):
return _fft
def _fft_auto_scale_wrapper(fft_auto_scale_fn, default_name):
"""Wrapper around gen_fft_ops.fft_auto_scale*."""
def _fft_auto_scale(input_tensor, name=default_name):
with tf.name_scope(name) as name:
input_tensor = tf.convert_to_tensor(input_tensor, dtype=tf.int16)
dim_list = input_tensor.shape.as_list()
if len(dim_list) != 1:
raise ValueError("Input tensor must have a rank of 1")
return fft_auto_scale_fn(input_tensor, name=name)
return _fft_auto_scale
rfft = _fft_wrapper(gen_fft_ops.signal_rfft, "signal_rfft")
irfft = _fft_wrapper(gen_fft_ops.signal_irfft, "signal_irfft")
fft_auto_scale = _fft_auto_scale_wrapper(gen_fft_ops.signal_fft_auto_scale,
"signal_fft_auto_scale")
tf.no_gradient("signal_rfft")
tf.no_gradient("signal_irfft")
tf.no_gradient("signal_fft_auto_scale")
......@@ -33,6 +33,31 @@ class RfftOpTest(tf.test.TestCase):
file_text = f.read()
return file_text
def SingleFftAutoScaleTest(self, filename):
lines = self.GetResource(filename).splitlines()
func = tf.function(fft_ops.fft_auto_scale)
input_size = len(lines[0].split())
concrete_function = func.get_concrete_function(
tf.TensorSpec(input_size, dtype=tf.int16))
interpreter = util.get_tflm_interpreter(concrete_function, func)
i = 0
while i < len(lines):
in_frame = np.array([int(j) for j in lines[i].split()], dtype=np.int16)
out_frame_exp = [int(j) for j in lines[i + 1].split()]
scale_exp = [int(j) for j in lines[i + 2].split()]
# TFLM
interpreter.set_input(in_frame, 0)
interpreter.invoke()
out_frame = interpreter.get_output(0)
scale = interpreter.get_output(1)
self.assertAllEqual(out_frame_exp, out_frame)
self.assertEqual(scale_exp, scale)
# TF
out_frame, scale = self.evaluate(fft_ops.fft_auto_scale(in_frame))
self.assertAllEqual(out_frame_exp, out_frame)
self.assertEqual(scale_exp, scale)
i += 3
def SingleRfftTest(self, filename):
lines = self.GetResource(filename).splitlines()
args = lines[0].split()
......@@ -43,8 +68,6 @@ class RfftOpTest(tf.test.TestCase):
tf.TensorSpec(input_size, dtype=tf.int16), fft_length)
# TODO(b/286252893): make test more robust (vs scipy)
interpreter = util.get_tflm_interpreter(concrete_function, func)
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
# Skip line 0, which contains the configuration params.
# Read lines in pairs <input, expected>
i = 1
......@@ -53,9 +76,9 @@ class RfftOpTest(tf.test.TestCase):
out_frame_exp = [int(j) for j in lines[i + 1].split()]
# Compare TFLM inference against the expected golden values
# TODO(b/286252893): validate usage of testing vs interpreter here
interpreter.set_tensor(input_details[0]['index'], in_frame)
interpreter.set_input(in_frame, 0)
interpreter.invoke()
out_frame = interpreter.get_tensor(output_details[0]['index'])
out_frame = interpreter.get_output(0)
self.assertAllEqual(out_frame_exp, out_frame)
# TF
out_frame = self.evaluate(fft_ops.rfft(in_frame, fft_length))
......@@ -83,11 +106,9 @@ class RfftOpTest(tf.test.TestCase):
concrete_function = func.get_concrete_function(
tf.TensorSpec(np.shape(in_frames), dtype=tf.int16), fft_length)
interpreter = util.get_tflm_interpreter(concrete_function, func)
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.set_tensor(input_details[0]['index'], in_frames)
interpreter.set_input(in_frames, 0)
interpreter.invoke()
out_frame = interpreter.get_tensor(output_details[0]['index'])
out_frame = interpreter.get_output(0)
self.assertAllEqual(out_frames_exp, out_frame)
# TF
out_frames = self.evaluate(fft_ops.rfft(in_frames, fft_length))
......@@ -204,6 +225,12 @@ class RfftOpTest(tf.test.TestCase):
delta=1)
fft_length = 2 * fft_length
def testRfft(self):
self.SingleRfftTest('testdata/rfft_test1.txt')
def testRfftLargeOuterDimension(self):
self.MultiDimRfftTest('testdata/rfft_test1.txt')
def testFftTooLarge(self):
for dtype in [np.int16, np.int32, np.float32]:
fft_input = np.zeros(round(fft_ops._MAX_FFT_LENGTH * 2), dtype=dtype)
......@@ -224,6 +251,64 @@ class RfftOpTest(tf.test.TestCase):
with self.assertRaises((tf.errors.InvalidArgumentError, ValueError)):
self.evaluate(fft_ops.rfft(fft_input, 127))
def testIrfftTest(self):
for dtype in [np.int16, np.int32, np.float32]:
fft_length = fft_ops._MIN_FFT_LENGTH
while fft_length <= fft_ops._MAX_FFT_LENGTH:
if dtype == np.float32:
# Random input in the range [-1, 1)
fft_input = np.random.random(fft_length).astype(dtype) * 2 - 1
else:
fft_input = np.random.randint(
np.iinfo(np.int16).min,
np.iinfo(np.int16).max + 1, fft_length).astype(dtype)
fft_output = self.evaluate(fft_ops.rfft(fft_input, fft_length))
self.assertEqual(fft_output.shape[0], (fft_length / 2 + 1) * 2)
ifft_output = self.evaluate(fft_ops.irfft(fft_output, fft_length))
self.assertEqual(ifft_output.shape[0], fft_length)
# Output of integer RFFT and IRFFT is scaled by 1/fft_length
if dtype == np.int16:
self.assertArrayNear(fft_input,
ifft_output.astype(np.int32) * fft_length, 6500)
elif dtype == np.int32:
self.assertArrayNear(fft_input,
ifft_output.astype(np.int32) * fft_length, 7875)
else:
self.assertArrayNear(fft_input, ifft_output, 5e-7)
fft_length = 2 * fft_length
def testIrfftLargeOuterDimension(self):
for dtype in [np.int16, np.int32, np.float32]:
fft_length = fft_ops._MIN_FFT_LENGTH
while fft_length <= fft_ops._MAX_FFT_LENGTH:
if dtype == np.float32:
# Random input in the range [-1, 1)
fft_input = np.random.random([2, 5, fft_length
]).astype(dtype) * 2 - 1
else:
fft_input = np.random.randint(
np.iinfo(np.int16).min,
np.iinfo(np.int16).max + 1, [2, 5, fft_length]).astype(dtype)
fft_output = self.evaluate(fft_ops.rfft(fft_input, fft_length))
self.assertEqual(fft_output.shape[-1], (fft_length / 2 + 1) * 2)
ifft_output = self.evaluate(fft_ops.irfft(fft_output, fft_length))
self.assertEqual(ifft_output.shape[-1], fft_length)
# Output of integer RFFT and IRFFT is scaled by 1/fft_length
if dtype == np.int16:
self.assertAllClose(fft_input,
ifft_output.astype(np.int32) * fft_length,
atol=7875)
elif dtype == np.int32:
self.assertAllClose(fft_input,
ifft_output.astype(np.int32) * fft_length,
atol=7875)
else:
self.assertAllClose(fft_input, ifft_output, rtol=5e-7, atol=5e-7)
fft_length = 2 * fft_length
def testAutoScale(self):
self.SingleFftAutoScaleTest('testdata/fft_auto_scale_test1.txt')
def testPow2FftLengthTest(self):
fft_length, fft_bits = fft_ops.get_pow2_fft_length(131)
self.assertEqual(fft_length, 256)
......
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Use framer op in python."""
import tensorflow as tf
from tflite_micro.python.tflite_micro.signal.utils import util
gen_framer_op = util.load_custom_op('framer_op.so')
def _framer_wrapper(framer_fn, default_name):
"""Wrapper around gen_framer_op.framer*."""
def _framer(input_tensor,
frame_size,
frame_step,
prefill=False,
name=default_name):
if frame_step > frame_size:
raise ValueError("frame_step must not be greater than frame_size.")
with tf.name_scope(name) as name:
input_tensor = tf.convert_to_tensor(input_tensor, dtype=tf.int16)
dim_list = input_tensor.shape.as_list()
if dim_list[-1] % frame_step != 0:
raise ValueError(
"Innermost input dimenion size must be a multiple of %d elements" %
frame_step)
return framer_fn(input_tensor,
frame_size=frame_size,
frame_step=frame_step,
prefill=prefill,
name=name)
return _framer
# TODO(b/286250473): change back name after name clash resolved
framer = _framer_wrapper(gen_framer_op.signal_framer, "signal_framer")
tf.no_gradient("signal_framer")
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for framer op."""
import os
import numpy as np
import tensorflow as tf
from tensorflow.python.platform import resource_loader
from tflite_micro.python.tflite_micro.signal.ops import framer_op
from tflite_micro.python.tflite_micro.signal.utils import util
class FramerOpTest(tf.test.TestCase):
_PREFIX_PATH = resource_loader.get_path_to_datafile('')
def GetResource(self, filepath):
full_path = os.path.join(self._PREFIX_PATH, filepath)
with open(full_path, 'rt') as f:
file_text = f.read()
return file_text
def SingleFramerTest(self, filename):
lines = self.GetResource(filename).splitlines()
args = lines[0].split()
frame_size = int(args[0])
frame_step = int(args[1])
prefill = bool(int(args[2]))
func = tf.function(framer_op.framer)
input_size = len(lines[1].split())
concrete_function = func.get_concrete_function(
tf.TensorSpec(input_size, dtype=tf.int16), frame_size, frame_step,
prefill)
interpreter = util.get_tflm_interpreter(concrete_function, func)
# Skip line 0, which contains the configuration params.
# Read lines in triplets <input, expected output, expected valid>
i = 1
while i < len(lines):
in_block = np.array([int(j) for j in lines[i].split()], dtype=np.int16)
out_frame_exp = [[int(j) for j in lines[i + 1].split()]]
out_valid_exp = [int(j) for j in lines[i + 2].split()]
# TFLM
interpreter.set_input(in_block, 0)
interpreter.invoke()
out_frame = interpreter.get_output(0)
out_valid = interpreter.get_output(1)
self.assertEqual(out_valid, out_valid_exp)
if out_valid:
self.assertAllEqual(out_frame, out_frame_exp)
# TF
out_frame, out_valid = self.evaluate(
framer_op.framer(in_block, frame_size, frame_step, prefill))
self.assertEqual(out_valid, out_valid_exp)
if out_valid:
self.assertAllEqual(out_frame, out_frame_exp)
i += 3
def MultiFrameRandomInputFramerTest(self, n_frames):
# Terminonlogy: input is in blocks, output is in frames
frame_step = 160
frame_size = 400
prefill = True
block_num = 10
block_size = frame_step * n_frames
test_input = np.random.randint(np.iinfo('int16').min,
np.iinfo('int16').max,
block_size * block_num,
dtype=np.int16)
expected_output = np.concatenate((np.zeros(frame_size - frame_step,
dtype=np.int16), test_input))
func = tf.function(framer_op.framer)
concrete_function = func.get_concrete_function(
tf.TensorSpec(block_size, dtype=tf.int16), frame_size, frame_step,
prefill)
interpreter = util.get_tflm_interpreter(concrete_function, func)
block_index = 0
frame_index = 0
while block_index < block_num:
in_block = test_input[(block_index * block_size):((block_index + 1) *
block_size)]
expected_valid = 1
expected_frame = [
expected_output[((frame_index + i) *
frame_step):((frame_index + i) * frame_step +
frame_size)] for i in range(n_frames)
]
# TFLM
interpreter.set_input(in_block, 0)
interpreter.invoke()
out_frame = interpreter.get_output(0)
out_valid = interpreter.get_output(1)
self.assertEqual(out_valid, expected_valid)
if out_valid:
self.assertAllEqual(out_frame, expected_frame)
# TF
out_frame, out_valid = self.evaluate(
framer_op.framer(in_block, frame_size, frame_step, prefill))
frame_index += n_frames
self.assertEqual(out_valid, expected_valid)
self.assertAllEqual(out_frame, expected_frame)
block_index += 1
def testFramerVectors(self):
self.SingleFramerTest('testdata/framer_test1.txt')
def testFramerRandomInput(self):
self.MultiFrameRandomInputFramerTest(1)
def testFramerRandomInputNframes2(self):
self.MultiFrameRandomInputFramerTest(2)
def testFramerRandomInputNframes4(self):
self.MultiFrameRandomInputFramerTest(4)
def testStepSizeTooLarge(self):
framer_input = np.zeros(160, dtype=np.int16)
with self.assertRaises((tf.errors.InvalidArgumentError, ValueError)):
self.evaluate(framer_op.framer(framer_input, 128, 129))
def testStepSizeNotEqualInputSize(self):
framer_input = np.zeros(122, dtype=np.int16)
with self.assertRaises((tf.errors.InvalidArgumentError, ValueError)):
self.evaluate(framer_op.framer(framer_input, 321, 123))
if __name__ == '__main__':
np.random.seed(0)
tf.test.main()
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Use overlap add op in python."""
import tensorflow as tf
from tflite_micro.python.tflite_micro.signal.utils import util
gen_overlap_add_op = util.load_custom_op('overlap_add_op.so')
def _overlap_add_wrapper(overlap_add_fn, default_name):
"""Wrapper around gen_overlap_add_op.overlap_add*."""
def _overlap_add(input_tensor, frame_step, name=default_name):
with tf.name_scope(name) as name:
input_tensor = tf.convert_to_tensor(input_tensor)
dim_list = input_tensor.shape.as_list()
if frame_step > dim_list[-1]:
raise ValueError(
"Frame_step must not exceed innermost input dimension")
return overlap_add_fn(input_tensor, frame_step=frame_step, name=name)
return _overlap_add
# TODO(b/286250473): change back name after name clash resolved
overlap_add = _overlap_add_wrapper(gen_overlap_add_op.signal_overlap_add,
"signal_overlap_add")
tf.no_gradient("signal_overlap_add")
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for overlap add op."""
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from tflite_micro.python.tflite_micro.signal.ops import overlap_add_op
from tflite_micro.python.tflite_micro.signal.utils import util
class OverlapAddOpTest(parameterized.TestCase, tf.test.TestCase):
def RunOverlapAdd(self, interpreter, input_frames, frame_step,
expected_output_frames, dtype):
input_frames = tf.convert_to_tensor(input_frames, dtype=dtype)
# TFLM
interpreter.set_input(input_frames, 0)
interpreter.invoke()
output_frame = interpreter.get_output(0)
self.assertAllEqual(output_frame, expected_output_frames)
# TF
output_frame = self.evaluate(
overlap_add_op.overlap_add(input_frames, frame_step))
self.assertAllEqual(output_frame, expected_output_frames)
@parameterized.named_parameters(('_FLOAT32InputOutput', tf.float32),
('_INT16InputOutput', tf.int16))
def testOverlapAddValidInput(self, dtype):
input_frames = np.array([[1, -5, 4, 2, 7], [4, 15, -44, 27, -16],
[66, -19, 79, 8, -12], [-122, 17, 65, 18, -101],
[3, 33, -66, -19, 55]])
expected_output_frames_step_1 = np.array([[1], [-1], [85], [-183], [133]])
expected_output_frames_step_2 = np.array([[1, -5], [8, 17], [29, 8],
[-59, 25], [56, 51]])
expected_output_frames_step_3 = np.array([[1, -5, 4], [6, 22, -44],
[93, -35, 79], [-114, 5, 65],
[21, -68, -66]])
expected_output_frames_step_4 = np.array([[1, -5, 4, 2], [11, 15, -44, 27],
[50, -19, 79, 8],
[-134, 17, 65, 18],
[-98, 33, -66, -19]])
expected_output_frames_step_5 = np.array([[1, -5, 4, 2, 7],
[4, 15, -44, 27, -16],
[66, -19, 79, 8, -12],
[-122, 17, 65, 18, -101],
[3, 33, -66, -19, 55]])
func = tf.function(overlap_add_op.overlap_add)
# Initialize an interpreter for each step size
# TODO(b/263020764): use a parameterized test instead
interpreters = [None] * 6
for i in range(5):
interpreters[i] = util.get_tflm_interpreter(
func.get_concrete_function(
tf.TensorSpec(np.shape([input_frames[0]]), dtype=dtype), i + 1),
func)
frame_num = input_frames.shape[0]
frame_index = 0
while frame_index < frame_num:
self.RunOverlapAdd(interpreters[0], [input_frames[frame_index]],
1,
expected_output_frames_step_1[frame_index],
dtype=dtype)
self.RunOverlapAdd(interpreters[1], [input_frames[frame_index]],
2,
expected_output_frames_step_2[frame_index],
dtype=dtype)
self.RunOverlapAdd(interpreters[2], [input_frames[frame_index]],
3,
expected_output_frames_step_3[frame_index],
dtype=dtype)
self.RunOverlapAdd(interpreters[3], [input_frames[frame_index]],
4,
expected_output_frames_step_4[frame_index],
dtype=dtype)
self.RunOverlapAdd(interpreters[4], [input_frames[frame_index]],
5,
expected_output_frames_step_5[frame_index],
dtype=dtype)
frame_index += 1
@parameterized.named_parameters(('_FLOAT32InputOutput', tf.float32),
('_INT16InputOutput', tf.int16))
def testOverlapAddNframes5(self, dtype):
input_frames = np.array([[1, -5, 4, 2, 7], [4, 15, -44, 27, -16],
[66, -19, 79, 8, -12], [-122, 17, 65, 18, -101],
[3, 33, -66, -19, 55]])
expected_output_frames_step_1 = np.array([1, -1, 85, -183, 133])
expected_output_frames_step_2 = np.array(
[1, -5, 8, 17, 29, 8, -59, 25, 56, 51])
expected_output_frames_step_3 = np.array(
[1, -5, 4, 6, 22, -44, 93, -35, 79, -114, 5, 65, 21, -68, -66])
expected_output_frames_step_4 = np.array([
1, -5, 4, 2, 11, 15, -44, 27, 50, -19, 79, 8, -134, 17, 65, 18, -98,
33, -66, -19
])
expected_output_frames_step_5 = np.array([
1, -5, 4, 2, 7, 4, 15, -44, 27, -16, 66, -19, 79, 8, -12, -122, 17, 65,
18, -101, 3, 33, -66, -19, 55
])
func = tf.function(overlap_add_op.overlap_add)
# Initialize an interpreter for each step size
# TODO(b/263020764): use a parameterized test instead
interpreters = [None] * 6
for i in range(5):
interpreters[i] = util.get_tflm_interpreter(
func.get_concrete_function(
tf.TensorSpec(np.shape(input_frames), dtype=dtype), i + 1), func)
self.RunOverlapAdd(interpreters[0],
input_frames,
1,
expected_output_frames_step_1,
dtype=dtype)
self.RunOverlapAdd(interpreters[1],
input_frames,
2,
expected_output_frames_step_2,
dtype=dtype)
self.RunOverlapAdd(interpreters[2],
input_frames,
3,
expected_output_frames_step_3,
dtype=dtype)
self.RunOverlapAdd(interpreters[3],
input_frames,
4,
expected_output_frames_step_4,
dtype=dtype)
self.RunOverlapAdd(interpreters[4],
input_frames,
5,
expected_output_frames_step_5,
dtype=dtype)
@parameterized.named_parameters(('_FLOAT32InputOutput', tf.float32),
('_INT16InputOutput', tf.int16))
def testOverlapAddNframes5Channels2(self, dtype):
input_frames = np.array([[[1, -5, 4, 2, 7], [4, 15, -44, 27, -16],
[66, -19, 79, 8, -12], [-122, 17, 65, 18, -101],
[3, 33, -66, -19, 55]],
[[1, -5, 4, 2, 7], [4, 15, -44, 27, -16],
[66, -19, 79, 8, -12], [-122, 17, 65, 18, -101],
[3, 33, -66, -19, 55]]])
expected_output_frames_step_1 = np.array([[1, -1, 85, -183, 133],
[1, -1, 85, -183, 133]])
expected_output_frames_step_2 = np.array(
[[1, -5, 8, 17, 29, 8, -59, 25, 56, 51],
[1, -5, 8, 17, 29, 8, -59, 25, 56, 51]])
expected_output_frames_step_3 = np.array(
[[1, -5, 4, 6, 22, -44, 93, -35, 79, -114, 5, 65, 21, -68, -66],
[1, -5, 4, 6, 22, -44, 93, -35, 79, -114, 5, 65, 21, -68, -66]])
expected_output_frames_step_4 = np.array([[
1, -5, 4, 2, 11, 15, -44, 27, 50, -19, 79, 8, -134, 17, 65, 18, -98,
33, -66, -19
],
[
1, -5, 4, 2, 11, 15, -44, 27,
50, -19, 79, 8, -134, 17, 65,
18, -98, 33, -66, -19
]])
expected_output_frames_step_5 = np.array([[
1, -5, 4, 2, 7, 4, 15, -44, 27, -16, 66, -19, 79, 8, -12, -122, 17, 65,
18, -101, 3, 33, -66, -19, 55
],
[
1, -5, 4, 2, 7, 4, 15, -44,
27, -16, 66, -19, 79, 8, -12,
-122, 17, 65, 18, -101, 3,
33, -66, -19, 55
]])
func = tf.function(overlap_add_op.overlap_add)
# Initialize an interpreter for each step size
# TODO(b/263020764): use a parameterized test instead
interpreters = [None] * 6
for i in range(5):
interpreters[i] = util.get_tflm_interpreter(
func.get_concrete_function(
tf.TensorSpec(np.shape(input_frames), dtype=dtype), i + 1), func)
self.RunOverlapAdd(interpreters[0],
input_frames,
1,
expected_output_frames_step_1,
dtype=dtype)
self.RunOverlapAdd(interpreters[1],
input_frames,
2,
expected_output_frames_step_2,
dtype=dtype)
self.RunOverlapAdd(interpreters[2],
input_frames,
3,
expected_output_frames_step_3,
dtype=dtype)
self.RunOverlapAdd(interpreters[3],
input_frames,
4,
expected_output_frames_step_4,
dtype=dtype)
self.RunOverlapAdd(interpreters[4],
input_frames,
5,
expected_output_frames_step_5,
dtype=dtype)
def testStepSizeTooLarge(self):
ovlerap_add_input = np.zeros(160, dtype=np.int16)
with self.assertRaises((tf.errors.InvalidArgumentError, ValueError)):
self.evaluate(overlap_add_op.overlap_add(ovlerap_add_input, 128, 129))
def testStepSizeNotEqualOutputSize(self):
ovlerap_add_input = np.zeros(122, dtype=np.int16)
with self.assertRaises((tf.errors.InvalidArgumentError, ValueError)):
self.evaluate(overlap_add_op.overlap_add(ovlerap_add_input, 321, 123))
if __name__ == '__main__':
tf.test.main()
......@@ -7,6 +7,9 @@ package(
)
exports_files([
"energy_test1.txt",
"fft_auto_scale_test1.txt",
"framer_test1.txt",
"rfft_test1.txt",
"window_test1.txt",
])
因为 它太大了无法显示 source diff 。你可以改为 查看blob
"""Build rule for wrapping a custom TF OP from .cc to python."""
load("@rules_python//python:defs.bzl", "py_library")
# TODO(b/286890280): refactor to be more generic build target for any custom OP
def py_tflm_signal_library(
name,
......@@ -61,7 +63,7 @@ def py_tflm_signal_library(
] + select({"//conditions:default": []}),
)
native.py_library(
py_library(
name = name,
srcs = srcs,
srcs_version = "PY2AND3",
......
......@@ -10,13 +10,20 @@ cc_library(
srcs = [
"delay.cc",
"energy.cc",
"fft_auto_scale.cc",
"filter_bank.cc",
"filter_bank_log.cc",
"filter_bank_spectral_subtraction.cc",
"filter_bank_square_root.cc",
"framer.cc",
"irfft.cc",
"overlap_add.cc",
"rfft.cc",
"stacker.cc",
"window.cc",
],
hdrs = [
"irfft.h",
"rfft.h",
],
copts = micro_copts(),
......@@ -26,6 +33,12 @@ cc_library(
deps = [
"//signal/src:circular_buffer",
"//signal/src:energy",
"//signal/src:fft_auto_scale",
"//signal/src:filter_bank",
"//signal/src:filter_bank_log",
"//signal/src:filter_bank_spectral_subtraction",
"//signal/src:filter_bank_square_root",
"//signal/src:irfft",
"//signal/src:overlap_add",
"//signal/src:rfft",
"//signal/src:window",
......@@ -222,3 +235,92 @@ cc_test(
"//tensorflow/lite/micro/testing:micro_test",
],
)
cc_library(
name = "filter_bank_flexbuffers_generated_data",
srcs = [
"filter_bank_flexbuffers_generated_data.cc",
],
hdrs = [
"filter_bank_flexbuffers_generated_data.h",
],
)
cc_library(
name = "filter_bank_log_flexbuffers_generated_data",
srcs = [
"filter_bank_log_flexbuffers_generated_data.cc",
],
hdrs = [
"filter_bank_log_flexbuffers_generated_data.h",
],
)
cc_library(
name = "filter_bank_spectral_subtraction_flexbuffers_generated_data",
srcs = [
"filter_bank_spectral_subtraction_flexbuffers_generated_data.cc",
],
hdrs = [
"filter_bank_spectral_subtraction_flexbuffers_generated_data.h",
],
)
cc_test(
name = "filter_bank_test",
srcs = [
"filter_bank_test.cc",
],
deps = [
":filter_bank_flexbuffers_generated_data",
"//tensorflow/lite/c:common",
"//tensorflow/lite/micro:op_resolvers",
"//tensorflow/lite/micro:test_helpers",
"//tensorflow/lite/micro/kernels:kernel_runner",
"//tensorflow/lite/micro/testing:micro_test",
],
)
cc_test(
name = "filter_bank_log_test",
srcs = [
"filter_bank_log_test.cc",
],
deps = [
":filter_bank_log_flexbuffers_generated_data",
"//tensorflow/lite/c:common",
"//tensorflow/lite/micro:op_resolvers",
"//tensorflow/lite/micro:test_helpers",
"//tensorflow/lite/micro/kernels:kernel_runner",
"//tensorflow/lite/micro/testing:micro_test",
],
)
cc_test(
name = "filter_bank_spectral_subtraction_test",
srcs = [
"filter_bank_spectral_subtraction_test.cc",
],
deps = [
":filter_bank_spectral_subtraction_flexbuffers_generated_data",
"//tensorflow/lite/c:common",
"//tensorflow/lite/micro:op_resolvers",
"//tensorflow/lite/micro:test_helpers",
"//tensorflow/lite/micro/kernels:kernel_runner",
"//tensorflow/lite/micro/testing:micro_test",
],
)
cc_test(
name = "filter_bank_square_root_test",
srcs = [
"filter_bank_square_root_test.cc",
],
deps = [
"//tensorflow/lite/c:common",
"//tensorflow/lite/micro:op_resolvers",
"//tensorflow/lite/micro:test_helpers",
"//tensorflow/lite/micro/kernels:kernel_runner",
"//tensorflow/lite/micro/testing:micro_test",
],
)
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "signal/src/fft_auto_scale.h"
#include <math.h>
#include <stddef.h>
#include <stdint.h>
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "tensorflow/lite/micro/micro_context.h"
namespace tflite {
namespace {
constexpr int kInputTensor = 0;
constexpr int kOutputTensor = 0;
constexpr int kScaleBitTensor = 1;
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 2);
MicroContext* micro_context = GetMicroContext(context);
TfLiteTensor* input =
micro_context->AllocateTempInputTensor(node, kInputTensor);
TF_LITE_ENSURE(context, input != nullptr);
TfLiteTensor* output =
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
TfLiteTensor* scale_bit =
micro_context->AllocateTempOutputTensor(node, kScaleBitTensor);
TF_LITE_ENSURE(context, scale_bit != nullptr);
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 1);
TF_LITE_ENSURE_EQ(context, NumDimensions(output), 1);
TF_LITE_ENSURE_EQ(context, NumDimensions(scale_bit), 0);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteInt16);
TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteInt16);
TF_LITE_ENSURE_TYPES_EQ(context, scale_bit->type, kTfLiteInt32);
micro_context->DeallocateTempTfLiteTensor(scale_bit);
micro_context->DeallocateTempTfLiteTensor(input);
micro_context->DeallocateTempTfLiteTensor(output);
return kTfLiteOk;
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteEvalTensor* input =
tflite::micro::GetEvalInput(context, node, kInputTensor);
TfLiteEvalTensor* output =
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
TfLiteEvalTensor* scale_bit =
tflite::micro::GetEvalOutput(context, node, kScaleBitTensor);
const int16_t* input_data = tflite::micro::GetTensorData<int16_t>(input);
int16_t* output_data = tflite::micro::GetTensorData<int16_t>(output);
int32_t* scale_bit_data = tflite::micro::GetTensorData<int32_t>(scale_bit);
*scale_bit_data =
tflm_signal::FftAutoScale(input_data, output->dims->data[0], output_data);
return kTfLiteOk;
}
} // namespace
// TODO(b/286250473): remove namespace once de-duped libraries
namespace tflm_signal {
TFLMRegistration* Register_FFT_AUTO_SCALE() {
static TFLMRegistration r = tflite::micro::RegisterOp(nullptr, Prepare, Eval);
return &r;
}
} // namespace tflm_signal
} // namespace tflite
......@@ -85,6 +85,43 @@ TfLiteStatus TestFFT(int* input_dims_data, const T* input_data,
return kTfLiteOk;
}
TfLiteStatus TestFFTAutoScale(int* input_dims_data, const int16_t* input_data,
int* output_dims_data, const int16_t* golden,
int* scale_bit_dims_data,
const int32_t scale_bit_golden,
const TFLMRegistration registration,
const uint8_t* flexbuffers_data,
const int flexbuffers_data_len,
int16_t* output_data, int32_t* scale_bit) {
TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
TfLiteIntArray* scale_bit_dims = IntArrayFromInts(scale_bit_dims_data);
constexpr int kInputsSize = 1;
constexpr int kOutputsSize = 2;
constexpr int kTensorsSize = kInputsSize + kOutputsSize;
TfLiteTensor tensors[kTensorsSize] = {
CreateTensor(input_data, input_dims),
CreateTensor(output_data, output_dims),
CreateTensor(scale_bit, scale_bit_dims),
};
int inputs_array_data[] = {1, 0};
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {2, 1, 2};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
const int output_len = ElementCount(*output_dims);
TF_LITE_ENSURE_STATUS(ValidateFFTGoldens<int16_t>(
tensors, kTensorsSize, inputs_array, outputs_array, output_len, golden,
registration, flexbuffers_data, flexbuffers_data_len, output_data, 0));
TF_LITE_MICRO_EXPECT_EQ(scale_bit_golden, *scale_bit);
return kTfLiteOk;
}
} // namespace
} // namespace testing
......@@ -266,4 +303,219 @@ TF_LITE_MICRO_TEST(RfftTestSize512Int32) {
g_gen_data_size_fft_length_512_int32, output, 0));
}
TF_LITE_MICRO_TEST(IrfftTestLength64Float) {
constexpr int kOutputLen = 64;
int input_shape[] = {1, 66};
const float input[] = {256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0};
int output_shape[] = {1, kOutputLen};
const float golden[] = {256, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
float output[kOutputLen];
const TFLMRegistration* registration =
tflite::tflm_signal::Register_IRFFT_FLOAT();
TF_LITE_MICRO_EXPECT_EQ(
kTfLiteOk, tflite::testing::TestFFT<float>(
input_shape, input, output_shape, golden, *registration,
g_gen_data_fft_length_64_float,
g_gen_data_size_fft_length_64_int16, output, 1e-7));
}
TF_LITE_MICRO_TEST(IrfftTestLength64Int16) {
constexpr int kOutputLen = 64;
int input_shape[] = {1, 66};
const int16_t input[] = {
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0};
int output_shape[] = {1, kOutputLen};
const int16_t golden[] = {256, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
int16_t output[kOutputLen];
const TFLMRegistration* registration =
tflite::tflm_signal::Register_IRFFT_INT16();
TF_LITE_MICRO_EXPECT_EQ(
kTfLiteOk, tflite::testing::TestFFT<int16_t>(
input_shape, input, output_shape, golden, *registration,
g_gen_data_fft_length_64_int16,
g_gen_data_size_fft_length_64_int16, output, 0));
}
TF_LITE_MICRO_TEST(IrfftTestLength64Int32) {
constexpr int kOutputLen = 64;
int input_shape[] = {1, 66};
const int32_t input[] = {
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0};
int output_shape[] = {1, kOutputLen};
const int32_t golden[] = {256, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
int32_t output[kOutputLen];
const TFLMRegistration* registration =
tflite::tflm_signal::Register_IRFFT_INT32();
TF_LITE_MICRO_EXPECT_EQ(
kTfLiteOk, tflite::testing::TestFFT<int32_t>(
input_shape, input, output_shape, golden, *registration,
g_gen_data_fft_length_64_int32,
g_gen_data_size_fft_length_64_int32, output, 0));
}
TF_LITE_MICRO_TEST(IrfftTestLength64Int32OuterDims4) {
constexpr int kOutputLen = 64;
constexpr int kOuterDim = 2;
int input_shape[] = {3, kOuterDim, kOuterDim, 66};
const int32_t input[] = {
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0};
int output_shape[] = {3, kOuterDim, kOuterDim, kOutputLen};
const int32_t golden[] = {
256, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 256, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 256, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 256, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
int32_t output[kOuterDim * kOuterDim * kOutputLen];
const TFLMRegistration* registration =
tflite::tflm_signal::Register_IRFFT_INT32();
TF_LITE_MICRO_EXPECT_EQ(
kTfLiteOk, tflite::testing::TestFFT<int32_t>(
input_shape, input, output_shape, golden, *registration,
g_gen_data_fft_length_64_int32,
g_gen_data_size_fft_length_64_int32, output, 0));
}
TF_LITE_MICRO_TEST(IrfftTestLength512Float) {
constexpr int kOutputLen = 512;
int input_shape[] = {1, 514};
int output_shape[] = {1, kOutputLen};
float output[kOutputLen];
const TFLMRegistration* registration =
tflite::tflm_signal::Register_IRFFT_FLOAT();
TF_LITE_MICRO_EXPECT_EQ(
kTfLiteOk, tflite::testing::TestFFT<float>(
input_shape, tflite::kIrfftFloatLength512Input,
output_shape, tflite::kIrfftFloatLength512Golden,
*registration, g_gen_data_fft_length_512_float,
g_gen_data_size_fft_length_512_float, output, 1e-6));
}
TF_LITE_MICRO_TEST(IrfftTestLength512Int16) {
constexpr int kOutputLen = 512;
int input_shape[] = {1, 514};
int output_shape[] = {1, kOutputLen};
int16_t output[kOutputLen];
const TFLMRegistration* registration =
tflite::tflm_signal::Register_IRFFT_INT16();
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
tflite::testing::TestFFT<int16_t>(
input_shape, tflite::kIrfftInt16Length512Input,
output_shape, tflite::kIrfftInt16Length512Golden,
*registration, g_gen_data_fft_length_512_int16,
g_gen_data_size_fft_length_512_int16, output, 0));
}
TF_LITE_MICRO_TEST(IrfftTestLength512Int32) {
constexpr int kOutputLen = 512;
int input_shape[] = {1, 514};
int output_shape[] = {1, kOutputLen};
int32_t output[kOutputLen];
const TFLMRegistration* registration =
tflite::tflm_signal::Register_IRFFT_INT32();
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
tflite::testing::TestFFT<int32_t>(
input_shape, tflite::kIrfftInt32Length512Input,
output_shape, tflite::kIrfftInt32Length512Golden,
*registration, g_gen_data_fft_length_512_int32,
g_gen_data_size_fft_length_512_int32, output, 0));
}
TF_LITE_MICRO_TEST(FftAutoScaleTestSmall) {
constexpr int kTensorsSize = 8;
int shape[] = {1, 8};
const int16_t input[] = {0x0000, 0x1111, 0x2222, 0x3333,
0x3333, 0x2222, 0x1111, 0x0000};
int16_t output[kTensorsSize];
int scale_bit_shape[] = {0};
int32_t scale_bit;
const int16_t golden[] = {0x0000, 0x2222, 0x4444, 0x6666,
0x6666, 0x4444, 0x2222, 0x0000};
const int32_t scale_bit_golden = 1;
const TFLMRegistration* registration =
tflite::tflm_signal::Register_FFT_AUTO_SCALE();
TF_LITE_MICRO_EXPECT_EQ(
kTfLiteOk,
tflite::testing::TestFFTAutoScale(
shape, input, shape, golden, scale_bit_shape, scale_bit_golden,
*registration, nullptr, 0, output, &scale_bit));
}
TF_LITE_MICRO_TEST(FftAutoScaleTestScaleBit) {
constexpr int kTensorsSize = 8;
int shape[] = {1, 8};
const int16_t input[] = {238, 113, -88, -243, -5, -130, 159, -70};
int16_t output[kTensorsSize];
int scale_bit_shape[] = {0};
int32_t scale_bit;
const int16_t golden[] = {30464, 14464, -11264, -31104,
-640, -16640, 20352, -8960};
const int32_t scale_bit_golden = 7;
const TFLMRegistration* registration =
tflite::tflm_signal::Register_FFT_AUTO_SCALE();
TF_LITE_MICRO_EXPECT_EQ(
kTfLiteOk,
tflite::testing::TestFFTAutoScale(
shape, input, shape, golden, scale_bit_shape, scale_bit_golden,
*registration, nullptr, 0, output, &scale_bit));
}
TF_LITE_MICRO_TEST(FftAutoScaleTestLarge) {
constexpr int kTensorsSize = 400;
int shape[] = {1, kTensorsSize};
int16_t output[kTensorsSize];
int scale_bit_shape[] = {0};
int32_t scale_bit;
const int32_t scale_bit_golden = 0;
const TFLMRegistration* registration =
tflite::tflm_signal::Register_FFT_AUTO_SCALE();
TF_LITE_MICRO_EXPECT_EQ(
kTfLiteOk,
tflite::testing::TestFFTAutoScale(
shape, tflite::kFftAutoScaleLength512Input, shape,
tflite::kFftAutoScaleLength512Golden, scale_bit_shape,
scale_bit_golden, *registration, nullptr, 0, output, &scale_bit));
}
TF_LITE_MICRO_TESTS_END
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "signal/src/filter_bank.h"
#include <stdint.h>
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/micro/flatbuffer_utils.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "tensorflow/lite/micro/memory_helpers.h"
#include "tensorflow/lite/micro/micro_context.h"
#include "tensorflow/lite/micro/micro_utils.h"
namespace tflite {
namespace {
constexpr int kInputTensor = 0;
constexpr int kWeightTensor = 1;
constexpr int kUnweightTensor = 2;
constexpr int kChFreqStartsTensor = 3;
constexpr int kChWeightStartsTensor = 4;
constexpr int kChannelWidthsTensor = 5;
constexpr int kOutputTensor = 0;
// Indices into the init flexbuffer's vector.
// The parameter's name is in the comment that follows.
// Elements in the vectors are ordered alphabetically by parameter name.
constexpr int kNumChannelsIndex = 0; // 'num_channels'
struct TFLMSignalFilterBankParams {
tflm_signal::FilterbankConfig config;
uint64_t* work_area;
};
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
auto* params = static_cast<TFLMSignalFilterBankParams*>(
context->AllocatePersistentBuffer(context,
sizeof(TFLMSignalFilterBankParams)));
if (params == nullptr) {
return nullptr;
}
tflite::FlexbufferWrapper fbw(reinterpret_cast<const uint8_t*>(buffer),
length);
params->config.num_channels = fbw.ElementAsInt32(kNumChannelsIndex);
params->work_area = static_cast<uint64_t*>(context->AllocatePersistentBuffer(
context, (params->config.num_channels + 1) * sizeof(uint64_t)));
if (params->work_area == nullptr) {
return nullptr;
}
return params;
}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 6);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
MicroContext* micro_context = GetMicroContext(context);
TfLiteTensor* input =
micro_context->AllocateTempInputTensor(node, kInputTensor);
TF_LITE_ENSURE(context, input != nullptr);
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 1);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteUInt32);
micro_context->DeallocateTempTfLiteTensor(input);
input = micro_context->AllocateTempInputTensor(node, kWeightTensor);
TF_LITE_ENSURE(context, input != nullptr);
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 1);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteInt16);
micro_context->DeallocateTempTfLiteTensor(input);
input = micro_context->AllocateTempInputTensor(node, kUnweightTensor);
TF_LITE_ENSURE(context, input != nullptr);
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 1);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteInt16);
micro_context->DeallocateTempTfLiteTensor(input);
input = micro_context->AllocateTempInputTensor(node, kChFreqStartsTensor);
TF_LITE_ENSURE(context, input != nullptr);
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 1);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteInt16);
micro_context->DeallocateTempTfLiteTensor(input);
input = micro_context->AllocateTempInputTensor(node, kChWeightStartsTensor);
TF_LITE_ENSURE(context, input != nullptr);
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 1);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteInt16);
micro_context->DeallocateTempTfLiteTensor(input);
input = micro_context->AllocateTempInputTensor(node, kChannelWidthsTensor);
TF_LITE_ENSURE(context, input != nullptr);
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 1);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteInt16);
micro_context->DeallocateTempTfLiteTensor(input);
TfLiteTensor* output =
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
TF_LITE_ENSURE_EQ(context, NumDimensions(output), 1);
TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteUInt64);
micro_context->DeallocateTempTfLiteTensor(output);
return kTfLiteOk;
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TFLMSignalFilterBankParams*>(node->user_data);
const TfLiteEvalTensor* input0 =
tflite::micro::GetEvalInput(context, node, kInputTensor);
const TfLiteEvalTensor* input1 =
tflite::micro::GetEvalInput(context, node, kWeightTensor);
const TfLiteEvalTensor* input2 =
tflite::micro::GetEvalInput(context, node, kUnweightTensor);
const TfLiteEvalTensor* input3 =
tflite::micro::GetEvalInput(context, node, kChFreqStartsTensor);
const TfLiteEvalTensor* input4 =
tflite::micro::GetEvalInput(context, node, kChWeightStartsTensor);
const TfLiteEvalTensor* input5 =
tflite::micro::GetEvalInput(context, node, kChannelWidthsTensor);
TfLiteEvalTensor* output =
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
params->config.weights = tflite::micro::GetTensorData<int16_t>(input1);
params->config.unweights = tflite::micro::GetTensorData<int16_t>(input2);
params->config.channel_frequency_starts =
tflite::micro::GetTensorData<int16_t>(input3);
params->config.channel_weight_starts =
tflite::micro::GetTensorData<int16_t>(input4);
params->config.channel_widths = tflite::micro::GetTensorData<int16_t>(input5);
const uint32_t* input_data = tflite::micro::GetTensorData<uint32_t>(input0);
uint64_t* output_data = tflite::micro::GetTensorData<uint64_t>(output);
tflm_signal::FilterbankAccumulateChannels(&params->config, input_data,
params->work_area);
size_t output_size;
TfLiteTypeSizeOf(output->type, &output_size);
output_size *= ElementCount(*output->dims);
// Discard channel 0, which is just scratch
memcpy(output_data, params->work_area + 1, output_size);
return kTfLiteOk;
}
} // namespace
namespace tflm_signal {
TFLMRegistration* Register_FILTER_BANK() {
static TFLMRegistration r = tflite::micro::RegisterOp(Init, Prepare, Eval);
return &r;
}
} // namespace tflm_signal
} // namespace tflite
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file is generated. See:
// tensorflow/lite/micro/kernels/test_data_generation/README.md
#include "signal/micro/kernels/filter_bank_flexbuffers_generated_data.h"
const int g_gen_data_size_filter_bank_32_channel = 23;
const unsigned char g_gen_data_filter_bank_32_channel[] = {
0x6e, 0x75, 0x6d, 0x5f, 0x63, 0x68, 0x61, 0x6e, 0x6e, 0x65, 0x6c, 0x73,
0x00, 0x01, 0x0e, 0x01, 0x01, 0x01, 0x20, 0x04, 0x02, 0x24, 0x01,
};
const int g_gen_data_size_filter_bank_16_channel = 23;
const unsigned char g_gen_data_filter_bank_16_channel[] = {
0x6e, 0x75, 0x6d, 0x5f, 0x63, 0x68, 0x61, 0x6e, 0x6e, 0x65, 0x6c, 0x73,
0x00, 0x01, 0x0e, 0x01, 0x01, 0x01, 0x10, 0x04, 0x02, 0x24, 0x01,
};
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef SIGNAL_MICRO_KERNELS_TEST_DATA_GENERATION_GENERATE_FILTER_BANK_FLEXBUFFERS_DATA_H_
#define SIGNAL_MICRO_KERNELS_TEST_DATA_GENERATION_GENERATE_FILTER_BANK_FLEXBUFFERS_DATA_H_
extern const int g_gen_data_size_filter_bank_32_channel;
extern const unsigned char g_gen_data_filter_bank_32_channel[];
extern const int g_gen_data_size_filter_bank_16_channel;
extern const unsigned char g_gen_data_filter_bank_16_channel[];
#endif // SIGNAL_MICRO_KERNELS_TEST_DATA_GENERATION_GENERATE_FILTER_BANK_FLEXBUFFERS_DATA_H_
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "signal/src/filter_bank_log.h"
#include <stdint.h>
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/micro/flatbuffer_utils.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "tensorflow/lite/micro/memory_helpers.h"
#include "tensorflow/lite/micro/micro_context.h"
#include "tensorflow/lite/micro/micro_utils.h"
namespace tflite {
namespace {
constexpr int kInputTensor = 0;
constexpr int kOutputTensor = 0;
// Indices into the init flexbuffer's vector.
// The parameter's name is in the comment that follows.
// Elements in the vectors are ordered alphabetically by parameter name.
constexpr int kInputCorrectionBitsIndex = 0; // 'input_correction_bits'
constexpr int kOutputScaleIndex = 1; // 'output_scale'
struct TFLMSignalLogParams {
int input_correction_bits;
int output_scale;
};
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
auto* params = static_cast<TFLMSignalLogParams*>(
context->AllocatePersistentBuffer(context, sizeof(TFLMSignalLogParams)));
if (params == nullptr) {
return nullptr;
}
tflite::FlexbufferWrapper fbw(reinterpret_cast<const uint8_t*>(buffer),
length);
params->input_correction_bits = fbw.ElementAsInt32(kInputCorrectionBitsIndex);
params->output_scale = fbw.ElementAsInt32(kOutputScaleIndex);
return params;
}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
MicroContext* micro_context = GetMicroContext(context);
TfLiteTensor* input =
micro_context->AllocateTempInputTensor(node, kInputTensor);
TfLiteTensor* output =
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
TF_LITE_ENSURE(context, input != nullptr);
TF_LITE_ENSURE(context, output != nullptr);
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 1);
TF_LITE_ENSURE_EQ(context, NumDimensions(output), 1);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteUInt32);
TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteInt16);
micro_context->DeallocateTempTfLiteTensor(input);
micro_context->DeallocateTempTfLiteTensor(output);
return kTfLiteOk;
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TFLMSignalLogParams*>(node->user_data);
const TfLiteEvalTensor* input =
tflite::micro::GetEvalInput(context, node, kInputTensor);
TfLiteEvalTensor* output =
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
const uint32_t* input_data = tflite::micro::GetTensorData<uint32_t>(input);
int16_t* output_data = tflite::micro::GetTensorData<int16_t>(output);
int num_channels = input->dims->data[0];
tflm_signal::FilterbankLog(input_data, num_channels, params->output_scale,
params->input_correction_bits, output_data);
return kTfLiteOk;
}
} // namespace
namespace tflm_signal {
TFLMRegistration* Register_FILTER_BANK_LOG() {
static TFLMRegistration r = tflite::micro::RegisterOp(Init, Prepare, Eval);
return &r;
}
} // namespace tflm_signal
} // namespace tflite
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file is generated. See:
// tensorflow/lite/micro/kernels/test_data_generation/README.md
#include "signal/micro/kernels/filter_bank_log_flexbuffers_generated_data.h"
const int g_gen_data_size_filter_bank_log_scale_1600_correction_bits_3 = 53;
const unsigned char g_gen_data_filter_bank_log_scale_1600_correction_bits_3[] =
{
0x6f, 0x75, 0x74, 0x70, 0x75, 0x74, 0x5f, 0x73, 0x63, 0x61, 0x6c,
0x65, 0x00, 0x69, 0x6e, 0x70, 0x75, 0x74, 0x5f, 0x63, 0x6f, 0x72,
0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x62, 0x69, 0x74,
0x73, 0x00, 0x02, 0x17, 0x25, 0x02, 0x00, 0x01, 0x00, 0x02, 0x00,
0x03, 0x00, 0x40, 0x06, 0x05, 0x05, 0x06, 0x25, 0x01,
};
const int g_gen_data_size_filter_bank_log_scale_32768_correction_bits_5 = 65;
const unsigned char g_gen_data_filter_bank_log_scale_32768_correction_bits_5[] =
{
0x6f, 0x75, 0x74, 0x70, 0x75, 0x74, 0x5f, 0x73, 0x63, 0x61, 0x6c,
0x65, 0x00, 0x69, 0x6e, 0x70, 0x75, 0x74, 0x5f, 0x63, 0x6f, 0x72,
0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x62, 0x69, 0x74,
0x73, 0x00, 0x02, 0x17, 0x25, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00,
0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00,
0x00, 0x00, 0x80, 0x00, 0x00, 0x06, 0x06, 0x0a, 0x26, 0x01,
};
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef SIGNAL_MICRO_KERNELS_TEST_DATA_GENERATION_GENERATE_FILTER_BANK_LOG_FLEXBUFFERS_DATA_H_
#define SIGNAL_MICRO_KERNELS_TEST_DATA_GENERATION_GENERATE_FILTER_BANK_LOG_FLEXBUFFERS_DATA_H_
extern const int g_gen_data_size_filter_bank_log_scale_1600_correction_bits_3;
extern const unsigned char
g_gen_data_filter_bank_log_scale_1600_correction_bits_3[];
extern const int g_gen_data_size_filter_bank_log_scale_32768_correction_bits_5;
extern const unsigned char
g_gen_data_filter_bank_log_scale_32768_correction_bits_5[];
#endif // SIGNAL_MICRO_KERNELS_TEST_DATA_GENERATION_GENERATE_FILTER_BANK_LOG_FLEXBUFFERS_DATA_H_
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstdint>
#include "signal/micro/kernels/filter_bank_log_flexbuffers_generated_data.h"
#include "tensorflow/lite/micro/kernels/kernel_runner.h"
#include "tensorflow/lite/micro/test_helpers.h"
#include "tensorflow/lite/micro/testing/micro_test.h"
namespace tflite {
namespace testing {
namespace {
TfLiteStatus TestFilterBankLog(int* input_dims_data, const uint32_t* input_data,
int* output_dims_data, const int16_t* golden,
const uint8_t* flexbuffers_data,
const int flexbuffers_data_len,
int16_t* output_data) {
TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
constexpr int kInputsSize = 1;
constexpr int kOutputsSize = 2;
constexpr int kTensorsSize = kInputsSize + kOutputsSize;
TfLiteTensor tensors[kTensorsSize] = {
CreateTensor(input_data, input_dims),
CreateTensor(output_data, output_dims),
};
int inputs_array_data[] = {1, 0};
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 1};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
const int output_len = ElementCount(*output_dims);
TFLMRegistration* registration =
tflite::tflm_signal::Register_FILTER_BANK_LOG();
micro::KernelRunner runner(*registration, tensors, kTensorsSize, inputs_array,
outputs_array,
/*builtin_data=*/nullptr);
// TfLite uses a char* for the raw bytes whereas flexbuffers use an unsigned
// char*. This small discrepancy results in compiler warnings unless we
// reinterpret_cast right before passing in the flexbuffer bytes to the
// KernelRunner.
TF_LITE_ENSURE_STATUS(runner.InitAndPrepare(
reinterpret_cast<const char*>(flexbuffers_data), flexbuffers_data_len));
TF_LITE_ENSURE_STATUS(runner.Invoke());
for (int i = 0; i < output_len; ++i) {
TF_LITE_MICRO_EXPECT_EQ(golden[i], output_data[i]);
}
return kTfLiteOk;
}
} // namespace
} // namespace testing
} // namespace tflite
TF_LITE_MICRO_TESTS_BEGIN
TF_LITE_MICRO_TEST(FilterBankLogTest32Channel) {
int input_shape[] = {1, 32};
int output_shape[] = {1, 32};
const uint32_t input[] = {29, 21, 29, 40, 19, 11, 13, 23, 13, 11, 25,
17, 5, 4, 46, 14, 17, 14, 20, 14, 10, 10,
15, 11, 17, 12, 15, 16, 19, 18, 6, 2};
const int16_t golden[] = {8715, 8198, 8715, 9229, 8038, 7164, 7431, 8344,
7431, 7164, 8477, 7860, 5902, 5545, 9453, 7550,
7860, 7550, 8120, 7550, 7011, 7011, 7660, 7164,
7860, 7303, 7660, 7763, 8038, 7952, 6194, 4436};
int16_t output[32];
TF_LITE_MICRO_EXPECT_EQ(
kTfLiteOk,
tflite::testing::TestFilterBankLog(
input_shape, input, output_shape, golden,
g_gen_data_filter_bank_log_scale_1600_correction_bits_3,
g_gen_data_size_filter_bank_log_scale_1600_correction_bits_3,
output));
}
TF_LITE_MICRO_TEST(FilterBankLogTest16Channel) {
int input_shape[] = {1, 16};
int output_shape[] = {1, 16};
const uint32_t input[] = {48, 20, 19, 24, 35, 47, 23, 30,
31, 10, 48, 21, 46, 14, 18, 27};
const int16_t golden[] = {32767, 15121, 13440, 21095, 32767, 32767,
19701, 28407, 29482, 32767, 32767, 16720,
32767, 3434, 11669, 24955};
int16_t output[16];
TF_LITE_MICRO_EXPECT_EQ(
kTfLiteOk,
tflite::testing::TestFilterBankLog(
input_shape, input, output_shape, golden,
g_gen_data_filter_bank_log_scale_32768_correction_bits_5,
g_gen_data_size_filter_bank_log_scale_32768_correction_bits_5,
output));
}
TF_LITE_MICRO_TESTS_END
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "signal/src/filter_bank_spectral_subtraction.h"
#include <stdint.h>
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/micro/flatbuffer_utils.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "tensorflow/lite/micro/memory_helpers.h"
#include "tensorflow/lite/micro/micro_utils.h"
namespace tflite {
namespace {
constexpr int kInputTensor = 0;
constexpr int kOutputTensor = 0;
constexpr int kNoiseEstimateTensor = 1;
// Indices into the init flexbuffer's vector.
// The parameter's name is in the comment that follows.
// Elements in the vectors are ordered alphabetically by parameter name.
// 'alternate_one_minus_smoothing'
constexpr int kAlternateOneMinusSmoothingIndex = 0;
constexpr int kAlternateSmoothingIndex = 1; // 'alternate_smoothing'
constexpr int kClampingIndex = 2; // 'clamping'
constexpr int kMinSignalRemainingIndex = 3; // 'min_signal_remaining'
constexpr int kNumChannelsIndex = 4; // 'num_channels'
constexpr int kOneMinusSmoothingIndex = 5; // 'one_minus_smoothing'
constexpr int kSmoothingIndex = 6; // 'smoothing'
constexpr int kSmoothingBitsIndex = 7; // 'smoothing_bits'
constexpr int kSpectralSubtractionBitsIndex = 8; // 'spectral_subtraction_bits'
struct TFLMSignalSpectralSubtractionParams {
tflm_signal::SpectralSubtractionConfig config;
uint32_t* noise_estimate;
size_t noise_estimate_size;
};
void ResetState(TFLMSignalSpectralSubtractionParams* params) {
memset(params->noise_estimate, 0,
sizeof(uint32_t) * params->config.num_channels);
}
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
auto* params = static_cast<TFLMSignalSpectralSubtractionParams*>(
context->AllocatePersistentBuffer(
context, sizeof(TFLMSignalSpectralSubtractionParams)));
if (params == nullptr) {
return nullptr;
}
tflite::FlexbufferWrapper fbw(reinterpret_cast<const uint8_t*>(buffer),
length);
params->config.alternate_one_minus_smoothing =
fbw.ElementAsInt32(kAlternateOneMinusSmoothingIndex);
params->config.alternate_smoothing =
fbw.ElementAsInt32(kAlternateSmoothingIndex);
params->config.clamping = fbw.ElementAsBool(kClampingIndex);
params->config.min_signal_remaining =
fbw.ElementAsInt32(kMinSignalRemainingIndex);
params->config.num_channels = fbw.ElementAsInt32(kNumChannelsIndex);
params->config.one_minus_smoothing =
fbw.ElementAsInt32(kOneMinusSmoothingIndex);
params->config.one_minus_smoothing =
fbw.ElementAsInt32(kOneMinusSmoothingIndex);
params->config.smoothing = fbw.ElementAsInt32(kSmoothingIndex);
params->config.smoothing_bits = fbw.ElementAsInt32(kSmoothingBitsIndex);
params->config.spectral_subtraction_bits =
fbw.ElementAsInt32(kSpectralSubtractionBitsIndex);
params->noise_estimate =
static_cast<uint32_t*>(context->AllocatePersistentBuffer(
context, params->config.num_channels * sizeof(uint32_t)));
if (params->noise_estimate == nullptr) {
return nullptr;
}
return params;
}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 2);
MicroContext* micro_context = GetMicroContext(context);
TfLiteTensor* input =
micro_context->AllocateTempInputTensor(node, kInputTensor);
TfLiteTensor* output =
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
TfLiteTensor* noise_estimate =
micro_context->AllocateTempOutputTensor(node, kNoiseEstimateTensor);
TF_LITE_ENSURE(context, input != nullptr);
TF_LITE_ENSURE(context, output != nullptr);
TF_LITE_ENSURE(context, noise_estimate != nullptr);
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 1);
TF_LITE_ENSURE_EQ(context, NumDimensions(output), 1);
TF_LITE_ENSURE_EQ(context, NumDimensions(noise_estimate), 1);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteUInt32);
TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteUInt32);
TF_LITE_ENSURE_TYPES_EQ(context, noise_estimate->type, kTfLiteUInt32);
auto* params =
reinterpret_cast<TFLMSignalSpectralSubtractionParams*>(node->user_data);
TfLiteTypeSizeOf(output->type, &params->noise_estimate_size);
params->noise_estimate_size *= ElementCount(*noise_estimate->dims);
ResetState(params);
micro_context->DeallocateTempTfLiteTensor(input);
micro_context->DeallocateTempTfLiteTensor(output);
micro_context->DeallocateTempTfLiteTensor(noise_estimate);
return kTfLiteOk;
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params =
reinterpret_cast<TFLMSignalSpectralSubtractionParams*>(node->user_data);
const TfLiteEvalTensor* input =
tflite::micro::GetEvalInput(context, node, kInputTensor);
TfLiteEvalTensor* output =
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
TfLiteEvalTensor* noise_estimate =
tflite::micro::GetEvalOutput(context, node, kNoiseEstimateTensor);
const uint32_t* input_data = tflite::micro::GetTensorData<uint32_t>(input);
uint32_t* output_data = tflite::micro::GetTensorData<uint32_t>(output);
uint32_t* noise_estimate_data =
tflite::micro::GetTensorData<uint32_t>(noise_estimate);
FilterbankSpectralSubtraction(&params->config, input_data, output_data,
params->noise_estimate);
memcpy(noise_estimate_data, params->noise_estimate,
params->noise_estimate_size);
return kTfLiteOk;
}
void Reset(TfLiteContext* context, void* buffer) {
ResetState(static_cast<TFLMSignalSpectralSubtractionParams*>(buffer));
}
} // namespace
namespace tflm_signal {
TFLMRegistration* Register_FILTER_BANK_SPECTRAL_SUBTRACTION() {
static TFLMRegistration r =
tflite::micro::RegisterOp(Init, Prepare, Eval, /*Free*/ nullptr, Reset);
return &r;
}
} // namespace tflm_signal
} // namespace tflite
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file is generated. See:
// tensorflow/lite/micro/kernels/test_data_generation/README.md
#include "signal/micro/kernels/filter_bank_spectral_subtraction_flexbuffers_generated_data.h"
const int g_gen_data_size_filter_bank_spectral_subtraction_32_channel = 210;
const unsigned char g_gen_data_filter_bank_spectral_subtraction_32_channel[] = {
0x6e, 0x75, 0x6d, 0x5f, 0x63, 0x68, 0x61, 0x6e, 0x6e, 0x65, 0x6c, 0x73,
0x00, 0x73, 0x6d, 0x6f, 0x6f, 0x74, 0x68, 0x69, 0x6e, 0x67, 0x00, 0x6f,
0x6e, 0x65, 0x5f, 0x6d, 0x69, 0x6e, 0x75, 0x73, 0x5f, 0x73, 0x6d, 0x6f,
0x6f, 0x74, 0x68, 0x69, 0x6e, 0x67, 0x00, 0x61, 0x6c, 0x74, 0x65, 0x72,
0x6e, 0x61, 0x74, 0x65, 0x5f, 0x73, 0x6d, 0x6f, 0x6f, 0x74, 0x68, 0x69,
0x6e, 0x67, 0x00, 0x61, 0x6c, 0x74, 0x65, 0x72, 0x6e, 0x61, 0x74, 0x65,
0x5f, 0x6f, 0x6e, 0x65, 0x5f, 0x6d, 0x69, 0x6e, 0x75, 0x73, 0x5f, 0x73,
0x6d, 0x6f, 0x6f, 0x74, 0x68, 0x69, 0x6e, 0x67, 0x00, 0x73, 0x6d, 0x6f,
0x6f, 0x74, 0x68, 0x69, 0x6e, 0x67, 0x5f, 0x62, 0x69, 0x74, 0x73, 0x00,
0x6d, 0x69, 0x6e, 0x5f, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x5f, 0x72,
0x65, 0x6d, 0x61, 0x69, 0x6e, 0x69, 0x6e, 0x67, 0x00, 0x63, 0x6c, 0x61,
0x6d, 0x70, 0x69, 0x6e, 0x67, 0x00, 0x73, 0x70, 0x65, 0x63, 0x74, 0x72,
0x61, 0x6c, 0x5f, 0x73, 0x75, 0x62, 0x74, 0x72, 0x61, 0x63, 0x74, 0x69,
0x6f, 0x6e, 0x5f, 0x62, 0x69, 0x74, 0x73, 0x00, 0x09, 0x66, 0x7b, 0x26,
0x3c, 0xa9, 0x93, 0x9e, 0x4f, 0x23, 0x09, 0x00, 0x01, 0x00, 0x09, 0x00,
0x71, 0x3d, 0x8f, 0x02, 0x00, 0x00, 0x33, 0x03, 0x20, 0x00, 0x71, 0x3d,
0x8f, 0x02, 0x00, 0x00, 0x0e, 0x00, 0x05, 0x05, 0x05, 0x05, 0x05, 0x05,
0x05, 0x05, 0x05, 0x1b, 0x25, 0x01,
};
const int g_gen_data_size_filter_bank_spectral_subtraction_16_channel = 210;
const unsigned char g_gen_data_filter_bank_spectral_subtraction_16_channel[] = {
0x6e, 0x75, 0x6d, 0x5f, 0x63, 0x68, 0x61, 0x6e, 0x6e, 0x65, 0x6c, 0x73,
0x00, 0x73, 0x6d, 0x6f, 0x6f, 0x74, 0x68, 0x69, 0x6e, 0x67, 0x00, 0x6f,
0x6e, 0x65, 0x5f, 0x6d, 0x69, 0x6e, 0x75, 0x73, 0x5f, 0x73, 0x6d, 0x6f,
0x6f, 0x74, 0x68, 0x69, 0x6e, 0x67, 0x00, 0x61, 0x6c, 0x74, 0x65, 0x72,
0x6e, 0x61, 0x74, 0x65, 0x5f, 0x73, 0x6d, 0x6f, 0x6f, 0x74, 0x68, 0x69,
0x6e, 0x67, 0x00, 0x61, 0x6c, 0x74, 0x65, 0x72, 0x6e, 0x61, 0x74, 0x65,
0x5f, 0x6f, 0x6e, 0x65, 0x5f, 0x6d, 0x69, 0x6e, 0x75, 0x73, 0x5f, 0x73,
0x6d, 0x6f, 0x6f, 0x74, 0x68, 0x69, 0x6e, 0x67, 0x00, 0x73, 0x6d, 0x6f,
0x6f, 0x74, 0x68, 0x69, 0x6e, 0x67, 0x5f, 0x62, 0x69, 0x74, 0x73, 0x00,
0x6d, 0x69, 0x6e, 0x5f, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x5f, 0x72,
0x65, 0x6d, 0x61, 0x69, 0x6e, 0x69, 0x6e, 0x67, 0x00, 0x63, 0x6c, 0x61,
0x6d, 0x70, 0x69, 0x6e, 0x67, 0x00, 0x73, 0x70, 0x65, 0x63, 0x74, 0x72,
0x61, 0x6c, 0x5f, 0x73, 0x75, 0x62, 0x74, 0x72, 0x61, 0x63, 0x74, 0x69,
0x6f, 0x6e, 0x5f, 0x62, 0x69, 0x74, 0x73, 0x00, 0x09, 0x66, 0x7b, 0x26,
0x3c, 0xa9, 0x93, 0x9e, 0x4f, 0x23, 0x09, 0x00, 0x01, 0x00, 0x09, 0x00,
0x71, 0x3d, 0x8f, 0x02, 0x00, 0x00, 0x33, 0x03, 0x10, 0x00, 0x71, 0x3d,
0x8f, 0x02, 0x00, 0x00, 0x0e, 0x00, 0x05, 0x05, 0x05, 0x05, 0x05, 0x05,
0x05, 0x05, 0x05, 0x1b, 0x25, 0x01,
};
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef SIGNAL_MICRO_KERNELS_TEST_DATA_GENERATION_GENERATE_FILTER_BANK_SPECTRAL_SUBTRACTION_FLEXBUFFERS_DATA_H_
#define SIGNAL_MICRO_KERNELS_TEST_DATA_GENERATION_GENERATE_FILTER_BANK_SPECTRAL_SUBTRACTION_FLEXBUFFERS_DATA_H_
extern const int g_gen_data_size_filter_bank_spectral_subtraction_32_channel;
extern const unsigned char
g_gen_data_filter_bank_spectral_subtraction_32_channel[];
extern const int g_gen_data_size_filter_bank_spectral_subtraction_16_channel;
extern const unsigned char
g_gen_data_filter_bank_spectral_subtraction_16_channel[];
#endif // SIGNAL_MICRO_KERNELS_TEST_DATA_GENERATION_GENERATE_FILTER_BANK_SPECTRAL_SUBTRACTION_FLEXBUFFERS_DATA_H_
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstdint>
#include "signal/micro/kernels/filter_bank_spectral_subtraction_flexbuffers_generated_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/micro/kernels/kernel_runner.h"
#include "tensorflow/lite/micro/test_helpers.h"
#include "tensorflow/lite/micro/testing/micro_test.h"
namespace tflite {
namespace testing {
namespace {
constexpr int kInputsSize = 1;
constexpr int kOutputsSize = 2;
constexpr int kTensorsSize = kInputsSize + kOutputsSize;
// Speicalized Kernel Runner for running test for the Filter Bank Spectral
// Subtract
// OP .
class FilterBankSpectralSubtractKernelRunner {
public:
explicit FilterBankSpectralSubtractKernelRunner(int* input_dims_data,
const uint32_t* input_data,
int* output_dims_data,
uint32_t* output_data1,
uint32_t* output_data2)
: inputs_array_(IntArrayFromInts(inputs_array_data_)),
outputs_array_(IntArrayFromInts(outputs_array_data_)),
kernel_runner_(*registration_, tensors_, kTensorsSize, inputs_array_,
outputs_array_, nullptr) {
tensors_[0] = tflite::testing::CreateTensor(
input_data, tflite::testing::IntArrayFromInts(input_dims_data));
tensors_[1] = tflite::testing::CreateTensor(
output_data1, tflite::testing::IntArrayFromInts(output_dims_data));
tensors_[2] = tflite::testing::CreateTensor(
output_data2, tflite::testing::IntArrayFromInts(output_dims_data));
}
tflite::micro::KernelRunner& kernel_runner() { return kernel_runner_; }
private:
int inputs_array_data_[kInputsSize + 1] = {1, 0};
int outputs_array_data_[kOutputsSize + 1] = {2, 1, 2};
TfLiteTensor tensors_[kTensorsSize] = {};
TfLiteIntArray* inputs_array_ = nullptr;
TfLiteIntArray* outputs_array_ = nullptr;
TFLMRegistration* registration_ =
tflite::tflm_signal::Register_FILTER_BANK_SPECTRAL_SUBTRACTION();
tflite::micro::KernelRunner kernel_runner_;
};
TfLiteStatus TestFilterBankSpectralSubtractionInvoke(
int* output_dims_data, const uint32_t* golden1, const uint32_t* golden2,
uint32_t* output1_data, uint32_t* output2_data,
tflite::micro::KernelRunner& kernel_runner) {
TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
const int output_len = ElementCount(*output_dims);
TF_LITE_ENSURE_STATUS(kernel_runner.Invoke());
for (int i = 0; i < output_len; ++i) {
TF_LITE_MICRO_EXPECT_EQ(golden1[i], output1_data[i]);
TF_LITE_MICRO_EXPECT_EQ(golden2[i], output2_data[i]);
}
return kTfLiteOk;
}
TfLiteStatus TestFilterBankSpectralSubtraction(
int* input_dims_data, const uint32_t* input_data, int* output_dims_data,
const uint32_t* golden1, const uint32_t* golden2,
const uint8_t* flexbuffers_data, const int flexbuffers_data_len,
uint32_t* output1_data, uint32_t* output2_data) {
FilterBankSpectralSubtractKernelRunner filter_bank_spectral_subtract_runner(
input_dims_data, input_data, output_dims_data, output1_data,
output2_data);
// TfLite uses a char* for the raw bytes whereas flexbuffers use an unsigned
// char*. This small discrepancy results in compiler warnings unless we
// reinterpret_cast right before passing in the flexbuffer bytes to the
// KernelRunner.
TF_LITE_ENSURE_STATUS(
filter_bank_spectral_subtract_runner.kernel_runner().InitAndPrepare(
reinterpret_cast<const char*>(flexbuffers_data),
flexbuffers_data_len));
TF_LITE_ENSURE_STATUS(TestFilterBankSpectralSubtractionInvoke(
output_dims_data, golden1, golden2, output1_data, output2_data,
filter_bank_spectral_subtract_runner.kernel_runner()));
return kTfLiteOk;
}
TfLiteStatus TestFilterBankSpectralSubtractionReset(
int* input_dims_data, const uint32_t* input_data, int* output_dims_data,
const uint32_t* golden1, const uint32_t* golden2,
const uint8_t* flexbuffers_data, const int flexbuffers_data_len,
uint32_t* output1_data, uint32_t* output2_data) {
FilterBankSpectralSubtractKernelRunner filter_bank_spectral_subtract_runner(
input_dims_data, input_data, output_dims_data, output1_data,
output2_data);
// TfLite uses a char* for the raw bytes whereas flexbuffers use an unsigned
// char*. This small discrepancy results in compiler warnings unless we
// reinterpret_cast right before passing in the flexbuffer bytes to the
// KernelRunner.
TF_LITE_ENSURE_STATUS(
filter_bank_spectral_subtract_runner.kernel_runner().InitAndPrepare(
reinterpret_cast<const char*>(flexbuffers_data),
flexbuffers_data_len));
TF_LITE_ENSURE_STATUS(TestFilterBankSpectralSubtractionInvoke(
output_dims_data, golden1, golden2, output1_data, output2_data,
filter_bank_spectral_subtract_runner.kernel_runner()));
filter_bank_spectral_subtract_runner.kernel_runner().Reset();
TF_LITE_ENSURE_STATUS(TestFilterBankSpectralSubtractionInvoke(
output_dims_data, golden1, golden2, output1_data, output2_data,
filter_bank_spectral_subtract_runner.kernel_runner()));
return kTfLiteOk;
}
} // namespace
} // namespace testing
} // namespace tflite
TF_LITE_MICRO_TESTS_BEGIN
TF_LITE_MICRO_TEST(FilterBankSpectralSubtractionTest32Channel) {
int input_shape[] = {1, 32};
int output_shape[] = {1, 32};
const uint32_t input[] = {322, 308, 210, 212, 181, 251, 403, 259, 65, 48, 76,
48, 50, 46, 53, 52, 112, 191, 136, 59, 70, 51,
39, 64, 33, 44, 41, 49, 74, 107, 262, 479};
const uint32_t golden1[] = {310, 296, 202, 204, 174, 241, 387, 249,
63, 47, 73, 47, 49, 45, 51, 50,
108, 184, 131, 57, 68, 49, 38, 62,
32, 43, 40, 48, 72, 103, 252, 460};
const uint32_t golden2[] = {12, 12, 8, 8, 7, 10, 16, 10, 2, 1, 3,
1, 1, 1, 2, 2, 4, 7, 5, 2, 2, 2,
1, 2, 1, 1, 1, 1, 2, 4, 10, 19};
uint32_t output1[32];
uint32_t output2[32];
TF_LITE_MICRO_EXPECT_EQ(
kTfLiteOk,
tflite::testing::TestFilterBankSpectralSubtraction(
input_shape, input, output_shape, golden1, golden2,
g_gen_data_filter_bank_spectral_subtraction_32_channel,
g_gen_data_size_filter_bank_spectral_subtraction_32_channel, output1,
output2));
}
TF_LITE_MICRO_TEST(FilterBankSpectralSubtractionTest16Channel) {
int input_shape[] = {1, 16};
int output_shape[] = {1, 16};
const uint32_t input[] = {393, 213, 408, 1, 361, 385, 386, 326,
170, 368, 368, 305, 152, 322, 213, 319};
const uint32_t golden1[] = {378, 205, 392, 1, 347, 370, 371, 313,
164, 354, 354, 293, 146, 310, 205, 307};
const uint32_t golden2[] = {15, 8, 16, 0, 14, 15, 15, 13,
6, 14, 14, 12, 6, 12, 8, 12};
uint32_t output1[32];
uint32_t output2[32];
TF_LITE_MICRO_EXPECT_EQ(
kTfLiteOk,
tflite::testing::TestFilterBankSpectralSubtraction(
input_shape, input, output_shape, golden1, golden2,
g_gen_data_filter_bank_spectral_subtraction_16_channel,
g_gen_data_size_filter_bank_spectral_subtraction_16_channel, output1,
output2));
}
TF_LITE_MICRO_TEST(FilterBankSpectralSubtractionTest32ChannelReset) {
int input_shape[] = {1, 32};
int output_shape[] = {1, 32};
const uint32_t input[] = {322, 308, 210, 212, 181, 251, 403, 259, 65, 48, 76,
48, 50, 46, 53, 52, 112, 191, 136, 59, 70, 51,
39, 64, 33, 44, 41, 49, 74, 107, 262, 479};
const uint32_t golden1[] = {310, 296, 202, 204, 174, 241, 387, 249,
63, 47, 73, 47, 49, 45, 51, 50,
108, 184, 131, 57, 68, 49, 38, 62,
32, 43, 40, 48, 72, 103, 252, 460};
const uint32_t golden2[] = {12, 12, 8, 8, 7, 10, 16, 10, 2, 1, 3,
1, 1, 1, 2, 2, 4, 7, 5, 2, 2, 2,
1, 2, 1, 1, 1, 1, 2, 4, 10, 19};
uint32_t output1[32];
uint32_t output2[32];
TF_LITE_MICRO_EXPECT_EQ(
kTfLiteOk,
tflite::testing::TestFilterBankSpectralSubtractionReset(
input_shape, input, output_shape, golden1, golden2,
g_gen_data_filter_bank_spectral_subtraction_32_channel,
g_gen_data_size_filter_bank_spectral_subtraction_32_channel, output1,
output2));
}
TF_LITE_MICRO_TEST(FilterBankSpectralSubtractionTest16ChannelReset) {
int input_shape[] = {1, 16};
int output_shape[] = {1, 16};
const uint32_t input[] = {393, 213, 408, 1, 361, 385, 386, 326,
170, 368, 368, 305, 152, 322, 213, 319};
const uint32_t golden1[] = {378, 205, 392, 1, 347, 370, 371, 313,
164, 354, 354, 293, 146, 310, 205, 307};
const uint32_t golden2[] = {15, 8, 16, 0, 14, 15, 15, 13,
6, 14, 14, 12, 6, 12, 8, 12};
uint32_t output1[32];
uint32_t output2[32];
TF_LITE_MICRO_EXPECT_EQ(
kTfLiteOk,
tflite::testing::TestFilterBankSpectralSubtractionReset(
input_shape, input, output_shape, golden1, golden2,
g_gen_data_filter_bank_spectral_subtraction_16_channel,
g_gen_data_size_filter_bank_spectral_subtraction_16_channel, output1,
output2));
}
TF_LITE_MICRO_TESTS_END
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "signal/src/filter_bank_square_root.h"
#include <stdint.h>
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "tensorflow/lite/micro/memory_helpers.h"
#include "tensorflow/lite/micro/micro_utils.h"
namespace tflite {
namespace {
constexpr int kInputTensor = 0;
constexpr int kScaleBitsTensor = 1;
constexpr int kOutputTensor = 0;
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
MicroContext* micro_context = GetMicroContext(context);
TfLiteTensor* input =
micro_context->AllocateTempInputTensor(node, kInputTensor);
TfLiteTensor* scale_bits =
micro_context->AllocateTempInputTensor(node, kScaleBitsTensor);
TfLiteTensor* output =
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
TF_LITE_ENSURE(context, input != nullptr);
TF_LITE_ENSURE(context, scale_bits != nullptr);
TF_LITE_ENSURE(context, output != nullptr);
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 1);
TF_LITE_ENSURE_EQ(context, NumDimensions(scale_bits), 0);
TF_LITE_ENSURE_EQ(context, NumDimensions(output), 1);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteUInt64);
TF_LITE_ENSURE_TYPES_EQ(context, scale_bits->type, kTfLiteInt32);
TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteUInt32);
micro_context->DeallocateTempTfLiteTensor(input);
micro_context->DeallocateTempTfLiteTensor(output);
micro_context->DeallocateTempTfLiteTensor(scale_bits);
return kTfLiteOk;
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteEvalTensor* input =
tflite::micro::GetEvalInput(context, node, kInputTensor);
const TfLiteEvalTensor* scale_bits =
tflite::micro::GetEvalInput(context, node, kScaleBitsTensor);
TfLiteEvalTensor* output =
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
const uint64_t* input_data = tflite::micro::GetTensorData<uint64_t>(input);
const int32_t* scale_bits_data =
tflite::micro::GetTensorData<int32_t>(scale_bits);
uint32_t* output_data = tflite::micro::GetTensorData<uint32_t>(output);
int32_t num_channels = input->dims->data[0];
tflm_signal::FilterbankSqrt(input_data, num_channels, *scale_bits_data,
output_data);
return kTfLiteOk;
}
} // namespace
namespace tflm_signal {
TFLMRegistration* Register_FILTER_BANK_SQUARE_ROOT() {
static TFLMRegistration r = tflite::micro::RegisterOp(nullptr, Prepare, Eval);
return &r;
}
} // namespace tflm_signal
} // namespace tflite
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstdint>
#include "tensorflow/lite/micro/kernels/kernel_runner.h"
#include "tensorflow/lite/micro/test_helpers.h"
#include "tensorflow/lite/micro/testing/micro_test.h"
namespace tflite {
namespace testing {
namespace {
TfLiteStatus TestFilterBankSquareRoot(
int* input1_dims_data, const uint64_t* input1_data, int* input2_dims_data,
const int32_t* input2_data, int* output_dims_data, const uint32_t* golden,
uint32_t* output_data) {
TfLiteIntArray* input1_dims = IntArrayFromInts(input1_dims_data);
TfLiteIntArray* input2_dims = IntArrayFromInts(input2_dims_data);
TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
constexpr int kInputsSize = 1;
constexpr int kOutputsSize = 2;
constexpr int kTensorsSize = kInputsSize + kOutputsSize;
TfLiteTensor tensors[kTensorsSize] = {
CreateTensor(input1_data, input1_dims),
CreateTensor(input2_data, input2_dims),
CreateTensor(output_data, output_dims),
};
int inputs_array_data[] = {2, 0, 1};
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 2};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
const int output_len = ElementCount(*output_dims);
TFLMRegistration* registration =
tflite::tflm_signal::Register_FILTER_BANK_SQUARE_ROOT();
micro::KernelRunner runner(*registration, tensors, kTensorsSize, inputs_array,
outputs_array,
/*builtin_data=*/nullptr);
// TfLite uses a char* for the raw bytes whereas flexbuffers use an unsigned
// char*. This small discrepancy results in compiler warnings unless we
// reinterpret_cast right before passing in the flexbuffer bytes to the
// KernelRunner.
TF_LITE_ENSURE_STATUS(runner.InitAndPrepare(nullptr, 0));
TF_LITE_ENSURE_STATUS(runner.Invoke());
for (int i = 0; i < output_len; ++i) {
TF_LITE_MICRO_EXPECT_EQ(golden[i], output_data[i]);
}
return kTfLiteOk;
}
} // namespace
} // namespace testing
} // namespace tflite
TF_LITE_MICRO_TESTS_BEGIN
TF_LITE_MICRO_TEST(FilterBankSquareRoot32Channel) {
int input1_shape[] = {1, 32};
int input2_shape[] = {0};
int output_shape[] = {1, 32};
const uint64_t input1[] = {
10528000193, 28362909357, 47577133750, 8466055850, 5842710800, 2350911449,
2989811430, 2646718839, 515262774, 276394561, 469831522, 55815334,
28232446, 11591835, 40329249, 67658028, 183446654, 323189165,
117473797, 41339272, 25846050, 12428673, 18670978, 22521722,
78477733, 54207503, 25150296, 43098592, 28211625, 15736687,
20990296, 17907031};
const int32_t input2[] = {7};
const uint32_t golden[] = {801, 1315, 1704, 718, 597, 378, 427, 401,
177, 129, 169, 58, 41, 26, 49, 64,
105, 140, 84, 50, 39, 27, 33, 37,
69, 57, 39, 51, 41, 30, 35, 33};
uint32_t output[32];
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, tflite::testing::TestFilterBankSquareRoot(
input1_shape, input1, input2_shape,
input2, output_shape, golden, output));
}
TF_LITE_MICRO_TEST(FilterBankSquareRoot16Channel) {
int input1_shape[] = {1, 16};
int input2_shape[] = {0};
int output_shape[] = {1, 16};
const uint64_t input1[] = {
13051415151, 14932650877, 18954728418, 8730126017,
6529665275, 12952546517, 10314975609, 8919697835,
8053663348, 17231208421, 7366899760, 1372112200,
19953434807, 17012385332, 4710443222, 17765594053};
const int32_t input2[] = {5};
const uint32_t golden[] = {3570, 3818, 4302, 2919, 2525, 3556, 3173, 2951,
2804, 4102, 2682, 1157, 4414, 4076, 2144, 4165};
uint32_t output[16];
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, tflite::testing::TestFilterBankSquareRoot(
input1_shape, input1, input2_shape,
input2, output_shape, golden, output));
}
TF_LITE_MICRO_TESTS_END
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstdint>
#include "signal/micro/kernels/filter_bank_flexbuffers_generated_data.h"
#include "tensorflow/lite/micro/kernels/kernel_runner.h"
#include "tensorflow/lite/micro/test_helpers.h"
#include "tensorflow/lite/micro/testing/micro_test.h"
namespace tflite {
namespace testing {
namespace {
TfLiteStatus TestFilterBank(int* input1_dims_data, const uint32_t* input1_data,
int* input2_dims_data, const int16_t* input2_data,
int* input3_dims_data, const int16_t* input3_data,
int* input4_dims_data, const int16_t* input4_data,
int* input5_dims_data, const int16_t* input5_data,
int* input6_dims_data, const int16_t* input6_data,
int* output_dims_data, const uint64_t* golden,
const uint8_t* flexbuffers_data,
const int flexbuffers_data_len,
uint64_t* output_data) {
TfLiteIntArray* input1_dims = IntArrayFromInts(input1_dims_data);
TfLiteIntArray* input2_dims = IntArrayFromInts(input2_dims_data);
TfLiteIntArray* input3_dims = IntArrayFromInts(input3_dims_data);
TfLiteIntArray* input4_dims = IntArrayFromInts(input4_dims_data);
TfLiteIntArray* input5_dims = IntArrayFromInts(input5_dims_data);
TfLiteIntArray* input6_dims = IntArrayFromInts(input6_dims_data);
TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
constexpr int kInputsSize = 6;
constexpr int kOutputsSize = 1;
constexpr int kTensorsSize = kInputsSize + kOutputsSize;
TfLiteTensor tensors[kTensorsSize] = {
CreateTensor(input1_data, input1_dims),
CreateTensor(input2_data, input2_dims),
CreateTensor(input3_data, input3_dims),
CreateTensor(input4_data, input4_dims),
CreateTensor(input5_data, input5_dims),
CreateTensor(input6_data, input6_dims),
CreateTensor(output_data, output_dims),
};
int inputs_array_data[] = {6, 0, 1, 2, 3, 4, 5};
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 6};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
const int output_len = ElementCount(*output_dims);
TFLMRegistration* registration = tflite::tflm_signal::Register_FILTER_BANK();
micro::KernelRunner runner(*registration, tensors, kTensorsSize, inputs_array,
outputs_array,
/*builtin_data=*/nullptr);
// TfLite uses a char* for the raw bytes whereas flexbuffers use an unsigned
// char*. This small discrepancy results in compiler warnings unless we
// reinterpret_cast right before passing in the flexbuffer bytes to the
// KernelRunner.
TF_LITE_ENSURE_STATUS(runner.InitAndPrepare(
reinterpret_cast<const char*>(flexbuffers_data), flexbuffers_data_len));
TF_LITE_ENSURE_STATUS(runner.Invoke());
for (int i = 0; i < output_len; ++i) {
TF_LITE_MICRO_EXPECT_EQ(golden[i], output_data[i]);
}
return kTfLiteOk;
}
} // namespace
} // namespace testing
} // namespace tflite
TF_LITE_MICRO_TESTS_BEGIN
TF_LITE_MICRO_TEST(FilterBankTest32Channel) {
int input1_shape[] = {1, 257};
int input2_shape[] = {1, 117};
int input3_shape[] = {1, 117};
int input4_shape[] = {1, 33};
int input5_shape[] = {1, 33};
int input6_shape[] = {1, 33};
int output_shape[] = {1, 32};
uint64_t output[32];
const uint32_t input1[] = {
65451, 11468838, 4280615122, 4283105055, 30080683, 969970,
1168164, 192770, 344209, 1811809, 1740724, 586130,
305045, 17981, 169273, 103321, 85277, 529901,
524660, 116609, 29653, 64345, 13121, 273956,
593748, 463432, 348169, 77545, 2117, 19277,
13837, 85, 16322, 1325, 69584, 233930,
253273, 94180, 8642, 104245, 151937, 231970,
90405, 95849, 106285, 81938, 76226, 103337,
303250, 337705, 75140, 43874, 33730, 44761,
117608, 57322, 9945, 19816, 48674, 19465,
15696, 52229, 103738, 102541, 126421, 133157,
33680, 7738, 45029, 57122, 61605, 60138,
26170, 41444, 210994, 238338, 74324, 21460,
33125, 3940, 15481, 7709, 24929, 17714,
170993, 91978, 45965, 214133, 96832, 1800,
16717, 42341, 87421, 114341, 65161, 26260,
135077, 245000, 122117, 81188, 107753, 74125,
86432, 91460, 29648, 2069, 3161, 5002,
784, 1152, 1424, 277, 452, 2696,
3610, 2120, 2617, 562, 1153, 4610,
2906, 65, 786450, 4293722107, 0, 393208,
2, 196608, 65539, 65537, 4294967295, 65537,
4294901762, 65535, 4294770689, 65533, 131073, 4294901761,
131071, 131071, 65535, 4294901764, 4294967295, 0,
4294901758, 4294901761, 196607, 4294836224, 131070, 4294901762,
4294901759, 196608, 4294901761, 131071, 131070, 65538,
0, 4294901761, 65536, 4294836225, 65536, 4294836225,
4294901757, 65535, 4294901760, 196607, 4294967295, 0,
131071, 4294901762, 4294836221, 196608, 65536, 1,
131074, 4294770690, 4294967291, 196611, 4294770687, 262143,
4294901759, 131071, 1, 4294901759, 196607, 4294705153,
196607, 4294967294, 65536, 1, 4294901759, 65536,
0, 65536, 65537, 4294901759, 65536, 3,
4294836222, 65534, 65536, 65538, 4294836225, 4294901760,
4294901761, 4294967293, 0, 65534, 131070, 65537,
4294901762, 65536, 2, 4294836224, 1, 4294901760,
0, 4294967294, 131073, 4294901760, 65535, 131073,
4294836224, 65536, 4294901760, 4294901760, 4294967295, 4294901761,
131071, 4294901760, 131071, 4294836224, 2, 4294901758,
4294967292, 131073, 0, 65535, 0, 4294901760,
4294967295, 131073, 4294901764, 4294836223, 4294967295, 65535,
65537, 65533, 3, 131072, 4294836224, 65537,
1, 4294967293, 196611, 4294901759, 1};
const int16_t input2[] = {
1133, 2373, 3712, 1047, 2564, 66, 1740, 3486, 1202, 3079, 919, 2913,
865, 2964, 1015, 3210, 1352, 3633, 1859, 123, 2520, 856, 3323, 1726,
161, 2722, 1215, 3833, 2382, 956, 3652, 2276, 923, 3689, 2380, 1093,
3922, 2676, 1448, 239, 3144, 1970, 814, 3770, 2646, 1538, 445, 3463,
2399, 1349, 313, 3386, 2376, 1379, 394, 3517, 2556, 1607, 668, 3837,
2920, 2013, 1117, 231, 3450, 2583, 1725, 877, 37, 3302, 2480, 1666,
861, 63, 3369, 2588, 1813, 1046, 287, 3630, 2885, 2147, 1415, 690,
4067, 3355, 2650, 1950, 1257, 569, 3984, 3308, 2638, 1973, 1314, 661,
12, 3465, 2827, 2194, 1566, 943, 325, 3808, 3199, 2595, 1996, 1401,
810, 224, 3738, 3160, 2586, 2017, 1451, 890, 332};
const int16_t input3[] = {
2962, 1722, 383, 3048, 1531, 4029, 2355, 609, 2893, 1016, 3176, 1182,
3230, 1131, 3080, 885, 2743, 462, 2236, 3972, 1575, 3239, 772, 2369,
3934, 1373, 2880, 262, 1713, 3139, 443, 1819, 3172, 406, 1715, 3002,
173, 1419, 2647, 3856, 951, 2125, 3281, 325, 1449, 2557, 3650, 632,
1696, 2746, 3782, 709, 1719, 2716, 3701, 578, 1539, 2488, 3427, 258,
1175, 2082, 2978, 3864, 645, 1512, 2370, 3218, 4058, 793, 1615, 2429,
3234, 4032, 726, 1507, 2282, 3049, 3808, 465, 1210, 1948, 2680, 3405,
28, 740, 1445, 2145, 2838, 3526, 111, 787, 1457, 2122, 2781, 3434,
4083, 630, 1268, 1901, 2529, 3152, 3770, 287, 896, 1500, 2099, 2694,
3285, 3871, 357, 935, 1509, 2078, 2644, 3205, 3763};
const int16_t input4[] = {5, 6, 7, 9, 11, 12, 14, 16, 18, 20, 22,
25, 27, 30, 32, 35, 38, 41, 45, 48, 52, 56,
60, 64, 69, 74, 79, 84, 89, 95, 102, 108, 115};
const int16_t input5[] = {0, 1, 2, 4, 6, 7, 9, 11, 13, 15, 17,
20, 22, 25, 27, 30, 33, 36, 40, 43, 47, 51,
55, 59, 64, 69, 74, 79, 84, 90, 97, 103, 110};
const int16_t input6[] = {1, 1, 2, 2, 1, 2, 2, 2, 2, 2, 3, 2, 3, 2, 3, 3, 3,
4, 3, 4, 4, 4, 4, 5, 5, 5, 5, 5, 6, 7, 6, 7, 7};
const uint64_t golden[] = {
5645104312, 3087527471, 5883346002, 10807122775, 2465336182, 853935004,
1206905130, 3485828019, 1134726750, 832725041, 4442875878, 2122064365,
178483220, 151483681, 1742660113, 1309124116, 1954305288, 1323857378,
2750861165, 1340947482, 792522630, 669257768, 1659699572, 940652856,
1957080469, 1034203505, 1541805928, 1710818326, 2432875876, 2254716277,
275382345, 57293224};
TF_LITE_MICRO_EXPECT_EQ(
kTfLiteOk,
tflite::testing::TestFilterBank(
input1_shape, input1, input2_shape, input2, input3_shape, input3,
input4_shape, input4, input5_shape, input5, input6_shape, input6,
output_shape, golden, g_gen_data_filter_bank_32_channel,
g_gen_data_size_filter_bank_32_channel, output));
}
TF_LITE_MICRO_TEST(FilterBankTest16Channel) {
int input1_shape[] = {1, 129};
int input2_shape[] = {1, 59};
int input3_shape[] = {1, 59};
int input4_shape[] = {1, 17};
int input5_shape[] = {1, 17};
int input6_shape[] = {1, 17};
int output_shape[] = {1, 16};
uint64_t output[16];
const uint32_t input1[] = {
645050, 4644, 3653, 24262, 56660, 43260, 50584, 57902, 31702, 5401,
45555, 34852, 8518, 43556, 13358, 19350, 40221, 18017, 27284, 64491,
60099, 17863, 11001, 29076, 32666, 65268, 50947, 28694, 32377, 30014,
25607, 22547, 45086, 10654, 46797, 8622, 47348, 43085, 5747, 51544,
50364, 6208, 20696, 59782, 14429, 60125, 37079, 32673, 63457, 60142,
34042, 11280, 1874, 33734, 62118, 13766, 54398, 47818, 50976, 46930,
25906, 59441, 25958, 59136, 1756, 18652, 29213, 13379, 51845, 1207,
55626, 27108, 43771, 35236, 3374, 40959, 47707, 41540, 34282, 27094,
36329, 13593, 65257, 47006, 46857, 1114, 37106, 18738, 25969, 15461,
2842, 36470, 32489, 61622, 23613, 29624, 32820, 30438, 9543, 6767,
23037, 52896, 12059, 32264, 11575, 42400, 43344, 27511, 16712, 6877,
4910, 50047, 61569, 57237, 48558, 2310, 22192, 7874, 46141, 64056,
61997, 7298, 31372, 25316, 683, 58940, 18755, 17898, 19196};
const int16_t input2[] = {
-2210, 1711, 3237, 1247, 2507, 61, 1019, 899, 206, 146, 2849, 2756,
1260, 1280, 1951, 213, 617, 2047, 211, 347, 2821, 3747, 150, 1924,
3962, 942, 1430, 2678, 993, 308, 3364, 2491, 954, 1308, 879, 3950,
1, 3556, 3628, 2104, 78, 1298, 1080, 342, 1337, 1639, 2352, 829,
1358, 2498, 1647, 2507, 3816, 3767, 3735, 1155, 2221, 2196, 1160};
const int16_t input3[] = {
408, 3574, 1880, 2561, 2011, 3394, 1019, 445, 3901, 343, 1874, 3846,
3566, 1830, 327, 111, 623, 1037, 2803, 1947, 1518, 661, 3239, 2351,
1257, 269, 1574, 3431, 3972, 2487, 2181, 1458, 552, 717, 679, 1031,
1738, 1782, 128, 2242, 353, 1460, 3305, 1424, 3813, 2895, 164, 272,
3886, 3135, 141, 747, 3233, 1478, 2612, 3837, 3271, 73, 1746};
const int16_t input4[] = {5, 6, 7, 9, 11, 12, 14, 16, 18,
20, 22, 25, 27, 30, 32, 35, 33};
const int16_t input5[] = {0, 1, 2, 4, 6, 7, 9, 11, 13,
15, 17, 20, 22, 25, 27, 30, 33};
const int16_t input6[] = {1, 1, 2, 2, 1, 2, 2, 2, 2, 2, 3, 2, 3, 2, 3, 3, 3};
const uint64_t golden[] = {104199304, 407748384, 206363744, 200989269,
52144406, 230780884, 174394190, 379684049,
94840835, 57788823, 531528204, 318265707,
263149795, 188110467, 501443259, 200747781};
TF_LITE_MICRO_EXPECT_EQ(
kTfLiteOk,
tflite::testing::TestFilterBank(
input1_shape, input1, input2_shape, input2, input3_shape, input3,
input4_shape, input4, input5_shape, input5, input6_shape, input6,
output_shape, golden, g_gen_data_filter_bank_16_channel,
g_gen_data_size_filter_bank_16_channel, output));
}
TF_LITE_MICRO_TESTS_END
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "signal/src/irfft.h"
#include <math.h>
#include <stddef.h>
#include <stdint.h>
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/micro/flatbuffer_utils.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "tensorflow/lite/portable_type_to_tflitetype.h"
namespace tflite {
namespace {
constexpr int kInputTensor = 0;
constexpr int kOutputTensor = 0;
// Indices into the init flexbuffer's vector.
// The parameter's name is in the comment that follows.
// Elements in the vectors are ordered alphabetically by parameter name.
// 'T' is added implicitly by the TensorFlow framework when the type is resolved
// during graph construction.
// constexpr int kTypeIndex = 0; // 'T' (unused)
constexpr int kFftLengthIndex = 1; // 'fft_length'
struct TfLiteAudioFrontendIrfftParams {
int32_t fft_length;
int32_t input_size;
int32_t input_length;
int32_t output_length;
TfLiteType fft_type;
int8_t* state;
};
template <typename T, size_t (*get_needed_memory_func)(int32_t),
void* (*init_func)(int32_t, void*, size_t)>
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
auto* params = static_cast<TfLiteAudioFrontendIrfftParams*>(
context->AllocatePersistentBuffer(
context, sizeof(TfLiteAudioFrontendIrfftParams)));
if (params == nullptr) {
return nullptr;
}
tflite::FlexbufferWrapper fbw(reinterpret_cast<const uint8_t*>(buffer),
length);
params->fft_length = fbw.ElementAsInt32(kFftLengthIndex);
params->fft_type = typeToTfLiteType<T>();
size_t state_size = (*get_needed_memory_func)(params->fft_length);
params->state = reinterpret_cast<int8_t*>(
context->AllocatePersistentBuffer(context, state_size * sizeof(int8_t)));
if (params->state == nullptr) {
return nullptr;
}
(*init_func)(params->fft_length, params->state, state_size);
return params;
}
template <TfLiteType TfLiteTypeEnum>
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
MicroContext* micro_context = GetMicroContext(context);
TfLiteTensor* input =
micro_context->AllocateTempInputTensor(node, kInputTensor);
TF_LITE_ENSURE(context, input != nullptr);
TfLiteTensor* output =
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
TF_LITE_ENSURE_EQ(context, NumDimensions(input), NumDimensions(output));
TF_LITE_ENSURE_TYPES_EQ(context, input->type, TfLiteTypeEnum);
TF_LITE_ENSURE_TYPES_EQ(context, output->type, TfLiteTypeEnum);
auto* params =
reinterpret_cast<TfLiteAudioFrontendIrfftParams*>(node->user_data);
RuntimeShape input_shape = GetTensorShape(input);
RuntimeShape output_shape = GetTensorShape(output);
// Divide by 2 because input is complex.
params->input_length =
input_shape.Dims(input_shape.DimensionsCount() - 1) / 2;
params->input_size = input_shape.FlatSize() / 2;
params->output_length = output_shape.Dims(output_shape.DimensionsCount() - 1);
micro_context->DeallocateTempTfLiteTensor(input);
micro_context->DeallocateTempTfLiteTensor(output);
return kTfLiteOk;
}
template <typename T, void (*apply_func)(void*, const Complex<T>* input, T*)>
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params =
reinterpret_cast<TfLiteAudioFrontendIrfftParams*>(node->user_data);
const TfLiteEvalTensor* input =
tflite::micro::GetEvalInput(context, node, kInputTensor);
TfLiteEvalTensor* output =
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
const Complex<T>* input_data =
tflite::micro::GetTensorData<Complex<T>>(input);
T* output_data = tflite::micro::GetTensorData<T>(output);
for (int input_idx = 0, output_idx = 0; input_idx < params->input_size;
input_idx += params->input_length, output_idx += params->output_length) {
(*apply_func)(params->state, &input_data[input_idx],
&output_data[output_idx]);
}
return kTfLiteOk;
}
void* InitAll(TfLiteContext* context, const char* buffer, size_t length) {
const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap();
auto tensor_type = static_cast<tflite::TensorType>(m["T"].AsInt32());
switch (tensor_type) {
case TensorType_INT16: {
return Init<int16_t, tflm_signal::IrfftInt16GetNeededMemory,
tflm_signal::IrfftInt16Init>(context, buffer, length);
}
case TensorType_INT32: {
return Init<int32_t, tflm_signal::IrfftInt32GetNeededMemory,
tflm_signal::IrfftInt32Init>(context, buffer, length);
}
case TensorType_FLOAT32: {
return Init<float, tflm_signal::IrfftFloatGetNeededMemory,
tflm_signal::IrfftFloatInit>(context, buffer, length);
}
default:
return nullptr;
}
}
TfLiteStatus PrepareAll(TfLiteContext* context, TfLiteNode* node) {
auto* params =
reinterpret_cast<TfLiteAudioFrontendIrfftParams*>(node->user_data);
switch (params->fft_type) {
case kTfLiteInt16: {
return Prepare<kTfLiteInt16>(context, node);
}
case kTfLiteInt32: {
return Prepare<kTfLiteInt32>(context, node);
}
case kTfLiteFloat32: {
return Prepare<kTfLiteFloat32>(context, node);
}
default:
return kTfLiteError;
}
}
TfLiteStatus EvalAll(TfLiteContext* context, TfLiteNode* node) {
auto* params =
reinterpret_cast<TfLiteAudioFrontendIrfftParams*>(node->user_data);
switch (params->fft_type) {
case kTfLiteInt16: {
return Eval<int16_t, tflm_signal::IrfftInt16Apply>(context, node);
}
case kTfLiteInt32: {
return Eval<int32_t, tflm_signal::IrfftInt32Apply>(context, node);
}
case kTfLiteFloat32: {
return Eval<float, tflm_signal::IrfftFloatApply>(context, node);
}
default:
return kTfLiteError;
}
}
} // namespace
// TODO(b/286250473): remove namespace once de-duped libraries
namespace tflm_signal {
TFLMRegistration* Register_IRFFT() {
static TFLMRegistration r =
tflite::micro::RegisterOp(InitAll, PrepareAll, EvalAll);
return &r;
}
TFLMRegistration* Register_IRFFT_FLOAT() {
static TFLMRegistration r = tflite::micro::RegisterOp(
Init<float, IrfftFloatGetNeededMemory, IrfftFloatInit>,
Prepare<kTfLiteFloat32>, Eval<float, IrfftFloatApply>);
return &r;
}
TFLMRegistration* Register_IRFFT_INT16() {
static TFLMRegistration r = tflite::micro::RegisterOp(
Init<int16_t, IrfftInt16GetNeededMemory, IrfftInt16Init>,
Prepare<kTfLiteInt16>, Eval<int16_t, IrfftInt16Apply>);
return &r;
}
TFLMRegistration* Register_IRFFT_INT32() {
static TFLMRegistration r = tflite::micro::RegisterOp(
Init<int32_t, IrfftInt32GetNeededMemory, IrfftInt32Init>,
Prepare<kTfLiteInt32>, Eval<int32_t, IrfftInt32Apply>);
return &r;
}
} // namespace tflm_signal
} // namespace tflite
\ No newline at end of file
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef SIGNAL_MICRO_KERNELS_IRFFT_H_
#define SIGNAL_MICRO_KERNELS_IRFFT_H_
#include "tensorflow/lite/micro/micro_common.h"
namespace tflite {
namespace tflm_signal {
TFLMRegistration* Register_IRFFT();
TFLMRegistration* Register_IRFFT_FLOAT();
TFLMRegistration* Register_IRFFT_INT16();
TFLMRegistration* Register_IRFFT_INT32();
} // namespace tflm_signal
} // namespace tflite
#endif // SIGNAL_MICRO_KERNELS_IRFFT_H_
......@@ -8,6 +8,72 @@ cc_library(
hdrs = ["complex.h"],
)
cc_library(
name = "fft_auto_scale",
srcs = ["fft_auto_scale.cc"],
hdrs = ["fft_auto_scale.h"],
deps = [
":max_abs",
":msb_32",
],
)
cc_library(
name = "irfft",
srcs = [
"irfft_float.cc",
"irfft_int16.cc",
"irfft_int32.cc",
],
hdrs = ["irfft.h"],
deps = [
":complex",
"//signal/src/kiss_fft_wrappers",
],
)
cc_library(
name = "max_abs",
srcs = ["max_abs.cc"],
hdrs = ["max_abs.h"],
)
cc_library(
name = "square_root_32",
srcs = ["square_root_32.cc"],
hdrs = ["square_root.h"],
deps = [":msb_32"],
)
cc_library(
name = "square_root_64",
srcs = ["square_root_64.cc"],
hdrs = ["square_root.h"],
deps = [
":msb_64",
":square_root_32",
],
)
cc_library(
name = "log",
srcs = ["log.cc"],
hdrs = ["log.h"],
deps = [":msb_32"],
)
cc_library(
name = "msb_32",
srcs = ["msb_32.cc"],
hdrs = ["msb.h"],
)
cc_library(
name = "msb_64",
srcs = ["msb_64.cc"],
hdrs = ["msb.h"],
)
cc_library(
name = "rfft",
srcs = [
......@@ -46,3 +112,33 @@ cc_library(
hdrs = ["energy.h"],
deps = [":complex"],
)
cc_library(
name = "filter_bank",
srcs = ["filter_bank.cc"],
hdrs = ["filter_bank.h"],
)
cc_library(
name = "filter_bank_log",
srcs = ["filter_bank_log.cc"],
hdrs = ["filter_bank_log.h"],
deps = [
":log",
],
)
cc_library(
name = "filter_bank_spectral_subtraction",
srcs = ["filter_bank_spectral_subtraction.cc"],
hdrs = ["filter_bank_spectral_subtraction.h"],
)
cc_library(
name = "filter_bank_square_root",
srcs = ["filter_bank_square_root.cc"],
hdrs = ["filter_bank_square_root.h"],
deps = [
":square_root_64",
],
)
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "signal/src/fft_auto_scale.h"
#include <stddef.h>
#include <stdint.h>
#include "signal/src/max_abs.h"
#include "signal/src/msb.h"
// TODO(b/286250473): remove namespace once de-duped libraries
namespace tflite {
namespace tflm_signal {
int FftAutoScale(const int16_t* input, int size, int16_t* output) {
const int16_t max = MaxAbs16(input, size);
int scale_bits = (sizeof(int16_t) * 8) - MostSignificantBit32(max) - 1;
if (scale_bits <= 0) {
scale_bits = 0;
}
for (int i = 0; i < size; i++) {
// (input[i] << scale_bits) is undefined if input[i] is negative.
// Multiply explicitly to make the code portable.
output[i] = input[i] * (1 << scale_bits);
}
return scale_bits;
}
} // namespace tflm_signal
} // namespace tflite
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef SIGNAL_SRC_FFT_AUTO_SCALE_H_
#define SIGNAL_SRC_FFT_AUTO_SCALE_H_
#include <stddef.h>
#include <stdint.h>
// TODO(b/286250473): remove namespace once de-duped libraries
namespace tflite {
namespace tflm_signal {
// Auto scales `input` and write the result to `output`
// Elements in `input` are left shifted to maximize the amplitude without
// clipping,
// * both `input` and `output` must be of size `size`
int FftAutoScale(const int16_t* input, int size, int16_t* output);
} // namespace tflm_signal
} // namespace tflite
#endif // SIGNAL_SRC_FFT_AUTO_SCALE_H_
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "signal/src/filter_bank.h"
namespace tflite {
namespace tflm_signal {
void FilterbankAccumulateChannels(const FilterbankConfig* config,
const uint32_t* input, uint64_t* output) {
// With a log mel filterbank, the energy at each frequency gets added to
// two adjacent filterbank filters/channels.
// For the first filter bank channel, its energy is first multiplied by
// some weight 'w', then gets accumulated.
// For the subsequent filter bank, its power is first multiplied by 1-'w'
// (called unweight here), then gets accumulated.
// For this reason, we need to calculate (config->num_channels + 1) output
// where element 0 is only used as scratch storage for the unweights of
// element 1 (channel 0). The caller should discard element 0.
// Writing the code like this doesn't save multiplications, but it lends
// itself better to optimization, because input[freq_start + j] only needs
// to be loaded once.
uint64_t weight_accumulator = 0;
uint64_t unweight_accumulator = 0;
for (int i = 0; i < config->num_channels + 1; i++) {
const int16_t freq_start = config->channel_frequency_starts[i];
const int16_t weight_start = config->channel_weight_starts[i];
for (int j = 0; j < config->channel_widths[i]; ++j) {
weight_accumulator += config->weights[weight_start + j] *
static_cast<uint64_t>(input[freq_start + j]);
unweight_accumulator += config->unweights[weight_start + j] *
static_cast<uint64_t>(input[freq_start + j]);
}
output[i] = weight_accumulator;
weight_accumulator = unweight_accumulator;
unweight_accumulator = 0;
}
}
} // namespace tflm_signal
} // namespace tflite
\ No newline at end of file
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef SIGNAL_SRC_FILTER_BANK_H_
#define SIGNAL_SRC_FILTER_BANK_H_
#include <stdint.h>
namespace tflite {
namespace tflm_signal {
// TODO(b/286250473): remove namespace once de-duped libraries above
struct FilterbankConfig {
// Number of filterbank channels
int32_t num_channels;
// Each of the following three arrays is of size num_channels + 1
// An extra channel is needed for scratch. See implementation of
// FilterbankAccumulateChannels() for more details
// For each channel, the index in the input (spectrum) where its band starts
const int16_t* channel_frequency_starts;
// For each channel, the index in the weights/unweights arrays where
// it filter weights start
const int16_t* channel_weight_starts;
// For each channel, the number of bins in the input (spectrum) that span
// its band
const int16_t* channel_widths;
// The weights array holds the triangular filter weights of all the filters
// in the bank. The output of each filter in the bank is caluclated by
// multiplying the elements in the input spectrum that are in its band
// (see above: channel_frequency_starts, channel_widths) by the filter weights
// then accumulating. Each element in the unweights array holds the 1 minus
// corresponding elements in the weights array and is used to make this
// operation more efficient. For more details, see documnetation in
// FilterbankAccumulateChannels()
const int16_t* weights;
const int16_t* unweights;
int32_t output_scale;
int32_t input_correction_bits;
};
// Accumulate the energy spectrum bins in `input` into filter bank channels
// contained in `output`.
// * `input` - Spectral energy array
// * `output` - of size `config.num_channels` + 1.
// Elements [1:num_channels] contain the filter bank channels.
// Element 0 is used as scratch and should be ignored
void FilterbankAccumulateChannels(const FilterbankConfig* config,
const uint32_t* input, uint64_t* output);
} // namespace tflm_signal
} // namespace tflite
#endif // SIGNAL_SRC_FILTER_BANK_H_
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "signal/src/filter_bank_log.h"
#include "signal/src/log.h"
namespace tflite {
namespace tflm_signal {
void FilterbankLog(const uint32_t* input, int num_channels,
int32_t output_scale, uint32_t correction_bits,
int16_t* output) {
for (int i = 0; i < num_channels; ++i) {
const uint32_t scaled = input[i] << correction_bits;
if (scaled > 1) {
const uint32_t log_value = Log32(scaled, output_scale);
output[i] = ((log_value < static_cast<uint32_t>(INT16_MAX))
? log_value
: static_cast<uint32_t>(INT16_MAX));
} else {
output[i] = 0;
}
}
}
} // namespace tflm_signal
} // namespace tflite
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef SIGNAL_SRC_FILTER_BANK_LOG_H_
#define SIGNAL_SRC_FILTER_BANK_LOG_H_
#include <stdint.h>
namespace tflite {
namespace tflm_signal {
// TODO(b/286250473): remove namespace once de-duped libraries above
// Apply natural log to each element in array `input` of size `num_channels`
// with pre-shift and post scaling.
// The operation is roughly equivalent to:
// `output` = min(Log(`input` << `correction_bits`) * `output_scale`, INT16_MAX)
// Where:
// If (input << `correction_bits`) is 1 or 0, the function returns 0
void FilterbankLog(const uint32_t* input, int num_channels,
int32_t output_scale, uint32_t correction_bits,
int16_t* output);
} // namespace tflm_signal
} // namespace tflite
#endif // SIGNAL_SRC_FILTER_BANK_LOG_H_
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "signal/src/filter_bank_spectral_subtraction.h"
namespace tflite {
namespace tflm_signal {
void FilterbankSpectralSubtraction(const SpectralSubtractionConfig* config,
const uint32_t* input, uint32_t* output,
uint32_t* noise_estimate) {
const bool data_clamping = config->clamping;
const int smoothing_bits = config->smoothing_bits;
const int num_channels = config->num_channels;
for (int i = 0; i < num_channels; ++i) {
uint32_t smoothing;
uint32_t one_minus_smoothing;
if ((i & 1) == 0) {
smoothing = config->smoothing;
one_minus_smoothing = config->one_minus_smoothing;
} else { // Use alternate smoothing coefficient on odd-index channels.
smoothing = config->alternate_smoothing;
one_minus_smoothing = config->alternate_one_minus_smoothing;
}
// Scale up signal[i] for smoothing filter computation.
const uint32_t signal_scaled_up = input[i] << smoothing_bits;
noise_estimate[i] =
((static_cast<uint64_t>(signal_scaled_up) * smoothing) +
(static_cast<uint64_t>(noise_estimate[i]) * one_minus_smoothing)) >>
config->spectral_subtraction_bits;
uint32_t estimate_scaled_up = noise_estimate[i];
// Make sure that we can't get a negative value for the signal - estimate.
if (estimate_scaled_up > signal_scaled_up) {
estimate_scaled_up = signal_scaled_up;
if (data_clamping) {
noise_estimate[i] = estimate_scaled_up;
}
}
const uint32_t floor =
(static_cast<uint64_t>(input[i]) * config->min_signal_remaining) >>
config->spectral_subtraction_bits;
const uint32_t subtracted =
(signal_scaled_up - estimate_scaled_up) >> smoothing_bits;
output[i] = subtracted > floor ? subtracted : floor;
}
}
} // namespace tflm_signal
} // namespace tflite
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef SIGNAL_SRC_FILTER_BANK_SPECTRAL_SUBTRACTION_H_
#define SIGNAL_SRC_FILTER_BANK_SPECTRAL_SUBTRACTION_H_
#include <stdint.h>
namespace tflite {
namespace tflm_signal {
// TODO(b/286250473): remove namespace once de-duped libraries above
struct SpectralSubtractionConfig {
// Number of filterbank channels in input and output
int32_t num_channels;
// The constant used for the lowpass filter for finding the noise.
// Higher values correspond to more aggressively adapting estimates
// of the noise.
// Scale is 1 << spectral_subtraction_bits
uint32_t smoothing;
// One minus smoothing constant for low pass filter.
// Scale is 1 << spectral_subtraction_bits
uint32_t one_minus_smoothing;
// The maximum cap to subtract away from the signal (ie, if this is
// 0.2, then the result of spectral subtraction will not go below
// 0.2 * signal).
// Scale is 1 << spectral_subtraction_bits
uint32_t min_signal_remaining;
// If positive, specifies the filter coefficient for odd-index
// channels, while 'smoothing' is used as the coefficient for even-
// index channels. Otherwise, the same filter coefficient is
// used on all channels.
// Scale is 1 << spectral_subtraction_bits
uint32_t alternate_smoothing;
// Alternate One minus smoothing constant for low pass filter.
// Scale is 1 << spectral_subtraction_bits
uint32_t alternate_one_minus_smoothing;
// Extra fractional bits for the noise_estimate smoothing filter.
uint32_t smoothing_bits;
// Scaling bits for some members of this struct
uint32_t spectral_subtraction_bits;
// If true, when the filterbank level drops below the output,
// the noise estimate will be forced down to the new noise level.
// If false, the noise estimate will remain above the current
// filterbank output (but the subtraction will still keep the
// output non negative).
bool clamping;
};
// Apply spectral subtraction to each element in `input`, then write the result
// to `output` and `noise_estimate`. `input`, `output` and `noise estimate`
// must all be of size `config.num_channels`. `config` holds the
// parameters of the spectral subtraction algorithm.
void FilterbankSpectralSubtraction(const SpectralSubtractionConfig* config,
const uint32_t* input, uint32_t* output,
uint32_t* noise_estimate);
} // namespace tflm_signal
} // namespace tflite
#endif // SIGNAL_SRC_FILTER_BANK_SPECTRAL_SUBTRACTION_H_
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "signal/src/filter_bank_square_root.h"
#include "signal/src/square_root.h"
namespace tflite {
namespace tflm_signal {
void FilterbankSqrt(const uint64_t* input, int num_channels,
int scale_down_bits, uint32_t* output) {
for (int i = 0; i < num_channels; ++i) {
output[i] = Sqrt64(input[i]) >> scale_down_bits;
}
}
} // namespace tflm_signal
} // namespace tflite
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef SIGNAL_SRC_FILTER_BANK_SQUARE_ROOT_H_
#define SIGNAL_SRC_FILTER_BANK_SQUARE_ROOT_H_
#include <stdint.h>
namespace tflite {
namespace tflm_signal {
// TODO(b/286250473): remove namespace once de-duped libraries above
// Apply square root to each element in `input`, then shift right by
// `scale_down_bits` before writing the result to `output`,
// `input` and `output` must both be of size `num_channels`
void FilterbankSqrt(const uint64_t* input, int num_channels,
int scale_down_bits, uint32_t* output);
} // namespace tflm_signal
} // namespace tflite
#endif // SIGNAL_SRC_FILTER_BANK_SQUARE_ROOT_H_
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef SIGNAL_SRC_IRFFT_H_
#define SIGNAL_SRC_IRFFT_H_
#include <stddef.h>
#include <stdint.h>
#include "signal/src/complex.h"
// TODO(b/286250473): remove namespace once de-duped libraries
namespace tflite {
namespace tflm_signal {
// IRFFT (Inverse Real Fast Fourier Transform)
// IFFT for real valued time domain outputs.
// 16-bit Integer input/output
// Returns the size of the memory that an IRFFT of `fft_length` needs
size_t IrfftInt16GetNeededMemory(int32_t fft_length);
// Initialize the state of an IRFFT of `fft_length`
// `state` points to an opaque state of size `state_size`, which
// must be greater or equal to the value returned by
// IrfftGetNeededMemory(fft_length). Fails if it isn't.
void* IrfftInt16Init(int32_t fft_length, void* state, size_t state_size);
// Applies IRFFT to `input` and writes the result to `output`
// * `input` must be of size `fft_length` elements (see IRfftInit)
// * `output` must be of size output
void IrfftInt16Apply(void* state, const Complex<int16_t>* input,
int16_t* output);
// 32-bit Integer input/output
// Returns the size of the memory that an IRFFT of `fft_length` needs
size_t IrfftInt32GetNeededMemory(int32_t fft_length);
// Initialize the state of an IRFFT of `fft_length`
// `state` points to an opaque state of size `state_size`, which
// must be greater or equal to the value returned by
// IrfftGetNeededMemory(fft_length). Fails if it isn't.
void* IrfftInt32Init(int32_t fft_length, void* state, size_t state_size);
// Applies IRFFT to `input` and writes the result to `output`
// * `input` must be of size `fft_length` elements (see IRfftInit)
// * `output` must be of size output
void IrfftInt32Apply(void* state, const Complex<int32_t>* input,
int32_t* output);
// Floating point input/output
// Returns the size of the memory that an IRFFT of `fft_length` needs
size_t IrfftFloatGetNeededMemory(int32_t fft_length);
// Initialize the state of an IRFFT of `fft_length`
// `state` points to an opaque state of size `state_size`, which
// must be greater or equal to the value returned by
// IrfftGetNeededMemory(fft_length). Fails if it isn't.
void* IrfftFloatInit(int32_t fft_length, void* state, size_t state_size);
// Applies IRFFT to `input` and writes the result to `output`
// * `input` must be of size `fft_length` elements (see IRfftInit)
// * `output` must be of size output
void IrfftFloatApply(void* state, const Complex<float>* input, float* output);
} // namespace tflm_signal
} // namespace tflite
#endif // SIGNAL_SRC_IRFFT_H_
\ No newline at end of file
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <stddef.h>
#include <stdint.h>
#include "signal/src/complex.h"
#include "signal/src/irfft.h"
#include "signal/src/kiss_fft_wrappers/kiss_fft_float.h"
// TODO(b/286250473): remove namespace once de-duped libraries
namespace tflite {
namespace tflm_signal {
struct IrfftFloatState {
int32_t fft_length;
kiss_fft_float::kiss_fftr_cfg cfg;
};
size_t IrfftFloatGetNeededMemory(int32_t fft_length) {
size_t cfg_size = 0;
kiss_fft_float::kiss_fftr_alloc(fft_length, 1, nullptr, &cfg_size);
return sizeof(IrfftFloatState) + cfg_size;
}
void* IrfftFloatInit(int32_t fft_length, void* state, size_t state_size) {
IrfftFloatState* irfft_float_state = static_cast<IrfftFloatState*>(state);
irfft_float_state->cfg =
reinterpret_cast<kiss_fft_float::kiss_fftr_cfg>(irfft_float_state + 1);
irfft_float_state->fft_length = fft_length;
size_t cfg_size = state_size - sizeof(IrfftFloatState);
return kiss_fft_float::kiss_fftr_alloc(fft_length, 1, irfft_float_state->cfg,
&cfg_size);
}
void IrfftFloatApply(void* state, const Complex<float>* input, float* output) {
IrfftFloatState* irfft_float_state = static_cast<IrfftFloatState*>(state);
kiss_fft_float::kiss_fftri(
static_cast<kiss_fft_float::kiss_fftr_cfg>(irfft_float_state->cfg),
reinterpret_cast<const kiss_fft_float::kiss_fft_cpx*>(input),
reinterpret_cast<kiss_fft_scalar*>(output));
// KissFFT scales the IRFFT output by the FFT length.
// KissFFT's nfft is the complex FFT length, which is half the real FFT's
// length. Compensate.
const int fft_length = irfft_float_state->fft_length;
for (int i = 0; i < fft_length; i++) {
output[i] /= fft_length;
}
}
} // namespace tflm_signal
} // namespace tflite
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <stddef.h>
#include <stdint.h>
#include "signal/src/complex.h"
#include "signal/src/irfft.h"
#include "signal/src/kiss_fft_wrappers/kiss_fft_int16.h"
// TODO(b/286250473): remove namespace once de-duped libraries
namespace tflite {
namespace tflm_signal {
size_t IrfftInt16GetNeededMemory(int32_t fft_length) {
size_t state_size = 0;
kiss_fft_fixed16::kiss_fftr_alloc(fft_length, 1, nullptr, &state_size);
return state_size;
}
void* IrfftInt16Init(int32_t fft_length, void* state, size_t state_size) {
return kiss_fft_fixed16::kiss_fftr_alloc(fft_length, 1, state, &state_size);
}
void IrfftInt16Apply(void* state, const Complex<int16_t>* input,
int16_t* output) {
kiss_fft_fixed16::kiss_fftri(
static_cast<kiss_fft_fixed16::kiss_fftr_cfg>(state),
reinterpret_cast<const kiss_fft_fixed16::kiss_fft_cpx*>(input),
reinterpret_cast<kiss_fft_scalar*>(output));
}
} // namespace tflm_signal
} // namespace tflite
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <stddef.h>
#include <stdint.h>
#include "signal/src/complex.h"
#include "signal/src/irfft.h"
#include "signal/src/kiss_fft_wrappers/kiss_fft_int32.h"
// TODO(b/286250473): remove namespace once de-duped libraries
namespace tflite {
namespace tflm_signal {
size_t IrfftInt32GetNeededMemory(int32_t fft_length) {
size_t state_size = 0;
kiss_fft_fixed32::kiss_fftr_alloc(fft_length, 1, nullptr, &state_size);
return state_size;
}
void* IrfftInt32Init(int32_t fft_length, void* state, size_t state_size) {
return kiss_fft_fixed32::kiss_fftr_alloc(fft_length, 1, state, &state_size);
}
void IrfftInt32Apply(void* state, const Complex<int32_t>* input,
int32_t* output) {
kiss_fft_fixed32::kiss_fftri(
static_cast<kiss_fft_fixed32::kiss_fftr_cfg>(state),
reinterpret_cast<const kiss_fft_fixed32::kiss_fft_cpx*>(input),
reinterpret_cast<kiss_fft_scalar*>(output));
}
} // namespace tflm_signal
} // namespace tflite
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "signal/src/log.h"
#include "signal/src/msb.h"
namespace tflite {
namespace tflm_signal {
namespace {
const uint16_t kLogLut[] = {
0, 224, 442, 654, 861, 1063, 1259, 1450, 1636, 1817, 1992, 2163,
2329, 2490, 2646, 2797, 2944, 3087, 3224, 3358, 3487, 3611, 3732, 3848,
3960, 4068, 4172, 4272, 4368, 4460, 4549, 4633, 4714, 4791, 4864, 4934,
5001, 5063, 5123, 5178, 5231, 5280, 5326, 5368, 5408, 5444, 5477, 5507,
5533, 5557, 5578, 5595, 5610, 5622, 5631, 5637, 5640, 5641, 5638, 5633,
5626, 5615, 5602, 5586, 5568, 5547, 5524, 5498, 5470, 5439, 5406, 5370,
5332, 5291, 5249, 5203, 5156, 5106, 5054, 5000, 4944, 4885, 4825, 4762,
4697, 4630, 4561, 4490, 4416, 4341, 4264, 4184, 4103, 4020, 3935, 3848,
3759, 3668, 3575, 3481, 3384, 3286, 3186, 3084, 2981, 2875, 2768, 2659,
2549, 2437, 2323, 2207, 2090, 1971, 1851, 1729, 1605, 1480, 1353, 1224,
1094, 963, 830, 695, 559, 421, 282, 142, 0, 0};
// Number of segments in the log lookup table. The table will be kLogSegments+1
// in length (with some padding).
// constexpr int kLogSegments = 128;
constexpr int kLogSegmentsLog2 = 7;
// Scale used by lookup table.
constexpr int kLogScale = 65536;
constexpr int kLogScaleLog2 = 16;
constexpr int kLogCoeff = 45426;
uint32_t Log2FractionPart32(uint32_t x, uint32_t log2x) {
// Part 1
int32_t frac = x - (1LL << log2x);
if (log2x < kLogScaleLog2) {
frac <<= kLogScaleLog2 - log2x;
} else {
frac >>= log2x - kLogScaleLog2;
}
// Part 2
const uint32_t base_seg = frac >> (kLogScaleLog2 - kLogSegmentsLog2);
const uint32_t seg_unit = (UINT32_C(1) << kLogScaleLog2) >> kLogSegmentsLog2;
// ASSERT(base_seg < kLogSegments);
const int32_t c0 = kLogLut[base_seg];
const int32_t c1 = kLogLut[base_seg + 1];
const int32_t seg_base = seg_unit * base_seg;
const int32_t rel_pos = ((c1 - c0) * (frac - seg_base)) >> kLogScaleLog2;
return frac + c0 + rel_pos;
}
} // namespace
// Calculate integer logarithm, 32 Bit version
uint32_t Log32(uint32_t x, uint32_t out_scale) {
// ASSERT(x != 0);
const uint32_t integer = MostSignificantBit32(x) - 1;
const uint32_t fraction = Log2FractionPart32(x, integer);
const uint32_t log2 = (integer << kLogScaleLog2) + fraction;
const uint32_t round = kLogScale / 2;
const uint32_t loge =
((static_cast<uint64_t>(kLogCoeff)) * log2 + round) >> kLogScaleLog2;
// Finally scale to our output scale
const uint32_t loge_scaled = (out_scale * loge + round) >> kLogScaleLog2;
return loge_scaled;
}
} // namespace tflm_signal
} // namespace tflite
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef SIGNAL_SRC_LOG_H_
#define SIGNAL_SRC_LOG_H_
#include <stdint.h>
namespace tflite {
namespace tflm_signal {
// TODO(b/286250473): remove namespace once de-duped libraries above
// Natural logarithm of an integer. The result is multiplied by out_scale
uint32_t Log32(uint32_t x, uint32_t out_scale);
} // namespace tflm_signal
} // namespace tflite
#endif // SIGNAL_SRC_LOG_H_
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "signal/src/max_abs.h"
// TODO(b/286250473): remove namespace once de-duped libraries
namespace tflite {
namespace tflm_signal {
int16_t MaxAbs16(const int16_t* input, int size) {
int16_t max = 0;
for (int i = 0; i < size; i++) {
const int16_t value = input[i];
if (value > max) {
max = value;
} else if (-value > max) {
max = -value;
}
}
return max;
}
} // namespace tflm_signal
} // namespace tflite
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef SIGNAL_SRC_MAX_ABS_H_
#define SIGNAL_SRC_MAX_ABS_H_
#include <stdint.h>
// TODO(b/286250473): remove namespace once de-duped libraries
namespace tflite {
namespace tflm_signal {
// Returns the maximum absolute value of the `size` elements in `input`
int16_t MaxAbs16(const int16_t* input, int size);
} // namespace tflm_signal
} // namespace tflite
#endif // SIGNAL_SRC_MAX_ABS_H_
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef SIGNAL_SRC_MSB_H_
#define SIGNAL_SRC_MSB_H_
#include <stdint.h>
namespace tflite {
namespace tflm_signal {
// TODO(b/286250473): remove namespace once de-duped libraries above
// Index of the most significant bit
uint32_t MostSignificantBit32(uint32_t x);
uint32_t MostSignificantBit64(uint64_t x);
} // namespace tflm_signal
} // namespace tflite
#endif // SIGNAL_SRC_MSB_H_
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "signal/src/msb.h"
namespace tflite {
namespace tflm_signal {
// TODO(b/286250473): remove namespace once de-duped libraries above
// TODO(b/291167350): can allow __builtin_clz to be used in more cases here
uint32_t MostSignificantBit32(uint32_t x) {
#if defined(__GNUC__)
if (x) {
return 32 - __builtin_clz(x);
}
return 32;
#else
uint32_t temp = 0;
while (x) {
x = x >> 1;
++temp;
}
return temp;
#endif
}
} // namespace tflm_signal
} // namespace tflite
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "signal/src/msb.h"
namespace tflite {
namespace tflm_signal {
// TODO(b/286250473): remove namespace once de-duped libraries above
uint32_t MostSignificantBit64(uint64_t x) {
#if defined(__GNUC__)
if (x) {
return 64 - __builtin_clzll(x);
}
return 64;
#else
uint32_t temp = 0;
while (x) {
x = x >> 1;
++temp;
}
return temp;
#endif
}
} // namespace tflm_signal
} // namespace tflite
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef SIGNAL_SRC_SQUARE_ROOT_H_
#define SIGNAL_SRC_SQUARE_ROOT_H_
#include <stdint.h>
namespace tflite {
namespace tflm_signal {
// TODO(b/286250473): remove namespace once de-duped libraries above
// Square root
uint16_t Sqrt32(uint32_t num);
uint32_t Sqrt64(uint64_t num);
} // namespace tflm_signal
} // namespace tflite
#endif // SIGNAL_SRC_SQUARE_ROOT_H_
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "signal/src/msb.h"
#include "signal/src/square_root.h"
namespace tflite {
namespace tflm_signal {
uint16_t Sqrt32(uint32_t num) {
if (num == 0) {
return 0;
};
uint32_t res = 0;
int max_bit_number = 32 - MostSignificantBit32(num);
max_bit_number |= 1;
uint32_t bit = 1u << (31 - max_bit_number);
int iterations = (31 - max_bit_number) / 2 + 1;
while (iterations--) {
if (num >= res + bit) {
num -= res + bit;
res = (res >> 1U) + bit;
} else {
res >>= 1U;
}
bit >>= 2U;
}
// Do rounding - if we have the bits.
if (num > res && res != 0xFFFF) ++res;
return res;
}
} // namespace tflm_signal
} // namespace tflite
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "signal/src/msb.h"
#include "signal/src/square_root.h"
namespace tflite {
namespace tflm_signal {
uint32_t Sqrt64(uint64_t num) {
// Take a shortcut and just use 32 bit operations if the upper word is all
// clear. This will cause a slight off by one issue for numbers close to 2^32,
// but it probably isn't going to matter (and gives us a big performance win).
if ((num >> 32) == 0) {
return Sqrt32(static_cast<uint32_t>(num));
}
uint64_t res = 0;
int max_bit_number = 64 - MostSignificantBit64(num);
max_bit_number |= 1;
uint64_t bit = UINT64_C(1) << (63 - max_bit_number);
int iterations = (63 - max_bit_number) / 2 + 1;
while (iterations--) {
if (num >= res + bit) {
num -= res + bit;
res = (res >> 1U) + bit;
} else {
res >>= 1U;
}
bit >>= 2U;
}
// Do rounding - if we have the bits.
if (num > res && res != 0xFFFFFFFFLL) ++res;
return res;
}
} // namespace tflm_signal
} // namespace tflite
......@@ -5,15 +5,54 @@ package(
licenses = ["notice"],
)
tflm_signal_kernel_library(
name = "delay_kernel",
srcs = ["delay_kernel.cc"],
deps = [
"//signal/src:circular_buffer",
"@tensorflow_cc_deps//:cc_library",
],
)
tflm_signal_kernel_library(
name = "energy_kernel",
srcs = ["energy_kernel.cc"],
deps = [
"//signal/src:complex",
"//signal/src:energy",
"@tensorflow_cc_deps//:cc_library",
],
)
tflm_signal_kernel_library(
name = "fft_kernel",
srcs = ["fft_kernels.cc"],
deps = [
"//signal/src:fft_auto_scale",
"//signal/src:irfft",
"//signal/src:rfft",
"@tensorflow_cc_deps//:cc_library",
],
)
tflm_signal_kernel_library(
name = "framer_kernel",
srcs = ["framer_kernel.cc"],
deps = [
"//signal/src:circular_buffer",
"@tensorflow_cc_deps//:cc_library",
],
)
tflm_signal_kernel_library(
name = "overlap_add_kernel",
srcs = ["overlap_add_kernel.cc"],
deps = [
"//signal/src:overlap_add",
"@tensorflow_cc_deps//:cc_library",
],
)
tflm_signal_kernel_library(
name = "window_kernel",
srcs = ["window_kernel.cc"],
......
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstdint>
#include "signal/src/circular_buffer.h"
#include "tensorflow/core/framework/op_kernel.h"
namespace tensorflow {
namespace signal {
class DelayOp : public tensorflow::OpKernel {
public:
explicit DelayOp(tensorflow::OpKernelConstruction* context)
: tensorflow::OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("delay_length", &delay_length_));
initialized_ = false;
}
~DelayOp() {}
void Compute(tensorflow::OpKernelContext* context) override {
const tensorflow::Tensor& input_tensor = context->input(0);
if (!initialized_) {
frame_size_ = input_tensor.flat_inner_dims<int16_t>().dimensions().at(1);
outer_dims_ = input_tensor.flat_inner_dims<int16_t>().dimensions().at(0);
state_tensors_.resize(outer_dims_);
circular_buffers_.resize(outer_dims_);
// Calculate the capacity of the circular buffer.
size_t capacity = frame_size_ + delay_length_;
size_t state_size =
tflite::tflm_signal::CircularBufferGetNeededMemory(capacity);
for (int i = 0; i < outer_dims_; i++) {
OP_REQUIRES_OK(
context,
context->allocate_temp(
DT_INT8, TensorShape({static_cast<int32_t>(state_size)}),
&state_tensors_[i]));
int8_t* state_ = state_tensors_[i].flat<int8_t>().data();
circular_buffers_[i] = tflite::tflm_signal::CircularBufferInit(
capacity, state_, state_size);
tflite::tflm_signal::CircularBufferWriteZeros(circular_buffers_[i],
delay_length_);
}
initialized_ = true;
}
TensorShape output_shape = input_tensor.shape();
tensorflow::Tensor* output_tensor = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output(0, output_shape, &output_tensor));
for (int dim_index = 0, sample_index = 0; dim_index < outer_dims_;
dim_index++, sample_index += frame_size_) {
tflite::tflm_signal::CircularBufferWrite(
circular_buffers_[dim_index],
&input_tensor.flat<int16_t>().data()[sample_index], frame_size_);
tflite::tflm_signal::CircularBufferGet(
circular_buffers_[dim_index], frame_size_,
&(reinterpret_cast<int16_t*>(output_tensor->data()))[sample_index]);
tflite::tflm_signal::CircularBufferDiscard(circular_buffers_[dim_index],
frame_size_);
}
}
private:
bool initialized_;
int frame_size_;
int delay_length_;
int outer_dims_;
std::vector<Tensor> state_tensors_;
std::vector<struct tflite::tflm_signal::CircularBuffer*> circular_buffers_;
};
// TODO(b/286250473): change back name after name clash resolved
REGISTER_KERNEL_BUILDER(Name("SignalDelay").Device(tensorflow::DEVICE_CPU),
DelayOp);
} // namespace signal
} // namespace tensorflow
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "signal/src/complex.h"
#include "signal/src/energy.h"
#include "tensorflow/core/framework/op_kernel.h"
namespace tensorflow {
namespace signal {
class EnergyOp : public tensorflow::OpKernel {
public:
explicit EnergyOp(tensorflow::OpKernelConstruction* context)
: tensorflow::OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("start_index", &start_index_));
OP_REQUIRES_OK(context, context->GetAttr("end_index", &end_index_));
}
void Compute(tensorflow::OpKernelContext* context) override {
const tensorflow::Tensor& input_tensor = context->input(0);
const int16_t* input = input_tensor.flat<int16_t>().data();
tensorflow::Tensor* output_tensor = nullptr;
// The input is complex. The output is real.
int output_size = input_tensor.flat<int16>().size() >> 1;
OP_REQUIRES_OK(context,
context->allocate_output(0, {output_size}, &output_tensor));
uint32* output = output_tensor->flat<uint32>().data();
tflite::tflm_signal::SpectrumToEnergy(
reinterpret_cast<const Complex<int16_t>*>(input), start_index_,
end_index_, output);
}
private:
int start_index_;
int end_index_;
};
// TODO(b/286250473): change back name after name clash resolved
REGISTER_KERNEL_BUILDER(Name("SignalEnergy").Device(tensorflow::DEVICE_CPU),
EnergyOp);
} // namespace signal
} // namespace tensorflow
......@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "signal/src/fft_auto_scale.h"
#include "signal/src/irfft.h"
#include "signal/src/rfft.h"
#include "tensorflow/core/framework/op_kernel.h"
......@@ -81,7 +83,82 @@ class RfftOp : public tensorflow::OpKernel {
Tensor state_tensor_;
};
// get_needed_memory_func(), init_func(), apply_func()
// are type specific implementations of the IRFFT functions.
// See irfft.h included above for documentation
template <typename T, size_t (*get_needed_memory_func)(int32_t),
void* (*init_func)(int32_t, void*, size_t),
void (*apply_func)(void*, const Complex<T>* input, T*)>
class IrfftOp : public tensorflow::OpKernel {
public:
explicit IrfftOp(tensorflow::OpKernelConstruction* context)
: tensorflow::OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("fft_length", &fft_length_));
// Subband array size is the number of subbands * 2 because each coefficient
// is complex.
subband_array_size_ = ((fft_length_ / 2) + 1) * 2;
size_t state_size = (*get_needed_memory_func)(fft_length_);
OP_REQUIRES_OK(context, context->allocate_temp(
DT_INT8, TensorShape({(int32_t)state_size}),
&state_handle_));
state_ = state_handle_.flat<int8_t>().data();
(*init_func)(fft_length_, state_, state_size);
}
void Compute(tensorflow::OpKernelContext* context) override {
const tensorflow::Tensor& input_tensor = context->input(0);
const T* input = input_tensor.flat<T>().data();
TensorShape output_shape = input_tensor.shape();
output_shape.set_dim(output_shape.dims() - 1, fft_length_);
// Create an output tensor
tensorflow::Tensor* output_tensor = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output(0, output_shape, &output_tensor));
T* output = output_tensor->flat<T>().data();
int outer_dims = input_tensor.flat_inner_dims<T, 2>().dimensions().at(0);
for (int i = 0; i < outer_dims; i++) {
(*apply_func)(
state_,
reinterpret_cast<const Complex<T>*>(&input[i * subband_array_size_]),
&output[i * fft_length_]);
}
}
private:
int fft_length_;
int subband_array_size_;
int8_t* state_;
Tensor state_handle_;
};
class FftAutoScaleOp : public tensorflow::OpKernel {
public:
explicit FftAutoScaleOp(tensorflow::OpKernelConstruction* context)
: tensorflow::OpKernel(context) {}
void Compute(tensorflow::OpKernelContext* context) override {
const tensorflow::Tensor& input_tensor = context->input(0);
const int16_t* input = input_tensor.flat<int16_t>().data();
// Create an output tensor
tensorflow::Tensor* output_tensor = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
&output_tensor));
int16_t* output = output_tensor->flat<int16_t>().data();
tensorflow::Tensor* scale_bit_tensor = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(1, {}, &scale_bit_tensor));
scale_bit_tensor->scalar<int32_t>()() = tflite::tflm_signal::FftAutoScale(
input, output_tensor->NumElements(), output);
}
};
// TODO(b/286250473): change back name after name clash resolved
REGISTER_KERNEL_BUILDER(
Name("SignalFftAutoScale").Device(tensorflow::DEVICE_CPU), FftAutoScaleOp);
REGISTER_KERNEL_BUILDER(
Name("SignalRfft")
.Device(tensorflow::DEVICE_CPU)
......@@ -101,5 +178,27 @@ REGISTER_KERNEL_BUILDER(
RfftOp<int32_t, DT_INT32, ::tflm_signal::RfftInt32GetNeededMemory,
::tflm_signal::RfftInt32Init, ::tflm_signal::RfftInt32Apply>);
REGISTER_KERNEL_BUILDER(
Name("SignalIrfft")
.Device(tensorflow::DEVICE_CPU)
.TypeConstraint<float>("T"),
IrfftOp<float, tflite::tflm_signal::IrfftFloatGetNeededMemory,
tflite::tflm_signal::IrfftFloatInit,
tflite::tflm_signal::IrfftFloatApply>);
REGISTER_KERNEL_BUILDER(
Name("SignalIrfft")
.Device(tensorflow::DEVICE_CPU)
.TypeConstraint<int16>("T"),
IrfftOp<int16_t, tflite::tflm_signal::IrfftInt16GetNeededMemory,
tflite::tflm_signal::IrfftInt16Init,
tflite::tflm_signal::IrfftInt16Apply>);
REGISTER_KERNEL_BUILDER(
Name("SignalIrfft")
.Device(tensorflow::DEVICE_CPU)
.TypeConstraint<int32>("T"),
IrfftOp<int32_t, tflite::tflm_signal::IrfftInt32GetNeededMemory,
tflite::tflm_signal::IrfftInt32Init,
tflite::tflm_signal::IrfftInt32Apply>);
} // namespace signal
} // namespace tensorflow
\ No newline at end of file
此差异已折叠。
此差异已折叠。
......@@ -5,6 +5,22 @@ package(
licenses = ["notice"],
)
tflm_signal_kernel_library(
name = "delay_op",
srcs = ["delay_op.cc"],
deps = [
"@tensorflow_cc_deps//:cc_library",
],
)
tflm_signal_kernel_library(
name = "energy_op",
srcs = ["energy_op.cc"],
deps = [
"@tensorflow_cc_deps//:cc_library",
],
)
tflm_signal_kernel_library(
name = "fft_ops",
srcs = ["fft_ops.cc"],
......@@ -13,6 +29,22 @@ tflm_signal_kernel_library(
],
)
tflm_signal_kernel_library(
name = "framer_op",
srcs = ["framer_op.cc"],
deps = [
"@tensorflow_cc_deps//:cc_library",
],
)
tflm_signal_kernel_library(
name = "overlap_add_op",
srcs = ["overlap_add_op.cc"],
deps = [
"@tensorflow_cc_deps//:cc_library",
],
)
tflm_signal_kernel_library(
name = "window_op",
srcs = ["window_op.cc"],
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
......@@ -99,8 +99,7 @@ TfLiteStatus CircularBufferEval(TfLiteContext* context, TfLiteNode* node) {
if (--data->cycles_until_run != 0) {
// Signal the interpreter to end current run if the delay before op invoke
// has not been reached.
// TODO(b/149795762): Add kTfLiteAbort to TfLiteStatus enum.
return static_cast<TfLiteStatus>(kTfLiteAbort);
return kTfLiteAbort;
}
data->cycles_until_run = data->cycles_max;
......
......@@ -30,9 +30,6 @@ extern const int kCircularBufferOutputTensor;
// Elements in the vectors are ordered alphabetically by parameter name.
extern const int kCircularBufferCyclesMaxIndex; // 'cycles_max'
// TODO(b/149795762): Add this to TfLiteStatus enum.
extern const TfLiteStatus kTfLiteAbort;
// These fields control the stride period of a strided streaming model. This op
// returns kTfLiteAbort until cycles_until_run-- is zero. At this time,
// cycles_until_run is reset to cycles_max.
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。