未验证 提交 1afe1ac9 编写于 作者: Z Zhong Hui 提交者: GitHub

[OPs] Bug fix, fix the segment mean for illegal syncthreads usage. (#32596)

* [OPs] Bug fix, fix the segment mean for illegal syncthreads usage.
上级 eca8dcc7
...@@ -25,14 +25,12 @@ namespace operators { ...@@ -25,14 +25,12 @@ namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
template <typename T, typename Index, int DimTileSize> template <typename T, typename Index, int DimTileSize>
__global__ void SegmentMeanCustomKernel( __global__ void SegmentSumIdsKernel(const Index* segment_ids, T* summed_ids,
const Index* segment_ids, const T* input, T* output, T* summed_ids, const Index input_length_size,
const Index input_length_size, const Index inner_dim_size, const Index total_stripe_count) {
const Index output_length_size, const Index total_stripe_count) {
CUDA_KERNEL_LOOP(stripe_index, total_stripe_count) { CUDA_KERNEL_LOOP(stripe_index, total_stripe_count) {
const Index segment_offset = stripe_index % inner_dim_size; const Index segment_offset = stripe_index;
const Index dim_index_base = const Index dim_index_base = stripe_index * Index(DimTileSize);
stripe_index / inner_dim_size * Index(DimTileSize);
const Index actual_height = const Index actual_height =
min(Index(DimTileSize), input_length_size - dim_index_base); min(Index(DimTileSize), input_length_size - dim_index_base);
...@@ -41,19 +39,20 @@ __global__ void SegmentMeanCustomKernel( ...@@ -41,19 +39,20 @@ __global__ void SegmentMeanCustomKernel(
if (dim_index_base > 0) { if (dim_index_base > 0) {
last_segment_id = segment_ids[dim_index_base - 1]; last_segment_id = segment_ids[dim_index_base - 1];
} }
if (segment_offset == 0) { T sum = T(0);
T sum = T(0); for (Index j = 0; j < actual_height; j++) {
for (Index j = 0; j < actual_height; j++) { Index current_segment_id = segment_ids[dim_index_base + j];
Index current_segment_id = segment_ids[dim_index_base + j]; PADDLE_ENFORCE(current_segment_id >= last_segment_id,
// Note(ZHUI): following check may cause "the segment ids should be sorted, but got "
// cudaErrorLaunchOutOfResources. "segment_ids[%d]:%d > segment_ids[%d]:%d.",
// PADDLE_ENFORCE(current_segment_id >= last_segment_id, dim_index_base + j - 1, dim_index_base + j,
// "the segment ids should be sorted, but got " last_segment_id, current_segment_id);
// "segment_ids[%d]:%d > segment_ids[%d]:%d.", if (current_segment_id > last_segment_id) {
// dim_index_base + j - 1, dim_index_base + j, for (Index interval_id = last_segment_id + 1;
// last_segment_id, current_segment_id); interval_id < current_segment_id; ++interval_id) {
*(summed_ids + interval_id) = 0;
if (j > 0 && current_segment_id > last_segment_id) { }
if (j > 0) {
if (last_segment_id == first_segment_id) { if (last_segment_id == first_segment_id) {
platform::CudaAtomicAdd(summed_ids + last_segment_id, sum); platform::CudaAtomicAdd(summed_ids + last_segment_id, sum);
} else { } else {
...@@ -61,33 +60,60 @@ __global__ void SegmentMeanCustomKernel( ...@@ -61,33 +60,60 @@ __global__ void SegmentMeanCustomKernel(
} }
sum = T(0); sum = T(0);
} }
sum += T(1);
last_segment_id = current_segment_id;
} }
platform::CudaAtomicAdd(summed_ids + last_segment_id, sum); sum += T(1);
last_segment_id = current_segment_id;
}
platform::CudaAtomicAdd(summed_ids + last_segment_id, sum);
}
}
template <typename T, typename Index, int DimTileSize>
__global__ void SegmentMeanKernel(const Index* segment_ids, const T* input,
T* output, T* summed_ids,
const Index input_length_size,
const Index inner_dim_size,
const Index output_length_size,
const Index total_stripe_count) {
CUDA_KERNEL_LOOP(stripe_index, total_stripe_count) {
const Index segment_offset = stripe_index % inner_dim_size;
const Index dim_index_base =
stripe_index / inner_dim_size * Index(DimTileSize);
const Index actual_height =
min(Index(DimTileSize), input_length_size - dim_index_base);
Index first_segment_id = segment_ids[dim_index_base];
Index last_segment_id = -1;
if (dim_index_base > 0) {
last_segment_id = segment_ids[dim_index_base - 1];
} }
// ensure last_segment_id is the largest
last_segment_id = output_length_size;
__syncthreads();
T sum = T(0); T sum = T(0);
for (Index j = 0; j < actual_height; j++) { for (Index j = 0; j < actual_height; j++) {
Index current_segment_id = segment_ids[dim_index_base + j]; Index current_segment_id = segment_ids[dim_index_base + j];
if (current_segment_id > last_segment_id) { if (current_segment_id > last_segment_id) {
const Index output_index = // reset the interval value which do not have corresponding ids.
last_segment_id * inner_dim_size + segment_offset; for (Index interval_id = last_segment_id + 1;
if (last_segment_id == first_segment_id) { interval_id < current_segment_id; ++interval_id) {
platform::CudaAtomicAdd(output + output_index, *(output + interval_id * inner_dim_size + segment_offset) = T(0);
sum / *(summed_ids + last_segment_id)); }
} else {
*(output + output_index) = sum / *(summed_ids + last_segment_id); if (j > 0) {
Index output_index =
last_segment_id * inner_dim_size + segment_offset;
if (last_segment_id == first_segment_id) {
platform::CudaAtomicAdd(output + output_index,
sum / *(summed_ids + last_segment_id));
} else {
*(output + output_index) = sum / *(summed_ids + last_segment_id);
}
sum = T(0);
} }
sum = T(0);
} }
sum += input[(dim_index_base + j) * inner_dim_size + segment_offset]; sum += input[(dim_index_base + j) * inner_dim_size + segment_offset];
last_segment_id = current_segment_id; last_segment_id = current_segment_id;
} }
const Index output_index = Index output_index = last_segment_id * inner_dim_size + segment_offset;
last_segment_id * inner_dim_size + segment_offset;
platform::CudaAtomicAdd(output + output_index, platform::CudaAtomicAdd(output + output_index,
sum / *(summed_ids + last_segment_id)); sum / *(summed_ids + last_segment_id));
} }
...@@ -122,7 +148,7 @@ __global__ void SegmentOpsKernel(const Index* segment_ids, const T* input, ...@@ -122,7 +148,7 @@ __global__ void SegmentOpsKernel(const Index* segment_ids, const T* input,
// reset the interval value which do not have corresponding ids. // reset the interval value which do not have corresponding ids.
for (Index interval_id = last_segment_id + 1; for (Index interval_id = last_segment_id + 1;
interval_id < current_segment_id; ++interval_id) { interval_id < current_segment_id; ++interval_id) {
*(output + interval_id * inner_dim_size + segment_offset) = 0; *(output + interval_id * inner_dim_size + segment_offset) = T(0);
} }
// don't update result when j=0 // don't update result when j=0
if (j > 0) { if (j > 0) {
...@@ -272,11 +298,25 @@ class SegmentPoolFunctor<platform::CUDADeviceContext, T, IndexT> { ...@@ -272,11 +298,25 @@ class SegmentPoolFunctor<platform::CUDADeviceContext, T, IndexT> {
framework::Tensor* output, framework::Tensor* output,
framework::Tensor* summed_ids = nullptr, framework::Tensor* summed_ids = nullptr,
const std::string pooltype = "SUM") { const std::string pooltype = "SUM") {
if (pooltype == "MEAN") {
// Sum the segment id num first
T DimTileSize = 8;
auto input_length_size = segment_ids.numel();
auto total_stripe_count =
(input_length_size + DimTileSize - 1) / DimTileSize;
auto config = platform::GetGpuLaunchConfig1D(ctx, total_stripe_count);
SegmentSumIdsKernel<
T, IndexT, IndexT(8)><<<config.block_per_grid.x,
config.thread_per_block.x, 0, ctx.stream()>>>(
segment_ids.data<IndexT>(), summed_ids->data<T>(), input_length_size,
total_stripe_count);
}
auto h = ArrangeHelper<IndexT>(input.numel(), segment_ids.dims()[0], auto h = ArrangeHelper<IndexT>(input.numel(), segment_ids.dims()[0],
output->dims()[0]); output->dims()[0]);
auto config = platform::GetGpuLaunchConfig1D(ctx, h.total_stripe_count); auto config = platform::GetGpuLaunchConfig1D(ctx, h.total_stripe_count);
if (pooltype == "MEAN") { if (pooltype == "MEAN") {
SegmentMeanCustomKernel< SegmentMeanKernel<
T, IndexT, IndexT(8)><<<config.block_per_grid.x, T, IndexT, IndexT(8)><<<config.block_per_grid.x,
config.thread_per_block.x, 0, ctx.stream()>>>( config.thread_per_block.x, 0, ctx.stream()>>>(
segment_ids.data<IndexT>(), input.data<T>(), output->data<T>(), segment_ids.data<IndexT>(), input.data<T>(), output->data<T>(),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册