提交 4f5491b2 编写于 作者: C chengduo 提交者: GitHub

Merge pull request #4146 from chengduoZH/Add_pool_op

Add pool op
......@@ -55,6 +55,12 @@ function(op_library TARGET)
set(pybind_flag 1)
endif()
if ("${TARGET}" STREQUAL "pool_op")
set(pybind_flag 1)
# It's enough to just adding one operator to pybind
file(APPEND ${pybind_file} "USE_OP(pool2d);\n")
endif()
# activation_op contains several operators
if ("${TARGET}" STREQUAL "activation_op")
set(pybind_flag 1)
......
if(WITH_GPU)
nv_library(math_function SRCS math_function.cc math_function.cu im2col.cc
im2col.cu DEPS cblas device_context operator)
nv_library(math_function SRCS math_function.cc math_function.cu im2col.cc im2col.cu pooling.cc pooling.cu DEPS cblas device_context operator)
nv_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor)
nv_library(softmax SRCS softmax.cc softmax.cu DEPS operator)
nv_library(cross_entropy SRCS cross_entropy.cc cross_entropy.cu DEPS operator)
else()
cc_library(math_function SRCS math_function.cc im2col.cc
DEPS cblas device_context operator)
cc_library(math_function SRCS math_function.cc im2col.cc pooling.cc DEPS cblas device_context operator)
cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor)
cc_library(softmax SRCS softmax.cc DEPS operator)
cc_library(cross_entropy SRCS cross_entropy.cc DEPS operator)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/pooling.h"
namespace paddle {
namespace operators {
namespace math {
template <typename PoolProcess, typename T>
class Pool2dFunctor<platform::CPUPlace, PoolProcess, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& output,
std::vector<int>& ksize, std::vector<int>& strides,
std::vector<int>& paddings, PoolProcess pool_process) {
const int batch_size = input.dims()[0];
const int input_height = input.dims()[2];
const int input_width = input.dims()[3];
const int output_channels = output.dims()[1];
const int output_height = output.dims()[2];
const int output_width = output.dims()[3];
const int ksize_height = ksize[0];
const int ksize_width = ksize[1];
const int stride_height = strides[0];
const int stride_width = strides[1];
const int padding_height = paddings[0];
const int padding_width = paddings[1];
const int input_stride = input_height * input_width;
const int output_stride = output_height * output_width;
const T* input_data = input.data<T>();
T* output_data = output.mutable_data<T>(context.GetPlace());
for (int i = 0; i < batch_size; i++) {
for (int c = 0; c < output_channels; ++c) {
for (int ph = 0; ph < output_height; ++ph) {
int hstart = ph * stride_height - padding_height;
int hend = std::min(hstart + ksize_height, input_height);
hstart = std::max(hstart, 0);
for (int pw = 0; pw < output_width; ++pw) {
int wstart = pw * stride_width - padding_width;
int wend = std::min(wstart + ksize_width, input_width);
wstart = std::max(wstart, 0);
T ele = pool_process.initial();
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
pool_process.compute(ele, input_data[h * input_width + w]);
}
}
int pool_size = (hend - hstart) * (wend - wstart);
pool_process.finalize(ele, (static_cast<T>(pool_size)));
output_data[ph * output_width + pw] = ele;
}
}
input_data += input_stride;
output_data += output_stride;
}
}
}
};
template <typename PoolProcess, class T>
class Pool2dGradFunctor<platform::CPUPlace, PoolProcess, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& input_grad,
const framework::Tensor& output,
const framework::Tensor& output_grad, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings,
PoolProcess pool_grad_process) {
const int batch_size = input.dims()[0];
const int input_height = input.dims()[2];
const int input_width = input.dims()[3];
const int output_channels = output.dims()[1];
const int output_height = output.dims()[2];
const int output_width = output.dims()[3];
const int ksize_height = ksize[0];
const int ksize_width = ksize[1];
const int stride_height = strides[0];
const int stride_width = strides[1];
const int padding_height = paddings[0];
const int padding_width = paddings[1];
const int input_stride = input_height * input_width;
const int output_stride = output_height * output_width;
const T* input_data = input.data<T>();
const T* output_data = output.data<T>();
const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad.mutable_data<T>(context.GetPlace());
for (int i = 0; i < batch_size; i++) {
for (int c = 0; c < output_channels; ++c) {
for (int ph = 0; ph < output_height; ++ph) {
int hstart = ph * stride_height - padding_height;
int hend = std::min(hstart + ksize_height, input_height);
hstart = std::max(hstart, 0);
for (int pw = 0; pw < output_width; ++pw) {
int wstart = pw * stride_width - padding_width;
int wend = std::min(wstart + ksize_width, input_width);
wstart = std::max(wstart, 0);
int pool_size = (hend - hstart) * (wend - wstart);
float scale = 1.0 / pool_size;
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
pool_grad_process.compute(
input_data[h * input_width + w],
output_data[ph * output_width + pw],
output_grad_data[ph * output_width + pw],
input_grad_data[h * input_width + w],
static_cast<T>(scale));
}
}
}
}
input_data += input_stride;
output_data += output_stride;
input_grad_data += input_stride;
output_grad_data += output_stride;
}
}
}
};
template <class T>
class MaxPool2dGradFunctor<platform::CPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& input_grad,
const framework::Tensor& output,
const framework::Tensor& output_grad, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings) {
const int batch_size = input.dims()[0];
const int input_height = input.dims()[2];
const int input_width = input.dims()[3];
const int output_channels = output.dims()[1];
const int output_height = output.dims()[2];
const int output_width = output.dims()[3];
const int ksize_height = ksize[0];
const int ksize_width = ksize[1];
const int stride_height = strides[0];
const int stride_width = strides[1];
const int padding_height = paddings[0];
const int padding_width = paddings[1];
const int input_stride = input_height * input_width;
const int output_stride = output_height * output_width;
const T* input_data = input.data<T>();
const T* output_data = output.data<T>();
const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad.mutable_data<T>(context.GetPlace());
for (int i = 0; i < batch_size; i++) {
for (int c = 0; c < output_channels; ++c) {
for (int ph = 0; ph < output_height; ++ph) {
int hstart = ph * stride_height - padding_height;
int hend = std::min(hstart + ksize_height, input_height);
hstart = std::max(hstart, 0);
for (int pw = 0; pw < output_width; ++pw) {
int wstart = pw * stride_width - padding_width;
int wend = std::min(wstart + ksize_width, input_width);
wstart = std::max(wstart, 0);
bool stop = false;
for (int h = hstart; h < hend && !stop; ++h) {
for (int w = wstart; w < wend && !stop; ++w) {
int input_idx = h * input_width + w;
int output_idx = ph * output_width + pw;
if (input_data[input_idx] == output_data[output_idx]) {
input_grad_data[input_idx] += output_grad_data[output_idx];
stop = true;
}
}
}
}
}
input_data += input_stride;
output_data += output_stride;
input_grad_data += input_stride;
output_grad_data += output_stride;
}
}
}
};
template class MaxPool2dGradFunctor<platform::CPUPlace, float>;
// template class MaxPool2dGradFunctor<platform::CPUPlace, double>;
template class Pool2dFunctor<platform::CPUPlace,
paddle::operators::math::MaxPool<float>, float>;
template class Pool2dFunctor<platform::CPUPlace,
paddle::operators::math::AvgPool<float>, float>;
template class Pool2dGradFunctor<
platform::CPUPlace, paddle::operators::math::MaxPoolGrad<float>, float>;
template class Pool2dGradFunctor<
platform::CPUPlace, paddle::operators::math::AvgPoolGrad<float>, float>;
template class Pool2dFunctor<platform::CPUPlace,
paddle::operators::math::MaxPool<double>, double>;
template class Pool2dFunctor<platform::CPUPlace,
paddle::operators::math::AvgPool<double>, double>;
template class Pool2dGradFunctor<
platform::CPUPlace, paddle::operators::math::MaxPoolGrad<double>, double>;
template class Pool2dGradFunctor<
platform::CPUPlace, paddle::operators::math::AvgPoolGrad<double>, double>;
template <typename PoolProcess, class T>
class Pool3dFunctor<platform::CPUPlace, PoolProcess, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& output,
std::vector<int>& ksize, std::vector<int>& strides,
std::vector<int>& paddings, PoolProcess pool_process) {
const int batch_size = input.dims()[0];
const int input_depth = input.dims()[2];
const int input_height = input.dims()[3];
const int input_width = input.dims()[4];
const int output_channels = output.dims()[1];
const int output_depth = output.dims()[2];
const int output_height = output.dims()[3];
const int output_width = output.dims()[4];
const int ksize_depth = ksize[0];
const int ksize_height = ksize[1];
const int ksize_width = ksize[2];
const int stride_depth = strides[0];
const int stride_height = strides[1];
const int stride_width = strides[2];
const int padding_depth = paddings[0];
const int padding_height = paddings[1];
const int padding_width = paddings[2];
const int input_stride = input_depth * input_height * input_width;
const int output_stride = output_depth * output_height * output_width;
const T* input_data = input.data<T>();
T* output_data = output.mutable_data<T>(context.GetPlace());
for (int i = 0; i < batch_size; i++) {
for (int c = 0; c < output_channels; ++c) {
for (int pd = 0; pd < output_depth; ++pd) {
int dstart = pd * stride_depth - padding_depth;
int dend = std::min(dstart + ksize_depth, input_depth);
dstart = std::max(dstart, 0);
for (int ph = 0; ph < output_height; ++ph) {
int hstart = ph * stride_height - padding_height;
int hend = std::min(hstart + ksize_height, input_height);
hstart = std::max(hstart, 0);
for (int pw = 0; pw < output_width; ++pw) {
int wstart = pw * stride_width - padding_width;
int wend = std::min(wstart + ksize_width, input_width);
wstart = std::max(wstart, 0);
int output_idx = (pd * output_height + ph) * output_width + pw;
T ele = pool_process.initial();
for (int d = dstart; d < dend; ++d) {
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
pool_process.compute(
ele,
input_data[(d * input_height + h) * input_width + w]);
}
}
}
int pool_size =
(dend - dstart) * (hend - hstart) * (wend - wstart);
pool_process.finalize(ele, static_cast<T>(pool_size));
output_data[output_idx] = ele;
}
}
}
input_data += input_stride;
output_data += output_stride;
}
}
}
};
template <typename PoolProcess, class T>
class Pool3dGradFunctor<platform::CPUPlace, PoolProcess, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& input_grad,
const framework::Tensor& output,
const framework::Tensor& output_grad, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings,
PoolProcess pool_grad_process) {
const int batch_size = input.dims()[0];
const int input_depth = input.dims()[2];
const int input_height = input.dims()[3];
const int input_width = input.dims()[4];
const int output_channels = output.dims()[1];
const int output_depth = output.dims()[2];
const int output_height = output.dims()[3];
const int output_width = output.dims()[4];
const int ksize_depth = ksize[0];
const int ksize_height = ksize[1];
const int ksize_width = ksize[2];
const int stride_depth = strides[0];
const int stride_height = strides[1];
const int stride_width = strides[2];
const int padding_depth = paddings[0];
const int padding_height = paddings[1];
const int padding_width = paddings[2];
const int input_stride = input_depth * input_height * input_width;
const int output_stride = output_depth * output_height * output_width;
const T* input_data = input.data<T>();
const T* output_data = output.data<T>();
const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad.mutable_data<T>(context.GetPlace());
for (int i = 0; i < batch_size; i++) {
for (int c = 0; c < output_channels; ++c) {
for (int pd = 0; pd < output_depth; ++pd) {
int dstart = pd * stride_depth - padding_depth;
int dend = std::min(dstart + ksize_depth, input_depth);
dstart = std::max(dstart, 0);
for (int ph = 0; ph < output_height; ++ph) {
int hstart = ph * stride_height - padding_height;
int hend = std::min(hstart + ksize_height, input_height);
hstart = std::max(hstart, 0);
for (int pw = 0; pw < output_width; ++pw) {
int wstart = pw * stride_width - padding_width;
int wend = std::min(wstart + ksize_width, input_width);
wstart = std::max(wstart, 0);
int pool_size =
(dend - dstart) * (hend - hstart) * (wend - wstart);
float scale = 1.0 / pool_size;
for (int d = dstart; d < dend; ++d) {
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
int input_idx = (d * input_height + h) * input_width + w;
int output_idx =
(pd * output_height + ph) * output_width + pw;
pool_grad_process.compute(
input_data[input_idx], output_data[output_idx],
output_grad_data[output_idx],
input_grad_data[input_idx], static_cast<T>(scale));
}
}
}
}
}
}
input_data += input_stride;
output_data += output_stride;
input_grad_data += input_stride;
output_grad_data += output_stride;
}
}
}
};
template <class T>
class MaxPool3dGradFunctor<platform::CPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& input_grad,
const framework::Tensor& output,
const framework::Tensor& output_grad, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings) {
const int batch_size = input.dims()[0];
const int input_depth = input.dims()[2];
const int input_height = input.dims()[3];
const int input_width = input.dims()[4];
const int output_channels = output.dims()[1];
const int output_depth = output.dims()[2];
const int output_height = output.dims()[3];
const int output_width = output.dims()[4];
const int ksize_depth = ksize[0];
const int ksize_height = ksize[1];
const int ksize_width = ksize[2];
const int stride_depth = strides[0];
const int stride_height = strides[1];
const int stride_width = strides[2];
const int padding_depth = paddings[0];
const int padding_height = paddings[1];
const int padding_width = paddings[2];
const int input_stride = input_depth * input_height * input_width;
const int output_stride = output_depth * output_height * output_width;
const T* input_data = input.data<T>();
const T* output_data = output.data<T>();
const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad.mutable_data<T>(context.GetPlace());
for (int i = 0; i < batch_size; i++) {
for (int c = 0; c < output_channels; ++c) {
for (int pd = 0; pd < output_depth; ++pd) {
int dstart = pd * stride_depth - padding_depth;
int dend = std::min(dstart + ksize_depth, input_depth);
dstart = std::max(dstart, 0);
for (int ph = 0; ph < output_height; ++ph) {
int hstart = ph * stride_height - padding_height;
int hend = std::min(hstart + ksize_height, input_height);
hstart = std::max(hstart, 0);
for (int pw = 0; pw < output_width; ++pw) {
int wstart = pw * stride_width - padding_width;
int wend = std::min(wstart + ksize_width, input_width);
wstart = std::max(wstart, 0);
bool stop = false;
for (int d = dstart; d < dend && !stop; ++d) {
for (int h = hstart; h < hend && !stop; ++h) {
for (int w = wstart; w < wend && !stop; ++w) {
int input_idx = (d * input_height + h) * input_width + w;
int output_idx =
(pd * output_height + ph) * output_width + pw;
if (input_data[input_idx] == output_data[output_idx]) {
input_grad_data[input_idx] +=
output_grad_data[output_idx];
stop = true;
}
}
}
}
}
}
}
input_data += input_stride;
output_data += output_stride;
input_grad_data += input_stride;
output_grad_data += output_stride;
}
}
}
};
template class MaxPool3dGradFunctor<platform::CPUPlace, float>;
// template class MaxPool3dGradFunctor<platform::CPUPlace, double>;
template class Pool3dFunctor<platform::CPUPlace,
paddle::operators::math::MaxPool<float>, float>;
template class Pool3dFunctor<platform::CPUPlace,
paddle::operators::math::AvgPool<float>, float>;
template class Pool3dGradFunctor<
platform::CPUPlace, paddle::operators::math::MaxPoolGrad<float>, float>;
template class Pool3dGradFunctor<
platform::CPUPlace, paddle::operators::math::AvgPoolGrad<float>, float>;
template class Pool3dFunctor<platform::CPUPlace,
paddle::operators::math::MaxPool<double>, double>;
template class Pool3dFunctor<platform::CPUPlace,
paddle::operators::math::AvgPool<double>, double>;
template class Pool3dGradFunctor<
platform::CPUPlace, paddle::operators::math::MaxPoolGrad<double>, double>;
template class Pool3dGradFunctor<
platform::CPUPlace, paddle::operators::math::AvgPoolGrad<double>, double>;
} // namespace math
} // namespace operators
} // namespace paddle
此差异已折叠。
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/hostdevice.h"
namespace paddle {
namespace operators {
namespace math {
//////////////////////
#define FLT_MAX __FLT_MAX__ //
template <class T>
class MaxPool {
public:
DEVICE inline T initial() { return static_cast<T>(-FLT_MAX); }
DEVICE inline void compute(T& y, const T& x) { y = y > x ? y : x; }
DEVICE inline void finalize(T& y, const T& poo_size) {}
};
template <class T>
class AvgPool {
public:
DEVICE inline T initial() { return static_cast<T>(0); }
DEVICE inline void compute(T& y, const T& x) { y += x; }
DEVICE inline void finalize(T& y, const T& poo_size) { y /= poo_size; }
};
template <class T>
class MaxPoolGrad {
public:
DEVICE inline void compute(const T& x, const T& y, const T& dy, T& dx,
T scale) {
dx += dy * (x == y);
}
};
template <class T>
class AvgPoolGrad {
public:
DEVICE inline void compute(const T& x, const T& y, const T& dy, T& dx,
T scale) {
dx += (scale * dy);
}
};
template <typename Place, typename PoolProcess, typename T>
class Pool2dFunctor {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& output,
std::vector<int>& ksize, std::vector<int>& strides,
std::vector<int>& paddings, PoolProcess pool_compute);
};
template <typename Place, typename PoolProcess, typename T>
class Pool2dGradFunctor {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& input_grad,
const framework::Tensor& output,
const framework::Tensor& output_grad, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings,
PoolProcess pool_compute);
};
template <typename Place, class T>
class MaxPool2dGradFunctor {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& input_grad,
const framework::Tensor& output,
const framework::Tensor& output_grad, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings);
};
template <typename Place, typename PoolProcess, typename T>
class Pool3dFunctor {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& output,
std::vector<int>& ksize, std::vector<int>& strides,
std::vector<int>& paddings, PoolProcess pool_compute);
};
template <typename Place, typename PoolProcess, typename T>
class Pool3dGradFunctor {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& input_grad,
const framework::Tensor& output,
const framework::Tensor& output_grad, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings,
PoolProcess pool_compute);
};
template <typename Place, class T>
class MaxPool3dGradFunctor {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& input_grad,
const framework::Tensor& output,
const framework::Tensor& output_grad, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings);
};
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/pool_op.h"
namespace paddle {
namespace operators {
int OutputSizePool(int input_size, int filter_size, int padding, int stride) {
int output_size = (input_size - filter_size + 2 * padding) / stride + 1;
return output_size;
}
class PoolOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"X(Input) of Pooling should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Out(Output) of Pooling should not be null.");
auto in_x_dims = ctx->GetInputDim("X");
std::string pooling_type = ctx->Attrs().Get<std::string>("poolingType");
std::vector<int> ksize = ctx->Attrs().Get<std::vector<int>>("ksize");
std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
PADDLE_ENFORCE(pooling_type == "max" || pooling_type == "avg",
"pooling_type should be 'max' or 'avg'");
PADDLE_ENFORCE(in_x_dims.size() == 4 || in_x_dims.size() == 5,
"Pooling intput should be 4-D or 5-D");
if (ctx->Attrs().Get<bool>("globalPooling")) {
ksize.resize(static_cast<size_t>(in_x_dims.size()) - 2);
for (size_t i = 0; i < ksize.size(); ++i)
ksize[i] = static_cast<int>(in_x_dims[i + 2]);
}
PADDLE_ENFORCE(in_x_dims.size() - ksize.size() == 2U,
"Input size and Pooling size should be consistent.");
PADDLE_ENFORCE(ksize.size() == 2 || ksize.size() == 3,
"Pooling size should be 2 elements. or 3 elements.");
PADDLE_ENFORCE_EQ(ksize.size(), strides.size(),
"strides size and pooling size should be the same.");
PADDLE_ENFORCE_EQ(ksize.size(), paddings.size(),
"paddings size and pooling size should be the same.");
std::vector<int64_t> output_shape({in_x_dims[0], in_x_dims[1]});
for (size_t i = 0; i < ksize.size(); ++i) {
output_shape.push_back(
OutputSizePool(in_x_dims[i + 2], ksize[i], paddings[i], strides[i]));
}
ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
}
};
class PoolOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"X(Input) of Pooling should not be null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
"Input@Grad of Pooling should not be null.");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
};
class Pool2dOpMaker : public framework::OpProtoAndCheckerMaker {
public:
Pool2dOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(
"X",
"The input tensor of pooling operator. "
"The format of input tensor is NCHW. Where N is batch size, C is the "
"number of channels, H and W is the height and width of feature.");
AddOutput("Out",
"The output tensor of pooling operator."
"The format of output tensor is also NCHW.");
AddAttr<std::string>("poolingType",
"PoolingType of pooling operator."
"Str constant equal to 'max' or 'avg'.")
.InEnum({"max", "avg"});
AddAttr<std::vector<int>>(
"ksize",
"Pooling size(depth, height, width) of pooling operator."
"If globalPooling = true, ksize is ignored and need not be "
"specified."); // TODO(Add checker)
AddAttr<bool>(
"globalPooling",
"Whether to use the globalPooling."
"Bool constant equal to false or true."
"Default false."
"If globalPooling = true, ksize is ignored and need not be specified.")
.SetDefault(false);
AddAttr<std::vector<int>>("strides",
"Strides(height, width) of pooling operator."
"Default {1,1}")
.SetDefault({1, 1}); // TODO(Add checker)
AddAttr<std::vector<int>>("paddings",
"Paddings(height, width) of pooling operator."
"Default {0,0}.")
.SetDefault({0, 0}); // TODO(Add checker)
AddComment(R"DOC(
The pooling2d operation calculates the output based on
the input, poolingType and ksize, strides, paddings parameters.
)DOC");
}
};
class Pool3dOpMaker : public framework::OpProtoAndCheckerMaker {
public:
Pool3dOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
"The input tensor of pooling operator. "
"The format of input tensor is NCDHW. Where N is batch size, C is "
"the "
"number of channels, D, H and W is the depth, height and width of "
"feature.");
AddOutput("Out",
"The output tensor of pooling operator."
"The format of output tensor is also NCDHW.");
AddAttr<std::string>("poolingType",
"PoolingType of pooling operator."
"str constant equal to 'max' or 'avg'.")
.InEnum({"max", "avg"});
AddAttr<std::vector<int>>(
"ksize",
"Pooling size(depth, height, width) of pooling operator."
"If globalPooling = true, ksize is ignored and need not be "
"specified."); // TODO(Add checker)
AddAttr<bool>(
"globalPooling",
"Whether to use the globalPooling."
"Bool constant equal to false or true."
"Default false."
"If globalPooling = true, ksize is ignored and need not be specified.")
.SetDefault(false);
AddAttr<std::vector<int>>(
"strides",
"Strides(depth, height, width) of pooling operator."
"Default {1,1,1}.")
.SetDefault({1, 1, 1}); // TODO(Add checker)
AddAttr<std::vector<int>>(
"paddings",
"Paddings(depth, height, width) of pooling operator."
"Default {0,0,0}.")
.SetDefault({0, 0, 0}); // TODO(Add checker)
AddComment(R"DOC(
The pooling3d operation calculates the output based on
the input, poolingType and ksize, strides, paddings parameters.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(pool2d, ops::PoolOp, ops::Pool2dOpMaker, pool2d_grad,
ops::PoolOpGrad);
REGISTER_OP_CPU_KERNEL(pool2d,
ops::PoolKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(pool2d_grad,
ops::PoolGradKernel<paddle::platform::CPUPlace, float>)
REGISTER_OP(pool3d, ops::PoolOp, ops::Pool3dOpMaker, pool3d_grad,
ops::PoolOpGrad);
REGISTER_OP_CPU_KERNEL(pool3d,
ops::PoolKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(pool3d_grad,
ops::PoolGradKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/pool_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(pool2d,
ops::PoolKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(pool2d_grad,
ops::PoolGradKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(pool3d,
ops::PoolKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(pool3d_grad,
ops::PoolGradKernel<paddle::platform::GPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/pooling.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename Place, typename T>
class PoolKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* in_x = context.Input<Tensor>("X");
Tensor* out = context.Output<Tensor>("Out");
std::string pooling_type = context.Attr<std::string>("poolingType");
std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
if (context.Attr<bool>("globalPooling")) {
for (size_t i = 0; i < ksize.size(); ++i) {
ksize[i] = static_cast<int>(in_x->dims()[i + 2]);
}
}
switch (ksize.size()) {
case 2: {
if (pooling_type == "max") {
paddle::operators::math::Pool2dFunctor<
Place, paddle::operators::math::MaxPool<T>, T>
pool2d_forward;
paddle::operators::math::MaxPool<T> pool_process;
pool2d_forward(context.device_context(), *in_x, *out, ksize, strides,
paddings, pool_process);
} else if (pooling_type == "avg") {
paddle::operators::math::Pool2dFunctor<
Place, paddle::operators::math::AvgPool<T>, T>
pool2d_forward;
paddle::operators::math::AvgPool<T> pool_process;
pool2d_forward(context.device_context(), *in_x, *out, ksize, strides,
paddings, pool_process);
}
} break;
case 3: {
if (pooling_type == "max") {
paddle::operators::math::Pool3dFunctor<
Place, paddle::operators::math::MaxPool<T>, T>
pool3d_forward;
paddle::operators::math::MaxPool<T> pool_process;
pool3d_forward(context.device_context(), *in_x, *out, ksize, strides,
paddings, pool_process);
} else if (pooling_type == "avg") {
paddle::operators::math::Pool3dFunctor<
Place, paddle::operators::math::AvgPool<T>, T>
pool3d_forward;
paddle::operators::math::AvgPool<T> pool_process;
pool3d_forward(context.device_context(), *in_x, *out, ksize, strides,
paddings, pool_process);
}
} break;
}
}
};
template <typename Place, typename T>
class PoolGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* in_x = context.Input<Tensor>("X");
const Tensor* out = context.Input<Tensor>("Out");
const Tensor* out_grad =
context.Input<Tensor>(framework::GradVarName("Out"));
Tensor* in_x_grad = context.Output<Tensor>(framework::GradVarName("X"));
std::string pooling_type = context.Attr<std::string>("poolingType");
std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
if (context.Attr<bool>("globalPooling")) {
for (size_t i = 0; i < ksize.size(); ++i)
ksize[i] = static_cast<int>(in_x->dims()[i + 2]);
}
if (in_x_grad) {
in_x_grad->mutable_data<T>(context.GetPlace());
auto temp = framework::EigenVector<T>::Flatten(*in_x_grad);
temp.device(context.GetEigenDevice<Place>()) =
temp.constant(static_cast<T>(0));
switch (ksize.size()) {
case 2: {
if (pooling_type == "max") {
paddle::operators::math::MaxPool2dGradFunctor<Place, T>
pool2d_backward;
pool2d_backward(context.device_context(), *in_x, *in_x_grad, *out,
*out_grad, ksize, strides, paddings);
} else if (pooling_type == "avg") {
paddle::operators::math::Pool2dGradFunctor<
Place, paddle::operators::math::AvgPoolGrad<T>, T>
pool2d_backward;
paddle::operators::math::AvgPoolGrad<T> pool_process;
pool2d_backward(context.device_context(), *in_x, *in_x_grad, *out,
*out_grad, ksize, strides, paddings, pool_process);
}
} break;
case 3: {
if (pooling_type == "max") {
paddle::operators::math::MaxPool3dGradFunctor<Place, T>
pool3d_backward;
pool3d_backward(context.device_context(), *in_x, *in_x_grad, *out,
*out_grad, ksize, strides, paddings);
} else if (pooling_type == "avg") {
paddle::operators::math::Pool3dGradFunctor<
Place, paddle::operators::math::AvgPoolGrad<T>, T>
pool3d_backward;
paddle::operators::math::AvgPoolGrad<T> pool_process;
pool3d_backward(context.device_context(), *in_x, *in_x_grad, *out,
*out_grad, ksize, strides, paddings, pool_process);
}
} break;
}
}
}
};
} // namespace operators
} // namespace paddle
......@@ -2,8 +2,10 @@
#ifdef __CUDACC__
#define HOSTDEVICE __host__ __device__
#define DEVICE __device__
#define HOST __host__
#else
#define HOSTDEVICE
#define DEVICE
#define HOST
#endif
import unittest
import numpy as np
from op_test import OpTest
def max_pool2D_forward_naive(x, ksize, strides, paddings=[0, 0], global_pool=0):
N, C, H, W = x.shape
if global_pool == 1:
ksize = [H, W]
H_out = (H - ksize[0] + 2 * paddings[0]) / strides[0] + 1
W_out = (W - ksize[1] + 2 * paddings[1]) / strides[1] + 1
out = np.zeros((N, C, H_out, W_out))
for i in xrange(H_out):
for j in xrange(W_out):
r_start = np.max((i * strides[0] - paddings[0], 0))
r_end = np.min((i * strides[0] + ksize[0] - paddings[0], H))
c_start = np.max((j * strides[1] - paddings[1], 0))
c_end = np.min((j * strides[1] + ksize[1] - paddings[1], W))
x_masked = x[:, :, r_start:r_end, c_start:c_end]
out[:, :, i, j] = np.max(x_masked, axis=(2, 3))
return out
def avg_pool2D_forward_naive(x, ksize, strides, paddings=[0, 0], global_pool=0):
N, C, H, W = x.shape
if global_pool == 1:
ksize = [H, W]
H_out = (H - ksize[0] + 2 * paddings[0]) / strides[0] + 1
W_out = (W - ksize[1] + 2 * paddings[1]) / strides[1] + 1
out = np.zeros((N, C, H_out, W_out))
for i in xrange(H_out):
for j in xrange(W_out):
r_start = np.max((i * strides[0] - paddings[0], 0))
r_end = np.min((i * strides[0] + ksize[0] - paddings[0], H))
c_start = np.max((j * strides[1] - paddings[1], 0))
c_end = np.min((j * strides[1] + ksize[1] - paddings[1], W))
x_masked = x[:, :, r_start:r_end, c_start:c_end]
out[:, :, i, j] = np.sum(x_masked, axis=(2, 3)) / (
(r_end - r_start) * (c_end - c_start))
return out
class TestPool2d_Op(OpTest):
def setUp(self):
self.initTestCase()
input = np.random.random(self.shape).astype("float32")
output = self.pool2D_forward_naive(input, self.ksize, self.strides,
self.paddings, self.global_pool)
self.inputs = {'X': input}
self.attrs = {
'strides': self.strides,
'paddings': self.paddings,
'ksize': self.ksize,
'poolingType': self.pool_type,
'globalPooling': self.global_pool,
}
self.outputs = {'Out': output}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
if self.pool_type != "max":
self.check_grad(set(['X']), 'Out', max_relative_error=0.07)
def initTestCase(self):
self.global_pool = True
self.op_type = "pool2d"
self.pool_type = "avg"
self.pool2D_forward_naive = avg_pool2D_forward_naive
self.shape = [2, 3, 5, 5]
self.ksize = [3, 3]
self.strides = [1, 1]
self.paddings = [0, 0]
class TestCase1(TestPool2d_Op):
def initTestCase(self):
self.global_pool = False
self.op_type = "pool2d"
self.pool_type = "avg"
self.pool2D_forward_naive = avg_pool2D_forward_naive
self.shape = [2, 3, 7, 7]
self.ksize = [3, 3]
self.strides = [1, 1]
self.paddings = [0, 0]
class TestCase2(TestPool2d_Op):
def initTestCase(self):
self.global_pool = False
self.op_type = "pool2d"
self.pool_type = "avg"
self.pool2D_forward_naive = avg_pool2D_forward_naive
self.shape = [2, 3, 7, 7]
self.ksize = [3, 3]
self.strides = [1, 1]
self.paddings = [1, 1]
class TestCase3(TestPool2d_Op):
def initTestCase(self):
self.global_pool = True
self.op_type = "pool2d"
self.pool_type = "max"
self.pool2D_forward_naive = max_pool2D_forward_naive
self.shape = [2, 3, 5, 5]
self.ksize = [3, 3]
self.strides = [1, 1]
self.paddings = [0, 0]
class TestCase4(TestPool2d_Op):
def initTestCase(self):
self.global_pool = False
self.op_type = "pool2d"
self.pool_type = "max"
self.pool2D_forward_naive = max_pool2D_forward_naive
self.shape = [2, 3, 7, 7]
self.ksize = [3, 3]
self.strides = [1, 1]
self.paddings = [0, 0]
class TestCase5(TestPool2d_Op):
def initTestCase(self):
self.global_pool = False
self.op_type = "pool2d"
self.pool_type = "max"
self.pool2D_forward_naive = max_pool2D_forward_naive
self.shape = [2, 3, 7, 7]
self.ksize = [3, 3]
self.strides = [1, 1]
self.paddings = [1, 1]
if __name__ == '__main__':
unittest.main()
import unittest
import numpy as np
from op_test import OpTest
def max_pool3D_forward_naive(x, ksize, strides, paddings=[0, 0], global_pool=0):
N, C, D, H, W = x.shape
if global_pool == 1:
ksize = [D, H, W]
D_out = (D - ksize[0] + 2 * paddings[0]) / strides[0] + 1
H_out = (H - ksize[1] + 2 * paddings[1]) / strides[1] + 1
W_out = (W - ksize[2] + 2 * paddings[2]) / strides[2] + 1
out = np.zeros((N, C, D_out, H_out, W_out))
for k in xrange(D_out):
d_start = np.max((k * strides[0] - paddings[0], 0))
d_end = np.min((k * strides[0] + ksize[0] - paddings[0], D))
for i in xrange(H_out):
h_start = np.max((i * strides[0] - paddings[0], 0))
h_end = np.min((i * strides[0] + ksize[0] - paddings[0], H))
for j in xrange(W_out):
w_start = np.max((j * strides[1] - paddings[1], 0))
w_end = np.min((j * strides[1] + ksize[1] - paddings[1], W))
x_masked = x[:, :, d_start:d_end, h_start:h_end, w_start:w_end]
out[:, :, k, i, j] = np.max(x_masked, axis=(2, 3, 4))
return out
def avg_pool3D_forward_naive(x, ksize, strides, paddings=[0, 0], global_pool=0):
N, C, D, H, W = x.shape
if global_pool == 1:
ksize = [D, H, W]
D_out = (D - ksize[0] + 2 * paddings[0]) / strides[0] + 1
H_out = (H - ksize[1] + 2 * paddings[1]) / strides[1] + 1
W_out = (W - ksize[2] + 2 * paddings[2]) / strides[2] + 1
out = np.zeros((N, C, D_out, H_out, W_out))
for k in xrange(D_out):
d_start = np.max((k * strides[0] - paddings[0], 0))
d_end = np.min((k * strides[0] + ksize[0] - paddings[0], D))
for i in xrange(H_out):
h_start = np.max((i * strides[0] - paddings[0], 0))
h_end = np.min((i * strides[0] + ksize[0] - paddings[0], H))
for j in xrange(W_out):
w_start = np.max((j * strides[1] - paddings[1], 0))
w_end = np.min((j * strides[1] + ksize[1] - paddings[1], W))
x_masked = x[:, :, d_start:d_end, h_start:h_end, w_start:w_end]
out[:, :, k, i, j] = np.sum(x_masked, axis=(2, 3, 4)) / (
(d_end - d_start) * (h_end - h_start) * (w_end - w_start))
return out
class TestPool3d_Op(OpTest):
def setUp(self):
self.initTestCase()
input = np.random.random(self.shape).astype("float32")
output = self.pool3D_forward_naive(input, self.ksize, self.strides,
self.paddings, self.global_pool)
self.inputs = {'X': input}
self.attrs = {
'strides': self.strides,
'paddings': self.paddings,
'ksize': self.ksize,
'poolingType': self.pool_type,
'globalPooling': self.global_pool,
}
self.outputs = {'Out': output}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
if self.pool_type != "max":
self.check_grad(set(['X']), 'Out', max_relative_error=0.07)
def initTestCase(self):
self.global_pool = True
self.op_type = "pool3d"
self.pool_type = "avg"
self.pool3D_forward_naive = avg_pool3D_forward_naive
self.shape = [2, 3, 5, 5, 5]
self.ksize = [3, 3, 3]
self.strides = [1, 1, 1]
self.paddings = [0, 0, 0]
class TestCase1(TestPool3d_Op):
def initTestCase(self):
self.global_pool = False
self.op_type = "pool3d"
self.pool_type = "avg"
self.pool3D_forward_naive = avg_pool3D_forward_naive
self.shape = [2, 3, 7, 7, 7]
self.ksize = [3, 3, 3]
self.strides = [1, 1, 1]
self.paddings = [0, 0, 0]
class TestCase2(TestPool3d_Op):
def initTestCase(self):
self.global_pool = False
self.op_type = "pool3d"
self.pool_type = "avg"
self.pool3D_forward_naive = avg_pool3D_forward_naive
self.shape = [2, 3, 7, 7, 7]
self.ksize = [3, 3, 3]
self.strides = [1, 1, 1]
self.paddings = [1, 1, 1]
class TestCase3(TestPool3d_Op):
def initTestCase(self):
self.global_pool = True
self.op_type = "pool3d"
self.pool_type = "max"
self.pool3D_forward_naive = max_pool3D_forward_naive
self.shape = [2, 3, 5, 5, 5]
self.ksize = [3, 3, 3]
self.strides = [1, 1, 1]
self.paddings = [0, 0, 0]
class TestCase4(TestPool3d_Op):
def initTestCase(self):
self.global_pool = False
self.op_type = "pool3d"
self.pool_type = "max"
self.pool3D_forward_naive = max_pool3D_forward_naive
self.shape = [2, 3, 7, 7, 7]
self.ksize = [3, 3, 3]
self.strides = [1, 1, 1]
self.paddings = [0, 0, 0]
class TestCase5(TestPool3d_Op):
def initTestCase(self):
self.global_pool = False
self.op_type = "pool3d"
self.pool_type = "max"
self.pool3D_forward_naive = max_pool3D_forward_naive
self.shape = [2, 3, 7, 7, 7]
self.ksize = [3, 3, 3]
self.strides = [1, 1, 1]
self.paddings = [1, 1, 1]
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册