diff --git a/tensorflow/lite/kernels/BUILD b/tensorflow/lite/kernels/BUILD index bcb8db861f6da86a40cd175f20c3a5818cb8e234..f28e215f1c82089d6f0713bd9588b3e752620f5d 100644 --- a/tensorflow/lite/kernels/BUILD +++ b/tensorflow/lite/kernels/BUILD @@ -1292,12 +1292,28 @@ cc_test( ], ) +cc_library( + name = "cast_test_common", + testonly = 1, + hdrs = [ + "cast_test_common.h", + ], + deps = [ + ":test_util", + "//tensorflow/lite:string", + "//tensorflow/lite/schema:schema_fbs", + "@com_google_absl//absl/types:span", + "@flatbuffers", + ], +) + cc_test( name = "cast_test", size = "small", srcs = ["cast_test.cc"], tags = ["tflite_nnapi"], deps = [ + ":cast_test_common", ":test_main", ":test_util", "//tensorflow/lite/core/c:c_api_types", diff --git a/tensorflow/lite/kernels/cast_test.cc b/tensorflow/lite/kernels/cast_test.cc index 9eaa4d9571129b05a76489faea5d5ffa7f6f4b3a..e2971016619532e1753ff987e435a98f4309ec7e 100644 --- a/tensorflow/lite/kernels/cast_test.cc +++ b/tensorflow/lite/kernels/cast_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/types/span.h" #include "flatbuffers/flatbuffers.h" // from @flatbuffers #include "tensorflow/lite/core/c/c_api_types.h" +#include "tensorflow/lite/kernels/cast_test_common.h" #include "tensorflow/lite/kernels/test_util.h" #include "tensorflow/lite/schema/schema_generated.h" @@ -32,28 +33,6 @@ namespace { using ::testing::ElementsAreArray; -class CastOpModel : public SingleOpModel { - public: - CastOpModel(const TensorData& input, const TensorData& output) { - input_ = AddInput(input); - output_ = AddOutput(output); - SetBuiltinOp(BuiltinOperator_CAST, BuiltinOptions_CastOptions, - CreateCastOptions(builder_).Union()); - BuildInterpreter({GetShape(input_)}); - } - - void Set4BitInput(absl::Span f) { - PopulateTensor4bit(input_, 0, f.data(), f.data() + f.size()); - } - - int input() const { return input_; } - int output() const { return output_; } - - protected: - int input_; - int output_; -}; - TEST(CastOpModel, CastInt4ToFloat) { CastOpModel m({TensorType_INT4, {2, 3}}, {TensorType_FLOAT32, {2, 3}}); m.Set4BitInput({1, 2, 3, 4, 5, 6}); diff --git a/tensorflow/lite/kernels/cast_test_common.h b/tensorflow/lite/kernels/cast_test_common.h new file mode 100644 index 0000000000000000000000000000000000000000..123cce213228f1677d14f38f39e077e45557ea4a --- /dev/null +++ b/tensorflow/lite/kernels/cast_test_common.h @@ -0,0 +1,53 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_KERNELS_CAST_TEST_COMMON_H_ +#define TENSORFLOW_LITE_KERNELS_CAST_TEST_COMMON_H_ + +#include + +#include + +#include "absl/types/span.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/schema/schema_generated.h" + +namespace tflite { + +using ::testing::ElementsAreArray; + +class CastOpModel : public SingleOpModel { + public: + CastOpModel(const TensorData& input, const TensorData& output) { + input_ = AddInput(input); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_CAST, BuiltinOptions_CastOptions, + CreateCastOptions(builder_).Union()); + BuildInterpreter({GetShape(input_)}); + } + + void Set4BitInput(absl::Span f) { + PopulateTensor4bit(input_, 0, f.data(), f.data() + f.size()); + } + + int input() const { return input_; } + int output() const { return output_; } + + protected: + int input_; + int output_; +}; + +} // namespace tflite +#endif // TENSORFLOW_LITE_KERNELS_CAST_TEST_COMMON_H_