diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 87efb900cd59e6adeb051e0e458f2b86c1b510c9..43eb4de2c1d4ff61f295c05af9dfd5f26c2ba365 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -55,6 +55,12 @@ function(op_library TARGET) set(pybind_flag 1) endif() + if ("${TARGET}" STREQUAL "pool_op") + set(pybind_flag 1) + # It's enough to just adding one operator to pybind + file(APPEND ${pybind_file} "USE_OP(pool2d);\n") + endif() + # activation_op contains several operators if ("${TARGET}" STREQUAL "activation_op") set(pybind_flag 1) diff --git a/paddle/operators/math/CMakeLists.txt b/paddle/operators/math/CMakeLists.txt index b39d4f0ac27bf0a8378344f852a602c5ecf4cf6a..a0ceb029e3abee2fe591325ffa3100168c3aa8e3 100644 --- a/paddle/operators/math/CMakeLists.txt +++ b/paddle/operators/math/CMakeLists.txt @@ -1,12 +1,10 @@ if(WITH_GPU) - nv_library(math_function SRCS math_function.cc math_function.cu im2col.cc - im2col.cu DEPS cblas device_context operator) + nv_library(math_function SRCS math_function.cc math_function.cu im2col.cc im2col.cu pooling.cc pooling.cu DEPS cblas device_context operator) nv_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) nv_library(softmax SRCS softmax.cc softmax.cu DEPS operator) nv_library(cross_entropy SRCS cross_entropy.cc cross_entropy.cu DEPS operator) else() - cc_library(math_function SRCS math_function.cc im2col.cc - DEPS cblas device_context operator) + cc_library(math_function SRCS math_function.cc im2col.cc pooling.cc DEPS cblas device_context operator) cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) cc_library(softmax SRCS softmax.cc DEPS operator) cc_library(cross_entropy SRCS cross_entropy.cc DEPS operator) diff --git a/paddle/operators/math/pooling.cc b/paddle/operators/math/pooling.cc new file mode 100644 index 0000000000000000000000000000000000000000..3b706529d8f1ed0d673904b81047a5614bd4cf23 --- /dev/null +++ b/paddle/operators/math/pooling.cc @@ -0,0 +1,463 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/pooling.h" + +namespace paddle { +namespace operators { +namespace math { + +template +class Pool2dFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor& output, + std::vector& ksize, std::vector& strides, + std::vector& paddings, PoolProcess pool_process) { + const int batch_size = input.dims()[0]; + const int input_height = input.dims()[2]; + const int input_width = input.dims()[3]; + const int output_channels = output.dims()[1]; + const int output_height = output.dims()[2]; + const int output_width = output.dims()[3]; + const int ksize_height = ksize[0]; + const int ksize_width = ksize[1]; + const int stride_height = strides[0]; + const int stride_width = strides[1]; + const int padding_height = paddings[0]; + const int padding_width = paddings[1]; + + const int input_stride = input_height * input_width; + const int output_stride = output_height * output_width; + + const T* input_data = input.data(); + T* output_data = output.mutable_data(context.GetPlace()); + + for (int i = 0; i < batch_size; i++) { + for (int c = 0; c < output_channels; ++c) { + for (int ph = 0; ph < output_height; ++ph) { + int hstart = ph * stride_height - padding_height; + int hend = std::min(hstart + ksize_height, input_height); + hstart = std::max(hstart, 0); + for (int pw = 0; pw < output_width; ++pw) { + int wstart = pw * stride_width - padding_width; + int wend = std::min(wstart + ksize_width, input_width); + wstart = std::max(wstart, 0); + + T ele = pool_process.initial(); + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + pool_process.compute(ele, input_data[h * input_width + w]); + } + } + int pool_size = (hend - hstart) * (wend - wstart); + pool_process.finalize(ele, (static_cast(pool_size))); + output_data[ph * output_width + pw] = ele; + } + } + input_data += input_stride; + output_data += output_stride; + } + } + } +}; + +template +class Pool2dGradFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor& input_grad, + const framework::Tensor& output, + const framework::Tensor& output_grad, std::vector& ksize, + std::vector& strides, std::vector& paddings, + PoolProcess pool_grad_process) { + const int batch_size = input.dims()[0]; + const int input_height = input.dims()[2]; + const int input_width = input.dims()[3]; + const int output_channels = output.dims()[1]; + const int output_height = output.dims()[2]; + const int output_width = output.dims()[3]; + const int ksize_height = ksize[0]; + const int ksize_width = ksize[1]; + const int stride_height = strides[0]; + const int stride_width = strides[1]; + const int padding_height = paddings[0]; + const int padding_width = paddings[1]; + const int input_stride = input_height * input_width; + const int output_stride = output_height * output_width; + + const T* input_data = input.data(); + const T* output_data = output.data(); + const T* output_grad_data = output_grad.data(); + T* input_grad_data = input_grad.mutable_data(context.GetPlace()); + + for (int i = 0; i < batch_size; i++) { + for (int c = 0; c < output_channels; ++c) { + for (int ph = 0; ph < output_height; ++ph) { + int hstart = ph * stride_height - padding_height; + int hend = std::min(hstart + ksize_height, input_height); + hstart = std::max(hstart, 0); + for (int pw = 0; pw < output_width; ++pw) { + int wstart = pw * stride_width - padding_width; + int wend = std::min(wstart + ksize_width, input_width); + wstart = std::max(wstart, 0); + int pool_size = (hend - hstart) * (wend - wstart); + float scale = 1.0 / pool_size; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + pool_grad_process.compute( + input_data[h * input_width + w], + output_data[ph * output_width + pw], + output_grad_data[ph * output_width + pw], + input_grad_data[h * input_width + w], + static_cast(scale)); + } + } + } + } + input_data += input_stride; + output_data += output_stride; + input_grad_data += input_stride; + output_grad_data += output_stride; + } + } + } +}; + +template +class MaxPool2dGradFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor& input_grad, + const framework::Tensor& output, + const framework::Tensor& output_grad, std::vector& ksize, + std::vector& strides, std::vector& paddings) { + const int batch_size = input.dims()[0]; + const int input_height = input.dims()[2]; + const int input_width = input.dims()[3]; + const int output_channels = output.dims()[1]; + const int output_height = output.dims()[2]; + const int output_width = output.dims()[3]; + const int ksize_height = ksize[0]; + const int ksize_width = ksize[1]; + const int stride_height = strides[0]; + const int stride_width = strides[1]; + const int padding_height = paddings[0]; + const int padding_width = paddings[1]; + const int input_stride = input_height * input_width; + const int output_stride = output_height * output_width; + + const T* input_data = input.data(); + const T* output_data = output.data(); + const T* output_grad_data = output_grad.data(); + T* input_grad_data = input_grad.mutable_data(context.GetPlace()); + + for (int i = 0; i < batch_size; i++) { + for (int c = 0; c < output_channels; ++c) { + for (int ph = 0; ph < output_height; ++ph) { + int hstart = ph * stride_height - padding_height; + int hend = std::min(hstart + ksize_height, input_height); + hstart = std::max(hstart, 0); + for (int pw = 0; pw < output_width; ++pw) { + int wstart = pw * stride_width - padding_width; + int wend = std::min(wstart + ksize_width, input_width); + wstart = std::max(wstart, 0); + + bool stop = false; + for (int h = hstart; h < hend && !stop; ++h) { + for (int w = wstart; w < wend && !stop; ++w) { + int input_idx = h * input_width + w; + int output_idx = ph * output_width + pw; + if (input_data[input_idx] == output_data[output_idx]) { + input_grad_data[input_idx] += output_grad_data[output_idx]; + stop = true; + } + } + } + } + } + input_data += input_stride; + output_data += output_stride; + input_grad_data += input_stride; + output_grad_data += output_stride; + } + } + } +}; + +template class MaxPool2dGradFunctor; +// template class MaxPool2dGradFunctor; + +template class Pool2dFunctor, float>; +template class Pool2dFunctor, float>; +template class Pool2dGradFunctor< + platform::CPUPlace, paddle::operators::math::MaxPoolGrad, float>; +template class Pool2dGradFunctor< + platform::CPUPlace, paddle::operators::math::AvgPoolGrad, float>; +template class Pool2dFunctor, double>; +template class Pool2dFunctor, double>; +template class Pool2dGradFunctor< + platform::CPUPlace, paddle::operators::math::MaxPoolGrad, double>; +template class Pool2dGradFunctor< + platform::CPUPlace, paddle::operators::math::AvgPoolGrad, double>; + +template +class Pool3dFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor& output, + std::vector& ksize, std::vector& strides, + std::vector& paddings, PoolProcess pool_process) { + const int batch_size = input.dims()[0]; + const int input_depth = input.dims()[2]; + const int input_height = input.dims()[3]; + const int input_width = input.dims()[4]; + const int output_channels = output.dims()[1]; + const int output_depth = output.dims()[2]; + const int output_height = output.dims()[3]; + const int output_width = output.dims()[4]; + const int ksize_depth = ksize[0]; + const int ksize_height = ksize[1]; + const int ksize_width = ksize[2]; + const int stride_depth = strides[0]; + const int stride_height = strides[1]; + const int stride_width = strides[2]; + const int padding_depth = paddings[0]; + const int padding_height = paddings[1]; + const int padding_width = paddings[2]; + + const int input_stride = input_depth * input_height * input_width; + const int output_stride = output_depth * output_height * output_width; + + const T* input_data = input.data(); + T* output_data = output.mutable_data(context.GetPlace()); + + for (int i = 0; i < batch_size; i++) { + for (int c = 0; c < output_channels; ++c) { + for (int pd = 0; pd < output_depth; ++pd) { + int dstart = pd * stride_depth - padding_depth; + int dend = std::min(dstart + ksize_depth, input_depth); + dstart = std::max(dstart, 0); + for (int ph = 0; ph < output_height; ++ph) { + int hstart = ph * stride_height - padding_height; + int hend = std::min(hstart + ksize_height, input_height); + hstart = std::max(hstart, 0); + for (int pw = 0; pw < output_width; ++pw) { + int wstart = pw * stride_width - padding_width; + int wend = std::min(wstart + ksize_width, input_width); + wstart = std::max(wstart, 0); + int output_idx = (pd * output_height + ph) * output_width + pw; + T ele = pool_process.initial(); + for (int d = dstart; d < dend; ++d) { + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + pool_process.compute( + ele, + input_data[(d * input_height + h) * input_width + w]); + } + } + } + int pool_size = + (dend - dstart) * (hend - hstart) * (wend - wstart); + pool_process.finalize(ele, static_cast(pool_size)); + output_data[output_idx] = ele; + } + } + } + input_data += input_stride; + output_data += output_stride; + } + } + } +}; + +template +class Pool3dGradFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor& input_grad, + const framework::Tensor& output, + const framework::Tensor& output_grad, std::vector& ksize, + std::vector& strides, std::vector& paddings, + PoolProcess pool_grad_process) { + const int batch_size = input.dims()[0]; + const int input_depth = input.dims()[2]; + const int input_height = input.dims()[3]; + const int input_width = input.dims()[4]; + const int output_channels = output.dims()[1]; + const int output_depth = output.dims()[2]; + const int output_height = output.dims()[3]; + const int output_width = output.dims()[4]; + const int ksize_depth = ksize[0]; + const int ksize_height = ksize[1]; + const int ksize_width = ksize[2]; + const int stride_depth = strides[0]; + const int stride_height = strides[1]; + const int stride_width = strides[2]; + const int padding_depth = paddings[0]; + const int padding_height = paddings[1]; + const int padding_width = paddings[2]; + const int input_stride = input_depth * input_height * input_width; + const int output_stride = output_depth * output_height * output_width; + + const T* input_data = input.data(); + const T* output_data = output.data(); + const T* output_grad_data = output_grad.data(); + T* input_grad_data = input_grad.mutable_data(context.GetPlace()); + + for (int i = 0; i < batch_size; i++) { + for (int c = 0; c < output_channels; ++c) { + for (int pd = 0; pd < output_depth; ++pd) { + int dstart = pd * stride_depth - padding_depth; + int dend = std::min(dstart + ksize_depth, input_depth); + dstart = std::max(dstart, 0); + for (int ph = 0; ph < output_height; ++ph) { + int hstart = ph * stride_height - padding_height; + int hend = std::min(hstart + ksize_height, input_height); + hstart = std::max(hstart, 0); + + for (int pw = 0; pw < output_width; ++pw) { + int wstart = pw * stride_width - padding_width; + int wend = std::min(wstart + ksize_width, input_width); + wstart = std::max(wstart, 0); + + int pool_size = + (dend - dstart) * (hend - hstart) * (wend - wstart); + float scale = 1.0 / pool_size; + for (int d = dstart; d < dend; ++d) { + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + int input_idx = (d * input_height + h) * input_width + w; + int output_idx = + (pd * output_height + ph) * output_width + pw; + pool_grad_process.compute( + input_data[input_idx], output_data[output_idx], + output_grad_data[output_idx], + input_grad_data[input_idx], static_cast(scale)); + } + } + } + } + } + } + input_data += input_stride; + output_data += output_stride; + input_grad_data += input_stride; + output_grad_data += output_stride; + } + } + } +}; + +template +class MaxPool3dGradFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor& input_grad, + const framework::Tensor& output, + const framework::Tensor& output_grad, std::vector& ksize, + std::vector& strides, std::vector& paddings) { + const int batch_size = input.dims()[0]; + const int input_depth = input.dims()[2]; + const int input_height = input.dims()[3]; + const int input_width = input.dims()[4]; + const int output_channels = output.dims()[1]; + const int output_depth = output.dims()[2]; + const int output_height = output.dims()[3]; + const int output_width = output.dims()[4]; + const int ksize_depth = ksize[0]; + const int ksize_height = ksize[1]; + const int ksize_width = ksize[2]; + const int stride_depth = strides[0]; + const int stride_height = strides[1]; + const int stride_width = strides[2]; + const int padding_depth = paddings[0]; + const int padding_height = paddings[1]; + const int padding_width = paddings[2]; + const int input_stride = input_depth * input_height * input_width; + const int output_stride = output_depth * output_height * output_width; + + const T* input_data = input.data(); + const T* output_data = output.data(); + const T* output_grad_data = output_grad.data(); + T* input_grad_data = input_grad.mutable_data(context.GetPlace()); + + for (int i = 0; i < batch_size; i++) { + for (int c = 0; c < output_channels; ++c) { + for (int pd = 0; pd < output_depth; ++pd) { + int dstart = pd * stride_depth - padding_depth; + int dend = std::min(dstart + ksize_depth, input_depth); + dstart = std::max(dstart, 0); + for (int ph = 0; ph < output_height; ++ph) { + int hstart = ph * stride_height - padding_height; + int hend = std::min(hstart + ksize_height, input_height); + hstart = std::max(hstart, 0); + for (int pw = 0; pw < output_width; ++pw) { + int wstart = pw * stride_width - padding_width; + int wend = std::min(wstart + ksize_width, input_width); + wstart = std::max(wstart, 0); + bool stop = false; + for (int d = dstart; d < dend && !stop; ++d) { + for (int h = hstart; h < hend && !stop; ++h) { + for (int w = wstart; w < wend && !stop; ++w) { + int input_idx = (d * input_height + h) * input_width + w; + int output_idx = + (pd * output_height + ph) * output_width + pw; + + if (input_data[input_idx] == output_data[output_idx]) { + input_grad_data[input_idx] += + output_grad_data[output_idx]; + stop = true; + } + } + } + } + } + } + } + input_data += input_stride; + output_data += output_stride; + input_grad_data += input_stride; + output_grad_data += output_stride; + } + } + } +}; + +template class MaxPool3dGradFunctor; +// template class MaxPool3dGradFunctor; + +template class Pool3dFunctor, float>; +template class Pool3dFunctor, float>; +template class Pool3dGradFunctor< + platform::CPUPlace, paddle::operators::math::MaxPoolGrad, float>; +template class Pool3dGradFunctor< + platform::CPUPlace, paddle::operators::math::AvgPoolGrad, float>; +template class Pool3dFunctor, double>; +template class Pool3dFunctor, double>; +template class Pool3dGradFunctor< + platform::CPUPlace, paddle::operators::math::MaxPoolGrad, double>; +template class Pool3dGradFunctor< + platform::CPUPlace, paddle::operators::math::AvgPoolGrad, double>; +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/pooling.cu b/paddle/operators/math/pooling.cu new file mode 100644 index 0000000000000000000000000000000000000000..8aeccd1f8e8855c51ad85016f0cb239b4c9c8fb0 --- /dev/null +++ b/paddle/operators/math/pooling.cu @@ -0,0 +1,635 @@ +/* Copyright (c) 2016 paddlepaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/pooling.h" +#include "paddle/platform/cuda_helper.h" + +namespace paddle { +namespace operators { +namespace math { + +template +__global__ void KernelPool2D(const int nthreads, const T* input_data, + T* output_data, const int channels, + const int input_height, const int input_width, + const int output_height, const int output_width, + const int ksize_height, const int ksize_width, + const int stride_height, const int stride_width, + const int padding_height, const int padding_width, + PoolProcess pool_process) { + for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; + index += blockDim.x * gridDim.x) { + int pw = index % output_width; + int ph = (index / output_width) % output_height; + int c = (index / output_width / output_height) % channels; + int batch_idx = index / output_width / output_height / channels; + + int hstart = ph * stride_height - padding_height; + int hend = min(hstart + ksize_height, input_height); + hstart = max(hstart, 0); + + int wstart = pw * stride_width - padding_width; + int wend = min(wstart + ksize_width, input_width); + wstart = max(wstart, 0); + + input_data += (batch_idx * channels + c) * input_height * input_width; + T ele = pool_process.initial(); + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + pool_process.compute(ele, input_data[h * input_width + w]); + } + } + int pool_size = (hend - hstart) * (wend - wstart); + pool_process.finalize(ele, (static_cast(pool_size))); + output_data[index] = ele; + } +} + +template +__global__ void KernelPool2DGrad( + const int nthreads, const T* input_data, const T* output_data, + const T* output_grad, T* input_grad, const int channels, + const int input_height, const int input_width, const int output_height, + const int output_width, const int ksize_height, const int ksize_width, + const int stride_height, const int stride_width, const int padding_height, + const int padding_width, PoolProcess pool_process) { + for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; + index += blockDim.x * gridDim.x) { + int offsetW = index % input_width + padding_width; + int offsetH = (index / input_width) % input_height + padding_height; + int offsetC = (index / input_width / input_height) % channels; + int batch_idx = index / input_width / input_height / channels; + + int phstart = (offsetH < ksize_height) + ? 0 + : (offsetH - ksize_height) / stride_height + 1; + int pwstart = (offsetW < ksize_width) + ? 0 + : (offsetW - ksize_width) / stride_width + 1; + int phend = min(offsetH / stride_height + 1, output_height); + int pwend = min(offsetW / stride_width + 1, output_width); + T gradient = 0; + T input = input_data[index]; + int output_idx = + (batch_idx * channels + offsetC) * output_height * output_width; + output_data += output_idx; + output_grad += output_idx; + for (int ph = phstart; ph < phend; ++ph) { + for (int pw = pwstart; pw < pwend; ++pw) { + int hstart = ph * stride_height - padding_height; + int wstart = pw * stride_width - padding_width; + int hend = min(hstart + ksize_height, input_height); + int wend = min(wstart + ksize_width, input_width); + hstart = max(hstart, 0); + wstart = max(wstart, 0); + int pool_size = (hend - hstart) * (wend - wstart); + int output_sub_idx = ph * output_width + pw; + pool_process.compute(input, output_data[output_sub_idx], + output_grad[output_sub_idx], gradient, + static_cast(1.0 / pool_size)); + } + } + input_grad[index] = gradient; + } +} + +template +__global__ void KernelMaxPool2DGrad( + const int nthreads, const T* input_data, const T* output_data, + const T* output_grad, T* input_grad, const int channels, + const int input_height, const int input_width, const int output_height, + const int output_width, const int ksize_height, const int ksize_width, + const int stride_height, const int stride_width, const int padding_height, + const int padding_width) { + for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; + index += blockDim.x * gridDim.x) { + int pw = index % output_width; + int ph = (index / output_width) % output_height; + int c = (index / output_width / output_height) % channels; + int batch_idx = index / output_width / output_height / channels; + + int hstart = ph * stride_height - padding_height; + int hend = min(hstart + ksize_height, input_height); + hstart = max(hstart, 0); + + int wstart = pw * stride_width - padding_width; + int wend = min(wstart + ksize_width, input_width); + wstart = max(wstart, 0); + + input_data += (batch_idx * channels + c) * input_height * input_width; + input_grad += (batch_idx * channels + c) * input_height * input_width; + + T ele = output_data[index]; + int maxIndex = -1; + bool stop = false; + for (int h = hstart; h < hend && !stop; ++h) { + for (int w = wstart; w < wend && !stop; ++w) { + if (ele == input_data[h * input_width + w]) { + maxIndex = h * input_width + w; + stop = true; + } + } + } + + if (maxIndex != -1) { + // atomic add + atomicAdd(input_grad + maxIndex, output_grad[index]); + } + } +} + +template +class Pool2dFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor& output, + std::vector& ksize, std::vector& strides, + std::vector& paddings, PoolProcess pool_process) { + const int batch_size = input.dims()[0]; + const int input_channels = input.dims()[1]; + const int input_height = input.dims()[2]; + const int input_width = input.dims()[3]; + const int output_channels = output.dims()[1]; + const int output_height = output.dims()[2]; + const int output_width = output.dims()[3]; + const int ksize_height = ksize[0]; + const int ksize_width = ksize[1]; + const int stride_height = strides[0]; + const int stride_width = strides[1]; + const int padding_height = paddings[0]; + const int padding_width = paddings[1]; + + const T* input_data = input.data(); + T* output_data = output.mutable_data(context.GetPlace()); + + int nthreads = batch_size * output_channels * output_height * output_width; + int blocks = (nthreads + 1024 - 1) / 1024; + dim3 threads(1024, 1); + dim3 grid(blocks, 1); + + KernelPool2D< + PoolProcess, + T><<(context) + .stream()>>>(nthreads, input_data, output_data, input_channels, + input_height, input_width, output_height, + output_width, ksize_height, ksize_width, + stride_height, stride_width, padding_height, + padding_width, pool_process); + } +}; + +template +class Pool2dGradFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor& input_grad, + const framework::Tensor& output, + const framework::Tensor& output_grad, std::vector& ksize, + std::vector& strides, std::vector& paddings, + PoolProcess pool_process) { + const int batch_size = input.dims()[0]; + const int input_channels = input.dims()[1]; + const int input_height = input.dims()[2]; + const int input_width = input.dims()[3]; + const int output_height = output.dims()[2]; + const int output_width = output.dims()[3]; + const int ksize_height = ksize[0]; + const int ksize_width = ksize[1]; + const int stride_height = strides[0]; + const int stride_width = strides[1]; + const int padding_height = paddings[0]; + const int padding_width = paddings[1]; + + const T* input_data = input.data(); + const T* output_data = output.data(); + const T* output_grad_data = output_grad.data(); + T* input_grad_data = input_grad.mutable_data(context.GetPlace()); + + int nthreads = batch_size * input_channels * input_height * input_width; + int blocks = (nthreads + 1024 - 1) / 1024; + dim3 threads(1024, 1); + dim3 grid(blocks, 1); + + KernelPool2DGrad< + PoolProcess, + T><<(context) + .stream()>>>( + nthreads, input_data, output_data, output_grad_data, input_grad_data, + input_channels, input_height, input_width, output_height, output_width, + ksize_height, ksize_width, stride_height, stride_width, padding_height, + padding_width, pool_process); + } +}; + +template +class MaxPool2dGradFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor& input_grad, + const framework::Tensor& output, + const framework::Tensor& output_grad, std::vector& ksize, + std::vector& strides, std::vector& paddings) { + const int batch_size = input.dims()[0]; + const int input_channels = input.dims()[1]; + const int input_height = input.dims()[2]; + const int input_width = input.dims()[3]; + const int output_channels = output.dims()[1]; + const int output_height = output.dims()[2]; + const int output_width = output.dims()[3]; + const int ksize_height = ksize[0]; + const int ksize_width = ksize[1]; + const int stride_height = strides[0]; + const int stride_width = strides[1]; + const int padding_height = paddings[0]; + const int padding_width = paddings[1]; + + const T* input_data = input.data(); + const T* output_data = output.data(); + const T* output_grad_data = output_grad.data(); + T* input_grad_data = input_grad.mutable_data(context.GetPlace()); + + int nthreads = batch_size * output_channels * output_height * output_width; + int blocks = (nthreads + 1024 - 1) / 1024; + dim3 threads(1024, 1); + dim3 grid(blocks, 1); + + KernelMaxPool2DGrad< + T><<(context) + .stream()>>>( + nthreads, input_data, output_data, output_grad_data, input_grad_data, + input_channels, input_height, input_width, output_height, output_width, + ksize_height, ksize_width, stride_height, stride_width, padding_height, + padding_width); + } +}; + +template class MaxPool2dGradFunctor; +// template class MaxPool2dGradFunctor; // The +// 64-bit floating-point version of atomicAdd() is only supported by devices of +// compute capability 6.x and higher. + +template class Pool2dFunctor, float>; +template class Pool2dFunctor, float>; +template class Pool2dGradFunctor< + platform::GPUPlace, paddle::operators::math::MaxPoolGrad, float>; +template class Pool2dGradFunctor< + platform::GPUPlace, paddle::operators::math::AvgPoolGrad, float>; +template class Pool2dFunctor, double>; +template class Pool2dFunctor, double>; +template class Pool2dGradFunctor< + platform::GPUPlace, paddle::operators::math::MaxPoolGrad, double>; +template class Pool2dGradFunctor< + platform::GPUPlace, paddle::operators::math::AvgPoolGrad, double>; + +template +__global__ void KernelPool3D( + const int nthreads, const T* input_data, T* output_data, const int channels, + const int input_depth, const int input_height, const int input_width, + const int output_depth, const int output_height, const int output_width, + const int ksize_depth, const int ksize_height, const int ksize_width, + const int stride_depth, const int stride_height, const int stride_width, + const int padding_depth, const int padding_height, const int padding_width, + PoolProcess pool_process) { + for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; + index += blockDim.x * gridDim.x) { + int pw = index % output_width; + int ph = (index / output_width) % output_height; + int pd = (index / output_width / output_height) % output_depth; + int c = (index / output_width / output_height / output_depth) % channels; + int batch_idx = + index / output_width / output_height / output_depth / channels; + int dstart = pd * stride_depth - padding_depth; + int hstart = ph * stride_height - padding_height; + int wstart = pw * stride_width - padding_width; + int dend = min(dstart + ksize_depth, input_depth); + int hend = min(hstart + ksize_height, input_height); + int wend = min(wstart + ksize_width, input_width); + dstart = max(dstart, 0); + hstart = max(hstart, 0); + wstart = max(wstart, 0); + T ele = pool_process.initial(); + input_data += + (batch_idx * channels + c) * input_depth * input_height * input_width; + for (int d = dstart; d < dend; ++d) { + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + pool_process.compute( + ele, input_data[(d * input_height + h) * input_width + w]); + } + } + } + int pool_size = (dend - dstart) * (hend - hstart) * (wend - wstart); + pool_process.finalize(ele, static_cast(pool_size)); + output_data[index] = ele; + } +} + +template +__global__ void KernelPool3DGrad( + const int nthreads, const T* input_data, const T* output_data, + const T* output_grad, T* input_grad, const int channels, + const int input_depth, const int input_height, const int input_width, + const int output_depth, const int output_height, const int output_width, + const int ksize_depth, const int ksize_height, const int ksize_width, + const int stride_depth, const int stride_height, const int stride_width, + const int padding_depth, const int padding_height, const int padding_width, + PoolProcess pool_process) { + for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; + index += blockDim.x * gridDim.x) { + int offsetW = index % input_width + padding_width; + int offsetH = (index / input_width) % input_height + padding_height; + int offsetD = + (index / input_width / input_height) % input_depth + padding_depth; + int offsetC = (index / input_width / input_height / input_depth) % channels; + int batch_idx = index / input_width / input_height / input_depth / channels; + + int pdstart = (offsetD < ksize_depth) + ? 0 + : (offsetD - ksize_depth) / stride_depth + 1; + int phstart = (offsetH < ksize_height) + ? 0 + : (offsetH - ksize_height) / stride_height + 1; + int pwstart = (offsetW < ksize_width) + ? 0 + : (offsetW - ksize_width) / stride_width + 1; + int pdend = min((offsetD) / stride_depth + 1, output_depth); + int phend = min((offsetH) / stride_height + 1, output_height); + int pwend = min((offsetW) / stride_width + 1, output_width); + + T gradient = 0; + T input = input_data[index]; + int output_idx = (batch_idx * channels + offsetC) * output_depth * + output_height * output_width; + output_data += output_idx; + output_grad += output_idx; + + for (int pd = pdstart; pd < pdend; ++pd) { + for (int ph = phstart; ph < phend; ++ph) { + for (int pw = pwstart; pw < pwend; ++pw) { + // figure out the pooling size + int dstart = pd * stride_depth - padding_depth; + int hstart = ph * stride_height - padding_height; + int wstart = pw * stride_width - padding_width; + int dend = min(dstart + ksize_depth, input_depth); + int hend = min(hstart + ksize_height, input_height); + int wend = min(wstart + ksize_width, input_width); + dstart = max(dstart, 0); + hstart = max(hstart, 0); + wstart = max(wstart, 0); + int pool_size = (dend - dstart) * (hend - hstart) * (wend - wstart); + int output_sub_idx = (pd * output_height + ph) * output_width + pw; + pool_process.compute(input, output_data[output_sub_idx], + output_grad[output_sub_idx], gradient, + static_cast(1.0 / pool_size)); + } + } + } + input_grad[index] = gradient; + } +} + +template +__global__ void KernelMaxPool3DGrad( + const int nthreads, const T* input_data, const T* output_data, + const T* output_grad, T* input_grad, const int channels, + const int input_depth, const int input_height, const int input_width, + const int output_depth, const int output_height, const int output_width, + const int ksize_depth, const int ksize_height, const int ksize_width, + const int stride_depth, const int stride_height, const int stride_width, + const int padding_depth, const int padding_height, + const int padding_width) { + for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; + index += blockDim.x * gridDim.x) { + int pw = index % output_width; + int ph = (index / output_width) % output_height; + int pd = (index / output_width / output_height) % output_depth; + int c = (index / output_width / output_height / output_depth) % channels; + int batch_idx = + index / output_width / output_height / output_depth / channels; + int dstart = pd * stride_depth - padding_depth; + int hstart = ph * stride_height - padding_height; + int wstart = pw * stride_width - padding_width; + int dend = min(dstart + ksize_depth, input_depth); + int hend = min(hstart + ksize_height, input_height); + int wend = min(wstart + ksize_width, input_width); + dstart = max(dstart, 0); + hstart = max(hstart, 0); + wstart = max(wstart, 0); + T ele = output_data[index]; + bool stop = false; + int maxIdx = -1; + input_data += + (batch_idx * channels + c) * input_depth * input_height * input_width; + input_grad += + (batch_idx * channels + c) * input_depth * input_height * input_width; + + for (int d = dstart; d < dend && !stop; ++d) { + for (int h = hstart; h < hend && !stop; ++h) { + for (int w = wstart; w < wend && !stop; ++w) { + if (ele == input_data[(d * input_height + h) * input_width + w]) { + stop = true; + maxIdx = (d * input_height + h) * input_width + w; + } + } + } + } + if (maxIdx != -1) { + // atomic add + atomicAdd(input_grad + maxIdx, output_grad[index]); + } + } +} + +template +class Pool3dFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor& output, + std::vector& ksize, std::vector& strides, + std::vector& paddings, PoolProcess pool_process) { + const int batch_size = input.dims()[0]; + const int input_channels = input.dims()[1]; + const int input_depth = input.dims()[2]; + const int input_height = input.dims()[3]; + const int input_width = input.dims()[4]; + const int output_channels = output.dims()[1]; + const int output_depth = output.dims()[2]; + const int output_height = output.dims()[3]; + const int output_width = output.dims()[4]; + const int ksize_depth = ksize[0]; + const int ksize_height = ksize[1]; + const int ksize_width = ksize[2]; + const int stride_depth = strides[0]; + const int stride_height = strides[1]; + const int stride_width = strides[2]; + const int padding_depth = paddings[0]; + const int padding_height = paddings[1]; + const int padding_width = paddings[2]; + + const T* input_data = input.data(); + T* output_data = output.mutable_data(context.GetPlace()); + + int nthreads = batch_size * output_channels * output_depth * output_height * + output_width; + int blocks = (nthreads + 1024 - 1) / 1024; + dim3 threads(1024, 1); + dim3 grid(blocks, 1); + + KernelPool3D< + PoolProcess, + T><<(context) + .stream()>>>( + nthreads, input_data, output_data, input_channels, input_depth, + input_height, input_width, output_depth, output_height, output_width, + ksize_depth, ksize_height, ksize_width, stride_depth, stride_height, + stride_width, padding_depth, padding_height, padding_width, + pool_process); + } +}; + +template +class Pool3dGradFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor& input_grad, + const framework::Tensor& output, + const framework::Tensor& output_grad, std::vector& ksize, + std::vector& strides, std::vector& paddings, + PoolProcess pool_process) { + const int batch_size = input.dims()[0]; + const int input_channels = input.dims()[1]; + const int input_depth = input.dims()[2]; + const int input_height = input.dims()[3]; + const int input_width = input.dims()[4]; + const int output_channels = output.dims()[1]; + const int output_depth = output.dims()[2]; + const int output_height = output.dims()[3]; + const int output_width = output.dims()[4]; + const int ksize_depth = ksize[0]; + const int ksize_height = ksize[1]; + const int ksize_width = ksize[2]; + const int stride_depth = strides[0]; + const int stride_height = strides[1]; + const int stride_width = strides[2]; + const int padding_depth = paddings[0]; + const int padding_height = paddings[1]; + const int padding_width = paddings[2]; + + const T* input_data = input.data(); + const T* output_data = output.data(); + const T* output_grad_data = output_grad.data(); + T* input_grad_data = input_grad.mutable_data(context.GetPlace()); + + int nthreads = + batch_size * input_channels * input_depth * input_height * input_width; + int blocks = (nthreads + 1024 - 1) / 1024; + dim3 threads(1024, 1); + dim3 grid(blocks, 1); + + KernelPool3DGrad< + PoolProcess, + T><<(context) + .stream()>>>( + nthreads, input_data, output_data, output_grad_data, input_grad_data, + input_channels, input_depth, input_height, input_width, output_depth, + output_height, output_width, ksize_depth, ksize_height, ksize_width, + stride_depth, stride_height, stride_width, padding_depth, + padding_height, padding_width, pool_process); + } +}; + +template +class MaxPool3dGradFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor& input_grad, + const framework::Tensor& output, + const framework::Tensor& output_grad, std::vector& ksize, + std::vector& strides, std::vector& paddings) { + const int batch_size = input.dims()[0]; + const int input_channels = input.dims()[1]; + const int input_depth = input.dims()[2]; + const int input_height = input.dims()[3]; + const int input_width = input.dims()[4]; + const int output_channels = output.dims()[1]; + const int output_depth = output.dims()[2]; + const int output_height = output.dims()[3]; + const int output_width = output.dims()[4]; + const int ksize_depth = ksize[0]; + const int ksize_height = ksize[1]; + const int ksize_width = ksize[2]; + const int stride_depth = strides[0]; + const int stride_height = strides[1]; + const int stride_width = strides[2]; + const int padding_depth = paddings[0]; + const int padding_height = paddings[1]; + const int padding_width = paddings[2]; + + const T* input_data = input.data(); + const T* output_data = output.data(); + const T* output_grad_data = output_grad.data(); + T* input_grad_data = input_grad.mutable_data(context.GetPlace()); + + int nthreads = batch_size * output_channels * output_depth * output_height * + output_width; + int blocks = (nthreads + 1024 - 1) / 1024; + dim3 threads(1024, 1); + dim3 grid(blocks, 1); + + KernelMaxPool3DGrad< + T><<(context) + .stream()>>>( + nthreads, input_data, output_data, output_grad_data, input_grad_data, + input_channels, input_depth, input_height, input_width, output_depth, + output_height, output_width, ksize_depth, ksize_height, ksize_width, + stride_depth, stride_height, stride_width, padding_depth, + padding_height, padding_width); + } +}; + +template class MaxPool3dGradFunctor; +// template class MaxPool3dGradFunctor; // The +// 64-bit floating-point version of atomicAdd() is only supported by devices of +// compute capability 6.x and higher. + +template class Pool3dFunctor, float>; +template class Pool3dFunctor, float>; +template class Pool3dGradFunctor< + platform::GPUPlace, paddle::operators::math::MaxPoolGrad, float>; +template class Pool3dGradFunctor< + platform::GPUPlace, paddle::operators::math::AvgPoolGrad, float>; +template class Pool3dFunctor, double>; +template class Pool3dFunctor, double>; +template class Pool3dGradFunctor< + platform::GPUPlace, paddle::operators::math::MaxPoolGrad, double>; +template class Pool3dGradFunctor< + platform::GPUPlace, paddle::operators::math::AvgPoolGrad, double>; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/pooling.h b/paddle/operators/math/pooling.h new file mode 100644 index 0000000000000000000000000000000000000000..d214c689235ad4233d3e4e1c2aa0fdc993bf20c6 --- /dev/null +++ b/paddle/operators/math/pooling.h @@ -0,0 +1,122 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/eigen.h" +#include "paddle/framework/tensor.h" +#include "paddle/platform/device_context.h" +#include "paddle/platform/hostdevice.h" + +namespace paddle { +namespace operators { +namespace math { +////////////////////// +#define FLT_MAX __FLT_MAX__ // + +template +class MaxPool { + public: + DEVICE inline T initial() { return static_cast(-FLT_MAX); } + DEVICE inline void compute(T& y, const T& x) { y = y > x ? y : x; } + DEVICE inline void finalize(T& y, const T& poo_size) {} +}; + +template +class AvgPool { + public: + DEVICE inline T initial() { return static_cast(0); } + DEVICE inline void compute(T& y, const T& x) { y += x; } + DEVICE inline void finalize(T& y, const T& poo_size) { y /= poo_size; } +}; +template +class MaxPoolGrad { + public: + DEVICE inline void compute(const T& x, const T& y, const T& dy, T& dx, + T scale) { + dx += dy * (x == y); + } +}; + +template +class AvgPoolGrad { + public: + DEVICE inline void compute(const T& x, const T& y, const T& dy, T& dx, + T scale) { + dx += (scale * dy); + } +}; + +template +class Pool2dFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor& output, + std::vector& ksize, std::vector& strides, + std::vector& paddings, PoolProcess pool_compute); +}; + +template +class Pool2dGradFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor& input_grad, + const framework::Tensor& output, + const framework::Tensor& output_grad, std::vector& ksize, + std::vector& strides, std::vector& paddings, + PoolProcess pool_compute); +}; + +template +class MaxPool2dGradFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor& input_grad, + const framework::Tensor& output, + const framework::Tensor& output_grad, std::vector& ksize, + std::vector& strides, std::vector& paddings); +}; + +template +class Pool3dFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor& output, + std::vector& ksize, std::vector& strides, + std::vector& paddings, PoolProcess pool_compute); +}; + +template +class Pool3dGradFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor& input_grad, + const framework::Tensor& output, + const framework::Tensor& output_grad, std::vector& ksize, + std::vector& strides, std::vector& paddings, + PoolProcess pool_compute); +}; + +template +class MaxPool3dGradFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor& input_grad, + const framework::Tensor& output, + const framework::Tensor& output_grad, std::vector& ksize, + std::vector& strides, std::vector& paddings); +}; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/pool_op.cc b/paddle/operators/pool_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..c29f51f05613832c838400eb114465c81290ea58 --- /dev/null +++ b/paddle/operators/pool_op.cc @@ -0,0 +1,195 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/pool_op.h" + +namespace paddle { +namespace operators { + +int OutputSizePool(int input_size, int filter_size, int padding, int stride) { + int output_size = (input_size - filter_size + 2 * padding) / stride + 1; + return output_size; +} + +class PoolOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContextBase *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "X(Input) of Pooling should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Out(Output) of Pooling should not be null."); + + auto in_x_dims = ctx->GetInputDim("X"); + + std::string pooling_type = ctx->Attrs().Get("poolingType"); + std::vector ksize = ctx->Attrs().Get>("ksize"); + std::vector strides = ctx->Attrs().Get>("strides"); + std::vector paddings = ctx->Attrs().Get>("paddings"); + + PADDLE_ENFORCE(pooling_type == "max" || pooling_type == "avg", + "pooling_type should be 'max' or 'avg'"); + PADDLE_ENFORCE(in_x_dims.size() == 4 || in_x_dims.size() == 5, + "Pooling intput should be 4-D or 5-D"); + + if (ctx->Attrs().Get("globalPooling")) { + ksize.resize(static_cast(in_x_dims.size()) - 2); + for (size_t i = 0; i < ksize.size(); ++i) + ksize[i] = static_cast(in_x_dims[i + 2]); + } + + PADDLE_ENFORCE(in_x_dims.size() - ksize.size() == 2U, + "Input size and Pooling size should be consistent."); + PADDLE_ENFORCE(ksize.size() == 2 || ksize.size() == 3, + "Pooling size should be 2 elements. or 3 elements."); + PADDLE_ENFORCE_EQ(ksize.size(), strides.size(), + "strides size and pooling size should be the same."); + PADDLE_ENFORCE_EQ(ksize.size(), paddings.size(), + "paddings size and pooling size should be the same."); + + std::vector output_shape({in_x_dims[0], in_x_dims[1]}); + for (size_t i = 0; i < ksize.size(); ++i) { + output_shape.push_back( + OutputSizePool(in_x_dims[i + 2], ksize[i], paddings[i], strides[i])); + } + ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); + } +}; + +class PoolOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContextBase *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "X(Input) of Pooling should not be null."); + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), + "Input@Grad of Pooling should not be null."); + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + } +}; + +class Pool2dOpMaker : public framework::OpProtoAndCheckerMaker { + public: + Pool2dOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput( + "X", + "The input tensor of pooling operator. " + "The format of input tensor is NCHW. Where N is batch size, C is the " + "number of channels, H and W is the height and width of feature."); + AddOutput("Out", + "The output tensor of pooling operator." + "The format of output tensor is also NCHW."); + + AddAttr("poolingType", + "PoolingType of pooling operator." + "Str constant equal to 'max' or 'avg'.") + .InEnum({"max", "avg"}); + AddAttr>( + "ksize", + "Pooling size(depth, height, width) of pooling operator." + "If globalPooling = true, ksize is ignored and need not be " + "specified."); // TODO(Add checker) + AddAttr( + "globalPooling", + "Whether to use the globalPooling." + "Bool constant equal to false or true." + "Default false." + "If globalPooling = true, ksize is ignored and need not be specified.") + .SetDefault(false); + AddAttr>("strides", + "Strides(height, width) of pooling operator." + "Default {1,1}") + .SetDefault({1, 1}); // TODO(Add checker) + AddAttr>("paddings", + "Paddings(height, width) of pooling operator." + "Default {0,0}.") + .SetDefault({0, 0}); // TODO(Add checker) + AddComment(R"DOC( +The pooling2d operation calculates the output based on +the input, poolingType and ksize, strides, paddings parameters. +)DOC"); + } +}; + +class Pool3dOpMaker : public framework::OpProtoAndCheckerMaker { + public: + Pool3dOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", + "The input tensor of pooling operator. " + "The format of input tensor is NCDHW. Where N is batch size, C is " + "the " + "number of channels, D, H and W is the depth, height and width of " + "feature."); + AddOutput("Out", + "The output tensor of pooling operator." + "The format of output tensor is also NCDHW."); + + AddAttr("poolingType", + "PoolingType of pooling operator." + "str constant equal to 'max' or 'avg'.") + .InEnum({"max", "avg"}); + AddAttr>( + "ksize", + "Pooling size(depth, height, width) of pooling operator." + "If globalPooling = true, ksize is ignored and need not be " + "specified."); // TODO(Add checker) + AddAttr( + "globalPooling", + "Whether to use the globalPooling." + "Bool constant equal to false or true." + "Default false." + "If globalPooling = true, ksize is ignored and need not be specified.") + .SetDefault(false); + AddAttr>( + "strides", + "Strides(depth, height, width) of pooling operator." + "Default {1,1,1}.") + .SetDefault({1, 1, 1}); // TODO(Add checker) + AddAttr>( + "paddings", + "Paddings(depth, height, width) of pooling operator." + "Default {0,0,0}.") + .SetDefault({0, 0, 0}); // TODO(Add checker) + AddComment(R"DOC( +The pooling3d operation calculates the output based on +the input, poolingType and ksize, strides, paddings parameters. +)DOC"); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP(pool2d, ops::PoolOp, ops::Pool2dOpMaker, pool2d_grad, + ops::PoolOpGrad); + +REGISTER_OP_CPU_KERNEL(pool2d, + ops::PoolKernel); +REGISTER_OP_CPU_KERNEL(pool2d_grad, + ops::PoolGradKernel) + +REGISTER_OP(pool3d, ops::PoolOp, ops::Pool3dOpMaker, pool3d_grad, + ops::PoolOpGrad); + +REGISTER_OP_CPU_KERNEL(pool3d, + ops::PoolKernel); +REGISTER_OP_CPU_KERNEL(pool3d_grad, + ops::PoolGradKernel); diff --git a/paddle/operators/pool_op.cu b/paddle/operators/pool_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..0e3b80868f7b9d1697d619889160856d65ad59a3 --- /dev/null +++ b/paddle/operators/pool_op.cu @@ -0,0 +1,27 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/pool_op.h" + +namespace ops = paddle::operators; + +REGISTER_OP_GPU_KERNEL(pool2d, + ops::PoolKernel); +REGISTER_OP_GPU_KERNEL(pool2d_grad, + ops::PoolGradKernel); + +REGISTER_OP_GPU_KERNEL(pool3d, + ops::PoolKernel); +REGISTER_OP_GPU_KERNEL(pool3d_grad, + ops::PoolGradKernel); diff --git a/paddle/operators/pool_op.h b/paddle/operators/pool_op.h new file mode 100644 index 0000000000000000000000000000000000000000..c2bc358def42959f2cc8f61cb00436fae1b7514b --- /dev/null +++ b/paddle/operators/pool_op.h @@ -0,0 +1,147 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" +#include "paddle/operators/math/math_function.h" +#include "paddle/operators/math/pooling.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class PoolKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* in_x = context.Input("X"); + Tensor* out = context.Output("Out"); + + std::string pooling_type = context.Attr("poolingType"); + std::vector ksize = context.Attr>("ksize"); + std::vector strides = context.Attr>("strides"); + std::vector paddings = context.Attr>("paddings"); + if (context.Attr("globalPooling")) { + for (size_t i = 0; i < ksize.size(); ++i) { + ksize[i] = static_cast(in_x->dims()[i + 2]); + } + } + + switch (ksize.size()) { + case 2: { + if (pooling_type == "max") { + paddle::operators::math::Pool2dFunctor< + Place, paddle::operators::math::MaxPool, T> + pool2d_forward; + paddle::operators::math::MaxPool pool_process; + pool2d_forward(context.device_context(), *in_x, *out, ksize, strides, + paddings, pool_process); + + } else if (pooling_type == "avg") { + paddle::operators::math::Pool2dFunctor< + Place, paddle::operators::math::AvgPool, T> + pool2d_forward; + paddle::operators::math::AvgPool pool_process; + pool2d_forward(context.device_context(), *in_x, *out, ksize, strides, + paddings, pool_process); + } + } break; + case 3: { + if (pooling_type == "max") { + paddle::operators::math::Pool3dFunctor< + Place, paddle::operators::math::MaxPool, T> + pool3d_forward; + paddle::operators::math::MaxPool pool_process; + pool3d_forward(context.device_context(), *in_x, *out, ksize, strides, + paddings, pool_process); + } else if (pooling_type == "avg") { + paddle::operators::math::Pool3dFunctor< + Place, paddle::operators::math::AvgPool, T> + pool3d_forward; + paddle::operators::math::AvgPool pool_process; + pool3d_forward(context.device_context(), *in_x, *out, ksize, strides, + paddings, pool_process); + } + } break; + } + } +}; + +template +class PoolGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* in_x = context.Input("X"); + const Tensor* out = context.Input("Out"); + const Tensor* out_grad = + context.Input(framework::GradVarName("Out")); + Tensor* in_x_grad = context.Output(framework::GradVarName("X")); + + std::string pooling_type = context.Attr("poolingType"); + std::vector ksize = context.Attr>("ksize"); + std::vector strides = context.Attr>("strides"); + std::vector paddings = context.Attr>("paddings"); + + if (context.Attr("globalPooling")) { + for (size_t i = 0; i < ksize.size(); ++i) + ksize[i] = static_cast(in_x->dims()[i + 2]); + } + + if (in_x_grad) { + in_x_grad->mutable_data(context.GetPlace()); + auto temp = framework::EigenVector::Flatten(*in_x_grad); + temp.device(context.GetEigenDevice()) = + temp.constant(static_cast(0)); + + switch (ksize.size()) { + case 2: { + if (pooling_type == "max") { + paddle::operators::math::MaxPool2dGradFunctor + pool2d_backward; + pool2d_backward(context.device_context(), *in_x, *in_x_grad, *out, + *out_grad, ksize, strides, paddings); + } else if (pooling_type == "avg") { + paddle::operators::math::Pool2dGradFunctor< + Place, paddle::operators::math::AvgPoolGrad, T> + pool2d_backward; + paddle::operators::math::AvgPoolGrad pool_process; + pool2d_backward(context.device_context(), *in_x, *in_x_grad, *out, + *out_grad, ksize, strides, paddings, pool_process); + } + } break; + case 3: { + if (pooling_type == "max") { + paddle::operators::math::MaxPool3dGradFunctor + pool3d_backward; + pool3d_backward(context.device_context(), *in_x, *in_x_grad, *out, + *out_grad, ksize, strides, paddings); + } else if (pooling_type == "avg") { + paddle::operators::math::Pool3dGradFunctor< + Place, paddle::operators::math::AvgPoolGrad, T> + pool3d_backward; + paddle::operators::math::AvgPoolGrad pool_process; + pool3d_backward(context.device_context(), *in_x, *in_x_grad, *out, + *out_grad, ksize, strides, paddings, pool_process); + } + } break; + } + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/platform/hostdevice.h b/paddle/platform/hostdevice.h index e7de86b7b2f75d206e730ec409bbee5d0a08942e..eb2df291cceef553d6422e6166e1fef2c63e2a47 100644 --- a/paddle/platform/hostdevice.h +++ b/paddle/platform/hostdevice.h @@ -2,8 +2,10 @@ #ifdef __CUDACC__ #define HOSTDEVICE __host__ __device__ +#define DEVICE __device__ #define HOST __host__ #else #define HOSTDEVICE +#define DEVICE #define HOST #endif diff --git a/python/paddle/v2/framework/tests/test_pool2d_op.py b/python/paddle/v2/framework/tests/test_pool2d_op.py new file mode 100644 index 0000000000000000000000000000000000000000..2941fda81b23998072810d8c6f6597a6f3db7e30 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_pool2d_op.py @@ -0,0 +1,144 @@ +import unittest +import numpy as np +from op_test import OpTest + + +def max_pool2D_forward_naive(x, ksize, strides, paddings=[0, 0], global_pool=0): + + N, C, H, W = x.shape + if global_pool == 1: + ksize = [H, W] + H_out = (H - ksize[0] + 2 * paddings[0]) / strides[0] + 1 + W_out = (W - ksize[1] + 2 * paddings[1]) / strides[1] + 1 + out = np.zeros((N, C, H_out, W_out)) + for i in xrange(H_out): + for j in xrange(W_out): + r_start = np.max((i * strides[0] - paddings[0], 0)) + r_end = np.min((i * strides[0] + ksize[0] - paddings[0], H)) + c_start = np.max((j * strides[1] - paddings[1], 0)) + c_end = np.min((j * strides[1] + ksize[1] - paddings[1], W)) + x_masked = x[:, :, r_start:r_end, c_start:c_end] + + out[:, :, i, j] = np.max(x_masked, axis=(2, 3)) + return out + + +def avg_pool2D_forward_naive(x, ksize, strides, paddings=[0, 0], global_pool=0): + + N, C, H, W = x.shape + if global_pool == 1: + ksize = [H, W] + H_out = (H - ksize[0] + 2 * paddings[0]) / strides[0] + 1 + W_out = (W - ksize[1] + 2 * paddings[1]) / strides[1] + 1 + out = np.zeros((N, C, H_out, W_out)) + for i in xrange(H_out): + for j in xrange(W_out): + r_start = np.max((i * strides[0] - paddings[0], 0)) + r_end = np.min((i * strides[0] + ksize[0] - paddings[0], H)) + c_start = np.max((j * strides[1] - paddings[1], 0)) + c_end = np.min((j * strides[1] + ksize[1] - paddings[1], W)) + x_masked = x[:, :, r_start:r_end, c_start:c_end] + + out[:, :, i, j] = np.sum(x_masked, axis=(2, 3)) / ( + (r_end - r_start) * (c_end - c_start)) + return out + + +class TestPool2d_Op(OpTest): + def setUp(self): + self.initTestCase() + input = np.random.random(self.shape).astype("float32") + output = self.pool2D_forward_naive(input, self.ksize, self.strides, + self.paddings, self.global_pool) + self.inputs = {'X': input} + + self.attrs = { + 'strides': self.strides, + 'paddings': self.paddings, + 'ksize': self.ksize, + 'poolingType': self.pool_type, + 'globalPooling': self.global_pool, + } + + self.outputs = {'Out': output} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + if self.pool_type != "max": + self.check_grad(set(['X']), 'Out', max_relative_error=0.07) + + def initTestCase(self): + self.global_pool = True + self.op_type = "pool2d" + self.pool_type = "avg" + self.pool2D_forward_naive = avg_pool2D_forward_naive + self.shape = [2, 3, 5, 5] + self.ksize = [3, 3] + self.strides = [1, 1] + self.paddings = [0, 0] + + +class TestCase1(TestPool2d_Op): + def initTestCase(self): + self.global_pool = False + self.op_type = "pool2d" + self.pool_type = "avg" + self.pool2D_forward_naive = avg_pool2D_forward_naive + self.shape = [2, 3, 7, 7] + self.ksize = [3, 3] + self.strides = [1, 1] + self.paddings = [0, 0] + + +class TestCase2(TestPool2d_Op): + def initTestCase(self): + self.global_pool = False + self.op_type = "pool2d" + self.pool_type = "avg" + self.pool2D_forward_naive = avg_pool2D_forward_naive + self.shape = [2, 3, 7, 7] + self.ksize = [3, 3] + self.strides = [1, 1] + self.paddings = [1, 1] + + +class TestCase3(TestPool2d_Op): + def initTestCase(self): + self.global_pool = True + self.op_type = "pool2d" + self.pool_type = "max" + self.pool2D_forward_naive = max_pool2D_forward_naive + self.shape = [2, 3, 5, 5] + self.ksize = [3, 3] + self.strides = [1, 1] + self.paddings = [0, 0] + + +class TestCase4(TestPool2d_Op): + def initTestCase(self): + self.global_pool = False + self.op_type = "pool2d" + self.pool_type = "max" + self.pool2D_forward_naive = max_pool2D_forward_naive + self.shape = [2, 3, 7, 7] + self.ksize = [3, 3] + self.strides = [1, 1] + self.paddings = [0, 0] + + +class TestCase5(TestPool2d_Op): + def initTestCase(self): + self.global_pool = False + self.op_type = "pool2d" + self.pool_type = "max" + self.pool2D_forward_naive = max_pool2D_forward_naive + self.shape = [2, 3, 7, 7] + self.ksize = [3, 3] + self.strides = [1, 1] + self.paddings = [1, 1] + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_pool3d_op.py b/python/paddle/v2/framework/tests/test_pool3d_op.py new file mode 100644 index 0000000000000000000000000000000000000000..8792b492e3da6541f71185be82b8bfc4f52d821d --- /dev/null +++ b/python/paddle/v2/framework/tests/test_pool3d_op.py @@ -0,0 +1,152 @@ +import unittest +import numpy as np +from op_test import OpTest + + +def max_pool3D_forward_naive(x, ksize, strides, paddings=[0, 0], global_pool=0): + + N, C, D, H, W = x.shape + if global_pool == 1: + ksize = [D, H, W] + D_out = (D - ksize[0] + 2 * paddings[0]) / strides[0] + 1 + H_out = (H - ksize[1] + 2 * paddings[1]) / strides[1] + 1 + W_out = (W - ksize[2] + 2 * paddings[2]) / strides[2] + 1 + out = np.zeros((N, C, D_out, H_out, W_out)) + for k in xrange(D_out): + d_start = np.max((k * strides[0] - paddings[0], 0)) + d_end = np.min((k * strides[0] + ksize[0] - paddings[0], D)) + for i in xrange(H_out): + h_start = np.max((i * strides[0] - paddings[0], 0)) + h_end = np.min((i * strides[0] + ksize[0] - paddings[0], H)) + for j in xrange(W_out): + w_start = np.max((j * strides[1] - paddings[1], 0)) + w_end = np.min((j * strides[1] + ksize[1] - paddings[1], W)) + x_masked = x[:, :, d_start:d_end, h_start:h_end, w_start:w_end] + + out[:, :, k, i, j] = np.max(x_masked, axis=(2, 3, 4)) + return out + + +def avg_pool3D_forward_naive(x, ksize, strides, paddings=[0, 0], global_pool=0): + + N, C, D, H, W = x.shape + if global_pool == 1: + ksize = [D, H, W] + D_out = (D - ksize[0] + 2 * paddings[0]) / strides[0] + 1 + H_out = (H - ksize[1] + 2 * paddings[1]) / strides[1] + 1 + W_out = (W - ksize[2] + 2 * paddings[2]) / strides[2] + 1 + out = np.zeros((N, C, D_out, H_out, W_out)) + for k in xrange(D_out): + d_start = np.max((k * strides[0] - paddings[0], 0)) + d_end = np.min((k * strides[0] + ksize[0] - paddings[0], D)) + for i in xrange(H_out): + h_start = np.max((i * strides[0] - paddings[0], 0)) + h_end = np.min((i * strides[0] + ksize[0] - paddings[0], H)) + for j in xrange(W_out): + w_start = np.max((j * strides[1] - paddings[1], 0)) + w_end = np.min((j * strides[1] + ksize[1] - paddings[1], W)) + x_masked = x[:, :, d_start:d_end, h_start:h_end, w_start:w_end] + + out[:, :, k, i, j] = np.sum(x_masked, axis=(2, 3, 4)) / ( + (d_end - d_start) * (h_end - h_start) * (w_end - w_start)) + return out + + +class TestPool3d_Op(OpTest): + def setUp(self): + self.initTestCase() + input = np.random.random(self.shape).astype("float32") + output = self.pool3D_forward_naive(input, self.ksize, self.strides, + self.paddings, self.global_pool) + self.inputs = {'X': input} + + self.attrs = { + 'strides': self.strides, + 'paddings': self.paddings, + 'ksize': self.ksize, + 'poolingType': self.pool_type, + 'globalPooling': self.global_pool, + } + + self.outputs = {'Out': output} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + if self.pool_type != "max": + self.check_grad(set(['X']), 'Out', max_relative_error=0.07) + + def initTestCase(self): + self.global_pool = True + self.op_type = "pool3d" + self.pool_type = "avg" + self.pool3D_forward_naive = avg_pool3D_forward_naive + self.shape = [2, 3, 5, 5, 5] + self.ksize = [3, 3, 3] + self.strides = [1, 1, 1] + self.paddings = [0, 0, 0] + + +class TestCase1(TestPool3d_Op): + def initTestCase(self): + self.global_pool = False + self.op_type = "pool3d" + self.pool_type = "avg" + self.pool3D_forward_naive = avg_pool3D_forward_naive + self.shape = [2, 3, 7, 7, 7] + self.ksize = [3, 3, 3] + self.strides = [1, 1, 1] + self.paddings = [0, 0, 0] + + +class TestCase2(TestPool3d_Op): + def initTestCase(self): + self.global_pool = False + self.op_type = "pool3d" + self.pool_type = "avg" + self.pool3D_forward_naive = avg_pool3D_forward_naive + self.shape = [2, 3, 7, 7, 7] + self.ksize = [3, 3, 3] + self.strides = [1, 1, 1] + self.paddings = [1, 1, 1] + + +class TestCase3(TestPool3d_Op): + def initTestCase(self): + self.global_pool = True + self.op_type = "pool3d" + self.pool_type = "max" + self.pool3D_forward_naive = max_pool3D_forward_naive + self.shape = [2, 3, 5, 5, 5] + self.ksize = [3, 3, 3] + self.strides = [1, 1, 1] + self.paddings = [0, 0, 0] + + +class TestCase4(TestPool3d_Op): + def initTestCase(self): + self.global_pool = False + self.op_type = "pool3d" + self.pool_type = "max" + self.pool3D_forward_naive = max_pool3D_forward_naive + self.shape = [2, 3, 7, 7, 7] + self.ksize = [3, 3, 3] + self.strides = [1, 1, 1] + self.paddings = [0, 0, 0] + + +class TestCase5(TestPool3d_Op): + def initTestCase(self): + self.global_pool = False + self.op_type = "pool3d" + self.pool_type = "max" + self.pool3D_forward_naive = max_pool3D_forward_naive + self.shape = [2, 3, 7, 7, 7] + self.ksize = [3, 3, 3] + self.strides = [1, 1, 1] + self.paddings = [1, 1, 1] + + +if __name__ == '__main__': + unittest.main()