From 10f850876ab5ff9c48187992c3de80eeb8c1ca07 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 25 Aug 2020 01:11:37 -0700 Subject: [PATCH] Convert a soup of ops representing softmax into tfl.softmax Converting the individual ops into tfl.softmax will improve performance by enabling fusion as well as imporve accuracy on backends where the intermediate tensors have reduced precision. PiperOrigin-RevId: 328287102 Change-Id: Ie2e37804c20854ba57ae7b1d11a4620ea17bd39f --- .../compiler/mlir/lite/tests/optimize.mlir | 26 ++++++++++++++ .../mlir/lite/transforms/optimize_patterns.td | 34 +++++++++++++++++++ 2 files changed, 60 insertions(+) diff --git a/tensorflow/compiler/mlir/lite/tests/optimize.mlir b/tensorflow/compiler/mlir/lite/tests/optimize.mlir index edbcef3d321..10ff03a46f8 100644 --- a/tensorflow/compiler/mlir/lite/tests/optimize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/optimize.mlir @@ -1175,3 +1175,29 @@ func @FoldReduceProdKeepDim(%arg0: tensor<8x128xf32>) -> tensor<1x1xf32> { // CHECK: %[[RESULT:.*]] = "tfl.reduce_prod"(%arg0, %cst) {keep_dims = true} : (tensor<8x128xf32>, tensor<2xi32>) -> tensor<1x1xf32> // CHECK: return %[[RESULT]] : tensor<1x1xf32> } + +func @SoftMaxWithNormalization(%arg0: tensor<8x128xf32>) -> tensor<8x128xf32> { + %cst = constant dense<1> : tensor<1xi32> + %0 = "tfl.reduce_max"(%arg0, %cst) {keep_dims = true} : (tensor<8x128xf32>, tensor<1xi32>) -> tensor<8x1xf32> + %1 = "tfl.sub"(%arg0, %0) {fused_activation_function = "NONE"} : (tensor<8x128xf32>, tensor<8x1xf32>) -> tensor<8x128xf32> + %2 = "tfl.exp"(%1) : (tensor<8x128xf32>) -> tensor<8x128xf32> + %3 = "tfl.sum"(%2, %cst) {keep_dims = true} : (tensor<8x128xf32>, tensor<1xi32>) -> tensor<8x1xf32> + %4 = "tfl.div"(%2, %3) {fused_activation_function = "NONE"} : (tensor<8x128xf32>, tensor<8x1xf32>) -> tensor<8x128xf32> + return %4 : tensor<8x128xf32> + +// CHECK-LABEL: SoftMaxWithNormalization +// CHECK: %[[RESULT:.*]] = "tfl.softmax"(%arg0) {beta = 1.000000e+00 : f32} : (tensor<8x128xf32>) -> tensor<8x128xf32> +// CHECK: return %[[RESULT]] : tensor<8x128xf32> +} + +func @SoftMaxWithoutNormalization(%arg0: tensor<8x128xf32>) -> tensor<8x128xf32> { + %cst = constant dense<1> : tensor<1xi32> + %0 = "tfl.exp"(%arg0) : (tensor<8x128xf32>) -> tensor<8x128xf32> + %1 = "tfl.sum"(%0, %cst) {keep_dims = true} : (tensor<8x128xf32>, tensor<1xi32>) -> tensor<8x1xf32> + %2 = "tfl.div"(%0, %1) {fused_activation_function = "NONE"} : (tensor<8x128xf32>, tensor<8x1xf32>) -> tensor<8x128xf32> + return %2 : tensor<8x128xf32> + +// CHECK-LABEL: SoftMaxWithoutNormalization +// CHECK: %[[RESULT:.*]] = "tfl.softmax"(%arg0) {beta = 1.000000e+00 : f32} : (tensor<8x128xf32>) -> tensor<8x128xf32> +// CHECK: return %[[RESULT]] : tensor<8x128xf32> +} diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td index 559d22dcf47..e53024b1e1e 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td @@ -552,3 +552,37 @@ foreach ReduceOp = [TFL_ReduceMaxOp, TFL_ReduceMinOp, TFL_ReduceProdOp, (HasOneUse $reduce)]>; } + +def IsSame : Constraint>; +def HasTwoUse : Constraint>; +def AxesIsLastDimension : Constraint().getNumElements() == 1 && " + "$0.cast().getValue({0}) == " + "$1.getType().cast().getRank() - 1">>; + +// Convert exp(x)/sum(exp(x)) into softmax. +def OptimizeToSoftmax : Pat< + (TFL_DivOp (TFL_ExpOp:$exp $input), + (TFL_SumOp:$sum $sum_input, (ConstantOp I32ElementsAttr: $axes), + ConstBoolAttrTrue), TFL_AF_None), + (TFL_SoftmaxOp $input, ConstF32Attr<"1.0">), + [(IsSame $exp, $sum_input), + (AxesIsLastDimension $axes, $sum_input), + (HasTwoUse $exp), + (HasOneUse $sum)]>; + +// Convert softmax(x-max(x)) into softmax(x) as the softmax op already deals +// with the max normalization. +def FoldNormalizationIntoSoftmax : Pat< + (TFL_SoftmaxOp + (TFL_SubOp:$sub $input, + (TFL_ReduceMaxOp:$max $max_input, (ConstantOp I32ElementsAttr: $axes), + ConstBoolAttrTrue), + TFL_AF_None), + $beta), + (TFL_SoftmaxOp $input, $beta), + [(IsSame $input, $max_input), + (AxesIsLastDimension $axes, $max_input), + (HasOneUse $sub), + (HasOneUse $max)]>; -- GitLab