提交 cb83c846 编写于 作者: P Pankaj Kanwar 提交者: TensorFlower Gardener

Canonicalize FusedBatchNorm op to FusedBatchNormV3

PiperOrigin-RevId: 337394368
Change-Id: I7bd7f0513815e1b27584a1b6cba7c447a9d9c9a2
上级 4ad4b488
......@@ -62,40 +62,6 @@ func @Conv2dNCHW(%arg0: tensor<256x3x32x32xf32>, %arg1: tensor<3x3x3x16xf32>) ->
// LAYOUT: "tfl.conv_2d"
}
func @fusedBatchNorm(tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>) {
^bb0(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>):
// OK
%0:5 = "tf.FusedBatchNorm"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
// Unsupported training
%1:5 = "tf.FusedBatchNorm"( %0#0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
// Use other output
%2:5 = "tf.FusedBatchNorm"( %1#0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
return %2, %2#1 : tensor<8x8x8x8xf32>, tensor<8xf32>
// CHECK-LABEL: fusedBatchNorm
// CHECK: %[[CONSTANT:.*]] = constant dense<1.000000e-03>
// variance + epsilon
// CHECK: %[[ADD1:.*]] = "tf.Add"(%[[ARG4:.*]], %[[CONSTANT]])
// rsqrt(variance + epsilon)
// CHECK: %[[RSQRT:.*]] = "tf.Rsqrt"(%[[ADD1]])
// scale * rsqrt(variance + epsilon)
// CHECK: %[[MUL1:.*]] = "tf.Mul"(%[[ARG1:.*]], %[[RSQRT]])
// x * scale * rsqrt(variance + epsilon)
// CHECK: %[[MUL2:.*]] = "tf.Mul"(%[[ARG0:.*]], %[[MUL1]])
// mean * scale * rsqrt(variance + epsilon)
// CHECK: %[[MUL3:.*]] = "tf.Mul"(%[[ARG3:.*]], %[[MUL1]])
// offset - mean * scale * rsqrt(variance + epsilon)
// CHECK: %[[SUB:.*]] = "tf.Sub"(%[[ARG2:.*]], %[[MUL3]])
// x * scale * rsqrt(variance + epsilon) +
// offset - mean * scale * rsqrt(variance + epsilon)
// CHECK: %[[ADD2:.*]] = "tf.Add"(%[[MUL2]], %[[SUB]])
// CHECK: %[[BATCHNORM1_a:[^,]+]], {{.*}} = "tf.FusedBatchNormV3"(%[[ADD2]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]])
// CHECK: "tf.FusedBatchNormV3"(%[[BATCHNORM1_a]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]])
}
func @fusedBatchNormV3(tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>) {
^bb0(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>):
// OK
......
......@@ -740,31 +740,6 @@ struct ConvertTFBroadcastTo : public RewritePattern {
}
};
struct ConvertFusedBatchNorm : public OpRewritePattern<TF::FusedBatchNormOp> {
explicit ConvertFusedBatchNorm(MLIRContext *context)
: OpRewritePattern<TF::FusedBatchNormOp>(context) {}
LogicalResult matchAndRewrite(TF::FusedBatchNormOp tf_fused_batch_norm_op,
PatternRewriter &rewriter) const override {
auto new_result_types =
llvm::to_vector<6>(tf_fused_batch_norm_op.getResultTypes());
// reserve_space_3
new_result_types.push_back(
UnrankedTensorType::get(FloatType::getF32(rewriter.getContext())));
OperationState new_state(tf_fused_batch_norm_op.getLoc(),
TF::FusedBatchNormV3Op::getOperationName(),
tf_fused_batch_norm_op.getOperands(),
new_result_types,
tf_fused_batch_norm_op.getAttrs());
Operation *tf_fused_batch_norm_op_v3 = rewriter.createOperation(new_state);
rewriter.replaceOp(tf_fused_batch_norm_op,
tf_fused_batch_norm_op_v3->getResults().drop_back());
return success();
}
};
// The below pattern is equivalent to the DRR rule below
// The checks are dependent on generated values, so we can't add
// the checks on intermediate values, ideally we should find equivalent
......@@ -1202,7 +1177,6 @@ void PrepareTFPass::runOnFunction() {
patterns.insert<ConvertTFDilatedConvOp<TF::Conv2DOp>, FusedBatchNormV3Pat,
ConvertTFDilatedConvOp<TF::DepthwiseConv2dNativeOp>>(ctx);
patterns.insert<ConvertFusedBatchNorm>(ctx);
TFL::populateWithGenerated(ctx, patterns);
// TODO(karimnosseir): Split to separate pass probably after
// deciding on long term plan for this optimization.
......
......@@ -3942,6 +3942,8 @@ The size of 1D Tensors matches the dimension C of the 4D Tensors.
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
let hasCanonicalizer = 1;
let verifier = [{
return Verify(*this);
}];
......
......@@ -2336,6 +2336,41 @@ void NonMaxSuppressionV3Op::getCanonicalizationPatterns(
results.insert<NMSV3ToNMSV4Op>(context);
}
//===----------------------------------------------------------------------===//
// FusedBatchNormOp
//===----------------------------------------------------------------------===//
namespace {
class ConvertFusedBatchNorm : public OpRewritePattern<TF::FusedBatchNormOp> {
using OpRewritePattern<FusedBatchNormOp>::OpRewritePattern;
LogicalResult matchAndRewrite(TF::FusedBatchNormOp tf_fused_batch_norm_op,
PatternRewriter &rewriter) const override {
auto new_result_types =
llvm::to_vector<6>(tf_fused_batch_norm_op.getResultTypes());
// reserve_space_3
new_result_types.push_back(
UnrankedTensorType::get(FloatType::getF32(rewriter.getContext())));
OperationState new_state(tf_fused_batch_norm_op.getLoc(),
TF::FusedBatchNormV3Op::getOperationName(),
tf_fused_batch_norm_op.getOperands(),
new_result_types,
tf_fused_batch_norm_op.getAttrs());
Operation *tf_fused_batch_norm_op_v3 = rewriter.createOperation(new_state);
rewriter.replaceOp(tf_fused_batch_norm_op,
tf_fused_batch_norm_op_v3->getResults().drop_back());
return success();
}
};
} // namespace.
void FusedBatchNormOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<ConvertFusedBatchNorm>(context);
}
//===----------------------------------------------------------------------===//
// UnpackOp
//===----------------------------------------------------------------------===//
......
......@@ -1284,3 +1284,10 @@ func @testNMSV3ToNMSV4(%arg0: tensor<3x4xf32>, %arg1: tensor<3xf32>, %arg2: tens
%0 = "tf.NonMaxSuppressionV3"(%arg0, %arg1, %max_size, %arg2, %arg3): (tensor<3x4xf32>, tensor<3xf32>, tensor<i32>, tensor<f32>, tensor<f32>) -> (tensor<2xi32>)
return %0 : tensor<2xi32>
}
// CHECK-LABEL: testFusedBatchNormToBatchNormV3
func @testFusedBatchNormToBatchNormV3(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) {
// CHECK: "tf.FusedBatchNormV3"
%0:5 = "tf.FusedBatchNorm"(%arg0, %arg1, %arg2, %arg3, %arg4): (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32> )
return %0#0 : tensor<8x8x8x8xf32>
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册