提交 bee95fc8 编写于 作者: C chengduoZH

fix code format and some bug

上级 6326c40d
...@@ -26,7 +26,6 @@ class MaxPool2dWithIndexFunctor<platform::CPUPlace, T> { ...@@ -26,7 +26,6 @@ class MaxPool2dWithIndexFunctor<platform::CPUPlace, T> {
framework::Tensor& mask, std::vector<int>& ksize, framework::Tensor& mask, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings) { std::vector<int>& strides, std::vector<int>& paddings) {
const int batch_size = input.dims()[0]; const int batch_size = input.dims()[0];
const int input_height = input.dims()[2]; const int input_height = input.dims()[2];
const int input_width = input.dims()[3]; const int input_width = input.dims()[3];
const int output_channels = output.dims()[1]; const int output_channels = output.dims()[1];
...@@ -112,13 +111,13 @@ class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, T> { ...@@ -112,13 +111,13 @@ class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, T> {
input_grad_data[input_idx] += output_grad_data[output_idx]; input_grad_data[input_idx] += output_grad_data[output_idx];
} }
} }
}
// offset // offset
input_grad_data += input_stride; input_grad_data += input_stride;
output_grad_data += output_stride; output_grad_data += output_stride;
mask_data += output_stride; mask_data += output_stride;
} }
} }
}
}; };
template class MaxPool2dWithIndexFunctor<platform::CPUPlace, float>; template class MaxPool2dWithIndexFunctor<platform::CPUPlace, float>;
...@@ -152,6 +151,7 @@ class MaxPool3dWithIndexFunctor<platform::CPUPlace, T> { ...@@ -152,6 +151,7 @@ class MaxPool3dWithIndexFunctor<platform::CPUPlace, T> {
const int padding_width = paddings[2]; const int padding_width = paddings[2];
const int input_stride = input_depth * input_height * input_width; const int input_stride = input_depth * input_height * input_width;
const int output_stride = output_depth * output_height * output_width; const int output_stride = output_depth * output_height * output_width;
const T* input_data = input.data<T>(); const T* input_data = input.data<T>();
T* output_data = output.mutable_data<T>(context.GetPlace()); T* output_data = output.mutable_data<T>(context.GetPlace());
T* mask_data = mask.mutable_data<T>(context.GetPlace()); T* mask_data = mask.mutable_data<T>(context.GetPlace());
...@@ -170,17 +170,17 @@ class MaxPool3dWithIndexFunctor<platform::CPUPlace, T> { ...@@ -170,17 +170,17 @@ class MaxPool3dWithIndexFunctor<platform::CPUPlace, T> {
int wstart = pw * stride_width - padding_width; int wstart = pw * stride_width - padding_width;
int wend = std::min(wstart + ksize_width, input_width); int wend = std::min(wstart + ksize_width, input_width);
wstart = std::max(wstart, 0); wstart = std::max(wstart, 0);
int output_idx = (pd * output_height + ph) * output_width + pw; int output_idx = (pd * output_height + ph) * output_width + pw;
T ele = static_cast<T>(-FLT_MAX); T ele = static_cast<T>(-FLT_MAX);
int index = -1; int index = -1;
for (int d = dstart; d < dend; ++d) { for (int d = dstart; d < dend; ++d) {
for (int h = hstart; h < hend; ++h) { for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) { for (int w = wstart; w < wend; ++w) {
if (ele < int input_idx = (d * input_height + h) * input_width + w;
input_data[(d * input_height + h) * input_width + w]) { if (ele < input_data[input_idx]) {
index = (d * input_height + h) * input_width + w; index = input_idx;
ele = ele = input_data[input_idx];
input_data[(d * input_height + h) * input_width + w];
} }
} }
} }
......
...@@ -20,14 +20,14 @@ namespace operators { ...@@ -20,14 +20,14 @@ namespace operators {
namespace math { namespace math {
template <typename T> template <typename T>
__global__ void KernelMaxPool2dWithIdxForward( __global__ void KernelMaxPool2dWithIdx(
const int nthreads, const T* input_data, T* output_data, T* mask_data, const int nthreads, const T* input_data, T* output_data, T* mask_data,
const int channels, const int input_height, const int input_width, const int channels, const int input_height, const int input_width,
const int output_height, const int output_width, const int ksize_height, const int output_height, const int output_width, const int ksize_height,
const int ksize_width, const int stride_height, const int stride_width, const int ksize_width, const int stride_height, const int stride_width,
const int padding_height, const int padding_width) { const int padding_height, const int padding_width) {
int index = blockIdx.x * blockDim.x + threadIdx.x; for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < (nthreads);
if (index < nthreads) { index += blockDim.x * gridDim.x) {
int pw = index % output_width; int pw = index % output_width;
int ph = (index / output_width) % output_height; int ph = (index / output_width) % output_height;
int c = (index / output_width / output_height) % channels; int c = (index / output_width / output_height) % channels;
...@@ -43,51 +43,58 @@ __global__ void KernelMaxPool2dWithIdxForward( ...@@ -43,51 +43,58 @@ __global__ void KernelMaxPool2dWithIdxForward(
input_data += (batch_idx * channels + c) * input_height * input_width; input_data += (batch_idx * channels + c) * input_height * input_width;
T ele = -FLT_MAX; T ele = -FLT_MAX;
int index = -1; int max_index = -1;
for (int h = hstart; h < hend; ++h) { for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) { for (int w = wstart; w < wend; ++w) {
if (ele < input_data[h * input_width + w]) { int input_index = h * input_width + w;
index = h * input_width + w; if (ele < input_data[input_index]) {
ele = input_data[h * input_width + w]; max_index = input_index;
ele = input_data[input_index];
} }
} }
} }
output_data[index] = ele; output_data[index] = ele;
mask_data[index] = index; mask_data[index] = max_index;
} }
} }
template <typename T> template <typename T>
__global__ void KernelMaxPool2DWithIdxBackward( __global__ void KernelMaxPool2DWithIdxGrad(
const int nthreads, T* input_grad, const T* output_grad, const T* mask_data, const int nthreads, T* input_grad, const T* output_grad, const T* mask_data,
const int channels, const int input_height, const int input_width, const int channels, const int input_height, const int input_width,
const int output_height, const int output_width, const int ksize_height, const int output_height, const int output_width, const int ksize_height,
const int ksize_width, const int stride_height, const int stride_width, const int ksize_width, const int stride_height, const int stride_width,
const int padding_height, const int padding_width) { const int padding_height, const int padding_width) {
int index = blockIdx.x * blockDim.x + threadIdx.x; for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < (nthreads);
if (index < nthreads) { index += blockDim.x * gridDim.x) {
int offsetW = index % input_width + padding_width; int w_offset = index % input_width;
int offsetH = (index / input_width) % input_height + padding_height; int h_offset = (index / input_width) % input_height;
int offsetC = (index / input_width / input_height) % channels; int c_offset = (index / input_width / input_height) % channels;
int batch_idx = index / input_width / input_height / channels; int batch_idx = index / input_width / input_height / channels;
int phstart = (offsetH < ksize_height) int ph_start =
(h_offset + padding_height < ksize_height)
? 0 ? 0
: (offsetH - ksize_height) / stride_height + 1; : (h_offset + padding_height - ksize_height) / stride_height + 1;
int pwstart = (offsetW < ksize_width) int pw_start =
(w_offset + padding_width < ksize_width)
? 0 ? 0
: (offsetW - ksize_width) / stride_width + 1; : (w_offset + padding_width - ksize_width) / stride_width + 1;
int phend = min(offsetH / stride_height + 1, output_height); int ph_end =
int pwend = min(offsetW / stride_width + 1, output_width); min((h_offset + padding_height) / stride_height + 1, output_height);
int pw_end =
min((w_offset + padding_width) / stride_width + 1, output_width);
T gradient = 0; T gradient = 0;
int input_current_featuremap_idx = h_offset * input_width + w_offset;
int output_idx = int output_idx =
(batch_idx * channels + offsetC) * output_height * output_width; (batch_idx * channels + c_offset) * output_height * output_width;
mask_data += output_idx; mask_data += output_idx;
output_grad += output_idx; output_grad += output_idx;
for (int ph = phstart; ph < phend; ++ph) { for (int ph = ph_start; ph < ph_end; ++ph) {
for (int pw = pwstart; pw < pwend; ++pw) { for (int pw = pw_start; pw < pw_end; ++pw) {
if ((offsetH * input_width + offsetW) == if (mask_data[ph * output_width + pw] == input_current_featuremap_idx)
mask_data[ph * output_width + pw])
gradient += output_grad[ph * output_width + pw]; gradient += output_grad[ph * output_width + pw];
} }
} }
...@@ -125,7 +132,7 @@ class MaxPool2dWithIndexFunctor<platform::GPUPlace, T> { ...@@ -125,7 +132,7 @@ class MaxPool2dWithIndexFunctor<platform::GPUPlace, T> {
dim3 threads(1024, 1); dim3 threads(1024, 1);
dim3 grid(blocks, 1); dim3 grid(blocks, 1);
KernelMaxPool2dWithIdxForward< KernelMaxPool2dWithIdx<
T><<<grid, threads, 0, T><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context) reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(nthreads, input_data, output_data, mask_data, .stream()>>>(nthreads, input_data, output_data, mask_data,
...@@ -167,7 +174,7 @@ class MaxPool2dWithIndexGradFunctor<platform::GPUPlace, T> { ...@@ -167,7 +174,7 @@ class MaxPool2dWithIndexGradFunctor<platform::GPUPlace, T> {
dim3 threads(1024, 1); dim3 threads(1024, 1);
dim3 grid(blocks, 1); dim3 grid(blocks, 1);
KernelMaxPool2DWithIdxBackward< KernelMaxPool2DWithIdxGrad<
T><<<grid, threads, 0, T><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context) reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(nthreads, input_grad_data, output_grad_data, .stream()>>>(nthreads, input_grad_data, output_grad_data,
...@@ -184,7 +191,7 @@ template class MaxPool2dWithIndexFunctor<platform::GPUPlace, double>; ...@@ -184,7 +191,7 @@ template class MaxPool2dWithIndexFunctor<platform::GPUPlace, double>;
template class MaxPool2dWithIndexGradFunctor<platform::GPUPlace, double>; template class MaxPool2dWithIndexGradFunctor<platform::GPUPlace, double>;
template <typename T> template <typename T>
__global__ void KernelMaxPool3DWithIdxForward( __global__ void KernelMaxPool3DWithIdx(
const int nthreads, const T* input_data, T* output_data, T* mask_data, const int nthreads, const T* input_data, T* output_data, T* mask_data,
const int channels, const int input_depth, const int input_height, const int channels, const int input_depth, const int input_height,
const int input_width, const int output_depth, const int output_height, const int input_width, const int output_depth, const int output_height,
...@@ -200,6 +207,7 @@ __global__ void KernelMaxPool3DWithIdxForward( ...@@ -200,6 +207,7 @@ __global__ void KernelMaxPool3DWithIdxForward(
int c = (index / output_width / output_height / output_depth) % channels; int c = (index / output_width / output_height / output_depth) % channels;
int batch_idx = int batch_idx =
index / output_width / output_height / output_depth / channels; index / output_width / output_height / output_depth / channels;
int dstart = pd * stride_depth - padding_depth; int dstart = pd * stride_depth - padding_depth;
int hstart = ph * stride_height - padding_height; int hstart = ph * stride_height - padding_height;
int wstart = pw * stride_width - padding_width; int wstart = pw * stride_width - padding_width;
...@@ -209,8 +217,9 @@ __global__ void KernelMaxPool3DWithIdxForward( ...@@ -209,8 +217,9 @@ __global__ void KernelMaxPool3DWithIdxForward(
dstart = max(dstart, 0); dstart = max(dstart, 0);
hstart = max(hstart, 0); hstart = max(hstart, 0);
wstart = max(wstart, 0); wstart = max(wstart, 0);
T ele = -FLT_MAX; T ele = -FLT_MAX;
int index = -1; int max_index = -1;
input_data += input_data +=
(batch_idx * channels + c) * input_depth * input_height * input_width; (batch_idx * channels + c) * input_depth * input_height * input_width;
...@@ -218,19 +227,19 @@ __global__ void KernelMaxPool3DWithIdxForward( ...@@ -218,19 +227,19 @@ __global__ void KernelMaxPool3DWithIdxForward(
for (int h = hstart; h < hend; ++h) { for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) { for (int w = wstart; w < wend; ++w) {
if (ele < input_data[(d * input_height + h) * input_width + w]) { if (ele < input_data[(d * input_height + h) * input_width + w]) {
index = (d * input_height + h) * input_width + w; max_index = (d * input_height + h) * input_width + w;
ele = input_data[(d * input_height + h) * input_width + w]; ele = input_data[max_index];
} }
} }
} }
} }
output_data[index] = ele; output_data[index] = ele;
mask_data[index] = index; mask_data[index] = max_index;
} }
} }
template <typename T> template <typename T>
__global__ void KernelMaxPool3DWithIdxBackward( __global__ void KernelMaxPool3DWithIdxGrad(
const int nthreads, T* input_grad, const T* output_grad, const T* mask, const int nthreads, T* input_grad, const T* output_grad, const T* mask,
const int channels, const int input_depth, const int input_height, const int channels, const int input_depth, const int input_height,
const int input_width, const int output_depth, const int output_height, const int input_width, const int output_depth, const int output_height,
...@@ -240,37 +249,45 @@ __global__ void KernelMaxPool3DWithIdxBackward( ...@@ -240,37 +249,45 @@ __global__ void KernelMaxPool3DWithIdxBackward(
const int padding_width) { const int padding_width) {
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < (nthreads); for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < (nthreads);
index += blockDim.x * gridDim.x) { index += blockDim.x * gridDim.x) {
int offsetW = index % input_width + padding_width; int w_offset = index % input_width;
int offsetH = (index / input_width) % input_height + padding_height; int h_offset = (index / input_width) % input_height;
int offsetD = int d_offset = (index / input_width / input_height) % input_depth;
(index / input_width / input_height) % input_depth + padding_depth; int c_offset =
int offsetC = (index / input_width / input_height / input_depth) % channels; (index / input_width / input_height / input_depth) % channels;
int batch_idx = index / input_width / input_height / input_depth / channels; int batch_idx = index / input_width / input_height / input_depth / channels;
int pdstart = (offsetD < ksize_depth) int pd_start =
(d_offset + padding_depth < ksize_depth)
? 0 ? 0
: (offsetD - ksize_depth) / stride_depth + 1; : (d_offset + padding_depth - ksize_depth) / stride_depth + 1;
int phstart = (offsetH < ksize_height) int ph_start =
(h_offset + padding_height < ksize_height)
? 0 ? 0
: (offsetH - ksize_height) / stride_height + 1; : (h_offset + padding_height - ksize_height) / stride_height + 1;
int pwstart = (offsetW < ksize_width) int pw_start =
(w_offset + padding_width < ksize_width)
? 0 ? 0
: (offsetW - ksize_width) / stride_width + 1; : (w_offset + padding_width - ksize_width) / stride_width + 1;
int pdend = min((offsetD) / stride_depth + 1, output_depth); int pd_end =
int phend = min((offsetH) / stride_height + 1, output_height); min((d_offset + padding_depth) / stride_depth + 1, output_depth);
int pwend = min((offsetW) / stride_width + 1, output_width); int ph_end =
min((h_offset + padding_height) / stride_height + 1, output_height);
int pw_end =
min((w_offset + padding_width) / stride_width + 1, output_width);
T gradient = 0; T gradient = 0;
int output_idx = (batch_idx * channels + offsetC) * output_depth * int input_current_feature_map_idx =
(d_offset * input_height + h_offset) * input_width + w_offset;
int output_idx = (batch_idx * channels + c_offset) * output_depth *
output_height * output_width; output_height * output_width;
mask += output_idx; mask += output_idx;
output_grad += output_idx; output_grad += output_idx;
for (int pd = pdstart; pd < pdend; ++pd) { for (int pd = pd_start; pd < pd_end; ++pd) {
for (int ph = phstart; ph < phend; ++ph) { for (int ph = ph_start; ph < ph_end; ++ph) {
for (int pw = pwstart; pw < pwend; ++pw) { for (int pw = pw_start; pw < pw_end; ++pw) {
if (((offsetD * input_height + offsetH) * input_width + offsetW) == if (mask[(pd * output_height + ph) * output_width + pw] ==
mask[(pd * output_height + ph) * output_width + pw]) input_current_feature_map_idx)
gradient += gradient +=
output_grad[(pd * output_height + ph) * output_width + pw]; output_grad[(pd * output_height + ph) * output_width + pw];
} }
...@@ -308,7 +325,7 @@ class MaxPool3dWithIndexFunctor<platform::GPUPlace, T> { ...@@ -308,7 +325,7 @@ class MaxPool3dWithIndexFunctor<platform::GPUPlace, T> {
const T* input_data = input.data<T>(); const T* input_data = input.data<T>();
T* output_data = output.mutable_data<T>(context.GetPlace()); T* output_data = output.mutable_data<T>(context.GetPlace());
T* mask_data = output.mutable_data<T>(context.GetPlace()); T* mask_data = mask.mutable_data<T>(context.GetPlace());
int nthreads = batch_size * output_channels * output_depth * output_height * int nthreads = batch_size * output_channels * output_depth * output_height *
output_width; output_width;
...@@ -316,7 +333,7 @@ class MaxPool3dWithIndexFunctor<platform::GPUPlace, T> { ...@@ -316,7 +333,7 @@ class MaxPool3dWithIndexFunctor<platform::GPUPlace, T> {
dim3 threads(1024, 1); dim3 threads(1024, 1);
dim3 grid(blocks, 1); dim3 grid(blocks, 1);
KernelMaxPool3DWithIdxForward< KernelMaxPool3DWithIdx<
T><<<grid, threads, 0, T><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context) reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>( .stream()>>>(
...@@ -341,10 +358,10 @@ class MaxPool3dWithIndexGradFunctor<platform::GPUPlace, T> { ...@@ -341,10 +358,10 @@ class MaxPool3dWithIndexGradFunctor<platform::GPUPlace, T> {
const int input_depth = input_grad.dims()[2]; const int input_depth = input_grad.dims()[2];
const int input_height = input_grad.dims()[3]; const int input_height = input_grad.dims()[3];
const int input_width = input_grad.dims()[4]; const int input_width = input_grad.dims()[4];
const int output_channels = input_grad.dims()[1]; const int output_channels = output_grad.dims()[1];
const int output_depth = input_grad.dims()[2]; const int output_depth = output_grad.dims()[2];
const int output_height = input_grad.dims()[3]; const int output_height = output_grad.dims()[3];
const int output_width = input_grad.dims()[4]; const int output_width = output_grad.dims()[4];
const int ksize_depth = ksize[0]; const int ksize_depth = ksize[0];
const int ksize_height = ksize[1]; const int ksize_height = ksize[1];
const int ksize_width = ksize[2]; const int ksize_width = ksize[2];
...@@ -365,7 +382,7 @@ class MaxPool3dWithIndexGradFunctor<platform::GPUPlace, T> { ...@@ -365,7 +382,7 @@ class MaxPool3dWithIndexGradFunctor<platform::GPUPlace, T> {
dim3 threads(1024, 1); dim3 threads(1024, 1);
dim3 grid(blocks, 1); dim3 grid(blocks, 1);
KernelMaxPool3DWithIdxBackward< KernelMaxPool3DWithIdxGrad<
T><<<grid, threads, 0, T><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context) reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>( .stream()>>>(
......
...@@ -23,7 +23,6 @@ namespace operators { ...@@ -23,7 +23,6 @@ namespace operators {
namespace math { namespace math {
////////////////////// //////////////////////
#define FLT_MAX __FLT_MAX__ #define FLT_MAX __FLT_MAX__
/////////////////////
template <typename Place, typename T> template <typename Place, typename T>
class MaxPool2dWithIndexFunctor { class MaxPool2dWithIndexFunctor {
......
...@@ -76,8 +76,8 @@ class MaxPoolWithIndexOpGrad : public framework::OperatorWithKernel { ...@@ -76,8 +76,8 @@ class MaxPoolWithIndexOpGrad : public framework::OperatorWithKernel {
protected: protected:
void InferShape(framework::InferShapeContextBase *ctx) const override { void InferShape(framework::InferShapeContextBase *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("X")), PADDLE_ENFORCE(ctx->HasInput("X"),
"X(Input) of MaxPoolWithIndexOpGrad should not be null."); "X(Input) of Pooling should not be null.");
PADDLE_ENFORCE( PADDLE_ENFORCE(
ctx->HasOutput(framework::GradVarName("X")), ctx->HasOutput(framework::GradVarName("X")),
"X@GRAD(Input@GRAD) of MaxPoolWithIndexOpGrad should not be null."); "X@GRAD(Input@GRAD) of MaxPoolWithIndexOpGrad should not be null.");
...@@ -97,28 +97,37 @@ class MaxPool2dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -97,28 +97,37 @@ class MaxPool2dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker {
"number of channels, H and W is the height and width of image."); "number of channels, H and W is the height and width of image.");
AddOutput("Out", AddOutput("Out",
"The output tensor of pooling operator." "The output tensor of pooling operator."
"The format of output tensor is also NCHW."); "The format of output tensor is also NCHW."
"Where N is batch size, C is "
"the number of channels, H and W is the height and "
"width of image.");
AddOutput("Mask", AddOutput("Mask",
"The Mask tensor of pooling operator." "The Mask tensor of pooling operator."
"The format of output tensor is also NCHW."); "The format of output tensor is also NCHW."
"Where N is batch size, C is the number of channels, H and W "
"is the height and width of image."
"The value in it is the index in current feature map");
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>(
"ksize", "pooling size(height, width) of pooling operator."); "ksize",
"Pooling size(height, width) of pooling operator."
"If globalPooling = true, ksize is ignored and need not be "
"specified."); // TODO(Add checker)
AddAttr<bool>( AddAttr<bool>(
"globalPooling", "globalPooling",
"whether to use the globalPooling." "Whether to use the globalPooling."
"int constant equal to false or true" "Bool constant equal to false or true."
"default false" "Default false."
"If globalPooling = true, ksize is ignored and need not be specified.") "If globalPooling = true, ksize is ignored and need not be specified.")
.SetDefault(false); .SetDefault(false);
AddAttr<std::vector<int>>("strides", AddAttr<std::vector<int>>("strides",
"strides(height, width) of pooling operator." "Strides(height, width) of pooling operator."
"default {1,1}") "Default {1,1}.")
.SetDefault({1, 1}); .SetDefault({1, 1}); // TODO(Add checker)
AddAttr<std::vector<int>>("paddings", AddAttr<std::vector<int>>("paddings",
"paddings(height, width) of pooling operator." "Paddings(height, width) of pooling operator."
"default {0,0}") "Default {0,0}.")
.SetDefault({0, 0}); .SetDefault({0, 0}); // TODO(Add checker)
AddComment(R"DOC( AddComment(R"DOC(
The maxPooling2d with index operation calculates the output and the mask based on The maxPooling2d with index operation calculates the output and the mask based on
...@@ -140,30 +149,40 @@ class MaxPool3dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -140,30 +149,40 @@ class MaxPool3dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker {
"image."); "image.");
AddOutput("Out", AddOutput("Out",
"The output tensor of pooling operator." "The output tensor of pooling operator."
"The format of output tensor is also NCDHW."); "The format of output tensor is also NCDHW."
"Where N is batch size, C is "
"the number of channels, D, H and W is the depth, height and "
"width of image.");
AddOutput("Mask", AddOutput("Mask",
"The Mask tensor of pooling operator." "The Mask tensor of pooling operator."
"The format of output tensor is also NCDHW."); "The format of output tensor is also NCDHW."
"Where N is batch size, C is the number of channels, D, H and W "
"is the depth, height and width of image."
"The value in it is the index in current feature map");
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>(
"ksize", "pooling size(depth, height, width) of pooling operator."); "ksize",
"Pooling size(depth, height, width) of pooling operator."
"If globalPooling = true, ksize is ignored and need not be "
"specified."); // TODO(Add checker)
AddAttr<bool>( AddAttr<bool>(
"globalPooling", "globalPooling",
"whether to use the globalPooling." "Whether to use the globalPooling."
"int constant equal to false or true" "Bool constant equal to false or true."
"default false" "Default false."
"If globalPooling = true, ksize is ignored and need not be specified.") "If globalPooling = true, ksize is ignored and need not be specified.")
.SetDefault(false); .SetDefault(false);
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>(
"strides", "strides",
"strides(depth, height, width) of pooling operator." "Strides(depth, height, width) of pooling operator."
"default {1,1,1}") "Default {1,1,1}.")
.SetDefault({1, 1, 1}); .SetDefault({1, 1, 1}); // TODO(Add checker)
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>(
"paddings", "paddings",
"paddings(depth, height, width) of pooling operator." "Paddings(depth, height, width) of pooling operator."
"default {0,0,0}") "Default {0,0,0}.")
.SetDefault({0, 0, 0}); .SetDefault({0, 0, 0}); // TODO(Add checker)
AddComment(R"DOC( AddComment(R"DOC(
The maxpooling3d with index operation calculates the output and the mask based on The maxpooling3d with index operation calculates the output and the mask based on
the input and ksize, strides, paddings parameters. the input and ksize, strides, paddings parameters.
......
...@@ -32,11 +32,10 @@ class MaxPoolWithIndexKernel : public framework::OpKernel { ...@@ -32,11 +32,10 @@ class MaxPoolWithIndexKernel : public framework::OpKernel {
Tensor* out = context.Output<Tensor>("Out"); Tensor* out = context.Output<Tensor>("Out");
Tensor* mask = context.Output<Tensor>("Mask"); Tensor* mask = context.Output<Tensor>("Mask");
bool global_pooling = context.Attr<bool>("globalPooling");
std::vector<int> ksize = context.Attr<std::vector<int>>("ksize"); std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
std::vector<int> strides = context.Attr<std::vector<int>>("strides"); std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings"); std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
if (global_pooling) { if (context.Attr<bool>("globalPooling")) {
for (size_t i = 0; i < ksize.size(); ++i) { for (size_t i = 0; i < ksize.size(); ++i) {
ksize[i] = static_cast<int>(in_x->dims()[i + 2]); ksize[i] = static_cast<int>(in_x->dims()[i + 2]);
} }
...@@ -63,7 +62,7 @@ template <typename Place, typename T> ...@@ -63,7 +62,7 @@ template <typename Place, typename T>
class MaxPoolWithIndexGradKernel : public framework::OpKernel { class MaxPoolWithIndexGradKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
const Tensor* mask = context.Input<Tensor>("Maks"); const Tensor* mask = context.Input<Tensor>("Mask");
const Tensor* out_grad = const Tensor* out_grad =
context.Input<Tensor>(framework::GradVarName("Out")); context.Input<Tensor>(framework::GradVarName("Out"));
Tensor* in_x_grad = context.Output<Tensor>(framework::GradVarName("X")); Tensor* in_x_grad = context.Output<Tensor>(framework::GradVarName("X"));
...@@ -71,6 +70,11 @@ class MaxPoolWithIndexGradKernel : public framework::OpKernel { ...@@ -71,6 +70,11 @@ class MaxPoolWithIndexGradKernel : public framework::OpKernel {
std::vector<int> ksize = context.Attr<std::vector<int>>("ksize"); std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
std::vector<int> strides = context.Attr<std::vector<int>>("strides"); std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings"); std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
if (context.Attr<bool>("globalPooling")) {
for (size_t i = 0; i < ksize.size(); ++i) {
ksize[i] = static_cast<int>(in_x_grad->dims()[i + 2]);
}
}
if (in_x_grad) { if (in_x_grad) {
in_x_grad->mutable_data<T>(context.GetPlace()); in_x_grad->mutable_data<T>(context.GetPlace());
......
...@@ -3,7 +3,11 @@ import numpy as np ...@@ -3,7 +3,11 @@ import numpy as np
from op_test import OpTest from op_test import OpTest
def max_pool3D_forward_naive(x, ksize, strides, paddings=[0, 0], global_pool=0): def max_pool3D_forward_naive(x,
ksize,
strides,
paddings=[0, 0, 0],
global_pool=0):
N, C, D, H, W = x.shape N, C, D, H, W = x.shape
if global_pool == 1: if global_pool == 1:
...@@ -25,8 +29,19 @@ def max_pool3D_forward_naive(x, ksize, strides, paddings=[0, 0], global_pool=0): ...@@ -25,8 +29,19 @@ def max_pool3D_forward_naive(x, ksize, strides, paddings=[0, 0], global_pool=0):
x_masked = x[:, :, d_start:d_end, h_start:h_end, w_start:w_end] x_masked = x[:, :, d_start:d_end, h_start:h_end, w_start:w_end]
out[:, :, k, i, j] = np.max(x_masked, axis=(2, 3, 4)) out[:, :, k, i, j] = np.max(x_masked, axis=(2, 3, 4))
# mask[:,:, k, i, j] = np.argmax(x_masked, axis=(2, 3, 4))
return out for n in xrange(N):
for c in xrange(C):
arr = x_masked[n, c, :, :, :]
index = np.where(arr == np.max(arr))
sub_deep = index[0][0]
sub_row = index[1][0]
sub_col = index[2][0]
index = ((d_start + sub_deep) * H +
(h_start + sub_row)) * W + w_start + sub_col
mask[n, c, k, i, j] = index
return out, mask
def max_pool2D_forward_naive(x, ksize, strides, paddings=[0, 0], global_pool=0): def max_pool2D_forward_naive(x, ksize, strides, paddings=[0, 0], global_pool=0):
...@@ -47,19 +62,25 @@ def max_pool2D_forward_naive(x, ksize, strides, paddings=[0, 0], global_pool=0): ...@@ -47,19 +62,25 @@ def max_pool2D_forward_naive(x, ksize, strides, paddings=[0, 0], global_pool=0):
x_masked = x[:, :, r_start:r_end, c_start:c_end] x_masked = x[:, :, r_start:r_end, c_start:c_end]
out[:, :, i, j] = np.max(x_masked, axis=(2, 3)) out[:, :, i, j] = np.max(x_masked, axis=(2, 3))
# mask[:,:, i, j] = np.argmax(x_masked, axis=(2, 3))
return out for n in xrange(N):
for c in xrange(C):
arr = x_masked[n, c, :, :]
index = np.where(arr == np.max(arr))
sub_row = index[0][0]
sub_col = index[1][0]
index = (r_start + sub_row) * W + c_start + sub_col
mask[n, c, i, j] = index
return out, mask
class TestMaxPoolWithIndex_Op(OpTest): class TestMaxPoolWithIndex_Op(OpTest):
def setUp(self): def setUp(self):
self.initTestCase() self.initTestCase()
self.op_type = "maxPool3dWithIndex"
input = np.random.random(self.shape).astype("float32") input = np.random.random(self.shape).astype("float32")
output = self.pool_forward_naive(input, self.ksize, self.strides, output, mask = self.pool_forward_naive(input, self.ksize, self.strides,
self.paddings, self.global_pool) self.paddings, self.global_pool)
# mask = np.zeros(output.shape)
self.attrs = { self.attrs = {
'strides': self.strides, 'strides': self.strides,
...@@ -69,7 +90,7 @@ class TestMaxPoolWithIndex_Op(OpTest): ...@@ -69,7 +90,7 @@ class TestMaxPoolWithIndex_Op(OpTest):
} }
self.inputs = {'X': input} self.inputs = {'X': input}
self.outputs = {'Out': output} self.outputs = {'Out': output, "Mask": mask}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -78,7 +99,8 @@ class TestMaxPoolWithIndex_Op(OpTest): ...@@ -78,7 +99,8 @@ class TestMaxPoolWithIndex_Op(OpTest):
# self.check_grad(set(['X']), ['Out'], max_relative_error=0.07) # self.check_grad(set(['X']), ['Out'], max_relative_error=0.07)
def initTestCase(self): def initTestCase(self):
self.global_pool = 0 self.global_pool = False
self.op_type = "maxPool3dWithIndex"
self.pool_forward_naive = max_pool3D_forward_naive self.pool_forward_naive = max_pool3D_forward_naive
self.shape = [2, 3, 7, 7, 7] self.shape = [2, 3, 7, 7, 7]
self.ksize = [3, 3, 3] self.ksize = [3, 3, 3]
...@@ -86,10 +108,9 @@ class TestMaxPoolWithIndex_Op(OpTest): ...@@ -86,10 +108,9 @@ class TestMaxPoolWithIndex_Op(OpTest):
self.paddings = [1, 1, 1] self.paddings = [1, 1, 1]
""""
class TestCase1(TestMaxPoolWithIndex_Op): class TestCase1(TestMaxPoolWithIndex_Op):
def initTestCase(self): def initTestCase(self):
self.global_pool = 1 self.global_pool = True
self.op_type = "maxPool3dWithIndex" self.op_type = "maxPool3dWithIndex"
self.pool_forward_naive = max_pool3D_forward_naive self.pool_forward_naive = max_pool3D_forward_naive
self.shape = [2, 3, 5, 5, 5] self.shape = [2, 3, 5, 5, 5]
...@@ -100,7 +121,7 @@ class TestCase1(TestMaxPoolWithIndex_Op): ...@@ -100,7 +121,7 @@ class TestCase1(TestMaxPoolWithIndex_Op):
class TestCase2(TestMaxPoolWithIndex_Op): class TestCase2(TestMaxPoolWithIndex_Op):
def initTestCase(self): def initTestCase(self):
self.global_pool = 0 self.global_pool = False
self.op_type = "maxPool2dWithIndex" self.op_type = "maxPool2dWithIndex"
self.pool_forward_naive = max_pool2D_forward_naive self.pool_forward_naive = max_pool2D_forward_naive
self.shape = [2, 3, 7, 7] self.shape = [2, 3, 7, 7]
...@@ -111,7 +132,7 @@ class TestCase2(TestMaxPoolWithIndex_Op): ...@@ -111,7 +132,7 @@ class TestCase2(TestMaxPoolWithIndex_Op):
class TestCase3(TestMaxPoolWithIndex_Op): class TestCase3(TestMaxPoolWithIndex_Op):
def initTestCase(self): def initTestCase(self):
self.global_pool = 1 self.global_pool = True
self.op_type = "maxPool2dWithIndex" self.op_type = "maxPool2dWithIndex"
self.pool_forward_naive = max_pool2D_forward_naive self.pool_forward_naive = max_pool2D_forward_naive
self.shape = [2, 3, 5, 5] self.shape = [2, 3, 5, 5]
...@@ -122,4 +143,3 @@ class TestCase3(TestMaxPoolWithIndex_Op): ...@@ -122,4 +143,3 @@ class TestCase3(TestMaxPoolWithIndex_Op):
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
"""
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册