未验证 提交 55037d2d 编写于 作者: S suleshahid 提交者: GitHub

Adds IRFFT Op to Signal Library (#2137)

Inverse-RFFT as part of Signal library ops.
Testing via current FFT Op tests.

BUG=[287346710](http://b/287346710)
上级 ed11500a
......@@ -64,6 +64,7 @@ PythonOpsResolver::PythonOpsResolver() {
AddGreaterEqual();
AddHardSwish();
AddIf();
AddIrfft();
AddL2Normalization();
AddL2Pool2D();
AddLeakyRelu();
......
......@@ -80,7 +80,9 @@ def _fft_auto_scale_wrapper(fft_auto_scale_fn, default_name):
rfft = _fft_wrapper(gen_fft_ops.signal_rfft, "signal_rfft")
irfft = _fft_wrapper(gen_fft_ops.signal_irfft, "signal_irfft")
fft_auto_scale = _fft_auto_scale_wrapper(gen_fft_ops.signal_fft_auto_scale,
"signal_fft_auto_scale")
tf.no_gradient("signal_rfft")
tf.no_gradient("signal_irfft")
tf.no_gradient("signal_fft_auto_scale")
......@@ -251,6 +251,61 @@ class RfftOpTest(tf.test.TestCase):
with self.assertRaises((tf.errors.InvalidArgumentError, ValueError)):
self.evaluate(fft_ops.rfft(fft_input, 127))
def testIrfftTest(self):
for dtype in [np.int16, np.int32, np.float32]:
fft_length = fft_ops._MIN_FFT_LENGTH
while fft_length <= fft_ops._MAX_FFT_LENGTH:
if dtype == np.float32:
# Random input in the range [-1, 1)
fft_input = np.random.random(fft_length).astype(dtype) * 2 - 1
else:
fft_input = np.random.randint(
np.iinfo(np.int16).min,
np.iinfo(np.int16).max + 1, fft_length).astype(dtype)
fft_output = self.evaluate(fft_ops.rfft(fft_input, fft_length))
self.assertEqual(fft_output.shape[0], (fft_length / 2 + 1) * 2)
ifft_output = self.evaluate(fft_ops.irfft(fft_output, fft_length))
self.assertEqual(ifft_output.shape[0], fft_length)
# Output of integer RFFT and IRFFT is scaled by 1/fft_length
if dtype == np.int16:
self.assertArrayNear(fft_input,
ifft_output.astype(np.int32) * fft_length, 6500)
elif dtype == np.int32:
self.assertArrayNear(fft_input,
ifft_output.astype(np.int32) * fft_length, 7875)
else:
self.assertArrayNear(fft_input, ifft_output, 5e-7)
fft_length = 2 * fft_length
def testIrfftLargeOuterDimension(self):
for dtype in [np.int16, np.int32, np.float32]:
fft_length = fft_ops._MIN_FFT_LENGTH
while fft_length <= fft_ops._MAX_FFT_LENGTH:
if dtype == np.float32:
# Random input in the range [-1, 1)
fft_input = np.random.random([2, 5, fft_length
]).astype(dtype) * 2 - 1
else:
fft_input = np.random.randint(
np.iinfo(np.int16).min,
np.iinfo(np.int16).max + 1, [2, 5, fft_length]).astype(dtype)
fft_output = self.evaluate(fft_ops.rfft(fft_input, fft_length))
self.assertEqual(fft_output.shape[-1], (fft_length / 2 + 1) * 2)
ifft_output = self.evaluate(fft_ops.irfft(fft_output, fft_length))
self.assertEqual(ifft_output.shape[-1], fft_length)
# Output of integer RFFT and IRFFT is scaled by 1/fft_length
if dtype == np.int16:
self.assertAllClose(fft_input,
ifft_output.astype(np.int32) * fft_length,
atol=7875)
elif dtype == np.int32:
self.assertAllClose(fft_input,
ifft_output.astype(np.int32) * fft_length,
atol=7875)
else:
self.assertAllClose(fft_input, ifft_output, rtol=5e-7, atol=5e-7)
fft_length = 2 * fft_length
def testAutoScale(self):
self.SingleFftAutoScaleTest('testdata/fft_auto_scale_test1.txt')
......
......@@ -16,12 +16,14 @@ cc_library(
"filter_bank_spectral_subtraction.cc",
"filter_bank_square_root.cc",
"framer.cc",
"irfft.cc",
"overlap_add.cc",
"rfft.cc",
"stacker.cc",
"window.cc",
],
hdrs = [
"irfft.h",
"rfft.h",
],
copts = micro_copts(),
......@@ -36,6 +38,7 @@ cc_library(
"//signal/src:filter_bank_log",
"//signal/src:filter_bank_spectral_subtraction",
"//signal/src:filter_bank_square_root",
"//signal/src:irfft",
"//signal/src:overlap_add",
"//signal/src:rfft",
"//signal/src:window",
......
......@@ -303,6 +303,164 @@ TF_LITE_MICRO_TEST(RfftTestSize512Int32) {
g_gen_data_size_fft_length_512_int32, output, 0));
}
TF_LITE_MICRO_TEST(IrfftTestLength64Float) {
constexpr int kOutputLen = 64;
int input_shape[] = {1, 66};
const float input[] = {256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0};
int output_shape[] = {1, kOutputLen};
const float golden[] = {256, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
float output[kOutputLen];
const TFLMRegistration* registration =
tflite::tflm_signal::Register_IRFFT_FLOAT();
TF_LITE_MICRO_EXPECT_EQ(
kTfLiteOk, tflite::testing::TestFFT<float>(
input_shape, input, output_shape, golden, *registration,
g_gen_data_fft_length_64_float,
g_gen_data_size_fft_length_64_int16, output, 1e-7));
}
TF_LITE_MICRO_TEST(IrfftTestLength64Int16) {
constexpr int kOutputLen = 64;
int input_shape[] = {1, 66};
const int16_t input[] = {
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0};
int output_shape[] = {1, kOutputLen};
const int16_t golden[] = {256, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
int16_t output[kOutputLen];
const TFLMRegistration* registration =
tflite::tflm_signal::Register_IRFFT_INT16();
TF_LITE_MICRO_EXPECT_EQ(
kTfLiteOk, tflite::testing::TestFFT<int16_t>(
input_shape, input, output_shape, golden, *registration,
g_gen_data_fft_length_64_int16,
g_gen_data_size_fft_length_64_int16, output, 0));
}
TF_LITE_MICRO_TEST(IrfftTestLength64Int32) {
constexpr int kOutputLen = 64;
int input_shape[] = {1, 66};
const int32_t input[] = {
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0};
int output_shape[] = {1, kOutputLen};
const int32_t golden[] = {256, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
int32_t output[kOutputLen];
const TFLMRegistration* registration =
tflite::tflm_signal::Register_IRFFT_INT32();
TF_LITE_MICRO_EXPECT_EQ(
kTfLiteOk, tflite::testing::TestFFT<int32_t>(
input_shape, input, output_shape, golden, *registration,
g_gen_data_fft_length_64_int32,
g_gen_data_size_fft_length_64_int32, output, 0));
}
TF_LITE_MICRO_TEST(IrfftTestLength64Int32OuterDims4) {
constexpr int kOutputLen = 64;
constexpr int kOuterDim = 2;
int input_shape[] = {3, kOuterDim, kOuterDim, 66};
const int32_t input[] = {
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0};
int output_shape[] = {3, kOuterDim, kOuterDim, kOutputLen};
const int32_t golden[] = {
256, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 256, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 256, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 256, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
int32_t output[kOuterDim * kOuterDim * kOutputLen];
const TFLMRegistration* registration =
tflite::tflm_signal::Register_IRFFT_INT32();
TF_LITE_MICRO_EXPECT_EQ(
kTfLiteOk, tflite::testing::TestFFT<int32_t>(
input_shape, input, output_shape, golden, *registration,
g_gen_data_fft_length_64_int32,
g_gen_data_size_fft_length_64_int32, output, 0));
}
TF_LITE_MICRO_TEST(IrfftTestLength512Float) {
constexpr int kOutputLen = 512;
int input_shape[] = {1, 514};
int output_shape[] = {1, kOutputLen};
float output[kOutputLen];
const TFLMRegistration* registration =
tflite::tflm_signal::Register_IRFFT_FLOAT();
TF_LITE_MICRO_EXPECT_EQ(
kTfLiteOk, tflite::testing::TestFFT<float>(
input_shape, tflite::kIrfftFloatLength512Input,
output_shape, tflite::kIrfftFloatLength512Golden,
*registration, g_gen_data_fft_length_512_float,
g_gen_data_size_fft_length_512_float, output, 1e-7));
}
TF_LITE_MICRO_TEST(IrfftTestLength512Int16) {
constexpr int kOutputLen = 512;
int input_shape[] = {1, 514};
int output_shape[] = {1, kOutputLen};
int16_t output[kOutputLen];
const TFLMRegistration* registration =
tflite::tflm_signal::Register_IRFFT_INT16();
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
tflite::testing::TestFFT<int16_t>(
input_shape, tflite::kIrfftInt16Length512Input,
output_shape, tflite::kIrfftInt16Length512Golden,
*registration, g_gen_data_fft_length_512_int16,
g_gen_data_size_fft_length_512_int16, output, 0));
}
TF_LITE_MICRO_TEST(IrfftTestLength512Int32) {
constexpr int kOutputLen = 512;
int input_shape[] = {1, 514};
int output_shape[] = {1, kOutputLen};
int32_t output[kOutputLen];
const TFLMRegistration* registration =
tflite::tflm_signal::Register_IRFFT_INT32();
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
tflite::testing::TestFFT<int32_t>(
input_shape, tflite::kIrfftInt32Length512Input,
output_shape, tflite::kIrfftInt32Length512Golden,
*registration, g_gen_data_fft_length_512_int32,
g_gen_data_size_fft_length_512_int32, output, 0));
}
TF_LITE_MICRO_TEST(FftAutoScaleTestSmall) {
constexpr int kTensorsSize = 8;
int shape[] = {1, 8};
......
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "signal/src/irfft.h"
#include <math.h>
#include <stddef.h>
#include <stdint.h>
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/micro/flatbuffer_utils.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "tensorflow/lite/portable_type_to_tflitetype.h"
namespace tflite {
namespace {
constexpr int kInputTensor = 0;
constexpr int kOutputTensor = 0;
// Indices into the init flexbuffer's vector.
// The parameter's name is in the comment that follows.
// Elements in the vectors are ordered alphabetically by parameter name.
// 'T' is added implicitly by the TensorFlow framework when the type is resolved
// during graph construction.
// constexpr int kTypeIndex = 0; // 'T' (unused)
constexpr int kFftLengthIndex = 1; // 'fft_length'
struct TfLiteAudioFrontendIrfftParams {
int32_t fft_length;
int32_t input_size;
int32_t input_length;
int32_t output_length;
TfLiteType fft_type;
int8_t* state;
};
template <typename T, size_t (*get_needed_memory_func)(int32_t),
void* (*init_func)(int32_t, void*, size_t)>
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
auto* params = static_cast<TfLiteAudioFrontendIrfftParams*>(
context->AllocatePersistentBuffer(
context, sizeof(TfLiteAudioFrontendIrfftParams)));
if (params == nullptr) {
return nullptr;
}
tflite::FlexbufferWrapper fbw(reinterpret_cast<const uint8_t*>(buffer),
length);
params->fft_length = fbw.ElementAsInt32(kFftLengthIndex);
params->fft_type = typeToTfLiteType<T>();
size_t state_size = (*get_needed_memory_func)(params->fft_length);
params->state = reinterpret_cast<int8_t*>(
context->AllocatePersistentBuffer(context, state_size * sizeof(int8_t)));
if (params->state == nullptr) {
return nullptr;
}
(*init_func)(params->fft_length, params->state, state_size);
return params;
}
template <TfLiteType TfLiteTypeEnum>
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
MicroContext* micro_context = GetMicroContext(context);
TfLiteTensor* input =
micro_context->AllocateTempInputTensor(node, kInputTensor);
TF_LITE_ENSURE(context, input != nullptr);
TfLiteTensor* output =
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
TF_LITE_ENSURE_EQ(context, NumDimensions(input), NumDimensions(output));
TF_LITE_ENSURE_TYPES_EQ(context, input->type, TfLiteTypeEnum);
TF_LITE_ENSURE_TYPES_EQ(context, output->type, TfLiteTypeEnum);
auto* params =
reinterpret_cast<TfLiteAudioFrontendIrfftParams*>(node->user_data);
RuntimeShape input_shape = GetTensorShape(input);
RuntimeShape output_shape = GetTensorShape(output);
// Divide by 2 because input is complex.
params->input_length =
input_shape.Dims(input_shape.DimensionsCount() - 1) / 2;
params->input_size = input_shape.FlatSize() / 2;
params->output_length = output_shape.Dims(output_shape.DimensionsCount() - 1);
micro_context->DeallocateTempTfLiteTensor(input);
micro_context->DeallocateTempTfLiteTensor(output);
return kTfLiteOk;
}
template <typename T, void (*apply_func)(void*, const Complex<T>* input, T*)>
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params =
reinterpret_cast<TfLiteAudioFrontendIrfftParams*>(node->user_data);
const TfLiteEvalTensor* input =
tflite::micro::GetEvalInput(context, node, kInputTensor);
TfLiteEvalTensor* output =
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
const Complex<T>* input_data =
tflite::micro::GetTensorData<Complex<T>>(input);
T* output_data = tflite::micro::GetTensorData<T>(output);
for (int input_idx = 0, output_idx = 0; input_idx < params->input_size;
input_idx += params->input_length, output_idx += params->output_length) {
(*apply_func)(params->state, &input_data[input_idx],
&output_data[output_idx]);
}
return kTfLiteOk;
}
void* InitAll(TfLiteContext* context, const char* buffer, size_t length) {
const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap();
auto tensor_type = static_cast<tflite::TensorType>(m["T"].AsInt32());
switch (tensor_type) {
case TensorType_INT16: {
return Init<int16_t, tflm_signal::IrfftInt16GetNeededMemory,
tflm_signal::IrfftInt16Init>(context, buffer, length);
}
case TensorType_INT32: {
return Init<int32_t, tflm_signal::IrfftInt32GetNeededMemory,
tflm_signal::IrfftInt32Init>(context, buffer, length);
}
case TensorType_FLOAT32: {
return Init<float, tflm_signal::IrfftFloatGetNeededMemory,
tflm_signal::IrfftFloatInit>(context, buffer, length);
}
default:
return nullptr;
}
}
TfLiteStatus PrepareAll(TfLiteContext* context, TfLiteNode* node) {
auto* params =
reinterpret_cast<TfLiteAudioFrontendIrfftParams*>(node->user_data);
switch (params->fft_type) {
case kTfLiteInt16: {
return Prepare<kTfLiteInt16>(context, node);
}
case kTfLiteInt32: {
return Prepare<kTfLiteInt32>(context, node);
}
case kTfLiteFloat32: {
return Prepare<kTfLiteFloat32>(context, node);
}
default:
return kTfLiteError;
}
}
TfLiteStatus EvalAll(TfLiteContext* context, TfLiteNode* node) {
auto* params =
reinterpret_cast<TfLiteAudioFrontendIrfftParams*>(node->user_data);
switch (params->fft_type) {
case kTfLiteInt16: {
return Eval<int16_t, tflm_signal::IrfftInt16Apply>(context, node);
}
case kTfLiteInt32: {
return Eval<int32_t, tflm_signal::IrfftInt32Apply>(context, node);
}
case kTfLiteFloat32: {
return Eval<float, tflm_signal::IrfftFloatApply>(context, node);
}
default:
return kTfLiteError;
}
}
} // namespace
// TODO(b/286250473): remove namespace once de-duped libraries
namespace tflm_signal {
TFLMRegistration* Register_IRFFT() {
static TFLMRegistration r =
tflite::micro::RegisterOp(InitAll, PrepareAll, EvalAll);
return &r;
}
TFLMRegistration* Register_IRFFT_FLOAT() {
static TFLMRegistration r = tflite::micro::RegisterOp(
Init<float, IrfftFloatGetNeededMemory, IrfftFloatInit>,
Prepare<kTfLiteFloat32>, Eval<float, IrfftFloatApply>);
return &r;
}
TFLMRegistration* Register_IRFFT_INT16() {
static TFLMRegistration r = tflite::micro::RegisterOp(
Init<int16_t, IrfftInt16GetNeededMemory, IrfftInt16Init>,
Prepare<kTfLiteInt16>, Eval<int16_t, IrfftInt16Apply>);
return &r;
}
TFLMRegistration* Register_IRFFT_INT32() {
static TFLMRegistration r = tflite::micro::RegisterOp(
Init<int32_t, IrfftInt32GetNeededMemory, IrfftInt32Init>,
Prepare<kTfLiteInt32>, Eval<int32_t, IrfftInt32Apply>);
return &r;
}
} // namespace tflm_signal
} // namespace tflite
\ No newline at end of file
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef SIGNAL_MICRO_KERNELS_IRFFT_H_
#define SIGNAL_MICRO_KERNELS_IRFFT_H_
#include "tensorflow/lite/micro/micro_common.h"
namespace tflite {
namespace tflm_signal {
TFLMRegistration* Register_IRFFT();
TFLMRegistration* Register_IRFFT_FLOAT();
TFLMRegistration* Register_IRFFT_INT16();
TFLMRegistration* Register_IRFFT_INT32();
} // namespace tflm_signal
} // namespace tflite
#endif // SIGNAL_MICRO_KERNELS_IRFFT_H_
......@@ -18,6 +18,20 @@ cc_library(
],
)
cc_library(
name = "irfft",
srcs = [
"irfft_float.cc",
"irfft_int16.cc",
"irfft_int32.cc",
],
hdrs = ["irfft.h"],
deps = [
":complex",
"//signal/src/kiss_fft_wrappers",
],
)
cc_library(
name = "max_abs",
srcs = ["max_abs.cc"],
......
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef SIGNAL_SRC_IRFFT_H_
#define SIGNAL_SRC_IRFFT_H_
#include <stddef.h>
#include <stdint.h>
#include "signal/src/complex.h"
// TODO(b/286250473): remove namespace once de-duped libraries
namespace tflite {
namespace tflm_signal {
// IRFFT (Inverse Real Fast Fourier Transform)
// IFFT for real valued time domain outputs.
// 16-bit Integer input/output
// Returns the size of the memory that an IRFFT of `fft_length` needs
size_t IrfftInt16GetNeededMemory(int32_t fft_length);
// Initialize the state of an IRFFT of `fft_length`
// `state` points to an opaque state of size `state_size`, which
// must be greater or equal to the value returned by
// IrfftGetNeededMemory(fft_length). Fails if it isn't.
void* IrfftInt16Init(int32_t fft_length, void* state, size_t state_size);
// Applies IRFFT to `input` and writes the result to `output`
// * `input` must be of size `fft_length` elements (see IRfftInit)
// * `output` must be of size output
void IrfftInt16Apply(void* state, const Complex<int16_t>* input,
int16_t* output);
// 32-bit Integer input/output
// Returns the size of the memory that an IRFFT of `fft_length` needs
size_t IrfftInt32GetNeededMemory(int32_t fft_length);
// Initialize the state of an IRFFT of `fft_length`
// `state` points to an opaque state of size `state_size`, which
// must be greater or equal to the value returned by
// IrfftGetNeededMemory(fft_length). Fails if it isn't.
void* IrfftInt32Init(int32_t fft_length, void* state, size_t state_size);
// Applies IRFFT to `input` and writes the result to `output`
// * `input` must be of size `fft_length` elements (see IRfftInit)
// * `output` must be of size output
void IrfftInt32Apply(void* state, const Complex<int32_t>* input,
int32_t* output);
// Floating point input/output
// Returns the size of the memory that an IRFFT of `fft_length` needs
size_t IrfftFloatGetNeededMemory(int32_t fft_length);
// Initialize the state of an IRFFT of `fft_length`
// `state` points to an opaque state of size `state_size`, which
// must be greater or equal to the value returned by
// IrfftGetNeededMemory(fft_length). Fails if it isn't.
void* IrfftFloatInit(int32_t fft_length, void* state, size_t state_size);
// Applies IRFFT to `input` and writes the result to `output`
// * `input` must be of size `fft_length` elements (see IRfftInit)
// * `output` must be of size output
void IrfftFloatApply(void* state, const Complex<float>* input, float* output);
} // namespace tflm_signal
} // namespace tflite
#endif // SIGNAL_SRC_IRFFT_H_
\ No newline at end of file
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <stddef.h>
#include <stdint.h>
#include "signal/src/complex.h"
#include "signal/src/irfft.h"
#include "signal/src/kiss_fft_wrappers/kiss_fft_float.h"
// TODO(b/286250473): remove namespace once de-duped libraries
namespace tflite {
namespace tflm_signal {
struct IrfftFloatState {
int32_t fft_length;
kiss_fft_float::kiss_fftr_cfg cfg;
};
size_t IrfftFloatGetNeededMemory(int32_t fft_length) {
size_t cfg_size = 0;
kiss_fft_float::kiss_fftr_alloc(fft_length, 1, nullptr, &cfg_size);
return sizeof(IrfftFloatState) + cfg_size;
}
void* IrfftFloatInit(int32_t fft_length, void* state, size_t state_size) {
IrfftFloatState* irfft_float_state = static_cast<IrfftFloatState*>(state);
irfft_float_state->cfg =
reinterpret_cast<kiss_fft_float::kiss_fftr_cfg>(irfft_float_state + 1);
irfft_float_state->fft_length = fft_length;
size_t cfg_size = state_size - sizeof(IrfftFloatState);
return kiss_fft_float::kiss_fftr_alloc(fft_length, 1, irfft_float_state->cfg,
&cfg_size);
}
void IrfftFloatApply(void* state, const Complex<float>* input, float* output) {
IrfftFloatState* irfft_float_state = static_cast<IrfftFloatState*>(state);
kiss_fft_float::kiss_fftri(
static_cast<kiss_fft_float::kiss_fftr_cfg>(irfft_float_state->cfg),
reinterpret_cast<const kiss_fft_float::kiss_fft_cpx*>(input),
reinterpret_cast<kiss_fft_scalar*>(output));
// KissFFT scales the IRFFT output by the FFT length.
// KissFFT's nfft is the complex FFT length, which is half the real FFT's
// length. Compensate.
const int fft_length = irfft_float_state->fft_length;
for (int i = 0; i < fft_length; i++) {
output[i] /= fft_length;
}
}
} // namespace tflm_signal
} // namespace tflite
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <stddef.h>
#include <stdint.h>
#include "signal/src/complex.h"
#include "signal/src/irfft.h"
#include "signal/src/kiss_fft_wrappers/kiss_fft_int16.h"
// TODO(b/286250473): remove namespace once de-duped libraries
namespace tflite {
namespace tflm_signal {
size_t IrfftInt16GetNeededMemory(int32_t fft_length) {
size_t state_size = 0;
kiss_fft_fixed16::kiss_fftr_alloc(fft_length, 1, nullptr, &state_size);
return state_size;
}
void* IrfftInt16Init(int32_t fft_length, void* state, size_t state_size) {
return kiss_fft_fixed16::kiss_fftr_alloc(fft_length, 1, state, &state_size);
}
void IrfftInt16Apply(void* state, const Complex<int16_t>* input,
int16_t* output) {
kiss_fft_fixed16::kiss_fftri(
static_cast<kiss_fft_fixed16::kiss_fftr_cfg>(state),
reinterpret_cast<const kiss_fft_fixed16::kiss_fft_cpx*>(input),
reinterpret_cast<kiss_fft_scalar*>(output));
}
} // namespace tflm_signal
} // namespace tflite
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <stddef.h>
#include <stdint.h>
#include "signal/src/complex.h"
#include "signal/src/irfft.h"
#include "signal/src/kiss_fft_wrappers/kiss_fft_int32.h"
// TODO(b/286250473): remove namespace once de-duped libraries
namespace tflite {
namespace tflm_signal {
size_t IrfftInt32GetNeededMemory(int32_t fft_length) {
size_t state_size = 0;
kiss_fft_fixed32::kiss_fftr_alloc(fft_length, 1, nullptr, &state_size);
return state_size;
}
void* IrfftInt32Init(int32_t fft_length, void* state, size_t state_size) {
return kiss_fft_fixed32::kiss_fftr_alloc(fft_length, 1, state, &state_size);
}
void IrfftInt32Apply(void* state, const Complex<int32_t>* input,
int32_t* output) {
kiss_fft_fixed32::kiss_fftri(
static_cast<kiss_fft_fixed32::kiss_fftr_cfg>(state),
reinterpret_cast<const kiss_fft_fixed32::kiss_fft_cpx*>(input),
reinterpret_cast<kiss_fft_scalar*>(output));
}
} // namespace tflm_signal
} // namespace tflite
......@@ -19,6 +19,7 @@ tflm_signal_kernel_library(
srcs = ["fft_kernels.cc"],
deps = [
"//signal/src:fft_auto_scale",
"//signal/src:irfft",
"//signal/src:rfft",
"@tensorflow_cc_deps//:cc_library",
],
......
......@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "signal/src/fft_auto_scale.h"
#include "signal/src/irfft.h"
#include "signal/src/rfft.h"
#include "tensorflow/core/framework/op_kernel.h"
......@@ -82,6 +83,58 @@ class RfftOp : public tensorflow::OpKernel {
Tensor state_tensor_;
};
// get_needed_memory_func(), init_func(), apply_func()
// are type specific implementations of the IRFFT functions.
// See irfft.h included above for documentation
template <typename T, size_t (*get_needed_memory_func)(int32_t),
void* (*init_func)(int32_t, void*, size_t),
void (*apply_func)(void*, const Complex<T>* input, T*)>
class IrfftOp : public tensorflow::OpKernel {
public:
explicit IrfftOp(tensorflow::OpKernelConstruction* context)
: tensorflow::OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("fft_length", &fft_length_));
// Subband array size is the number of subbands * 2 because each coefficient
// is complex.
subband_array_size_ = ((fft_length_ / 2) + 1) * 2;
size_t state_size = (*get_needed_memory_func)(fft_length_);
OP_REQUIRES_OK(context, context->allocate_temp(
DT_INT8, TensorShape({(int32_t)state_size}),
&state_handle_));
state_ = state_handle_.flat<int8_t>().data();
(*init_func)(fft_length_, state_, state_size);
}
void Compute(tensorflow::OpKernelContext* context) override {
const tensorflow::Tensor& input_tensor = context->input(0);
const T* input = input_tensor.flat<T>().data();
TensorShape output_shape = input_tensor.shape();
output_shape.set_dim(output_shape.dims() - 1, fft_length_);
// Create an output tensor
tensorflow::Tensor* output_tensor = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output(0, output_shape, &output_tensor));
T* output = output_tensor->flat<T>().data();
int outer_dims = input_tensor.flat_inner_dims<T, 2>().dimensions().at(0);
for (int i = 0; i < outer_dims; i++) {
(*apply_func)(
state_,
reinterpret_cast<const Complex<T>*>(&input[i * subband_array_size_]),
&output[i * fft_length_]);
}
}
private:
int fft_length_;
int subband_array_size_;
int8_t* state_;
Tensor state_handle_;
};
class FftAutoScaleOp : public tensorflow::OpKernel {
public:
explicit FftAutoScaleOp(tensorflow::OpKernelConstruction* context)
......@@ -125,5 +178,27 @@ REGISTER_KERNEL_BUILDER(
RfftOp<int32_t, DT_INT32, ::tflm_signal::RfftInt32GetNeededMemory,
::tflm_signal::RfftInt32Init, ::tflm_signal::RfftInt32Apply>);
REGISTER_KERNEL_BUILDER(
Name("SignalIrfft")
.Device(tensorflow::DEVICE_CPU)
.TypeConstraint<float>("T"),
IrfftOp<float, tflite::tflm_signal::IrfftFloatGetNeededMemory,
tflite::tflm_signal::IrfftFloatInit,
tflite::tflm_signal::IrfftFloatApply>);
REGISTER_KERNEL_BUILDER(
Name("SignalIrfft")
.Device(tensorflow::DEVICE_CPU)
.TypeConstraint<int16>("T"),
IrfftOp<int16_t, tflite::tflm_signal::IrfftInt16GetNeededMemory,
tflite::tflm_signal::IrfftInt16Init,
tflite::tflm_signal::IrfftInt16Apply>);
REGISTER_KERNEL_BUILDER(
Name("SignalIrfft")
.Device(tensorflow::DEVICE_CPU)
.TypeConstraint<int32>("T"),
IrfftOp<int32_t, tflite::tflm_signal::IrfftInt32GetNeededMemory,
tflite::tflm_signal::IrfftInt32Init,
tflite::tflm_signal::IrfftInt32Apply>);
} // namespace signal
} // namespace tensorflow
\ No newline at end of file
......@@ -33,6 +33,16 @@ Status RfftShape(InferenceContext* c) {
return OkStatus();
}
Status IrfftShape(InferenceContext* c) {
ShapeHandle out;
int fft_length;
TF_RETURN_IF_ERROR(c->GetAttr<int>("fft_length", &fft_length));
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &out));
TF_RETURN_IF_ERROR(c->ReplaceDim(out, -1, c->MakeDim(fft_length), &out));
c->set_output(0, out);
return OkStatus();
}
// TODO(b/286250473): change back name after name clash resolved
REGISTER_OP("SignalRfft")
.Attr("T: {float, int16, int32}")
......@@ -62,6 +72,31 @@ fft_length: The length of the FFT operation. An input signal that's shorter
will be zero padded to fft_length.
)doc");
// TODO(b/286250473): change back name after name clash resolved
REGISTER_OP("SignalIrfft")
.Attr("T: {float, int16, int32}")
.Attr("fft_length: int >= 2")
.Input("input: T")
.Output("output: T")
.SetShapeFn(IrfftShape)
.Doc(R"doc(
Computes the inverse 1-dimensional discrete Fourier transform of a real-valued
signal over the inner-most dimension of input.
The inner-most dimension of input is assumed to be the result of RFFT:
the fft_length / 2 + 1 unique components of the DFT of a real-valued signal.
fft_length must be provided.
input: A tensor containing ((fft_length / 2) + 1) complex spectral
components along its innermost dimension.
Since there's no TF integer complex type, the array is represented using
((fft_length / 2) + 1) * 2 real elements.
output: A tensor containing fft_length time domain elements along its innermost
dimension.
fft_length: The length of the IFFT operation.
)doc");
// TODO(b/286250473): change back name after name clash resolved
REGISTER_OP("SignalFftAutoScale")
.Input("input: int16")
......
此差异已折叠。
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -31,6 +31,15 @@ extern const int32_t kRfftInt32Length512Golden[];
extern const float kRfftFloatLength512Input[];
extern const float kRfftFloatLength512Golden[];
extern const int16_t kIrfftInt16Length512Input[];
extern const int16_t kIrfftInt16Length512Golden[];
extern const int32_t kIrfftInt32Length512Input[];
extern const int32_t kIrfftInt32Length512Golden[];
extern const float kIrfftFloatLength512Input[];
extern const float kIrfftFloatLength512Golden[];
extern const int16_t kFftAutoScaleLength512Input[];
extern const int16_t kFftAutoScaleLength512Golden[];
......
......@@ -15,6 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_MICRO_KERNELS_MICRO_OPS_H_
#define TENSORFLOW_LITE_MICRO_KERNELS_MICRO_OPS_H_
#include "signal/micro/kernels/irfft.h"
#include "signal/micro/kernels/rfft.h"
#include "tensorflow/lite/c/common.h"
......
......@@ -343,6 +343,12 @@ class MicroMutableOpResolver : public MicroOpResolver {
return AddBuiltin(BuiltinOperator_IF, tflite::Register_IF(), ParseIf);
}
TfLiteStatus AddIrfft(const TFLMRegistration* registration =
tflite::tflm_signal::Register_IRFFT()) {
// TODO(b/286250473): change back name and remove namespace
return AddCustom("SignalIrfft", registration);
}
TfLiteStatus AddL2Normalization() {
return AddBuiltin(BuiltinOperator_L2_NORMALIZATION,
Register_L2_NORMALIZATION(), ParseL2Normalization);
......
......@@ -320,6 +320,7 @@ $(TENSORFLOW_ROOT)signal/micro/kernels/filter_bank_log.cc \
$(TENSORFLOW_ROOT)signal/micro/kernels/filter_bank_square_root.cc \
$(TENSORFLOW_ROOT)signal/micro/kernels/filter_bank_spectral_subtraction.cc \
$(TENSORFLOW_ROOT)signal/micro/kernels/framer.cc \
$(TENSORFLOW_ROOT)signal/micro/kernels/irfft.cc \
$(TENSORFLOW_ROOT)signal/micro/kernels/rfft.cc \
$(TENSORFLOW_ROOT)signal/micro/kernels/stacker.cc \
$(TENSORFLOW_ROOT)signal/micro/kernels/overlap_add.cc \
......@@ -331,6 +332,9 @@ $(TENSORFLOW_ROOT)signal/src/filter_bank.cc \
$(TENSORFLOW_ROOT)signal/src/filter_bank_log.cc \
$(TENSORFLOW_ROOT)signal/src/filter_bank_square_root.cc \
$(TENSORFLOW_ROOT)signal/src/filter_bank_spectral_subtraction.cc \
$(TENSORFLOW_ROOT)signal/src/irfft_float.cc \
$(TENSORFLOW_ROOT)signal/src/irfft_int16.cc \
$(TENSORFLOW_ROOT)signal/src/irfft_int32.cc \
$(TENSORFLOW_ROOT)signal/src/log.cc \
$(TENSORFLOW_ROOT)signal/src/max_abs.cc \
$(TENSORFLOW_ROOT)signal/src/msb_32.cc \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册