未验证 提交 4cef7ee3 编写于 作者: C cad-audio 提交者: GitHub

REF_CODE_REFACTOR: logicalAND and logicalOR (#300)

* REF_CODE_REFACTOR: logicalAND and logicalOR
Refactoring the reference code for logicalAND and logicalOR operators.

BUG=refactoring existing code.

* Fix formatting.

* fix bad merge (copyright year).
Co-authored-by: NAdvait Jain <advaitjain@google.com>
Co-authored-by: NAdvait Jain <advaitjain@users.noreply.github.com>
上级 487ae230
......@@ -150,6 +150,7 @@ cc_library(
"leaky_relu.cc",
"log_softmax.cc",
"logical.cc",
"logical_common.cc",
"logistic.cc",
"logistic_common.cc",
"maximum_minimum.cc",
......@@ -191,6 +192,7 @@ cc_library(
"depthwise_conv.h",
"ethosu.h",
"fully_connected.h",
"logical.h",
"logistic.h",
"micro_ops.h",
"pooling.h",
......
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/micro/kernels/logical.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/reference/binary_function.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
......@@ -19,60 +21,17 @@ limitations under the License.
#include "tensorflow/lite/micro/kernels/kernel_util.h"
namespace tflite {
namespace ops {
namespace micro {
namespace logical {
namespace {
// Input/output tensor index.
constexpr int kInputTensor1 = 0;
constexpr int kInputTensor2 = 1;
constexpr int kOutputTensor = 0;
TfLiteStatus LogicalImpl(TfLiteContext* context, TfLiteNode* node,
bool (*func)(bool, bool)) {
const TfLiteEvalTensor* input1 =
tflite::micro::GetEvalInput(context, node, kInputTensor1);
const TfLiteEvalTensor* input2 =
tflite::micro::GetEvalInput(context, node, kInputTensor2);
TfLiteEvalTensor* output =
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
if (tflite::micro::HaveSameShapes(input1, input2)) {
reference_ops::BinaryFunction<bool, bool, bool>(
tflite::micro::GetTensorShape(input1),
tflite::micro::GetTensorData<bool>(input1),
tflite::micro::GetTensorShape(input2),
tflite::micro::GetTensorData<bool>(input2),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<bool>(output), func);
} else {
reference_ops::BroadcastBinaryFunction4DSlow<bool, bool, bool>(
tflite::micro::GetTensorShape(input1),
tflite::micro::GetTensorData<bool>(input1),
tflite::micro::GetTensorShape(input2),
tflite::micro::GetTensorData<bool>(input2),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<bool>(output), func);
}
return kTfLiteOk;
}
bool LogicalOr(bool x, bool y) { return x || y; }
TfLiteStatus LogicalOrEval(TfLiteContext* context, TfLiteNode* node) {
return LogicalImpl(context, node, LogicalOr);
}
bool LogicalAnd(bool x, bool y) { return x && y; }
TfLiteStatus LogicalAndEval(TfLiteContext* context, TfLiteNode* node) {
return LogicalImpl(context, node, LogicalAnd);
}
} // namespace
} // namespace logical
TfLiteRegistration Register_LOGICAL_OR() {
// Init, Free, Prepare, Eval are satisfying the Interface required by
......@@ -80,7 +39,7 @@ TfLiteRegistration Register_LOGICAL_OR() {
return {/*init=*/nullptr,
/*free=*/nullptr,
/*prepare=*/nullptr,
/*invoke=*/logical::LogicalOrEval,
/*invoke=*/LogicalOrEval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
......@@ -93,13 +52,11 @@ TfLiteRegistration Register_LOGICAL_AND() {
return {/*init=*/nullptr,
/*free=*/nullptr,
/*prepare=*/nullptr,
/*invoke=*/logical::LogicalAndEval,
/*invoke=*/LogicalAndEval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
}
} // namespace micro
} // namespace ops
} // namespace tflite
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_MICRO_KERNELS_LOGICAL_H_
#define TENSORFLOW_LITE_MICRO_KERNELS_LOGICAL_H_
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
namespace tflite {
// Input/output tensor index.
extern const int kLogicalInputTensor1;
extern const int kLogicalInputTensor2;
extern const int kLogicalOutputTensor;
TfLiteStatus LogicalImpl(TfLiteContext* context, TfLiteNode* node,
bool (*func)(bool, bool));
bool LogicalOr(bool x, bool y);
bool LogicalAnd(bool x, bool y);
} // namespace tflite
#endif // TENSORFLOW_LITE_MICRO_KERNELS_LOGICAL_H_
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/reference/binary_function.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/op_macros.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "tensorflow/lite/micro/kernels/logical.h"
namespace tflite {
// Input/output tensor index.
const int kLogicalInputTensor1 = 0;
const int kLogicalInputTensor2 = 1;
const int kLogicalOutputTensor = 0;
TfLiteStatus LogicalImpl(TfLiteContext* context, TfLiteNode* node,
bool (*func)(bool, bool)) {
const TfLiteEvalTensor* input1 =
tflite::micro::GetEvalInput(context, node, kLogicalInputTensor1);
const TfLiteEvalTensor* input2 =
tflite::micro::GetEvalInput(context, node, kLogicalInputTensor2);
TfLiteEvalTensor* output =
tflite::micro::GetEvalOutput(context, node, kLogicalOutputTensor);
if (tflite::micro::HaveSameShapes(input1, input2)) {
reference_ops::BinaryFunction<bool, bool, bool>(
tflite::micro::GetTensorShape(input1),
tflite::micro::GetTensorData<bool>(input1),
tflite::micro::GetTensorShape(input2),
tflite::micro::GetTensorData<bool>(input2),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<bool>(output), func);
} else {
reference_ops::BroadcastBinaryFunction4DSlow<bool, bool, bool>(
tflite::micro::GetTensorShape(input1),
tflite::micro::GetTensorData<bool>(input1),
tflite::micro::GetTensorShape(input2),
tflite::micro::GetTensorData<bool>(input2),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<bool>(output), func);
}
return kTfLiteOk;
}
bool LogicalOr(bool x, bool y) { return x || y; }
bool LogicalAnd(bool x, bool y) { return x && y; }
} // namespace tflite
......@@ -73,9 +73,8 @@ TF_LITE_MICRO_TEST(LogicalOr) {
const bool input2[] = {true, false, true, false};
const bool golden[] = {true, false, true, true};
bool output_data[4];
tflite::testing::TestLogicalOp(tflite::ops::micro::Register_LOGICAL_OR(),
shape, input1, shape, input2, shape, golden,
output_data);
tflite::testing::TestLogicalOp(tflite::Register_LOGICAL_OR(), shape, input1,
shape, input2, shape, golden, output_data);
}
TF_LITE_MICRO_TEST(BroadcastLogicalOr) {
......@@ -85,9 +84,9 @@ TF_LITE_MICRO_TEST(BroadcastLogicalOr) {
const bool input2[] = {false};
const bool golden[] = {true, false, false, true};
bool output_data[4];
tflite::testing::TestLogicalOp(tflite::ops::micro::Register_LOGICAL_OR(),
input1_shape, input1, input2_shape, input2,
input1_shape, golden, output_data);
tflite::testing::TestLogicalOp(tflite::Register_LOGICAL_OR(), input1_shape,
input1, input2_shape, input2, input1_shape,
golden, output_data);
}
TF_LITE_MICRO_TEST(LogicalAnd) {
......@@ -96,9 +95,8 @@ TF_LITE_MICRO_TEST(LogicalAnd) {
const bool input2[] = {true, false, true, false};
const bool golden[] = {true, false, false, false};
bool output_data[4];
tflite::testing::TestLogicalOp(tflite::ops::micro::Register_LOGICAL_AND(),
shape, input1, shape, input2, shape, golden,
output_data);
tflite::testing::TestLogicalOp(tflite::Register_LOGICAL_AND(), shape, input1,
shape, input2, shape, golden, output_data);
}
TF_LITE_MICRO_TEST(BroadcastLogicalAnd) {
......@@ -108,9 +106,9 @@ TF_LITE_MICRO_TEST(BroadcastLogicalAnd) {
const bool input2[] = {true};
const bool golden[] = {true, false, false, true};
bool output_data[4];
tflite::testing::TestLogicalOp(tflite::ops::micro::Register_LOGICAL_AND(),
input1_shape, input1, input2_shape, input2,
input1_shape, golden, output_data);
tflite::testing::TestLogicalOp(tflite::Register_LOGICAL_AND(), input1_shape,
input1, input2_shape, input2, input1_shape,
golden, output_data);
}
TF_LITE_MICRO_TESTS_END
......@@ -51,6 +51,8 @@ TfLiteRegistration Register_IF();
TfLiteRegistration Register_L2_POOL_2D();
TfLiteRegistration Register_LEAKY_RELU();
TfLiteRegistration Register_LOG_SOFTMAX();
TfLiteRegistration Register_LOGICAL_AND();
TfLiteRegistration Register_LOGICAL_OR();
TfLiteRegistration Register_LOGISTIC();
TfLiteRegistration Register_MAX_POOL_2D();
TfLiteRegistration Register_QUANTIZE();
......@@ -87,9 +89,7 @@ TfLiteRegistration Register_HARD_SWISH();
TfLiteRegistration Register_LESS();
TfLiteRegistration Register_LESS_EQUAL();
TfLiteRegistration Register_LOG();
TfLiteRegistration Register_LOGICAL_AND();
TfLiteRegistration Register_LOGICAL_NOT();
TfLiteRegistration Register_LOGICAL_OR();
TfLiteRegistration Register_MAXIMUM();
TfLiteRegistration Register_MEAN();
TfLiteRegistration Register_MINIMUM();
......
......@@ -321,8 +321,7 @@ class MicroMutableOpResolver : public MicroOpResolver {
TfLiteStatus AddLogicalAnd() {
return AddBuiltin(BuiltinOperator_LOGICAL_AND,
tflite::ops::micro::Register_LOGICAL_AND(),
ParseLogicalAnd);
tflite::Register_LOGICAL_AND(), ParseLogicalAnd);
}
TfLiteStatus AddLogicalNot() {
......@@ -332,8 +331,7 @@ class MicroMutableOpResolver : public MicroOpResolver {
}
TfLiteStatus AddLogicalOr() {
return AddBuiltin(BuiltinOperator_LOGICAL_OR,
tflite::ops::micro::Register_LOGICAL_OR(),
return AddBuiltin(BuiltinOperator_LOGICAL_OR, tflite::Register_LOGICAL_OR(),
ParseLogicalOr);
}
......
......@@ -387,6 +387,7 @@ tensorflow/lite/micro/kernels/l2norm.cc \
tensorflow/lite/micro/kernels/l2_pool_2d.cc \
tensorflow/lite/micro/kernels/leaky_relu.cc \
tensorflow/lite/micro/kernels/logical.cc \
tensorflow/lite/micro/kernels/logical_common.cc \
tensorflow/lite/micro/kernels/logistic.cc \
tensorflow/lite/micro/kernels/logistic_common.cc \
tensorflow/lite/micro/kernels/log_softmax.cc \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册