/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/sequence_pooling.h" #include "paddle/fluid/platform/cuda_helper.h" namespace paddle { namespace operators { namespace math { #define FLT_MAX __FLT_MAX__ template struct MaxPoolFunctor { HOSTDEVICE void operator()(const T* input, const size_t start, const size_t end, const size_t item_dim, T* output, int* index) { for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) { T max_val = static_cast(-FLT_MAX); int max_index = -1; for (int i = start; i < end; ++i) { if (max_val < input[item_dim * i + tid]) { max_val = input[item_dim * i + tid]; max_index = i; } } output[tid] = max_val; index[tid] = max_index; } } }; template struct AvgPoolFunctor { HOSTDEVICE void operator()(const T* input, const size_t start, const size_t end, const size_t item_dim, T* output, int* index) { for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) { T val = static_cast(0); for (int i = start; i < end; ++i) { val += input[item_dim * i + tid]; } // end, start is lod, so end - start != 0 output[tid] = val / static_cast(end - start); } } }; template struct SumPoolFunctor { HOSTDEVICE void operator()(const T* input, const size_t start, const size_t end, const size_t item_dim, T* output, int* index) { for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) { T val = static_cast(0); for (int i = start; i < end; ++i) { val += input[item_dim * i + tid]; } output[tid] = val; } } }; template struct SqrtPoolFunctor { HOSTDEVICE void operator()(const T* input, const size_t start, const size_t end, const size_t item_dim, T* output, int* index) { for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) { T val = static_cast(0); for (int i = start; i < end; ++i) { val += input[item_dim * i + tid]; } // end, start is lod, so end - start != 0 output[tid] = val / sqrt(end - start); } } }; template struct LastPoolFunctor { HOSTDEVICE void operator()(const T* input, const size_t start, const size_t end, const size_t item_dim, T* output, int* index) { for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) { output[tid] = input[item_dim * (end - 1) + tid]; } } }; template struct FirstPoolFunctor { HOSTDEVICE void operator()(const T* input, const size_t start, const size_t end, const size_t item_dim, T* output, int* index) { for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) { output[tid] = input[item_dim * start + tid]; } } }; template __global__ void sequence_pool_kernel(Range_OP op, const T* input, const size_t* lod, const size_t lod_size, const size_t item_dim, T* output, int* index) { int bid = blockIdx.x; if (bid >= lod_size - 1) return; size_t start = lod[bid]; size_t end = lod[bid + 1]; int* index_offset = nullptr; if (index != nullptr) { index_offset = &index[bid * item_dim]; } op(input, start, end, item_dim, &output[bid * item_dim], index_offset); } template class SequencePoolFunctor { public: void operator()(const platform::CUDADeviceContext& context, const std::string pooltype, const framework::LoDTensor& input, framework::Tensor* output, framework::Tensor* index = nullptr) { auto lod = input.lod()[0]; const size_t item_dim = output->numel() / output->dims()[0]; dim3 threads(1024, 1); dim3 grid(lod.size(), 1); if (pooltype == "MAX") { sequence_pool_kernel< T, MaxPoolFunctor><<>>( MaxPoolFunctor(), input.data(), lod.CUDAData(context.GetPlace()), lod.size(), item_dim, output->mutable_data(context.GetPlace()), index->data()); } else if (pooltype == "AVERAGE") { sequence_pool_kernel< T, AvgPoolFunctor><<>>( AvgPoolFunctor(), input.data(), lod.CUDAData(context.GetPlace()), lod.size(), item_dim, output->mutable_data(context.GetPlace()), nullptr); } else if (pooltype == "SUM") { sequence_pool_kernel< T, SumPoolFunctor><<>>( SumPoolFunctor(), input.data(), lod.CUDAData(context.GetPlace()), lod.size(), item_dim, output->mutable_data(context.GetPlace()), nullptr); } else if (pooltype == "SQRT") { sequence_pool_kernel< T, SqrtPoolFunctor><<>>( SqrtPoolFunctor(), input.data(), lod.CUDAData(context.GetPlace()), lod.size(), item_dim, output->mutable_data(context.GetPlace()), nullptr); } else if (pooltype == "LAST") { sequence_pool_kernel< T, LastPoolFunctor><<>>( LastPoolFunctor(), input.data(), lod.CUDAData(context.GetPlace()), lod.size(), item_dim, output->mutable_data(context.GetPlace()), nullptr); } else if (pooltype == "FIRST") { sequence_pool_kernel< T, FirstPoolFunctor><<>>( FirstPoolFunctor(), input.data(), lod.CUDAData(context.GetPlace()), lod.size(), item_dim, output->mutable_data(context.GetPlace()), nullptr); } else { PADDLE_THROW("unsupported pooling pooltype"); } } }; template struct MaxPoolGradFunctor { HOSTDEVICE void operator()(const T* out_grad, const size_t start, const size_t end, const size_t item_dim, T* in_grad, const int* index) { for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) { for (int i = start; i < end; ++i) { if (i == index[tid]) { in_grad[item_dim * i + tid] = out_grad[tid]; } else { in_grad[item_dim * i + tid] = static_cast(0); } } } } }; template struct AvgPoolGradFunctor { HOSTDEVICE void operator()(const T* out_grad, const size_t start, const size_t end, const size_t item_dim, T* in_grad, const int* index) { for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) { for (int i = start; i < end; ++i) { in_grad[item_dim * i + tid] = out_grad[tid] / (end - start); } } } }; template struct SumPoolGradFunctor { HOSTDEVICE void operator()(const T* out_grad, const size_t start, const size_t end, const size_t item_dim, T* in_grad, const int* index) { for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) { for (int i = start; i < end; ++i) { in_grad[item_dim * i + tid] = out_grad[tid]; } } } }; template struct SqrtPoolGradFunctor { HOSTDEVICE void operator()(const T* out_grad, const size_t start, const size_t end, const size_t item_dim, T* in_grad, const int* index) { for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) { for (int i = start; i < end; ++i) { in_grad[item_dim * i + tid] = out_grad[tid] / (sqrt(static_cast(end - start))); } } } }; template struct LastPoolGradFunctor { HOSTDEVICE void operator()(const T* out_grad, const size_t start, const size_t end, const size_t item_dim, T* in_grad, const int* index) { for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) { for (int i = start; i < end; ++i) { if (i == end - 1) { in_grad[item_dim * i + tid] = out_grad[tid]; } else { in_grad[item_dim * i + tid] = static_cast(0); } } } } }; template struct FirstPoolGradFunctor { HOSTDEVICE void operator()(const T* out_grad, const size_t start, const size_t end, const size_t item_dim, T* in_grad, const int* index) { for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) { for (int i = start; i < end; ++i) { if (i == start) { in_grad[item_dim * i + tid] = out_grad[tid]; } else { in_grad[item_dim * i + tid] = static_cast(0); } } } } }; template __global__ void sequence_pool_grad_kernel(Range_OP op, const T* out_grad, const size_t* lod, const size_t lod_size, const size_t item_dim, T* in_grad, const int* index) { int bid = blockIdx.x; if (bid >= lod_size - 1) return; size_t start = lod[bid]; size_t end = lod[bid + 1]; const int* index_offset = nullptr; if (index != nullptr) { index_offset = &index[bid * item_dim]; } op(&out_grad[bid * item_dim], start, end, item_dim, in_grad, index_offset); } template class SequencePoolGradFunctor { public: void operator()(const platform::CUDADeviceContext& context, const std::string pooltype, const framework::Tensor& out_grad, framework::LoDTensor* in_grad, /* max pool has index */ const framework::Tensor* index = nullptr) { auto lod = in_grad->lod()[0]; const size_t item_dim = in_grad->numel() / in_grad->dims()[0]; dim3 threads(1024, 1); dim3 grid(lod.size(), 1); if (pooltype == "MAX") { sequence_pool_grad_kernel< T, MaxPoolGradFunctor><<>>( MaxPoolGradFunctor(), out_grad.data(), lod.CUDAData(context.GetPlace()), lod.size(), item_dim, in_grad->mutable_data(context.GetPlace()), index->data()); } else if (pooltype == "AVERAGE") { sequence_pool_grad_kernel< T, AvgPoolGradFunctor><<>>( AvgPoolGradFunctor(), out_grad.data(), lod.CUDAData(context.GetPlace()), lod.size(), item_dim, in_grad->mutable_data(context.GetPlace()), nullptr); } else if (pooltype == "SUM") { sequence_pool_grad_kernel< T, SumPoolGradFunctor><<>>( SumPoolGradFunctor(), out_grad.data(), lod.CUDAData(context.GetPlace()), lod.size(), item_dim, in_grad->mutable_data(context.GetPlace()), nullptr); } else if (pooltype == "SQRT") { sequence_pool_grad_kernel< T, SqrtPoolGradFunctor><<>>( SqrtPoolGradFunctor(), out_grad.data(), lod.CUDAData(context.GetPlace()), lod.size(), item_dim, in_grad->mutable_data(context.GetPlace()), nullptr); } else if (pooltype == "LAST") { sequence_pool_grad_kernel< T, LastPoolGradFunctor><<>>( LastPoolGradFunctor(), out_grad.data(), lod.CUDAData(context.GetPlace()), lod.size(), item_dim, in_grad->mutable_data(context.GetPlace()), nullptr); } else if (pooltype == "FIRST") { sequence_pool_grad_kernel< T, FirstPoolGradFunctor><<>>( FirstPoolGradFunctor(), out_grad.data(), lod.CUDAData(context.GetPlace()), lod.size(), item_dim, in_grad->mutable_data(context.GetPlace()), nullptr); } else { PADDLE_THROW("unsupported pooling pooltype"); } } }; // sequence pooling template class SequencePoolFunctor; template class SequencePoolFunctor; template class SequencePoolGradFunctor; template class SequencePoolGradFunctor; } // namespace math } // namespace operators } // namespace paddle