提交 89cbc088 编写于 作者: R Renjie Liu 提交者: TensorFlower Gardener

Add real/imag custom ops. The ops will be migrated to builtin ops soon.

PiperOrigin-RevId: 327742515
Change-Id: I0699f469c98270bf895cb1b8826fcc5a2c6fdd46
上级 c37ef0c1
......@@ -697,16 +697,16 @@ cc_test(
cc_library(
name = "custom_ops",
srcs = ["rfft2d.cc"],
srcs = [
"complex_support.cc",
"rfft2d.cc",
],
hdrs = ["custom_ops_register.h"],
copts = tflite_copts(),
deps = [
":kernel_util",
":op_macros",
"//tensorflow/lite:context",
"//tensorflow/lite/c:common",
"//tensorflow/lite/kernels/hashtable:hashtable_op_kernels",
"//tensorflow/lite/kernels/internal:kernel_utils",
"//tensorflow/lite/kernels/internal:optimized_base",
"//tensorflow/lite/kernels/internal:tensor",
"//tensorflow/lite/kernels/internal:types",
"//third_party/fft2d:fft2d_headers",
......@@ -2288,4 +2288,19 @@ cc_test(
],
)
cc_test(
name = "complex_support_test",
srcs = ["complex_support_test.cc"],
deps = [
":custom_ops",
":test_main",
":test_util",
"//tensorflow/lite:framework",
"//tensorflow/lite/schema:schema_fbs",
"//tensorflow/lite/testing:util",
"@com_google_googletest//:gtest",
"@flatbuffers",
],
)
tflite_portable_test_suite_combined(combine_conditions = {"deps": [":test_main"]})
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <complex>
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
// TODO(b/165735381): Promote this op to builtin-op when we can add new builtin
// ops.
namespace tflite {
namespace ops {
namespace custom {
namespace complex {
static const int kInputTensor = 0;
static const int kOutputTensor = 0;
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TF_LITE_ENSURE(context, input->type == kTfLiteComplex64 ||
input->type == kTfLiteComplex128);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
if (input->type == kTfLiteComplex64) {
TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32);
} else {
TF_LITE_ENSURE(context, output->type = kTfLiteFloat64);
}
TfLiteIntArray* output_shape = TfLiteIntArrayCopy(input->dims);
return context->ResizeTensor(context, output, output_shape);
}
template <typename T, typename ExtractF>
void ExtractData(const TfLiteTensor* input, ExtractF extract_func,
TfLiteTensor* output) {
const std::complex<T>* input_data = GetTensorData<std::complex<T>>(input);
T* output_data = GetTensorData<T>(output);
const int input_size = NumElements(input);
for (int i = 0; i < input_size; ++i) {
*output_data++ = extract_func(*input_data++);
}
}
TfLiteStatus EvalReal(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
switch (input->type) {
case kTfLiteComplex64: {
ExtractData<float>(
input,
static_cast<float (*)(const std::complex<float>&)>(std::real<float>),
output);
break;
}
case kTfLiteComplex128: {
ExtractData<double>(input,
static_cast<double (*)(const std::complex<double>&)>(
std::real<double>),
output);
break;
}
default: {
TF_LITE_KERNEL_LOG(context,
"Unsupported input type, Real op only supports "
"complex input, but got: ",
TfLiteTypeGetName(input->type));
return kTfLiteError;
}
}
return kTfLiteOk;
}
TfLiteStatus EvalImag(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
switch (input->type) {
case kTfLiteComplex64: {
ExtractData<float>(
input,
static_cast<float (*)(const std::complex<float>&)>(std::imag<float>),
output);
break;
}
case kTfLiteComplex128: {
ExtractData<double>(input,
static_cast<double (*)(const std::complex<double>&)>(
std::imag<double>),
output);
break;
}
default: {
TF_LITE_KERNEL_LOG(context,
"Unsupported input type, Imag op only supports "
"complex input, but got: ",
TfLiteTypeGetName(input->type));
return kTfLiteError;
}
}
return kTfLiteOk;
}
} // namespace complex
TfLiteRegistration* Register_REAL() {
static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
complex::Prepare, complex::EvalReal};
return &r;
}
TfLiteRegistration* Register_IMAG() {
static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
complex::Prepare, complex::EvalImag};
return &r;
}
} // namespace custom
} // namespace ops
} // namespace tflite
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <complex>
#include <vector>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/custom_ops_register.h"
#include "tensorflow/lite/kernels/test_util.h"
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/testing/util.h"
namespace tflite {
namespace ops {
namespace custom {
TfLiteRegistration* Register_REAL();
TfLiteRegistration* Register_IMAG();
namespace {
template <typename T>
class RealOpModel : public SingleOpModel {
public:
RealOpModel(const TensorData& input, const TensorData& output) {
input_ = AddInput(input);
output_ = AddOutput(output);
const std::vector<uint8_t> custom_option;
SetCustomOp("Real", custom_option, Register_REAL);
BuildInterpreter({GetShape(input_)});
}
int input() { return input_; }
std::vector<T> GetOutput() { return ExtractVector<T>(output_); }
private:
int input_;
int output_;
};
TEST(RealOpTest, SimpleFloatTest) {
RealOpModel<float> m({TensorType_COMPLEX64, {2, 4}},
{TensorType_FLOAT32, {}});
m.PopulateTensor<std::complex<float>>(m.input(), {{75, 0},
{-6, -1},
{9, 0},
{-10, 5},
{-3, 2},
{-6, 11},
{0, 0},
{22.1, 33.3}});
m.Invoke();
EXPECT_THAT(m.GetOutput(), testing::ElementsAreArray(ArrayFloatNear(
{75, -6, 9, -10, -3, -6, 0, 22.1f})));
}
TEST(RealOpTest, SimpleDoubleTest) {
RealOpModel<double> m({TensorType_COMPLEX128, {2, 4}},
{TensorType_FLOAT64, {}});
m.PopulateTensor<std::complex<double>>(m.input(), {{75, 0},
{-6, -1},
{9, 0},
{-10, 5},
{-3, 2},
{-6, 11},
{0, 0},
{22.1, 33.3}});
m.Invoke();
EXPECT_THAT(m.GetOutput(), testing::ElementsAreArray(ArrayFloatNear(
{75, -6, 9, -10, -3, -6, 0, 22.1f})));
}
template <typename T>
class ImagOpModel : public SingleOpModel {
public:
ImagOpModel(const TensorData& input, const TensorData& output) {
input_ = AddInput(input);
output_ = AddOutput(output);
const std::vector<uint8_t> custom_option;
SetCustomOp("Imag", custom_option, Register_IMAG);
BuildInterpreter({GetShape(input_)});
}
int input() { return input_; }
std::vector<T> GetOutput() { return ExtractVector<T>(output_); }
private:
int input_;
int output_;
};
TEST(ImagOpTest, SimpleFloatTest) {
ImagOpModel<float> m({TensorType_COMPLEX64, {2, 4}},
{TensorType_FLOAT32, {}});
m.PopulateTensor<std::complex<float>>(m.input(), {{75, 7},
{-6, -1},
{9, 3.5},
{-10, 5},
{-3, 2},
{-6, 11},
{0, 0},
{22.1, 33.3}});
m.Invoke();
EXPECT_THAT(m.GetOutput(), testing::ElementsAreArray(ArrayFloatNear(
{7, -1, 3.5f, 5, 2, 11, 0, 33.3f})));
}
TEST(ImagOpTest, SimpleDoubleTest) {
ImagOpModel<double> m({TensorType_COMPLEX128, {2, 4}},
{TensorType_FLOAT64, {}});
m.PopulateTensor<std::complex<double>>(m.input(), {{75, 7},
{-6, -1},
{9, 3.5},
{-10, 5},
{-3, 2},
{-6, 11},
{0, 0},
{22.1, 33.3}});
m.Invoke();
EXPECT_THAT(m.GetOutput(), testing::ElementsAreArray(ArrayFloatNear(
{7, -1, 3.5f, 5, 2, 11, 0, 33.3f})));
}
} // namespace
} // namespace custom
} // namespace ops
} // namespace tflite
int main(int argc, char** argv) {
::tflite::LogToStderr();
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}
......@@ -26,6 +26,8 @@ TfLiteRegistration* Register_HASHTABLE();
TfLiteRegistration* Register_HASHTABLE_FIND();
TfLiteRegistration* Register_HASHTABLE_IMPORT();
TfLiteRegistration* Register_HASHTABLE_SIZE();
TfLiteRegistration* Register_REAL();
TfLiteRegistration* Register_IMAG();
}
} // namespace ops
} // namespace tflite
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册