提交 e54cae40 编写于 作者: S Son Tuan Vu 提交者: TensorFlower Gardener

[XLA:GPU] Limit unroll factor for column reductions

Vectorized column reductions might exceed shmem budget. Limit the unroll factors to avoid this.

PiperOrigin-RevId: 565170403
上级 adcfd3f6
......@@ -840,6 +840,8 @@ ReductionCodegenInfo HloFusionAnalysis::ComputeReductionCodegenInfo(
reduction_is_race_free);
int vector_size = vectorize ? 2 : 1;
// TODO(b/283542954): Autotune num_partial_results? This can make a big
// difference, e.g. by affecting register spilling.
int num_partial_results = 1;
if (!reduction_dimensions.is_row_reduction && vectorize) {
int smallest_input_dtype_bits = SmallestInputDtypeBits();
......@@ -861,26 +863,12 @@ ReductionCodegenInfo HloFusionAnalysis::ComputeReductionCodegenInfo(
} else {
num_partial_results = 2;
}
reduction_tiling[kDimX] *= num_partial_results;
}
// TODO(b/283542954): Autotune num_partial_results? This can make a big
// difference, e.g. by affecting register spilling.
// Row reductions use one shmem block per partial result, so we have to make
// sure we fit in budget. Column reductions only ever use one shmem block.
// (Indeed I *think* "num_partial_results" is a misnomer for column
// reductions; I think it's the number of *complete*, i.e. not partial,
// results per warp.)
// TODO(vuson): something is wrong here:
// The loop was originally applied to both row and column reductions, we
// would need to verify that we could indeed exceed the memory usage for
// column reductions, in which case the outer if needs to be removed.
if (reduction_dimensions.is_row_reduction) {
while (num_partial_results != 1 &&
shmem_usage * num_partial_results > shmem_budget) {
num_partial_results /= 2;
}
reduction_tiling[kDimX] *= num_partial_results;
}
VLOG(3) << "Each thread will produce " << num_partial_results << " output(s)";
......
......@@ -375,3 +375,38 @@ ENTRY reduce.1 {
f32[131072,1024] parameter0
), kind=kLoop, calls=fusion_vectorized
}
// -----
// CHECK: define void @vectorized_col_reduction_exceeding_shmem_budget(
// CHECK-COUNT-12: call void @add
// CHECK-NOT: call void @add
// We are trying to have a column reduction that:
// - triggers vectorization (thus large number of elements 1048576)
// - has a small "smallest input size" (1 for pred)
// - exceeds the shmem budget because `num_partial_results` is 8
HloModule m
add {
a = f64[] parameter(0)
b = f64[] parameter(1)
ROOT out = f64[] add(a, b)
}
fused_computation {
p1 = f64[1048576,1048576]{1,0} parameter(0)
p2 = f64[1048576,1048576]{1,0} parameter(1)
s = pred[1048576,1048576]{1,0} parameter(2)
p = f64[1048576,1048576]{1,0} select(s, p1, p2)
z = f64[] constant(0)
ROOT out = f64[1048576]{0} reduce(p, z), to_apply=add, dimensions={0}
}
ENTRY e {
p1 = f64[1048576,1048576]{1,0} parameter(0)
p2 = f64[1048576,1048576]{1,0} parameter(1)
s = pred[1048576,1048576]{1,0} parameter(2)
ROOT vectorized_col_reduction_exceeding_shmem_budget = f64[1048576]{0} fusion(p1, p2, s), kind=kInput, calls=fused_computation
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册