提交 3e799d06 编写于 作者: P Prakalp Srivastava 提交者: TensorFlower Gardener

Handle uint8/16/32/64 as special types in TF dialect.

If the return of a function is an unsigned tensor type, it was being modeled as standard int type, losing information that the return is unsigned.

PiperOrigin-RevId: 257623093
上级 3d848eb0
......@@ -167,6 +167,8 @@ static StatusOr<tflite::TensorType> GetTFLiteType(Type type,
return tflite::TensorType_STRING;
case mlir::TF::TensorFlowTypes::COMPLEX64:
return tflite::TensorType_COMPLEX64;
case mlir::TF::TensorFlowTypes::UINT8:
return tflite::TensorType_UINT8;
case mlir::StandardTypes::Integer: {
const auto& itype = type.cast<mlir::IntegerType>();
switch (itype.getWidth()) {
......
......@@ -50,6 +50,13 @@ def TFL_Str : Type<CPred<"$_self.isa<mlir::TF::StringType>()">,
"TFLite string type">,
BuildableType<"getType<mlir::TF::StringType>()">;
//===----------------------------------------------------------------------===//
// TFLite dialect uint8 type - uses the TF uint8 type as implementation
//===----------------------------------------------------------------------===//
def TFL_Uint8 : Type<CPred<"$_self.isa<mlir::TF::Uint8Type>()">,
"TFLite uint8 type">,
BuildableType<"getType<mlir::TF::Uint8Type>()">;
//===----------------------------------------------------------------------===//
// Activation function enum definitions.
//===----------------------------------------------------------------------===//
......@@ -438,13 +445,13 @@ def TFL_ConcatenationOp : TFL_Op<"concatenation",
}];
let arguments = (
ins Variadic<TensorOf<[F32, I64, I32, I16, I8, TFL_QI8]>>:$values,
ins Variadic<TensorOf<[F32, I64, I32, I16, I8, TFL_QI8, TFL_Uint8]>>:$values,
I32Attr:$axis,
TFL_AFAttr:$fused_activation_function
);
let results = (outs
TensorOf<[F32, I64, I32, I16, I8, TFL_QI8]>:$output
TensorOf<[F32, I64, I32, I16, I8, TFL_QI8, TFL_Uint8]>:$output
);
let hasOptions = 1;
......@@ -559,9 +566,8 @@ def TFL_LessEqualOp : TFL_Op<"less_equal", [Broadcastable, NoSideEffect]> {
}];
let arguments = (
// TODO(haoliang): missing Uint8
ins TensorOf<[F32, I32, I64, I8]>:$lhs,
TensorOf<[F32, I32, I64, I8]>:$rhs);
ins TensorOf<[F32, I32, I64, I8, TFL_Uint8]>:$lhs,
TensorOf<[F32, I32, I64, I8, TFL_Uint8]>:$rhs);
let results = (outs TFL_BoolTensor:$output);
......@@ -671,10 +677,9 @@ def TFL_EqualOp: TFL_Op<"equal", [Commutative, Broadcastable,
}];
let arguments = (
// TODO: missing Uint8
ins
TensorOf<[I1, F32, I32, I64, I8]>:$x,
TensorOf<[I1, F32, I32, I64, I8]>:$y
TensorOf<[I1, F32, I32, I64, I8, TFL_Uint8]>:$x,
TensorOf<[I1, F32, I32, I64, I8, TFL_Uint8]>:$y
);
let results = (outs TFL_BoolTensor:$output);
......@@ -1072,8 +1077,7 @@ def TFL_MeanOp : TFL_Op<"mean", [NoSideEffect]> {
}];
let arguments = (ins
// TODO: missing uint8
TensorOf<[F32, I8, I32, I64]>:$input,
TensorOf<[F32, I8, I32, I64, TFL_Uint8]>:$input,
TensorOf<[I32, I64]>:$axis,
BoolAttr:$keep_dims
);
......
......@@ -817,7 +817,7 @@ func @testConcatInvalidOutputElementalType(%arg0: tensor<2xi32>, %arg1: tensor<2
// -----
func @testConcatInvalidStorageType(%arg0: tensor<2x!quant.uniform<i9:f32, 0.1:128>>, %arg1: tensor<2x!quant.uniform<i8:f32, 0.1:128>>) -> tensor<2x2x!quant.uniform<i8:f32, 0.1:128>> {
// expected-error @+1 {{'tfl.concatenation' op operand #0 must be tensor of 32-bit float or 64-bit integer or 32-bit integer or 16-bit integer or 8-bit integer or quantized type with 8 bits storage type values}}
// expected-error @+1 {{'tfl.concatenation' op operand #0 must be tensor of 32-bit float or 64-bit integer or 32-bit integer or 16-bit integer or 8-bit integer or quantized type with 8 bits storage type or TFLite uint8 type values}}
%0 = "tfl.concatenation"(%arg0, %arg1) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<2x!quant.uniform<i9:f32, 0.1:128>>, tensor<2x!quant.uniform<i8:f32, 0.1:128>>) -> tensor<2x2x!quant.uniform<i8:f32, 0.1:128>>
return %0 : tensor<2x2x!quant.uniform<i8:f32, 0.1:128>>
}
......
......@@ -19,6 +19,10 @@ limitations under the License.
#ifdef HANDLE_TF_TYPE
// class, enumerant, name
HANDLE_TF_TYPE(Uint8, UINT8, "uint8")
HANDLE_TF_TYPE(Uint16, UINT16, "uint16")
HANDLE_TF_TYPE(Uint32, UINT32, "uint32")
HANDLE_TF_TYPE(Uint64, UINT64, "uint64")
HANDLE_TF_TYPE(Qint8, QINT8, "qint8")
HANDLE_TF_TYPE(Qint16, QINT16, "qint16")
HANDLE_TF_TYPE(Qint32, QINT32, "qint32")
......
# RUN: tf-mlir-translate -graphdef-to-mlir -mlir-print-debuginfo %s -o - | FileCheck %s
node {
name: "PartitionedCall"
op: "PartitionedCall"
attr {
key: "Tin"
value {
list {
}
}
}
attr {
key: "Tout"
value {
list {
type: DT_UINT8
}
}
}
attr {
key: "_gradient_op_type"
value {
s: "PartitionedCall-15"
}
}
attr {
key: "config"
value {
s: ""
}
}
attr {
key: "config_proto"
value {
s: "\n\007\n\003GPU\020\000\n\007\n\003CPU\020\0012\002J\0008\001"
}
}
attr {
key: "executor_type"
value {
s: ""
}
}
attr {
key: "f"
value {
func {
name: "__inference_uint_const_14"
}
}
}
}
library {
function {
signature {
name: "__inference_uint_const_14"
output_arg {
name: "identity"
type: DT_UINT8
}
}
node_def {
name: "Const"
op: "Const"
attr {
key: "dtype"
value {
type: DT_UINT8
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_UINT8
tensor_shape {
}
int_val: 5
}
}
}
}
node_def {
name: "Identity"
op: "Identity"
input: "Const:output:0"
attr {
key: "T"
value {
type: DT_UINT8
}
}
}
ret {
key: "identity"
value: "Identity:output:0"
}
}
}
versions {
producer: 29
min_consumer: 12
}
# CHECK: func @main
# CHECK: "_tf.PartitionedCall"()
# CHECK-SAME: Tout = ["tfdtype$DT_UINT8"]
# CHECK-SAME: f = @[[FUNCTION:[A-Za-z0-9_]*]]
# CHECK: func @[[FUNCTION]]() -> tensor<!tf.uint8>
# CHECK: return {{%[0-9]*#[0-9]*}} : tensor<!tf.uint8>
......@@ -46,19 +46,15 @@ Status ConvertDataType(const DataType& dtype, Builder builder, Type* type) {
*type = builder.getIntegerType(1);
return Status::OK();
case DT_INT8:
case DT_UINT8:
*type = builder.getIntegerType(8);
return Status::OK();
case DT_INT16:
case DT_UINT16:
*type = builder.getIntegerType(16);
return Status::OK();
case DT_INT32:
case DT_UINT32:
*type = builder.getIntegerType(32);
return Status::OK();
case DT_INT64:
case DT_UINT64:
*type = builder.getIntegerType(64);
return Status::OK();
case DT_BFLOAT16:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册