diff --git a/paddle/fluid/operators/math/segment_pooling.cu b/paddle/fluid/operators/math/segment_pooling.cu index 0b615cefac4eed2b2d972d5ed4b0e3a728d55486..b49b5036ac42e2359a2840f48ab0a42ced6bc406 100644 --- a/paddle/fluid/operators/math/segment_pooling.cu +++ b/paddle/fluid/operators/math/segment_pooling.cu @@ -25,14 +25,12 @@ namespace operators { using Tensor = framework::Tensor; template -__global__ void SegmentMeanCustomKernel( - const Index* segment_ids, const T* input, T* output, T* summed_ids, - const Index input_length_size, const Index inner_dim_size, - const Index output_length_size, const Index total_stripe_count) { +__global__ void SegmentSumIdsKernel(const Index* segment_ids, T* summed_ids, + const Index input_length_size, + const Index total_stripe_count) { CUDA_KERNEL_LOOP(stripe_index, total_stripe_count) { - const Index segment_offset = stripe_index % inner_dim_size; - const Index dim_index_base = - stripe_index / inner_dim_size * Index(DimTileSize); + const Index segment_offset = stripe_index; + const Index dim_index_base = stripe_index * Index(DimTileSize); const Index actual_height = min(Index(DimTileSize), input_length_size - dim_index_base); @@ -41,19 +39,20 @@ __global__ void SegmentMeanCustomKernel( if (dim_index_base > 0) { last_segment_id = segment_ids[dim_index_base - 1]; } - if (segment_offset == 0) { - T sum = T(0); - for (Index j = 0; j < actual_height; j++) { - Index current_segment_id = segment_ids[dim_index_base + j]; - // Note(ZHUI): following check may cause - // cudaErrorLaunchOutOfResources. - // PADDLE_ENFORCE(current_segment_id >= last_segment_id, - // "the segment ids should be sorted, but got " - // "segment_ids[%d]:%d > segment_ids[%d]:%d.", - // dim_index_base + j - 1, dim_index_base + j, - // last_segment_id, current_segment_id); - - if (j > 0 && current_segment_id > last_segment_id) { + T sum = T(0); + for (Index j = 0; j < actual_height; j++) { + Index current_segment_id = segment_ids[dim_index_base + j]; + PADDLE_ENFORCE(current_segment_id >= last_segment_id, + "the segment ids should be sorted, but got " + "segment_ids[%d]:%d > segment_ids[%d]:%d.", + dim_index_base + j - 1, dim_index_base + j, + last_segment_id, current_segment_id); + if (current_segment_id > last_segment_id) { + for (Index interval_id = last_segment_id + 1; + interval_id < current_segment_id; ++interval_id) { + *(summed_ids + interval_id) = 0; + } + if (j > 0) { if (last_segment_id == first_segment_id) { platform::CudaAtomicAdd(summed_ids + last_segment_id, sum); } else { @@ -61,33 +60,60 @@ __global__ void SegmentMeanCustomKernel( } sum = T(0); } - sum += T(1); - last_segment_id = current_segment_id; } - platform::CudaAtomicAdd(summed_ids + last_segment_id, sum); + sum += T(1); + last_segment_id = current_segment_id; + } + platform::CudaAtomicAdd(summed_ids + last_segment_id, sum); + } +} + +template +__global__ void SegmentMeanKernel(const Index* segment_ids, const T* input, + T* output, T* summed_ids, + const Index input_length_size, + const Index inner_dim_size, + const Index output_length_size, + const Index total_stripe_count) { + CUDA_KERNEL_LOOP(stripe_index, total_stripe_count) { + const Index segment_offset = stripe_index % inner_dim_size; + const Index dim_index_base = + stripe_index / inner_dim_size * Index(DimTileSize); + const Index actual_height = + min(Index(DimTileSize), input_length_size - dim_index_base); + + Index first_segment_id = segment_ids[dim_index_base]; + Index last_segment_id = -1; + if (dim_index_base > 0) { + last_segment_id = segment_ids[dim_index_base - 1]; } - // ensure last_segment_id is the largest - last_segment_id = output_length_size; - __syncthreads(); T sum = T(0); for (Index j = 0; j < actual_height; j++) { Index current_segment_id = segment_ids[dim_index_base + j]; if (current_segment_id > last_segment_id) { - const Index output_index = - last_segment_id * inner_dim_size + segment_offset; - if (last_segment_id == first_segment_id) { - platform::CudaAtomicAdd(output + output_index, - sum / *(summed_ids + last_segment_id)); - } else { - *(output + output_index) = sum / *(summed_ids + last_segment_id); + // reset the interval value which do not have corresponding ids. + for (Index interval_id = last_segment_id + 1; + interval_id < current_segment_id; ++interval_id) { + *(output + interval_id * inner_dim_size + segment_offset) = T(0); + } + + if (j > 0) { + Index output_index = + last_segment_id * inner_dim_size + segment_offset; + + if (last_segment_id == first_segment_id) { + platform::CudaAtomicAdd(output + output_index, + sum / *(summed_ids + last_segment_id)); + } else { + *(output + output_index) = sum / *(summed_ids + last_segment_id); + } + sum = T(0); } - sum = T(0); } sum += input[(dim_index_base + j) * inner_dim_size + segment_offset]; last_segment_id = current_segment_id; } - const Index output_index = - last_segment_id * inner_dim_size + segment_offset; + Index output_index = last_segment_id * inner_dim_size + segment_offset; platform::CudaAtomicAdd(output + output_index, sum / *(summed_ids + last_segment_id)); } @@ -122,7 +148,7 @@ __global__ void SegmentOpsKernel(const Index* segment_ids, const T* input, // reset the interval value which do not have corresponding ids. for (Index interval_id = last_segment_id + 1; interval_id < current_segment_id; ++interval_id) { - *(output + interval_id * inner_dim_size + segment_offset) = 0; + *(output + interval_id * inner_dim_size + segment_offset) = T(0); } // don't update result when j=0 if (j > 0) { @@ -272,11 +298,25 @@ class SegmentPoolFunctor { framework::Tensor* output, framework::Tensor* summed_ids = nullptr, const std::string pooltype = "SUM") { + if (pooltype == "MEAN") { + // Sum the segment id num first + T DimTileSize = 8; + auto input_length_size = segment_ids.numel(); + auto total_stripe_count = + (input_length_size + DimTileSize - 1) / DimTileSize; + auto config = platform::GetGpuLaunchConfig1D(ctx, total_stripe_count); + SegmentSumIdsKernel< + T, IndexT, IndexT(8)><<>>( + segment_ids.data(), summed_ids->data(), input_length_size, + total_stripe_count); + } + auto h = ArrangeHelper(input.numel(), segment_ids.dims()[0], output->dims()[0]); auto config = platform::GetGpuLaunchConfig1D(ctx, h.total_stripe_count); if (pooltype == "MEAN") { - SegmentMeanCustomKernel< + SegmentMeanKernel< T, IndexT, IndexT(8)><<>>( segment_ids.data(), input.data(), output->data(),