提交 10f85087 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Convert a soup of ops representing softmax into tfl.softmax

Converting the individual ops into tfl.softmax will improve performance
by enabling fusion as well as imporve accuracy on backends where the
intermediate tensors have reduced precision.

PiperOrigin-RevId: 328287102
Change-Id: Ie2e37804c20854ba57ae7b1d11a4620ea17bd39f
上级 5b3cd9ce
......@@ -1175,3 +1175,29 @@ func @FoldReduceProdKeepDim(%arg0: tensor<8x128xf32>) -> tensor<1x1xf32> {
// CHECK: %[[RESULT:.*]] = "tfl.reduce_prod"(%arg0, %cst) {keep_dims = true} : (tensor<8x128xf32>, tensor<2xi32>) -> tensor<1x1xf32>
// CHECK: return %[[RESULT]] : tensor<1x1xf32>
}
func @SoftMaxWithNormalization(%arg0: tensor<8x128xf32>) -> tensor<8x128xf32> {
%cst = constant dense<1> : tensor<1xi32>
%0 = "tfl.reduce_max"(%arg0, %cst) {keep_dims = true} : (tensor<8x128xf32>, tensor<1xi32>) -> tensor<8x1xf32>
%1 = "tfl.sub"(%arg0, %0) {fused_activation_function = "NONE"} : (tensor<8x128xf32>, tensor<8x1xf32>) -> tensor<8x128xf32>
%2 = "tfl.exp"(%1) : (tensor<8x128xf32>) -> tensor<8x128xf32>
%3 = "tfl.sum"(%2, %cst) {keep_dims = true} : (tensor<8x128xf32>, tensor<1xi32>) -> tensor<8x1xf32>
%4 = "tfl.div"(%2, %3) {fused_activation_function = "NONE"} : (tensor<8x128xf32>, tensor<8x1xf32>) -> tensor<8x128xf32>
return %4 : tensor<8x128xf32>
// CHECK-LABEL: SoftMaxWithNormalization
// CHECK: %[[RESULT:.*]] = "tfl.softmax"(%arg0) {beta = 1.000000e+00 : f32} : (tensor<8x128xf32>) -> tensor<8x128xf32>
// CHECK: return %[[RESULT]] : tensor<8x128xf32>
}
func @SoftMaxWithoutNormalization(%arg0: tensor<8x128xf32>) -> tensor<8x128xf32> {
%cst = constant dense<1> : tensor<1xi32>
%0 = "tfl.exp"(%arg0) : (tensor<8x128xf32>) -> tensor<8x128xf32>
%1 = "tfl.sum"(%0, %cst) {keep_dims = true} : (tensor<8x128xf32>, tensor<1xi32>) -> tensor<8x1xf32>
%2 = "tfl.div"(%0, %1) {fused_activation_function = "NONE"} : (tensor<8x128xf32>, tensor<8x1xf32>) -> tensor<8x128xf32>
return %2 : tensor<8x128xf32>
// CHECK-LABEL: SoftMaxWithoutNormalization
// CHECK: %[[RESULT:.*]] = "tfl.softmax"(%arg0) {beta = 1.000000e+00 : f32} : (tensor<8x128xf32>) -> tensor<8x128xf32>
// CHECK: return %[[RESULT]] : tensor<8x128xf32>
}
......@@ -552,3 +552,37 @@ foreach ReduceOp = [TFL_ReduceMaxOp, TFL_ReduceMinOp, TFL_ReduceProdOp,
(HasOneUse $reduce)]>;
}
def IsSame : Constraint<CPred<"$0 == $1">>;
def HasTwoUse : Constraint<CPred<
"std::distance($0.use_begin(), $0.use_end()) == 2">>;
def AxesIsLastDimension : Constraint<CPred<
"$0.cast<DenseIntElementsAttr>().getNumElements() == 1 && "
"$0.cast<DenseIntElementsAttr>().getValue<APInt>({0}) == "
"$1.getType().cast<ShapedType>().getRank() - 1">>;
// Convert exp(x)/sum(exp(x)) into softmax.
def OptimizeToSoftmax : Pat<
(TFL_DivOp (TFL_ExpOp:$exp $input),
(TFL_SumOp:$sum $sum_input, (ConstantOp I32ElementsAttr: $axes),
ConstBoolAttrTrue), TFL_AF_None),
(TFL_SoftmaxOp $input, ConstF32Attr<"1.0">),
[(IsSame $exp, $sum_input),
(AxesIsLastDimension $axes, $sum_input),
(HasTwoUse $exp),
(HasOneUse $sum)]>;
// Convert softmax(x-max(x)) into softmax(x) as the softmax op already deals
// with the max normalization.
def FoldNormalizationIntoSoftmax : Pat<
(TFL_SoftmaxOp
(TFL_SubOp:$sub $input,
(TFL_ReduceMaxOp:$max $max_input, (ConstantOp I32ElementsAttr: $axes),
ConstBoolAttrTrue),
TFL_AF_None),
$beta),
(TFL_SoftmaxOp $input, $beta),
[(IsSame $input, $max_input),
(AxesIsLastDimension $axes, $max_input),
(HasOneUse $sub),
(HasOneUse $max)]>;
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册