未验证 提交 54ab656c 编写于 作者: Z Zhong Hui 提交者: GitHub

[OPs] Bug fix, fix the segment mean for illegal syncthreads usage. (#32596) (#32610)

* [OPs] Bug fix, fix the segment mean for illegal syncthreads usage.
上级 15158927
......@@ -25,14 +25,12 @@ namespace operators {
using Tensor = framework::Tensor;
template <typename T, typename Index, int DimTileSize>
__global__ void SegmentMeanCustomKernel(
const Index* segment_ids, const T* input, T* output, T* summed_ids,
const Index input_length_size, const Index inner_dim_size,
const Index output_length_size, const Index total_stripe_count) {
__global__ void SegmentSumIdsKernel(const Index* segment_ids, T* summed_ids,
const Index input_length_size,
const Index total_stripe_count) {
CUDA_KERNEL_LOOP(stripe_index, total_stripe_count) {
const Index segment_offset = stripe_index % inner_dim_size;
const Index dim_index_base =
stripe_index / inner_dim_size * Index(DimTileSize);
const Index segment_offset = stripe_index;
const Index dim_index_base = stripe_index * Index(DimTileSize);
const Index actual_height =
min(Index(DimTileSize), input_length_size - dim_index_base);
......@@ -41,19 +39,20 @@ __global__ void SegmentMeanCustomKernel(
if (dim_index_base > 0) {
last_segment_id = segment_ids[dim_index_base - 1];
}
if (segment_offset == 0) {
T sum = T(0);
for (Index j = 0; j < actual_height; j++) {
Index current_segment_id = segment_ids[dim_index_base + j];
// Note(ZHUI): following check may cause
// cudaErrorLaunchOutOfResources.
// PADDLE_ENFORCE(current_segment_id >= last_segment_id,
// "the segment ids should be sorted, but got "
// "segment_ids[%d]:%d > segment_ids[%d]:%d.",
// dim_index_base + j - 1, dim_index_base + j,
// last_segment_id, current_segment_id);
if (j > 0 && current_segment_id > last_segment_id) {
T sum = T(0);
for (Index j = 0; j < actual_height; j++) {
Index current_segment_id = segment_ids[dim_index_base + j];
PADDLE_ENFORCE(current_segment_id >= last_segment_id,
"the segment ids should be sorted, but got "
"segment_ids[%d]:%d > segment_ids[%d]:%d.",
dim_index_base + j - 1, dim_index_base + j,
last_segment_id, current_segment_id);
if (current_segment_id > last_segment_id) {
for (Index interval_id = last_segment_id + 1;
interval_id < current_segment_id; ++interval_id) {
*(summed_ids + interval_id) = 0;
}
if (j > 0) {
if (last_segment_id == first_segment_id) {
platform::CudaAtomicAdd(summed_ids + last_segment_id, sum);
} else {
......@@ -61,33 +60,60 @@ __global__ void SegmentMeanCustomKernel(
}
sum = T(0);
}
sum += T(1);
last_segment_id = current_segment_id;
}
platform::CudaAtomicAdd(summed_ids + last_segment_id, sum);
sum += T(1);
last_segment_id = current_segment_id;
}
platform::CudaAtomicAdd(summed_ids + last_segment_id, sum);
}
}
template <typename T, typename Index, int DimTileSize>
__global__ void SegmentMeanKernel(const Index* segment_ids, const T* input,
T* output, T* summed_ids,
const Index input_length_size,
const Index inner_dim_size,
const Index output_length_size,
const Index total_stripe_count) {
CUDA_KERNEL_LOOP(stripe_index, total_stripe_count) {
const Index segment_offset = stripe_index % inner_dim_size;
const Index dim_index_base =
stripe_index / inner_dim_size * Index(DimTileSize);
const Index actual_height =
min(Index(DimTileSize), input_length_size - dim_index_base);
Index first_segment_id = segment_ids[dim_index_base];
Index last_segment_id = -1;
if (dim_index_base > 0) {
last_segment_id = segment_ids[dim_index_base - 1];
}
// ensure last_segment_id is the largest
last_segment_id = output_length_size;
__syncthreads();
T sum = T(0);
for (Index j = 0; j < actual_height; j++) {
Index current_segment_id = segment_ids[dim_index_base + j];
if (current_segment_id > last_segment_id) {
const Index output_index =
last_segment_id * inner_dim_size + segment_offset;
if (last_segment_id == first_segment_id) {
platform::CudaAtomicAdd(output + output_index,
sum / *(summed_ids + last_segment_id));
} else {
*(output + output_index) = sum / *(summed_ids + last_segment_id);
// reset the interval value which do not have corresponding ids.
for (Index interval_id = last_segment_id + 1;
interval_id < current_segment_id; ++interval_id) {
*(output + interval_id * inner_dim_size + segment_offset) = T(0);
}
if (j > 0) {
Index output_index =
last_segment_id * inner_dim_size + segment_offset;
if (last_segment_id == first_segment_id) {
platform::CudaAtomicAdd(output + output_index,
sum / *(summed_ids + last_segment_id));
} else {
*(output + output_index) = sum / *(summed_ids + last_segment_id);
}
sum = T(0);
}
sum = T(0);
}
sum += input[(dim_index_base + j) * inner_dim_size + segment_offset];
last_segment_id = current_segment_id;
}
const Index output_index =
last_segment_id * inner_dim_size + segment_offset;
Index output_index = last_segment_id * inner_dim_size + segment_offset;
platform::CudaAtomicAdd(output + output_index,
sum / *(summed_ids + last_segment_id));
}
......@@ -122,7 +148,7 @@ __global__ void SegmentOpsKernel(const Index* segment_ids, const T* input,
// reset the interval value which do not have corresponding ids.
for (Index interval_id = last_segment_id + 1;
interval_id < current_segment_id; ++interval_id) {
*(output + interval_id * inner_dim_size + segment_offset) = 0;
*(output + interval_id * inner_dim_size + segment_offset) = T(0);
}
// don't update result when j=0
if (j > 0) {
......@@ -272,11 +298,25 @@ class SegmentPoolFunctor<platform::CUDADeviceContext, T, IndexT> {
framework::Tensor* output,
framework::Tensor* summed_ids = nullptr,
const std::string pooltype = "SUM") {
if (pooltype == "MEAN") {
// Sum the segment id num first
T DimTileSize = 8;
auto input_length_size = segment_ids.numel();
auto total_stripe_count =
(input_length_size + DimTileSize - 1) / DimTileSize;
auto config = platform::GetGpuLaunchConfig1D(ctx, total_stripe_count);
SegmentSumIdsKernel<
T, IndexT, IndexT(8)><<<config.block_per_grid.x,
config.thread_per_block.x, 0, ctx.stream()>>>(
segment_ids.data<IndexT>(), summed_ids->data<T>(), input_length_size,
total_stripe_count);
}
auto h = ArrangeHelper<IndexT>(input.numel(), segment_ids.dims()[0],
output->dims()[0]);
auto config = platform::GetGpuLaunchConfig1D(ctx, h.total_stripe_count);
if (pooltype == "MEAN") {
SegmentMeanCustomKernel<
SegmentMeanKernel<
T, IndexT, IndexT(8)><<<config.block_per_grid.x,
config.thread_per_block.x, 0, ctx.stream()>>>(
segment_ids.data<IndexT>(), input.data<T>(), output->data<T>(),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册