From fb1ed49e98a71cfa55de32ba94089ea6f325600e Mon Sep 17 00:00:00 2001 From: Hanhan Wang Date: Fri, 21 Aug 2020 23:26:35 -0700 Subject: [PATCH] Enhance lowering reshape op to Linalg. Handle non-expansion and non-collapsion cases by rewriting it to two reshape ops. PiperOrigin-RevId: 327926863 Change-Id: I2b9f406d505ab69d9e25e892f75f38aa03467e1e --- .../mhlo/transforms/legalize_to_linalg.cc | 43 ++++++++++++++++++- .../hlo/tests/hlo-legalize-to-linalg.mlir | 12 ++++++ .../hlo/tests/lhlo-legalize-to-linalg.mlir | 14 ++++++ 3 files changed, 67 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc index f47f2c2fbdc..033021c36ac 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc @@ -15,6 +15,8 @@ limitations under the License. // This file implements logic for lowering HLO/LHLO dialect to Linalg dialect. +#include + #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h" @@ -598,6 +600,7 @@ class ReshapeOpConverter : public OpConversionPattern { unsigned currSrcDim = 0, currDstDim = 0; SmallVector reassociationMap( dstShape.size()); + bool isExpandingOrCollapsing = true; while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) { int64_t dstSize = dstShape[currDstDim]; int64_t srcSize = srcShape[currSrcDim]; @@ -619,11 +622,47 @@ class ReshapeOpConverter : public OpConversionPattern { } } } else { - return failure(); + isExpandingOrCollapsing = false; + break; } currDstDim++; } - if (currSrcDim != srcShape.size()) return failure(); + if (currSrcDim != srcShape.size()) isExpandingOrCollapsing = false; + + if (!isExpandingOrCollapsing) { + auto getIdentityExprs = [&rewriter](int n) { + SmallVector exprs; + for (int i = 0; i < n; ++i) + exprs.push_back(rewriter.getAffineDimExpr(i)); + return exprs; + }; + Location loc = reshapeOp.getLoc(); + int64_t totalElems = std::accumulate(srcShape.begin(), srcShape.end(), 1, + std::multiplies()); + auto elemType = operandType.getElementType(); + SmallVector collapsingMap = { + getIdentityExprs(dstShape.size())}; + SmallVector expandingMap = { + getIdentityExprs(srcShape.size())}; + + if (isLHLO) { + auto collapsedType = MemRefType::get({totalElems}, elemType); + Value collapsedOp = rewriter.create( + loc, collapsedType, args[0], collapsingMap); + Value reshapeBuffer = rewriter.create( + loc, resultType, collapsedOp, expandingMap); + rewriter.replaceOpWithNewOp( + reshapeOp, reshapeBuffer, args[1], /*inputPermutation =*/nullptr, + /*outputPermutation =*/nullptr); + } else { + auto collapsedType = RankedTensorType::get({totalElems}, elemType); + Value collapsedOp = rewriter.create( + loc, collapsedType, args[0], collapsingMap); + rewriter.replaceOpWithNewOp( + reshapeOp, resultType, collapsedOp, expandingMap); + } + return success(); + } if (isLHLO) { Value reshapeBuffer = rewriter.create( diff --git a/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-linalg.mlir b/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-linalg.mlir index 46725e0bd09..aecf612962a 100644 --- a/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-linalg.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-linalg.mlir @@ -373,6 +373,18 @@ func @reshape_2D_4D(%arg0: tensor<12x42xi32>) -> tensor<12x1x42x1xi32> { // ----- +// CHECK-DAG: #[[RESHAPE_MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[RESHAPE_MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-LABEL: func @reshape_3D_4D +func @reshape_3D_4D(%arg0: tensor<1x49x16xf32>) -> tensor<1x784x1x1xf32> { + %0 = "mhlo.reshape"(%arg0) : (tensor<1x49x16xf32>) -> tensor<1x784x1x1xf32> + return %0 : tensor<1x784x1x1xf32> +} +// CHECK: linalg.tensor_reshape %{{.*}} [#[[RESHAPE_MAP1]]] +// CHECK: linalg.tensor_reshape %{{.*}} [#[[RESHAPE_MAP2]]] + +// ----- + // CHECK-LABEL: func @minf func @minf(%lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> { %0 = "mhlo.minimum"(%lhs, %rhs) diff --git a/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-linalg.mlir b/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-linalg.mlir index 768d8da22bd..f174b005a8d 100644 --- a/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-linalg.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-linalg.mlir @@ -688,6 +688,20 @@ func @reshape_2D_4D(%arg0: memref<12x42xi32>, %arg1 : memref<12x1x42x1xi32>) { // ----- +// CHECK-DAG: #[[RESHAPE_MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[RESHAPE_MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-LABEL: func @reshape_3D_4D +func @reshape_3D_4D(%arg0: memref<1x49x16xf32>, %arg1: memref<1x784x1x1xf32>) { + "lmhlo.reshape"(%arg0, %arg1) + : (memref<1x49x16xf32>, memref<1x784x1x1xf32>) -> () + return +} +// CHECK: linalg.reshape %{{.*}} [#[[RESHAPE_MAP1]]] +// CHECK: linalg.reshape %{{.*}} [#[[RESHAPE_MAP2]]] +// CHECK: linalg.copy + +// ----- + // CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 2)> // CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-LABEL: func @reverse -- GitLab