提交 24010472 编写于 作者: L liym27 提交者: Tao Luo

fix pool2d pool3d,support asymmetric padding and channel_last (#19739)

* fix pool2d pool3d:
1. support asymmetric padding;
2. support padding algorithm:"SAME" and "VALID";
3. support channel_last: data_format NHWC and NDHWC;
4. support inferring shape when input with negative dims in compile time;
5. change doc of python API and c++;
6. fix bug in cuda kernel when Attr(adaptive) is true.

test=develop,test=document_preview

* fix 'tensors' to 'Tensors'. test=develop,test=document_preview

* add test for converage ValueError.test=develop,test=document_preview

* resolve conflict in test_pool2d. test=develop
上级 fe581b0e
...@@ -144,8 +144,8 @@ paddle.fluid.layers.conv3d (ArgSpec(args=['input', 'num_filters', 'filter_size', ...@@ -144,8 +144,8 @@ paddle.fluid.layers.conv3d (ArgSpec(args=['input', 'num_filters', 'filter_size',
paddle.fluid.layers.sequence_pool (ArgSpec(args=['input', 'pool_type', 'is_test', 'pad_value'], varargs=None, keywords=None, defaults=(False, 0.0)), ('document', 'e90a93251c52dc4e6fb34fb3991b3f82')) paddle.fluid.layers.sequence_pool (ArgSpec(args=['input', 'pool_type', 'is_test', 'pad_value'], varargs=None, keywords=None, defaults=(False, 0.0)), ('document', 'e90a93251c52dc4e6fb34fb3991b3f82'))
paddle.fluid.layers.sequence_softmax (ArgSpec(args=['input', 'use_cudnn', 'name'], varargs=None, keywords=None, defaults=(False, None)), ('document', 'eaa9d0bbd3d4e017c8bc4ecdac483711')) paddle.fluid.layers.sequence_softmax (ArgSpec(args=['input', 'use_cudnn', 'name'], varargs=None, keywords=None, defaults=(False, None)), ('document', 'eaa9d0bbd3d4e017c8bc4ecdac483711'))
paddle.fluid.layers.softmax (ArgSpec(args=['input', 'use_cudnn', 'name', 'axis'], varargs=None, keywords=None, defaults=(False, None, -1)), ('document', 'cee673c79e3ff4582656a24e04f841e5')) paddle.fluid.layers.softmax (ArgSpec(args=['input', 'use_cudnn', 'name', 'axis'], varargs=None, keywords=None, defaults=(False, None, -1)), ('document', 'cee673c79e3ff4582656a24e04f841e5'))
paddle.fluid.layers.pool2d (ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'name', 'exclusive'], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, None, True)), ('document', 'be7e530dcbd603962e25573a63eb145e')) paddle.fluid.layers.pool2d (ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'name', 'exclusive', 'data_format'], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, None, True, 'NCHW')), ('document', '630cae697d46b4b575b15d56cf8be25a'))
paddle.fluid.layers.pool3d (ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'name', 'exclusive'], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, None, True)), ('document', '053b1a855f13a066d005759171724bc6')) paddle.fluid.layers.pool3d (ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'name', 'exclusive', 'data_format'], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, None, True, 'NCDHW')), ('document', 'db0035a3132b1dfb12e53c57591fb9f6'))
paddle.fluid.layers.adaptive_pool2d (ArgSpec(args=['input', 'pool_size', 'pool_type', 'require_index', 'name'], varargs=None, keywords=None, defaults=('max', False, None)), ('document', '52343203de40afe29607397e13aaf0d2')) paddle.fluid.layers.adaptive_pool2d (ArgSpec(args=['input', 'pool_size', 'pool_type', 'require_index', 'name'], varargs=None, keywords=None, defaults=('max', False, None)), ('document', '52343203de40afe29607397e13aaf0d2'))
paddle.fluid.layers.adaptive_pool3d (ArgSpec(args=['input', 'pool_size', 'pool_type', 'require_index', 'name'], varargs=None, keywords=None, defaults=('max', False, None)), ('document', '55db6ae7275fb9678a6814aebab81a9c')) paddle.fluid.layers.adaptive_pool3d (ArgSpec(args=['input', 'pool_size', 'pool_type', 'require_index', 'name'], varargs=None, keywords=None, defaults=('max', False, None)), ('document', '55db6ae7275fb9678a6814aebab81a9c'))
paddle.fluid.layers.batch_norm (ArgSpec(args=['input', 'act', 'is_test', 'momentum', 'epsilon', 'param_attr', 'bias_attr', 'data_layout', 'in_place', 'name', 'moving_mean_name', 'moving_variance_name', 'do_model_average_for_mean_and_var', 'fuse_with_relu', 'use_global_stats'], varargs=None, keywords=None, defaults=(None, False, 0.9, 1e-05, None, None, 'NCHW', False, None, None, None, False, False, False)), ('document', '9e5a9f4f6d82d34a33d9ca632379cbcc')) paddle.fluid.layers.batch_norm (ArgSpec(args=['input', 'act', 'is_test', 'momentum', 'epsilon', 'param_attr', 'bias_attr', 'data_layout', 'in_place', 'name', 'moving_mean_name', 'moving_variance_name', 'do_model_average_for_mean_and_var', 'fuse_with_relu', 'use_global_stats'], varargs=None, keywords=None, defaults=(None, False, 0.9, 1e-05, None, None, 'NCHW', False, None, None, None, False, False, False)), ('document', '9e5a9f4f6d82d34a33d9ca632379cbcc'))
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
...@@ -83,10 +84,11 @@ HOSTDEVICE inline int AdaptEndIndex(int ph, int input_size, int output_size) { ...@@ -83,10 +84,11 @@ HOSTDEVICE inline int AdaptEndIndex(int ph, int input_size, int output_size) {
/* /*
* \brief Getting pooling results, and calculating gradient. * \brief Getting pooling results, and calculating gradient.
* *
* In pool2d, all tensors are in NCHW format. Where N is batch size, C is the * In pool2d, all Tensors are in NCHW or NHWC format. Where N is batch size, C
* number of channels, H and W is the height and width of feature. * is the number of channels, H and W is the height and width of feature.
* In pool3d, all tensors are in NCDHW format. Where N is batch size, C is the * In pool3d, all Tensors are in NCDHW or NDHWC format. Where N is batch size, C
* number of channels, D, H and W is the depth, height and width of feature. * is the number of channels, D, H and W is the depth, height and width of
* feature.
* *
* In max pooling, it is possible that the pooling region has multiple maximum * In max pooling, it is possible that the pooling region has multiple maximum
* elements. In this case, we should compute the gradient of the first maximum * elements. In this case, we should compute the gradient of the first maximum
...@@ -115,6 +117,14 @@ class Pool2dFunctor { ...@@ -115,6 +117,14 @@ class Pool2dFunctor {
const std::vector<int>& strides, const std::vector<int>& strides,
const std::vector<int>& paddings, PoolProcess pool_compute, const std::vector<int>& paddings, PoolProcess pool_compute,
bool exclusive, bool adaptive, framework::Tensor* output); bool exclusive, bool adaptive, framework::Tensor* output);
// overload operator() to support argument data_format
void operator()(const DeviceContext& context, const framework::Tensor& input,
const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::string data_format, PoolProcess pool_compute,
bool exclusive, bool adaptive, framework::Tensor* output);
}; };
template <typename DeviceContext, typename PoolProcess, typename T> template <typename DeviceContext, typename PoolProcess, typename T>
...@@ -127,6 +137,15 @@ class Pool2dGradFunctor { ...@@ -127,6 +137,15 @@ class Pool2dGradFunctor {
const std::vector<int>& strides, const std::vector<int>& strides,
const std::vector<int>& paddings, PoolProcess pool_compute, const std::vector<int>& paddings, PoolProcess pool_compute,
bool exclusive, bool adaptive, framework::Tensor* input_grad); bool exclusive, bool adaptive, framework::Tensor* input_grad);
// overload operator() to support argument data_format
void operator()(const DeviceContext& context, const framework::Tensor& input,
const framework::Tensor& output,
const framework::Tensor& output_grad,
const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::string data_format, PoolProcess pool_compute,
bool exclusive, bool adaptive, framework::Tensor* input_grad);
}; };
template <typename DeviceContext, class T> template <typename DeviceContext, class T>
...@@ -139,6 +158,14 @@ class MaxPool2dGradFunctor { ...@@ -139,6 +158,14 @@ class MaxPool2dGradFunctor {
const std::vector<int>& strides, const std::vector<int>& strides,
const std::vector<int>& paddings, const std::vector<int>& paddings,
framework::Tensor* input_grad); framework::Tensor* input_grad);
// overload operator() to support argument data_format
void operator()(const DeviceContext& context, const framework::Tensor& input,
const framework::Tensor& output,
const framework::Tensor& output_grad,
const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::string data_format, framework::Tensor* input_grad);
}; };
template <typename DeviceContext, typename PoolProcess, typename T> template <typename DeviceContext, typename PoolProcess, typename T>
...@@ -149,6 +176,13 @@ class Pool3dFunctor { ...@@ -149,6 +176,13 @@ class Pool3dFunctor {
const std::vector<int>& strides, const std::vector<int>& strides,
const std::vector<int>& paddings, PoolProcess pool_compute, const std::vector<int>& paddings, PoolProcess pool_compute,
bool exclusive, bool adaptive, framework::Tensor* output); bool exclusive, bool adaptive, framework::Tensor* output);
// overload operator() to support argument data_format
void operator()(const DeviceContext& context, const framework::Tensor& input,
const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::string data_format, PoolProcess pool_compute,
bool exclusive, bool adaptive, framework::Tensor* output);
}; };
template <typename DeviceContext, typename PoolProcess, typename T> template <typename DeviceContext, typename PoolProcess, typename T>
...@@ -161,6 +195,15 @@ class Pool3dGradFunctor { ...@@ -161,6 +195,15 @@ class Pool3dGradFunctor {
const std::vector<int>& strides, const std::vector<int>& strides,
const std::vector<int>& paddings, PoolProcess pool_compute, const std::vector<int>& paddings, PoolProcess pool_compute,
bool exclusive, bool adaptive, framework::Tensor* input_grad); bool exclusive, bool adaptive, framework::Tensor* input_grad);
// overload operator() to support argument data_format
void operator()(const DeviceContext& context, const framework::Tensor& input,
const framework::Tensor& output,
const framework::Tensor& output_grad,
const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::string data_format, PoolProcess pool_compute,
bool exclusive, bool adaptive, framework::Tensor* input_grad);
}; };
template <typename DeviceContext, class T> template <typename DeviceContext, class T>
...@@ -173,6 +216,14 @@ class MaxPool3dGradFunctor { ...@@ -173,6 +216,14 @@ class MaxPool3dGradFunctor {
const std::vector<int>& strides, const std::vector<int>& strides,
const std::vector<int>& paddings, const std::vector<int>& paddings,
framework::Tensor* input_grad); framework::Tensor* input_grad);
// overload operator() to support argument data_format
void operator()(const DeviceContext& context, const framework::Tensor& input,
const framework::Tensor& output,
const framework::Tensor& output_grad,
const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::string data_format, framework::Tensor* input_grad);
}; };
/* /*
......
...@@ -12,7 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <string>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/pool_op.h" #include "paddle/fluid/operators/pool_op.h"
#include "paddle/fluid/platform/cudnn_helper.h" #include "paddle/fluid/platform/cudnn_helper.h"
...@@ -27,47 +29,117 @@ using PoolingMode = platform::PoolingMode; ...@@ -27,47 +29,117 @@ using PoolingMode = platform::PoolingMode;
template <typename T> template <typename T>
using ScalingParamType = typename platform::CudnnDataType<T>::ScalingParamType; using ScalingParamType = typename platform::CudnnDataType<T>::ScalingParamType;
DataLayout getLayoutFromStr(std::string data_format) {
if (data_format == "NHWC") {
return DataLayout::kNHWC;
} else if (data_format == "NCHW") {
return DataLayout::kNCHW;
} else if (data_format == "NCDHW") {
return DataLayout::kNCDHW;
} else {
return DataLayout::kNCDHW;
}
}
template <typename T> template <typename T>
class PoolCUDNNOpKernel : public framework::OpKernel<T> { class PoolCUDNNOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true,
"It must use CUDAPlace."); "It must use CUDAPlace.");
const Tensor *input = ctx.Input<Tensor>("X"); const Tensor *input = ctx.Input<Tensor>("X");
Tensor *output = ctx.Output<Tensor>("Out"); Tensor *output = ctx.Output<Tensor>("Out");
output->mutable_data<T>(ctx.GetPlace());
const T *input_data = input->data<T>();
T *output_data = output->mutable_data<T>(ctx.GetPlace());
std::string pooling_type = ctx.Attr<std::string>("pooling_type"); std::string pooling_type = ctx.Attr<std::string>("pooling_type");
bool exclusive = ctx.Attr<bool>("exclusive"); bool exclusive = ctx.Attr<bool>("exclusive");
bool adaptive = ctx.Attr<bool>("adaptive");
std::vector<int> ksize = ctx.Attr<std::vector<int>>("ksize"); std::vector<int> ksize = ctx.Attr<std::vector<int>>("ksize");
std::vector<int> strides = ctx.Attr<std::vector<int>>("strides"); std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings"); std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
if (ctx.Attr<bool>("global_pooling")) { std::string data_format = ctx.Attr<std::string>("data_format");
for (size_t i = 0; i < ksize.size(); ++i) { bool global_pooling = ctx.Attr<bool>("global_pooling");
paddings[i] = 0; std::string padding_algorithm = ctx.Attr<std::string>("padding_algorithm");
ksize[i] = static_cast<int>(input->dims()[i + 2]); const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC");
// update paddings
auto in_x_dims = input->dims();
framework::DDim data_dims;
if (channel_last) {
data_dims = framework::slice_ddim(in_x_dims, 1, in_x_dims.size() - 1);
} else {
data_dims = framework::slice_ddim(in_x_dims, 2, in_x_dims.size());
}
UpdatePadding(&paddings, global_pooling, adaptive, padding_algorithm,
data_dims, strides, ksize);
if (data_dims.size() * 2 == paddings.size()) {
for (size_t i = 0; i < data_dims.size(); ++i) {
paddings.erase(paddings.begin() + i + 1);
} }
} }
// ------------------- cudnn descriptors --------------------- if (global_pooling) {
ScopedTensorDescriptor input_desc; UpdateKsize(&ksize, data_dims);
ScopedTensorDescriptor output_desc; }
ScopedPoolingDescriptor pool_desc;
const std::string str_NCHW = "NCHW", str_NHWC = "NHWC";
const std::string str_NCDHW = "NCDHW", str_NDHWC = "NDHWC";
// -----------------transformed tensor ------------------------
Tensor transformed_input(input->type());
Tensor transformed_output(output->type());
DataLayout layout; DataLayout layout;
if (strides.size() == 2U) { if (data_format == str_NDHWC) {
layout = DataLayout::kNCHW;
} else {
layout = DataLayout::kNCDHW; layout = DataLayout::kNCDHW;
auto &dev_ctx =
ctx.template device_context<paddle::platform::CUDADeviceContext>();
std::vector<int> axis{0, 4, 1, 2, 3};
// input
transformed_input.Resize(input->dims());
auto in_dims_vec = framework::vectorize(input->dims());
in_dims_vec[1] = input->dims()[4];
in_dims_vec[2] = input->dims()[1];
in_dims_vec[3] = input->dims()[2];
in_dims_vec[4] = input->dims()[3];
transformed_input.Resize(framework::make_ddim(in_dims_vec));
transformed_input.mutable_data(ctx.GetPlace(), input->type());
math::Transpose<paddle::platform::CUDADeviceContext, T, 5> trans5;
trans5(dev_ctx, *input, &transformed_input, axis);
// output
transformed_output.Resize(output->dims());
auto out_dims_vec = framework::vectorize(output->dims());
out_dims_vec[1] = output->dims()[4];
out_dims_vec[2] = output->dims()[1];
out_dims_vec[3] = output->dims()[2];
out_dims_vec[4] = output->dims()[3];
transformed_output.Resize(framework::make_ddim(out_dims_vec));
} else {
layout = getLayoutFromStr(data_format);
transformed_input = *input;
transformed_output = *output;
} }
const T *tranformed_input_data = transformed_input.data<T>();
T *tranformed_output_data = transformed_output.mutable_data<T>(
transformed_output.dims(), ctx.GetPlace());
// ------------------- cudnn descriptors ---------------------
ScopedTensorDescriptor input_desc;
ScopedTensorDescriptor output_desc;
ScopedPoolingDescriptor pool_desc;
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>( cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
layout, framework::vectorize<int>(input->dims())); layout, framework::vectorize<int>(transformed_input.dims()));
cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor<T>( cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor<T>(
layout, framework::vectorize<int>(output->dims())); layout, framework::vectorize<int>(transformed_output.dims()));
PoolingMode pooling_mode; PoolingMode pooling_mode;
if (pooling_type == "max") { if (pooling_type == "max") {
...@@ -83,9 +155,19 @@ class PoolCUDNNOpKernel : public framework::OpKernel<T> { ...@@ -83,9 +155,19 @@ class PoolCUDNNOpKernel : public framework::OpKernel<T> {
// ------------------- cudnn pool algorithm --------------------- // ------------------- cudnn pool algorithm ---------------------
auto handle = ctx.cuda_device_context().cudnn_handle(); auto handle = ctx.cuda_device_context().cudnn_handle();
ScalingParamType<T> alpha = 1.0f, beta = 0.0f; ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
CUDNN_ENFORCE(platform::dynload::cudnnPoolingForward( CUDNN_ENFORCE(platform::dynload::cudnnPoolingForward(
handle, cudnn_pool_desc, &alpha, cudnn_input_desc, input_data, &beta, handle, cudnn_pool_desc, &alpha, cudnn_input_desc,
cudnn_output_desc, output_data)); tranformed_input_data, &beta, cudnn_output_desc,
tranformed_output_data));
// add
if (data_format == str_NDHWC) {
auto &dev_ctx =
ctx.template device_context<paddle::platform::CUDADeviceContext>();
std::vector<int> axis{0, 2, 3, 4, 1};
math::Transpose<paddle::platform::CUDADeviceContext, T, 5> trans5_v2;
trans5_v2(dev_ctx, transformed_output, output, axis);
}
} }
}; };
...@@ -93,8 +175,8 @@ template <typename T> ...@@ -93,8 +175,8 @@ template <typename T>
class PoolCUDNNGradOpKernel : public framework::OpKernel<T> { class PoolCUDNNGradOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true,
"It must use CUDAPlace."); "It must use CUDAPlace.");
const Tensor *input = ctx.Input<Tensor>("X"); const Tensor *input = ctx.Input<Tensor>("X");
const Tensor *output = ctx.Input<Tensor>("Out"); const Tensor *output = ctx.Input<Tensor>("Out");
...@@ -104,37 +186,109 @@ class PoolCUDNNGradOpKernel : public framework::OpKernel<T> { ...@@ -104,37 +186,109 @@ class PoolCUDNNGradOpKernel : public framework::OpKernel<T> {
std::string pooling_type = ctx.Attr<std::string>("pooling_type"); std::string pooling_type = ctx.Attr<std::string>("pooling_type");
bool exclusive = ctx.Attr<bool>("exclusive"); bool exclusive = ctx.Attr<bool>("exclusive");
bool adaptive = ctx.Attr<bool>("adaptive");
std::vector<int> ksize = ctx.Attr<std::vector<int>>("ksize"); std::vector<int> ksize = ctx.Attr<std::vector<int>>("ksize");
std::vector<int> strides = ctx.Attr<std::vector<int>>("strides"); std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings"); std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
std::string data_format = ctx.Attr<std::string>("data_format");
if (ctx.Attr<bool>("global_pooling")) { bool global_pooling = ctx.Attr<bool>("global_pooling");
for (size_t i = 0; i < ksize.size(); ++i) { std::string padding_algorithm = ctx.Attr<std::string>("padding_algorithm");
paddings[i] = 0; const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC");
ksize[i] = static_cast<int>(input->dims()[i + 2]);
// update paddings
auto in_x_dims = input->dims();
framework::DDim data_dims;
if (channel_last) {
data_dims = framework::slice_ddim(in_x_dims, 1, in_x_dims.size() - 1);
} else {
data_dims = framework::slice_ddim(in_x_dims, 2, in_x_dims.size());
}
UpdatePadding(&paddings, global_pooling, adaptive, padding_algorithm,
data_dims, strides, ksize);
if (data_dims.size() * 2 == paddings.size()) {
for (size_t i = 0; i < data_dims.size(); ++i) {
paddings.erase(paddings.begin() + i + 1);
} }
} }
const T *input_data = input->data<T>(); if (global_pooling) {
const T *output_data = output->data<T>(); UpdateKsize(&ksize, data_dims);
const T *output_grad_data = output_grad->data<T>(); }
// ------------------- cudnn descriptors --------------------- // ------- tensor grad --------------
ScopedTensorDescriptor input_desc; Tensor transformed_input(input->type());
ScopedTensorDescriptor output_desc; Tensor transformed_output(output->type());
ScopedPoolingDescriptor pool_desc; Tensor transformed_output_grad(output_grad->type());
input_grad->mutable_data<T>(ctx.GetPlace());
Tensor transformed_input_grad(input_grad->type());
DataLayout layout; DataLayout layout;
const std::string str_NCHW = "NCHW", str_NHWC = "NHWC";
const std::string str_NCDHW = "NCDHW", str_NDHWC = "NDHWC";
if (data_format == str_NDHWC) {
layout = DataLayout::kNCDHW;
auto &dev_ctx =
ctx.template device_context<paddle::platform::CUDADeviceContext>();
std::vector<int> axis{0, 4, 1, 2, 3};
// input
transformed_input.Resize(input->dims());
auto in_dims_vec = framework::vectorize(input->dims());
in_dims_vec[1] = input->dims()[4];
in_dims_vec[2] = input->dims()[1];
in_dims_vec[3] = input->dims()[2];
in_dims_vec[4] = input->dims()[3];
transformed_input.Resize(framework::make_ddim(in_dims_vec));
transformed_input.mutable_data(ctx.GetPlace(), input->type());
math::Transpose<paddle::platform::CUDADeviceContext, T, 5> trans5;
trans5(dev_ctx, *input, &transformed_input, axis);
// output
transformed_output.Resize(output->dims());
auto out_dims_vec = framework::vectorize(output->dims());
out_dims_vec[1] = output->dims()[4];
out_dims_vec[2] = output->dims()[1];
out_dims_vec[3] = output->dims()[2];
out_dims_vec[4] = output->dims()[3];
transformed_output.Resize(framework::make_ddim(out_dims_vec));
transformed_output.mutable_data(ctx.GetPlace(), output->type());
math::Transpose<paddle::platform::CUDADeviceContext, T, 5> trans5_v2;
trans5_v2(dev_ctx, *output, &transformed_output, axis);
// output grad
transformed_output_grad.Resize(framework::make_ddim(out_dims_vec));
transformed_output_grad.mutable_data(ctx.GetPlace(), output_grad->type());
math::Transpose<paddle::platform::CUDADeviceContext, T, 5> trans5_v3;
trans5_v3(dev_ctx, *output_grad, &transformed_output_grad, axis);
// input grad
transformed_input_grad.Resize(framework::make_ddim(in_dims_vec));
if (strides.size() == 2U) {
layout = DataLayout::kNCHW;
} else { } else {
layout = DataLayout::kNCDHW; layout = getLayoutFromStr(data_format);
transformed_input = *input;
transformed_output = *output;
transformed_output_grad = *output_grad;
transformed_input_grad = *input_grad;
} }
const T *input_data = transformed_input.data<T>();
const T *output_data = transformed_output.data<T>();
const T *output_grad_data = transformed_output_grad.data<T>();
// ------------------- cudnn descriptors ---------------------
ScopedTensorDescriptor input_desc;
ScopedTensorDescriptor output_desc;
ScopedPoolingDescriptor pool_desc;
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>( cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
layout, framework::vectorize<int>(input->dims())); layout, framework::vectorize<int>(transformed_input.dims()));
cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor<T>( cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor<T>(
layout, framework::vectorize<int>(output->dims())); layout, framework::vectorize<int>(transformed_output.dims()));
PoolingMode pooling_mode; PoolingMode pooling_mode;
if (pooling_type == "max") { if (pooling_type == "max") {
...@@ -155,13 +309,21 @@ class PoolCUDNNGradOpKernel : public framework::OpKernel<T> { ...@@ -155,13 +309,21 @@ class PoolCUDNNGradOpKernel : public framework::OpKernel<T> {
auto handle = ctx.cuda_device_context().cudnn_handle(); auto handle = ctx.cuda_device_context().cudnn_handle();
ScalingParamType<T> alpha = 1.0f, beta = 0.0f; ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
if (input_grad) { if (input_grad) {
T *input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace()); T *input_grad_data = transformed_input_grad.mutable_data<T>(
transformed_input_grad.dims(), ctx.GetPlace());
// Because beta is zero, it is unnecessary to reset input_grad. // Because beta is zero, it is unnecessary to reset input_grad.
CUDNN_ENFORCE(platform::dynload::cudnnPoolingBackward( CUDNN_ENFORCE(platform::dynload::cudnnPoolingBackward(
handle, cudnn_pool_desc, &alpha, cudnn_output_desc, output_data, handle, cudnn_pool_desc, &alpha, cudnn_output_desc, output_data,
cudnn_output_desc, output_grad_data, cudnn_input_desc, input_data, cudnn_output_desc, output_grad_data, cudnn_input_desc, input_data,
&beta, cudnn_input_desc, input_grad_data)); &beta, cudnn_input_desc, input_grad_data));
if (data_format == str_NDHWC) {
auto &dev_ctx =
ctx.template device_context<paddle::platform::CUDADeviceContext>();
std::vector<int> axis{0, 2, 3, 4, 1};
math::Transpose<paddle::platform::CUDADeviceContext, T, 5> trans5_v4;
trans5_v4(dev_ctx, transformed_input_grad, input_grad, axis);
}
} }
} }
}; };
......
...@@ -24,29 +24,32 @@ limitations under the License. */ ...@@ -24,29 +24,32 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
int PoolOutputSize(int input_size, int filter_size, int padding, int stride, int PoolOutputSize(int input_size, int filter_size, int padding_1,
bool ceil_mode) { int padding_2, int stride, bool ceil_mode) {
int output_size; int output_size;
if (!ceil_mode) { if (!ceil_mode) {
output_size = (input_size - filter_size + 2 * padding) / stride + 1; output_size =
(input_size - filter_size + padding_1 + padding_2) / stride + 1;
} else { } else {
output_size = output_size =
(input_size - filter_size + 2 * padding + stride - 1) / stride + 1; (input_size - filter_size + padding_1 + padding_2 + stride - 1) /
stride +
1;
} }
PADDLE_ENFORCE(output_size > 0, PADDLE_ENFORCE_GT(
"Due to the settings of padding(%d), filter_size(%d) and " output_size, 0,
"stride(%d), the output size is less than 0, please check " "Due to the settings of padding(%d,%d), filter_size(%d) and "
"again. Input_size:%d", "stride(%d), the output size is less than 0, please check "
padding, filter_size, stride, input_size); "again. Input_size:%d",
padding_1, padding_2, filter_size, stride, input_size);
return output_size; return output_size;
} }
void PoolOp::InferShape(framework::InferShapeContext* ctx) const { void PoolOp::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasInput("X"), "X(Input) of Pooling should not be null."); PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
PADDLE_ENFORCE(ctx->HasOutput("Out"), "X(Input) of Pooling should not be null.");
"Out(Output) of Pooling should not be null."); PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
"Out(Output) of Pooling should not be null.");
auto in_x_dims = ctx->GetInputDim("X");
std::string pooling_type = ctx->Attrs().Get<std::string>("pooling_type"); std::string pooling_type = ctx->Attrs().Get<std::string>("pooling_type");
std::vector<int> ksize = ctx->Attrs().Get<std::vector<int>>("ksize"); std::vector<int> ksize = ctx->Attrs().Get<std::vector<int>>("ksize");
...@@ -54,38 +57,60 @@ void PoolOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -54,38 +57,60 @@ void PoolOp::InferShape(framework::InferShapeContext* ctx) const {
std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings"); std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
bool ceil_mode = ctx->Attrs().Get<bool>("ceil_mode"); bool ceil_mode = ctx->Attrs().Get<bool>("ceil_mode");
bool adaptive = ctx->Attrs().Get<bool>("adaptive"); bool adaptive = ctx->Attrs().Get<bool>("adaptive");
bool global_pooling = ctx->Attrs().Get<bool>("global_pooling");
std::string data_format = ctx->Attrs().Get<std::string>("data_format");
std::string padding_algorithm =
ctx->Attrs().Get<std::string>("padding_algorithm");
PADDLE_ENFORCE(in_x_dims.size() == 4 || in_x_dims.size() == 5, auto in_x_dims = ctx->GetInputDim("X");
"Pooling intput should be 4-D or 5-D tensor."); PADDLE_ENFORCE_EQ(in_x_dims.size() == 4 || in_x_dims.size() == 5, true,
"Pooling intput should be 4-D or 5-D tensor.");
if (ctx->Attrs().Get<bool>("global_pooling")) {
ksize.resize(static_cast<size_t>(in_x_dims.size()) - 2);
for (size_t i = 0; i < ksize.size(); ++i) {
paddings[i] = 0;
ksize[i] = static_cast<int>(in_x_dims[i + 2]);
}
}
PADDLE_ENFORCE(in_x_dims.size() - ksize.size() == 2U, PADDLE_ENFORCE_EQ(in_x_dims.size() - ksize.size(), 2U,
"Input size and pooling size should be consistent."); "Input size and pooling size should be consistent.");
PADDLE_ENFORCE_EQ(ksize.size(), strides.size(), PADDLE_ENFORCE_EQ(ksize.size(), strides.size(),
"Strides size and pooling size should be the same."); "Strides size and pooling size should be the same.");
PADDLE_ENFORCE_EQ(ksize.size(), paddings.size(),
"Paddings size and pooling size should be the same.");
std::vector<int64_t> output_shape({in_x_dims[0], in_x_dims[1]}); const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC");
// update paddings if "SAME" or global_pooling
framework::DDim data_dims;
if (channel_last) {
data_dims = framework::slice_ddim(in_x_dims, 1, in_x_dims.size() - 1);
} else {
data_dims = framework::slice_ddim(in_x_dims, 2, in_x_dims.size());
}
UpdatePadding(&paddings, global_pooling, adaptive, padding_algorithm,
data_dims, strides, ksize);
if (global_pooling) {
UpdateKsize(&ksize, data_dims);
}
std::vector<int64_t> output_shape;
if (adaptive) { if (adaptive) {
output_shape.insert(output_shape.end(), ksize.begin(), ksize.end()); output_shape.insert(output_shape.end(), ksize.begin(), ksize.end());
} else { } else {
for (size_t i = 0; i < ksize.size(); ++i) { for (size_t i = 0; i < data_dims.size(); ++i) {
if (!ctx->IsRuntime() && in_x_dims[i + 2] <= 0) { if ((!ctx->IsRuntime()) && (data_dims[i] < 0)) {
output_shape.push_back(-1); output_shape.push_back(in_x_dims[i]);
} else { } else {
output_shape.push_back(PoolOutputSize( output_shape.push_back(
in_x_dims[i + 2], ksize[i], paddings[i], strides[i], ceil_mode)); PoolOutputSize(data_dims[i], ksize[i], paddings[2 * i],
paddings[2 * i + 1], strides[i], ceil_mode));
} }
} }
} }
// output_N = input_N
output_shape.insert(output_shape.begin(), in_x_dims[0]);
// output_C = input_C
if (channel_last) {
output_shape.push_back(in_x_dims[in_x_dims.size() - 1]);
} else {
output_shape.insert(output_shape.begin() + 1, in_x_dims[1]);
}
ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
ctx->ShareLoD("X", "Out"); ctx->ShareLoD("X", "Out");
} }
...@@ -93,7 +118,9 @@ void PoolOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -93,7 +118,9 @@ void PoolOp::InferShape(framework::InferShapeContext* ctx) const {
framework::OpKernelType PoolOp::GetExpectedKernelType( framework::OpKernelType PoolOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
framework::LibraryType library_{framework::LibraryType::kPlain}; framework::LibraryType library_{framework::LibraryType::kPlain};
std::string data_format = ctx.Attr<std::string>("data_format"); // std::string data_format = ctx.Attr<std::string>("data_format"); // change:
// delete
std::string data_format = "AnyLayout";
framework::DataLayout layout_ = framework::StringToDataLayout(data_format); framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
...@@ -114,16 +141,18 @@ framework::OpKernelType PoolOp::GetExpectedKernelType( ...@@ -114,16 +141,18 @@ framework::OpKernelType PoolOp::GetExpectedKernelType(
} }
void PoolOpGrad::InferShape(framework::InferShapeContext* ctx) const { void PoolOpGrad::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null."); PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, "Input(X) must not be null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("X")), true,
"Input(X@GRAD) should not be null."); "Input(X@GRAD) should not be null.");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
} }
framework::OpKernelType PoolOpGrad::GetExpectedKernelType( framework::OpKernelType PoolOpGrad::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
framework::LibraryType library_{framework::LibraryType::kPlain}; framework::LibraryType library_{framework::LibraryType::kPlain};
std::string data_format = ctx.Attr<std::string>("data_format"); // std::string data_format = ctx.Attr<std::string>("data_format"); //
// change:delete
std::string data_format = "AnyLayout";
framework::DataLayout layout_ = framework::StringToDataLayout(data_format); framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
...@@ -186,8 +215,8 @@ void Pool2dOpMaker::Make() { ...@@ -186,8 +215,8 @@ void Pool2dOpMaker::Make() {
// TypedAttrChecker don't support vector type.) // TypedAttrChecker don't support vector type.)
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>(
"paddings", "paddings",
"(vector<int>, default {0,0}), paddings(height, width) of pooling " "(vector<int>, default {0,0}), paddings(height_top, height_bottom, "
"operator." "width_left, wifth_right) of pooling operator."
"If global_pooling = true, paddings and kernel size will be ignored.") "If global_pooling = true, paddings and kernel size will be ignored.")
.SetDefault({0, 0}); .SetDefault({0, 0});
AddAttr<bool>( AddAttr<bool>(
...@@ -206,7 +235,7 @@ void Pool2dOpMaker::Make() { ...@@ -206,7 +235,7 @@ void Pool2dOpMaker::Make() {
AddAttr<bool>( AddAttr<bool>(
"use_cudnn", "use_cudnn",
"(bool, default false) Only used in cudnn kernel, need install cudnn") "(bool, default false) Only used in cudnn kernel, need install cudnn.")
.SetDefault(false); .SetDefault(false);
AddAttr<bool>( AddAttr<bool>(
"ceil_mode", "ceil_mode",
...@@ -215,7 +244,7 @@ void Pool2dOpMaker::Make() { ...@@ -215,7 +244,7 @@ void Pool2dOpMaker::Make() {
"the floor function will be used.") "the floor function will be used.")
.SetDefault(false); .SetDefault(false);
AddAttr<bool>("use_mkldnn", AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel") "(bool, default false) Only used in mkldnn kernel.")
.SetDefault(false); .SetDefault(false);
AddAttr<bool>("use_quantizer", AddAttr<bool>("use_quantizer",
"(bool, default false) " "(bool, default false) "
...@@ -229,18 +258,24 @@ void Pool2dOpMaker::Make() { ...@@ -229,18 +258,24 @@ void Pool2dOpMaker::Make() {
"An optional string from: \"NHWC\", \"NCHW\". " "An optional string from: \"NHWC\", \"NCHW\". "
"Defaults to \"NHWC\". Specify the data format of the output data, " "Defaults to \"NHWC\". Specify the data format of the output data, "
"the input will be transformed automatically. ") "the input will be transformed automatically. ")
.SetDefault("AnyLayout"); .SetDefault("NCHW");
AddAttr<bool>("is_test", AddAttr<bool>("is_test",
"(bool, default false) Set to true for inference only, false " "(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true.") "for training. Some layers may run faster when this is true.")
.SetDefault(false); .SetDefault(false);
AddAttr<std::string>(
"padding_algorithm",
"(string, default \"EXPLICIT\") An optional string from: \"EXPLICIT\","
"\"SAME\",\"VALID\". Set to \"EXPLICIT\" for explicit padding. "
"Set to \"SAME\" or \"VALID\" for algorithm of padding. ")
.SetDefault("EXPLICIT");
// TODO(dzhwinter): need to registered layout transform function // TODO(dzhwinter): need to registered layout transform function
AddComment(R"DOC( AddComment(R"DOC(
The pooling2d operation calculates the output based on The pooling2d operation calculates the output based on
the input, pooling_type and ksize, strides, paddings parameters. the input, pooling_type and ksize, strides, paddings parameters.
Input(X) and output(Out) are in NCHW format, where N is batch size, C is the Input(X) and output(Out) are in NCHW or NHWC format, where N is batch size, C is the
number of channels, H is the height of the feature, and W is the width of the feature. number of channels, H is the height of the feature, and W is the width of the feature.
Parameters(ksize, strides, paddings) are two elements. Parameters(ksize, strides, paddings) are two elements.
These two elements represent height and width, respectively. These two elements represent height and width, respectively.
...@@ -256,30 +291,47 @@ Example: ...@@ -256,30 +291,47 @@ Example:
Out shape: $(N, C, H_{out}, W_{out})$ Out shape: $(N, C, H_{out}, W_{out})$
For pool_padding = "SAME":
$$
H_{out} = \\frac{(H_{in} + strides[0] - 1)}{strides[0]}
$$
$$
W_{out} = \\frac{(W_{in} + strides[1] - 1)}{strides[1]}
$$
For pool_padding = "VALID":
$$
H_{out} = \\frac{(H_{in} - ksize[0] + strides[0])}{strides[0]}
$$
$$
W_{out} = \\frac{(W_{in} - ksize[1] + strides[1])}{strides[1]}
$$
For ceil_mode = false: For ceil_mode = false:
$$ $$
H_{out} = \\frac{(H_{in} - ksize[0] + 2 * paddings[0])}{strides[0]} + 1 H_{out} = \\frac{(H_{in} - ksize[0] + pad_height_top + pad_height_bottom}{strides[0]} + 1
$$ $$
$$ $$
W_{out} = \\frac{(W_{in} - ksize[1] + 2 * paddings[1])}{strides[1]} + 1 W_{out} = \\frac{(W_{in} - ksize[1] + pad_width_left + pad_width_right}{strides[1]} + 1
$$ $$
For ceil_mode = true: For ceil_mode = true:
$$ $$
H_{out} = \\frac{(H_{in} - ksize[0] + 2 * paddings[0] + strides[0] - 1)}{strides[0]} + 1 H_{out} = \\frac{(H_{in} - ksize[0] + pad_height_top + pad_height_bottom + strides[0] - 1)}{strides[0]} + 1
$$ $$
$$ $$
W_{out} = \\frac{(W_{in} - ksize[1] + 2 * paddings[1] + strides[1] - 1)}{strides[1]} + 1 W_{out} = \\frac{(W_{in} - ksize[1] + pad_width_left + pad_width_right + strides[1] - 1)}{strides[1]} + 1
$$ $$
For exclusive = false: For exclusive = false:
$$ $$
hstart = i * strides[0] - paddings[0] hstart = i * strides[0] - pad_height_top
$$ $$
$$ $$
hend = hstart + ksize[0] hend = hstart + ksize[0]
$$ $$
$$ $$
wstart = j * strides[1] - paddings[1] wstart = j * strides[1] - pad_width_left
$$ $$
$$ $$
wend = wstart + ksize[1] wend = wstart + ksize[1]
...@@ -290,13 +342,13 @@ Example: ...@@ -290,13 +342,13 @@ Example:
For exclusive = true: For exclusive = true:
$$ $$
hstart = max(0, i * strides[0] - paddings[0]) hstart = max(0, i * strides[0] - pad_height_top)
$$ $$
$$ $$
hend = min(H, hstart + ksize[0]) hend = min(H, hstart + ksize[0])
$$ $$
$$ $$
wstart = max(0, j * strides[1] - paddings[1]) wstart = max(0, j * strides[1] - pad_width_left)
$$ $$
$$ $$
wend = min(W, wstart + ksize[1]) wend = min(W, wstart + ksize[1])
...@@ -319,13 +371,14 @@ class PoolOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput { ...@@ -319,13 +371,14 @@ class PoolOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput {
void Pool3dOpMaker::Make() { void Pool3dOpMaker::Make() {
AddInput("X", AddInput("X",
"(Tensor) The input tensor of pooling operator. " "(Tensor) The input tensor of pooling operator. "
"The format of input tensor is NCDHW, where N is batch size, C is " "The format of input tensor is NCDHW or NDHWC, where N is batch "
"size, C is "
"the number of channels, and D, H and W is the depth, height and " "the number of channels, and D, H and W is the depth, height and "
"width of " "width of "
"the feature, respectively."); "the feature, respectively.");
AddOutput("Out", AddOutput("Out",
"(Tensor) The output tensor of pooling operator." "(Tensor) The output tensor of pooling operator."
"The format of output tensor is also NCDHW, " "The format of output tensor is also NCDHW or NDHWC, "
"where N is batch size, C is " "where N is batch size, C is "
"the number of channels, and D, H and W is the depth, height and " "the number of channels, and D, H and W is the depth, height and "
"width of the feature, respectively."); "width of the feature, respectively.");
...@@ -355,8 +408,10 @@ void Pool3dOpMaker::Make() { ...@@ -355,8 +408,10 @@ void Pool3dOpMaker::Make() {
// TypedAttrChecker don't support vector type.) // TypedAttrChecker don't support vector type.)
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>(
"paddings", "paddings",
"(vector<int>, default {0,0,0}), paddings(depth, height, " "(vector<int>, default {0,0,0}), paddings(pad_depth_front, "
"width) of pooling operator. " "pad_depth_back, "
"pad_height_top, pad_height_bottom, pad_width_left, pad_width_right"
") of pooling operator. "
"If global_pooling = true, ksize and paddings will be ignored.") "If global_pooling = true, ksize and paddings will be ignored.")
.SetDefault({0, 0, 0}); // TODO(Chengduo): Add checker. (Currently, .SetDefault({0, 0, 0}); // TODO(Chengduo): Add checker. (Currently,
// TypedAttrChecker don't support vector type.) // TypedAttrChecker don't support vector type.)
...@@ -376,7 +431,7 @@ void Pool3dOpMaker::Make() { ...@@ -376,7 +431,7 @@ void Pool3dOpMaker::Make() {
AddAttr<bool>( AddAttr<bool>(
"use_cudnn", "use_cudnn",
"(bool, default false) Only used in cudnn kernel, need install cudnn") "(bool, default false) Only used in cudnn kernel, need install cudnn.")
.SetDefault(false); .SetDefault(false);
AddAttr<bool>( AddAttr<bool>(
"ceil_mode", "ceil_mode",
...@@ -389,11 +444,17 @@ void Pool3dOpMaker::Make() { ...@@ -389,11 +444,17 @@ void Pool3dOpMaker::Make() {
.SetDefault(false); .SetDefault(false);
AddAttr<std::string>( AddAttr<std::string>(
"data_format", "data_format",
"(string, default NCHW) Only used in " "(string, default NCDHW) Only used in "
"An optional string from: \"NHWC\", \"NCHW\". " "An optional string from: \"NDHWC\", \"NCDHW\". "
"Defaults to \"NHWC\". Specify the data format of the output data, " "Defaults to \"NDHWC\". Specify the data format of the output data, "
"the input will be transformed automatically. ") "the input will be transformed automatically. ")
.SetDefault("AnyLayout"); .SetDefault("NCDHW");
AddAttr<std::string>(
"padding_algorithm",
"(string, default \"EXPLICIT\") An optional string from: \"EXPLICIT\","
"\"SAME\",\"VALID\". Set to \"EXPLICIT\" for explicit padding. "
"Set to \"SAME\" or \"VALID\" for algorithm of padding. ")
.SetDefault("EXPLICIT");
// TODO(dzhwinter): need to registered layout transform function // TODO(dzhwinter): need to registered layout transform function
AddComment(R"DOC( AddComment(R"DOC(
...@@ -401,7 +462,7 @@ Pool3d Operator. ...@@ -401,7 +462,7 @@ Pool3d Operator.
The pooling3d operation calculates the output based on The pooling3d operation calculates the output based on
the input, pooling_type, ksize, strides, and paddings parameters. the input, pooling_type, ksize, strides, and paddings parameters.
Input(X) and output(Out) are in NCDHW format, where N is batch Input(X) and output(Out) are in NCDHW or NDHWC format, where N is batch
size, C is the number of channels, and D, H and W are the depth, height and size, C is the number of channels, and D, H and W are the depth, height and
width of the feature, respectively. Parameters(ksize, strides, paddings) width of the feature, respectively. Parameters(ksize, strides, paddings)
are three elements. These three elements represent depth, height and are three elements. These three elements represent depth, height and
...@@ -412,42 +473,65 @@ Example: ...@@ -412,42 +473,65 @@ Example:
X shape: $(N, C, D_{in}, H_{in}, W_{in})$ X shape: $(N, C, D_{in}, H_{in}, W_{in})$
Output: Output:
Out shape: $(N, C, D_{out}, H_{out}, W_{out})$ Out shape: $(N, C, D_{out}, H_{out}, W_{out})$
For pool_padding = "SAME":
$$
D_{out} = \\frac{(D_{in} + strides[0] - 1)}{strides[0]}
$$
$$
H_{out} = \\frac{(H_{in} + strides[1] - 1)}{strides[1]}
$$
$$
W_{out} = \\frac{(W_{in} + strides[2] - 1)}{strides[2]}
$$
For pool_padding = "VALID":
$$
D_{out} = \\frac{(D_{in} - ksize[0] + strides[0])}{strides[0]}
$$
$$
H_{out} = \\frac{(H_{in} - ksize[1] + strides[1])}{strides[1]}
$$
$$
W_{out} = \\frac{(W_{in} - ksize[2] + strides[2])}{strides[2]}
$$
For ceil_mode = false: For ceil_mode = false:
$$ $$
D_{out} = \\frac{(D_{in} - ksize[0] + 2 * paddings[0])}{strides[0]} + 1 D_{out} = \\frac{(D_{in} - ksize[0] + pad_depth_front + pad_depth_back)}{strides[0]} + 1
$$ $$
$$ $$
H_{out} = \\frac{(H_{in} - ksize[1] + 2 * paddings[1])}{strides[2]} + 1 H_{out} = \\frac{(H_{in} - ksize[1] + pad_height_top + pad_height_bottom)}{strides[1]} + 1
$$ $$
$$ $$
W_{out} = \\frac{(W_{in} - ksize[2] + 2 * paddings[2])}{strides[2]} + 1 W_{out} = \\frac{(W_{in} - ksize[2] + pad_width_left + pad_width_right)}{strides[2]} + 1
$$ $$
For ceil_mode = true: For ceil_mode = true:
$$ $$
D_{out} = \\frac{(D_{in} - ksize[0] + 2 * paddings[0] + strides[0] -1)}{strides[0]} + 1 D_{out} = \\frac{(D_{in} - ksize[0] + pad_depth_front + pad_depth_back + strides[0] -1)}{strides[0]} + 1
$$ $$
$$ $$
H_{out} = \\frac{(H_{in} - ksize[1] + 2 * paddings[1] + strides[1] -1)}{strides[1]} + 1 H_{out} = \\frac{(H_{in} - ksize[1] + pad_height_top + pad_height_bottom + strides[1] -1)}{strides[1]} + 1
$$ $$
$$ $$
W_{out} = \\frac{(W_{in} - ksize[2] + 2 * paddings[2] + strides[2] -1)}{strides[2]} + 1 W_{out} = \\frac{(W_{in} - ksize[2] + pad_width_left + pad_width_right + strides[2] -1)}{strides[2]} + 1
$$ $$
For exclusive = false: For exclusive = false:
$$ $$
dstart = i * strides[0] - paddings[0] dstart = i * strides[0] - pad_depth_front
$$ $$
$$ $$
dend = dstart + ksize[0] dend = dstart + ksize[0]
$$ $$
$$ $$
hstart = j * strides[1] - paddings[1] hstart = j * strides[1] - pad_height_top
$$ $$
$$ $$
hend = hstart + ksize[1] hend = hstart + ksize[1]
$$ $$
$$ $$
wstart = k * strides[2] - paddings[2] wstart = k * strides[2] - pad_width_left
$$ $$
$$ $$
wend = wstart + ksize[2] wend = wstart + ksize[2]
...@@ -458,16 +542,19 @@ Example: ...@@ -458,16 +542,19 @@ Example:
For exclusive = true: For exclusive = true:
$$ $$
dstart = max(0, i * strides[0] - paddings[0]) dstart = max(0, i * strides[0] - pad_depth_front)
$$ $$
$$ $$
dend = min(D, dstart + ksize[0]) dend = min(D, dstart + ksize[0])
$$ $$
$$ $$
hstart = max(0, j * strides[1] - pad_height_top)
$$
$$
hend = min(H, hstart + ksize[1]) hend = min(H, hstart + ksize[1])
$$ $$
$$ $$
wstart = max(0, k * strides[2] - paddings[2]) wstart = max(0, k * strides[2] - pad_width_left)
$$ $$
$$ $$
wend = min(W, wstart + ksize[2]) wend = min(W, wstart + ksize[2])
......
...@@ -14,13 +14,13 @@ limitations under the License. */ ...@@ -14,13 +14,13 @@ limitations under the License. */
#pragma once #pragma once
#include <algorithm>
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/pooling.h" #include "paddle/fluid/operators/math/pooling.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -57,6 +57,57 @@ class Pool3dOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -57,6 +57,57 @@ class Pool3dOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override; void Make() override;
}; };
inline void UpdatePadding(std::vector<int>* paddings, const bool global_pooling,
const bool adaptive,
const std::string padding_algorithm,
const framework::DDim data_dims,
const std::vector<int>& strides,
const std::vector<int>& ksize) {
// set padding size == data_dims.size() * 2
auto data_shape = framework::vectorize<int>(data_dims);
if (paddings->size() == data_dims.size()) {
for (size_t i = 0; i < data_dims.size(); ++i) {
int copy_pad = *(paddings->begin() + 2 * i);
paddings->insert(paddings->begin() + 2 * i + 1, copy_pad);
}
} else {
PADDLE_ENFORCE_EQ(
data_dims.size() * 2, paddings->size(),
"Paddings size should be the same or twice as the pooling size.");
}
// when padding_desc is "VALID" or "SAME"
if (padding_algorithm == "SAME") {
for (size_t i = 0; i < data_dims.size(); ++i) {
int out_size = (data_dims[i] + strides[i] - 1) / strides[0];
int pad_sum =
std::max((out_size - 1) * strides[i] + ksize[i] - data_shape[i], 0);
int pad_0 = pad_sum / 2;
int pad_1 = pad_sum - pad_0;
*(paddings->begin() + i * 2) = pad_0;
*(paddings->begin() + i * 2 + 1) = pad_1;
}
} else if (padding_algorithm == "VALID") {
for (auto it = paddings->begin(); it != paddings->end(); it++) {
*it = 0;
}
}
// if global_pooling == true or adaptive == true, padding will be ignore
if (global_pooling || adaptive) {
for (auto it = paddings->begin(); it != paddings->end(); it++) {
*it = 0;
}
}
}
inline void UpdateKsize(std::vector<int>* ksize,
const framework::DDim data_dims) {
ksize->resize(static_cast<size_t>(data_dims.size()));
for (size_t i = 0; i < ksize->size(); ++i) {
*(ksize->begin() + i) = static_cast<int>(data_dims[i]);
}
}
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class PoolKernel : public framework::OpKernel<T> { class PoolKernel : public framework::OpKernel<T> {
...@@ -69,14 +120,36 @@ class PoolKernel : public framework::OpKernel<T> { ...@@ -69,14 +120,36 @@ class PoolKernel : public framework::OpKernel<T> {
std::vector<int> ksize = context.Attr<std::vector<int>>("ksize"); std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
std::vector<int> strides = context.Attr<std::vector<int>>("strides"); std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings"); std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
std::string data_format = context.Attr<std::string>("data_format");
bool exclusive = context.Attr<bool>("exclusive"); bool exclusive = context.Attr<bool>("exclusive");
bool adaptive = context.Attr<bool>("adaptive"); bool adaptive = context.Attr<bool>("adaptive");
if (context.Attr<bool>("global_pooling")) { bool global_pooling = context.Attr<bool>("global_pooling");
for (size_t i = 0; i < ksize.size(); ++i) { std::string padding_algorithm =
paddings[i] = 0; context.Attr<std::string>("padding_algorithm");
ksize[i] = static_cast<int>(in_x->dims()[i + 2]);
const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC");
// update paddings
auto in_x_dims = in_x->dims();
framework::DDim data_dims;
if (channel_last) {
data_dims = framework::slice_ddim(in_x_dims, 1, in_x_dims.size() - 1);
} else {
data_dims = framework::slice_ddim(in_x_dims, 2, in_x_dims.size());
}
UpdatePadding(&paddings, global_pooling, adaptive, padding_algorithm,
data_dims, strides, ksize);
if (data_dims.size() * 2 == paddings.size()) {
for (size_t i = 0; i < data_dims.size(); ++i) {
paddings.erase(paddings.begin() + i + 1);
} }
} }
if (global_pooling) {
UpdateKsize(&ksize, data_dims);
}
auto& dev_ctx = context.template device_context<DeviceContext>(); auto& dev_ctx = context.template device_context<DeviceContext>();
switch (ksize.size()) { switch (ksize.size()) {
case 2: { case 2: {
...@@ -85,16 +158,16 @@ class PoolKernel : public framework::OpKernel<T> { ...@@ -85,16 +158,16 @@ class PoolKernel : public framework::OpKernel<T> {
DeviceContext, paddle::operators::math::MaxPool<T>, T> DeviceContext, paddle::operators::math::MaxPool<T>, T>
pool2d_forward; pool2d_forward;
paddle::operators::math::MaxPool<T> pool_process; paddle::operators::math::MaxPool<T> pool_process;
pool2d_forward(dev_ctx, *in_x, ksize, strides, paddings, pool_process, pool2d_forward(dev_ctx, *in_x, ksize, strides, paddings, data_format,
true, false, out); pool_process, true, false, out);
} else if (pooling_type == "avg") { } else if (pooling_type == "avg") {
paddle::operators::math::Pool2dFunctor< paddle::operators::math::Pool2dFunctor<
DeviceContext, paddle::operators::math::AvgPool<T>, T> DeviceContext, paddle::operators::math::AvgPool<T>, T>
pool2d_forward; pool2d_forward;
paddle::operators::math::AvgPool<T> pool_process; paddle::operators::math::AvgPool<T> pool_process;
pool2d_forward(dev_ctx, *in_x, ksize, strides, paddings, pool_process, pool2d_forward(dev_ctx, *in_x, ksize, strides, paddings, data_format,
exclusive, adaptive, out); pool_process, exclusive, adaptive, out);
} }
} break; } break;
case 3: { case 3: {
...@@ -103,15 +176,16 @@ class PoolKernel : public framework::OpKernel<T> { ...@@ -103,15 +176,16 @@ class PoolKernel : public framework::OpKernel<T> {
DeviceContext, paddle::operators::math::MaxPool<T>, T> DeviceContext, paddle::operators::math::MaxPool<T>, T>
pool3d_forward; pool3d_forward;
paddle::operators::math::MaxPool<T> pool_process; paddle::operators::math::MaxPool<T> pool_process;
pool3d_forward(dev_ctx, *in_x, ksize, strides, paddings, pool_process, pool3d_forward(dev_ctx, *in_x, ksize, strides, paddings, data_format,
true, false, out); pool_process, true, false, out);
} else if (pooling_type == "avg") { } else if (pooling_type == "avg") {
paddle::operators::math::Pool3dFunctor< paddle::operators::math::Pool3dFunctor<
DeviceContext, paddle::operators::math::AvgPool<T>, T> DeviceContext, paddle::operators::math::AvgPool<T>, T>
pool3d_forward; pool3d_forward;
paddle::operators::math::AvgPool<T> pool_process; paddle::operators::math::AvgPool<T> pool_process;
pool3d_forward(dev_ctx, *in_x, ksize, strides, paddings, pool_process, pool3d_forward(dev_ctx, *in_x, ksize, strides, paddings, data_format,
exclusive, adaptive, out); pool_process, exclusive, adaptive, out);
} }
} break; } break;
default: { PADDLE_THROW("Pool op only supports 2D and 3D input."); } default: { PADDLE_THROW("Pool op only supports 2D and 3D input."); }
...@@ -135,13 +209,33 @@ class PoolGradKernel : public framework::OpKernel<T> { ...@@ -135,13 +209,33 @@ class PoolGradKernel : public framework::OpKernel<T> {
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings"); std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
bool exclusive = context.Attr<bool>("exclusive"); bool exclusive = context.Attr<bool>("exclusive");
bool adaptive = context.Attr<bool>("adaptive"); bool adaptive = context.Attr<bool>("adaptive");
std::string data_format = context.Attr<std::string>("data_format");
bool global_pooling = context.Attr<bool>("global_pooling");
std::string padding_algorithm =
context.Attr<std::string>("padding_algorithm");
const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC");
if (context.Attr<bool>("global_pooling")) { // update paddings
for (size_t i = 0; i < ksize.size(); ++i) { auto in_x_dims = in_x->dims();
paddings[i] = 0; framework::DDim data_dims;
ksize[i] = static_cast<int>(in_x->dims()[i + 2]); if (channel_last) {
data_dims = framework::slice_ddim(in_x_dims, 1, in_x_dims.size() - 1);
} else {
data_dims = framework::slice_ddim(in_x_dims, 2, in_x_dims.size());
}
UpdatePadding(&paddings, global_pooling, adaptive, padding_algorithm,
data_dims, strides, ksize);
if (data_dims.size() * 2 == paddings.size()) {
for (size_t i = 0; i < data_dims.size(); ++i) {
paddings.erase(paddings.begin() + i + 1);
} }
} }
if (global_pooling) {
UpdateKsize(&ksize, data_dims);
}
auto& dev_ctx = context.template device_context<DeviceContext>(); auto& dev_ctx = context.template device_context<DeviceContext>();
if (in_x_grad) { if (in_x_grad) {
in_x_grad->mutable_data<T>(context.GetPlace()); in_x_grad->mutable_data<T>(context.GetPlace());
...@@ -154,15 +248,15 @@ class PoolGradKernel : public framework::OpKernel<T> { ...@@ -154,15 +248,15 @@ class PoolGradKernel : public framework::OpKernel<T> {
paddle::operators::math::MaxPool2dGradFunctor<DeviceContext, T> paddle::operators::math::MaxPool2dGradFunctor<DeviceContext, T>
pool2d_backward; pool2d_backward;
pool2d_backward(dev_ctx, *in_x, *out, *out_grad, ksize, strides, pool2d_backward(dev_ctx, *in_x, *out, *out_grad, ksize, strides,
paddings, in_x_grad); paddings, data_format, in_x_grad);
} else if (pooling_type == "avg") { } else if (pooling_type == "avg") {
paddle::operators::math::Pool2dGradFunctor< paddle::operators::math::Pool2dGradFunctor<
DeviceContext, paddle::operators::math::AvgPoolGrad<T>, T> DeviceContext, paddle::operators::math::AvgPoolGrad<T>, T>
pool2d_backward; pool2d_backward;
paddle::operators::math::AvgPoolGrad<T> pool_process; paddle::operators::math::AvgPoolGrad<T> pool_process;
pool2d_backward(dev_ctx, *in_x, *out, *out_grad, ksize, strides, pool2d_backward(dev_ctx, *in_x, *out, *out_grad, ksize, strides,
paddings, pool_process, exclusive, adaptive, paddings, data_format, pool_process, exclusive,
in_x_grad); adaptive, in_x_grad);
} }
} break; } break;
case 3: { case 3: {
...@@ -170,15 +264,15 @@ class PoolGradKernel : public framework::OpKernel<T> { ...@@ -170,15 +264,15 @@ class PoolGradKernel : public framework::OpKernel<T> {
paddle::operators::math::MaxPool3dGradFunctor<DeviceContext, T> paddle::operators::math::MaxPool3dGradFunctor<DeviceContext, T>
pool3d_backward; pool3d_backward;
pool3d_backward(dev_ctx, *in_x, *out, *out_grad, ksize, strides, pool3d_backward(dev_ctx, *in_x, *out, *out_grad, ksize, strides,
paddings, in_x_grad); paddings, data_format, in_x_grad);
} else if (pooling_type == "avg") { } else if (pooling_type == "avg") {
paddle::operators::math::Pool3dGradFunctor< paddle::operators::math::Pool3dGradFunctor<
DeviceContext, paddle::operators::math::AvgPoolGrad<T>, T> DeviceContext, paddle::operators::math::AvgPoolGrad<T>, T>
pool3d_backward; pool3d_backward;
paddle::operators::math::AvgPoolGrad<T> pool_process; paddle::operators::math::AvgPoolGrad<T> pool_process;
pool3d_backward(dev_ctx, *in_x, *out, *out_grad, ksize, strides, pool3d_backward(dev_ctx, *in_x, *out, *out_grad, ksize, strides,
paddings, pool_process, exclusive, adaptive, paddings, data_format, pool_process, exclusive,
in_x_grad); adaptive, in_x_grad);
} }
} break; } break;
default: { PADDLE_THROW("Pool op only supports 2D and 3D input."); } default: { PADDLE_THROW("Pool op only supports 2D and 3D input."); }
......
...@@ -72,6 +72,7 @@ enum class DataLayout { // Not use ...@@ -72,6 +72,7 @@ enum class DataLayout { // Not use
kNHWC, kNHWC,
kNCHW, kNCHW,
kNCDHW, kNCDHW,
kNDHWC, // add, liyamei
kNCHW_VECT_C, kNCHW_VECT_C,
}; };
...@@ -212,6 +213,8 @@ inline cudnnTensorFormat_t GetCudnnTensorFormat( ...@@ -212,6 +213,8 @@ inline cudnnTensorFormat_t GetCudnnTensorFormat(
return CUDNN_TENSOR_NCHW; return CUDNN_TENSOR_NCHW;
case DataLayout::kNCDHW: case DataLayout::kNCDHW:
return CUDNN_TENSOR_NCHW; // NOTE: cudnn treat NdTensor as the same return CUDNN_TENSOR_NCHW; // NOTE: cudnn treat NdTensor as the same
case DataLayout::kNDHWC:
return CUDNN_TENSOR_NHWC; // add, liyamei
default: default:
PADDLE_THROW("Unknown cudnn equivalent for order"); PADDLE_THROW("Unknown cudnn equivalent for order");
} }
...@@ -238,14 +241,31 @@ class ScopedTensorDescriptor { ...@@ -238,14 +241,31 @@ class ScopedTensorDescriptor {
strides[i] = dims[i + 1] * strides[i + 1]; strides[i] = dims[i + 1] * strides[i + 1];
} }
// Update tensor descriptor dims setting if groups > 1 // Update tensor descriptor dims setting if groups > 1
// NOTE: Assume using NCHW or NCDHW order // NOTE: Here, Assume using NCHW or NCDHW order
std::vector<int> dims_with_group(dims.begin(), dims.end()); // copy std::vector<int> dims_with_group(dims.begin(), dims.end());
if (groups > 1) { if (groups > 1) {
dims_with_group[1] = dims_with_group[1] / groups; dims_with_group[1] = dims_with_group[1] / groups;
} }
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnSetTensorNdDescriptor(
desc_, type, dims_with_group.size(), dims_with_group.data(), if (dims.size() == 4) {
strides.data())); if (format == CUDNN_TENSOR_NCHW) {
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnSetTensorNdDescriptor(
desc_, type, dims_with_group.size(), dims_with_group.data(),
strides.data()));
} else { // CUDNN_TENSOR_NHWC
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnSetTensor4dDescriptor(
desc_, format, type, dims[0], dims[3], dims[1], dims[2]));
}
} else if (dims.size() == 5) {
if (format == CUDNN_TENSOR_NCHW) {
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnSetTensorNdDescriptor(
desc_, type, dims_with_group.size(), dims_with_group.data(),
strides.data()));
} else { // CUDNN_TENSOR_NHWC
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnSetTensorNdDescriptorEx(
desc_, format, type, dims.size(), dims.data()));
}
}
return desc_; return desc_;
} }
......
...@@ -126,7 +126,8 @@ extern void EnforceCUDNNLoaded(const char* fn_name); ...@@ -126,7 +126,8 @@ extern void EnforceCUDNNLoaded(const char* fn_name);
__macro(cudnnRNNBackwardWeights); \ __macro(cudnnRNNBackwardWeights); \
__macro(cudnnRNNForwardInference); \ __macro(cudnnRNNForwardInference); \
__macro(cudnnDestroyDropoutDescriptor); \ __macro(cudnnDestroyDropoutDescriptor); \
__macro(cudnnDestroyRNNDescriptor); __macro(cudnnDestroyRNNDescriptor); \
__macro(cudnnSetTensorNdDescriptorEx);
CUDNN_DNN_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) CUDNN_DNN_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
......
...@@ -2871,15 +2871,16 @@ def pool2d(input, ...@@ -2871,15 +2871,16 @@ def pool2d(input,
use_cudnn=True, use_cudnn=True,
ceil_mode=False, ceil_mode=False,
name=None, name=None,
exclusive=True): exclusive=True,
data_format="NCHW"):
""" """
${comment} ${comment}
Args: Args:
input (Variable): The input tensor of pooling operator. The format of input (Variable): The input tensor of pooling operator. The format of
input tensor is NCHW, where N is batch size, C is input tensor is `"NCHW"` or `"NHWC"`, where `N` is batch size, `C` is
the number of channels, H is the height of the the number of channels, `H` is the height of the
feature, and W is the width of the feature. feature, and `W` is the width of the feature.
pool_size (int|list|tuple): The pool kernel size. If pool kernel size is a tuple or list, pool_size (int|list|tuple): The pool kernel size. If pool kernel size is a tuple or list,
it must contain two integers, (pool_size_Height, pool_size_Width). it must contain two integers, (pool_size_Height, pool_size_Width).
Otherwise, the pool kernel size will be a square of an int. Otherwise, the pool kernel size will be a square of an int.
...@@ -2887,8 +2888,13 @@ def pool2d(input, ...@@ -2887,8 +2888,13 @@ def pool2d(input,
pool_stride (int|list|tuple): The pool stride size. If pool stride size is a tuple or list, pool_stride (int|list|tuple): The pool stride size. If pool stride size is a tuple or list,
it must contain two integers, (pool_stride_Height, pool_stride_Width). it must contain two integers, (pool_stride_Height, pool_stride_Width).
Otherwise, the pool stride size will be a square of an int. Otherwise, the pool stride size will be a square of an int.
pool_padding (int|list|tuple): The pool padding size. If pool padding size is a tuple, pool_padding (string|int|list|tuple): The pool padding. If `pool_padding` is a string, either 'VALID' or
it must contain two integers, (pool_padding_on_Height, pool_padding_on_Width). 'SAME' which is the padding algorithm. If pool padding size is a tuple or list,
it could be in three forms: `[pad_height, pad_width]` or
`[pad_height_top, pad_height_bottom, pad_width_left, pad_width_right]`, and when `data_format` is `"NCHW"`,
`pool_padding` can be in the form `[[0,0], [0,0], [pad_height_top, pad_height_bottom], [pad_width_left, pad_width_right]]`.
when `data_format` is `"NHWC"`, `pool_padding` can be in the form
`[[0,0], [pad_height_top, pad_height_bottom], [pad_width_left, pad_width_right], [0,0]]`.
Otherwise, the pool padding size will be a square of an int. Otherwise, the pool padding size will be a square of an int.
global_pooling (bool): ${global_pooling_comment} global_pooling (bool): ${global_pooling_comment}
use_cudnn (bool): ${use_cudnn_comment} use_cudnn (bool): ${use_cudnn_comment}
...@@ -2896,55 +2902,125 @@ def pool2d(input, ...@@ -2896,55 +2902,125 @@ def pool2d(input,
name (str|None): A name for this layer(optional). If set None, the name (str|None): A name for this layer(optional). If set None, the
layer will be named automatically. layer will be named automatically.
exclusive (bool): Whether to exclude padding points in average pooling exclusive (bool): Whether to exclude padding points in average pooling
mode, default is true mode, default is `true`.
data_format (string): The data format of the input and output data. An optional string from: `"NCHW"`, `"NDHW"`.
The default is `"NCHW"`. When it is `"NCHW"`, the data is stored in the order of:
`[batch_size, input_channels, input_height, input_width]`.
Returns: Returns:
Variable: The pooling result. Variable: The pooling result.
Raises: Raises:
ValueError: If 'pool_type' is not "max" nor "avg" ValueError: If `pool_type` is not "max" nor "avg"
ValueError: If 'global_pooling' is False and 'pool_size' is -1 ValueError: If `global_pooling` is False and `pool_size` is -1
ValueError: If 'use_cudnn' is not a bool value. ValueError: If `use_cudnn` is not a bool value.
Examples: Examples:
.. code-block:: python .. code-block:: python
import paddle.fluid as fluid import paddle.fluid as fluid
data = fluid.layers.data( data = fluid.layers.data(
name='data', shape=[3, 32, 32], dtype='float32') name='data', shape=[10, 3, 32, 32], append_batch_size=False, dtype='float32')
pool2d = fluid.layers.pool2d(
input=data, # example 1:
pool_size=2, # Attr(pool_padding) is a list with 4 elements, Attr(data_format) is "NCHW".
pool_type='max', out_1 = fluid.layers.pool2d(
pool_stride=1, input = data,
global_pooling=False) pool_size = 3,
pool_type = "avg",
pool_stride = 1,
pool_padding = [1, 2, 1, 0],
data_format = "NCHW")
# example 2:
# Attr(pool_padding) is a string, Attr(data_format) is "NCHW".
out_2 = fluid.layers.pool2d(
input = data,
pool_size = 3,
pool_type = "avg",
pool_stride = 1,
pool_padding = "VALID",
data_format = "NCHW")
""" """
if pool_type not in ["max", "avg"]: if pool_type not in ["max", "avg"]:
raise ValueError( raise ValueError(
"Unknown pool_type: '%s'. It can only be 'max' or 'avg'.", "Unknown Attr(pool_type): '%s'. It can only be 'max' or 'avg'.",
str(pool_type)) str(pool_type))
if global_pooling is False and pool_size == -1: if global_pooling is False and pool_size == -1:
raise ValueError( raise ValueError(
"When the global_pooling is False, pool_size must be passed " "When Attr(global_pooling) is False, Attr(pool_size) must be passed "
"and be a valid value. Received pool_size: " + str(pool_size)) "and be a valid value. Received pool_size: %s." % str(pool_size))
if not isinstance(use_cudnn, bool):
raise ValueError("Attr(use_cudnn) should be True or False. Received "
"Attr(use_cudnn): %s." % str(use_cudnn))
if data_format not in ["NCHW", "NHWC"]:
raise ValueError(
"Attr(data_format) should be 'NCHW' or 'NHWC'. Received "
"Attr(data_format): %s." % str(data_format))
pool_size = utils.convert_to_list(pool_size, 2, 'pool_size') pool_size = utils.convert_to_list(pool_size, 2, 'pool_size')
pool_padding = utils.convert_to_list(pool_padding, 2, 'pool_padding')
pool_stride = utils.convert_to_list(pool_stride, 2, 'pool_stride') pool_stride = utils.convert_to_list(pool_stride, 2, 'pool_stride')
if not isinstance(use_cudnn, bool): def update_padding(padding, data_format):
raise ValueError("use_cudnn should be True or False") def is_list_or_tuple(ele):
if isinstance(ele, list) or isinstance(ele, tuple):
return True
return False
if is_list_or_tuple(padding) and len(padding) == 4:
if is_list_or_tuple(padding[0]) and (data_format == "NCHW"):
if not (padding[0] == [0, 0] and padding[1] == [0, 0]):
raise ValueError(
"Non-zero pool_padding(%s) in the batch or channel dimensions "
"is not supported." % str(padding))
padding = padding[2:4]
padding = [ele for a_list in padding for ele in a_list]
elif is_list_or_tuple(padding[0]) and (data_format == "NHWC"):
if not (padding[0] == [0, 0] and padding[3] == [0, 0]):
raise ValueError(
"Non-zero pool_padding(%s) in the batch or channel dimensions "
"is not supported." % str(padding))
padding = padding[1:3]
padding = [ele for a_list in padding for ele in a_list]
padding = utils.convert_to_list(padding, 4, 'padding')
l_type = 'pool2d' else:
padding = utils.convert_to_list(padding, 2, 'padding')
helper = LayerHelper(l_type, **locals()) return padding
padding_algorithm = "EXPLICIT"
if isinstance(pool_padding, str):
pool_padding = pool_padding.upper()
if pool_padding not in ["SAME", "VALID"]:
raise ValueError(
"Unknown Attr(pool_padding): '%s'. It can only be 'SAME' or 'VALID'."
% str(pool_padding))
if pool_padding == "VALID":
padding_algorithm = "VALID"
pool_padding = [0, 0, 0, 0]
if ceil_mode != False:
raise ValueError(
"When Attr(pool_padding) is \"VALID\", Attr(ceil_mode) must be False. "
"Received ceil_mode: True.")
elif pool_padding == "SAME":
padding_algorithm = "SAME"
pool_padding = [0, 0, 0, 0]
pool_padding = update_padding(pool_padding, data_format)
op_type = 'pool2d'
helper = LayerHelper(op_type, **locals())
dtype = helper.input_dtype() dtype = helper.input_dtype()
pool_out = helper.create_variable_for_type_inference(dtype) pool_out = helper.create_variable_for_type_inference(dtype)
helper.append_op( helper.append_op(
type=l_type, type=op_type,
inputs={"X": input}, inputs={"X": input},
outputs={"Out": pool_out}, outputs={"Out": pool_out},
attrs={ attrs={
...@@ -2953,10 +3029,12 @@ def pool2d(input, ...@@ -2953,10 +3029,12 @@ def pool2d(input,
"global_pooling": global_pooling, "global_pooling": global_pooling,
"strides": pool_stride, "strides": pool_stride,
"paddings": pool_padding, "paddings": pool_padding,
"padding_algorithm": padding_algorithm,
"use_cudnn": use_cudnn, "use_cudnn": use_cudnn,
"ceil_mode": ceil_mode, "ceil_mode": ceil_mode,
"use_mkldnn": False, "use_mkldnn": False,
"exclusive": exclusive, "exclusive": exclusive,
"data_format": data_format,
}) })
return pool_out return pool_out
...@@ -2972,30 +3050,43 @@ def pool3d(input, ...@@ -2972,30 +3050,43 @@ def pool3d(input,
use_cudnn=True, use_cudnn=True,
ceil_mode=False, ceil_mode=False,
name=None, name=None,
exclusive=True): exclusive=True,
data_format="NCDHW"):
""" """
${comment} ${comment}
Args: Args:
input (Variable): The input tensor of pooling operator. The format of input (Variable): The input tensor of pooling operator. The format of
input tensor is NCDHW, where N is batch size, C is input tensor is `"NCDHW"` or `"NDHWC"`, where `N` is batch size, `C` is
the number of channels, D is the depth of the feature, the number of channels, `D` is the depth of the feature,
H is the height of the feature, and W is the width `H` is the height of the feature, and `W` is the width
of the feature. of the feature.
pool_size (int|list|tuple): The pool kernel size. If pool kernel size pool_size (int|list|tuple): The pool kernel size. If pool kernel size
is a tuple or list, it must contain three integers, is a tuple or list, it must contain three integers,
(pool_size_Depth, pool_size_Height, pool_size_Width). (pool_size_Depth, pool_size_Height, pool_size_Width).
Otherwise, the pool kernel size will be the cube of an int. Otherwise, the pool kernel size will be the cube of an int.
pool_type (string): ${pooling_type_comment} pool_type (string): ${pooling_type_comment}
pool_stride (int): stride of the pooling layer. pool_stride (string|int|list|tuple)): The pool padding. If `pool_padding` is a string, either 'VALID' or
pool_padding (int): padding size. 'SAME' which is the padding algorithm. If pool stride size is a tuple or list,
it must contain three integers, `[stride_Depth, stride_Height, stride_Width]`.
Otherwise, the pool stride size will be a cube of an int.
pool_padding (int|list|tuple): The pool padding size. If pool padding size is a tuple or list,
it could be in three forms: `[pad_depth, pad_height, pad_width]` or
`[pad_depth_front, pad_depth_back, pad_height_top, pad_height_bottom, pad_width_left, pad_width_right]`,
and when `data_format` is `"NCDHW"`, `pool_padding` can be in the form
`[[0,0], [0,0], [pad_depth_front, pad_depth_back], [pad_height_top, pad_height_bottom], [pad_width_left, pad_width_right]]`.
when `data_format` is `"NDHWC"`, `pool_padding` can be in the form
`[[0,0], [pad_depth_front, pad_depth_back], [pad_height_top, pad_height_bottom], [pad_width_left, pad_width_right], [0,0]]`.
global_pooling (bool): ${global_pooling_comment} global_pooling (bool): ${global_pooling_comment}
use_cudnn (bool): ${use_cudnn_comment} use_cudnn (bool): ${use_cudnn_comment}
ceil_mode (bool): ${ceil_mode_comment} ceil_mode (bool): ${ceil_mode_comment}
name (str): A name for this layer(optional). If set None, the layer name (str): A name for this layer(optional). If set None, the layer
will be named automatically. will be named automatically.
exclusive (bool): Whether to exclude padding points in average pooling exclusive (bool): Whether to exclude padding points in average pooling
mode, default is true mode, default is true.
data_format (string): The data format of the input and output data. An optional string from: `"NCDHW"`, `"NDHWC"`.
The default is `"NCDHW"`. When it is `"NCDHW"`, the data is stored in the order of:
`[batch_size, input_channels, input_depth, input_height, input_width]`.
Returns: Returns:
Variable: output of pool3d layer. Variable: output of pool3d layer.
...@@ -3005,39 +3096,114 @@ def pool3d(input, ...@@ -3005,39 +3096,114 @@ def pool3d(input,
.. code-block:: python .. code-block:: python
import paddle.fluid as fluid import paddle.fluid as fluid
data = fluid.layers.data( data = fluid.layers.data(
name='data', shape=[3, 32, 32, 32], dtype='float32') name='data', shape=[10, 3, 32, 32, 32], append_batch_size=False, dtype='float32')
pool3d = fluid.layers.pool3d(
input=data, # example 1:
pool_size=2, # Attr(pool_padding) is a list with 6 elements, Attr(data_format) is "NCDHW".
pool_type='max', out_1 = fluid.layers.pool3d(
pool_stride=1, input = data,
global_pooling=False) pool_size = 2,
pool_type = "avg",
pool_stride = 1,
pool_padding = [1, 2, 1, 0, 1, 2],
global_pooling = False,
data_format = "NCDHW")
# example 2:
# Attr(pool_padding) is a string, Attr(data_format) is "NCDHW".
out_2 = fluid.layers.pool3d(
input = data,
pool_size = 3,
pool_type = "avg",
pool_stride = 1,
pool_padding = "VALID",
global_pooling = False,
data_format = "NCDHW")
""" """
if pool_type not in ["max", "avg"]: if pool_type not in ["max", "avg"]:
raise ValueError( raise ValueError(
"Unknown pool_type: '%s'. It can only be 'max' or 'avg'.", "Unknown Attr(pool_type): '%s'. It can only be 'max' or 'avg'.",
str(pool_type)) str(pool_type))
if global_pooling is False and pool_size == -1: if global_pooling is False and pool_size == -1:
raise ValueError( raise ValueError(
"When the global_pooling is False, pool_size must be passed " "When Attr(global_pooling) is False, Attr(pool_size) must be passed "
"and be a valid value. Received pool_size: " + str(pool_size)) "and be a valid value. Received Attr(pool_size): %s." %
str(pool_size))
if not isinstance(use_cudnn, bool):
raise ValueError("Attr(use_cudnn) should be True or False. Received "
"Attr(use_cudnn): %s. " % str(use_cudnn))
if data_format not in ["NCDHW", "NDHWC"]:
raise ValueError(
"Attr(data_format) should be 'NCDHW' or 'NDHWC'. Received "
"Attr(data_format): %s" % str(data_format))
pool_size = utils.convert_to_list(pool_size, 3, 'pool_size') pool_size = utils.convert_to_list(pool_size, 3, 'pool_size')
pool_padding = utils.convert_to_list(pool_padding, 3, 'pool_padding')
pool_stride = utils.convert_to_list(pool_stride, 3, 'pool_stride') pool_stride = utils.convert_to_list(pool_stride, 3, 'pool_stride')
if not isinstance(use_cudnn, bool): def update_padding(padding, data_format):
raise ValueError("use_cudnn should be True or False") def is_list_or_tuple(ele):
if isinstance(ele, (list, tuple)):
return True
return False
if is_list_or_tuple(padding) and len(padding) == 5:
if is_list_or_tuple(padding[0]) and (data_format == "NCDHW"):
if not (padding[0] == [0, 0] and padding[1] == [0, 0]):
raise ValueError(
"Non-zero pool_padding(%s) in the batch or channel dimensions "
"is not supported." % str(padding))
padding = padding[2:5]
padding = [ele for a_list in padding for ele in a_list]
elif is_list_or_tuple(padding[0]) and (data_format == "NDHWC"):
if not (padding[0] == [0, 0] and padding[4] == [0, 0]):
raise ValueError(
"Non-zero pool_padding(%s) in the batch or channel dimensions "
"is not supported." % str(padding))
padding = padding[1:4]
padding = [ele for a_list in padding for ele in a_list]
padding = utils.convert_to_list(padding, 6, 'padding')
elif is_list_or_tuple(padding) and len(padding) == 6:
padding = utils.convert_to_list(padding, 6, 'padding')
l_type = "pool3d" else:
helper = LayerHelper(l_type, **locals()) padding = utils.convert_to_list(padding, 3, 'padding')
return padding
padding_algorithm = "EXPLICIT"
if isinstance(pool_padding, str):
pool_padding = pool_padding.upper()
if pool_padding not in ["SAME", "VALID"]:
raise ValueError(
"Unknown Attr(pool_padding): '%s'. It can only be 'SAME' or 'VALID'."
% str(pool_padding))
if pool_padding == "VALID":
padding_algorithm = "VALID"
pool_padding = [0, 0, 0, 0, 0, 0]
if ceil_mode != False:
raise ValueError(
"When Attr(pool_padding) is \"VALID\", ceil_mode must be False. "
"Received ceil_mode: True.")
elif pool_padding == "SAME":
padding_algorithm = "SAME"
pool_padding = [0, 0, 0, 0, 0, 0]
pool_padding = update_padding(pool_padding, data_format)
op_type = "pool3d"
helper = LayerHelper(op_type, **locals())
dtype = helper.input_dtype() dtype = helper.input_dtype()
pool_out = helper.create_variable_for_type_inference(dtype) pool_out = helper.create_variable_for_type_inference(dtype)
helper.append_op( helper.append_op(
type=l_type, type=op_type,
inputs={"X": input}, inputs={"X": input},
outputs={"Out": pool_out}, outputs={"Out": pool_out},
attrs={ attrs={
...@@ -3046,10 +3212,12 @@ def pool3d(input, ...@@ -3046,10 +3212,12 @@ def pool3d(input,
"global_pooling": global_pooling, "global_pooling": global_pooling,
"strides": pool_stride, "strides": pool_stride,
"paddings": pool_padding, "paddings": pool_padding,
"padding_algorithm": padding_algorithm,
"use_cudnn": use_cudnn, "use_cudnn": use_cudnn,
"ceil_mode": ceil_mode, "ceil_mode": ceil_mode,
"use_mkldnn": False, "use_mkldnn": False,
"exclusive": exclusive, "exclusive": exclusive,
"data_format": data_format,
}) })
return pool_out return pool_out
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册