提交 9133bcb2 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Export all known shape information to TFLite flatbuffers

Currently, the TFLite flatbuffer exporter only exports the shapes for
inputs and constants, which is all that is required by TFLite. However,
exporting more information this will improve the ability to round-trip
between TFLite and MLIR and could enable the TFLite runtime to infer
more information statically.

PiperOrigin-RevId: 258452655
上级 ec0b5d1c
......@@ -507,9 +507,10 @@ Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
auto type = value->getType().cast<TensorType>();
// TFLite requires tensor shape only for the inputs and constants.
// However, we output all known shapes for better round-tripping
std::vector<int32_t> shape;
if (auto* inst = value->getDefiningOp()) {
if (IsConstOrInput(inst)) {
if (type.hasStaticShape()) {
auto shape_ref = type.getShape();
auto is_out_of_range = [](int64_t dim) {
return dim > std::numeric_limits<int32_t>::max();
......
......@@ -61,7 +61,7 @@ versions {
# CHECK-EMPTY:
# CHECK-NEXT: }
# CHECK-NEXT: }, {
# CHECK-NEXT: shape: [ ],
# CHECK-NEXT: shape: [ 4 ],
# CHECK-NEXT: type: INT32,
# CHECK-NEXT: buffer: 3,
# CHECK-NEXT: name: "Add",
......
......@@ -29,21 +29,21 @@ func @main(tensor<4xf32>) -> tensor<4xf32> {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: buffer: 3,
// CHECK-NEXT: name: "mul",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: buffer: 4,
// CHECK-NEXT: name: "MyCustomOp",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: buffer: 5,
// CHECK-NEXT: name: "exp",
// CHECK-NEXT: quantization: {
......
......@@ -25,21 +25,21 @@ func @main(tensor<4xf32>) -> tensor<4xf32> {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: buffer: 3,
// CHECK-NEXT: name: "mul0",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: buffer: 4,
// CHECK-NEXT: name: "mul1",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: buffer: 5,
// CHECK-NEXT: name: "exp",
// CHECK-NEXT: quantization: {
......
......@@ -16,7 +16,7 @@ func @main(%arg0: tensor<3x2xf32>) -> tensor<3x2xf32> {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: shape: [ 3, 2 ],
// CHECK-NEXT: buffer: 2,
// CHECK-NEXT: name: "tf.AddV2",
// CHECK-NEXT: quantization: {
......
......@@ -28,21 +28,21 @@ func @main(tensor<4xf32>) -> tensor<4xf32> {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: buffer: 3,
// CHECK-NEXT: name: "mul",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: buffer: 4,
// CHECK-NEXT: name: "div",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: buffer: 5,
// CHECK-NEXT: name: "exp",
// CHECK-NEXT: quantization: {
......
......@@ -28,7 +28,7 @@
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: shape: [ 1 ],
// CHECK-NEXT: type: BOOL,
// CHECK-NEXT: buffer: 3,
// CHECK-NEXT: name: "tfl.less",
......@@ -36,7 +36,7 @@
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: shape: [ 1 ],
// CHECK-NEXT: buffer: 4,
// CHECK-NEXT: name: "tf.If",
// CHECK-NEXT: quantization: {
......
......@@ -35,7 +35,7 @@ func @main(tensor<4xi1>) -> tensor<4xi1> {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: type: BOOL,
// CHECK-NEXT: buffer: 4,
// CHECK-NEXT: name: "logical_or",
......@@ -43,7 +43,7 @@ func @main(tensor<4xi1>) -> tensor<4xi1> {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: type: BOOL,
// CHECK-NEXT: buffer: 5,
// CHECK-NEXT: name: "logical_and",
......
......@@ -31,35 +31,35 @@ func @main(tensor<4xf32>) -> tensor<4xf32> {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: buffer: 3,
// CHECK-NEXT: name: "squared_difference",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: buffer: 4,
// CHECK-NEXT: name: "mul",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: buffer: 5,
// CHECK-NEXT: name: "div",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: buffer: 6,
// CHECK-NEXT: name: "exp",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: buffer: 7,
// CHECK-NEXT: name: "neg",
// CHECK-NEXT: quantization: {
......
......@@ -16,7 +16,7 @@ func @main(tensor<1x6x6x16xf32>) -> tensor<1x1x1x16xf32> {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: shape: [ 1, 1, 1, 16 ],
// CHECK-NEXT: buffer: 2,
// CHECK-NEXT: name: "avgpool",
// CHECK-NEXT: quantization: {
......
......@@ -23,7 +23,7 @@ func @main(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x1001xf32> {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: shape: [ 1, 224, 224, 3 ],
// CHECK-NEXT: type: UINT8,
// CHECK-NEXT: buffer: 2,
// CHECK-NEXT: name: "tfl.quantize",
......@@ -50,7 +50,7 @@ func @main(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x1001xf32> {
// CHECK-NEXT: zero_point: [ 0 ]
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: shape: [ 1, 112, 112, 32 ],
// CHECK-NEXT: type: UINT8,
// CHECK-NEXT: buffer: 5,
// CHECK-NEXT: name: "tfl.conv_2d",
......@@ -59,7 +59,7 @@ func @main(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x1001xf32> {
// CHECK-NEXT: zero_point: [ 0 ]
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: shape: [ 1, 1001 ],
// CHECK-NEXT: type: UINT8,
// CHECK-NEXT: buffer: 6,
// CHECK-NEXT: name: "tfl.reshape",
......@@ -68,7 +68,7 @@ func @main(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x1001xf32> {
// CHECK-NEXT: zero_point: [ 0 ]
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: shape: [ 1, 1001 ],
// CHECK-NEXT: type: UINT8,
// CHECK-NEXT: buffer: 7,
// CHECK-NEXT: name: "tfl.softmax",
......@@ -77,7 +77,7 @@ func @main(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x1001xf32> {
// CHECK-NEXT: zero_point: [ 0 ]
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: shape: [ 1, 1001 ],
// CHECK-NEXT: buffer: 8,
// CHECK-NEXT: name: "tfl.dequantize",
// CHECK-NEXT: quantization: {
......
......@@ -17,7 +17,7 @@ func @main(tensor<3x2xi32>) -> tensor<6xi32> {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: shape: [ 6 ],
// CHECK-NEXT: type: INT32,
// CHECK-NEXT: buffer: 2,
// CHECK-NEXT: name: "tfl.reshape",
......
......@@ -28,7 +28,7 @@ func @main(tensor<3x2xi32>) -> tensor<3x2xi32>
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: shape: [ 3, 2 ],
// CHECK-NEXT: type: INT32,
// CHECK-NEXT: buffer: 3,
// CHECK-NEXT: name: "sub",
......@@ -44,7 +44,7 @@ func @main(tensor<3x2xi32>) -> tensor<3x2xi32>
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: shape: [ 3, 2 ],
// CHECK-NEXT: type: INT32,
// CHECK-NEXT: buffer: 5,
// CHECK-NEXT: name: "SameNameAsOutput",
......
......@@ -37,7 +37,7 @@
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ ],
// CHECK-NEXT: shape: [ 1 ],
// CHECK-NEXT: buffer: 4,
// CHECK-NEXT: name: "tf.While:1",
// CHECK-NEXT: quantization: {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册