提交 1a09f4c5 编写于 作者: R Rahul Joshi 提交者: TensorFlower Gardener

[MLIR] Add canonicalization to fold IfRegion operations with constant condition

- Fold an IfRegion with constant condition by inlining the then or else region in place
  of the IfRegion op.

PiperOrigin-RevId: 327875351
Change-Id: Ie04e32cc7ea93ae93817ad845eb80568d7e25b35
上级 0ea8288a
......@@ -368,6 +368,8 @@ else_branch: A region that computes the outputs of the op if cond = false.
let verifier = [{
return Verify(*this);
}];
let hasCanonicalizer = 1;
}
def TF_LegacyCallOp : TF_Op<"LegacyCall",
......
......@@ -1935,6 +1935,7 @@ static LogicalResult Verify(IfOp op) {
// IfOp canonicalization.
//===----------------------------------------------------------------------===//
namespace {
class FoldConstantIfOp : public OpRewritePattern<TF::IfOp> {
public:
explicit FoldConstantIfOp(MLIRContext *context)
......@@ -1966,7 +1967,7 @@ LogicalResult FoldConstantIfOp::matchAndRewrite(
auto rewrite = [&](auto op_type) {
auto empty = rewriter.getStringAttr("");
auto call_op = rewriter.create<typename decltype(op_type)::CallOp>(
op.getLoc(), op.getResultTypes(), op.getOperands().drop_front(), func,
op.getLoc(), op.getResultTypes(), op.input(), func,
/*config=*/empty, /*config_proto=*/empty, /*executor_type=*/empty);
CopyDeviceAndUnderscoredAttributes(op.getOperation(), call_op);
rewriter.replaceOp(op, call_op.getResults());
......@@ -1979,6 +1980,7 @@ LogicalResult FoldConstantIfOp::matchAndRewrite(
return success();
}
} // anonymous namespace
void IfOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
......@@ -1997,6 +1999,61 @@ static LogicalResult Verify(IfRegionOp op) {
return success();
}
namespace {
class FoldConstantIfRegionOp : public OpRewritePattern<TF::IfRegionOp> {
public:
explicit FoldConstantIfRegionOp(MLIRContext *context)
: OpRewritePattern<TF::IfRegionOp>(context) {}
LogicalResult matchAndRewrite(TF::IfRegionOp op,
PatternRewriter &rewriter) const override;
};
LogicalResult FoldConstantIfRegionOp::matchAndRewrite(
TF::IfRegionOp op, PatternRewriter &rewriter) const {
// Extract the constant cond value.
DenseIntElementsAttr cond_attr;
if (!matchPattern(op.cond(), m_Constant(&cond_attr))) return failure();
// IfRegion condition should always be a scalar. Select the region to fold to.
bool cond = cond_attr.getSplatValue<BoolAttr>().getValue();
Region &region = cond ? op.then_branch() : op.else_branch();
// If the IfRegion is stateless but the region being inlined itself is not
// stateless, then inlining the region could cause a loss of information.
// However, its probably better to fold the IfRegion instead of having the
// dead branch stay.
// Inline the region in place of the IfRegion op, and forward the yield
// inputs to the IfRegion op results. This is possible only if the yield
// types match the result types.
auto yield = cast<YieldOp>(region.front().getTerminator());
auto updated_results = llvm::to_vector<4>(yield.getOperands());
// If the yield types do not match the IfRegion result types, add appropriate
// casts.
rewriter.setInsertionPoint(yield);
for (auto it : llvm::zip(op.getResultTypes(), updated_results)) {
auto &updated_result = std::get<1>(it);
Type result_type = std::get<0>(it);
if (result_type != updated_result.getType()) {
updated_result =
rewriter.create<TF::CastOp>(op.getLoc(), result_type, updated_result,
/*Truncate=*/rewriter.getBoolAttr(false));
}
}
// Inline the region into the block containing the IfRegion.
rewriter.mergeBlockBefore(&region.front(), op);
rewriter.eraseOp(yield);
rewriter.replaceOp(op, updated_results);
return success();
}
} // anonymous namespace
void IfRegionOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<FoldConstantIfRegionOp>(context);
}
//===----------------------------------------------------------------------===//
// InvertOp
//===----------------------------------------------------------------------===//
......
......@@ -902,6 +902,51 @@ func @foldIf(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> (tens
return %4 : tensor<f32>
}
// CHECK-LABEL: foldIfRegion
func @foldIfRegion(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> (tensor<f32>, tensor<f32>) {
%false = "tf.Const"() {value = dense<false> : tensor<i1>} : () -> tensor<i1>
%true = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
// CHECK: [[Val0:%.*]] = "tf.Mul"(%arg0, %arg1)
%0 = "tf.IfRegion"(%true) ({
%true_value = "tf.Mul"(%arg0, %arg1) : (tensor<f32>, tensor<f32>) -> tensor<f32>
"tf.Yield"(%true_value) : (tensor<f32>) -> ()
}, {
%false_value = "tf.Sub"(%arg0, %arg1) : (tensor<f32>, tensor<f32>) -> tensor<f32>
"tf.Yield"(%false_value) : (tensor<f32>) -> ()
}) { is_stateless = true}: (tensor<i1>) -> tensor<f32>
// CHECK: [[Val1:%.*]] = "tf.Sub"(%arg0, %arg1)
%1 = "tf.IfRegion"(%false) ({
%true_value = "tf.Mul"(%arg0, %arg1) : (tensor<f32>, tensor<f32>) -> tensor<f32>
"tf.Yield"(%true_value) : (tensor<f32>) -> ()
}, {
%false_value = "tf.Sub"(%arg0, %arg1) : (tensor<f32>, tensor<f32>) -> tensor<f32>
"tf.Yield"(%false_value) : (tensor<f32>) -> ()
}) { is_stateless = true}: (tensor<i1>) -> tensor<f32>
// CHECK: return [[Val0]], [[Val1]]
return %0, %1 : tensor<f32>, tensor<f32>
}
// CHECK-LABEL: foldIfRegionMismatchedTypes
func @foldIfRegionMismatchedTypes(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>, %arg2: tensor<i1>) -> tensor<1xf32> {
%false = "tf.Const"() {value = dense<false> : tensor<i1>} : () -> tensor<i1>
%true = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
// CHECK: [[Val0:%.*]] = "tf.Mul"(%arg0, %arg1)
// CHECK-NEXT: [[Cast:%.*]] = "tf.Cast"([[Val0]])
// CHECK-NEXT: return [[Cast]]
%0 = "tf.IfRegion"(%true) ({
%true_value = "tf.Mul"(%arg0, %arg1) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
"tf.Yield"(%true_value) : (tensor<?xf32>) -> ()
}, {
%false_value = "tf.Sub"(%arg0, %arg1) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
"tf.Yield"(%false_value) : (tensor<?xf32>) -> ()
}) { is_stateless = true}: (tensor<i1>) -> tensor<1xf32>
return %0 : tensor<1xf32>
}
// CHECK-LABEL: foldCase
func @foldCase(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>) {
%2 = constant dense<1> : tensor<i32>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册