From 771f9625343759d5953ae0048e7e92169f42d9b0 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 19 Jul 2023 15:19:05 -0700 Subject: [PATCH] Move pseudo-constant from while loop argument to while loop body in `tfl_while_outline` to avoid memory penalty during runtime PiperOrigin-RevId: 549440221 --- .../mlir/lite/tests/tfl_while_outline.mlir | 24 +++++++++++++++++++ .../lite/transforms/while_loop_outline.cc | 7 ++++-- 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/tests/tfl_while_outline.mlir b/tensorflow/compiler/mlir/lite/tests/tfl_while_outline.mlir index 0fd1482d77d..71f717d4c78 100644 --- a/tensorflow/compiler/mlir/lite/tests/tfl_while_outline.mlir +++ b/tensorflow/compiler/mlir/lite/tests/tfl_while_outline.mlir @@ -190,3 +190,27 @@ func.func @whileDifferentResultShapes(%arg0: tensor) -> tensor // CHECK: (tensor, tensor<1xf32>, tensor) -> (tensor, tensor, tensor) func.return %0#1 : tensor } + +func.func @whileSinkConstant(%arg0: tensor<1x256xf32>) -> tensor<1x256xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "prefix", outputs = "Identity_1"}} { + %cst_0 = arith.constant dense<1> : tensor<256x256xi8> + %cst_1 = "tfl.pseudo_qconst"() {qtype = tensor<256x256x!quant.uniform>, value = dense<1> : tensor<256x256xi8>} : () -> tensor<256x256x!quant.uniform> + %cst_2 = arith.constant dense<0> : tensor + %0 = "tfl.batch_matmul"(%arg0, %cst_0) {adj_x = false, adj_y = false} : (tensor<1x256xf32>, tensor<256x256xi8>) -> tensor<1x256xf32> + %1 = "tfl.batch_matmul"(%0, %cst_1) {adj_x = false, adj_y = false} : (tensor<1x256xf32>, tensor<256x256x!quant.uniform>) -> tensor<1x256xf32> + %2:2 = "tfl.while"(%cst_2, %1) ({ + ^bb0(%arg1: tensor, %arg2: tensor<1x256xf32>): + %cst_3 = arith.constant dense<10> : tensor + %3 = tfl.less(%arg1, %cst_3) : (tensor, tensor) -> tensor + "tfl.yield"(%3) : (tensor) -> () + }, { + ^bb0(%arg1: tensor, %arg2: tensor<1x256xf32>): + // CHECK: %[[QCONST:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<256x256x!quant.uniform>, value = dense<1> : tensor<256x256xi8>} : () -> tensor<256x256x!quant.uniform> + // CHECK: %[[CONST:.*]] = arith.constant dense<1> : tensor<256x256xi8> + %4 = "tfl.batch_matmul"(%arg2, %cst_0) {adj_x = false, adj_y = false} : (tensor<1x256xf32>, tensor<256x256xi8>) -> tensor<1x256xf32> + // CHECK-NEXT: %[[BMM_0:.*]] = "tfl.batch_matmul"(%arg1, %[[CONST]]) + %5 = "tfl.batch_matmul"(%4, %cst_1) {adj_x = false, adj_y = false} : (tensor<1x256xf32>, tensor<256x256x!quant.uniform>) -> tensor<1x256xf32> + // CHECK-NEXT: %[[BMM_1:.*]] = "tfl.batch_matmul"(%[[BMM_0]], %[[QCONST]]) + "tfl.yield"(%arg1, %5) : (tensor, tensor<1x256xf32>) -> () + }) {is_stateless = false} : (tensor, tensor<1x256xf32>) -> (tensor, tensor<1x256xf32>) + return %2#1 : tensor<1x256xf32> + } \ No newline at end of file diff --git a/tensorflow/compiler/mlir/lite/transforms/while_loop_outline.cc b/tensorflow/compiler/mlir/lite/transforms/while_loop_outline.cc index 560bcfd3543..eda15e78f9c 100644 --- a/tensorflow/compiler/mlir/lite/transforms/while_loop_outline.cc +++ b/tensorflow/compiler/mlir/lite/transforms/while_loop_outline.cc @@ -18,6 +18,7 @@ limitations under the License. #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project @@ -210,9 +211,11 @@ void WhileOutlinePass::OutlineWhile(WhileOp while_op) { llvm::SetVector region_extern_values; getUsedValuesDefinedAbove(*it.value(), region_extern_values); - // Sink down constants into the functions. + // Sink down constants (including quantized constant) into the functions. for (auto extern_value : region_extern_values) { - if (!matchPattern(extern_value, m_Constant())) { + if (!matchPattern(extern_value, m_Constant()) && + !llvm::dyn_cast_or_null( + extern_value.getDefiningOp())) { extern_values.insert(extern_value); continue; } -- GitLab