提交 5a957929 编写于 作者: C chengduo 提交者: GitHub

Merge pull request #4636 from chengduoZH/Add_pool_cudnn_op

Add pool2d cudnn op
...@@ -195,6 +195,14 @@ std::vector<int64_t> vectorize(const DDim& ddim) { ...@@ -195,6 +195,14 @@ std::vector<int64_t> vectorize(const DDim& ddim) {
return result; return result;
} }
// NOTE: framework::vectorize converts to type int64_t
// which does not fit cudnn inputs.
std::vector<int> vectorize2int(const DDim& ddim) {
std::vector<int64_t> temp = vectorize(ddim);
std::vector<int> result(temp.begin(), temp.end());
return result;
}
struct ProductVisitor : public boost::static_visitor<int64_t> { struct ProductVisitor : public boost::static_visitor<int64_t> {
template <int D> template <int D>
int64_t operator()(const Dim<D>& dim) { int64_t operator()(const Dim<D>& dim) {
......
...@@ -93,6 +93,7 @@ int64_t get(const DDim& dim, int idx); ...@@ -93,6 +93,7 @@ int64_t get(const DDim& dim, int idx);
void set(DDim& dim, int idx, int val); void set(DDim& dim, int idx, int val);
std::vector<int64_t> vectorize(const DDim& ddim); std::vector<int64_t> vectorize(const DDim& ddim);
std::vector<int> vectorize2int(const DDim& ddim);
int64_t product(const DDim& ddim); int64_t product(const DDim& ddim);
......
...@@ -69,6 +69,13 @@ function(op_library TARGET) ...@@ -69,6 +69,13 @@ function(op_library TARGET)
file(APPEND ${pybind_file} "USE_OP(max_pool2d_with_index);\n") file(APPEND ${pybind_file} "USE_OP(max_pool2d_with_index);\n")
endif() endif()
# pool_cudnn_op contains several operators
if ("${TARGET}" STREQUAL "pool_cudnn_op")
set(pybind_flag 1)
# It's enough to just adding one operator to pybind
file(APPEND ${pybind_file} "USE_OP(pool2d_cudnn);\n")
endif()
# save_restore_op contains several operators # save_restore_op contains several operators
if ("${TARGET}" STREQUAL "save_restore_op") if ("${TARGET}" STREQUAL "save_restore_op")
set(pybind_flag 1) set(pybind_flag 1)
......
...@@ -31,16 +31,6 @@ using CUDADeviceContext = platform::CUDADeviceContext; ...@@ -31,16 +31,6 @@ using CUDADeviceContext = platform::CUDADeviceContext;
static constexpr size_t kCONV_CUDNN_WORKSPACE_LIMIT_BYTES = 1024 * 1024 * 1024; static constexpr size_t kCONV_CUDNN_WORKSPACE_LIMIT_BYTES = 1024 * 1024 * 1024;
// NOTE: framework::vectorize converts to type int64_t
// which does not fit cudnn inputs.
std::vector<int> Dims2Vector(const framework::DDim& dims) {
std::vector<int> ret;
for (int i = 0; i < dims.size(); i++) {
ret.push_back(dims[i]);
}
return ret;
}
template <typename T> template <typename T>
class CudnnConvOpKernel : public framework::OpKernel<T> { class CudnnConvOpKernel : public framework::OpKernel<T> {
public: public:
...@@ -68,12 +58,12 @@ class CudnnConvOpKernel : public framework::OpKernel<T> { ...@@ -68,12 +58,12 @@ class CudnnConvOpKernel : public framework::OpKernel<T> {
ScopedConvolutionDescriptor conv_desc; ScopedConvolutionDescriptor conv_desc;
DataLayout layout = DataLayout::kNCHW; DataLayout layout = DataLayout::kNCHW;
cudnnTensorDescriptor_t cudnn_input_desc = cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
input_desc.descriptor<T>(layout, Dims2Vector(input->dims()), groups); layout, framework::vectorize2int(input->dims()), groups);
cudnnTensorDescriptor_t cudnn_output_desc = cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor<T>(
output_desc.descriptor<T>(layout, Dims2Vector(output->dims()), groups); layout, framework::vectorize2int(output->dims()), groups);
cudnnFilterDescriptor_t cudnn_filter_desc = cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor<T>(
filter_desc.descriptor<T>(layout, Dims2Vector(filter->dims()), groups); layout, framework::vectorize2int(filter->dims()), groups);
cudnnConvolutionDescriptor_t cudnn_conv_desc = cudnnConvolutionDescriptor_t cudnn_conv_desc =
conv_desc.descriptor<T>(paddings, strides, dilations); conv_desc.descriptor<T>(paddings, strides, dilations);
...@@ -156,13 +146,13 @@ class CudnnConvGradOpKernel : public framework::OpKernel<T> { ...@@ -156,13 +146,13 @@ class CudnnConvGradOpKernel : public framework::OpKernel<T> {
ScopedConvolutionDescriptor conv_desc; ScopedConvolutionDescriptor conv_desc;
DataLayout layout = DataLayout::kNCHW; DataLayout layout = DataLayout::kNCHW;
cudnnTensorDescriptor_t cudnn_input_desc = cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
input_desc.descriptor<T>(layout, Dims2Vector(input->dims()), groups); layout, framework::vectorize2int(input->dims()), groups);
cudnnTensorDescriptor_t cudnn_output_grad_desc = cudnnTensorDescriptor_t cudnn_output_grad_desc =
output_grad_desc.descriptor<T>(layout, Dims2Vector(output_grad->dims()), output_grad_desc.descriptor<T>(
groups); layout, framework::vectorize2int(output_grad->dims()), groups);
cudnnFilterDescriptor_t cudnn_filter_desc = cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor<T>(
filter_desc.descriptor<T>(layout, Dims2Vector(filter->dims()), groups); layout, framework::vectorize2int(filter->dims()), groups);
cudnnTensorDescriptor_t cudnn_input_grad_desc = nullptr; cudnnTensorDescriptor_t cudnn_input_grad_desc = nullptr;
cudnnFilterDescriptor_t cudnn_filter_grad_desc = nullptr; cudnnFilterDescriptor_t cudnn_filter_grad_desc = nullptr;
...@@ -192,7 +182,7 @@ class CudnnConvGradOpKernel : public framework::OpKernel<T> { ...@@ -192,7 +182,7 @@ class CudnnConvGradOpKernel : public framework::OpKernel<T> {
auto handle = ctx.cuda_device_context().cudnn_handle(); auto handle = ctx.cuda_device_context().cudnn_handle();
if (input_grad) { if (input_grad) {
cudnn_input_grad_desc = input_grad_desc.descriptor<T>( cudnn_input_grad_desc = input_grad_desc.descriptor<T>(
layout, Dims2Vector(input_grad->dims()), groups); layout, framework::vectorize2int(input_grad->dims()), groups);
PADDLE_ENFORCE( PADDLE_ENFORCE(
platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm( platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm(
handle, cudnn_filter_desc, handle, cudnn_filter_desc,
...@@ -213,7 +203,7 @@ class CudnnConvGradOpKernel : public framework::OpKernel<T> { ...@@ -213,7 +203,7 @@ class CudnnConvGradOpKernel : public framework::OpKernel<T> {
if (filter_grad) { if (filter_grad) {
cudnn_filter_grad_desc = filter_grad_desc.descriptor<T>( cudnn_filter_grad_desc = filter_grad_desc.descriptor<T>(
layout, Dims2Vector(filter_grad->dims()), groups); layout, framework::vectorize2int(filter_grad->dims()), groups);
PADDLE_ENFORCE( PADDLE_ENFORCE(
platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm( platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm(
handle, cudnn_input_desc, cudnn_output_grad_desc, cudnn_conv_desc, handle, cudnn_input_desc, cudnn_output_grad_desc, cudnn_conv_desc,
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/pool_cudnn_op.h"
namespace ops = paddle::operators;
REGISTER_OP(pool2d_cudnn, ops::PoolOp, ops::Pool2dOpMaker, pool2d_cudnn_grad,
ops::PoolOpGrad);
REGISTER_OP_CPU_KERNEL(pool2d_cudnn,
ops::PoolKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(pool2d_cudnn_grad,
ops::PoolGradKernel<paddle::platform::CPUPlace, float>)
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/pool_cudnn_op.h"
#include "paddle/platform/cudnn_helper.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using ScopedTensorDescriptor = platform::ScopedTensorDescriptor;
using ScopedPoolingDescriptor = platform::ScopedPoolingDescriptor;
using DataLayout = platform::DataLayout;
using PoolingMode = platform::PoolingMode;
template <typename T>
class PoolCudnnOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"It must use GPUPlace.");
const Tensor *input = ctx.Input<Tensor>("X");
Tensor *output = ctx.Output<Tensor>("Out");
const T *input_data = input->data<T>();
T *output_data = output->mutable_data<T>(ctx.GetPlace());
std::string pooling_type = ctx.Attr<std::string>("poolingType");
std::vector<int> ksize = ctx.Attr<std::vector<int>>("ksize");
std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
if (ctx.Attr<bool>("globalPooling")) {
for (size_t i = 0; i < ksize.size(); ++i) {
ksize[i] = static_cast<int>(input->dims()[i + 2]);
}
}
// ------------------- cudnn descriptors ---------------------
ScopedTensorDescriptor input_desc;
ScopedTensorDescriptor output_desc;
ScopedPoolingDescriptor pool_desc;
DataLayout layout = DataLayout::kNCHW;
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
layout, framework::vectorize2int(input->dims()));
cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor<T>(
layout, framework::vectorize2int(output->dims()));
PoolingMode pooling_mode;
if (pooling_type == "max") {
pooling_mode = PoolingMode::kMaximum;
} else {
pooling_mode = PoolingMode::kAverage;
}
cudnnPoolingDescriptor_t cudnn_pool_desc =
pool_desc.descriptor(pooling_mode, ksize, paddings, strides);
// ------------------- cudnn pool algorithm ---------------------
auto handle = ctx.cuda_device_context().cudnn_handle();
T alpha = 1.0f, beta = 0.0f;
PADDLE_ENFORCE(platform::dynload::cudnnPoolingForward(
handle, cudnn_pool_desc, &alpha, cudnn_input_desc, input_data, &beta,
cudnn_output_desc, output_data));
}
};
template <typename T>
class PoolCudnnGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"It must use GPUPlace.");
const Tensor *input = ctx.Input<Tensor>("X");
const Tensor *output = ctx.Input<Tensor>("Out");
const Tensor *output_grad =
ctx.Input<Tensor>(framework::GradVarName("Out"));
Tensor *input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
std::string pooling_type = ctx.Attr<std::string>("poolingType");
std::vector<int> ksize = ctx.Attr<std::vector<int>>("ksize");
std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
if (ctx.Attr<bool>("globalPooling")) {
for (size_t i = 0; i < ksize.size(); ++i)
ksize[i] = static_cast<int>(input->dims()[i + 2]);
}
const T *input_data = input->data<T>();
const T *output_data = output->data<T>();
const T *output_grad_data = output_grad->data<T>();
// ------------------- cudnn descriptors ---------------------
ScopedTensorDescriptor input_desc;
ScopedTensorDescriptor output_desc;
ScopedPoolingDescriptor pool_desc;
DataLayout layout = DataLayout::kNCHW;
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
layout, framework::vectorize2int(input->dims()));
cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor<T>(
layout, framework::vectorize2int(output->dims()));
PoolingMode pooling_mode;
if (pooling_type == "max") {
pooling_mode = PoolingMode::kMaximum;
} else {
pooling_mode = PoolingMode::kAverage;
}
cudnnPoolingDescriptor_t cudnn_pool_desc =
pool_desc.descriptor(pooling_mode, ksize, paddings, strides);
// ------------------- cudnn pool algorithm ---------------------
auto handle = ctx.cuda_device_context().cudnn_handle();
T alpha = 1.0f, beta = 0.0f;
if (input_grad) {
T *input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
math::SetConstant<paddle::platform::GPUPlace, T> set_zero;
set_zero(ctx.device_context(), input_grad, static_cast<T>(0));
PADDLE_ENFORCE(platform::dynload::cudnnPoolingBackward(
handle, cudnn_pool_desc, &alpha, cudnn_output_desc, output_data,
cudnn_output_desc, output_grad_data, cudnn_input_desc, input_data,
&beta, cudnn_input_desc, input_grad_data));
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(pool2d_cudnn, ops::PoolCudnnOpKernel<float>);
REGISTER_OP_GPU_KERNEL(pool2d_cudnn_grad, ops::PoolCudnnGradOpKernel<float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/op_registry.h"
#include "paddle/operators/pool_op.h"
namespace paddle {
namespace operators {} // namespace operators
} // namespace paddle
...@@ -29,7 +29,7 @@ void PoolOp::InferShape(framework::InferShapeContext *ctx) const { ...@@ -29,7 +29,7 @@ void PoolOp::InferShape(framework::InferShapeContext *ctx) const {
auto in_x_dims = ctx->GetInputDim("X"); auto in_x_dims = ctx->GetInputDim("X");
std::string pooling_type = ctx->Attrs().Get<std::string>("pooling_type"); std::string pooling_type = ctx->Attrs().Get<std::string>("poolingType");
std::vector<int> ksize = ctx->Attrs().Get<std::vector<int>>("ksize"); std::vector<int> ksize = ctx->Attrs().Get<std::vector<int>>("ksize");
std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides"); std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings"); std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
...@@ -37,7 +37,7 @@ void PoolOp::InferShape(framework::InferShapeContext *ctx) const { ...@@ -37,7 +37,7 @@ void PoolOp::InferShape(framework::InferShapeContext *ctx) const {
PADDLE_ENFORCE(in_x_dims.size() == 4 || in_x_dims.size() == 5, PADDLE_ENFORCE(in_x_dims.size() == 4 || in_x_dims.size() == 5,
"Pooling intput should be 4-D or 5-D tensor."); "Pooling intput should be 4-D or 5-D tensor.");
if (ctx->Attrs().Get<bool>("global_pooling")) { if (ctx->Attrs().Get<bool>("globalPooling")) {
ksize.resize(static_cast<size_t>(in_x_dims.size()) - 2); ksize.resize(static_cast<size_t>(in_x_dims.size()) - 2);
for (size_t i = 0; i < ksize.size(); ++i) for (size_t i = 0; i < ksize.size(); ++i)
ksize[i] = static_cast<int>(in_x_dims[i + 2]); ksize[i] = static_cast<int>(in_x_dims[i + 2]);
...@@ -80,32 +80,28 @@ Pool2dOpMaker::Pool2dOpMaker(framework::OpProto *proto, ...@@ -80,32 +80,28 @@ Pool2dOpMaker::Pool2dOpMaker(framework::OpProto *proto,
"the number of channels, H and W is the height and " "the number of channels, H and W is the height and "
"width of feature."); "width of feature.");
AddAttr<std::string>("pooling_type", AddAttr<std::string>("poolingType",
"Pooling_type of pooling operator." "(string), pooling type, can be \"max\" for max-pooling "
"Str constant equal to 'max' or 'avg'.") "and \"avg\" for average-pooling.")
.InEnum({"max", "avg"}); .InEnum({"max", "avg"});
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>(
"ksize", "ksize",
"The pooling window size(height, width) of pooling operator." "(vector ), the pooling window size(height, width) of pooling operator."
"If global_pooling = true, ksize is ignored and need not be " "If globalPooling = true, ksize is ignored and need not be "
"specified."); // TODO(Chengduo): Add checker. (Currently, "specified."); // TODO(Chengduo): Add checker. (Currently,
// TypedAttrChecker don't support vector type.) // TypedAttrChecker don't support vector type.)
AddAttr<bool>( AddAttr<bool>("globalPooling",
"global_pooling", "(bool default: false), whether to use the global pooling."
"Whether to use the global_pooling." "If globalPooling = true, ksize is ignored.")
"Bool constant equal to false or true."
"Default false."
"If global_pooling = true, ksize is ignored and need not be specified.")
.SetDefault(false); .SetDefault(false);
AddAttr<std::vector<int>>("strides", AddAttr<std::vector<int>>(
"The strides(height, width) of pooling window." "strides",
"Default {1,1}.") "(vector, default:{1, 1}), strides(height, width) of pooling operator.")
.SetDefault({1, 1}); // TODO(Chengduo): Add checker. (Currently, .SetDefault({1, 1}); // TODO(Chengduo): Add checker. (Currently,
// TypedAttrChecker don't support vector type.) // TypedAttrChecker don't support vector type.)
AddAttr<std::vector<int>>("paddings", AddAttr<std::vector<int>>(
"The zero padding(height, width) size on both sides" "paddings",
"Default {0,0}.") "(vector defalut:{0,0}), paddings(height, width) of pooling operator.")
.SetDefault({0, 0}); // TODO(Chengduo): Add checker. (Currently, .SetDefault({0, 0}); // TODO(Chengduo): Add checker. (Currently,
// TypedAttrChecker don't support vector type.) // TypedAttrChecker don't support vector type.)
...@@ -123,7 +119,6 @@ Example: ...@@ -123,7 +119,6 @@ Example:
X shape: (N, C, H_in, W_in) X shape: (N, C, H_in, W_in)
Output: Output:
Out shape: (N, C, H_out, W_out) Out shape: (N, C, H_out, W_out)
Mask shape: (N, C, H_out, W_out)
where where
H_out = (H_in - ksize[0] + 2 * paddings[0]) / strides[0] + 1; H_out = (H_in - ksize[0] + 2 * paddings[0]) / strides[0] + 1;
W_out = (W_in - ksize[1] + 2 * paddings[1]) / strides[1] + 1; W_out = (W_in - ksize[1] + 2 * paddings[1]) / strides[1] + 1;
...@@ -146,33 +141,29 @@ Pool3dOpMaker::Pool3dOpMaker(framework::OpProto *proto, ...@@ -146,33 +141,29 @@ Pool3dOpMaker::Pool3dOpMaker(framework::OpProto *proto,
"the number of channels, D, H and W is the depth, height and " "the number of channels, D, H and W is the depth, height and "
"width of feature."); "width of feature.");
AddAttr<std::string>("pooling_type", AddAttr<std::string>("poolingType",
"PoolingType of pooling operator." "(string), pooling type, can be \"max\" for max-pooling "
"Str constant equal to 'max' or 'avg'.") "and \"avg\" for average-pooling.")
.InEnum({"max", "avg"}); .InEnum({"max", "avg"});
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>(
"ksize", "ksize",
"The pooling window size(depth, height, width) of pooling operator." "(vector ), the pooling window size(depth, height, width) of pooling "
"If global_pooling = true, ksize is ignored and need not be " "operator."
"If globalPooling = true, ksize is ignored and need not be "
"specified."); // TODO(Chengduo): Add checker. (Currently, "specified."); // TODO(Chengduo): Add checker. (Currently,
// TypedAttrChecker don't support vector type.) // TypedAttrChecker don't support vector type.)
AddAttr<bool>( AddAttr<bool>("globalPooling",
"global_pooling", "(bool default: false), whether to use the global pooling."
"Whether to use the global_pooling." "If globalPooling = true, ksize is ignored.")
"Bool constant equal to false or true."
"Default false."
"If global_pooling = true, ksize is ignored and need not be specified.")
.SetDefault(false); .SetDefault(false);
AddAttr<std::vector<int>>("strides", AddAttr<std::vector<int>>("strides",
"Strides(depth, height, width) of pooling operator." "(vector, default:{1,1,1}), strides(depth, height, "
"Default {1,1,1}.") "width) of pooling operator.")
.SetDefault({1, 1, 1}); // TODO(Chengduo): Add checker. (Currently, .SetDefault({1, 1, 1}); // TODO(Chengduo): Add checker. (Currently,
// TypedAttrChecker don't support vector type.) // TypedAttrChecker don't support vector type.)
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>("paddings",
"paddings", "(vector defalut:{0,0,0}), paddings(depth, height, "
"Paddings(depth, height, width) of pooling operator." "width) of pooling operator.")
"Default {0,0,0}.")
.SetDefault({0, 0, 0}); // TODO(Chengduo): Add checker. (Currently, .SetDefault({0, 0, 0}); // TODO(Chengduo): Add checker. (Currently,
// TypedAttrChecker don't support vector type.) // TypedAttrChecker don't support vector type.)
...@@ -190,7 +181,6 @@ Example: ...@@ -190,7 +181,6 @@ Example:
X shape: (N, C, D_in, H_in, W_in) X shape: (N, C, D_in, H_in, W_in)
Output: Output:
Out shape: (N, C, D_out, H_out, W_out) Out shape: (N, C, D_out, H_out, W_out)
Mask shape: (N, C, D_out, H_out, W_out)
where where
D_out = (D_in - ksize[0] + 2 * paddings[0]) / strides[0] + 1; D_out = (D_in - ksize[0] + 2 * paddings[0]) / strides[0] + 1;
H_out = (H_in - ksize[1] + 2 * paddings[1]) / strides[1] + 1; H_out = (H_in - ksize[1] + 2 * paddings[1]) / strides[1] + 1;
......
...@@ -57,11 +57,11 @@ class PoolKernel : public framework::OpKernel<T> { ...@@ -57,11 +57,11 @@ class PoolKernel : public framework::OpKernel<T> {
const Tensor* in_x = context.Input<Tensor>("X"); const Tensor* in_x = context.Input<Tensor>("X");
Tensor* out = context.Output<Tensor>("Out"); Tensor* out = context.Output<Tensor>("Out");
std::string pooling_type = context.Attr<std::string>("pooling_type"); std::string pooling_type = context.Attr<std::string>("poolingType");
std::vector<int> ksize = context.Attr<std::vector<int>>("ksize"); std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
std::vector<int> strides = context.Attr<std::vector<int>>("strides"); std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings"); std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
if (context.Attr<bool>("global_pooling")) { if (context.Attr<bool>("globalPooling")) {
for (size_t i = 0; i < ksize.size(); ++i) { for (size_t i = 0; i < ksize.size(); ++i) {
ksize[i] = static_cast<int>(in_x->dims()[i + 2]); ksize[i] = static_cast<int>(in_x->dims()[i + 2]);
} }
...@@ -117,12 +117,12 @@ class PoolGradKernel : public framework::OpKernel<T> { ...@@ -117,12 +117,12 @@ class PoolGradKernel : public framework::OpKernel<T> {
context.Input<Tensor>(framework::GradVarName("Out")); context.Input<Tensor>(framework::GradVarName("Out"));
Tensor* in_x_grad = context.Output<Tensor>(framework::GradVarName("X")); Tensor* in_x_grad = context.Output<Tensor>(framework::GradVarName("X"));
std::string pooling_type = context.Attr<std::string>("pooling_type"); std::string pooling_type = context.Attr<std::string>("poolingType");
std::vector<int> ksize = context.Attr<std::vector<int>>("ksize"); std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
std::vector<int> strides = context.Attr<std::vector<int>>("strides"); std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings"); std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
if (context.Attr<bool>("global_pooling")) { if (context.Attr<bool>("globalPooling")) {
for (size_t i = 0; i < ksize.size(); ++i) for (size_t i = 0; i < ksize.size(); ++i)
ksize[i] = static_cast<int>(in_x->dims()[i + 2]); ksize[i] = static_cast<int>(in_x->dims()[i + 2]);
} }
......
...@@ -44,7 +44,7 @@ class MaxPoolWithIndexOp : public framework::OperatorWithKernel { ...@@ -44,7 +44,7 @@ class MaxPoolWithIndexOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(in_x_dims.size() == 4 || in_x_dims.size() == 5, PADDLE_ENFORCE(in_x_dims.size() == 4 || in_x_dims.size() == 5,
"Pooling intput should be 4-D or 5-D tensor."); "Pooling intput should be 4-D or 5-D tensor.");
if (ctx->Attrs().Get<bool>("global_pooling")) { if (ctx->Attrs().Get<bool>("globalPooling")) {
ksize.resize(static_cast<size_t>(in_x_dims.size()) - 2); ksize.resize(static_cast<size_t>(in_x_dims.size()) - 2);
for (size_t i = 0; i < ksize.size(); ++i) for (size_t i = 0; i < ksize.size(); ++i)
ksize[i] = static_cast<int>(in_x_dims[i + 2]); ksize[i] = static_cast<int>(in_x_dims[i + 2]);
...@@ -105,26 +105,22 @@ class MaxPool2dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -105,26 +105,22 @@ class MaxPool2dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>(
"ksize", "ksize",
"The pooling window size(height, width) of pooling operator." "(vector ), the pooling window size(height, width) of pooling operator."
"If global_pooling = true, ksize is ignored and need not be " "If globalPooling = true, ksize is ignored and need not be "
"specified."); // TODO(Chengduo): Add checker. (Currently, "specified."); // TODO(Chengduo): Add checker. (Currently,
// TypedAttrChecker don't support vector type.) // TypedAttrChecker don't support vector type.)
AddAttr<bool>( AddAttr<bool>("globalPooling",
"global_pooling", "(bool default: false), whether to use the global pooling."
"Whether to use the global_pooling." "If globalPooling = true, ksize is ignored.")
"Bool constant equal to false or true."
"Default false."
"If global_pooling = true, ksize is ignored and need not be specified.")
.SetDefault(false); .SetDefault(false);
AddAttr<std::vector<int>>("strides", AddAttr<std::vector<int>>(
"The strides(height, width) of pooling window." "strides",
"Default {1,1}.") "(vector, default:{1, 1}), strides(height, width) of pooling operator.")
.SetDefault({1, 1}); // TODO(Chengduo): Add checker. (Currently, .SetDefault({1, 1}); // TODO(Chengduo): Add checker. (Currently,
// TypedAttrChecker don't support vector type.) // TypedAttrChecker don't support vector type.)
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>(
"paddings", "paddings",
"The zero padding(height, width) size on both sides" "(vector defalut:{0,0}), paddings(height, width) of pooling operator.")
"Default {0,0}.")
.SetDefault({0, 0}); // TODO(Chengduo): Add checker. (Currently, .SetDefault({0, 0}); // TODO(Chengduo): Add checker. (Currently,
// TypedAttrChecker don't support vector type.) // TypedAttrChecker don't support vector type.)
...@@ -176,27 +172,23 @@ class MaxPool3dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -176,27 +172,23 @@ class MaxPool3dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>(
"ksize", "ksize",
"The pooling window size(depth, height, width) of pooling operator." "(vector ), the pooling window size(depth, height, width) of pooling "
"If global_pooling = true, ksize is ignored and need not be " "operator."
"If globalPooling = true, ksize is ignored and need not be "
"specified."); // TODO(Chengduo): Add checker. (Currently, "specified."); // TODO(Chengduo): Add checker. (Currently,
// TypedAttrChecker don't support vector type.) // TypedAttrChecker don't support vector type.)
AddAttr<bool>( AddAttr<bool>("globalPooling",
"global_pooling", "(bool default: false), whether to use the global pooling."
"Whether to use the global_pooling." "If globalPooling = true, ksize is ignored.")
"Bool constant equal to false or true."
"Default false."
"If global_pooling = true, ksize is ignored and need not be specified.")
.SetDefault(false); .SetDefault(false);
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>("strides",
"strides", "(vector, default:{1,1,1}), strides(depth, "
"Strides(depth, height, width) of pooling operator." "height, width) of pooling operator.")
"Default {1,1,1}.")
.SetDefault({1, 1, 1}); // TODO(Chengduo): Add checker. (Currently, .SetDefault({1, 1, 1}); // TODO(Chengduo): Add checker. (Currently,
// TypedAttrChecker don't support vector type.) // TypedAttrChecker don't support vector type.)
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>("paddings",
"paddings", "(vector defalut:{0,0,0}), paddings(depth, "
"Paddings(depth, height, width) of pooling operator." "height, width) of pooling operator.")
"Default {0,0,0}.")
.SetDefault({0, 0, 0}); // TODO(Chengduo): Add checker. (Currently, .SetDefault({0, 0, 0}); // TODO(Chengduo): Add checker. (Currently,
// TypedAttrChecker don't support vector type.) // TypedAttrChecker don't support vector type.)
......
...@@ -35,7 +35,7 @@ class MaxPoolWithIndexKernel : public framework::OpKernel<T> { ...@@ -35,7 +35,7 @@ class MaxPoolWithIndexKernel : public framework::OpKernel<T> {
std::vector<int> ksize = context.Attr<std::vector<int>>("ksize"); std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
std::vector<int> strides = context.Attr<std::vector<int>>("strides"); std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings"); std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
if (context.Attr<bool>("global_pooling")) { if (context.Attr<bool>("globalPooling")) {
for (size_t i = 0; i < ksize.size(); ++i) { for (size_t i = 0; i < ksize.size(); ++i) {
ksize[i] = static_cast<int>(in_x->dims()[i + 2]); ksize[i] = static_cast<int>(in_x->dims()[i + 2]);
} }
...@@ -70,7 +70,7 @@ class MaxPoolWithIndexGradKernel : public framework::OpKernel<T> { ...@@ -70,7 +70,7 @@ class MaxPoolWithIndexGradKernel : public framework::OpKernel<T> {
std::vector<int> ksize = context.Attr<std::vector<int>>("ksize"); std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
std::vector<int> strides = context.Attr<std::vector<int>>("strides"); std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings"); std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
if (context.Attr<bool>("global_pooling")) { if (context.Attr<bool>("globalPooling")) {
for (size_t i = 0; i < ksize.size(); ++i) { for (size_t i = 0; i < ksize.size(); ++i) {
ksize[i] = static_cast<int>(in_x_grad->dims()[i + 2]); ksize[i] = static_cast<int>(in_x_grad->dims()[i + 2]);
} }
......
...@@ -284,9 +284,9 @@ def pool2d(input, ...@@ -284,9 +284,9 @@ def pool2d(input,
inputs={"X": input}, inputs={"X": input},
outputs={"Out": pool_out}, outputs={"Out": pool_out},
attrs={ attrs={
"pooling_type": pool_type, "poolingType": pool_type,
"ksize": pool_size, "ksize": pool_size,
"global_pooling": global_pooling, "globalPooling": global_pooling,
"strides": pool_stride, "strides": pool_stride,
"paddings": pool_padding "paddings": pool_padding
}) })
......
...@@ -46,7 +46,9 @@ def avg_pool2D_forward_naive(x, ksize, strides, paddings=[0, 0], global_pool=0): ...@@ -46,7 +46,9 @@ def avg_pool2D_forward_naive(x, ksize, strides, paddings=[0, 0], global_pool=0):
class TestPool2d_Op(OpTest): class TestPool2d_Op(OpTest):
def setUp(self): def setUp(self):
self.initTestCase() self.init_test_case()
self.init_op_type()
self.init_pool_type()
input = np.random.random(self.shape).astype("float32") input = np.random.random(self.shape).astype("float32")
output = self.pool2D_forward_naive(input, self.ksize, self.strides, output = self.pool2D_forward_naive(input, self.ksize, self.strides,
self.paddings, self.global_pool) self.paddings, self.global_pool)
...@@ -56,8 +58,8 @@ class TestPool2d_Op(OpTest): ...@@ -56,8 +58,8 @@ class TestPool2d_Op(OpTest):
'strides': self.strides, 'strides': self.strides,
'paddings': self.paddings, 'paddings': self.paddings,
'ksize': self.ksize, 'ksize': self.ksize,
'pooling_type': self.pool_type, 'poolingType': self.pool_type,
'global_pooling': self.global_pool, 'globalPooling': self.global_pool,
} }
self.outputs = {'Out': output.astype('float32')} self.outputs = {'Out': output.astype('float32')}
...@@ -69,76 +71,197 @@ class TestPool2d_Op(OpTest): ...@@ -69,76 +71,197 @@ class TestPool2d_Op(OpTest):
if self.pool_type != "max": if self.pool_type != "max":
self.check_grad(set(['X']), 'Out', max_relative_error=0.07) self.check_grad(set(['X']), 'Out', max_relative_error=0.07)
def initTestCase(self): def init_test_case(self):
self.global_pool = True self.global_pool = True
self.op_type = "pool2d"
self.pool_type = "avg"
self.pool2D_forward_naive = avg_pool2D_forward_naive self.pool2D_forward_naive = avg_pool2D_forward_naive
self.shape = [2, 3, 5, 5] self.shape = [2, 3, 5, 5]
self.ksize = [3, 3] self.ksize = [3, 3]
self.strides = [1, 1] self.strides = [1, 1]
self.paddings = [0, 0] self.paddings = [0, 0]
def init_op_type(self):
self.op_type = "pool2d"
def init_pool_type(self):
self.pool_type = "avg"
class TestCase1(TestPool2d_Op): class TestCase1(TestPool2d_Op):
def initTestCase(self): def init_test_case(self):
self.global_pool = False self.global_pool = False
self.op_type = "pool2d"
self.pool_type = "avg"
self.pool2D_forward_naive = avg_pool2D_forward_naive self.pool2D_forward_naive = avg_pool2D_forward_naive
self.shape = [2, 3, 7, 7] self.shape = [2, 3, 7, 7]
self.ksize = [3, 3] self.ksize = [3, 3]
self.strides = [1, 1] self.strides = [1, 1]
self.paddings = [0, 0] self.paddings = [0, 0]
def init_op_type(self):
self.op_type = "pool2d"
def init_pool_type(self):
self.pool_type = "avg"
class TestCase2(TestPool2d_Op): class TestCase2(TestPool2d_Op):
def initTestCase(self): def init_test_case(self):
self.global_pool = False self.global_pool = False
self.op_type = "pool2d"
self.pool_type = "avg"
self.pool2D_forward_naive = avg_pool2D_forward_naive self.pool2D_forward_naive = avg_pool2D_forward_naive
self.shape = [2, 3, 7, 7] self.shape = [2, 3, 7, 7]
self.ksize = [3, 3] self.ksize = [3, 3]
self.strides = [1, 1] self.strides = [1, 1]
self.paddings = [1, 1] self.paddings = [1, 1]
def init_op_type(self):
self.op_type = "pool2d"
def init_pool_type(self):
self.pool_type = "avg"
class TestCase3(TestPool2d_Op): class TestCase3(TestPool2d_Op):
def initTestCase(self): def init_test_case(self):
self.global_pool = True self.global_pool = True
self.op_type = "pool2d"
self.pool_type = "max"
self.pool2D_forward_naive = max_pool2D_forward_naive self.pool2D_forward_naive = max_pool2D_forward_naive
self.shape = [2, 3, 5, 5] self.shape = [2, 3, 5, 5]
self.ksize = [3, 3] self.ksize = [3, 3]
self.strides = [1, 1] self.strides = [1, 1]
self.paddings = [0, 0] self.paddings = [0, 0]
def init_op_type(self):
self.op_type = "pool2d"
def init_pool_type(self):
self.pool_type = "max"
class TestCase4(TestPool2d_Op): class TestCase4(TestPool2d_Op):
def initTestCase(self): def init_test_case(self):
self.global_pool = False self.global_pool = False
self.op_type = "pool2d"
self.pool_type = "max"
self.pool2D_forward_naive = max_pool2D_forward_naive self.pool2D_forward_naive = max_pool2D_forward_naive
self.shape = [2, 3, 7, 7] self.shape = [2, 3, 7, 7]
self.ksize = [3, 3] self.ksize = [3, 3]
self.strides = [1, 1] self.strides = [1, 1]
self.paddings = [0, 0] self.paddings = [0, 0]
def init_op_type(self):
self.op_type = "pool2d"
def init_pool_type(self):
self.pool_type = "max"
class TestCase5(TestPool2d_Op): class TestCase5(TestPool2d_Op):
def initTestCase(self): def init_test_case(self):
self.global_pool = False self.global_pool = False
self.pool2D_forward_naive = max_pool2D_forward_naive
self.shape = [2, 3, 7, 7]
self.ksize = [3, 3]
self.strides = [1, 1]
self.paddings = [1, 1]
def init_op_type(self):
self.op_type = "pool2d" self.op_type = "pool2d"
def init_pool_type(self):
self.pool_type = "max"
#--------------------test pool2d_cudnn--------------------
class TestCaseCudnn1(TestPool2d_Op):
def init_test_case(self):
self.global_pool = True
self.pool2D_forward_naive = avg_pool2D_forward_naive
self.shape = [2, 3, 5, 5]
self.ksize = [3, 3]
self.strides = [1, 1]
self.paddings = [0, 0]
def init_op_type(self):
self.op_type = "pool2d_cudnn"
def init_pool_type(self):
self.pool_type = "avg"
class TestCaseCudnn2(TestPool2d_Op):
def init_test_case(self):
self.global_pool = False
self.pool2D_forward_naive = avg_pool2D_forward_naive
self.shape = [2, 3, 7, 7]
self.ksize = [3, 3]
self.strides = [1, 1]
self.paddings = [0, 0]
def init_op_type(self):
self.op_type = "pool2d_cudnn"
def init_pool_type(self):
self.pool_type = "avg"
class TestCaseCudnn3(TestPool2d_Op):
def init_test_case(self):
self.global_pool = False
self.pool2D_forward_naive = avg_pool2D_forward_naive
self.shape = [2, 3, 7, 7]
self.ksize = [3, 3]
self.strides = [1, 1]
self.paddings = [1, 1]
def init_op_type(self):
self.op_type = "pool2d_cudnn"
def init_pool_type(self):
self.pool_type = "avg"
class TestCaseCudnn4(TestPool2d_Op):
def init_test_case(self):
self.global_pool = True
self.pool2D_forward_naive = max_pool2D_forward_naive
self.shape = [2, 3, 5, 5]
self.ksize = [3, 3]
self.strides = [1, 1]
self.paddings = [0, 0]
def init_op_type(self):
self.op_type = "pool2d_cudnn"
def init_pool_type(self):
self.pool_type = "max"
class TestCaseCudnn5(TestPool2d_Op):
def init_test_case(self):
self.global_pool = False
self.pool2D_forward_naive = max_pool2D_forward_naive
self.shape = [2, 3, 7, 7]
self.ksize = [3, 3]
self.strides = [1, 1]
self.paddings = [0, 0]
def init_op_type(self):
self.op_type = "pool2d_cudnn"
def init_pool_type(self):
self.pool_type = "max" self.pool_type = "max"
class TestCaseCudnn6(TestPool2d_Op):
def init_test_case(self):
self.global_pool = False
self.pool2D_forward_naive = max_pool2D_forward_naive self.pool2D_forward_naive = max_pool2D_forward_naive
self.shape = [2, 3, 7, 7] self.shape = [2, 3, 7, 7]
self.ksize = [3, 3] self.ksize = [3, 3]
self.strides = [1, 1] self.strides = [1, 1]
self.paddings = [1, 1] self.paddings = [1, 1]
def init_op_type(self):
self.op_type = "pool2d_cudnn"
def init_pool_type(self):
self.pool_type = "max"
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -64,8 +64,8 @@ class TestPool3d_Op(OpTest): ...@@ -64,8 +64,8 @@ class TestPool3d_Op(OpTest):
'strides': self.strides, 'strides': self.strides,
'paddings': self.paddings, 'paddings': self.paddings,
'ksize': self.ksize, 'ksize': self.ksize,
'pooling_type': self.pool_type, 'poolingType': self.pool_type,
'global_pooling': self.global_pool, 'globalPooling': self.global_pool,
} }
self.outputs = {'Out': output.astype('float32')} self.outputs = {'Out': output.astype('float32')}
......
...@@ -86,7 +86,7 @@ class TestMaxPoolWithIndex_Op(OpTest): ...@@ -86,7 +86,7 @@ class TestMaxPoolWithIndex_Op(OpTest):
'strides': self.strides, 'strides': self.strides,
'paddings': self.paddings, 'paddings': self.paddings,
'ksize': self.ksize, 'ksize': self.ksize,
'global_pooling': self.global_pool, 'globalPooling': self.global_pool,
} }
self.inputs = {'X': input} self.inputs = {'X': input}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册