diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index e56895c63a426b782f7b46091bc86c367d49899d..0feb969c623b800ab80142374c4d9e3d56833d22 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -62,6 +62,12 @@ function(op_library TARGET) file(APPEND ${pybind_file} "USE_OP(sigmoid);\n") endif() + if ("${TARGET}" STREQUAL "pool_with_index_op") + set(pybind_flag 1) + # It's enough to just adding one operator to pybind + file(APPEND ${pybind_file} "USE_OP(maxPool2dWithIndex);\n") + endif() + # pybind USE_NO_KERNEL_OP file(READ ${TARGET}.cc TARGET_CONTENT) string(REGEX MATCH "OperatorWithKernel" regex_result "${TARGET_CONTENT}") diff --git a/paddle/operators/math/CMakeLists.txt b/paddle/operators/math/CMakeLists.txt index 91ae3d49f1df51d9524547f7765285bff9dbb5c5..811deb4c2c96fe0bddfeac7ae0dc773816bdc2d0 100644 --- a/paddle/operators/math/CMakeLists.txt +++ b/paddle/operators/math/CMakeLists.txt @@ -1,12 +1,12 @@ if(WITH_GPU) nv_library(math_function SRCS math_function.cc math_function.cu im2col.cc - im2col.cu DEPS cblas device_context operator) + im2col.cu pooling.cc pooling.cu DEPS cblas device_context operator) nv_library(softmax_function SRCS softmax.cc softmax.cu DEPS operator) nv_library(cross_entropy_function SRCS cross_entropy.cc cross_entropy.cu DEPS operator) else() - cc_library(math_function SRCS math_function.cc im2col.cc + cc_library(math_function SRCS math_function.cc im2col.cc pooling.cc DEPS cblas device_context operator) cc_library(softmax_function SRCS softmax.cc DEPS operator) cc_library(cross_entropy_function SRCS cross_entropy.cc DEPS operator) diff --git a/paddle/operators/math/pooling.cc b/paddle/operators/math/pooling.cc new file mode 100644 index 0000000000000000000000000000000000000000..0e4d9007a63cd342668d5e6fa622deea06f90d9c --- /dev/null +++ b/paddle/operators/math/pooling.cc @@ -0,0 +1,255 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/pooling.h" + +namespace paddle { +namespace operators { +namespace math { + +template +class MaxPool2dWithIndexFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor& output, + framework::Tensor& mask, std::vector& ksize, + std::vector& strides, std::vector& paddings) { + const int batch_size = input.dims()[0]; + + const int input_height = input.dims()[2]; + const int input_width = input.dims()[3]; + const int output_channels = output.dims()[1]; + const int output_height = output.dims()[2]; + const int output_width = output.dims()[3]; + const int ksize_height = ksize[0]; + const int ksize_width = ksize[1]; + const int stride_height = strides[0]; + const int stride_width = strides[1]; + const int padding_height = paddings[0]; + const int padding_width = paddings[1]; + + const int input_stride = input_height * input_width; + const int output_stride = output_height * output_width; + + const T* input_data = input.data(); + T* output_data = output.mutable_data(context.GetPlace()); + + T* mask_data = mask.mutable_data(context.GetPlace()); + + for (int i = 0; i < batch_size; i++) { + for (int c = 0; c < output_channels; ++c) { + for (int ph = 0; ph < output_height; ++ph) { + int hstart = ph * stride_height - padding_height; + int hend = std::min(hstart + ksize_height, input_height); + hstart = std::max(hstart, 0); + for (int pw = 0; pw < output_width; ++pw) { + int wstart = pw * stride_width - padding_width; + int wend = std::min(wstart + ksize_width, input_width); + wstart = std::max(wstart, 0); + + T ele = static_cast(-FLT_MAX); + int index = -1; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + if (ele < input_data[h * input_width + w]) { + ele = input_data[h * input_width + w]; + index = h * input_width + w; + } + } + } + output_data[ph * output_width + pw] = ele; + mask_data[ph * output_width + pw] = index; + } + } + // offset + input_data += input_stride; + output_data += output_stride; + mask_data += output_stride; + } + } + } +}; + +template +class MaxPool2dWithIndexGradFunctor { + public: + void operator()(const platform::DeviceContext& context, + framework::Tensor& input_grad, + const framework::Tensor& output_grad, + const framework::Tensor& mask, std::vector& ksize, + std::vector& strides, std::vector& paddings) { + const int batch_size = input_grad.dims()[0]; + const int input_height = input_grad.dims()[2]; + const int input_width = input_grad.dims()[3]; + const int output_channels = output_grad.dims()[1]; + const int output_height = output_grad.dims()[2]; + const int output_width = output_grad.dims()[3]; + const int input_stride = input_height * input_width; + const int output_stride = output_height * output_width; + + const T* mask_data = mask.data(); + const T* output_grad_data = output_grad.data(); + T* input_grad_data = input_grad.mutable_data(context.GetPlace()); + + for (size_t n = 0; n < batch_size; ++n) { + for (size_t c = 0; c < output_channels; ++c) { + for (size_t ph = 0; ph < output_height; ++ph) { + for (size_t pw = 0; pw < output_width; ++pw) { + const size_t output_idx = ph * output_width + pw; + const size_t input_idx = static_cast(mask_data[output_idx]); + + input_grad_data[input_idx] += output_grad_data[output_idx]; + } + } + } + // offset + input_grad_data += input_stride; + output_grad_data += output_stride; + mask_data += output_stride; + } + } +}; + +template class MaxPool2dWithIndexFunctor; +template class MaxPool2dWithIndexGradFunctor; +template class MaxPool2dWithIndexFunctor; +template class MaxPool2dWithIndexGradFunctor; + +template +class MaxPool3dWithIndexFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor& output, + framework::Tensor& mask, std::vector& ksize, + std::vector& strides, std::vector& paddings) { + const int batch_size = input.dims()[0]; + const int input_depth = input.dims()[2]; + const int input_height = input.dims()[3]; + const int input_width = input.dims()[4]; + const int output_channels = output.dims()[1]; + const int output_depth = output.dims()[2]; + const int output_height = output.dims()[3]; + const int output_width = output.dims()[4]; + const int ksize_depth = ksize[0]; + const int ksize_height = ksize[1]; + const int ksize_width = ksize[2]; + const int stride_depth = strides[0]; + const int stride_height = strides[1]; + const int stride_width = strides[2]; + const int padding_depth = paddings[0]; + const int padding_height = paddings[1]; + const int padding_width = paddings[2]; + const int input_stride = input_depth * input_height * input_width; + const int output_stride = output_depth * output_height * output_width; + const T* input_data = input.data(); + T* output_data = output.mutable_data(context.GetPlace()); + T* mask_data = mask.mutable_data(context.GetPlace()); + + for (int i = 0; i < batch_size; i++) { + for (int c = 0; c < output_channels; ++c) { + for (int pd = 0; pd < output_depth; ++pd) { + int dstart = pd * stride_depth - padding_depth; + int dend = std::min(dstart + ksize_depth, input_depth); + dstart = std::max(dstart, 0); + for (int ph = 0; ph < output_height; ++ph) { + int hstart = ph * stride_height - padding_height; + int hend = std::min(hstart + ksize_height, input_height); + hstart = std::max(hstart, 0); + for (int pw = 0; pw < output_width; ++pw) { + int wstart = pw * stride_width - padding_width; + int wend = std::min(wstart + ksize_width, input_width); + wstart = std::max(wstart, 0); + int output_idx = (pd * output_height + ph) * output_width + pw; + T ele = static_cast(-FLT_MAX); + int index = -1; + for (int d = dstart; d < dend; ++d) { + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + if (ele < + input_data[(d * input_height + h) * input_width + w]) { + index = (d * input_height + h) * input_width + w; + ele = + input_data[(d * input_height + h) * input_width + w]; + } + } + } + } + output_data[output_idx] = ele; + mask_data[output_idx] = index; + } + } + } + // offset + input_data += input_stride; + output_data += output_stride; + mask_data += output_stride; + } + } + } +}; + +template +class MaxPool3dWithIndexGradFunctor { + public: + void operator()(const platform::DeviceContext& context, + framework::Tensor& input_grad, + const framework::Tensor& output_grad, + const framework::Tensor& mask, std::vector& ksize, + std::vector& strides, std::vector& paddings) { + const int batch_size = input_grad.dims()[0]; + const int input_depth = input_grad.dims()[2]; + const int input_height = input_grad.dims()[3]; + const int input_width = input_grad.dims()[4]; + const int output_channels = output_grad.dims()[1]; + const int output_depth = output_grad.dims()[2]; + const int output_height = output_grad.dims()[3]; + const int output_width = output_grad.dims()[4]; + const int input_stride = input_depth * input_height * input_width; + const int output_stride = output_depth * output_height * output_width; + + const T* mask_data = mask.data(); + const T* output_grad_data = output_grad.data(); + T* input_grad_data = input_grad.mutable_data(context.GetPlace()); + + for (size_t n = 0; n < batch_size; ++n) { + for (size_t c = 0; c < output_channels; ++c) { + for (size_t pd = 0; pd < output_depth; ++pd) { + for (size_t ph = 0; ph < output_height; ++ph) { + for (size_t pw = 0; pw < output_width; ++pw) { + const size_t output_idx = + (pd * output_height + ph) * output_width + pw; + const size_t input_idx = + static_cast(mask_data[output_idx]); + + input_grad_data[input_idx] += output_grad_data[output_idx]; + } + } + } + // offset + input_grad_data += input_stride; + output_grad_data += output_stride; + mask_data += output_stride; + } + } + } +}; + +template class MaxPool3dWithIndexFunctor; +template class MaxPool3dWithIndexGradFunctor; +template class MaxPool3dWithIndexFunctor; +template class MaxPool3dWithIndexGradFunctor; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/pooling.cu b/paddle/operators/math/pooling.cu new file mode 100644 index 0000000000000000000000000000000000000000..f32e6a26d0db0bc2b2d493f5c5f5ba292f76b62c --- /dev/null +++ b/paddle/operators/math/pooling.cu @@ -0,0 +1,387 @@ +/* Copyright (c) 2016 paddlepaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/pooling.h" +#include "paddle/platform/cuda_helper.h" + +namespace paddle { +namespace operators { +namespace math { + +template +__global__ void KernelMaxPool2dWithIdxForward( + const int nthreads, const T* input_data, T* output_data, T* mask_data, + const int channels, const int input_height, const int input_width, + const int output_height, const int output_width, const int ksize_height, + const int ksize_width, const int stride_height, const int stride_width, + const int padding_height, const int padding_width) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < nthreads) { + int pw = index % output_width; + int ph = (index / output_width) % output_height; + int c = (index / output_width / output_height) % channels; + int batch_idx = index / output_width / output_height / channels; + + int hstart = ph * stride_height - padding_height; + int hend = min(hstart + ksize_height, input_height); + hstart = max(hstart, 0); + + int wstart = pw * stride_width - padding_width; + int wend = min(wstart + ksize_width, input_width); + wstart = max(wstart, 0); + + input_data += (batch_idx * channels + c) * input_height * input_width; + T ele = -FLT_MAX; + int index = -1; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + if (ele < input_data[h * input_width + w]) { + index = h * input_width + w; + ele = input_data[h * input_width + w]; + } + } + } + output_data[index] = ele; + mask_data[index] = index; + } +} + +template +__global__ void KernelMaxPool2DWithIdxBackward( + const int nthreads, T* input_grad, const T* output_grad, const T* mask_data, + const int channels, const int input_height, const int input_width, + const int output_height, const int output_width, const int ksize_height, + const int ksize_width, const int stride_height, const int stride_width, + const int padding_height, const int padding_width) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < nthreads) { + int offsetW = index % input_width + padding_width; + int offsetH = (index / input_width) % input_height + padding_height; + int offsetC = (index / input_width / input_height) % channels; + int batch_idx = index / input_width / input_height / channels; + + int phstart = (offsetH < ksize_height) + ? 0 + : (offsetH - ksize_height) / stride_height + 1; + int pwstart = (offsetW < ksize_width) + ? 0 + : (offsetW - ksize_width) / stride_width + 1; + int phend = min(offsetH / stride_height + 1, output_height); + int pwend = min(offsetW / stride_width + 1, output_width); + T gradient = 0; + int output_idx = + (batch_idx * channels + offsetC) * output_height * output_width; + mask_data += output_idx; + output_grad += output_idx; + for (int ph = phstart; ph < phend; ++ph) { + for (int pw = pwstart; pw < pwend; ++pw) { + if ((offsetH * input_width + offsetW) == + mask_data[ph * output_width + pw]) + gradient += output_grad[ph * output_width + pw]; + } + } + input_grad[index] = gradient; + } +} + +template +class MaxPool2dWithIndexFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor& output, + framework::Tensor& mask, std::vector& ksize, + std::vector& strides, std::vector& paddings) { + const int batch_size = input.dims()[0]; + const int input_channels = input.dims()[1]; + const int input_height = input.dims()[2]; + const int input_width = input.dims()[3]; + const int output_channels = output.dims()[1]; + const int output_height = output.dims()[2]; + const int output_width = output.dims()[3]; + const int ksize_height = ksize[0]; + const int ksize_width = ksize[1]; + const int stride_height = strides[0]; + const int stride_width = strides[1]; + const int padding_height = paddings[0]; + const int padding_width = paddings[1]; + + const T* input_data = input.data(); + T* output_data = output.mutable_data(context.GetPlace()); + T* mask_data = mask.mutable_data(context.GetPlace()); + + int nthreads = batch_size * output_channels * output_height * output_width; + int blocks = (nthreads + 1024 - 1) / 1024; + dim3 threads(1024, 1); + dim3 grid(blocks, 1); + + KernelMaxPool2dWithIdxForward< + T><<(context) + .stream()>>>(nthreads, input_data, output_data, mask_data, + input_channels, input_height, input_width, + output_height, output_width, ksize_height, + ksize_width, stride_height, stride_width, + padding_height, padding_width); + } +}; + +template +class MaxPool2dWithIndexGradFunctor { + public: + void operator()(const platform::DeviceContext& context, + framework::Tensor& input_grad, + const framework::Tensor& output_grad, + const framework::Tensor& mask, std::vector& ksize, + std::vector& strides, std::vector& paddings) { + const int batch_size = input_grad.dims()[0]; + const int input_channels = input_grad.dims()[1]; + const int input_height = input_grad.dims()[2]; + const int input_width = input_grad.dims()[3]; + const int output_channels = output_grad.dims()[1]; + const int output_height = output_grad.dims()[2]; + const int output_width = output_grad.dims()[3]; + const int ksize_height = ksize[0]; + const int ksize_width = ksize[1]; + const int stride_height = strides[0]; + const int stride_width = strides[1]; + const int padding_height = paddings[0]; + const int padding_width = paddings[1]; + + const T* mask_data = mask.data(); + const T* output_grad_data = output_grad.data(); + T* input_grad_data = input_grad.mutable_data(context.GetPlace()); + + int nthreads = batch_size * input_channels * input_height * input_width; + int blocks = (nthreads + 1024 - 1) / 1024; + dim3 threads(1024, 1); + dim3 grid(blocks, 1); + + KernelMaxPool2DWithIdxBackward< + T><<(context) + .stream()>>>(nthreads, input_grad_data, output_grad_data, + mask_data, input_channels, input_height, + input_width, output_height, output_width, + ksize_height, ksize_width, stride_height, + stride_width, padding_height, padding_width); + } +}; + +template class MaxPool2dWithIndexFunctor; +template class MaxPool2dWithIndexGradFunctor; +template class MaxPool2dWithIndexFunctor; +template class MaxPool2dWithIndexGradFunctor; + +template +__global__ void KernelMaxPool3DWithIdxForward( + const int nthreads, const T* input_data, T* output_data, T* mask_data, + const int channels, const int input_depth, const int input_height, + const int input_width, const int output_depth, const int output_height, + const int output_width, const int ksize_depth, const int ksize_height, + const int ksize_width, const int stride_depth, const int stride_height, + const int stride_width, const int padding_depth, const int padding_height, + const int padding_width) { + for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < (nthreads); + index += blockDim.x * gridDim.x) { + int pw = index % output_width; + int ph = (index / output_width) % output_height; + int pd = (index / output_width / output_height) % output_depth; + int c = (index / output_width / output_height / output_depth) % channels; + int batch_idx = + index / output_width / output_height / output_depth / channels; + int dstart = pd * stride_depth - padding_depth; + int hstart = ph * stride_height - padding_height; + int wstart = pw * stride_width - padding_width; + int dend = min(dstart + ksize_depth, input_depth); + int hend = min(hstart + ksize_height, input_height); + int wend = min(wstart + ksize_width, input_width); + dstart = max(dstart, 0); + hstart = max(hstart, 0); + wstart = max(wstart, 0); + T ele = -FLT_MAX; + int index = -1; + input_data += + (batch_idx * channels + c) * input_depth * input_height * input_width; + + for (int d = dstart; d < dend; ++d) { + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + if (ele < input_data[(d * input_height + h) * input_width + w]) { + index = (d * input_height + h) * input_width + w; + ele = input_data[(d * input_height + h) * input_width + w]; + } + } + } + } + output_data[index] = ele; + mask_data[index] = index; + } +} + +template +__global__ void KernelMaxPool3DWithIdxBackward( + const int nthreads, T* input_grad, const T* output_grad, const T* mask, + const int channels, const int input_depth, const int input_height, + const int input_width, const int output_depth, const int output_height, + const int output_width, const int ksize_depth, const int ksize_height, + const int ksize_width, const int stride_depth, const int stride_height, + const int stride_width, const int padding_depth, const int padding_height, + const int padding_width) { + for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < (nthreads); + index += blockDim.x * gridDim.x) { + int offsetW = index % input_width + padding_width; + int offsetH = (index / input_width) % input_height + padding_height; + int offsetD = + (index / input_width / input_height) % input_depth + padding_depth; + int offsetC = (index / input_width / input_height / input_depth) % channels; + int batch_idx = index / input_width / input_height / input_depth / channels; + + int pdstart = (offsetD < ksize_depth) + ? 0 + : (offsetD - ksize_depth) / stride_depth + 1; + int phstart = (offsetH < ksize_height) + ? 0 + : (offsetH - ksize_height) / stride_height + 1; + int pwstart = (offsetW < ksize_width) + ? 0 + : (offsetW - ksize_width) / stride_width + 1; + int pdend = min((offsetD) / stride_depth + 1, output_depth); + int phend = min((offsetH) / stride_height + 1, output_height); + int pwend = min((offsetW) / stride_width + 1, output_width); + + T gradient = 0; + int output_idx = (batch_idx * channels + offsetC) * output_depth * + output_height * output_width; + mask += output_idx; + output_grad += output_idx; + + for (int pd = pdstart; pd < pdend; ++pd) { + for (int ph = phstart; ph < phend; ++ph) { + for (int pw = pwstart; pw < pwend; ++pw) { + if (((offsetD * input_height + offsetH) * input_width + offsetW) == + mask[(pd * output_height + ph) * output_width + pw]) + gradient += + output_grad[(pd * output_height + ph) * output_width + pw]; + } + } + } + input_grad[index] = gradient; + } +} + +template +class MaxPool3dWithIndexFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor& output, + framework::Tensor& mask, std::vector& ksize, + std::vector& strides, std::vector& paddings) { + const int batch_size = input.dims()[0]; + const int input_channels = input.dims()[1]; + const int input_depth = input.dims()[2]; + const int input_height = input.dims()[3]; + const int input_width = input.dims()[4]; + const int output_channels = output.dims()[1]; + const int output_depth = output.dims()[2]; + const int output_height = output.dims()[3]; + const int output_width = output.dims()[4]; + const int ksize_depth = ksize[0]; + const int ksize_height = ksize[1]; + const int ksize_width = ksize[2]; + const int stride_depth = strides[0]; + const int stride_height = strides[1]; + const int stride_width = strides[2]; + const int padding_depth = paddings[0]; + const int padding_height = paddings[1]; + const int padding_width = paddings[2]; + + const T* input_data = input.data(); + T* output_data = output.mutable_data(context.GetPlace()); + T* mask_data = output.mutable_data(context.GetPlace()); + + int nthreads = batch_size * output_channels * output_depth * output_height * + output_width; + int blocks = (nthreads + 1024 - 1) / 1024; + dim3 threads(1024, 1); + dim3 grid(blocks, 1); + + KernelMaxPool3DWithIdxForward< + T><<(context) + .stream()>>>( + nthreads, input_data, output_data, mask_data, input_channels, + input_depth, input_height, input_width, output_depth, output_height, + output_width, ksize_depth, ksize_height, ksize_width, stride_depth, + stride_height, stride_width, padding_depth, padding_height, + padding_width); + } +}; + +template +class MaxPool3dWithIndexGradFunctor { + public: + void operator()(const platform::DeviceContext& context, + framework::Tensor& input_grad, + const framework::Tensor& output_grad, + const framework::Tensor& mask, std::vector& ksize, + std::vector& strides, std::vector& paddings) { + const int batch_size = input_grad.dims()[0]; + const int input_channels = input_grad.dims()[1]; + const int input_depth = input_grad.dims()[2]; + const int input_height = input_grad.dims()[3]; + const int input_width = input_grad.dims()[4]; + const int output_channels = input_grad.dims()[1]; + const int output_depth = input_grad.dims()[2]; + const int output_height = input_grad.dims()[3]; + const int output_width = input_grad.dims()[4]; + const int ksize_depth = ksize[0]; + const int ksize_height = ksize[1]; + const int ksize_width = ksize[2]; + const int stride_depth = strides[0]; + const int stride_height = strides[1]; + const int stride_width = strides[2]; + const int padding_depth = paddings[0]; + const int padding_height = paddings[1]; + const int padding_width = paddings[2]; + + const T* output_grad_data = output_grad.data(); + const T* mask_data = mask.data(); + T* input_grad_data = input_grad.mutable_data(context.GetPlace()); + + int nthreads = + batch_size * input_channels * input_depth * input_height * input_width; + int blocks = (nthreads + 1024 - 1) / 1024; + dim3 threads(1024, 1); + dim3 grid(blocks, 1); + + KernelMaxPool3DWithIdxBackward< + T><<(context) + .stream()>>>( + nthreads, input_grad_data, output_grad_data, mask_data, input_channels, + input_depth, input_height, input_width, output_depth, output_height, + output_width, ksize_depth, ksize_height, ksize_width, stride_depth, + stride_height, stride_width, padding_depth, padding_height, + padding_width); + } +}; + +template class MaxPool3dWithIndexFunctor; +template class MaxPool3dWithIndexGradFunctor; +template class MaxPool3dWithIndexFunctor; +template class MaxPool3dWithIndexGradFunctor; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/pooling.h b/paddle/operators/math/pooling.h new file mode 100644 index 0000000000000000000000000000000000000000..3a05cd98fed3a4baa394dd50bf3a36c7d32e2e09 --- /dev/null +++ b/paddle/operators/math/pooling.h @@ -0,0 +1,68 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/eigen.h" +#include "paddle/framework/tensor.h" +#include "paddle/platform/device_context.h" +#include "paddle/platform/hostdevice.h" + +namespace paddle { +namespace operators { +namespace math { +////////////////////// +#define FLT_MAX __FLT_MAX__ +///////////////////// + +template +class MaxPool2dWithIndexFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor& output, + framework::Tensor& mask, std::vector& ksize, + std::vector& strides, std::vector& paddings); +}; + +template +class MaxPool2dWithIndexGradFunctor { + public: + void operator()(const platform::DeviceContext& context, + framework::Tensor& input_grad, + const framework::Tensor& output_grad, + const framework::Tensor& mask, std::vector& ksize, + std::vector& strides, std::vector& paddings); +}; + +template +class MaxPool3dWithIndexFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor& output, + framework::Tensor& mask, std::vector& ksize, + std::vector& strides, std::vector& paddings); +}; + +template +class MaxPool3dWithIndexGradFunctor { + public: + void operator()(const platform::DeviceContext& context, + framework::Tensor& input_grad, + const framework::Tensor& output_grad, + const framework::Tensor& mask, std::vector& ksize, + std::vector& strides, std::vector& paddings); +}; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/pool_with_index_op.cc b/paddle/operators/pool_with_index_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..d7a07a403dbd4ff18d428f387e07adfeabb6b6cc --- /dev/null +++ b/paddle/operators/pool_with_index_op.cc @@ -0,0 +1,198 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/pool_with_index_op.h" + +namespace paddle { +namespace operators { + +int OutputSizeMaxPool(int input_size, int filter_size, int padding, + int stride) { + int output_size = (input_size - filter_size + 2 * padding) / stride + 1; + return output_size; +} + +class MaxPoolWithIndexOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContextBase *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "X(Input) of Pooling should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Out(Output) of Pooling should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Mask"), + "Out(Output) of Pooling should not be null."); + + auto in_x_dims = ctx->GetInputDim("X"); + + std::vector ksize = ctx->Attrs().Get>("ksize"); + std::vector strides = ctx->Attrs().Get>("strides"); + std::vector paddings = ctx->Attrs().Get>("paddings"); + + PADDLE_ENFORCE(in_x_dims.size() == 4 || in_x_dims.size() == 5, + "Pooling intput should be 4-D or 5-D"); + + if (ctx->Attrs().Get("globalPooling")) { + ksize.resize(static_cast(in_x_dims.size()) - 2); + for (size_t i = 0; i < ksize.size(); ++i) + ksize[i] = static_cast(in_x_dims[i + 2]); + } + + PADDLE_ENFORCE(in_x_dims.size() - ksize.size() == 2U, + "Pooling intput size and pooling size should be consistent"); + PADDLE_ENFORCE(ksize.size() == 2 || ksize.size() == 3, + "Pooling size size should be 2 elements. or 3 elements."); + PADDLE_ENFORCE_EQ(ksize.size(), strides.size(), + "strides size and pooling size should be the same."); + PADDLE_ENFORCE_EQ(ksize.size(), paddings.size(), + "paddings size and pooling size should be the same."); + + std::vector output_shape({in_x_dims[0], in_x_dims[1]}); + for (size_t i = 0; i < ksize.size(); ++i) { + output_shape.push_back(OutputSizeMaxPool(in_x_dims[i + 2], ksize[i], + paddings[i], strides[i])); + } + ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); + ctx->SetOutputDim("Mask", framework::make_ddim(output_shape)); + } +}; + +class MaxPoolWithIndexOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContextBase *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("X")), + "X(Input) of MaxPoolWithIndexOpGrad should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput(framework::GradVarName("X")), + "X@GRAD(Input@GRAD) of MaxPoolWithIndexOpGrad should not be null."); + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + } +}; + +class MaxPool2dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker { + public: + MaxPool2dWithIndexOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput( + "X", + "The input tensor of pooling operator. " + "The format of input tensor is NCHW. Where N is batch size, C is the " + "number of channels, H and W is the height and width of image."); + AddOutput("Out", + "The output tensor of pooling operator." + "The format of output tensor is also NCHW."); + AddOutput("Mask", + "The Mask tensor of pooling operator." + "The format of output tensor is also NCHW."); + + AddAttr>( + "ksize", "pooling size(height, width) of pooling operator."); + AddAttr( + "globalPooling", + "whether to use the globalPooling." + "int constant equal to false or true" + "default false" + "If globalPooling = true, ksize is ignored and need not be specified.") + .SetDefault(false); + AddAttr>("strides", + "strides(height, width) of pooling operator." + "default {1,1}") + .SetDefault({1, 1}); + AddAttr>("paddings", + "paddings(height, width) of pooling operator." + "default {0,0}") + .SetDefault({0, 0}); + + AddComment(R"DOC( +The maxPooling2d with index operation calculates the output and the mask based on +the input and ksize, strides, paddings parameters. +)DOC"); + } +}; + +class MaxPool3dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker { + public: + MaxPool3dWithIndexOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput( + "X", + "The input tensor of pooling operator. " + "The format of input tensor is NCDHW. Where N is batch size, C is " + "the number of channels, D, H and W is the depth, height and width of " + "image."); + AddOutput("Out", + "The output tensor of pooling operator." + "The format of output tensor is also NCDHW."); + AddOutput("Mask", + "The Mask tensor of pooling operator." + "The format of output tensor is also NCDHW."); + + AddAttr>( + "ksize", "pooling size(depth, height, width) of pooling operator."); + AddAttr( + "globalPooling", + "whether to use the globalPooling." + "int constant equal to false or true" + "default false" + "If globalPooling = true, ksize is ignored and need not be specified.") + .SetDefault(false); + AddAttr>( + "strides", + "strides(depth, height, width) of pooling operator." + "default {1,1,1}") + .SetDefault({1, 1, 1}); + AddAttr>( + "paddings", + "paddings(depth, height, width) of pooling operator." + "default {0,0,0}") + .SetDefault({0, 0, 0}); + AddComment(R"DOC( +The maxpooling3d with index operation calculates the output and the mask based on +the input and ksize, strides, paddings parameters. +)DOC"); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP(maxPool2dWithIndex, ops::MaxPoolWithIndexOp, + ops::MaxPool2dWithIndexOpMaker, maxPool2dWithIndex_grad, + ops::MaxPoolWithIndexOpGrad); + +REGISTER_OP_CPU_KERNEL( + maxPool2dWithIndex, + ops::MaxPoolWithIndexKernel); +REGISTER_OP_CPU_KERNEL( + maxPool2dWithIndex_grad, + ops::MaxPoolWithIndexGradKernel) + +REGISTER_OP(maxPool3dWithIndex, ops::MaxPoolWithIndexOp, + ops::MaxPool3dWithIndexOpMaker, maxPool3dWithIndex_grad, + ops::MaxPoolWithIndexOpGrad); + +REGISTER_OP_CPU_KERNEL( + maxPool3dWithIndex, + ops::MaxPoolWithIndexKernel); +REGISTER_OP_CPU_KERNEL( + maxPool3dWithIndex_grad, + ops::MaxPoolWithIndexGradKernel) diff --git a/paddle/operators/pool_with_index_op.cu b/paddle/operators/pool_with_index_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..8007fc7ccfb3c1565ffd3b826c81da7844e75125 --- /dev/null +++ b/paddle/operators/pool_with_index_op.cu @@ -0,0 +1,31 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/pool_with_index_op.h" + +namespace ops = paddle::operators; + +REGISTER_OP_GPU_KERNEL( + maxPool2dWithIndex, + ops::MaxPoolWithIndexKernel); +REGISTER_OP_GPU_KERNEL( + maxPool2dWithIndex_grad, + ops::MaxPoolWithIndexGradKernel) + +REGISTER_OP_GPU_KERNEL( + maxPool3dWithIndex, + ops::MaxPoolWithIndexKernel); +REGISTER_OP_GPU_KERNEL( + maxPool3dWithIndex_grad, + ops::MaxPoolWithIndexGradKernel) diff --git a/paddle/operators/pool_with_index_op.h b/paddle/operators/pool_with_index_op.h new file mode 100644 index 0000000000000000000000000000000000000000..91abeed016fe9d84bef321e9e7633ad658eb120c --- /dev/null +++ b/paddle/operators/pool_with_index_op.h @@ -0,0 +1,99 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" +#include "paddle/operators/math/math_function.h" +#include "paddle/operators/math/pooling.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class MaxPoolWithIndexKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* in_x = context.Input("X"); + Tensor* out = context.Output("Out"); + Tensor* mask = context.Output("Mask"); + + bool global_pooling = context.Attr("globalPooling"); + std::vector ksize = context.Attr>("ksize"); + std::vector strides = context.Attr>("strides"); + std::vector paddings = context.Attr>("paddings"); + if (global_pooling) { + for (size_t i = 0; i < ksize.size(); ++i) { + ksize[i] = static_cast(in_x->dims()[i + 2]); + } + } + + switch (ksize.size()) { + case 2: { + paddle::operators::math::MaxPool2dWithIndexFunctor + pool2d_forward; + pool2d_forward(context.device_context(), *in_x, *out, *mask, ksize, + strides, paddings); + } break; + case 3: { + paddle::operators::math::MaxPool3dWithIndexFunctor + pool3d_forward; + pool3d_forward(context.device_context(), *in_x, *out, *mask, ksize, + strides, paddings); + } break; + } + } +}; + +template +class MaxPoolWithIndexGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* mask = context.Input("Maks"); + const Tensor* out_grad = + context.Input(framework::GradVarName("Out")); + Tensor* in_x_grad = context.Output(framework::GradVarName("X")); + + std::vector ksize = context.Attr>("ksize"); + std::vector strides = context.Attr>("strides"); + std::vector paddings = context.Attr>("paddings"); + + if (in_x_grad) { + in_x_grad->mutable_data(context.GetPlace()); + auto temp = framework::EigenVector::Flatten(*in_x_grad); + temp.device(context.GetEigenDevice()) = + temp.constant(static_cast(0)); + + switch (ksize.size()) { + case 2: { + paddle::operators::math::MaxPool2dWithIndexGradFunctor + pool2d_backward; + pool2d_backward(context.device_context(), *in_x_grad, *out_grad, + *mask, ksize, strides, paddings); + } break; + case 3: { + paddle::operators::math::MaxPool3dWithIndexGradFunctor + pool3d_backward; + pool3d_backward(context.device_context(), *in_x_grad, *out_grad, + *mask, ksize, strides, paddings); + } break; + } + } + } +}; +} // namespace operators +} // namespace paddle diff --git a/python/paddle/v2/framework/tests/test_pool_max_op.py b/python/paddle/v2/framework/tests/test_pool_max_op.py new file mode 100644 index 0000000000000000000000000000000000000000..2945c8b7a4eecf6cd04d69d25b0686c8add5f196 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_pool_max_op.py @@ -0,0 +1,125 @@ +import unittest +import numpy as np +from op_test import OpTest + + +def max_pool3D_forward_naive(x, ksize, strides, paddings=[0, 0], global_pool=0): + + N, C, D, H, W = x.shape + if global_pool == 1: + ksize = [D, H, W] + D_out = (D - ksize[0] + 2 * paddings[0]) / strides[0] + 1 + H_out = (H - ksize[1] + 2 * paddings[1]) / strides[1] + 1 + W_out = (W - ksize[2] + 2 * paddings[2]) / strides[2] + 1 + out = np.zeros((N, C, D_out, H_out, W_out)) + mask = np.zeros((N, C, D_out, H_out, W_out)) + for k in xrange(D_out): + d_start = np.max((k * strides[0] - paddings[0], 0)) + d_end = np.min((k * strides[0] + ksize[0] - paddings[0], D)) + for i in xrange(H_out): + h_start = np.max((i * strides[0] - paddings[0], 0)) + h_end = np.min((i * strides[0] + ksize[0] - paddings[0], H)) + for j in xrange(W_out): + w_start = np.max((j * strides[1] - paddings[1], 0)) + w_end = np.min((j * strides[1] + ksize[1] - paddings[1], W)) + x_masked = x[:, :, d_start:d_end, h_start:h_end, w_start:w_end] + + out[:, :, k, i, j] = np.max(x_masked, axis=(2, 3, 4)) + # mask[:,:, k, i, j] = np.argmax(x_masked, axis=(2, 3, 4)) + return out + + +def max_pool2D_forward_naive(x, ksize, strides, paddings=[0, 0], global_pool=0): + + N, C, H, W = x.shape + if global_pool == 1: + ksize = [H, W] + H_out = (H - ksize[0] + 2 * paddings[0]) / strides[0] + 1 + W_out = (W - ksize[1] + 2 * paddings[1]) / strides[1] + 1 + out = np.zeros((N, C, H_out, W_out)) + mask = np.zeros((N, C, H_out, W_out)) + for i in xrange(H_out): + for j in xrange(W_out): + r_start = np.max((i * strides[0] - paddings[0], 0)) + r_end = np.min((i * strides[0] + ksize[0] - paddings[0], H)) + c_start = np.max((j * strides[1] - paddings[1], 0)) + c_end = np.min((j * strides[1] + ksize[1] - paddings[1], W)) + x_masked = x[:, :, r_start:r_end, c_start:c_end] + + out[:, :, i, j] = np.max(x_masked, axis=(2, 3)) + # mask[:,:, i, j] = np.argmax(x_masked, axis=(2, 3)) + + return out + + +class TestMaxPoolWithIndex_Op(OpTest): + def setUp(self): + self.initTestCase() + self.op_type = "maxPool3dWithIndex" + input = np.random.random(self.shape).astype("float32") + output = self.pool_forward_naive(input, self.ksize, self.strides, + self.paddings, self.global_pool) + # mask = np.zeros(output.shape) + + self.attrs = { + 'strides': self.strides, + 'paddings': self.paddings, + 'ksize': self.ksize, + 'globalPooling': self.global_pool, + } + + self.inputs = {'X': input} + self.outputs = {'Out': output} + + def test_check_output(self): + self.check_output() + + # def test_check_grad(self): + # self.check_grad(set(['X']), ['Out'], max_relative_error=0.07) + + def initTestCase(self): + self.global_pool = 0 + self.pool_forward_naive = max_pool3D_forward_naive + self.shape = [2, 3, 7, 7, 7] + self.ksize = [3, 3, 3] + self.strides = [1, 1, 1] + self.paddings = [1, 1, 1] + + +"""" +class TestCase1(TestMaxPoolWithIndex_Op): + def initTestCase(self): + self.global_pool = 1 + self.op_type = "maxPool3dWithIndex" + self.pool_forward_naive = max_pool3D_forward_naive + self.shape = [2, 3, 5, 5, 5] + self.ksize = [3, 3, 3] + self.strides = [1, 1, 1] + self.paddings = [0, 0, 0] + + +class TestCase2(TestMaxPoolWithIndex_Op): + def initTestCase(self): + self.global_pool = 0 + self.op_type = "maxPool2dWithIndex" + self.pool_forward_naive = max_pool2D_forward_naive + self.shape = [2, 3, 7, 7] + self.ksize = [3, 3] + self.strides = [1, 1] + self.paddings = [1, 1] + + +class TestCase3(TestMaxPoolWithIndex_Op): + def initTestCase(self): + self.global_pool = 1 + self.op_type = "maxPool2dWithIndex" + self.pool_forward_naive = max_pool2D_forward_naive + self.shape = [2, 3, 5, 5] + self.ksize = [3, 3] + self.strides = [1, 1] + self.paddings = [0, 0] + + +if __name__ == '__main__': + unittest.main() +"""