提交 0f1d3af4 编写于 作者: C chengduo 提交者: GitHub

Merge pull request #4461 from chengduoZH/Add_maxpool_withIdx_only

Add max pool op (with index)
......@@ -55,12 +55,20 @@ function(op_library TARGET)
set(pybind_flag 1)
endif()
# pool_op contains several operators
if ("${TARGET}" STREQUAL "pool_op")
set(pybind_flag 1)
# It's enough to just adding one operator to pybind
file(APPEND ${pybind_file} "USE_OP(pool2d);\n")
endif()
# pool_with_index_op contains several operators
if ("${TARGET}" STREQUAL "pool_with_index_op")
set(pybind_flag 1)
# It's enough to just adding one operator to pybind
file(APPEND ${pybind_file} "USE_OP(max_pool2d_with_index);\n")
endif()
# activation_op contains several operators
if ("${TARGET}" STREQUAL "activation_op")
set(pybind_flag 1)
......
......@@ -18,6 +18,11 @@ namespace paddle {
namespace operators {
namespace math {
/*
* All tensors are in NCHW format.
* Ksize, strides, paddings are two elements. These two elements represent
* height and width, respectively.
*/
template <typename PoolProcess, typename T>
class Pool2dFunctor<platform::CPUPlace, PoolProcess, T> {
public:
......@@ -73,6 +78,11 @@ class Pool2dFunctor<platform::CPUPlace, PoolProcess, T> {
}
};
/*
* All tensors are in NCHW format.
* Ksize, strides, paddings are two elements. These two elements represent height
* and width, respectively.
*/
template <typename PoolProcess, class T>
class Pool2dGradFunctor<platform::CPUPlace, PoolProcess, T> {
public:
......@@ -135,6 +145,11 @@ class Pool2dGradFunctor<platform::CPUPlace, PoolProcess, T> {
}
};
/*
* All tensors are in NCHW format.
* Ksize, strides, paddings are two elements. These two elements represent
* height and width, respectively.
*/
template <class T>
class MaxPool2dGradFunctor<platform::CPUPlace, T> {
public:
......@@ -197,7 +212,7 @@ class MaxPool2dGradFunctor<platform::CPUPlace, T> {
};
template class MaxPool2dGradFunctor<platform::CPUPlace, float>;
// template class MaxPool2dGradFunctor<platform::CPUPlace, double>;
template class MaxPool2dGradFunctor<platform::CPUPlace, double>;
template class Pool2dFunctor<platform::CPUPlace,
paddle::operators::math::MaxPool<float>, float>;
......@@ -216,6 +231,11 @@ template class Pool2dGradFunctor<
template class Pool2dGradFunctor<
platform::CPUPlace, paddle::operators::math::AvgPoolGrad<double>, double>;
/*
* All tensors are in NCDHW format.
* Ksize, strides, paddings are three elements. These three elements represent
* depth, height and width, respectively.
*/
template <typename PoolProcess, class T>
class Pool3dFunctor<platform::CPUPlace, PoolProcess, T> {
public:
......@@ -286,6 +306,11 @@ class Pool3dFunctor<platform::CPUPlace, PoolProcess, T> {
}
};
/*
* All tensors are in NCDHW format.
* Ksize, strides, paddings are three elements. These three elements represent
* depth, height and width, respectively.
*/
template <typename PoolProcess, class T>
class Pool3dGradFunctor<platform::CPUPlace, PoolProcess, T> {
public:
......@@ -364,6 +389,11 @@ class Pool3dGradFunctor<platform::CPUPlace, PoolProcess, T> {
}
};
/*
* All tensors are in NCDHW format.
* Ksize, strides, paddings are three elements. These three elements represent
* depth, height and width, respectively.
*/
template <class T>
class MaxPool3dGradFunctor<platform::CPUPlace, T> {
public:
......@@ -440,7 +470,7 @@ class MaxPool3dGradFunctor<platform::CPUPlace, T> {
};
template class MaxPool3dGradFunctor<platform::CPUPlace, float>;
// template class MaxPool3dGradFunctor<platform::CPUPlace, double>;
template class MaxPool3dGradFunctor<platform::CPUPlace, double>;
template class Pool3dFunctor<platform::CPUPlace,
paddle::operators::math::MaxPool<float>, float>;
......@@ -458,6 +488,253 @@ template class Pool3dGradFunctor<
platform::CPUPlace, paddle::operators::math::MaxPoolGrad<double>, double>;
template class Pool3dGradFunctor<
platform::CPUPlace, paddle::operators::math::AvgPoolGrad<double>, double>;
/*
* All tensors are in NCHW format.
* Ksize, strides, paddings are two elements. These two elements represent
* height and width, respectively.
*/
template <typename T>
class MaxPool2dWithIndexFunctor<platform::CPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& output,
framework::Tensor& mask, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings) {
const int batch_size = input.dims()[0];
const int input_height = input.dims()[2];
const int input_width = input.dims()[3];
const int output_channels = output.dims()[1];
const int output_height = output.dims()[2];
const int output_width = output.dims()[3];
const int ksize_height = ksize[0];
const int ksize_width = ksize[1];
const int stride_height = strides[0];
const int stride_width = strides[1];
const int padding_height = paddings[0];
const int padding_width = paddings[1];
const int input_stride = input_height * input_width;
const int output_stride = output_height * output_width;
const T* input_data = input.data<T>();
T* output_data = output.mutable_data<T>(context.GetPlace());
T* mask_data = mask.mutable_data<T>(context.GetPlace());
for (int i = 0; i < batch_size; i++) {
for (int c = 0; c < output_channels; ++c) {
for (int ph = 0; ph < output_height; ++ph) {
int hstart = ph * stride_height - padding_height;
int hend = std::min(hstart + ksize_height, input_height);
hstart = std::max(hstart, 0);
for (int pw = 0; pw < output_width; ++pw) {
int wstart = pw * stride_width - padding_width;
int wend = std::min(wstart + ksize_width, input_width);
wstart = std::max(wstart, 0);
T ele = static_cast<T>(-FLT_MAX);
int index = -1;
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
if (ele < input_data[h * input_width + w]) {
ele = input_data[h * input_width + w];
index = h * input_width + w;
}
}
}
output_data[ph * output_width + pw] = ele;
mask_data[ph * output_width + pw] = index;
}
}
// offset
input_data += input_stride;
output_data += output_stride;
mask_data += output_stride;
}
}
}
};
/*
* All tensors are in NCHW format.
* Ksize, strides, paddings are two elements. These two elements represent
* height and width, respectively.
*/
template <typename T>
class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
framework::Tensor& input_grad,
const framework::Tensor& output_grad,
const framework::Tensor& mask, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings) {
const int batch_size = input_grad.dims()[0];
const int input_height = input_grad.dims()[2];
const int input_width = input_grad.dims()[3];
const int output_channels = output_grad.dims()[1];
const int output_height = output_grad.dims()[2];
const int output_width = output_grad.dims()[3];
const int input_stride = input_height * input_width;
const int output_stride = output_height * output_width;
const T* mask_data = mask.data<T>();
const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad.mutable_data<T>(context.GetPlace());
for (int n = 0; n < batch_size; ++n) {
for (int c = 0; c < output_channels; ++c) {
for (int ph = 0; ph < output_height; ++ph) {
for (int pw = 0; pw < output_width; ++pw) {
const int output_idx = ph * output_width + pw;
const int input_idx = static_cast<int>(mask_data[output_idx]);
input_grad_data[input_idx] += output_grad_data[output_idx];
}
}
// offset
input_grad_data += input_stride;
output_grad_data += output_stride;
mask_data += output_stride;
}
}
}
};
template class MaxPool2dWithIndexFunctor<platform::CPUPlace, float>;
template class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, float>;
template class MaxPool2dWithIndexFunctor<platform::CPUPlace, double>;
template class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, double>;
/*
* All tensors are in NCDHW format.
* Ksize, strides, paddings are three elements. These three elements represent
* depth, height and width, respectively.
*/
template <typename T>
class MaxPool3dWithIndexFunctor<platform::CPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& output,
framework::Tensor& mask, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings) {
const int batch_size = input.dims()[0];
const int input_depth = input.dims()[2];
const int input_height = input.dims()[3];
const int input_width = input.dims()[4];
const int output_channels = output.dims()[1];
const int output_depth = output.dims()[2];
const int output_height = output.dims()[3];
const int output_width = output.dims()[4];
const int ksize_depth = ksize[0];
const int ksize_height = ksize[1];
const int ksize_width = ksize[2];
const int stride_depth = strides[0];
const int stride_height = strides[1];
const int stride_width = strides[2];
const int padding_depth = paddings[0];
const int padding_height = paddings[1];
const int padding_width = paddings[2];
const int input_stride = input_depth * input_height * input_width;
const int output_stride = output_depth * output_height * output_width;
const T* input_data = input.data<T>();
T* output_data = output.mutable_data<T>(context.GetPlace());
T* mask_data = mask.mutable_data<T>(context.GetPlace());
for (int i = 0; i < batch_size; i++) {
for (int c = 0; c < output_channels; ++c) {
for (int pd = 0; pd < output_depth; ++pd) {
int dstart = pd * stride_depth - padding_depth;
int dend = std::min(dstart + ksize_depth, input_depth);
dstart = std::max(dstart, 0);
for (int ph = 0; ph < output_height; ++ph) {
int hstart = ph * stride_height - padding_height;
int hend = std::min(hstart + ksize_height, input_height);
hstart = std::max(hstart, 0);
for (int pw = 0; pw < output_width; ++pw) {
int wstart = pw * stride_width - padding_width;
int wend = std::min(wstart + ksize_width, input_width);
wstart = std::max(wstart, 0);
int output_idx = (pd * output_height + ph) * output_width + pw;
T ele = static_cast<T>(-FLT_MAX);
int index = -1;
for (int d = dstart; d < dend; ++d) {
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
int input_idx = (d * input_height + h) * input_width + w;
if (ele < input_data[input_idx]) {
index = input_idx;
ele = input_data[input_idx];
}
}
}
}
output_data[output_idx] = ele;
mask_data[output_idx] = index;
}
}
}
// offset
input_data += input_stride;
output_data += output_stride;
mask_data += output_stride;
}
}
}
};
/*
* All tensors are in NCDHW format.
* Ksize, strides, paddings are three elements. These three elements represent
* depth, height and width, respectively.
*/
template <typename T>
class MaxPool3dWithIndexGradFunctor<platform::CPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
framework::Tensor& input_grad,
const framework::Tensor& output_grad,
const framework::Tensor& mask, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings) {
const int batch_size = input_grad.dims()[0];
const int input_depth = input_grad.dims()[2];
const int input_height = input_grad.dims()[3];
const int input_width = input_grad.dims()[4];
const int output_channels = output_grad.dims()[1];
const int output_depth = output_grad.dims()[2];
const int output_height = output_grad.dims()[3];
const int output_width = output_grad.dims()[4];
const int input_stride = input_depth * input_height * input_width;
const int output_stride = output_depth * output_height * output_width;
const T* mask_data = mask.data<T>();
const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad.mutable_data<T>(context.GetPlace());
for (int n = 0; n < batch_size; ++n) {
for (int c = 0; c < output_channels; ++c) {
for (int pd = 0; pd < output_depth; ++pd) {
for (int ph = 0; ph < output_height; ++ph) {
for (int pw = 0; pw < output_width; ++pw) {
const int output_idx =
(pd * output_height + ph) * output_width + pw;
const int input_idx = static_cast<int>(mask_data[output_idx]);
input_grad_data[input_idx] += output_grad_data[output_idx];
}
}
}
// offset
input_grad_data += input_stride;
output_grad_data += output_stride;
mask_data += output_stride;
}
}
}
};
template class MaxPool3dWithIndexFunctor<platform::CPUPlace, float>;
template class MaxPool3dWithIndexGradFunctor<platform::CPUPlace, float>;
template class MaxPool3dWithIndexFunctor<platform::CPUPlace, double>;
template class MaxPool3dWithIndexGradFunctor<platform::CPUPlace, double>;
} // namespace math
} // namespace operators
} // namespace paddle
此差异已折叠。
......@@ -21,15 +21,27 @@ limitations under the License. */
namespace paddle {
namespace operators {
namespace math {
//////////////////////
#define FLT_MAX __FLT_MAX__ //
#define FLT_MAX \
__FLT_MAX__ // It might need to be placed in another file, but I'm still
// wondering where to put it.
/*
* \brief Extracting simple operations from pooling.
* Both MaxPool and AvgPool need "initial", "compute" and "finalize"
* operation.
* MaxPool initializes temp variable to the negative maximum to find the
* maximum value in the pooling field.
* AvgPool initializes temp variable to the zero to accumulate all values
* in pool pooling, and finally takes the average.
* MaxPoolGrad and AvgPoolGrad are gradient operations respectively.
*/
template <class T>
class MaxPool {
public:
DEVICE inline T initial() { return static_cast<T>(-FLT_MAX); }
DEVICE inline void compute(T& y, const T& x) { y = y > x ? y : x; }
DEVICE inline void finalize(T& y, const T& poo_size) {}
DEVICE inline void finalize(T& y, const T& pool_field) {}
};
template <class T>
......@@ -37,8 +49,9 @@ class AvgPool {
public:
DEVICE inline T initial() { return static_cast<T>(0); }
DEVICE inline void compute(T& y, const T& x) { y += x; }
DEVICE inline void finalize(T& y, const T& poo_size) { y /= poo_size; }
DEVICE inline void finalize(T& y, const T& pool_field) { y /= pool_field; }
};
template <class T>
class MaxPoolGrad {
public:
......@@ -57,6 +70,20 @@ class AvgPoolGrad {
}
};
/*
* \brief Getting pooling results, and calculating gradient.
*
* In pool2d, all tensors are in NCHW format. Where N is batch size, C is the
* number of channels, H and W is the height and width of feature.
* In pool3d, all tensors are in NCDHW format. Where N is batch size, C is the
* number of channels, D, H and W is the depth, height and width of feature.
*
* In max pooling, it is possible that the pooling region has multiple maximum
* elements. In this case, we should compute the gradient of the first maximum
* element.
* This is different from average pooling. So we rewrite the max_pool_grad:
* MaxPool2dGradFunctor, MaxPool3dGradFunctor.
*/
template <typename Place, typename PoolProcess, typename T>
class Pool2dFunctor {
public:
......@@ -117,6 +144,51 @@ class MaxPool3dGradFunctor {
std::vector<int>& strides, std::vector<int>& paddings);
};
/*
* \brief Getting max pooling results and corresponding max index, and
* calculating gradient.
* In up-sampling-pooling, it is necessary to know max element index.
* In pool2d, all tensors are in NCHW format. In pool3d, all tensors are in
* NCDHW format.
*/
template <typename Place, typename T>
class MaxPool2dWithIndexFunctor {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& output,
framework::Tensor& mask, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings);
};
template <typename Place, typename T>
class MaxPool2dWithIndexGradFunctor {
public:
void operator()(const platform::DeviceContext& context,
framework::Tensor& input_grad,
const framework::Tensor& output_grad,
const framework::Tensor& mask, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings);
};
template <typename Place, typename T>
class MaxPool3dWithIndexFunctor {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& output,
framework::Tensor& mask, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings);
};
template <typename Place, typename T>
class MaxPool3dWithIndexGradFunctor {
public:
void operator()(const platform::DeviceContext& context,
framework::Tensor& input_grad,
const framework::Tensor& output_grad,
const framework::Tensor& mask, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings);
};
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/pool_with_index_op.h"
namespace paddle {
namespace operators {
inline int OutputSizeMaxPool(int input_size, int filter_size, int padding,
int stride) {
int output_size = (input_size - filter_size + 2 * padding) / stride + 1;
return output_size;
}
class MaxPoolWithIndexOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"X(Input) of Pooling should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Out(Output) of Pooling should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Mask"),
"Mask(Output) of Pooling should not be null.");
auto in_x_dims = ctx->GetInputDim("X");
std::vector<int> ksize = ctx->Attrs().Get<std::vector<int>>("ksize");
std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
PADDLE_ENFORCE(in_x_dims.size() == 4 || in_x_dims.size() == 5,
"Pooling intput should be 4-D or 5-D");
if (ctx->Attrs().Get<bool>("globalPooling")) {
ksize.resize(static_cast<size_t>(in_x_dims.size()) - 2);
for (size_t i = 0; i < ksize.size(); ++i)
ksize[i] = static_cast<int>(in_x_dims[i + 2]);
}
PADDLE_ENFORCE(in_x_dims.size() - ksize.size() == 2U,
"Intput size and pooling size should be consistent.");
PADDLE_ENFORCE_EQ(ksize.size(), strides.size(),
"Strides size and pooling size should be the same.");
PADDLE_ENFORCE_EQ(ksize.size(), paddings.size(),
"Paddings size and pooling size should be the same.");
std::vector<int64_t> output_shape({in_x_dims[0], in_x_dims[1]});
for (size_t i = 0; i < ksize.size(); ++i) {
output_shape.push_back(OutputSizeMaxPool(in_x_dims[i + 2], ksize[i],
paddings[i], strides[i]));
}
ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
ctx->SetOutputDim("Mask", framework::make_ddim(output_shape));
}
};
class MaxPoolWithIndexOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
"Input(X@GRAD) should not be null.");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
};
class MaxPool2dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker {
public:
MaxPool2dWithIndexOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(
"X",
"The input tensor of pooling operator. "
"The format of input tensor is NCHW. Where N is batch size, C is the "
"number of channels, H and W is the height and width of image.");
AddOutput("Out",
"The output tensor of pooling operator."
"The format of output tensor is also NCHW."
"Where N is batch size, C is "
"the number of channels, H and W is the height and "
"width of image.");
AddOutput("Mask",
"The Mask tensor of pooling operator."
"The format of output tensor is also NCHW."
"Where N is batch size, C is the number of channels, H and W "
"is the height and width of image."
"The value in it is the index in current feature map");
AddAttr<std::vector<int>>(
"ksize",
"The pooling size(height, width) of pooling operator."
"If globalPooling = true, ksize is ignored and need not be "
"specified."); // TODO(Chengduo): Add checker. (Currently,
// TypedAttrChecker don't support vector type.)
AddAttr<bool>(
"globalPooling",
"Whether to use the globalPooling."
"Bool constant equal to false or true."
"Default false."
"If globalPooling = true, ksize is ignored and need not be specified.")
.SetDefault(false);
AddAttr<std::vector<int>>("strides",
"Strides(height, width) of pooling operator."
"Default {1,1}.")
.SetDefault({1, 1}); // TODO(Chengduo): Add checker. (Currently,
// TypedAttrChecker don't support vector type.)
AddAttr<std::vector<int>>("paddings",
"Paddings(height, width) of pooling operator."
"Default {0,0}.")
.SetDefault({0, 0}); // TODO(Chengduo): Add checker. (Currently,
// TypedAttrChecker don't support vector type.)
AddComment(R"DOC(
The maxPooling2d with index operation calculates the output and the mask
based on the input and ksize, strides, paddings parameters. Input(X) and
output(Out, Mask) are in NCHW format. Where N is batch size, C is the
number of channels, H and W is the height and width of feature.
Parameters(ksize, strides, paddings) are two elements.
These two elements represent height and width, respectively.
)DOC");
}
};
class MaxPool3dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker {
public:
MaxPool3dWithIndexOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(
"X",
"The input tensor of pooling operator. "
"The format of input tensor is NCDHW. Where N is batch size, C is "
"the number of channels, D, H and W is the depth, height and width of "
"image.");
AddOutput("Out",
"The output tensor of pooling operator."
"The format of output tensor is also NCDHW."
"Where N is batch size, C is "
"the number of channels, D, H and W is the depth, height and "
"width of image.");
AddOutput("Mask",
"The Mask tensor of pooling operator."
"The format of output tensor is also NCDHW."
"Where N is batch size, C is the number of channels, D, H and W "
"is the depth, height and width of image."
"The value in it is the index in current feature map");
AddAttr<std::vector<int>>(
"ksize",
"The pooling size(depth, height, width) of pooling operator."
"If globalPooling = true, ksize is ignored and need not be "
"specified."); // TODO(Chengduo): Add checker. (Currently,
// TypedAttrChecker don't support vector type.)
AddAttr<bool>(
"globalPooling",
"Whether to use the globalPooling."
"Bool constant equal to false or true."
"Default false."
"If globalPooling = true, ksize is ignored and need not be specified.")
.SetDefault(false);
AddAttr<std::vector<int>>(
"strides",
"Strides(depth, height, width) of pooling operator."
"Default {1,1,1}.")
.SetDefault({1, 1, 1}); // TODO(Chengduo): Add checker. (Currently,
// TypedAttrChecker don't support vector type.)
AddAttr<std::vector<int>>(
"paddings",
"Paddings(depth, height, width) of pooling operator."
"Default {0,0,0}.")
.SetDefault({0, 0, 0}); // TODO(Chengduo): Add checker. (Currently,
// TypedAttrChecker don't support vector type.)
AddComment(R"DOC(
The maxpooling3d with index operation calculates the output and the mask
based on the input and ksize, strides, paddings parameters.
Input(X) and output(Out, Mask) are in NCDHW format. Where N is batch
size, C is the number of channels, D, H and W is the depth, height and
width of feature. Parameters(ksize, strides, paddings) are three elements.
These three elements represent depth, height and width, respectively.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(max_pool2d_with_index, ops::MaxPoolWithIndexOp,
ops::MaxPool2dWithIndexOpMaker, max_pool2d_with_index_grad,
ops::MaxPoolWithIndexOpGrad);
REGISTER_OP_CPU_KERNEL(
max_pool2d_with_index,
ops::MaxPoolWithIndexKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
max_pool2d_with_index_grad,
ops::MaxPoolWithIndexGradKernel<paddle::platform::CPUPlace, float>)
REGISTER_OP(max_pool3d_with_index, ops::MaxPoolWithIndexOp,
ops::MaxPool3dWithIndexOpMaker, max_pool3d_with_index_grad,
ops::MaxPoolWithIndexOpGrad);
REGISTER_OP_CPU_KERNEL(
max_pool3d_with_index,
ops::MaxPoolWithIndexKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
max_pool3d_with_index_grad,
ops::MaxPoolWithIndexGradKernel<paddle::platform::CPUPlace, float>)
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/pool_with_index_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
max_pool2d_with_index,
ops::MaxPoolWithIndexKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
max_pool2d_with_index_grad,
ops::MaxPoolWithIndexGradKernel<paddle::platform::GPUPlace, float>)
REGISTER_OP_GPU_KERNEL(
max_pool3d_with_index,
ops::MaxPoolWithIndexKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
max_pool3d_with_index_grad,
ops::MaxPoolWithIndexGradKernel<paddle::platform::GPUPlace, float>)
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/pooling.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename Place, typename T>
class MaxPoolWithIndexKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* in_x = context.Input<Tensor>("X");
Tensor* out = context.Output<Tensor>("Out");
Tensor* mask = context.Output<Tensor>("Mask");
std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
if (context.Attr<bool>("globalPooling")) {
for (size_t i = 0; i < ksize.size(); ++i) {
ksize[i] = static_cast<int>(in_x->dims()[i + 2]);
}
}
switch (ksize.size()) {
case 2: {
paddle::operators::math::MaxPool2dWithIndexFunctor<Place, T>
pool2d_forward;
pool2d_forward(context.device_context(), *in_x, *out, *mask, ksize,
strides, paddings);
} break;
case 3: {
paddle::operators::math::MaxPool3dWithIndexFunctor<Place, T>
pool3d_forward;
pool3d_forward(context.device_context(), *in_x, *out, *mask, ksize,
strides, paddings);
} break;
}
}
};
template <typename Place, typename T>
class MaxPoolWithIndexGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* mask = context.Input<Tensor>("Mask");
const Tensor* out_grad =
context.Input<Tensor>(framework::GradVarName("Out"));
Tensor* in_x_grad = context.Output<Tensor>(framework::GradVarName("X"));
std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
if (context.Attr<bool>("globalPooling")) {
for (size_t i = 0; i < ksize.size(); ++i) {
ksize[i] = static_cast<int>(in_x_grad->dims()[i + 2]);
}
}
if (in_x_grad) {
in_x_grad->mutable_data<T>(context.GetPlace());
auto temp = framework::EigenVector<T>::Flatten(*in_x_grad);
temp.device(context.GetEigenDevice<Place>()) =
temp.constant(static_cast<T>(0));
switch (ksize.size()) {
case 2: {
paddle::operators::math::MaxPool2dWithIndexGradFunctor<Place, T>
pool2d_backward;
pool2d_backward(context.device_context(), *in_x_grad, *out_grad,
*mask, ksize, strides, paddings);
} break;
case 3: {
paddle::operators::math::MaxPool3dWithIndexGradFunctor<Place, T>
pool3d_backward;
pool3d_backward(context.device_context(), *in_x_grad, *out_grad,
*mask, ksize, strides, paddings);
} break;
}
}
}
};
} // namespace operators
} // namespace paddle
import unittest
import numpy as np
from op_test import OpTest
def max_pool3D_forward_naive(x,
ksize,
strides,
paddings=[0, 0, 0],
global_pool=0):
N, C, D, H, W = x.shape
if global_pool == 1:
ksize = [D, H, W]
D_out = (D - ksize[0] + 2 * paddings[0]) / strides[0] + 1
H_out = (H - ksize[1] + 2 * paddings[1]) / strides[1] + 1
W_out = (W - ksize[2] + 2 * paddings[2]) / strides[2] + 1
out = np.zeros((N, C, D_out, H_out, W_out))
mask = np.zeros((N, C, D_out, H_out, W_out))
for k in xrange(D_out):
d_start = np.max((k * strides[0] - paddings[0], 0))
d_end = np.min((k * strides[0] + ksize[0] - paddings[0], D))
for i in xrange(H_out):
h_start = np.max((i * strides[0] - paddings[0], 0))
h_end = np.min((i * strides[0] + ksize[0] - paddings[0], H))
for j in xrange(W_out):
w_start = np.max((j * strides[1] - paddings[1], 0))
w_end = np.min((j * strides[1] + ksize[1] - paddings[1], W))
x_masked = x[:, :, d_start:d_end, h_start:h_end, w_start:w_end]
out[:, :, k, i, j] = np.max(x_masked, axis=(2, 3, 4))
for n in xrange(N):
for c in xrange(C):
arr = x_masked[n, c, :, :, :]
index = np.where(arr == np.max(arr))
sub_deep = index[0][0]
sub_row = index[1][0]
sub_col = index[2][0]
index = ((d_start + sub_deep) * H +
(h_start + sub_row)) * W + w_start + sub_col
mask[n, c, k, i, j] = index
return out, mask
def max_pool2D_forward_naive(x, ksize, strides, paddings=[0, 0], global_pool=0):
N, C, H, W = x.shape
if global_pool == 1:
ksize = [H, W]
H_out = (H - ksize[0] + 2 * paddings[0]) / strides[0] + 1
W_out = (W - ksize[1] + 2 * paddings[1]) / strides[1] + 1
out = np.zeros((N, C, H_out, W_out))
mask = np.zeros((N, C, H_out, W_out))
for i in xrange(H_out):
for j in xrange(W_out):
r_start = np.max((i * strides[0] - paddings[0], 0))
r_end = np.min((i * strides[0] + ksize[0] - paddings[0], H))
c_start = np.max((j * strides[1] - paddings[1], 0))
c_end = np.min((j * strides[1] + ksize[1] - paddings[1], W))
x_masked = x[:, :, r_start:r_end, c_start:c_end]
out[:, :, i, j] = np.max(x_masked, axis=(2, 3))
for n in xrange(N):
for c in xrange(C):
arr = x_masked[n, c, :, :]
index = np.where(arr == np.max(arr))
sub_row = index[0][0]
sub_col = index[1][0]
index = (r_start + sub_row) * W + c_start + sub_col
mask[n, c, i, j] = index
return out, mask
class TestMaxPoolWithIndex_Op(OpTest):
def setUp(self):
self.initTestCase()
input = np.random.random(self.shape).astype("float32")
output, mask = self.pool_forward_naive(input, self.ksize, self.strides,
self.paddings, self.global_pool)
self.attrs = {
'strides': self.strides,
'paddings': self.paddings,
'ksize': self.ksize,
'globalPooling': self.global_pool,
}
self.inputs = {'X': input}
self.outputs = {'Out': output, "Mask": mask}
def test_check_output(self):
self.check_output()
# def test_check_grad(self):
# self.check_grad(set(['X']), ['Out'], max_relative_error=0.07)
def initTestCase(self):
self.global_pool = True
self.index = "max_pool3d_with_index"
self.op_type = "%s" % self.index
self.pool_forward_naive = max_pool3D_forward_naive
self.shape = [2, 3, 5, 5, 5]
self.ksize = [3, 3, 3]
self.strides = [1, 1, 1]
self.paddings = [1, 1, 1]
class TestCase1(TestMaxPoolWithIndex_Op):
def initTestCase(self):
self.global_pool = True
self.op_type = "max_pool3d_with_index"
self.pool_forward_naive = max_pool3D_forward_naive
self.shape = [2, 3, 5, 5, 5]
self.ksize = [3, 3, 3]
self.strides = [1, 1, 1]
self.paddings = [1, 1, 1]
class TestCase2(TestMaxPoolWithIndex_Op):
def initTestCase(self):
self.global_pool = False
self.op_type = "max_pool3d_with_index"
self.pool_forward_naive = max_pool3D_forward_naive
self.shape = [2, 3, 7, 7, 7]
self.ksize = [3, 3, 3]
self.strides = [1, 1, 1]
self.paddings = [1, 1, 1]
class TestCase3(TestMaxPoolWithIndex_Op):
def initTestCase(self):
self.global_pool = False
self.op_type = "max_pool3d_with_index"
self.pool_forward_naive = max_pool3D_forward_naive
self.shape = [2, 3, 7, 7, 7]
self.ksize = [3, 3, 3]
self.strides = [2, 2, 2]
self.paddings = [0, 0, 0]
class TestCase4(TestMaxPoolWithIndex_Op):
def initTestCase(self):
self.global_pool = True
self.op_type = "max_pool3d_with_index"
self.pool_forward_naive = max_pool3D_forward_naive
self.shape = [2, 3, 5, 5, 5]
self.ksize = [3, 3, 3]
self.strides = [1, 1, 1]
self.paddings = [1, 1, 1]
class TestCase5(TestMaxPoolWithIndex_Op):
def initTestCase(self):
self.global_pool = True
self.op_type = "max_pool3d_with_index"
self.pool_forward_naive = max_pool3D_forward_naive
self.shape = [2, 3, 5, 5, 5]
self.ksize = [3, 3, 3]
self.strides = [2, 2, 2]
self.paddings = [0, 0, 0]
class TestCase6(TestMaxPoolWithIndex_Op):
def initTestCase(self):
self.global_pool = False
self.op_type = "max_pool2d_with_index"
self.pool_forward_naive = max_pool2D_forward_naive
self.shape = [2, 3, 7, 7]
self.ksize = [3, 3]
self.strides = [1, 1]
self.paddings = [1, 1]
class TestCase7(TestMaxPoolWithIndex_Op):
def initTestCase(self):
self.global_pool = False
self.op_type = "max_pool2d_with_index"
self.pool_forward_naive = max_pool2D_forward_naive
self.shape = [2, 3, 7, 7]
self.ksize = [3, 3]
self.strides = [2, 2]
self.paddings = [0, 0]
class TestCase8(TestMaxPoolWithIndex_Op):
def initTestCase(self):
self.global_pool = True
self.op_type = "max_pool2d_with_index"
self.pool_forward_naive = max_pool2D_forward_naive
self.shape = [2, 3, 5, 5]
self.ksize = [3, 3]
self.strides = [1, 1]
self.paddings = [1, 1]
class TestCase9(TestMaxPoolWithIndex_Op):
def initTestCase(self):
self.global_pool = True
self.op_type = "max_pool2d_with_index"
self.pool_forward_naive = max_pool2D_forward_naive
self.shape = [2, 3, 5, 5]
self.ksize = [3, 3]
self.strides = [2, 2]
self.paddings = [0, 0]
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册