From 18a59822cf591723afb753dfaff75ae08b2935fe Mon Sep 17 00:00:00 2001 From: zlsh80826 Date: Fri, 17 Dec 2021 13:07:39 +0800 Subject: [PATCH] add launch bound to limit the registers usage for volta architecture (#38113) From --ptxas-options=-v, SegmentOpsKernel uses 66 registers in a block. There are two ways to resolve this problem: Reduce the threads per block launch configuration add __launch_bound__ to give information to nvcc compiler for reducing registers usage this PR chooses __launch_bound__ solution because changing gpu_launch_config may affect other ops. --- paddle/fluid/operators/math/segment_pooling.cu | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/math/segment_pooling.cu b/paddle/fluid/operators/math/segment_pooling.cu index 67cf3162460..0cbfaa4c5df 100644 --- a/paddle/fluid/operators/math/segment_pooling.cu +++ b/paddle/fluid/operators/math/segment_pooling.cu @@ -120,8 +120,9 @@ __global__ void SegmentMeanKernel(const Index* segment_ids, const T* input, } template -__global__ void SegmentOpsKernel(const Index* segment_ids, const T* input, - T* output, Helper h, Pool pool) { +__global__ void __launch_bounds__(1024, 1) + SegmentOpsKernel(const Index* segment_ids, const T* input, T* output, + Helper h, Pool pool) { CUDA_KERNEL_LOOP(stripe_index, h.total_stripe_count) { Index segment_offset, dim_index_base, actual_height; Index inner_dim_size = h.inner_dim_size; -- GitLab