提交 771f9625 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Move pseudo-constant from while loop argument to while loop body in...

Move pseudo-constant from while loop argument to while loop body in `tfl_while_outline` to avoid memory penalty during runtime

PiperOrigin-RevId: 549440221
上级 85d3662a
......@@ -190,3 +190,27 @@ func.func @whileDifferentResultShapes(%arg0: tensor<i32>) -> tensor<?xf32>
// CHECK: (tensor<i32>, tensor<1xf32>, tensor<i32>) -> (tensor<i32>, tensor<?xf32>, tensor<i32>)
func.return %0#1 : tensor<?xf32>
}
func.func @whileSinkConstant(%arg0: tensor<1x256xf32>) -> tensor<1x256xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "prefix", outputs = "Identity_1"}} {
%cst_0 = arith.constant dense<1> : tensor<256x256xi8>
%cst_1 = "tfl.pseudo_qconst"() {qtype = tensor<256x256x!quant.uniform<i8:f32, 1.000000e+00>>, value = dense<1> : tensor<256x256xi8>} : () -> tensor<256x256x!quant.uniform<i8:f32, 1.000000e+00>>
%cst_2 = arith.constant dense<0> : tensor<i32>
%0 = "tfl.batch_matmul"(%arg0, %cst_0) {adj_x = false, adj_y = false} : (tensor<1x256xf32>, tensor<256x256xi8>) -> tensor<1x256xf32>
%1 = "tfl.batch_matmul"(%0, %cst_1) {adj_x = false, adj_y = false} : (tensor<1x256xf32>, tensor<256x256x!quant.uniform<i8:f32, 1.000000e+00>>) -> tensor<1x256xf32>
%2:2 = "tfl.while"(%cst_2, %1) ({
^bb0(%arg1: tensor<i32>, %arg2: tensor<1x256xf32>):
%cst_3 = arith.constant dense<10> : tensor<i32>
%3 = tfl.less(%arg1, %cst_3) : (tensor<i32>, tensor<i32>) -> tensor<i1>
"tfl.yield"(%3) : (tensor<i1>) -> ()
}, {
^bb0(%arg1: tensor<i32>, %arg2: tensor<1x256xf32>):
// CHECK: %[[QCONST:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<256x256x!quant.uniform<i8:f32, 1.000000e+00>>, value = dense<1> : tensor<256x256xi8>} : () -> tensor<256x256x!quant.uniform<i8:f32, 1.000000e+00>>
// CHECK: %[[CONST:.*]] = arith.constant dense<1> : tensor<256x256xi8>
%4 = "tfl.batch_matmul"(%arg2, %cst_0) {adj_x = false, adj_y = false} : (tensor<1x256xf32>, tensor<256x256xi8>) -> tensor<1x256xf32>
// CHECK-NEXT: %[[BMM_0:.*]] = "tfl.batch_matmul"(%arg1, %[[CONST]])
%5 = "tfl.batch_matmul"(%4, %cst_1) {adj_x = false, adj_y = false} : (tensor<1x256xf32>, tensor<256x256x!quant.uniform<i8:f32, 1.000000e+00>>) -> tensor<1x256xf32>
// CHECK-NEXT: %[[BMM_1:.*]] = "tfl.batch_matmul"(%[[BMM_0]], %[[QCONST]])
"tfl.yield"(%arg1, %5) : (tensor<i32>, tensor<1x256xf32>) -> ()
}) {is_stateless = false} : (tensor<i32>, tensor<1x256xf32>) -> (tensor<i32>, tensor<1x256xf32>)
return %2#1 : tensor<1x256xf32>
}
\ No newline at end of file
......@@ -18,6 +18,7 @@ limitations under the License.
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
......@@ -210,9 +211,11 @@ void WhileOutlinePass::OutlineWhile(WhileOp while_op) {
llvm::SetVector<Value> region_extern_values;
getUsedValuesDefinedAbove(*it.value(), region_extern_values);
// Sink down constants into the functions.
// Sink down constants (including quantized constant) into the functions.
for (auto extern_value : region_extern_values) {
if (!matchPattern(extern_value, m_Constant())) {
if (!matchPattern(extern_value, m_Constant()) &&
!llvm::dyn_cast_or_null<TFL::QConstOp>(
extern_value.getDefiningOp())) {
extern_values.insert(extern_value);
continue;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册