提交 6326c40d 编写于 作者: C chengduoZH

Add max pool with index

上级 38bca7d3
......@@ -62,6 +62,12 @@ function(op_library TARGET)
file(APPEND ${pybind_file} "USE_OP(sigmoid);\n")
endif()
if ("${TARGET}" STREQUAL "pool_with_index_op")
set(pybind_flag 1)
# It's enough to just adding one operator to pybind
file(APPEND ${pybind_file} "USE_OP(maxPool2dWithIndex);\n")
endif()
# pybind USE_NO_KERNEL_OP
file(READ ${TARGET}.cc TARGET_CONTENT)
string(REGEX MATCH "OperatorWithKernel" regex_result "${TARGET_CONTENT}")
......
if(WITH_GPU)
nv_library(math_function SRCS math_function.cc math_function.cu im2col.cc
im2col.cu DEPS cblas device_context operator)
im2col.cu pooling.cc pooling.cu DEPS cblas device_context operator)
nv_library(softmax_function SRCS softmax.cc softmax.cu
DEPS operator)
nv_library(cross_entropy_function SRCS cross_entropy.cc cross_entropy.cu
DEPS operator)
else()
cc_library(math_function SRCS math_function.cc im2col.cc
cc_library(math_function SRCS math_function.cc im2col.cc pooling.cc
DEPS cblas device_context operator)
cc_library(softmax_function SRCS softmax.cc DEPS operator)
cc_library(cross_entropy_function SRCS cross_entropy.cc DEPS operator)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/pooling.h"
namespace paddle {
namespace operators {
namespace math {
template <typename T>
class MaxPool2dWithIndexFunctor<platform::CPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& output,
framework::Tensor& mask, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings) {
const int batch_size = input.dims()[0];
const int input_height = input.dims()[2];
const int input_width = input.dims()[3];
const int output_channels = output.dims()[1];
const int output_height = output.dims()[2];
const int output_width = output.dims()[3];
const int ksize_height = ksize[0];
const int ksize_width = ksize[1];
const int stride_height = strides[0];
const int stride_width = strides[1];
const int padding_height = paddings[0];
const int padding_width = paddings[1];
const int input_stride = input_height * input_width;
const int output_stride = output_height * output_width;
const T* input_data = input.data<T>();
T* output_data = output.mutable_data<T>(context.GetPlace());
T* mask_data = mask.mutable_data<T>(context.GetPlace());
for (int i = 0; i < batch_size; i++) {
for (int c = 0; c < output_channels; ++c) {
for (int ph = 0; ph < output_height; ++ph) {
int hstart = ph * stride_height - padding_height;
int hend = std::min(hstart + ksize_height, input_height);
hstart = std::max(hstart, 0);
for (int pw = 0; pw < output_width; ++pw) {
int wstart = pw * stride_width - padding_width;
int wend = std::min(wstart + ksize_width, input_width);
wstart = std::max(wstart, 0);
T ele = static_cast<T>(-FLT_MAX);
int index = -1;
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
if (ele < input_data[h * input_width + w]) {
ele = input_data[h * input_width + w];
index = h * input_width + w;
}
}
}
output_data[ph * output_width + pw] = ele;
mask_data[ph * output_width + pw] = index;
}
}
// offset
input_data += input_stride;
output_data += output_stride;
mask_data += output_stride;
}
}
}
};
template <typename T>
class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
framework::Tensor& input_grad,
const framework::Tensor& output_grad,
const framework::Tensor& mask, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings) {
const int batch_size = input_grad.dims()[0];
const int input_height = input_grad.dims()[2];
const int input_width = input_grad.dims()[3];
const int output_channels = output_grad.dims()[1];
const int output_height = output_grad.dims()[2];
const int output_width = output_grad.dims()[3];
const int input_stride = input_height * input_width;
const int output_stride = output_height * output_width;
const T* mask_data = mask.data<T>();
const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad.mutable_data<T>(context.GetPlace());
for (size_t n = 0; n < batch_size; ++n) {
for (size_t c = 0; c < output_channels; ++c) {
for (size_t ph = 0; ph < output_height; ++ph) {
for (size_t pw = 0; pw < output_width; ++pw) {
const size_t output_idx = ph * output_width + pw;
const size_t input_idx = static_cast<size_t>(mask_data[output_idx]);
input_grad_data[input_idx] += output_grad_data[output_idx];
}
}
}
// offset
input_grad_data += input_stride;
output_grad_data += output_stride;
mask_data += output_stride;
}
}
};
template class MaxPool2dWithIndexFunctor<platform::CPUPlace, float>;
template class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, float>;
template class MaxPool2dWithIndexFunctor<platform::CPUPlace, double>;
template class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, double>;
template <typename T>
class MaxPool3dWithIndexFunctor<platform::CPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& output,
framework::Tensor& mask, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings) {
const int batch_size = input.dims()[0];
const int input_depth = input.dims()[2];
const int input_height = input.dims()[3];
const int input_width = input.dims()[4];
const int output_channels = output.dims()[1];
const int output_depth = output.dims()[2];
const int output_height = output.dims()[3];
const int output_width = output.dims()[4];
const int ksize_depth = ksize[0];
const int ksize_height = ksize[1];
const int ksize_width = ksize[2];
const int stride_depth = strides[0];
const int stride_height = strides[1];
const int stride_width = strides[2];
const int padding_depth = paddings[0];
const int padding_height = paddings[1];
const int padding_width = paddings[2];
const int input_stride = input_depth * input_height * input_width;
const int output_stride = output_depth * output_height * output_width;
const T* input_data = input.data<T>();
T* output_data = output.mutable_data<T>(context.GetPlace());
T* mask_data = mask.mutable_data<T>(context.GetPlace());
for (int i = 0; i < batch_size; i++) {
for (int c = 0; c < output_channels; ++c) {
for (int pd = 0; pd < output_depth; ++pd) {
int dstart = pd * stride_depth - padding_depth;
int dend = std::min(dstart + ksize_depth, input_depth);
dstart = std::max(dstart, 0);
for (int ph = 0; ph < output_height; ++ph) {
int hstart = ph * stride_height - padding_height;
int hend = std::min(hstart + ksize_height, input_height);
hstart = std::max(hstart, 0);
for (int pw = 0; pw < output_width; ++pw) {
int wstart = pw * stride_width - padding_width;
int wend = std::min(wstart + ksize_width, input_width);
wstart = std::max(wstart, 0);
int output_idx = (pd * output_height + ph) * output_width + pw;
T ele = static_cast<T>(-FLT_MAX);
int index = -1;
for (int d = dstart; d < dend; ++d) {
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
if (ele <
input_data[(d * input_height + h) * input_width + w]) {
index = (d * input_height + h) * input_width + w;
ele =
input_data[(d * input_height + h) * input_width + w];
}
}
}
}
output_data[output_idx] = ele;
mask_data[output_idx] = index;
}
}
}
// offset
input_data += input_stride;
output_data += output_stride;
mask_data += output_stride;
}
}
}
};
template <typename T>
class MaxPool3dWithIndexGradFunctor<platform::CPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
framework::Tensor& input_grad,
const framework::Tensor& output_grad,
const framework::Tensor& mask, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings) {
const int batch_size = input_grad.dims()[0];
const int input_depth = input_grad.dims()[2];
const int input_height = input_grad.dims()[3];
const int input_width = input_grad.dims()[4];
const int output_channels = output_grad.dims()[1];
const int output_depth = output_grad.dims()[2];
const int output_height = output_grad.dims()[3];
const int output_width = output_grad.dims()[4];
const int input_stride = input_depth * input_height * input_width;
const int output_stride = output_depth * output_height * output_width;
const T* mask_data = mask.data<T>();
const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad.mutable_data<T>(context.GetPlace());
for (size_t n = 0; n < batch_size; ++n) {
for (size_t c = 0; c < output_channels; ++c) {
for (size_t pd = 0; pd < output_depth; ++pd) {
for (size_t ph = 0; ph < output_height; ++ph) {
for (size_t pw = 0; pw < output_width; ++pw) {
const size_t output_idx =
(pd * output_height + ph) * output_width + pw;
const size_t input_idx =
static_cast<size_t>(mask_data[output_idx]);
input_grad_data[input_idx] += output_grad_data[output_idx];
}
}
}
// offset
input_grad_data += input_stride;
output_grad_data += output_stride;
mask_data += output_stride;
}
}
}
};
template class MaxPool3dWithIndexFunctor<platform::CPUPlace, float>;
template class MaxPool3dWithIndexGradFunctor<platform::CPUPlace, float>;
template class MaxPool3dWithIndexFunctor<platform::CPUPlace, double>;
template class MaxPool3dWithIndexGradFunctor<platform::CPUPlace, double>;
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 paddlepaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/pooling.h"
#include "paddle/platform/cuda_helper.h"
namespace paddle {
namespace operators {
namespace math {
template <typename T>
__global__ void KernelMaxPool2dWithIdxForward(
const int nthreads, const T* input_data, T* output_data, T* mask_data,
const int channels, const int input_height, const int input_width,
const int output_height, const int output_width, const int ksize_height,
const int ksize_width, const int stride_height, const int stride_width,
const int padding_height, const int padding_width) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < nthreads) {
int pw = index % output_width;
int ph = (index / output_width) % output_height;
int c = (index / output_width / output_height) % channels;
int batch_idx = index / output_width / output_height / channels;
int hstart = ph * stride_height - padding_height;
int hend = min(hstart + ksize_height, input_height);
hstart = max(hstart, 0);
int wstart = pw * stride_width - padding_width;
int wend = min(wstart + ksize_width, input_width);
wstart = max(wstart, 0);
input_data += (batch_idx * channels + c) * input_height * input_width;
T ele = -FLT_MAX;
int index = -1;
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
if (ele < input_data[h * input_width + w]) {
index = h * input_width + w;
ele = input_data[h * input_width + w];
}
}
}
output_data[index] = ele;
mask_data[index] = index;
}
}
template <typename T>
__global__ void KernelMaxPool2DWithIdxBackward(
const int nthreads, T* input_grad, const T* output_grad, const T* mask_data,
const int channels, const int input_height, const int input_width,
const int output_height, const int output_width, const int ksize_height,
const int ksize_width, const int stride_height, const int stride_width,
const int padding_height, const int padding_width) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < nthreads) {
int offsetW = index % input_width + padding_width;
int offsetH = (index / input_width) % input_height + padding_height;
int offsetC = (index / input_width / input_height) % channels;
int batch_idx = index / input_width / input_height / channels;
int phstart = (offsetH < ksize_height)
? 0
: (offsetH - ksize_height) / stride_height + 1;
int pwstart = (offsetW < ksize_width)
? 0
: (offsetW - ksize_width) / stride_width + 1;
int phend = min(offsetH / stride_height + 1, output_height);
int pwend = min(offsetW / stride_width + 1, output_width);
T gradient = 0;
int output_idx =
(batch_idx * channels + offsetC) * output_height * output_width;
mask_data += output_idx;
output_grad += output_idx;
for (int ph = phstart; ph < phend; ++ph) {
for (int pw = pwstart; pw < pwend; ++pw) {
if ((offsetH * input_width + offsetW) ==
mask_data[ph * output_width + pw])
gradient += output_grad[ph * output_width + pw];
}
}
input_grad[index] = gradient;
}
}
template <typename T>
class MaxPool2dWithIndexFunctor<platform::GPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& output,
framework::Tensor& mask, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings) {
const int batch_size = input.dims()[0];
const int input_channels = input.dims()[1];
const int input_height = input.dims()[2];
const int input_width = input.dims()[3];
const int output_channels = output.dims()[1];
const int output_height = output.dims()[2];
const int output_width = output.dims()[3];
const int ksize_height = ksize[0];
const int ksize_width = ksize[1];
const int stride_height = strides[0];
const int stride_width = strides[1];
const int padding_height = paddings[0];
const int padding_width = paddings[1];
const T* input_data = input.data<T>();
T* output_data = output.mutable_data<T>(context.GetPlace());
T* mask_data = mask.mutable_data<T>(context.GetPlace());
int nthreads = batch_size * output_channels * output_height * output_width;
int blocks = (nthreads + 1024 - 1) / 1024;
dim3 threads(1024, 1);
dim3 grid(blocks, 1);
KernelMaxPool2dWithIdxForward<
T><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(nthreads, input_data, output_data, mask_data,
input_channels, input_height, input_width,
output_height, output_width, ksize_height,
ksize_width, stride_height, stride_width,
padding_height, padding_width);
}
};
template <typename T>
class MaxPool2dWithIndexGradFunctor<platform::GPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
framework::Tensor& input_grad,
const framework::Tensor& output_grad,
const framework::Tensor& mask, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings) {
const int batch_size = input_grad.dims()[0];
const int input_channels = input_grad.dims()[1];
const int input_height = input_grad.dims()[2];
const int input_width = input_grad.dims()[3];
const int output_channels = output_grad.dims()[1];
const int output_height = output_grad.dims()[2];
const int output_width = output_grad.dims()[3];
const int ksize_height = ksize[0];
const int ksize_width = ksize[1];
const int stride_height = strides[0];
const int stride_width = strides[1];
const int padding_height = paddings[0];
const int padding_width = paddings[1];
const T* mask_data = mask.data<T>();
const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad.mutable_data<T>(context.GetPlace());
int nthreads = batch_size * input_channels * input_height * input_width;
int blocks = (nthreads + 1024 - 1) / 1024;
dim3 threads(1024, 1);
dim3 grid(blocks, 1);
KernelMaxPool2DWithIdxBackward<
T><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(nthreads, input_grad_data, output_grad_data,
mask_data, input_channels, input_height,
input_width, output_height, output_width,
ksize_height, ksize_width, stride_height,
stride_width, padding_height, padding_width);
}
};
template class MaxPool2dWithIndexFunctor<platform::GPUPlace, float>;
template class MaxPool2dWithIndexGradFunctor<platform::GPUPlace, float>;
template class MaxPool2dWithIndexFunctor<platform::GPUPlace, double>;
template class MaxPool2dWithIndexGradFunctor<platform::GPUPlace, double>;
template <typename T>
__global__ void KernelMaxPool3DWithIdxForward(
const int nthreads, const T* input_data, T* output_data, T* mask_data,
const int channels, const int input_depth, const int input_height,
const int input_width, const int output_depth, const int output_height,
const int output_width, const int ksize_depth, const int ksize_height,
const int ksize_width, const int stride_depth, const int stride_height,
const int stride_width, const int padding_depth, const int padding_height,
const int padding_width) {
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < (nthreads);
index += blockDim.x * gridDim.x) {
int pw = index % output_width;
int ph = (index / output_width) % output_height;
int pd = (index / output_width / output_height) % output_depth;
int c = (index / output_width / output_height / output_depth) % channels;
int batch_idx =
index / output_width / output_height / output_depth / channels;
int dstart = pd * stride_depth - padding_depth;
int hstart = ph * stride_height - padding_height;
int wstart = pw * stride_width - padding_width;
int dend = min(dstart + ksize_depth, input_depth);
int hend = min(hstart + ksize_height, input_height);
int wend = min(wstart + ksize_width, input_width);
dstart = max(dstart, 0);
hstart = max(hstart, 0);
wstart = max(wstart, 0);
T ele = -FLT_MAX;
int index = -1;
input_data +=
(batch_idx * channels + c) * input_depth * input_height * input_width;
for (int d = dstart; d < dend; ++d) {
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
if (ele < input_data[(d * input_height + h) * input_width + w]) {
index = (d * input_height + h) * input_width + w;
ele = input_data[(d * input_height + h) * input_width + w];
}
}
}
}
output_data[index] = ele;
mask_data[index] = index;
}
}
template <typename T>
__global__ void KernelMaxPool3DWithIdxBackward(
const int nthreads, T* input_grad, const T* output_grad, const T* mask,
const int channels, const int input_depth, const int input_height,
const int input_width, const int output_depth, const int output_height,
const int output_width, const int ksize_depth, const int ksize_height,
const int ksize_width, const int stride_depth, const int stride_height,
const int stride_width, const int padding_depth, const int padding_height,
const int padding_width) {
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < (nthreads);
index += blockDim.x * gridDim.x) {
int offsetW = index % input_width + padding_width;
int offsetH = (index / input_width) % input_height + padding_height;
int offsetD =
(index / input_width / input_height) % input_depth + padding_depth;
int offsetC = (index / input_width / input_height / input_depth) % channels;
int batch_idx = index / input_width / input_height / input_depth / channels;
int pdstart = (offsetD < ksize_depth)
? 0
: (offsetD - ksize_depth) / stride_depth + 1;
int phstart = (offsetH < ksize_height)
? 0
: (offsetH - ksize_height) / stride_height + 1;
int pwstart = (offsetW < ksize_width)
? 0
: (offsetW - ksize_width) / stride_width + 1;
int pdend = min((offsetD) / stride_depth + 1, output_depth);
int phend = min((offsetH) / stride_height + 1, output_height);
int pwend = min((offsetW) / stride_width + 1, output_width);
T gradient = 0;
int output_idx = (batch_idx * channels + offsetC) * output_depth *
output_height * output_width;
mask += output_idx;
output_grad += output_idx;
for (int pd = pdstart; pd < pdend; ++pd) {
for (int ph = phstart; ph < phend; ++ph) {
for (int pw = pwstart; pw < pwend; ++pw) {
if (((offsetD * input_height + offsetH) * input_width + offsetW) ==
mask[(pd * output_height + ph) * output_width + pw])
gradient +=
output_grad[(pd * output_height + ph) * output_width + pw];
}
}
}
input_grad[index] = gradient;
}
}
template <typename T>
class MaxPool3dWithIndexFunctor<platform::GPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& output,
framework::Tensor& mask, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings) {
const int batch_size = input.dims()[0];
const int input_channels = input.dims()[1];
const int input_depth = input.dims()[2];
const int input_height = input.dims()[3];
const int input_width = input.dims()[4];
const int output_channels = output.dims()[1];
const int output_depth = output.dims()[2];
const int output_height = output.dims()[3];
const int output_width = output.dims()[4];
const int ksize_depth = ksize[0];
const int ksize_height = ksize[1];
const int ksize_width = ksize[2];
const int stride_depth = strides[0];
const int stride_height = strides[1];
const int stride_width = strides[2];
const int padding_depth = paddings[0];
const int padding_height = paddings[1];
const int padding_width = paddings[2];
const T* input_data = input.data<T>();
T* output_data = output.mutable_data<T>(context.GetPlace());
T* mask_data = output.mutable_data<T>(context.GetPlace());
int nthreads = batch_size * output_channels * output_depth * output_height *
output_width;
int blocks = (nthreads + 1024 - 1) / 1024;
dim3 threads(1024, 1);
dim3 grid(blocks, 1);
KernelMaxPool3DWithIdxForward<
T><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(
nthreads, input_data, output_data, mask_data, input_channels,
input_depth, input_height, input_width, output_depth, output_height,
output_width, ksize_depth, ksize_height, ksize_width, stride_depth,
stride_height, stride_width, padding_depth, padding_height,
padding_width);
}
};
template <typename T>
class MaxPool3dWithIndexGradFunctor<platform::GPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
framework::Tensor& input_grad,
const framework::Tensor& output_grad,
const framework::Tensor& mask, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings) {
const int batch_size = input_grad.dims()[0];
const int input_channels = input_grad.dims()[1];
const int input_depth = input_grad.dims()[2];
const int input_height = input_grad.dims()[3];
const int input_width = input_grad.dims()[4];
const int output_channels = input_grad.dims()[1];
const int output_depth = input_grad.dims()[2];
const int output_height = input_grad.dims()[3];
const int output_width = input_grad.dims()[4];
const int ksize_depth = ksize[0];
const int ksize_height = ksize[1];
const int ksize_width = ksize[2];
const int stride_depth = strides[0];
const int stride_height = strides[1];
const int stride_width = strides[2];
const int padding_depth = paddings[0];
const int padding_height = paddings[1];
const int padding_width = paddings[2];
const T* output_grad_data = output_grad.data<T>();
const T* mask_data = mask.data<T>();
T* input_grad_data = input_grad.mutable_data<T>(context.GetPlace());
int nthreads =
batch_size * input_channels * input_depth * input_height * input_width;
int blocks = (nthreads + 1024 - 1) / 1024;
dim3 threads(1024, 1);
dim3 grid(blocks, 1);
KernelMaxPool3DWithIdxBackward<
T><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(
nthreads, input_grad_data, output_grad_data, mask_data, input_channels,
input_depth, input_height, input_width, output_depth, output_height,
output_width, ksize_depth, ksize_height, ksize_width, stride_depth,
stride_height, stride_width, padding_depth, padding_height,
padding_width);
}
};
template class MaxPool3dWithIndexFunctor<platform::GPUPlace, float>;
template class MaxPool3dWithIndexGradFunctor<platform::GPUPlace, float>;
template class MaxPool3dWithIndexFunctor<platform::GPUPlace, double>;
template class MaxPool3dWithIndexGradFunctor<platform::GPUPlace, double>;
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/hostdevice.h"
namespace paddle {
namespace operators {
namespace math {
//////////////////////
#define FLT_MAX __FLT_MAX__
/////////////////////
template <typename Place, typename T>
class MaxPool2dWithIndexFunctor {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& output,
framework::Tensor& mask, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings);
};
template <typename Place, typename T>
class MaxPool2dWithIndexGradFunctor {
public:
void operator()(const platform::DeviceContext& context,
framework::Tensor& input_grad,
const framework::Tensor& output_grad,
const framework::Tensor& mask, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings);
};
template <typename Place, typename T>
class MaxPool3dWithIndexFunctor {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& output,
framework::Tensor& mask, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings);
};
template <typename Place, typename T>
class MaxPool3dWithIndexGradFunctor {
public:
void operator()(const platform::DeviceContext& context,
framework::Tensor& input_grad,
const framework::Tensor& output_grad,
const framework::Tensor& mask, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings);
};
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/pool_with_index_op.h"
namespace paddle {
namespace operators {
int OutputSizeMaxPool(int input_size, int filter_size, int padding,
int stride) {
int output_size = (input_size - filter_size + 2 * padding) / stride + 1;
return output_size;
}
class MaxPoolWithIndexOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"X(Input) of Pooling should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Out(Output) of Pooling should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Mask"),
"Out(Output) of Pooling should not be null.");
auto in_x_dims = ctx->GetInputDim("X");
std::vector<int> ksize = ctx->Attrs().Get<std::vector<int>>("ksize");
std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
PADDLE_ENFORCE(in_x_dims.size() == 4 || in_x_dims.size() == 5,
"Pooling intput should be 4-D or 5-D");
if (ctx->Attrs().Get<bool>("globalPooling")) {
ksize.resize(static_cast<size_t>(in_x_dims.size()) - 2);
for (size_t i = 0; i < ksize.size(); ++i)
ksize[i] = static_cast<int>(in_x_dims[i + 2]);
}
PADDLE_ENFORCE(in_x_dims.size() - ksize.size() == 2U,
"Pooling intput size and pooling size should be consistent");
PADDLE_ENFORCE(ksize.size() == 2 || ksize.size() == 3,
"Pooling size size should be 2 elements. or 3 elements.");
PADDLE_ENFORCE_EQ(ksize.size(), strides.size(),
"strides size and pooling size should be the same.");
PADDLE_ENFORCE_EQ(ksize.size(), paddings.size(),
"paddings size and pooling size should be the same.");
std::vector<int64_t> output_shape({in_x_dims[0], in_x_dims[1]});
for (size_t i = 0; i < ksize.size(); ++i) {
output_shape.push_back(OutputSizeMaxPool(in_x_dims[i + 2], ksize[i],
paddings[i], strides[i]));
}
ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
ctx->SetOutputDim("Mask", framework::make_ddim(output_shape));
}
};
class MaxPoolWithIndexOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("X")),
"X(Input) of MaxPoolWithIndexOpGrad should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput(framework::GradVarName("X")),
"X@GRAD(Input@GRAD) of MaxPoolWithIndexOpGrad should not be null.");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
};
class MaxPool2dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker {
public:
MaxPool2dWithIndexOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(
"X",
"The input tensor of pooling operator. "
"The format of input tensor is NCHW. Where N is batch size, C is the "
"number of channels, H and W is the height and width of image.");
AddOutput("Out",
"The output tensor of pooling operator."
"The format of output tensor is also NCHW.");
AddOutput("Mask",
"The Mask tensor of pooling operator."
"The format of output tensor is also NCHW.");
AddAttr<std::vector<int>>(
"ksize", "pooling size(height, width) of pooling operator.");
AddAttr<bool>(
"globalPooling",
"whether to use the globalPooling."
"int constant equal to false or true"
"default false"
"If globalPooling = true, ksize is ignored and need not be specified.")
.SetDefault(false);
AddAttr<std::vector<int>>("strides",
"strides(height, width) of pooling operator."
"default {1,1}")
.SetDefault({1, 1});
AddAttr<std::vector<int>>("paddings",
"paddings(height, width) of pooling operator."
"default {0,0}")
.SetDefault({0, 0});
AddComment(R"DOC(
The maxPooling2d with index operation calculates the output and the mask based on
the input and ksize, strides, paddings parameters.
)DOC");
}
};
class MaxPool3dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker {
public:
MaxPool3dWithIndexOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(
"X",
"The input tensor of pooling operator. "
"The format of input tensor is NCDHW. Where N is batch size, C is "
"the number of channels, D, H and W is the depth, height and width of "
"image.");
AddOutput("Out",
"The output tensor of pooling operator."
"The format of output tensor is also NCDHW.");
AddOutput("Mask",
"The Mask tensor of pooling operator."
"The format of output tensor is also NCDHW.");
AddAttr<std::vector<int>>(
"ksize", "pooling size(depth, height, width) of pooling operator.");
AddAttr<bool>(
"globalPooling",
"whether to use the globalPooling."
"int constant equal to false or true"
"default false"
"If globalPooling = true, ksize is ignored and need not be specified.")
.SetDefault(false);
AddAttr<std::vector<int>>(
"strides",
"strides(depth, height, width) of pooling operator."
"default {1,1,1}")
.SetDefault({1, 1, 1});
AddAttr<std::vector<int>>(
"paddings",
"paddings(depth, height, width) of pooling operator."
"default {0,0,0}")
.SetDefault({0, 0, 0});
AddComment(R"DOC(
The maxpooling3d with index operation calculates the output and the mask based on
the input and ksize, strides, paddings parameters.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(maxPool2dWithIndex, ops::MaxPoolWithIndexOp,
ops::MaxPool2dWithIndexOpMaker, maxPool2dWithIndex_grad,
ops::MaxPoolWithIndexOpGrad);
REGISTER_OP_CPU_KERNEL(
maxPool2dWithIndex,
ops::MaxPoolWithIndexKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
maxPool2dWithIndex_grad,
ops::MaxPoolWithIndexGradKernel<paddle::platform::CPUPlace, float>)
REGISTER_OP(maxPool3dWithIndex, ops::MaxPoolWithIndexOp,
ops::MaxPool3dWithIndexOpMaker, maxPool3dWithIndex_grad,
ops::MaxPoolWithIndexOpGrad);
REGISTER_OP_CPU_KERNEL(
maxPool3dWithIndex,
ops::MaxPoolWithIndexKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
maxPool3dWithIndex_grad,
ops::MaxPoolWithIndexGradKernel<paddle::platform::CPUPlace, float>)
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/pool_with_index_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
maxPool2dWithIndex,
ops::MaxPoolWithIndexKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
maxPool2dWithIndex_grad,
ops::MaxPoolWithIndexGradKernel<paddle::platform::GPUPlace, float>)
REGISTER_OP_GPU_KERNEL(
maxPool3dWithIndex,
ops::MaxPoolWithIndexKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
maxPool3dWithIndex_grad,
ops::MaxPoolWithIndexGradKernel<paddle::platform::GPUPlace, float>)
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/pooling.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename Place, typename T>
class MaxPoolWithIndexKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* in_x = context.Input<Tensor>("X");
Tensor* out = context.Output<Tensor>("Out");
Tensor* mask = context.Output<Tensor>("Mask");
bool global_pooling = context.Attr<bool>("globalPooling");
std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
if (global_pooling) {
for (size_t i = 0; i < ksize.size(); ++i) {
ksize[i] = static_cast<int>(in_x->dims()[i + 2]);
}
}
switch (ksize.size()) {
case 2: {
paddle::operators::math::MaxPool2dWithIndexFunctor<Place, T>
pool2d_forward;
pool2d_forward(context.device_context(), *in_x, *out, *mask, ksize,
strides, paddings);
} break;
case 3: {
paddle::operators::math::MaxPool3dWithIndexFunctor<Place, T>
pool3d_forward;
pool3d_forward(context.device_context(), *in_x, *out, *mask, ksize,
strides, paddings);
} break;
}
}
};
template <typename Place, typename T>
class MaxPoolWithIndexGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* mask = context.Input<Tensor>("Maks");
const Tensor* out_grad =
context.Input<Tensor>(framework::GradVarName("Out"));
Tensor* in_x_grad = context.Output<Tensor>(framework::GradVarName("X"));
std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
if (in_x_grad) {
in_x_grad->mutable_data<T>(context.GetPlace());
auto temp = framework::EigenVector<T>::Flatten(*in_x_grad);
temp.device(context.GetEigenDevice<Place>()) =
temp.constant(static_cast<T>(0));
switch (ksize.size()) {
case 2: {
paddle::operators::math::MaxPool2dWithIndexGradFunctor<Place, T>
pool2d_backward;
pool2d_backward(context.device_context(), *in_x_grad, *out_grad,
*mask, ksize, strides, paddings);
} break;
case 3: {
paddle::operators::math::MaxPool3dWithIndexGradFunctor<Place, T>
pool3d_backward;
pool3d_backward(context.device_context(), *in_x_grad, *out_grad,
*mask, ksize, strides, paddings);
} break;
}
}
}
};
} // namespace operators
} // namespace paddle
import unittest
import numpy as np
from op_test import OpTest
def max_pool3D_forward_naive(x, ksize, strides, paddings=[0, 0], global_pool=0):
N, C, D, H, W = x.shape
if global_pool == 1:
ksize = [D, H, W]
D_out = (D - ksize[0] + 2 * paddings[0]) / strides[0] + 1
H_out = (H - ksize[1] + 2 * paddings[1]) / strides[1] + 1
W_out = (W - ksize[2] + 2 * paddings[2]) / strides[2] + 1
out = np.zeros((N, C, D_out, H_out, W_out))
mask = np.zeros((N, C, D_out, H_out, W_out))
for k in xrange(D_out):
d_start = np.max((k * strides[0] - paddings[0], 0))
d_end = np.min((k * strides[0] + ksize[0] - paddings[0], D))
for i in xrange(H_out):
h_start = np.max((i * strides[0] - paddings[0], 0))
h_end = np.min((i * strides[0] + ksize[0] - paddings[0], H))
for j in xrange(W_out):
w_start = np.max((j * strides[1] - paddings[1], 0))
w_end = np.min((j * strides[1] + ksize[1] - paddings[1], W))
x_masked = x[:, :, d_start:d_end, h_start:h_end, w_start:w_end]
out[:, :, k, i, j] = np.max(x_masked, axis=(2, 3, 4))
# mask[:,:, k, i, j] = np.argmax(x_masked, axis=(2, 3, 4))
return out
def max_pool2D_forward_naive(x, ksize, strides, paddings=[0, 0], global_pool=0):
N, C, H, W = x.shape
if global_pool == 1:
ksize = [H, W]
H_out = (H - ksize[0] + 2 * paddings[0]) / strides[0] + 1
W_out = (W - ksize[1] + 2 * paddings[1]) / strides[1] + 1
out = np.zeros((N, C, H_out, W_out))
mask = np.zeros((N, C, H_out, W_out))
for i in xrange(H_out):
for j in xrange(W_out):
r_start = np.max((i * strides[0] - paddings[0], 0))
r_end = np.min((i * strides[0] + ksize[0] - paddings[0], H))
c_start = np.max((j * strides[1] - paddings[1], 0))
c_end = np.min((j * strides[1] + ksize[1] - paddings[1], W))
x_masked = x[:, :, r_start:r_end, c_start:c_end]
out[:, :, i, j] = np.max(x_masked, axis=(2, 3))
# mask[:,:, i, j] = np.argmax(x_masked, axis=(2, 3))
return out
class TestMaxPoolWithIndex_Op(OpTest):
def setUp(self):
self.initTestCase()
self.op_type = "maxPool3dWithIndex"
input = np.random.random(self.shape).astype("float32")
output = self.pool_forward_naive(input, self.ksize, self.strides,
self.paddings, self.global_pool)
# mask = np.zeros(output.shape)
self.attrs = {
'strides': self.strides,
'paddings': self.paddings,
'ksize': self.ksize,
'globalPooling': self.global_pool,
}
self.inputs = {'X': input}
self.outputs = {'Out': output}
def test_check_output(self):
self.check_output()
# def test_check_grad(self):
# self.check_grad(set(['X']), ['Out'], max_relative_error=0.07)
def initTestCase(self):
self.global_pool = 0
self.pool_forward_naive = max_pool3D_forward_naive
self.shape = [2, 3, 7, 7, 7]
self.ksize = [3, 3, 3]
self.strides = [1, 1, 1]
self.paddings = [1, 1, 1]
""""
class TestCase1(TestMaxPoolWithIndex_Op):
def initTestCase(self):
self.global_pool = 1
self.op_type = "maxPool3dWithIndex"
self.pool_forward_naive = max_pool3D_forward_naive
self.shape = [2, 3, 5, 5, 5]
self.ksize = [3, 3, 3]
self.strides = [1, 1, 1]
self.paddings = [0, 0, 0]
class TestCase2(TestMaxPoolWithIndex_Op):
def initTestCase(self):
self.global_pool = 0
self.op_type = "maxPool2dWithIndex"
self.pool_forward_naive = max_pool2D_forward_naive
self.shape = [2, 3, 7, 7]
self.ksize = [3, 3]
self.strides = [1, 1]
self.paddings = [1, 1]
class TestCase3(TestMaxPoolWithIndex_Op):
def initTestCase(self):
self.global_pool = 1
self.op_type = "maxPool2dWithIndex"
self.pool_forward_naive = max_pool2D_forward_naive
self.shape = [2, 3, 5, 5]
self.ksize = [3, 3]
self.strides = [1, 1]
self.paddings = [0, 0]
if __name__ == '__main__':
unittest.main()
"""
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册