diff --git a/third_party/xla/xla/service/gpu/hlo_fusion_analysis.cc b/third_party/xla/xla/service/gpu/hlo_fusion_analysis.cc index d4972b76b77743c09ebe2dfd1057789d7423b731..8ea517acd32c4eec496c668fc65e59b6a41d3e68 100644 --- a/third_party/xla/xla/service/gpu/hlo_fusion_analysis.cc +++ b/third_party/xla/xla/service/gpu/hlo_fusion_analysis.cc @@ -840,6 +840,8 @@ ReductionCodegenInfo HloFusionAnalysis::ComputeReductionCodegenInfo( reduction_is_race_free); int vector_size = vectorize ? 2 : 1; + // TODO(b/283542954): Autotune num_partial_results? This can make a big + // difference, e.g. by affecting register spilling. int num_partial_results = 1; if (!reduction_dimensions.is_row_reduction && vectorize) { int smallest_input_dtype_bits = SmallestInputDtypeBits(); @@ -861,26 +863,12 @@ ReductionCodegenInfo HloFusionAnalysis::ComputeReductionCodegenInfo( } else { num_partial_results = 2; } - reduction_tiling[kDimX] *= num_partial_results; - } - // TODO(b/283542954): Autotune num_partial_results? This can make a big - // difference, e.g. by affecting register spilling. - - // Row reductions use one shmem block per partial result, so we have to make - // sure we fit in budget. Column reductions only ever use one shmem block. - // (Indeed I *think* "num_partial_results" is a misnomer for column - // reductions; I think it's the number of *complete*, i.e. not partial, - // results per warp.) - // TODO(vuson): something is wrong here: - // The loop was originally applied to both row and column reductions, we - // would need to verify that we could indeed exceed the memory usage for - // column reductions, in which case the outer if needs to be removed. - if (reduction_dimensions.is_row_reduction) { while (num_partial_results != 1 && shmem_usage * num_partial_results > shmem_budget) { num_partial_results /= 2; } + reduction_tiling[kDimX] *= num_partial_results; } VLOG(3) << "Each thread will produce " << num_partial_results << " output(s)"; diff --git a/third_party/xla/xla/service/gpu/tests/reduce_unnested.hlo b/third_party/xla/xla/service/gpu/tests/reduce_unnested.hlo index b8d2d9f0993ebbc820318a82451601b44f652400..0d04f86da9cee4dadff47e9677fb8e5e299b358f 100644 --- a/third_party/xla/xla/service/gpu/tests/reduce_unnested.hlo +++ b/third_party/xla/xla/service/gpu/tests/reduce_unnested.hlo @@ -375,3 +375,38 @@ ENTRY reduce.1 { f32[131072,1024] parameter0 ), kind=kLoop, calls=fusion_vectorized } + +// ----- + +// CHECK: define void @vectorized_col_reduction_exceeding_shmem_budget( +// CHECK-COUNT-12: call void @add +// CHECK-NOT: call void @add + +// We are trying to have a column reduction that: +// - triggers vectorization (thus large number of elements 1048576) +// - has a small "smallest input size" (1 for pred) +// - exceeds the shmem budget because `num_partial_results` is 8 + +HloModule m + +add { + a = f64[] parameter(0) + b = f64[] parameter(1) + ROOT out = f64[] add(a, b) +} + +fused_computation { + p1 = f64[1048576,1048576]{1,0} parameter(0) + p2 = f64[1048576,1048576]{1,0} parameter(1) + s = pred[1048576,1048576]{1,0} parameter(2) + p = f64[1048576,1048576]{1,0} select(s, p1, p2) + z = f64[] constant(0) + ROOT out = f64[1048576]{0} reduce(p, z), to_apply=add, dimensions={0} +} + +ENTRY e { + p1 = f64[1048576,1048576]{1,0} parameter(0) + p2 = f64[1048576,1048576]{1,0} parameter(1) + s = pred[1048576,1048576]{1,0} parameter(2) + ROOT vectorized_col_reduction_exceeding_shmem_budget = f64[1048576]{0} fusion(p1, p2, s), kind=kInput, calls=fused_computation +}