提交 b4587874 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

[xla:gml_st] Relax a check that refuses to tile parallel dimensions when op...

[xla:gml_st] Relax a check that refuses to tile parallel dimensions when op contains a reduction dimension

The user is responsible for making sure to not ask to generate parallel loops on reduction dimension.

PiperOrigin-RevId: 481366310
上级 be8c8442
......@@ -254,12 +254,6 @@ FailureOr<TilingResult> tile(const TilingOptions &options,
op, "missing tile size computation function");
}
// Implement adding accumulator to the gml_st.parallel terminator.
if (options.distribute && llvm::count(op.getLoopIteratorTypes(),
utils::IteratorType::reduction) > 0) {
return failure();
}
// 1. Get the range of the loops that are represented by the operation.
SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter);
size_t numLoops = iterationDomain.size();
......@@ -274,6 +268,7 @@ FailureOr<TilingResult> tile(const TilingOptions &options,
OpBuilder::InsertionGuard guard(rewriter);
tileSizeVector = options.tileSizeComputationFn(rewriter, op);
}
if (tileSizeVector.size() < iterationDomain.size()) {
auto zero = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0);
tileSizeVector.append(numLoops - tileSizeVector.size(), zero);
......
......@@ -235,14 +235,6 @@ func.func @reduce_row(%lhs: tensor<?x?xf32>,
// CHECK-FOR: return %[[FOR_0]]
// CHECK-PARALLEL-LABEL: @reduce_row
// CHECK-PARALLEL-SAME: %[[LHS:.*]]: tensor<?x?xf32>, %[[RHS:.*]]: tensor<?x?xf32>
// CHECK-PARALLEL-NOT: gml_st.parallel
// CHECK-PARALLEL: %[[RES:.*]] = linalg.generic
// CHECK-PARALLEL-NOT: gml_st.parallel
// CHECK-PARALLEL: return %[[RES]]
// -----
func.func @thlo_reduction(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册