提交 fc6b8ba7 编写于 作者: A Antonio Sanchez 提交者: TensorFlower Gardener

Add missing S/UINT4 xla-to-TF type entries.

PiperOrigin-RevId: 564494782
上级 66edf039
......@@ -773,6 +773,19 @@ cc_library(
],
)
tf_cc_test(
name = "type_util_test",
srcs = ["type_util_test.cc"],
deps = [
":common",
"//tensorflow/core:framework_types_hdr",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"@com_google_absl//absl/status:statusor",
],
)
cc_library(
name = "frontend_attributes_util",
srcs = ["frontend_attributes_util.cc"],
......
......@@ -106,10 +106,12 @@ StatusOr<DataType> EncodePrimitiveTypeAsDataType(xla::PrimitiveType type) {
{xla::F32, DT_FLOAT},
{xla::F64, DT_DOUBLE},
{xla::C64, DT_COMPLEX64},
{xla::S4, DT_INT4},
{xla::S8, DT_INT8},
{xla::S16, DT_INT16},
{xla::S32, DT_INT32},
{xla::S64, DT_INT64},
{xla::U4, DT_UINT4},
{xla::U8, DT_UINT8},
{xla::U16, DT_UINT16},
{xla::U32, DT_UINT32},
......
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/type_util.h"
#include <array>
#include "absl/status/statusor.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
// Conversion utilities should support any primitive type,
// excluding string, resource, variant, invalid.
bool DataTypeSupportsXlaConversion(DataType dt) {
switch (dt) {
case DataType::DT_STRING:
case DataType::DT_RESOURCE:
case DataType::DT_VARIANT:
case DataType::DT_INVALID:
return false;
default:
// All other types should be supported.
break;
}
return !IsRefType(dt);
}
TEST(DataTypeToPrimitiveTypeTest, AllDataTypesSupported) {
for (int i = tensorflow::DataType_MIN; i < tensorflow::DataType_MAX; ++i) {
if (tensorflow::DataType_IsValid(i)) {
DataType dt = static_cast<DataType>(i);
if (DataTypeSupportsXlaConversion(dt)) {
xla::PrimitiveType out_type;
EXPECT_TRUE(DataTypeToPrimitiveType(dt, &out_type).ok());
}
}
}
}
TEST(EncodePrimitiveTypeAsDataType, AllPrimitiveTypesSupported) {
for (int i = tensorflow::DataType_MIN; i < tensorflow::DataType_MAX; ++i) {
DataType dt = static_cast<DataType>(i);
xla::PrimitiveType xla_type;
// If conversion to primitive type works, then the reverse mapping should
// also work.
if (DataTypeToPrimitiveType(dt, &xla_type).ok()) {
absl::StatusOr<DataType> data_type_or =
EncodePrimitiveTypeAsDataType(xla_type);
EXPECT_TRUE(data_type_or.ok());
// Non-quantized inputs should map directly back to the original type.
if (!DataTypeIsQuantized(dt)) {
EXPECT_EQ(*data_type_or, dt);
}
}
}
}
TEST(EncodePrimitiveTypeAsDataType, QuantizedTypesMapToUnquantized) {
static std::array<DataType, 5> quantized_inputs = {
DT_QINT8, DT_QINT16, DT_QINT32, DT_QUINT8, DT_QUINT16};
static std::array<DataType, 5> expected_outputs = {
DT_INT8, DT_INT16, DT_INT32, DT_UINT8, DT_UINT16};
for (int i = 0; i < quantized_inputs.size(); ++i) {
xla::PrimitiveType xla_type;
EXPECT_TRUE(DataTypeToPrimitiveType(quantized_inputs[i], &xla_type).ok());
absl::StatusOr<DataType> data_type_or =
EncodePrimitiveTypeAsDataType(xla_type);
EXPECT_TRUE(data_type_or.ok());
EXPECT_EQ(*data_type_or, expected_outputs[i]);
}
}
} // namespace
} // namespace tensorflow
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册