From 4a9d21de4987ced5aaf58a318ac598abff853b48 Mon Sep 17 00:00:00 2001 From: Zhong Hui Date: Thu, 24 Sep 2020 10:11:06 +0800 Subject: [PATCH] Add GPU Kernels of Segment Ops, support, sum, max, min, mean Add GPU Kernels of Segment Ops, support, sum, max, min, mean --- .../fluid/operators/math/segment_pooling.cu | 365 ++++++++++++++++++ paddle/fluid/operators/segment_pool_op.cu | 28 ++ paddle/fluid/operators/segment_pool_op.h | 40 ++ paddle/fluid/platform/cuda_primitives.h | 107 +++++ 4 files changed, 540 insertions(+) create mode 100644 paddle/fluid/operators/math/segment_pooling.cu create mode 100644 paddle/fluid/operators/segment_pool_op.cu diff --git a/paddle/fluid/operators/math/segment_pooling.cu b/paddle/fluid/operators/math/segment_pooling.cu new file mode 100644 index 00000000000..bb2b6db100b --- /dev/null +++ b/paddle/fluid/operators/math/segment_pooling.cu @@ -0,0 +1,365 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/elementwise/elementwise_div_op.h" +#include "paddle/fluid/operators/gather.cu.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/operators/math/segment_pooling.h" +#include "paddle/fluid/platform/cuda_primitives.h" +#include "paddle/fluid/platform/gpu_launch_param_config.h" +#include "paddle/fluid/platform/macros.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +__global__ void SegmentMeanCustomKernel( + const Index* segment_ids, const T* input, T* output, T* summed_ids, + const Index input_length_size, const Index inner_dim_size, + const Index output_length_size, const Index total_stripe_count) { + CUDA_KERNEL_LOOP(stripe_index, total_stripe_count) { + const Index segment_offset = stripe_index % inner_dim_size; + const Index dim_index_base = + stripe_index / inner_dim_size * Index(DimTileSize); + const Index actual_height = + min(Index(DimTileSize), input_length_size - dim_index_base); + + Index first_segment_id = segment_ids[dim_index_base]; + Index last_segment_id = -1; + if (dim_index_base > 0) { + last_segment_id = segment_ids[dim_index_base - 1]; + } + if (segment_offset == 0) { + T sum = T(0); + for (Index j = 0; j < actual_height; j++) { + Index current_segment_id = segment_ids[dim_index_base + j]; + // Note(ZHUI): following check may cause + // cudaErrorLaunchOutOfResources. + // PADDLE_ENFORCE(current_segment_id >= last_segment_id, + // "the segment ids should be sorted, but got " + // "segment_ids[%d]:%d > segment_ids[%d]:%d.", + // dim_index_base + j - 1, dim_index_base + j, + // last_segment_id, current_segment_id); + + if (j > 0 && current_segment_id > last_segment_id) { + if (last_segment_id == first_segment_id) { + platform::CudaAtomicAdd(summed_ids + last_segment_id, sum); + } else { + *(summed_ids + last_segment_id) = sum; + } + sum = T(0); + } + sum += T(1); + last_segment_id = current_segment_id; + } + platform::CudaAtomicAdd(summed_ids + last_segment_id, sum); + } + // ensure last_segment_id is the largest + last_segment_id = output_length_size; + __syncthreads(); + T sum = T(0); + for (Index j = 0; j < actual_height; j++) { + Index current_segment_id = segment_ids[dim_index_base + j]; + if (current_segment_id > last_segment_id) { + const Index output_index = + last_segment_id * inner_dim_size + segment_offset; + if (last_segment_id == first_segment_id) { + platform::CudaAtomicAdd(output + output_index, + sum / *(summed_ids + last_segment_id)); + } else { + *(output + output_index) = sum / *(summed_ids + last_segment_id); + } + sum = T(0); + } + sum += input[(dim_index_base + j) * inner_dim_size + segment_offset]; + last_segment_id = current_segment_id; + } + const Index output_index = + last_segment_id * inner_dim_size + segment_offset; + platform::CudaAtomicAdd(output + output_index, + sum / *(summed_ids + last_segment_id)); + } +} + +template +__global__ void SegmentOpsKernel(const Index* segment_ids, const T* input, + T* output, Helper h, Pool pool) { + CUDA_KERNEL_LOOP(stripe_index, h.total_stripe_count) { + Index segment_offset, dim_index_base, actual_height; + Index inner_dim_size = h.inner_dim_size; + h.calculate(stripe_index, segment_offset, dim_index_base, actual_height); + + T minmax = pool.initial(); + Index first_segment_id = segment_ids[dim_index_base]; + // -1 is for the start value when interval_id = 0 + Index last_segment_id = -1; + if (dim_index_base > 0) { + last_segment_id = segment_ids[dim_index_base - 1]; + } + + for (Index j = 0; j < actual_height; j++) { + Index current_segment_id = segment_ids[dim_index_base + j]; + // ensure the segment_ids is sorted. + PADDLE_ENFORCE(current_segment_id >= last_segment_id, + "The segment ids should be sorted, but got " + "segment_ids[%d]:%d > segment_ids[%d]:%d.", + dim_index_base + j - 1, dim_index_base + j, + last_segment_id, current_segment_id); + + if (current_segment_id > last_segment_id) { + // reset the interval value which do not have corresponding ids. + for (Index interval_id = last_segment_id + 1; + interval_id < current_segment_id; ++interval_id) { + *(output + interval_id * inner_dim_size + segment_offset) = 0; + } + // don't update result when j=0 + if (j > 0) { + const Index output_index = + last_segment_id * inner_dim_size + segment_offset; + if (last_segment_id == first_segment_id) { + pool.atomic(output + output_index, minmax); + } else { + *(output + output_index) = minmax; + } + minmax = pool.initial(); + } + } + pool.compute( + input[(dim_index_base + j) * inner_dim_size + segment_offset], + &minmax); + last_segment_id = current_segment_id; + } + const Index output_index = + last_segment_id * inner_dim_size + segment_offset; + pool.atomic(output + output_index, minmax); + } +} + +template +__global__ void SegmentIndexGradKernel(const Index* segment_ids, const T* input, + const T* output, const T* out_grad, + T* in_grad, Helper h) { + CUDA_KERNEL_LOOP(stripe_index, h.total_stripe_count) { + Index segment_offset, dim_index_base, actual_height; + h.calculate(stripe_index, segment_offset, dim_index_base, actual_height); + + for (Index j = 0; j < actual_height; j++) { + Index current_segment_id = segment_ids[dim_index_base + j]; + Index input_index = + (dim_index_base + j) * h.inner_dim_size + segment_offset; + Index output_index = + current_segment_id * h.inner_dim_size + segment_offset; + if (input[input_index] == output[output_index]) { + in_grad[input_index] = out_grad[output_index]; + } + } + } +} + +template +class MaxPool { + public: + DEVICE inline T initial() { return static_cast(-FLT_MAX); } + DEVICE inline void compute(const T& x, T* y) { *y = *y > x ? *y : x; } + DEVICE inline T atomic(T* address, const T val) { + return platform::CudaAtomicMax(address, val); + } +}; + +template +class MinPool { + public: + DEVICE inline T initial() { return static_cast(FLT_MAX); } + DEVICE inline void compute(const T& x, T* y) { *y = *y < x ? *y : x; } + DEVICE inline T atomic(T* address, const T val) { + return platform::CudaAtomicMin(address, val); + } +}; + +template +class SumPool { + public: + DEVICE inline T initial() { return static_cast(0); } + DEVICE inline void compute(const T& x, T* y) { *y = *y + x; } + DEVICE inline T atomic(T* address, const T val) { + return platform::CudaAtomicAdd(address, val); + } +}; + +template +class ArrangeHelper { + public: + const T input_total_size; + const T input_length_size; + const T output_length_size; + T inner_dim_size; + T total_stripe_count; + const T DimTileSize = 8; + + ArrangeHelper(T a, T b, T c) + : input_total_size(a), input_length_size(b), output_length_size(c) { + T input_outer_dim_num_stripe = + (input_length_size + DimTileSize - 1) / DimTileSize; + inner_dim_size = input_total_size / input_length_size; + total_stripe_count = inner_dim_size * input_outer_dim_num_stripe; + } + + DEVICE inline void calculate(T stripe_index, T& segment_offset, + T& dim_index_base, T& actual_height) { + segment_offset = stripe_index % inner_dim_size; + dim_index_base = stripe_index / inner_dim_size * DimTileSize; + actual_height = min(DimTileSize, input_length_size - dim_index_base); + } +}; + +template +void SegmentPoolCUDAGradFunctor(const platform::CUDADeviceContext& ctx, + const framework::Tensor& input, + const framework::Tensor& segment_ids, + const framework::Tensor& output, + const framework::Tensor& out_grad, + framework::Tensor* in_grad, + const std::string pooltype = "SUM") { + auto h = ArrangeHelper(input.numel(), segment_ids.dims()[0], + output.dims()[0]); + auto config = platform::GetGpuLaunchConfig1D(ctx, h.total_stripe_count); + if (pooltype == "MAX" || pooltype == "MIN") { + SegmentIndexGradKernel><<< + config.block_per_grid.x, config.thread_per_block.x, 0, ctx.stream()>>>( + segment_ids.data(), input.data(), output.data(), + out_grad.data(), in_grad->data(), h); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Unsupported segment pooling grad operation, Only MAX, MIN " + "available, but got %s.", + pooltype)); + } +} + +template +__global__ void SimpleDiv(T* x, const T* y, const int len, const int dim) { + for (int i = blockIdx.x; i < len; i += gridDim.x) { + __shared__ T y_i; + auto base = i * dim; + if (threadIdx.x == 0) { + y_i = y[i]; + } + __syncthreads(); + for (int j = threadIdx.x; j < dim; j += blockDim.x) { + x[base + j] /= y_i; + } + } +} + +template +class SegmentPoolFunctor { + public: + void operator()(const platform::CUDADeviceContext& ctx, + const framework::Tensor& input, + const framework::Tensor& segment_ids, + framework::Tensor* output, + framework::Tensor* summed_ids = nullptr, + const std::string pooltype = "SUM") { + auto h = ArrangeHelper(input.numel(), segment_ids.dims()[0], + output->dims()[0]); + auto config = platform::GetGpuLaunchConfig1D(ctx, h.total_stripe_count); + if (pooltype == "MEAN") { + SegmentMeanCustomKernel< + T, IndexT, IndexT(8)><<>>( + segment_ids.data(), input.data(), output->data(), + summed_ids->data(), h.input_length_size, h.inner_dim_size, + h.output_length_size, h.total_stripe_count); + } else if (pooltype == "SUM") { + SumPool pool; + SegmentOpsKernel< + T, IndexT, ArrangeHelper, + SumPool><<>>(segment_ids.data(), + input.data(), output->data(), h, + pool); + } else if (pooltype == "MAX") { + MaxPool pool; + SegmentOpsKernel< + T, IndexT, ArrangeHelper, + MaxPool><<>>(segment_ids.data(), + input.data(), output->data(), h, + pool); + } else if (pooltype == "MIN") { + MinPool pool; + SegmentOpsKernel< + T, IndexT, ArrangeHelper, + MinPool><<>>(segment_ids.data(), + input.data(), output->data(), h, + pool); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Unsupported segment pooling operation, Only MEAN, SUM, MAX, MIN " + "available, but got %s.", + pooltype)); + } + } +}; + +template +class SegmentPoolGradFunctor { + public: + void operator()(const platform::CUDADeviceContext& context, + const framework::Tensor& input, + const framework::Tensor& output, + const framework::Tensor& out_grad, + const framework::Tensor& segments, framework::Tensor* in_grad, + const framework::Tensor* summed_ids = nullptr, + const std::string pooltype = "SUM") { + if (pooltype == "MAX" || pooltype == "MIN") { + SegmentPoolCUDAGradFunctor(context, input, segments, output, + out_grad, in_grad, pooltype); + } else if (pooltype == "MEAN") { + framework::Tensor mean_grad; + mean_grad.mutable_data(input.dims(), context.GetPlace()); + framework::TensorCopy(out_grad, context.GetPlace(), context, &mean_grad); + int len = output.dims()[0]; + int dim = output.numel() / len; + auto config = platform::GetGpuLaunchConfig1D(context, len); + SimpleDiv<<>>(mean_grad.data(), + summed_ids->data(), len, dim); + GPUGather(context, mean_grad, segments, in_grad); + } else if (pooltype == "SUM") { + GPUGather(context, out_grad, segments, in_grad); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Unsupported segment pooling operation, Only MEAN, SUM, MAX, MIN " + "available, but got %s.", + pooltype)); + } + } +}; + +using CUDA = paddle::platform::CUDADeviceContext; +template class SegmentPoolFunctor; +template class SegmentPoolFunctor; +template class SegmentPoolFunctor; +template class SegmentPoolFunctor; +template class SegmentPoolGradFunctor; +template class SegmentPoolGradFunctor; +template class SegmentPoolGradFunctor; +template class SegmentPoolGradFunctor; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/segment_pool_op.cu b/paddle/fluid/operators/segment_pool_op.cu new file mode 100644 index 00000000000..dc92d7fcc3a --- /dev/null +++ b/paddle/fluid/operators/segment_pool_op.cu @@ -0,0 +1,28 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/gather.cu.h" +#include "paddle/fluid/operators/segment_pool_op.h" +#include "paddle/fluid/platform/cuda_primitives.h" +#include "paddle/fluid/platform/gpu_launch_param_config.h" + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + segment_pool, + ops::SegmentPoolKernel, + ops::SegmentPoolKernel); +REGISTER_OP_CUDA_KERNEL( + segment_pool_grad, + ops::SegmentPoolGradKernel, + ops::SegmentPoolGradKernel); diff --git a/paddle/fluid/operators/segment_pool_op.h b/paddle/fluid/operators/segment_pool_op.h index a505946b9f5..23b0c31608d 100644 --- a/paddle/fluid/operators/segment_pool_op.h +++ b/paddle/fluid/operators/segment_pool_op.h @@ -63,6 +63,46 @@ void SegmentKernelLaunchHelper(const framework::ExecutionContext& context) { auto& dev_ctx = context.template device_context(); set_zero(dev_ctx, output, static_cast(0)); } +#ifdef PADDLE_WITH_CUDA + if (!cpu_place) { + Tensor length; + length.mutable_data(framework::make_ddim({1}), + platform::CPUPlace()); + IndexT* length_data = length.data(); + const IndexT* segment_ids = segment->data(); + + PADDLE_ENFORCE_CUDA_SUCCESS( + cudaMemcpy(length_data, segment_ids + num_indices - 1, sizeof(IndexT), + cudaMemcpyDeviceToHost)); + + IndexT length_host = length_data[0]; + length_host++; + PADDLE_ENFORCE_GT( + length_host, 0, + platform::errors::InvalidArgument( + "Segment ids must be >= 0, but got last id %d", length_data[0])); + auto dims = input->dims(); + dims[0] = static_cast(length_host); + output->Resize({dims}); + output->mutable_data(context.GetPlace()); + T init_value = 0; + if (pooltype == "MAX") { + init_value = static_cast(-FLT_MAX); + } else if (pooltype == "MIN") { + init_value = static_cast(FLT_MAX); + } + math::SetConstant setconst; + auto& dev_ctx = context.template device_context(); + setconst(dev_ctx, output, static_cast(init_value)); + // the gpu kernel of mean pool record the counts of segment_ids + if (pooltype == "MEAN") { + summed_ids = context.Output("SummedIds"); + summed_ids->Resize({dims[0], 1}); + summed_ids->mutable_data(context.GetPlace()); + setconst(dev_ctx, summed_ids, static_cast(1e-12)); + } + } +#endif SegmentPoolFunctor pool; diff --git a/paddle/fluid/platform/cuda_primitives.h b/paddle/fluid/platform/cuda_primitives.h index 67ea64833d3..f7c77071b12 100644 --- a/paddle/fluid/platform/cuda_primitives.h +++ b/paddle/fluid/platform/cuda_primitives.h @@ -128,5 +128,112 @@ CUDA_ATOMIC_WRAPPER(Add, float16) { } #endif + +// For atomicMax +USE_CUDA_ATOMIC(Max, int); +USE_CUDA_ATOMIC(Max, unsigned int); +// CUDA API uses unsigned long long int, we cannot use uint64_t here. +// It because unsigned long long int is not necessarily uint64_t +USE_CUDA_ATOMIC(Max, unsigned long long int); // NOLINT + +CUDA_ATOMIC_WRAPPER(Max, int64_t) { + // Here, we check long long int must be int64_t. + static_assert(sizeof(int64_t) == sizeof(long long int), // NOLINT + "long long should be int64"); + return CudaAtomicMax( + reinterpret_cast(address), // NOLINT + static_cast(val)); // NOLINT +} + +CUDA_ATOMIC_WRAPPER(Max, float) { + if (*address >= val) { + return; + } + + int *const address_as_i = (int *)address; + int old = *address_as_i, assumed; + + do { + assumed = old; + if (__int_as_float(assumed) >= val) { + break; + } + + old = atomicCAS(address_as_i, assumed, __float_as_int(val)); + } while (assumed != old); +} + +CUDA_ATOMIC_WRAPPER(Max, double) { + if (*address >= val) { + return; + } + + unsigned long long int *const address_as_ull = + (unsigned long long int *)address; + unsigned long long int old = *address_as_ull, assumed; + + do { + assumed = old; + if (__longlong_as_double(assumed) >= val) { + break; + } + + old = atomicCAS(address_as_ull, assumed, __double_as_longlong(val)); + } while (assumed != old); +} + +// For atomicMin +USE_CUDA_ATOMIC(Min, int); +USE_CUDA_ATOMIC(Min, unsigned int); +// CUDA API uses unsigned long long int, we cannot use uint64_t here. +// It because unsigned long long int is not necessarily uint64_t +USE_CUDA_ATOMIC(Min, unsigned long long int); // NOLINT + +CUDA_ATOMIC_WRAPPER(Min, int64_t) { + // Here, we check long long int must be int64_t. + static_assert(sizeof(int64_t) == sizeof(long long int), // NOLINT + "long long should be int64"); + return CudaAtomicMin( + reinterpret_cast(address), // NOLINT + static_cast(val)); // NOLINT +} + +CUDA_ATOMIC_WRAPPER(Min, float) { + if (*address <= val) { + return; + } + + int *const address_as_i = (int *)address; + int old = *address_as_i, assumed; + + do { + assumed = old; + if (__int_as_float(assumed) <= val) { + break; + } + + old = atomicCAS(address_as_i, assumed, __float_as_int(val)); + } while (assumed != old); +} + +CUDA_ATOMIC_WRAPPER(Min, double) { + if (*address <= val) { + return; + } + + unsigned long long int *const address_as_ull = + (unsigned long long int *)address; + unsigned long long int old = *address_as_ull, assumed; + + do { + assumed = old; + if (__longlong_as_double(assumed) <= val) { + break; + } + + old = atomicCAS(address_as_ull, assumed, __double_as_longlong(val)); + } while (assumed != old); +} + } // namespace platform } // namespace paddle -- GitLab