提交 76079c00 编写于 作者: Y Yuanzhong Xu 提交者: TensorFlower Gardener

[MLIR:TF/XLA] Lower SigmoidGrad op to HLO

PiperOrigin-RevId: 306537968
Change-Id: Ie37d58ee0671131c1c0906882705b98247496ddd
上级 d9666eb3
......@@ -2064,6 +2064,17 @@ func @sigmoid(%arg0: tensor<2xf32>) -> tensor<2xf32> {
return %0 : tensor<2xf32>
}
// CHECK-LABEL: @sigmoid_grad
func @sigmoid_grad(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
// CHECK-DAG: [[MUL0:%.+]] = xla_hlo.multiply %arg1, %arg0 : tensor<2xf32>
// CHECK-DAG: [[ONE:%.+]] = xla_hlo.constant dense<1.000000e+00> : tensor<2xf32>
// CHECK-DAG: [[SUB:%.+]] = xla_hlo.subtract [[ONE]], %arg0 : tensor<2xf32>
// CHECK-DAG: [[MUL1:%.+]] = xla_hlo.multiply [[MUL0]], [[SUB]] : tensor<2xf32>
// CHECK: return [[MUL1]]
%0 = "tf.SigmoidGrad"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}
// CHECK-LABEL: @sin
func @sin(%arg0: tensor<2xf32>) -> tensor<2xf32> {
// CHECK: "xla_hlo.sine"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
......
......@@ -610,3 +610,14 @@ def : Pat<(srcDstOpPair[0]:$old $shape, $seed, $seed2),
(CastValueToI64 $old, $shape)),
[(IsShapedTensor $shape)]>;
}
//===----------------------------------------------------------------------===//
// Sigmoid grad op.
//===----------------------------------------------------------------------===//
def : Pat<(TF_SigmoidGradOp AnyRankedTensor:$l, AnyRankedTensor:$r),
(HLO_MulOp
(HLO_MulOp $r, $l, (NullDenseIntElementsAttr)),
(HLO_SubOp (HLO_ConstOp (ConstantSplat<"1"> $l)), $l,
(NullDenseIntElementsAttr)),
(NullDenseIntElementsAttr)),
[(IEEEFloatTensor $l)]>;
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册