未验证 提交 feaf1e2d 编写于 作者: C chengduo 提交者: GitHub

Merge pull request #5608 from chengduoZH/fix_pooling_function_parameter_order

fix pooling functor parameter order
...@@ -27,15 +27,15 @@ template <typename PoolProcess, typename T> ...@@ -27,15 +27,15 @@ template <typename PoolProcess, typename T>
class Pool2dFunctor<platform::CPUPlace, PoolProcess, T> { class Pool2dFunctor<platform::CPUPlace, PoolProcess, T> {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& output, const framework::Tensor& input, std::vector<int>& ksize,
std::vector<int>& ksize, std::vector<int>& strides, std::vector<int>& strides, std::vector<int>& paddings,
std::vector<int>& paddings, PoolProcess pool_process) { PoolProcess pool_process, framework::Tensor* output) {
const int batch_size = input.dims()[0]; const int batch_size = input.dims()[0];
const int input_height = input.dims()[2]; const int input_height = input.dims()[2];
const int input_width = input.dims()[3]; const int input_width = input.dims()[3];
const int output_channels = output.dims()[1]; const int output_channels = output->dims()[1];
const int output_height = output.dims()[2]; const int output_height = output->dims()[2];
const int output_width = output.dims()[3]; const int output_width = output->dims()[3];
const int ksize_height = ksize[0]; const int ksize_height = ksize[0];
const int ksize_width = ksize[1]; const int ksize_width = ksize[1];
const int stride_height = strides[0]; const int stride_height = strides[0];
...@@ -47,7 +47,7 @@ class Pool2dFunctor<platform::CPUPlace, PoolProcess, T> { ...@@ -47,7 +47,7 @@ class Pool2dFunctor<platform::CPUPlace, PoolProcess, T> {
const int output_stride = output_height * output_width; const int output_stride = output_height * output_width;
const T* input_data = input.data<T>(); const T* input_data = input.data<T>();
T* output_data = output.mutable_data<T>(context.GetPlace()); T* output_data = output->mutable_data<T>(context.GetPlace());
for (int i = 0; i < batch_size; i++) { for (int i = 0; i < batch_size; i++) {
for (int c = 0; c < output_channels; ++c) { for (int c = 0; c < output_channels; ++c) {
...@@ -87,11 +87,12 @@ template <typename PoolProcess, class T> ...@@ -87,11 +87,12 @@ template <typename PoolProcess, class T>
class Pool2dGradFunctor<platform::CPUPlace, PoolProcess, T> { class Pool2dGradFunctor<platform::CPUPlace, PoolProcess, T> {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& input_grad, const framework::Tensor& input,
const framework::Tensor& output, const framework::Tensor& output,
const framework::Tensor& output_grad, std::vector<int>& ksize, const framework::Tensor& output_grad, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings, std::vector<int>& strides, std::vector<int>& paddings,
PoolProcess pool_grad_process) { PoolProcess pool_grad_process,
framework::Tensor* input_grad) {
const int batch_size = input.dims()[0]; const int batch_size = input.dims()[0];
const int input_height = input.dims()[2]; const int input_height = input.dims()[2];
const int input_width = input.dims()[3]; const int input_width = input.dims()[3];
...@@ -110,7 +111,7 @@ class Pool2dGradFunctor<platform::CPUPlace, PoolProcess, T> { ...@@ -110,7 +111,7 @@ class Pool2dGradFunctor<platform::CPUPlace, PoolProcess, T> {
const T* input_data = input.data<T>(); const T* input_data = input.data<T>();
const T* output_data = output.data<T>(); const T* output_data = output.data<T>();
const T* output_grad_data = output_grad.data<T>(); const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad.mutable_data<T>(context.GetPlace()); T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
for (int i = 0; i < batch_size; i++) { for (int i = 0; i < batch_size; i++) {
for (int c = 0; c < output_channels; ++c) { for (int c = 0; c < output_channels; ++c) {
...@@ -154,10 +155,11 @@ template <class T> ...@@ -154,10 +155,11 @@ template <class T>
class MaxPool2dGradFunctor<platform::CPUPlace, T> { class MaxPool2dGradFunctor<platform::CPUPlace, T> {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& input_grad, const framework::Tensor& input,
const framework::Tensor& output, const framework::Tensor& output,
const framework::Tensor& output_grad, std::vector<int>& ksize, const framework::Tensor& output_grad, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings) { std::vector<int>& strides, std::vector<int>& paddings,
framework::Tensor* input_grad) {
const int batch_size = input.dims()[0]; const int batch_size = input.dims()[0];
const int input_height = input.dims()[2]; const int input_height = input.dims()[2];
const int input_width = input.dims()[3]; const int input_width = input.dims()[3];
...@@ -176,7 +178,7 @@ class MaxPool2dGradFunctor<platform::CPUPlace, T> { ...@@ -176,7 +178,7 @@ class MaxPool2dGradFunctor<platform::CPUPlace, T> {
const T* input_data = input.data<T>(); const T* input_data = input.data<T>();
const T* output_data = output.data<T>(); const T* output_data = output.data<T>();
const T* output_grad_data = output_grad.data<T>(); const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad.mutable_data<T>(context.GetPlace()); T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
for (int i = 0; i < batch_size; i++) { for (int i = 0; i < batch_size; i++) {
for (int c = 0; c < output_channels; ++c) { for (int c = 0; c < output_channels; ++c) {
...@@ -240,17 +242,17 @@ template <typename PoolProcess, class T> ...@@ -240,17 +242,17 @@ template <typename PoolProcess, class T>
class Pool3dFunctor<platform::CPUPlace, PoolProcess, T> { class Pool3dFunctor<platform::CPUPlace, PoolProcess, T> {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& output, const framework::Tensor& input, std::vector<int>& ksize,
std::vector<int>& ksize, std::vector<int>& strides, std::vector<int>& strides, std::vector<int>& paddings,
std::vector<int>& paddings, PoolProcess pool_process) { PoolProcess pool_process, framework::Tensor* output) {
const int batch_size = input.dims()[0]; const int batch_size = input.dims()[0];
const int input_depth = input.dims()[2]; const int input_depth = input.dims()[2];
const int input_height = input.dims()[3]; const int input_height = input.dims()[3];
const int input_width = input.dims()[4]; const int input_width = input.dims()[4];
const int output_channels = output.dims()[1]; const int output_channels = output->dims()[1];
const int output_depth = output.dims()[2]; const int output_depth = output->dims()[2];
const int output_height = output.dims()[3]; const int output_height = output->dims()[3];
const int output_width = output.dims()[4]; const int output_width = output->dims()[4];
const int ksize_depth = ksize[0]; const int ksize_depth = ksize[0];
const int ksize_height = ksize[1]; const int ksize_height = ksize[1];
const int ksize_width = ksize[2]; const int ksize_width = ksize[2];
...@@ -265,7 +267,7 @@ class Pool3dFunctor<platform::CPUPlace, PoolProcess, T> { ...@@ -265,7 +267,7 @@ class Pool3dFunctor<platform::CPUPlace, PoolProcess, T> {
const int output_stride = output_depth * output_height * output_width; const int output_stride = output_depth * output_height * output_width;
const T* input_data = input.data<T>(); const T* input_data = input.data<T>();
T* output_data = output.mutable_data<T>(context.GetPlace()); T* output_data = output->mutable_data<T>(context.GetPlace());
for (int i = 0; i < batch_size; i++) { for (int i = 0; i < batch_size; i++) {
for (int c = 0; c < output_channels; ++c) { for (int c = 0; c < output_channels; ++c) {
...@@ -315,11 +317,12 @@ template <typename PoolProcess, class T> ...@@ -315,11 +317,12 @@ template <typename PoolProcess, class T>
class Pool3dGradFunctor<platform::CPUPlace, PoolProcess, T> { class Pool3dGradFunctor<platform::CPUPlace, PoolProcess, T> {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& input_grad, const framework::Tensor& input,
const framework::Tensor& output, const framework::Tensor& output,
const framework::Tensor& output_grad, std::vector<int>& ksize, const framework::Tensor& output_grad, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings, std::vector<int>& strides, std::vector<int>& paddings,
PoolProcess pool_grad_process) { PoolProcess pool_grad_process,
framework::Tensor* input_grad) {
const int batch_size = input.dims()[0]; const int batch_size = input.dims()[0];
const int input_depth = input.dims()[2]; const int input_depth = input.dims()[2];
const int input_height = input.dims()[3]; const int input_height = input.dims()[3];
...@@ -343,7 +346,7 @@ class Pool3dGradFunctor<platform::CPUPlace, PoolProcess, T> { ...@@ -343,7 +346,7 @@ class Pool3dGradFunctor<platform::CPUPlace, PoolProcess, T> {
const T* input_data = input.data<T>(); const T* input_data = input.data<T>();
const T* output_data = output.data<T>(); const T* output_data = output.data<T>();
const T* output_grad_data = output_grad.data<T>(); const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad.mutable_data<T>(context.GetPlace()); T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
for (int i = 0; i < batch_size; i++) { for (int i = 0; i < batch_size; i++) {
for (int c = 0; c < output_channels; ++c) { for (int c = 0; c < output_channels; ++c) {
...@@ -398,10 +401,11 @@ template <class T> ...@@ -398,10 +401,11 @@ template <class T>
class MaxPool3dGradFunctor<platform::CPUPlace, T> { class MaxPool3dGradFunctor<platform::CPUPlace, T> {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& input_grad, const framework::Tensor& input,
const framework::Tensor& output, const framework::Tensor& output,
const framework::Tensor& output_grad, std::vector<int>& ksize, const framework::Tensor& output_grad, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings) { std::vector<int>& strides, std::vector<int>& paddings,
framework::Tensor* input_grad) {
const int batch_size = input.dims()[0]; const int batch_size = input.dims()[0];
const int input_depth = input.dims()[2]; const int input_depth = input.dims()[2];
const int input_height = input.dims()[3]; const int input_height = input.dims()[3];
...@@ -425,7 +429,7 @@ class MaxPool3dGradFunctor<platform::CPUPlace, T> { ...@@ -425,7 +429,7 @@ class MaxPool3dGradFunctor<platform::CPUPlace, T> {
const T* input_data = input.data<T>(); const T* input_data = input.data<T>();
const T* output_data = output.data<T>(); const T* output_data = output.data<T>();
const T* output_grad_data = output_grad.data<T>(); const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad.mutable_data<T>(context.GetPlace()); T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
for (int i = 0; i < batch_size; i++) { for (int i = 0; i < batch_size; i++) {
for (int c = 0; c < output_channels; ++c) { for (int c = 0; c < output_channels; ++c) {
...@@ -498,15 +502,15 @@ template <typename T> ...@@ -498,15 +502,15 @@ template <typename T>
class MaxPool2dWithIndexFunctor<platform::CPUPlace, T> { class MaxPool2dWithIndexFunctor<platform::CPUPlace, T> {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& output, const framework::Tensor& input, std::vector<int>& ksize,
framework::Tensor& mask, std::vector<int>& ksize, std::vector<int>& strides, std::vector<int>& paddings,
std::vector<int>& strides, std::vector<int>& paddings) { framework::Tensor* output, framework::Tensor* mask) {
const int batch_size = input.dims()[0]; const int batch_size = input.dims()[0];
const int input_height = input.dims()[2]; const int input_height = input.dims()[2];
const int input_width = input.dims()[3]; const int input_width = input.dims()[3];
const int output_channels = output.dims()[1]; const int output_channels = output->dims()[1];
const int output_height = output.dims()[2]; const int output_height = output->dims()[2];
const int output_width = output.dims()[3]; const int output_width = output->dims()[3];
const int ksize_height = ksize[0]; const int ksize_height = ksize[0];
const int ksize_width = ksize[1]; const int ksize_width = ksize[1];
const int stride_height = strides[0]; const int stride_height = strides[0];
...@@ -517,8 +521,8 @@ class MaxPool2dWithIndexFunctor<platform::CPUPlace, T> { ...@@ -517,8 +521,8 @@ class MaxPool2dWithIndexFunctor<platform::CPUPlace, T> {
const int output_stride = output_height * output_width; const int output_stride = output_height * output_width;
const T* input_data = input.data<T>(); const T* input_data = input.data<T>();
T* output_data = output.mutable_data<T>(context.GetPlace()); T* output_data = output->mutable_data<T>(context.GetPlace());
T* mask_data = mask.mutable_data<T>(context.GetPlace()); T* mask_data = mask->mutable_data<T>(context.GetPlace());
for (int i = 0; i < batch_size; i++) { for (int i = 0; i < batch_size; i++) {
for (int c = 0; c < output_channels; ++c) { for (int c = 0; c < output_channels; ++c) {
...@@ -563,13 +567,13 @@ template <typename T> ...@@ -563,13 +567,13 @@ template <typename T>
class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, T> { class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, T> {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
framework::Tensor& input_grad,
const framework::Tensor& output_grad, const framework::Tensor& output_grad,
const framework::Tensor& mask, std::vector<int>& ksize, const framework::Tensor& mask, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings) { std::vector<int>& strides, std::vector<int>& paddings,
const int batch_size = input_grad.dims()[0]; framework::Tensor* input_grad) {
const int input_height = input_grad.dims()[2]; const int batch_size = input_grad->dims()[0];
const int input_width = input_grad.dims()[3]; const int input_height = input_grad->dims()[2];
const int input_width = input_grad->dims()[3];
const int output_channels = output_grad.dims()[1]; const int output_channels = output_grad.dims()[1];
const int output_height = output_grad.dims()[2]; const int output_height = output_grad.dims()[2];
const int output_width = output_grad.dims()[3]; const int output_width = output_grad.dims()[3];
...@@ -578,7 +582,7 @@ class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, T> { ...@@ -578,7 +582,7 @@ class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, T> {
const T* mask_data = mask.data<T>(); const T* mask_data = mask.data<T>();
const T* output_grad_data = output_grad.data<T>(); const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad.mutable_data<T>(context.GetPlace()); T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
for (int n = 0; n < batch_size; ++n) { for (int n = 0; n < batch_size; ++n) {
for (int c = 0; c < output_channels; ++c) { for (int c = 0; c < output_channels; ++c) {
...@@ -612,17 +616,17 @@ template <typename T> ...@@ -612,17 +616,17 @@ template <typename T>
class MaxPool3dWithIndexFunctor<platform::CPUPlace, T> { class MaxPool3dWithIndexFunctor<platform::CPUPlace, T> {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& output, const framework::Tensor& input, std::vector<int>& ksize,
framework::Tensor& mask, std::vector<int>& ksize, std::vector<int>& strides, std::vector<int>& paddings,
std::vector<int>& strides, std::vector<int>& paddings) { framework::Tensor* output, framework::Tensor* mask) {
const int batch_size = input.dims()[0]; const int batch_size = input.dims()[0];
const int input_depth = input.dims()[2]; const int input_depth = input.dims()[2];
const int input_height = input.dims()[3]; const int input_height = input.dims()[3];
const int input_width = input.dims()[4]; const int input_width = input.dims()[4];
const int output_channels = output.dims()[1]; const int output_channels = output->dims()[1];
const int output_depth = output.dims()[2]; const int output_depth = output->dims()[2];
const int output_height = output.dims()[3]; const int output_height = output->dims()[3];
const int output_width = output.dims()[4]; const int output_width = output->dims()[4];
const int ksize_depth = ksize[0]; const int ksize_depth = ksize[0];
const int ksize_height = ksize[1]; const int ksize_height = ksize[1];
const int ksize_width = ksize[2]; const int ksize_width = ksize[2];
...@@ -636,8 +640,8 @@ class MaxPool3dWithIndexFunctor<platform::CPUPlace, T> { ...@@ -636,8 +640,8 @@ class MaxPool3dWithIndexFunctor<platform::CPUPlace, T> {
const int output_stride = output_depth * output_height * output_width; const int output_stride = output_depth * output_height * output_width;
const T* input_data = input.data<T>(); const T* input_data = input.data<T>();
T* output_data = output.mutable_data<T>(context.GetPlace()); T* output_data = output->mutable_data<T>(context.GetPlace());
T* mask_data = mask.mutable_data<T>(context.GetPlace()); T* mask_data = mask->mutable_data<T>(context.GetPlace());
for (int i = 0; i < batch_size; i++) { for (int i = 0; i < batch_size; i++) {
for (int c = 0; c < output_channels; ++c) { for (int c = 0; c < output_channels; ++c) {
...@@ -691,14 +695,14 @@ template <typename T> ...@@ -691,14 +695,14 @@ template <typename T>
class MaxPool3dWithIndexGradFunctor<platform::CPUPlace, T> { class MaxPool3dWithIndexGradFunctor<platform::CPUPlace, T> {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
framework::Tensor& input_grad,
const framework::Tensor& output_grad, const framework::Tensor& output_grad,
const framework::Tensor& mask, std::vector<int>& ksize, const framework::Tensor& mask, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings) { std::vector<int>& strides, std::vector<int>& paddings,
const int batch_size = input_grad.dims()[0]; framework::Tensor* input_grad) {
const int input_depth = input_grad.dims()[2]; const int batch_size = input_grad->dims()[0];
const int input_height = input_grad.dims()[3]; const int input_depth = input_grad->dims()[2];
const int input_width = input_grad.dims()[4]; const int input_height = input_grad->dims()[3];
const int input_width = input_grad->dims()[4];
const int output_channels = output_grad.dims()[1]; const int output_channels = output_grad.dims()[1];
const int output_depth = output_grad.dims()[2]; const int output_depth = output_grad.dims()[2];
const int output_height = output_grad.dims()[3]; const int output_height = output_grad.dims()[3];
...@@ -708,7 +712,7 @@ class MaxPool3dWithIndexGradFunctor<platform::CPUPlace, T> { ...@@ -708,7 +712,7 @@ class MaxPool3dWithIndexGradFunctor<platform::CPUPlace, T> {
const T* mask_data = mask.data<T>(); const T* mask_data = mask.data<T>();
const T* output_grad_data = output_grad.data<T>(); const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad.mutable_data<T>(context.GetPlace()); T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
for (int n = 0; n < batch_size; ++n) { for (int n = 0; n < batch_size; ++n) {
for (int c = 0; c < output_channels; ++c) { for (int c = 0; c < output_channels; ++c) {
......
此差异已折叠。
...@@ -88,60 +88,62 @@ template <typename Place, typename PoolProcess, typename T> ...@@ -88,60 +88,62 @@ template <typename Place, typename PoolProcess, typename T>
class Pool2dFunctor { class Pool2dFunctor {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& output, const framework::Tensor& input, std::vector<int>& ksize,
std::vector<int>& ksize, std::vector<int>& strides, std::vector<int>& strides, std::vector<int>& paddings,
std::vector<int>& paddings, PoolProcess pool_compute); PoolProcess pool_compute, framework::Tensor* output);
}; };
template <typename Place, typename PoolProcess, typename T> template <typename Place, typename PoolProcess, typename T>
class Pool2dGradFunctor { class Pool2dGradFunctor {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& input_grad, const framework::Tensor& input,
const framework::Tensor& output, const framework::Tensor& output,
const framework::Tensor& output_grad, std::vector<int>& ksize, const framework::Tensor& output_grad, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings, std::vector<int>& strides, std::vector<int>& paddings,
PoolProcess pool_compute); PoolProcess pool_compute, framework::Tensor* input_grad);
}; };
template <typename Place, class T> template <typename Place, class T>
class MaxPool2dGradFunctor { class MaxPool2dGradFunctor {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& input_grad, const framework::Tensor& input,
const framework::Tensor& output, const framework::Tensor& output,
const framework::Tensor& output_grad, std::vector<int>& ksize, const framework::Tensor& output_grad, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings); std::vector<int>& strides, std::vector<int>& paddings,
framework::Tensor* input_grad);
}; };
template <typename Place, typename PoolProcess, typename T> template <typename Place, typename PoolProcess, typename T>
class Pool3dFunctor { class Pool3dFunctor {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& output, const framework::Tensor& input, std::vector<int>& ksize,
std::vector<int>& ksize, std::vector<int>& strides, std::vector<int>& strides, std::vector<int>& paddings,
std::vector<int>& paddings, PoolProcess pool_compute); PoolProcess pool_compute, framework::Tensor* output);
}; };
template <typename Place, typename PoolProcess, typename T> template <typename Place, typename PoolProcess, typename T>
class Pool3dGradFunctor { class Pool3dGradFunctor {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& input_grad, const framework::Tensor& input,
const framework::Tensor& output, const framework::Tensor& output,
const framework::Tensor& output_grad, std::vector<int>& ksize, const framework::Tensor& output_grad, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings, std::vector<int>& strides, std::vector<int>& paddings,
PoolProcess pool_compute); PoolProcess pool_compute, framework::Tensor* input_grad);
}; };
template <typename Place, class T> template <typename Place, class T>
class MaxPool3dGradFunctor { class MaxPool3dGradFunctor {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& input_grad, const framework::Tensor& input,
const framework::Tensor& output, const framework::Tensor& output,
const framework::Tensor& output_grad, std::vector<int>& ksize, const framework::Tensor& output_grad, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings); std::vector<int>& strides, std::vector<int>& paddings,
framework::Tensor* input_grad);
}; };
/* /*
...@@ -155,38 +157,38 @@ template <typename Place, typename T> ...@@ -155,38 +157,38 @@ template <typename Place, typename T>
class MaxPool2dWithIndexFunctor { class MaxPool2dWithIndexFunctor {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& output, const framework::Tensor& input, std::vector<int>& ksize,
framework::Tensor& mask, std::vector<int>& ksize, std::vector<int>& strides, std::vector<int>& paddings,
std::vector<int>& strides, std::vector<int>& paddings); framework::Tensor* output, framework::Tensor* mask);
}; };
template <typename Place, typename T> template <typename Place, typename T>
class MaxPool2dWithIndexGradFunctor { class MaxPool2dWithIndexGradFunctor {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
framework::Tensor& input_grad,
const framework::Tensor& output_grad, const framework::Tensor& output_grad,
const framework::Tensor& mask, std::vector<int>& ksize, const framework::Tensor& mask, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings); std::vector<int>& strides, std::vector<int>& paddings,
framework::Tensor* input_grad);
}; };
template <typename Place, typename T> template <typename Place, typename T>
class MaxPool3dWithIndexFunctor { class MaxPool3dWithIndexFunctor {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& output, const framework::Tensor& input, std::vector<int>& ksize,
framework::Tensor& mask, std::vector<int>& ksize, std::vector<int>& strides, std::vector<int>& paddings,
std::vector<int>& strides, std::vector<int>& paddings); framework::Tensor* output, framework::Tensor* mask);
}; };
template <typename Place, typename T> template <typename Place, typename T>
class MaxPool3dWithIndexGradFunctor { class MaxPool3dWithIndexGradFunctor {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
framework::Tensor& input_grad,
const framework::Tensor& output_grad, const framework::Tensor& output_grad,
const framework::Tensor& mask, std::vector<int>& ksize, const framework::Tensor& mask, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings); std::vector<int>& strides, std::vector<int>& paddings,
framework::Tensor* input_grad);
}; };
} // namespace math } // namespace math
......
...@@ -75,16 +75,16 @@ class PoolKernel : public framework::OpKernel<T> { ...@@ -75,16 +75,16 @@ class PoolKernel : public framework::OpKernel<T> {
Place, paddle::operators::math::MaxPool<T>, T> Place, paddle::operators::math::MaxPool<T>, T>
pool2d_forward; pool2d_forward;
paddle::operators::math::MaxPool<T> pool_process; paddle::operators::math::MaxPool<T> pool_process;
pool2d_forward(context.device_context(), *in_x, *out, ksize, strides, pool2d_forward(context.device_context(), *in_x, ksize, strides,
paddings, pool_process); paddings, pool_process, out);
} else if (pooling_type == "avg") { } else if (pooling_type == "avg") {
paddle::operators::math::Pool2dFunctor< paddle::operators::math::Pool2dFunctor<
Place, paddle::operators::math::AvgPool<T>, T> Place, paddle::operators::math::AvgPool<T>, T>
pool2d_forward; pool2d_forward;
paddle::operators::math::AvgPool<T> pool_process; paddle::operators::math::AvgPool<T> pool_process;
pool2d_forward(context.device_context(), *in_x, *out, ksize, strides, pool2d_forward(context.device_context(), *in_x, ksize, strides,
paddings, pool_process); paddings, pool_process, out);
} }
} break; } break;
case 3: { case 3: {
...@@ -93,15 +93,15 @@ class PoolKernel : public framework::OpKernel<T> { ...@@ -93,15 +93,15 @@ class PoolKernel : public framework::OpKernel<T> {
Place, paddle::operators::math::MaxPool<T>, T> Place, paddle::operators::math::MaxPool<T>, T>
pool3d_forward; pool3d_forward;
paddle::operators::math::MaxPool<T> pool_process; paddle::operators::math::MaxPool<T> pool_process;
pool3d_forward(context.device_context(), *in_x, *out, ksize, strides, pool3d_forward(context.device_context(), *in_x, ksize, strides,
paddings, pool_process); paddings, pool_process, out);
} else if (pooling_type == "avg") { } else if (pooling_type == "avg") {
paddle::operators::math::Pool3dFunctor< paddle::operators::math::Pool3dFunctor<
Place, paddle::operators::math::AvgPool<T>, T> Place, paddle::operators::math::AvgPool<T>, T>
pool3d_forward; pool3d_forward;
paddle::operators::math::AvgPool<T> pool_process; paddle::operators::math::AvgPool<T> pool_process;
pool3d_forward(context.device_context(), *in_x, *out, ksize, strides, pool3d_forward(context.device_context(), *in_x, ksize, strides,
paddings, pool_process); paddings, pool_process, out);
} }
} break; } break;
default: { PADDLE_THROW("Pool op only supports 2D and 3D input."); } default: { PADDLE_THROW("Pool op only supports 2D and 3D input."); }
...@@ -142,30 +142,30 @@ class PoolGradKernel : public framework::OpKernel<T> { ...@@ -142,30 +142,30 @@ class PoolGradKernel : public framework::OpKernel<T> {
if (pooling_type == "max") { if (pooling_type == "max") {
paddle::operators::math::MaxPool2dGradFunctor<Place, T> paddle::operators::math::MaxPool2dGradFunctor<Place, T>
pool2d_backward; pool2d_backward;
pool2d_backward(context.device_context(), *in_x, *in_x_grad, *out, pool2d_backward(context.device_context(), *in_x, *out, *out_grad,
*out_grad, ksize, strides, paddings); ksize, strides, paddings, in_x_grad);
} else if (pooling_type == "avg") { } else if (pooling_type == "avg") {
paddle::operators::math::Pool2dGradFunctor< paddle::operators::math::Pool2dGradFunctor<
Place, paddle::operators::math::AvgPoolGrad<T>, T> Place, paddle::operators::math::AvgPoolGrad<T>, T>
pool2d_backward; pool2d_backward;
paddle::operators::math::AvgPoolGrad<T> pool_process; paddle::operators::math::AvgPoolGrad<T> pool_process;
pool2d_backward(context.device_context(), *in_x, *in_x_grad, *out, pool2d_backward(context.device_context(), *in_x, *out, *out_grad,
*out_grad, ksize, strides, paddings, pool_process); ksize, strides, paddings, pool_process, in_x_grad);
} }
} break; } break;
case 3: { case 3: {
if (pooling_type == "max") { if (pooling_type == "max") {
paddle::operators::math::MaxPool3dGradFunctor<Place, T> paddle::operators::math::MaxPool3dGradFunctor<Place, T>
pool3d_backward; pool3d_backward;
pool3d_backward(context.device_context(), *in_x, *in_x_grad, *out, pool3d_backward(context.device_context(), *in_x, *out, *out_grad,
*out_grad, ksize, strides, paddings); ksize, strides, paddings, in_x_grad);
} else if (pooling_type == "avg") { } else if (pooling_type == "avg") {
paddle::operators::math::Pool3dGradFunctor< paddle::operators::math::Pool3dGradFunctor<
Place, paddle::operators::math::AvgPoolGrad<T>, T> Place, paddle::operators::math::AvgPoolGrad<T>, T>
pool3d_backward; pool3d_backward;
paddle::operators::math::AvgPoolGrad<T> pool_process; paddle::operators::math::AvgPoolGrad<T> pool_process;
pool3d_backward(context.device_context(), *in_x, *in_x_grad, *out, pool3d_backward(context.device_context(), *in_x, *out, *out_grad,
*out_grad, ksize, strides, paddings, pool_process); ksize, strides, paddings, pool_process, in_x_grad);
} }
} break; } break;
default: { PADDLE_THROW("Pool op only supports 2D and 3D input."); } default: { PADDLE_THROW("Pool op only supports 2D and 3D input."); }
......
...@@ -46,14 +46,14 @@ class MaxPoolWithIndexKernel : public framework::OpKernel<T> { ...@@ -46,14 +46,14 @@ class MaxPoolWithIndexKernel : public framework::OpKernel<T> {
case 2: { case 2: {
paddle::operators::math::MaxPool2dWithIndexFunctor<Place, T> paddle::operators::math::MaxPool2dWithIndexFunctor<Place, T>
pool2d_forward; pool2d_forward;
pool2d_forward(context.device_context(), *in_x, *out, *mask, ksize, pool2d_forward(context.device_context(), *in_x, ksize, strides,
strides, paddings); paddings, out, mask);
} break; } break;
case 3: { case 3: {
paddle::operators::math::MaxPool3dWithIndexFunctor<Place, T> paddle::operators::math::MaxPool3dWithIndexFunctor<Place, T>
pool3d_forward; pool3d_forward;
pool3d_forward(context.device_context(), *in_x, *out, *mask, ksize, pool3d_forward(context.device_context(), *in_x, ksize, strides,
strides, paddings); paddings, out, mask);
} break; } break;
default: { PADDLE_THROW("Pool op only supports 2D and 3D input."); } default: { PADDLE_THROW("Pool op only supports 2D and 3D input."); }
} }
...@@ -89,14 +89,14 @@ class MaxPoolWithIndexGradKernel : public framework::OpKernel<T> { ...@@ -89,14 +89,14 @@ class MaxPoolWithIndexGradKernel : public framework::OpKernel<T> {
case 2: { case 2: {
paddle::operators::math::MaxPool2dWithIndexGradFunctor<Place, T> paddle::operators::math::MaxPool2dWithIndexGradFunctor<Place, T>
pool2d_backward; pool2d_backward;
pool2d_backward(context.device_context(), *in_x_grad, *out_grad, pool2d_backward(context.device_context(), *out_grad, *mask, ksize,
*mask, ksize, strides, paddings); strides, paddings, in_x_grad);
} break; } break;
case 3: { case 3: {
paddle::operators::math::MaxPool3dWithIndexGradFunctor<Place, T> paddle::operators::math::MaxPool3dWithIndexGradFunctor<Place, T>
pool3d_backward; pool3d_backward;
pool3d_backward(context.device_context(), *in_x_grad, *out_grad, pool3d_backward(context.device_context(), *out_grad, *mask, ksize,
*mask, ksize, strides, paddings); strides, paddings, in_x_grad);
} break; } break;
default: { PADDLE_THROW("Pool op only supports 2D and 3D input."); } default: { PADDLE_THROW("Pool op only supports 2D and 3D input."); }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册