提交 24010472 编写于 作者: L liym27 提交者: Tao Luo

fix pool2d pool3d,support asymmetric padding and channel_last (#19739)

* fix pool2d pool3d:
1. support asymmetric padding;
2. support padding algorithm:"SAME" and "VALID";
3. support channel_last: data_format NHWC and NDHWC;
4. support inferring shape when input with negative dims in compile time;
5. change doc of python API and c++;
6. fix bug in cuda kernel when Attr(adaptive) is true.

test=develop,test=document_preview

* fix 'tensors' to 'Tensors'. test=develop,test=document_preview

* add test for converage ValueError.test=develop,test=document_preview

* resolve conflict in test_pool2d. test=develop
上级 fe581b0e
......@@ -144,8 +144,8 @@ paddle.fluid.layers.conv3d (ArgSpec(args=['input', 'num_filters', 'filter_size',
paddle.fluid.layers.sequence_pool (ArgSpec(args=['input', 'pool_type', 'is_test', 'pad_value'], varargs=None, keywords=None, defaults=(False, 0.0)), ('document', 'e90a93251c52dc4e6fb34fb3991b3f82'))
paddle.fluid.layers.sequence_softmax (ArgSpec(args=['input', 'use_cudnn', 'name'], varargs=None, keywords=None, defaults=(False, None)), ('document', 'eaa9d0bbd3d4e017c8bc4ecdac483711'))
paddle.fluid.layers.softmax (ArgSpec(args=['input', 'use_cudnn', 'name', 'axis'], varargs=None, keywords=None, defaults=(False, None, -1)), ('document', 'cee673c79e3ff4582656a24e04f841e5'))
paddle.fluid.layers.pool2d (ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'name', 'exclusive'], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, None, True)), ('document', 'be7e530dcbd603962e25573a63eb145e'))
paddle.fluid.layers.pool3d (ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'name', 'exclusive'], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, None, True)), ('document', '053b1a855f13a066d005759171724bc6'))
paddle.fluid.layers.pool2d (ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'name', 'exclusive', 'data_format'], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, None, True, 'NCHW')), ('document', '630cae697d46b4b575b15d56cf8be25a'))
paddle.fluid.layers.pool3d (ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'name', 'exclusive', 'data_format'], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, None, True, 'NCDHW')), ('document', 'db0035a3132b1dfb12e53c57591fb9f6'))
paddle.fluid.layers.adaptive_pool2d (ArgSpec(args=['input', 'pool_size', 'pool_type', 'require_index', 'name'], varargs=None, keywords=None, defaults=('max', False, None)), ('document', '52343203de40afe29607397e13aaf0d2'))
paddle.fluid.layers.adaptive_pool3d (ArgSpec(args=['input', 'pool_size', 'pool_type', 'require_index', 'name'], varargs=None, keywords=None, defaults=('max', False, None)), ('document', '55db6ae7275fb9678a6814aebab81a9c'))
paddle.fluid.layers.batch_norm (ArgSpec(args=['input', 'act', 'is_test', 'momentum', 'epsilon', 'param_attr', 'bias_attr', 'data_layout', 'in_place', 'name', 'moving_mean_name', 'moving_variance_name', 'do_model_average_for_mean_and_var', 'fuse_with_relu', 'use_global_stats'], varargs=None, keywords=None, defaults=(None, False, 0.9, 1e-05, None, None, 'NCHW', False, None, None, None, False, False, False)), ('document', '9e5a9f4f6d82d34a33d9ca632379cbcc'))
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/tensor.h"
......@@ -83,10 +84,11 @@ HOSTDEVICE inline int AdaptEndIndex(int ph, int input_size, int output_size) {
/*
* \brief Getting pooling results, and calculating gradient.
*
* In pool2d, all tensors are in NCHW format. Where N is batch size, C is the
* number of channels, H and W is the height and width of feature.
* In pool3d, all tensors are in NCDHW format. Where N is batch size, C is the
* number of channels, D, H and W is the depth, height and width of feature.
* In pool2d, all Tensors are in NCHW or NHWC format. Where N is batch size, C
* is the number of channels, H and W is the height and width of feature.
* In pool3d, all Tensors are in NCDHW or NDHWC format. Where N is batch size, C
* is the number of channels, D, H and W is the depth, height and width of
* feature.
*
* In max pooling, it is possible that the pooling region has multiple maximum
* elements. In this case, we should compute the gradient of the first maximum
......@@ -115,6 +117,14 @@ class Pool2dFunctor {
const std::vector<int>& strides,
const std::vector<int>& paddings, PoolProcess pool_compute,
bool exclusive, bool adaptive, framework::Tensor* output);
// overload operator() to support argument data_format
void operator()(const DeviceContext& context, const framework::Tensor& input,
const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::string data_format, PoolProcess pool_compute,
bool exclusive, bool adaptive, framework::Tensor* output);
};
template <typename DeviceContext, typename PoolProcess, typename T>
......@@ -127,6 +137,15 @@ class Pool2dGradFunctor {
const std::vector<int>& strides,
const std::vector<int>& paddings, PoolProcess pool_compute,
bool exclusive, bool adaptive, framework::Tensor* input_grad);
// overload operator() to support argument data_format
void operator()(const DeviceContext& context, const framework::Tensor& input,
const framework::Tensor& output,
const framework::Tensor& output_grad,
const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::string data_format, PoolProcess pool_compute,
bool exclusive, bool adaptive, framework::Tensor* input_grad);
};
template <typename DeviceContext, class T>
......@@ -139,6 +158,14 @@ class MaxPool2dGradFunctor {
const std::vector<int>& strides,
const std::vector<int>& paddings,
framework::Tensor* input_grad);
// overload operator() to support argument data_format
void operator()(const DeviceContext& context, const framework::Tensor& input,
const framework::Tensor& output,
const framework::Tensor& output_grad,
const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::string data_format, framework::Tensor* input_grad);
};
template <typename DeviceContext, typename PoolProcess, typename T>
......@@ -149,6 +176,13 @@ class Pool3dFunctor {
const std::vector<int>& strides,
const std::vector<int>& paddings, PoolProcess pool_compute,
bool exclusive, bool adaptive, framework::Tensor* output);
// overload operator() to support argument data_format
void operator()(const DeviceContext& context, const framework::Tensor& input,
const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::string data_format, PoolProcess pool_compute,
bool exclusive, bool adaptive, framework::Tensor* output);
};
template <typename DeviceContext, typename PoolProcess, typename T>
......@@ -161,6 +195,15 @@ class Pool3dGradFunctor {
const std::vector<int>& strides,
const std::vector<int>& paddings, PoolProcess pool_compute,
bool exclusive, bool adaptive, framework::Tensor* input_grad);
// overload operator() to support argument data_format
void operator()(const DeviceContext& context, const framework::Tensor& input,
const framework::Tensor& output,
const framework::Tensor& output_grad,
const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::string data_format, PoolProcess pool_compute,
bool exclusive, bool adaptive, framework::Tensor* input_grad);
};
template <typename DeviceContext, class T>
......@@ -173,6 +216,14 @@ class MaxPool3dGradFunctor {
const std::vector<int>& strides,
const std::vector<int>& paddings,
framework::Tensor* input_grad);
// overload operator() to support argument data_format
void operator()(const DeviceContext& context, const framework::Tensor& input,
const framework::Tensor& output,
const framework::Tensor& output_grad,
const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::string data_format, framework::Tensor* input_grad);
};
/*
......
......@@ -12,7 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/pool_op.h"
#include "paddle/fluid/platform/cudnn_helper.h"
......@@ -27,47 +29,117 @@ using PoolingMode = platform::PoolingMode;
template <typename T>
using ScalingParamType = typename platform::CudnnDataType<T>::ScalingParamType;
DataLayout getLayoutFromStr(std::string data_format) {
if (data_format == "NHWC") {
return DataLayout::kNHWC;
} else if (data_format == "NCHW") {
return DataLayout::kNCHW;
} else if (data_format == "NCDHW") {
return DataLayout::kNCDHW;
} else {
return DataLayout::kNCDHW;
}
}
template <typename T>
class PoolCUDNNOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"It must use CUDAPlace.");
PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true,
"It must use CUDAPlace.");
const Tensor *input = ctx.Input<Tensor>("X");
Tensor *output = ctx.Output<Tensor>("Out");
const T *input_data = input->data<T>();
T *output_data = output->mutable_data<T>(ctx.GetPlace());
output->mutable_data<T>(ctx.GetPlace());
std::string pooling_type = ctx.Attr<std::string>("pooling_type");
bool exclusive = ctx.Attr<bool>("exclusive");
bool adaptive = ctx.Attr<bool>("adaptive");
std::vector<int> ksize = ctx.Attr<std::vector<int>>("ksize");
std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
if (ctx.Attr<bool>("global_pooling")) {
for (size_t i = 0; i < ksize.size(); ++i) {
paddings[i] = 0;
ksize[i] = static_cast<int>(input->dims()[i + 2]);
std::string data_format = ctx.Attr<std::string>("data_format");
bool global_pooling = ctx.Attr<bool>("global_pooling");
std::string padding_algorithm = ctx.Attr<std::string>("padding_algorithm");
const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC");
// update paddings
auto in_x_dims = input->dims();
framework::DDim data_dims;
if (channel_last) {
data_dims = framework::slice_ddim(in_x_dims, 1, in_x_dims.size() - 1);
} else {
data_dims = framework::slice_ddim(in_x_dims, 2, in_x_dims.size());
}
UpdatePadding(&paddings, global_pooling, adaptive, padding_algorithm,
data_dims, strides, ksize);
if (data_dims.size() * 2 == paddings.size()) {
for (size_t i = 0; i < data_dims.size(); ++i) {
paddings.erase(paddings.begin() + i + 1);
}
}
// ------------------- cudnn descriptors ---------------------
ScopedTensorDescriptor input_desc;
ScopedTensorDescriptor output_desc;
ScopedPoolingDescriptor pool_desc;
if (global_pooling) {
UpdateKsize(&ksize, data_dims);
}
const std::string str_NCHW = "NCHW", str_NHWC = "NHWC";
const std::string str_NCDHW = "NCDHW", str_NDHWC = "NDHWC";
// -----------------transformed tensor ------------------------
Tensor transformed_input(input->type());
Tensor transformed_output(output->type());
DataLayout layout;
if (strides.size() == 2U) {
layout = DataLayout::kNCHW;
} else {
if (data_format == str_NDHWC) {
layout = DataLayout::kNCDHW;
auto &dev_ctx =
ctx.template device_context<paddle::platform::CUDADeviceContext>();
std::vector<int> axis{0, 4, 1, 2, 3};
// input
transformed_input.Resize(input->dims());
auto in_dims_vec = framework::vectorize(input->dims());
in_dims_vec[1] = input->dims()[4];
in_dims_vec[2] = input->dims()[1];
in_dims_vec[3] = input->dims()[2];
in_dims_vec[4] = input->dims()[3];
transformed_input.Resize(framework::make_ddim(in_dims_vec));
transformed_input.mutable_data(ctx.GetPlace(), input->type());
math::Transpose<paddle::platform::CUDADeviceContext, T, 5> trans5;
trans5(dev_ctx, *input, &transformed_input, axis);
// output
transformed_output.Resize(output->dims());
auto out_dims_vec = framework::vectorize(output->dims());
out_dims_vec[1] = output->dims()[4];
out_dims_vec[2] = output->dims()[1];
out_dims_vec[3] = output->dims()[2];
out_dims_vec[4] = output->dims()[3];
transformed_output.Resize(framework::make_ddim(out_dims_vec));
} else {
layout = getLayoutFromStr(data_format);
transformed_input = *input;
transformed_output = *output;
}
const T *tranformed_input_data = transformed_input.data<T>();
T *tranformed_output_data = transformed_output.mutable_data<T>(
transformed_output.dims(), ctx.GetPlace());
// ------------------- cudnn descriptors ---------------------
ScopedTensorDescriptor input_desc;
ScopedTensorDescriptor output_desc;
ScopedPoolingDescriptor pool_desc;
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
layout, framework::vectorize<int>(input->dims()));
layout, framework::vectorize<int>(transformed_input.dims()));
cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor<T>(
layout, framework::vectorize<int>(output->dims()));
layout, framework::vectorize<int>(transformed_output.dims()));
PoolingMode pooling_mode;
if (pooling_type == "max") {
......@@ -83,9 +155,19 @@ class PoolCUDNNOpKernel : public framework::OpKernel<T> {
// ------------------- cudnn pool algorithm ---------------------
auto handle = ctx.cuda_device_context().cudnn_handle();
ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
CUDNN_ENFORCE(platform::dynload::cudnnPoolingForward(
handle, cudnn_pool_desc, &alpha, cudnn_input_desc, input_data, &beta,
cudnn_output_desc, output_data));
handle, cudnn_pool_desc, &alpha, cudnn_input_desc,
tranformed_input_data, &beta, cudnn_output_desc,
tranformed_output_data));
// add
if (data_format == str_NDHWC) {
auto &dev_ctx =
ctx.template device_context<paddle::platform::CUDADeviceContext>();
std::vector<int> axis{0, 2, 3, 4, 1};
math::Transpose<paddle::platform::CUDADeviceContext, T, 5> trans5_v2;
trans5_v2(dev_ctx, transformed_output, output, axis);
}
}
};
......@@ -93,8 +175,8 @@ template <typename T>
class PoolCUDNNGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"It must use CUDAPlace.");
PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true,
"It must use CUDAPlace.");
const Tensor *input = ctx.Input<Tensor>("X");
const Tensor *output = ctx.Input<Tensor>("Out");
......@@ -104,37 +186,109 @@ class PoolCUDNNGradOpKernel : public framework::OpKernel<T> {
std::string pooling_type = ctx.Attr<std::string>("pooling_type");
bool exclusive = ctx.Attr<bool>("exclusive");
bool adaptive = ctx.Attr<bool>("adaptive");
std::vector<int> ksize = ctx.Attr<std::vector<int>>("ksize");
std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
if (ctx.Attr<bool>("global_pooling")) {
for (size_t i = 0; i < ksize.size(); ++i) {
paddings[i] = 0;
ksize[i] = static_cast<int>(input->dims()[i + 2]);
std::string data_format = ctx.Attr<std::string>("data_format");
bool global_pooling = ctx.Attr<bool>("global_pooling");
std::string padding_algorithm = ctx.Attr<std::string>("padding_algorithm");
const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC");
// update paddings
auto in_x_dims = input->dims();
framework::DDim data_dims;
if (channel_last) {
data_dims = framework::slice_ddim(in_x_dims, 1, in_x_dims.size() - 1);
} else {
data_dims = framework::slice_ddim(in_x_dims, 2, in_x_dims.size());
}
UpdatePadding(&paddings, global_pooling, adaptive, padding_algorithm,
data_dims, strides, ksize);
if (data_dims.size() * 2 == paddings.size()) {
for (size_t i = 0; i < data_dims.size(); ++i) {
paddings.erase(paddings.begin() + i + 1);
}
}
const T *input_data = input->data<T>();
const T *output_data = output->data<T>();
const T *output_grad_data = output_grad->data<T>();
if (global_pooling) {
UpdateKsize(&ksize, data_dims);
}
// ------------------- cudnn descriptors ---------------------
ScopedTensorDescriptor input_desc;
ScopedTensorDescriptor output_desc;
ScopedPoolingDescriptor pool_desc;
// ------- tensor grad --------------
Tensor transformed_input(input->type());
Tensor transformed_output(output->type());
Tensor transformed_output_grad(output_grad->type());
input_grad->mutable_data<T>(ctx.GetPlace());
Tensor transformed_input_grad(input_grad->type());
DataLayout layout;
const std::string str_NCHW = "NCHW", str_NHWC = "NHWC";
const std::string str_NCDHW = "NCDHW", str_NDHWC = "NDHWC";
if (data_format == str_NDHWC) {
layout = DataLayout::kNCDHW;
auto &dev_ctx =
ctx.template device_context<paddle::platform::CUDADeviceContext>();
std::vector<int> axis{0, 4, 1, 2, 3};
// input
transformed_input.Resize(input->dims());
auto in_dims_vec = framework::vectorize(input->dims());
in_dims_vec[1] = input->dims()[4];
in_dims_vec[2] = input->dims()[1];
in_dims_vec[3] = input->dims()[2];
in_dims_vec[4] = input->dims()[3];
transformed_input.Resize(framework::make_ddim(in_dims_vec));
transformed_input.mutable_data(ctx.GetPlace(), input->type());
math::Transpose<paddle::platform::CUDADeviceContext, T, 5> trans5;
trans5(dev_ctx, *input, &transformed_input, axis);
// output
transformed_output.Resize(output->dims());
auto out_dims_vec = framework::vectorize(output->dims());
out_dims_vec[1] = output->dims()[4];
out_dims_vec[2] = output->dims()[1];
out_dims_vec[3] = output->dims()[2];
out_dims_vec[4] = output->dims()[3];
transformed_output.Resize(framework::make_ddim(out_dims_vec));
transformed_output.mutable_data(ctx.GetPlace(), output->type());
math::Transpose<paddle::platform::CUDADeviceContext, T, 5> trans5_v2;
trans5_v2(dev_ctx, *output, &transformed_output, axis);
// output grad
transformed_output_grad.Resize(framework::make_ddim(out_dims_vec));
transformed_output_grad.mutable_data(ctx.GetPlace(), output_grad->type());
math::Transpose<paddle::platform::CUDADeviceContext, T, 5> trans5_v3;
trans5_v3(dev_ctx, *output_grad, &transformed_output_grad, axis);
// input grad
transformed_input_grad.Resize(framework::make_ddim(in_dims_vec));
if (strides.size() == 2U) {
layout = DataLayout::kNCHW;
} else {
layout = DataLayout::kNCDHW;
layout = getLayoutFromStr(data_format);
transformed_input = *input;
transformed_output = *output;
transformed_output_grad = *output_grad;
transformed_input_grad = *input_grad;
}
const T *input_data = transformed_input.data<T>();
const T *output_data = transformed_output.data<T>();
const T *output_grad_data = transformed_output_grad.data<T>();
// ------------------- cudnn descriptors ---------------------
ScopedTensorDescriptor input_desc;
ScopedTensorDescriptor output_desc;
ScopedPoolingDescriptor pool_desc;
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
layout, framework::vectorize<int>(input->dims()));
layout, framework::vectorize<int>(transformed_input.dims()));
cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor<T>(
layout, framework::vectorize<int>(output->dims()));
layout, framework::vectorize<int>(transformed_output.dims()));
PoolingMode pooling_mode;
if (pooling_type == "max") {
......@@ -155,13 +309,21 @@ class PoolCUDNNGradOpKernel : public framework::OpKernel<T> {
auto handle = ctx.cuda_device_context().cudnn_handle();
ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
if (input_grad) {
T *input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
T *input_grad_data = transformed_input_grad.mutable_data<T>(
transformed_input_grad.dims(), ctx.GetPlace());
// Because beta is zero, it is unnecessary to reset input_grad.
CUDNN_ENFORCE(platform::dynload::cudnnPoolingBackward(
handle, cudnn_pool_desc, &alpha, cudnn_output_desc, output_data,
cudnn_output_desc, output_grad_data, cudnn_input_desc, input_data,
&beta, cudnn_input_desc, input_grad_data));
if (data_format == str_NDHWC) {
auto &dev_ctx =
ctx.template device_context<paddle::platform::CUDADeviceContext>();
std::vector<int> axis{0, 2, 3, 4, 1};
math::Transpose<paddle::platform::CUDADeviceContext, T, 5> trans5_v4;
trans5_v4(dev_ctx, transformed_input_grad, input_grad, axis);
}
}
}
};
......
......@@ -24,29 +24,32 @@ limitations under the License. */
namespace paddle {
namespace operators {
int PoolOutputSize(int input_size, int filter_size, int padding, int stride,
bool ceil_mode) {
int PoolOutputSize(int input_size, int filter_size, int padding_1,
int padding_2, int stride, bool ceil_mode) {
int output_size;
if (!ceil_mode) {
output_size = (input_size - filter_size + 2 * padding) / stride + 1;
output_size =
(input_size - filter_size + padding_1 + padding_2) / stride + 1;
} else {
output_size =
(input_size - filter_size + 2 * padding + stride - 1) / stride + 1;
(input_size - filter_size + padding_1 + padding_2 + stride - 1) /
stride +
1;
}
PADDLE_ENFORCE(output_size > 0,
"Due to the settings of padding(%d), filter_size(%d) and "
"stride(%d), the output size is less than 0, please check "
"again. Input_size:%d",
padding, filter_size, stride, input_size);
PADDLE_ENFORCE_GT(
output_size, 0,
"Due to the settings of padding(%d,%d), filter_size(%d) and "
"stride(%d), the output size is less than 0, please check "
"again. Input_size:%d",
padding_1, padding_2, filter_size, stride, input_size);
return output_size;
}
void PoolOp::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasInput("X"), "X(Input) of Pooling should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Out(Output) of Pooling should not be null.");
auto in_x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
"X(Input) of Pooling should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
"Out(Output) of Pooling should not be null.");
std::string pooling_type = ctx->Attrs().Get<std::string>("pooling_type");
std::vector<int> ksize = ctx->Attrs().Get<std::vector<int>>("ksize");
......@@ -54,38 +57,60 @@ void PoolOp::InferShape(framework::InferShapeContext* ctx) const {
std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
bool ceil_mode = ctx->Attrs().Get<bool>("ceil_mode");
bool adaptive = ctx->Attrs().Get<bool>("adaptive");
bool global_pooling = ctx->Attrs().Get<bool>("global_pooling");
std::string data_format = ctx->Attrs().Get<std::string>("data_format");
std::string padding_algorithm =
ctx->Attrs().Get<std::string>("padding_algorithm");
PADDLE_ENFORCE(in_x_dims.size() == 4 || in_x_dims.size() == 5,
"Pooling intput should be 4-D or 5-D tensor.");
if (ctx->Attrs().Get<bool>("global_pooling")) {
ksize.resize(static_cast<size_t>(in_x_dims.size()) - 2);
for (size_t i = 0; i < ksize.size(); ++i) {
paddings[i] = 0;
ksize[i] = static_cast<int>(in_x_dims[i + 2]);
}
}
auto in_x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(in_x_dims.size() == 4 || in_x_dims.size() == 5, true,
"Pooling intput should be 4-D or 5-D tensor.");
PADDLE_ENFORCE(in_x_dims.size() - ksize.size() == 2U,
"Input size and pooling size should be consistent.");
PADDLE_ENFORCE_EQ(in_x_dims.size() - ksize.size(), 2U,
"Input size and pooling size should be consistent.");
PADDLE_ENFORCE_EQ(ksize.size(), strides.size(),
"Strides size and pooling size should be the same.");
PADDLE_ENFORCE_EQ(ksize.size(), paddings.size(),
"Paddings size and pooling size should be the same.");
std::vector<int64_t> output_shape({in_x_dims[0], in_x_dims[1]});
const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC");
// update paddings if "SAME" or global_pooling
framework::DDim data_dims;
if (channel_last) {
data_dims = framework::slice_ddim(in_x_dims, 1, in_x_dims.size() - 1);
} else {
data_dims = framework::slice_ddim(in_x_dims, 2, in_x_dims.size());
}
UpdatePadding(&paddings, global_pooling, adaptive, padding_algorithm,
data_dims, strides, ksize);
if (global_pooling) {
UpdateKsize(&ksize, data_dims);
}
std::vector<int64_t> output_shape;
if (adaptive) {
output_shape.insert(output_shape.end(), ksize.begin(), ksize.end());
} else {
for (size_t i = 0; i < ksize.size(); ++i) {
if (!ctx->IsRuntime() && in_x_dims[i + 2] <= 0) {
output_shape.push_back(-1);
for (size_t i = 0; i < data_dims.size(); ++i) {
if ((!ctx->IsRuntime()) && (data_dims[i] < 0)) {
output_shape.push_back(in_x_dims[i]);
} else {
output_shape.push_back(PoolOutputSize(
in_x_dims[i + 2], ksize[i], paddings[i], strides[i], ceil_mode));
output_shape.push_back(
PoolOutputSize(data_dims[i], ksize[i], paddings[2 * i],
paddings[2 * i + 1], strides[i], ceil_mode));
}
}
}
// output_N = input_N
output_shape.insert(output_shape.begin(), in_x_dims[0]);
// output_C = input_C
if (channel_last) {
output_shape.push_back(in_x_dims[in_x_dims.size() - 1]);
} else {
output_shape.insert(output_shape.begin() + 1, in_x_dims[1]);
}
ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
ctx->ShareLoD("X", "Out");
}
......@@ -93,7 +118,9 @@ void PoolOp::InferShape(framework::InferShapeContext* ctx) const {
framework::OpKernelType PoolOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
framework::LibraryType library_{framework::LibraryType::kPlain};
std::string data_format = ctx.Attr<std::string>("data_format");
// std::string data_format = ctx.Attr<std::string>("data_format"); // change:
// delete
std::string data_format = "AnyLayout";
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
#ifdef PADDLE_WITH_CUDA
......@@ -114,16 +141,18 @@ framework::OpKernelType PoolOp::GetExpectedKernelType(
}
void PoolOpGrad::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
"Input(X@GRAD) should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, "Input(X) must not be null.");
PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("X")), true,
"Input(X@GRAD) should not be null.");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
framework::OpKernelType PoolOpGrad::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
framework::LibraryType library_{framework::LibraryType::kPlain};
std::string data_format = ctx.Attr<std::string>("data_format");
// std::string data_format = ctx.Attr<std::string>("data_format"); //
// change:delete
std::string data_format = "AnyLayout";
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
#ifdef PADDLE_WITH_CUDA
......@@ -186,8 +215,8 @@ void Pool2dOpMaker::Make() {
// TypedAttrChecker don't support vector type.)
AddAttr<std::vector<int>>(
"paddings",
"(vector<int>, default {0,0}), paddings(height, width) of pooling "
"operator."
"(vector<int>, default {0,0}), paddings(height_top, height_bottom, "
"width_left, wifth_right) of pooling operator."
"If global_pooling = true, paddings and kernel size will be ignored.")
.SetDefault({0, 0});
AddAttr<bool>(
......@@ -206,7 +235,7 @@ void Pool2dOpMaker::Make() {
AddAttr<bool>(
"use_cudnn",
"(bool, default false) Only used in cudnn kernel, need install cudnn")
"(bool, default false) Only used in cudnn kernel, need install cudnn.")
.SetDefault(false);
AddAttr<bool>(
"ceil_mode",
......@@ -215,7 +244,7 @@ void Pool2dOpMaker::Make() {
"the floor function will be used.")
.SetDefault(false);
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
"(bool, default false) Only used in mkldnn kernel.")
.SetDefault(false);
AddAttr<bool>("use_quantizer",
"(bool, default false) "
......@@ -229,18 +258,24 @@ void Pool2dOpMaker::Make() {
"An optional string from: \"NHWC\", \"NCHW\". "
"Defaults to \"NHWC\". Specify the data format of the output data, "
"the input will be transformed automatically. ")
.SetDefault("AnyLayout");
.SetDefault("NCHW");
AddAttr<bool>("is_test",
"(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true.")
.SetDefault(false);
AddAttr<std::string>(
"padding_algorithm",
"(string, default \"EXPLICIT\") An optional string from: \"EXPLICIT\","
"\"SAME\",\"VALID\". Set to \"EXPLICIT\" for explicit padding. "
"Set to \"SAME\" or \"VALID\" for algorithm of padding. ")
.SetDefault("EXPLICIT");
// TODO(dzhwinter): need to registered layout transform function
AddComment(R"DOC(
The pooling2d operation calculates the output based on
the input, pooling_type and ksize, strides, paddings parameters.
Input(X) and output(Out) are in NCHW format, where N is batch size, C is the
Input(X) and output(Out) are in NCHW or NHWC format, where N is batch size, C is the
number of channels, H is the height of the feature, and W is the width of the feature.
Parameters(ksize, strides, paddings) are two elements.
These two elements represent height and width, respectively.
......@@ -256,30 +291,47 @@ Example:
Out shape: $(N, C, H_{out}, W_{out})$
For pool_padding = "SAME":
$$
H_{out} = \\frac{(H_{in} + strides[0] - 1)}{strides[0]}
$$
$$
W_{out} = \\frac{(W_{in} + strides[1] - 1)}{strides[1]}
$$
For pool_padding = "VALID":
$$
H_{out} = \\frac{(H_{in} - ksize[0] + strides[0])}{strides[0]}
$$
$$
W_{out} = \\frac{(W_{in} - ksize[1] + strides[1])}{strides[1]}
$$
For ceil_mode = false:
$$
H_{out} = \\frac{(H_{in} - ksize[0] + 2 * paddings[0])}{strides[0]} + 1
H_{out} = \\frac{(H_{in} - ksize[0] + pad_height_top + pad_height_bottom}{strides[0]} + 1
$$
$$
W_{out} = \\frac{(W_{in} - ksize[1] + 2 * paddings[1])}{strides[1]} + 1
W_{out} = \\frac{(W_{in} - ksize[1] + pad_width_left + pad_width_right}{strides[1]} + 1
$$
For ceil_mode = true:
$$
H_{out} = \\frac{(H_{in} - ksize[0] + 2 * paddings[0] + strides[0] - 1)}{strides[0]} + 1
H_{out} = \\frac{(H_{in} - ksize[0] + pad_height_top + pad_height_bottom + strides[0] - 1)}{strides[0]} + 1
$$
$$
W_{out} = \\frac{(W_{in} - ksize[1] + 2 * paddings[1] + strides[1] - 1)}{strides[1]} + 1
W_{out} = \\frac{(W_{in} - ksize[1] + pad_width_left + pad_width_right + strides[1] - 1)}{strides[1]} + 1
$$
For exclusive = false:
$$
hstart = i * strides[0] - paddings[0]
hstart = i * strides[0] - pad_height_top
$$
$$
hend = hstart + ksize[0]
$$
$$
wstart = j * strides[1] - paddings[1]
wstart = j * strides[1] - pad_width_left
$$
$$
wend = wstart + ksize[1]
......@@ -290,13 +342,13 @@ Example:
For exclusive = true:
$$
hstart = max(0, i * strides[0] - paddings[0])
hstart = max(0, i * strides[0] - pad_height_top)
$$
$$
hend = min(H, hstart + ksize[0])
$$
$$
wstart = max(0, j * strides[1] - paddings[1])
wstart = max(0, j * strides[1] - pad_width_left)
$$
$$
wend = min(W, wstart + ksize[1])
......@@ -319,13 +371,14 @@ class PoolOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput {
void Pool3dOpMaker::Make() {
AddInput("X",
"(Tensor) The input tensor of pooling operator. "
"The format of input tensor is NCDHW, where N is batch size, C is "
"The format of input tensor is NCDHW or NDHWC, where N is batch "
"size, C is "
"the number of channels, and D, H and W is the depth, height and "
"width of "
"the feature, respectively.");
AddOutput("Out",
"(Tensor) The output tensor of pooling operator."
"The format of output tensor is also NCDHW, "
"The format of output tensor is also NCDHW or NDHWC, "
"where N is batch size, C is "
"the number of channels, and D, H and W is the depth, height and "
"width of the feature, respectively.");
......@@ -355,8 +408,10 @@ void Pool3dOpMaker::Make() {
// TypedAttrChecker don't support vector type.)
AddAttr<std::vector<int>>(
"paddings",
"(vector<int>, default {0,0,0}), paddings(depth, height, "
"width) of pooling operator. "
"(vector<int>, default {0,0,0}), paddings(pad_depth_front, "
"pad_depth_back, "
"pad_height_top, pad_height_bottom, pad_width_left, pad_width_right"
") of pooling operator. "
"If global_pooling = true, ksize and paddings will be ignored.")
.SetDefault({0, 0, 0}); // TODO(Chengduo): Add checker. (Currently,
// TypedAttrChecker don't support vector type.)
......@@ -376,7 +431,7 @@ void Pool3dOpMaker::Make() {
AddAttr<bool>(
"use_cudnn",
"(bool, default false) Only used in cudnn kernel, need install cudnn")
"(bool, default false) Only used in cudnn kernel, need install cudnn.")
.SetDefault(false);
AddAttr<bool>(
"ceil_mode",
......@@ -389,11 +444,17 @@ void Pool3dOpMaker::Make() {
.SetDefault(false);
AddAttr<std::string>(
"data_format",
"(string, default NCHW) Only used in "
"An optional string from: \"NHWC\", \"NCHW\". "
"Defaults to \"NHWC\". Specify the data format of the output data, "
"(string, default NCDHW) Only used in "
"An optional string from: \"NDHWC\", \"NCDHW\". "
"Defaults to \"NDHWC\". Specify the data format of the output data, "
"the input will be transformed automatically. ")
.SetDefault("AnyLayout");
.SetDefault("NCDHW");
AddAttr<std::string>(
"padding_algorithm",
"(string, default \"EXPLICIT\") An optional string from: \"EXPLICIT\","
"\"SAME\",\"VALID\". Set to \"EXPLICIT\" for explicit padding. "
"Set to \"SAME\" or \"VALID\" for algorithm of padding. ")
.SetDefault("EXPLICIT");
// TODO(dzhwinter): need to registered layout transform function
AddComment(R"DOC(
......@@ -401,7 +462,7 @@ Pool3d Operator.
The pooling3d operation calculates the output based on
the input, pooling_type, ksize, strides, and paddings parameters.
Input(X) and output(Out) are in NCDHW format, where N is batch
Input(X) and output(Out) are in NCDHW or NDHWC format, where N is batch
size, C is the number of channels, and D, H and W are the depth, height and
width of the feature, respectively. Parameters(ksize, strides, paddings)
are three elements. These three elements represent depth, height and
......@@ -412,42 +473,65 @@ Example:
X shape: $(N, C, D_{in}, H_{in}, W_{in})$
Output:
Out shape: $(N, C, D_{out}, H_{out}, W_{out})$
For pool_padding = "SAME":
$$
D_{out} = \\frac{(D_{in} + strides[0] - 1)}{strides[0]}
$$
$$
H_{out} = \\frac{(H_{in} + strides[1] - 1)}{strides[1]}
$$
$$
W_{out} = \\frac{(W_{in} + strides[2] - 1)}{strides[2]}
$$
For pool_padding = "VALID":
$$
D_{out} = \\frac{(D_{in} - ksize[0] + strides[0])}{strides[0]}
$$
$$
H_{out} = \\frac{(H_{in} - ksize[1] + strides[1])}{strides[1]}
$$
$$
W_{out} = \\frac{(W_{in} - ksize[2] + strides[2])}{strides[2]}
$$
For ceil_mode = false:
$$
D_{out} = \\frac{(D_{in} - ksize[0] + 2 * paddings[0])}{strides[0]} + 1
D_{out} = \\frac{(D_{in} - ksize[0] + pad_depth_front + pad_depth_back)}{strides[0]} + 1
$$
$$
H_{out} = \\frac{(H_{in} - ksize[1] + 2 * paddings[1])}{strides[2]} + 1
H_{out} = \\frac{(H_{in} - ksize[1] + pad_height_top + pad_height_bottom)}{strides[1]} + 1
$$
$$
W_{out} = \\frac{(W_{in} - ksize[2] + 2 * paddings[2])}{strides[2]} + 1
W_{out} = \\frac{(W_{in} - ksize[2] + pad_width_left + pad_width_right)}{strides[2]} + 1
$$
For ceil_mode = true:
$$
D_{out} = \\frac{(D_{in} - ksize[0] + 2 * paddings[0] + strides[0] -1)}{strides[0]} + 1
D_{out} = \\frac{(D_{in} - ksize[0] + pad_depth_front + pad_depth_back + strides[0] -1)}{strides[0]} + 1
$$
$$
H_{out} = \\frac{(H_{in} - ksize[1] + 2 * paddings[1] + strides[1] -1)}{strides[1]} + 1
H_{out} = \\frac{(H_{in} - ksize[1] + pad_height_top + pad_height_bottom + strides[1] -1)}{strides[1]} + 1
$$
$$
W_{out} = \\frac{(W_{in} - ksize[2] + 2 * paddings[2] + strides[2] -1)}{strides[2]} + 1
W_{out} = \\frac{(W_{in} - ksize[2] + pad_width_left + pad_width_right + strides[2] -1)}{strides[2]} + 1
$$
For exclusive = false:
$$
dstart = i * strides[0] - paddings[0]
dstart = i * strides[0] - pad_depth_front
$$
$$
dend = dstart + ksize[0]
$$
$$
hstart = j * strides[1] - paddings[1]
hstart = j * strides[1] - pad_height_top
$$
$$
hend = hstart + ksize[1]
$$
$$
wstart = k * strides[2] - paddings[2]
wstart = k * strides[2] - pad_width_left
$$
$$
wend = wstart + ksize[2]
......@@ -458,16 +542,19 @@ Example:
For exclusive = true:
$$
dstart = max(0, i * strides[0] - paddings[0])
dstart = max(0, i * strides[0] - pad_depth_front)
$$
$$
dend = min(D, dstart + ksize[0])
$$
$$
hstart = max(0, j * strides[1] - pad_height_top)
$$
$$
hend = min(H, hstart + ksize[1])
$$
$$
wstart = max(0, k * strides[2] - paddings[2])
wstart = max(0, k * strides[2] - pad_width_left)
$$
$$
wend = min(W, wstart + ksize[2])
......
......@@ -14,13 +14,13 @@ limitations under the License. */
#pragma once
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/pooling.h"
namespace paddle {
namespace operators {
......@@ -57,6 +57,57 @@ class Pool3dOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override;
};
inline void UpdatePadding(std::vector<int>* paddings, const bool global_pooling,
const bool adaptive,
const std::string padding_algorithm,
const framework::DDim data_dims,
const std::vector<int>& strides,
const std::vector<int>& ksize) {
// set padding size == data_dims.size() * 2
auto data_shape = framework::vectorize<int>(data_dims);
if (paddings->size() == data_dims.size()) {
for (size_t i = 0; i < data_dims.size(); ++i) {
int copy_pad = *(paddings->begin() + 2 * i);
paddings->insert(paddings->begin() + 2 * i + 1, copy_pad);
}
} else {
PADDLE_ENFORCE_EQ(
data_dims.size() * 2, paddings->size(),
"Paddings size should be the same or twice as the pooling size.");
}
// when padding_desc is "VALID" or "SAME"
if (padding_algorithm == "SAME") {
for (size_t i = 0; i < data_dims.size(); ++i) {
int out_size = (data_dims[i] + strides[i] - 1) / strides[0];
int pad_sum =
std::max((out_size - 1) * strides[i] + ksize[i] - data_shape[i], 0);
int pad_0 = pad_sum / 2;
int pad_1 = pad_sum - pad_0;
*(paddings->begin() + i * 2) = pad_0;
*(paddings->begin() + i * 2 + 1) = pad_1;
}
} else if (padding_algorithm == "VALID") {
for (auto it = paddings->begin(); it != paddings->end(); it++) {
*it = 0;
}
}
// if global_pooling == true or adaptive == true, padding will be ignore
if (global_pooling || adaptive) {
for (auto it = paddings->begin(); it != paddings->end(); it++) {
*it = 0;
}
}
}
inline void UpdateKsize(std::vector<int>* ksize,
const framework::DDim data_dims) {
ksize->resize(static_cast<size_t>(data_dims.size()));
for (size_t i = 0; i < ksize->size(); ++i) {
*(ksize->begin() + i) = static_cast<int>(data_dims[i]);
}
}
template <typename DeviceContext, typename T>
class PoolKernel : public framework::OpKernel<T> {
......@@ -69,14 +120,36 @@ class PoolKernel : public framework::OpKernel<T> {
std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
std::string data_format = context.Attr<std::string>("data_format");
bool exclusive = context.Attr<bool>("exclusive");
bool adaptive = context.Attr<bool>("adaptive");
if (context.Attr<bool>("global_pooling")) {
for (size_t i = 0; i < ksize.size(); ++i) {
paddings[i] = 0;
ksize[i] = static_cast<int>(in_x->dims()[i + 2]);
bool global_pooling = context.Attr<bool>("global_pooling");
std::string padding_algorithm =
context.Attr<std::string>("padding_algorithm");
const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC");
// update paddings
auto in_x_dims = in_x->dims();
framework::DDim data_dims;
if (channel_last) {
data_dims = framework::slice_ddim(in_x_dims, 1, in_x_dims.size() - 1);
} else {
data_dims = framework::slice_ddim(in_x_dims, 2, in_x_dims.size());
}
UpdatePadding(&paddings, global_pooling, adaptive, padding_algorithm,
data_dims, strides, ksize);
if (data_dims.size() * 2 == paddings.size()) {
for (size_t i = 0; i < data_dims.size(); ++i) {
paddings.erase(paddings.begin() + i + 1);
}
}
if (global_pooling) {
UpdateKsize(&ksize, data_dims);
}
auto& dev_ctx = context.template device_context<DeviceContext>();
switch (ksize.size()) {
case 2: {
......@@ -85,16 +158,16 @@ class PoolKernel : public framework::OpKernel<T> {
DeviceContext, paddle::operators::math::MaxPool<T>, T>
pool2d_forward;
paddle::operators::math::MaxPool<T> pool_process;
pool2d_forward(dev_ctx, *in_x, ksize, strides, paddings, pool_process,
true, false, out);
pool2d_forward(dev_ctx, *in_x, ksize, strides, paddings, data_format,
pool_process, true, false, out);
} else if (pooling_type == "avg") {
paddle::operators::math::Pool2dFunctor<
DeviceContext, paddle::operators::math::AvgPool<T>, T>
pool2d_forward;
paddle::operators::math::AvgPool<T> pool_process;
pool2d_forward(dev_ctx, *in_x, ksize, strides, paddings, pool_process,
exclusive, adaptive, out);
pool2d_forward(dev_ctx, *in_x, ksize, strides, paddings, data_format,
pool_process, exclusive, adaptive, out);
}
} break;
case 3: {
......@@ -103,15 +176,16 @@ class PoolKernel : public framework::OpKernel<T> {
DeviceContext, paddle::operators::math::MaxPool<T>, T>
pool3d_forward;
paddle::operators::math::MaxPool<T> pool_process;
pool3d_forward(dev_ctx, *in_x, ksize, strides, paddings, pool_process,
true, false, out);
pool3d_forward(dev_ctx, *in_x, ksize, strides, paddings, data_format,
pool_process, true, false, out);
} else if (pooling_type == "avg") {
paddle::operators::math::Pool3dFunctor<
DeviceContext, paddle::operators::math::AvgPool<T>, T>
pool3d_forward;
paddle::operators::math::AvgPool<T> pool_process;
pool3d_forward(dev_ctx, *in_x, ksize, strides, paddings, pool_process,
exclusive, adaptive, out);
pool3d_forward(dev_ctx, *in_x, ksize, strides, paddings, data_format,
pool_process, exclusive, adaptive, out);
}
} break;
default: { PADDLE_THROW("Pool op only supports 2D and 3D input."); }
......@@ -135,13 +209,33 @@ class PoolGradKernel : public framework::OpKernel<T> {
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
bool exclusive = context.Attr<bool>("exclusive");
bool adaptive = context.Attr<bool>("adaptive");
std::string data_format = context.Attr<std::string>("data_format");
bool global_pooling = context.Attr<bool>("global_pooling");
std::string padding_algorithm =
context.Attr<std::string>("padding_algorithm");
const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC");
if (context.Attr<bool>("global_pooling")) {
for (size_t i = 0; i < ksize.size(); ++i) {
paddings[i] = 0;
ksize[i] = static_cast<int>(in_x->dims()[i + 2]);
// update paddings
auto in_x_dims = in_x->dims();
framework::DDim data_dims;
if (channel_last) {
data_dims = framework::slice_ddim(in_x_dims, 1, in_x_dims.size() - 1);
} else {
data_dims = framework::slice_ddim(in_x_dims, 2, in_x_dims.size());
}
UpdatePadding(&paddings, global_pooling, adaptive, padding_algorithm,
data_dims, strides, ksize);
if (data_dims.size() * 2 == paddings.size()) {
for (size_t i = 0; i < data_dims.size(); ++i) {
paddings.erase(paddings.begin() + i + 1);
}
}
if (global_pooling) {
UpdateKsize(&ksize, data_dims);
}
auto& dev_ctx = context.template device_context<DeviceContext>();
if (in_x_grad) {
in_x_grad->mutable_data<T>(context.GetPlace());
......@@ -154,15 +248,15 @@ class PoolGradKernel : public framework::OpKernel<T> {
paddle::operators::math::MaxPool2dGradFunctor<DeviceContext, T>
pool2d_backward;
pool2d_backward(dev_ctx, *in_x, *out, *out_grad, ksize, strides,
paddings, in_x_grad);
paddings, data_format, in_x_grad);
} else if (pooling_type == "avg") {
paddle::operators::math::Pool2dGradFunctor<
DeviceContext, paddle::operators::math::AvgPoolGrad<T>, T>
pool2d_backward;
paddle::operators::math::AvgPoolGrad<T> pool_process;
pool2d_backward(dev_ctx, *in_x, *out, *out_grad, ksize, strides,
paddings, pool_process, exclusive, adaptive,
in_x_grad);
paddings, data_format, pool_process, exclusive,
adaptive, in_x_grad);
}
} break;
case 3: {
......@@ -170,15 +264,15 @@ class PoolGradKernel : public framework::OpKernel<T> {
paddle::operators::math::MaxPool3dGradFunctor<DeviceContext, T>
pool3d_backward;
pool3d_backward(dev_ctx, *in_x, *out, *out_grad, ksize, strides,
paddings, in_x_grad);
paddings, data_format, in_x_grad);
} else if (pooling_type == "avg") {
paddle::operators::math::Pool3dGradFunctor<
DeviceContext, paddle::operators::math::AvgPoolGrad<T>, T>
pool3d_backward;
paddle::operators::math::AvgPoolGrad<T> pool_process;
pool3d_backward(dev_ctx, *in_x, *out, *out_grad, ksize, strides,
paddings, pool_process, exclusive, adaptive,
in_x_grad);
paddings, data_format, pool_process, exclusive,
adaptive, in_x_grad);
}
} break;
default: { PADDLE_THROW("Pool op only supports 2D and 3D input."); }
......
......@@ -72,6 +72,7 @@ enum class DataLayout { // Not use
kNHWC,
kNCHW,
kNCDHW,
kNDHWC, // add, liyamei
kNCHW_VECT_C,
};
......@@ -212,6 +213,8 @@ inline cudnnTensorFormat_t GetCudnnTensorFormat(
return CUDNN_TENSOR_NCHW;
case DataLayout::kNCDHW:
return CUDNN_TENSOR_NCHW; // NOTE: cudnn treat NdTensor as the same
case DataLayout::kNDHWC:
return CUDNN_TENSOR_NHWC; // add, liyamei
default:
PADDLE_THROW("Unknown cudnn equivalent for order");
}
......@@ -238,14 +241,31 @@ class ScopedTensorDescriptor {
strides[i] = dims[i + 1] * strides[i + 1];
}
// Update tensor descriptor dims setting if groups > 1
// NOTE: Assume using NCHW or NCDHW order
std::vector<int> dims_with_group(dims.begin(), dims.end()); // copy
// NOTE: Here, Assume using NCHW or NCDHW order
std::vector<int> dims_with_group(dims.begin(), dims.end());
if (groups > 1) {
dims_with_group[1] = dims_with_group[1] / groups;
}
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnSetTensorNdDescriptor(
desc_, type, dims_with_group.size(), dims_with_group.data(),
strides.data()));
if (dims.size() == 4) {
if (format == CUDNN_TENSOR_NCHW) {
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnSetTensorNdDescriptor(
desc_, type, dims_with_group.size(), dims_with_group.data(),
strides.data()));
} else { // CUDNN_TENSOR_NHWC
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnSetTensor4dDescriptor(
desc_, format, type, dims[0], dims[3], dims[1], dims[2]));
}
} else if (dims.size() == 5) {
if (format == CUDNN_TENSOR_NCHW) {
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnSetTensorNdDescriptor(
desc_, type, dims_with_group.size(), dims_with_group.data(),
strides.data()));
} else { // CUDNN_TENSOR_NHWC
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnSetTensorNdDescriptorEx(
desc_, format, type, dims.size(), dims.data()));
}
}
return desc_;
}
......
......@@ -126,7 +126,8 @@ extern void EnforceCUDNNLoaded(const char* fn_name);
__macro(cudnnRNNBackwardWeights); \
__macro(cudnnRNNForwardInference); \
__macro(cudnnDestroyDropoutDescriptor); \
__macro(cudnnDestroyRNNDescriptor);
__macro(cudnnDestroyRNNDescriptor); \
__macro(cudnnSetTensorNdDescriptorEx);
CUDNN_DNN_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
......
......@@ -2871,15 +2871,16 @@ def pool2d(input,
use_cudnn=True,
ceil_mode=False,
name=None,
exclusive=True):
exclusive=True,
data_format="NCHW"):
"""
${comment}
Args:
input (Variable): The input tensor of pooling operator. The format of
input tensor is NCHW, where N is batch size, C is
the number of channels, H is the height of the
feature, and W is the width of the feature.
input tensor is `"NCHW"` or `"NHWC"`, where `N` is batch size, `C` is
the number of channels, `H` is the height of the
feature, and `W` is the width of the feature.
pool_size (int|list|tuple): The pool kernel size. If pool kernel size is a tuple or list,
it must contain two integers, (pool_size_Height, pool_size_Width).
Otherwise, the pool kernel size will be a square of an int.
......@@ -2887,8 +2888,13 @@ def pool2d(input,
pool_stride (int|list|tuple): The pool stride size. If pool stride size is a tuple or list,
it must contain two integers, (pool_stride_Height, pool_stride_Width).
Otherwise, the pool stride size will be a square of an int.
pool_padding (int|list|tuple): The pool padding size. If pool padding size is a tuple,
it must contain two integers, (pool_padding_on_Height, pool_padding_on_Width).
pool_padding (string|int|list|tuple): The pool padding. If `pool_padding` is a string, either 'VALID' or
'SAME' which is the padding algorithm. If pool padding size is a tuple or list,
it could be in three forms: `[pad_height, pad_width]` or
`[pad_height_top, pad_height_bottom, pad_width_left, pad_width_right]`, and when `data_format` is `"NCHW"`,
`pool_padding` can be in the form `[[0,0], [0,0], [pad_height_top, pad_height_bottom], [pad_width_left, pad_width_right]]`.
when `data_format` is `"NHWC"`, `pool_padding` can be in the form
`[[0,0], [pad_height_top, pad_height_bottom], [pad_width_left, pad_width_right], [0,0]]`.
Otherwise, the pool padding size will be a square of an int.
global_pooling (bool): ${global_pooling_comment}
use_cudnn (bool): ${use_cudnn_comment}
......@@ -2896,55 +2902,125 @@ def pool2d(input,
name (str|None): A name for this layer(optional). If set None, the
layer will be named automatically.
exclusive (bool): Whether to exclude padding points in average pooling
mode, default is true
mode, default is `true`.
data_format (string): The data format of the input and output data. An optional string from: `"NCHW"`, `"NDHW"`.
The default is `"NCHW"`. When it is `"NCHW"`, the data is stored in the order of:
`[batch_size, input_channels, input_height, input_width]`.
Returns:
Variable: The pooling result.
Raises:
ValueError: If 'pool_type' is not "max" nor "avg"
ValueError: If 'global_pooling' is False and 'pool_size' is -1
ValueError: If 'use_cudnn' is not a bool value.
ValueError: If `pool_type` is not "max" nor "avg"
ValueError: If `global_pooling` is False and `pool_size` is -1
ValueError: If `use_cudnn` is not a bool value.
Examples:
.. code-block:: python
import paddle.fluid as fluid
data = fluid.layers.data(
name='data', shape=[3, 32, 32], dtype='float32')
pool2d = fluid.layers.pool2d(
input=data,
pool_size=2,
pool_type='max',
pool_stride=1,
global_pooling=False)
name='data', shape=[10, 3, 32, 32], append_batch_size=False, dtype='float32')
# example 1:
# Attr(pool_padding) is a list with 4 elements, Attr(data_format) is "NCHW".
out_1 = fluid.layers.pool2d(
input = data,
pool_size = 3,
pool_type = "avg",
pool_stride = 1,
pool_padding = [1, 2, 1, 0],
data_format = "NCHW")
# example 2:
# Attr(pool_padding) is a string, Attr(data_format) is "NCHW".
out_2 = fluid.layers.pool2d(
input = data,
pool_size = 3,
pool_type = "avg",
pool_stride = 1,
pool_padding = "VALID",
data_format = "NCHW")
"""
if pool_type not in ["max", "avg"]:
raise ValueError(
"Unknown pool_type: '%s'. It can only be 'max' or 'avg'.",
"Unknown Attr(pool_type): '%s'. It can only be 'max' or 'avg'.",
str(pool_type))
if global_pooling is False and pool_size == -1:
raise ValueError(
"When the global_pooling is False, pool_size must be passed "
"and be a valid value. Received pool_size: " + str(pool_size))
"When Attr(global_pooling) is False, Attr(pool_size) must be passed "
"and be a valid value. Received pool_size: %s." % str(pool_size))
if not isinstance(use_cudnn, bool):
raise ValueError("Attr(use_cudnn) should be True or False. Received "
"Attr(use_cudnn): %s." % str(use_cudnn))
if data_format not in ["NCHW", "NHWC"]:
raise ValueError(
"Attr(data_format) should be 'NCHW' or 'NHWC'. Received "
"Attr(data_format): %s." % str(data_format))
pool_size = utils.convert_to_list(pool_size, 2, 'pool_size')
pool_padding = utils.convert_to_list(pool_padding, 2, 'pool_padding')
pool_stride = utils.convert_to_list(pool_stride, 2, 'pool_stride')
if not isinstance(use_cudnn, bool):
raise ValueError("use_cudnn should be True or False")
def update_padding(padding, data_format):
def is_list_or_tuple(ele):
if isinstance(ele, list) or isinstance(ele, tuple):
return True
return False
if is_list_or_tuple(padding) and len(padding) == 4:
if is_list_or_tuple(padding[0]) and (data_format == "NCHW"):
if not (padding[0] == [0, 0] and padding[1] == [0, 0]):
raise ValueError(
"Non-zero pool_padding(%s) in the batch or channel dimensions "
"is not supported." % str(padding))
padding = padding[2:4]
padding = [ele for a_list in padding for ele in a_list]
elif is_list_or_tuple(padding[0]) and (data_format == "NHWC"):
if not (padding[0] == [0, 0] and padding[3] == [0, 0]):
raise ValueError(
"Non-zero pool_padding(%s) in the batch or channel dimensions "
"is not supported." % str(padding))
padding = padding[1:3]
padding = [ele for a_list in padding for ele in a_list]
padding = utils.convert_to_list(padding, 4, 'padding')
l_type = 'pool2d'
else:
padding = utils.convert_to_list(padding, 2, 'padding')
helper = LayerHelper(l_type, **locals())
return padding
padding_algorithm = "EXPLICIT"
if isinstance(pool_padding, str):
pool_padding = pool_padding.upper()
if pool_padding not in ["SAME", "VALID"]:
raise ValueError(
"Unknown Attr(pool_padding): '%s'. It can only be 'SAME' or 'VALID'."
% str(pool_padding))
if pool_padding == "VALID":
padding_algorithm = "VALID"
pool_padding = [0, 0, 0, 0]
if ceil_mode != False:
raise ValueError(
"When Attr(pool_padding) is \"VALID\", Attr(ceil_mode) must be False. "
"Received ceil_mode: True.")
elif pool_padding == "SAME":
padding_algorithm = "SAME"
pool_padding = [0, 0, 0, 0]
pool_padding = update_padding(pool_padding, data_format)
op_type = 'pool2d'
helper = LayerHelper(op_type, **locals())
dtype = helper.input_dtype()
pool_out = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type=l_type,
type=op_type,
inputs={"X": input},
outputs={"Out": pool_out},
attrs={
......@@ -2953,10 +3029,12 @@ def pool2d(input,
"global_pooling": global_pooling,
"strides": pool_stride,
"paddings": pool_padding,
"padding_algorithm": padding_algorithm,
"use_cudnn": use_cudnn,
"ceil_mode": ceil_mode,
"use_mkldnn": False,
"exclusive": exclusive,
"data_format": data_format,
})
return pool_out
......@@ -2972,30 +3050,43 @@ def pool3d(input,
use_cudnn=True,
ceil_mode=False,
name=None,
exclusive=True):
exclusive=True,
data_format="NCDHW"):
"""
${comment}
Args:
input (Variable): The input tensor of pooling operator. The format of
input tensor is NCDHW, where N is batch size, C is
the number of channels, D is the depth of the feature,
H is the height of the feature, and W is the width
input tensor is `"NCDHW"` or `"NDHWC"`, where `N` is batch size, `C` is
the number of channels, `D` is the depth of the feature,
`H` is the height of the feature, and `W` is the width
of the feature.
pool_size (int|list|tuple): The pool kernel size. If pool kernel size
is a tuple or list, it must contain three integers,
(pool_size_Depth, pool_size_Height, pool_size_Width).
Otherwise, the pool kernel size will be the cube of an int.
pool_type (string): ${pooling_type_comment}
pool_stride (int): stride of the pooling layer.
pool_padding (int): padding size.
pool_stride (string|int|list|tuple)): The pool padding. If `pool_padding` is a string, either 'VALID' or
'SAME' which is the padding algorithm. If pool stride size is a tuple or list,
it must contain three integers, `[stride_Depth, stride_Height, stride_Width]`.
Otherwise, the pool stride size will be a cube of an int.
pool_padding (int|list|tuple): The pool padding size. If pool padding size is a tuple or list,
it could be in three forms: `[pad_depth, pad_height, pad_width]` or
`[pad_depth_front, pad_depth_back, pad_height_top, pad_height_bottom, pad_width_left, pad_width_right]`,
and when `data_format` is `"NCDHW"`, `pool_padding` can be in the form
`[[0,0], [0,0], [pad_depth_front, pad_depth_back], [pad_height_top, pad_height_bottom], [pad_width_left, pad_width_right]]`.
when `data_format` is `"NDHWC"`, `pool_padding` can be in the form
`[[0,0], [pad_depth_front, pad_depth_back], [pad_height_top, pad_height_bottom], [pad_width_left, pad_width_right], [0,0]]`.
global_pooling (bool): ${global_pooling_comment}
use_cudnn (bool): ${use_cudnn_comment}
ceil_mode (bool): ${ceil_mode_comment}
name (str): A name for this layer(optional). If set None, the layer
will be named automatically.
exclusive (bool): Whether to exclude padding points in average pooling
mode, default is true
mode, default is true.
data_format (string): The data format of the input and output data. An optional string from: `"NCDHW"`, `"NDHWC"`.
The default is `"NCDHW"`. When it is `"NCDHW"`, the data is stored in the order of:
`[batch_size, input_channels, input_depth, input_height, input_width]`.
Returns:
Variable: output of pool3d layer.
......@@ -3005,39 +3096,114 @@ def pool3d(input,
.. code-block:: python
import paddle.fluid as fluid
data = fluid.layers.data(
name='data', shape=[3, 32, 32, 32], dtype='float32')
pool3d = fluid.layers.pool3d(
input=data,
pool_size=2,
pool_type='max',
pool_stride=1,
global_pooling=False)
name='data', shape=[10, 3, 32, 32, 32], append_batch_size=False, dtype='float32')
# example 1:
# Attr(pool_padding) is a list with 6 elements, Attr(data_format) is "NCDHW".
out_1 = fluid.layers.pool3d(
input = data,
pool_size = 2,
pool_type = "avg",
pool_stride = 1,
pool_padding = [1, 2, 1, 0, 1, 2],
global_pooling = False,
data_format = "NCDHW")
# example 2:
# Attr(pool_padding) is a string, Attr(data_format) is "NCDHW".
out_2 = fluid.layers.pool3d(
input = data,
pool_size = 3,
pool_type = "avg",
pool_stride = 1,
pool_padding = "VALID",
global_pooling = False,
data_format = "NCDHW")
"""
if pool_type not in ["max", "avg"]:
raise ValueError(
"Unknown pool_type: '%s'. It can only be 'max' or 'avg'.",
"Unknown Attr(pool_type): '%s'. It can only be 'max' or 'avg'.",
str(pool_type))
if global_pooling is False and pool_size == -1:
raise ValueError(
"When the global_pooling is False, pool_size must be passed "
"and be a valid value. Received pool_size: " + str(pool_size))
"When Attr(global_pooling) is False, Attr(pool_size) must be passed "
"and be a valid value. Received Attr(pool_size): %s." %
str(pool_size))
if not isinstance(use_cudnn, bool):
raise ValueError("Attr(use_cudnn) should be True or False. Received "
"Attr(use_cudnn): %s. " % str(use_cudnn))
if data_format not in ["NCDHW", "NDHWC"]:
raise ValueError(
"Attr(data_format) should be 'NCDHW' or 'NDHWC'. Received "
"Attr(data_format): %s" % str(data_format))
pool_size = utils.convert_to_list(pool_size, 3, 'pool_size')
pool_padding = utils.convert_to_list(pool_padding, 3, 'pool_padding')
pool_stride = utils.convert_to_list(pool_stride, 3, 'pool_stride')
if not isinstance(use_cudnn, bool):
raise ValueError("use_cudnn should be True or False")
def update_padding(padding, data_format):
def is_list_or_tuple(ele):
if isinstance(ele, (list, tuple)):
return True
return False
if is_list_or_tuple(padding) and len(padding) == 5:
if is_list_or_tuple(padding[0]) and (data_format == "NCDHW"):
if not (padding[0] == [0, 0] and padding[1] == [0, 0]):
raise ValueError(
"Non-zero pool_padding(%s) in the batch or channel dimensions "
"is not supported." % str(padding))
padding = padding[2:5]
padding = [ele for a_list in padding for ele in a_list]
elif is_list_or_tuple(padding[0]) and (data_format == "NDHWC"):
if not (padding[0] == [0, 0] and padding[4] == [0, 0]):
raise ValueError(
"Non-zero pool_padding(%s) in the batch or channel dimensions "
"is not supported." % str(padding))
padding = padding[1:4]
padding = [ele for a_list in padding for ele in a_list]
padding = utils.convert_to_list(padding, 6, 'padding')
elif is_list_or_tuple(padding) and len(padding) == 6:
padding = utils.convert_to_list(padding, 6, 'padding')
l_type = "pool3d"
helper = LayerHelper(l_type, **locals())
else:
padding = utils.convert_to_list(padding, 3, 'padding')
return padding
padding_algorithm = "EXPLICIT"
if isinstance(pool_padding, str):
pool_padding = pool_padding.upper()
if pool_padding not in ["SAME", "VALID"]:
raise ValueError(
"Unknown Attr(pool_padding): '%s'. It can only be 'SAME' or 'VALID'."
% str(pool_padding))
if pool_padding == "VALID":
padding_algorithm = "VALID"
pool_padding = [0, 0, 0, 0, 0, 0]
if ceil_mode != False:
raise ValueError(
"When Attr(pool_padding) is \"VALID\", ceil_mode must be False. "
"Received ceil_mode: True.")
elif pool_padding == "SAME":
padding_algorithm = "SAME"
pool_padding = [0, 0, 0, 0, 0, 0]
pool_padding = update_padding(pool_padding, data_format)
op_type = "pool3d"
helper = LayerHelper(op_type, **locals())
dtype = helper.input_dtype()
pool_out = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type=l_type,
type=op_type,
inputs={"X": input},
outputs={"Out": pool_out},
attrs={
......@@ -3046,10 +3212,12 @@ def pool3d(input,
"global_pooling": global_pooling,
"strides": pool_stride,
"paddings": pool_padding,
"padding_algorithm": padding_algorithm,
"use_cudnn": use_cudnn,
"ceil_mode": ceil_mode,
"use_mkldnn": False,
"exclusive": exclusive,
"data_format": data_format,
})
return pool_out
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册