diff --git a/paddle/operators/math/pooling.cc b/paddle/operators/math/pooling.cc index 50cfb88bb5700dda3785e63e0ccc6457cc928da0..ead89e146f32ef005b06f4f6f04224d691805d74 100644 --- a/paddle/operators/math/pooling.cc +++ b/paddle/operators/math/pooling.cc @@ -27,15 +27,15 @@ template class Pool2dFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& input, framework::Tensor& output, - std::vector& ksize, std::vector& strides, - std::vector& paddings, PoolProcess pool_process) { + const framework::Tensor& input, std::vector& ksize, + std::vector& strides, std::vector& paddings, + PoolProcess pool_process, framework::Tensor* output) { const int batch_size = input.dims()[0]; const int input_height = input.dims()[2]; const int input_width = input.dims()[3]; - const int output_channels = output.dims()[1]; - const int output_height = output.dims()[2]; - const int output_width = output.dims()[3]; + const int output_channels = output->dims()[1]; + const int output_height = output->dims()[2]; + const int output_width = output->dims()[3]; const int ksize_height = ksize[0]; const int ksize_width = ksize[1]; const int stride_height = strides[0]; @@ -47,7 +47,7 @@ class Pool2dFunctor { const int output_stride = output_height * output_width; const T* input_data = input.data(); - T* output_data = output.mutable_data(context.GetPlace()); + T* output_data = output->mutable_data(context.GetPlace()); for (int i = 0; i < batch_size; i++) { for (int c = 0; c < output_channels; ++c) { @@ -87,11 +87,12 @@ template class Pool2dGradFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& input, framework::Tensor& input_grad, + const framework::Tensor& input, const framework::Tensor& output, const framework::Tensor& output_grad, std::vector& ksize, std::vector& strides, std::vector& paddings, - PoolProcess pool_grad_process) { + PoolProcess pool_grad_process, + framework::Tensor* input_grad) { const int batch_size = input.dims()[0]; const int input_height = input.dims()[2]; const int input_width = input.dims()[3]; @@ -110,7 +111,7 @@ class Pool2dGradFunctor { const T* input_data = input.data(); const T* output_data = output.data(); const T* output_grad_data = output_grad.data(); - T* input_grad_data = input_grad.mutable_data(context.GetPlace()); + T* input_grad_data = input_grad->mutable_data(context.GetPlace()); for (int i = 0; i < batch_size; i++) { for (int c = 0; c < output_channels; ++c) { @@ -154,10 +155,11 @@ template class MaxPool2dGradFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& input, framework::Tensor& input_grad, + const framework::Tensor& input, const framework::Tensor& output, const framework::Tensor& output_grad, std::vector& ksize, - std::vector& strides, std::vector& paddings) { + std::vector& strides, std::vector& paddings, + framework::Tensor* input_grad) { const int batch_size = input.dims()[0]; const int input_height = input.dims()[2]; const int input_width = input.dims()[3]; @@ -176,7 +178,7 @@ class MaxPool2dGradFunctor { const T* input_data = input.data(); const T* output_data = output.data(); const T* output_grad_data = output_grad.data(); - T* input_grad_data = input_grad.mutable_data(context.GetPlace()); + T* input_grad_data = input_grad->mutable_data(context.GetPlace()); for (int i = 0; i < batch_size; i++) { for (int c = 0; c < output_channels; ++c) { @@ -240,17 +242,17 @@ template class Pool3dFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& input, framework::Tensor& output, - std::vector& ksize, std::vector& strides, - std::vector& paddings, PoolProcess pool_process) { + const framework::Tensor& input, std::vector& ksize, + std::vector& strides, std::vector& paddings, + PoolProcess pool_process, framework::Tensor* output) { const int batch_size = input.dims()[0]; const int input_depth = input.dims()[2]; const int input_height = input.dims()[3]; const int input_width = input.dims()[4]; - const int output_channels = output.dims()[1]; - const int output_depth = output.dims()[2]; - const int output_height = output.dims()[3]; - const int output_width = output.dims()[4]; + const int output_channels = output->dims()[1]; + const int output_depth = output->dims()[2]; + const int output_height = output->dims()[3]; + const int output_width = output->dims()[4]; const int ksize_depth = ksize[0]; const int ksize_height = ksize[1]; const int ksize_width = ksize[2]; @@ -265,7 +267,7 @@ class Pool3dFunctor { const int output_stride = output_depth * output_height * output_width; const T* input_data = input.data(); - T* output_data = output.mutable_data(context.GetPlace()); + T* output_data = output->mutable_data(context.GetPlace()); for (int i = 0; i < batch_size; i++) { for (int c = 0; c < output_channels; ++c) { @@ -315,11 +317,12 @@ template class Pool3dGradFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& input, framework::Tensor& input_grad, + const framework::Tensor& input, const framework::Tensor& output, const framework::Tensor& output_grad, std::vector& ksize, std::vector& strides, std::vector& paddings, - PoolProcess pool_grad_process) { + PoolProcess pool_grad_process, + framework::Tensor* input_grad) { const int batch_size = input.dims()[0]; const int input_depth = input.dims()[2]; const int input_height = input.dims()[3]; @@ -343,7 +346,7 @@ class Pool3dGradFunctor { const T* input_data = input.data(); const T* output_data = output.data(); const T* output_grad_data = output_grad.data(); - T* input_grad_data = input_grad.mutable_data(context.GetPlace()); + T* input_grad_data = input_grad->mutable_data(context.GetPlace()); for (int i = 0; i < batch_size; i++) { for (int c = 0; c < output_channels; ++c) { @@ -398,10 +401,11 @@ template class MaxPool3dGradFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& input, framework::Tensor& input_grad, + const framework::Tensor& input, const framework::Tensor& output, const framework::Tensor& output_grad, std::vector& ksize, - std::vector& strides, std::vector& paddings) { + std::vector& strides, std::vector& paddings, + framework::Tensor* input_grad) { const int batch_size = input.dims()[0]; const int input_depth = input.dims()[2]; const int input_height = input.dims()[3]; @@ -425,7 +429,7 @@ class MaxPool3dGradFunctor { const T* input_data = input.data(); const T* output_data = output.data(); const T* output_grad_data = output_grad.data(); - T* input_grad_data = input_grad.mutable_data(context.GetPlace()); + T* input_grad_data = input_grad->mutable_data(context.GetPlace()); for (int i = 0; i < batch_size; i++) { for (int c = 0; c < output_channels; ++c) { @@ -498,15 +502,15 @@ template class MaxPool2dWithIndexFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& input, framework::Tensor& output, - framework::Tensor& mask, std::vector& ksize, - std::vector& strides, std::vector& paddings) { + const framework::Tensor& input, std::vector& ksize, + std::vector& strides, std::vector& paddings, + framework::Tensor* output, framework::Tensor* mask) { const int batch_size = input.dims()[0]; const int input_height = input.dims()[2]; const int input_width = input.dims()[3]; - const int output_channels = output.dims()[1]; - const int output_height = output.dims()[2]; - const int output_width = output.dims()[3]; + const int output_channels = output->dims()[1]; + const int output_height = output->dims()[2]; + const int output_width = output->dims()[3]; const int ksize_height = ksize[0]; const int ksize_width = ksize[1]; const int stride_height = strides[0]; @@ -517,8 +521,8 @@ class MaxPool2dWithIndexFunctor { const int output_stride = output_height * output_width; const T* input_data = input.data(); - T* output_data = output.mutable_data(context.GetPlace()); - T* mask_data = mask.mutable_data(context.GetPlace()); + T* output_data = output->mutable_data(context.GetPlace()); + T* mask_data = mask->mutable_data(context.GetPlace()); for (int i = 0; i < batch_size; i++) { for (int c = 0; c < output_channels; ++c) { @@ -563,13 +567,13 @@ template class MaxPool2dWithIndexGradFunctor { public: void operator()(const platform::DeviceContext& context, - framework::Tensor& input_grad, const framework::Tensor& output_grad, const framework::Tensor& mask, std::vector& ksize, - std::vector& strides, std::vector& paddings) { - const int batch_size = input_grad.dims()[0]; - const int input_height = input_grad.dims()[2]; - const int input_width = input_grad.dims()[3]; + std::vector& strides, std::vector& paddings, + framework::Tensor* input_grad) { + const int batch_size = input_grad->dims()[0]; + const int input_height = input_grad->dims()[2]; + const int input_width = input_grad->dims()[3]; const int output_channels = output_grad.dims()[1]; const int output_height = output_grad.dims()[2]; const int output_width = output_grad.dims()[3]; @@ -578,7 +582,7 @@ class MaxPool2dWithIndexGradFunctor { const T* mask_data = mask.data(); const T* output_grad_data = output_grad.data(); - T* input_grad_data = input_grad.mutable_data(context.GetPlace()); + T* input_grad_data = input_grad->mutable_data(context.GetPlace()); for (int n = 0; n < batch_size; ++n) { for (int c = 0; c < output_channels; ++c) { @@ -612,17 +616,17 @@ template class MaxPool3dWithIndexFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& input, framework::Tensor& output, - framework::Tensor& mask, std::vector& ksize, - std::vector& strides, std::vector& paddings) { + const framework::Tensor& input, std::vector& ksize, + std::vector& strides, std::vector& paddings, + framework::Tensor* output, framework::Tensor* mask) { const int batch_size = input.dims()[0]; const int input_depth = input.dims()[2]; const int input_height = input.dims()[3]; const int input_width = input.dims()[4]; - const int output_channels = output.dims()[1]; - const int output_depth = output.dims()[2]; - const int output_height = output.dims()[3]; - const int output_width = output.dims()[4]; + const int output_channels = output->dims()[1]; + const int output_depth = output->dims()[2]; + const int output_height = output->dims()[3]; + const int output_width = output->dims()[4]; const int ksize_depth = ksize[0]; const int ksize_height = ksize[1]; const int ksize_width = ksize[2]; @@ -636,8 +640,8 @@ class MaxPool3dWithIndexFunctor { const int output_stride = output_depth * output_height * output_width; const T* input_data = input.data(); - T* output_data = output.mutable_data(context.GetPlace()); - T* mask_data = mask.mutable_data(context.GetPlace()); + T* output_data = output->mutable_data(context.GetPlace()); + T* mask_data = mask->mutable_data(context.GetPlace()); for (int i = 0; i < batch_size; i++) { for (int c = 0; c < output_channels; ++c) { @@ -691,14 +695,14 @@ template class MaxPool3dWithIndexGradFunctor { public: void operator()(const platform::DeviceContext& context, - framework::Tensor& input_grad, const framework::Tensor& output_grad, const framework::Tensor& mask, std::vector& ksize, - std::vector& strides, std::vector& paddings) { - const int batch_size = input_grad.dims()[0]; - const int input_depth = input_grad.dims()[2]; - const int input_height = input_grad.dims()[3]; - const int input_width = input_grad.dims()[4]; + std::vector& strides, std::vector& paddings, + framework::Tensor* input_grad) { + const int batch_size = input_grad->dims()[0]; + const int input_depth = input_grad->dims()[2]; + const int input_height = input_grad->dims()[3]; + const int input_width = input_grad->dims()[4]; const int output_channels = output_grad.dims()[1]; const int output_depth = output_grad.dims()[2]; const int output_height = output_grad.dims()[3]; @@ -708,7 +712,7 @@ class MaxPool3dWithIndexGradFunctor { const T* mask_data = mask.data(); const T* output_grad_data = output_grad.data(); - T* input_grad_data = input_grad.mutable_data(context.GetPlace()); + T* input_grad_data = input_grad->mutable_data(context.GetPlace()); for (int n = 0; n < batch_size; ++n) { for (int c = 0; c < output_channels; ++c) { diff --git a/paddle/operators/math/pooling.cu b/paddle/operators/math/pooling.cu index 736327f4b7b9e9df9ce8f7f60b0437fc1d2d373a..6d1138ad50cb095e85b4ceb44fa81731316f10dd 100644 --- a/paddle/operators/math/pooling.cu +++ b/paddle/operators/math/pooling.cu @@ -21,13 +21,13 @@ namespace math { template __global__ void KernelPool2D(const int nthreads, const T* input_data, - T* output_data, const int channels, - const int input_height, const int input_width, - const int output_height, const int output_width, - const int ksize_height, const int ksize_width, - const int stride_height, const int stride_width, - const int padding_height, const int padding_width, - PoolProcess pool_process) { + const int channels, const int input_height, + const int input_width, const int output_height, + const int output_width, const int ksize_height, + const int ksize_width, const int stride_height, + const int stride_width, const int padding_height, + const int padding_width, PoolProcess pool_process, + T* output_data) { for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; index += blockDim.x * gridDim.x) { int pw = index % output_width; @@ -59,11 +59,11 @@ __global__ void KernelPool2D(const int nthreads, const T* input_data, template __global__ void KernelPool2DGrad( const int nthreads, const T* input_data, const T* output_data, - const T* output_grad, T* input_grad, const int channels, - const int input_height, const int input_width, const int output_height, - const int output_width, const int ksize_height, const int ksize_width, - const int stride_height, const int stride_width, const int padding_height, - const int padding_width, PoolProcess pool_process) { + const T* output_grad, const int channels, const int input_height, + const int input_width, const int output_height, const int output_width, + const int ksize_height, const int ksize_width, const int stride_height, + const int stride_width, const int padding_height, const int padding_width, + PoolProcess pool_process, T* input_grad) { for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; index += blockDim.x * gridDim.x) { int offsetW = index % input_width + padding_width; @@ -107,11 +107,11 @@ __global__ void KernelPool2DGrad( template __global__ void KernelMaxPool2DGrad( const int nthreads, const T* input_data, const T* output_data, - const T* output_grad, T* input_grad, const int channels, - const int input_height, const int input_width, const int output_height, - const int output_width, const int ksize_height, const int ksize_width, - const int stride_height, const int stride_width, const int padding_height, - const int padding_width) { + const T* output_grad, const int channels, const int input_height, + const int input_width, const int output_height, const int output_width, + const int ksize_height, const int ksize_width, const int stride_height, + const int stride_width, const int padding_height, const int padding_width, + T* input_grad) { for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; index += blockDim.x * gridDim.x) { int pw = index % output_width; @@ -158,16 +158,16 @@ template class Pool2dFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& input, framework::Tensor& output, - std::vector& ksize, std::vector& strides, - std::vector& paddings, PoolProcess pool_process) { + const framework::Tensor& input, std::vector& ksize, + std::vector& strides, std::vector& paddings, + PoolProcess pool_process, framework::Tensor* output) { const int batch_size = input.dims()[0]; const int input_channels = input.dims()[1]; const int input_height = input.dims()[2]; const int input_width = input.dims()[3]; - const int output_channels = output.dims()[1]; - const int output_height = output.dims()[2]; - const int output_width = output.dims()[3]; + const int output_channels = output->dims()[1]; + const int output_height = output->dims()[2]; + const int output_width = output->dims()[3]; const int ksize_height = ksize[0]; const int ksize_width = ksize[1]; const int stride_height = strides[0]; @@ -176,7 +176,7 @@ class Pool2dFunctor { const int padding_width = paddings[1]; const T* input_data = input.data(); - T* output_data = output.mutable_data(context.GetPlace()); + T* output_data = output->mutable_data(context.GetPlace()); int nthreads = batch_size * output_channels * output_height * output_width; int blocks = (nthreads + 1024 - 1) / 1024; @@ -187,11 +187,10 @@ class Pool2dFunctor { PoolProcess, T><<(context) - .stream()>>>(nthreads, input_data, output_data, input_channels, - input_height, input_width, output_height, - output_width, ksize_height, ksize_width, - stride_height, stride_width, padding_height, - padding_width, pool_process); + .stream()>>>( + nthreads, input_data, input_channels, input_height, input_width, + output_height, output_width, ksize_height, ksize_width, stride_height, + stride_width, padding_height, padding_width, pool_process, output_data); } }; @@ -204,11 +203,11 @@ template class Pool2dGradFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& input, framework::Tensor& input_grad, + const framework::Tensor& input, const framework::Tensor& output, const framework::Tensor& output_grad, std::vector& ksize, std::vector& strides, std::vector& paddings, - PoolProcess pool_process) { + PoolProcess pool_process, framework::Tensor* input_grad) { const int batch_size = input.dims()[0]; const int input_channels = input.dims()[1]; const int input_height = input.dims()[2]; @@ -225,7 +224,7 @@ class Pool2dGradFunctor { const T* input_data = input.data(); const T* output_data = output.data(); const T* output_grad_data = output_grad.data(); - T* input_grad_data = input_grad.mutable_data(context.GetPlace()); + T* input_grad_data = input_grad->mutable_data(context.GetPlace()); int nthreads = batch_size * input_channels * input_height * input_width; int blocks = (nthreads + 1024 - 1) / 1024; @@ -237,10 +236,10 @@ class Pool2dGradFunctor { T><<(context) .stream()>>>( - nthreads, input_data, output_data, output_grad_data, input_grad_data, - input_channels, input_height, input_width, output_height, output_width, - ksize_height, ksize_width, stride_height, stride_width, padding_height, - padding_width, pool_process); + nthreads, input_data, output_data, output_grad_data, input_channels, + input_height, input_width, output_height, output_width, ksize_height, + ksize_width, stride_height, stride_width, padding_height, padding_width, + pool_process, input_grad_data); } }; @@ -253,10 +252,11 @@ template class MaxPool2dGradFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& input, framework::Tensor& input_grad, + const framework::Tensor& input, const framework::Tensor& output, const framework::Tensor& output_grad, std::vector& ksize, - std::vector& strides, std::vector& paddings) { + std::vector& strides, std::vector& paddings, + framework::Tensor* input_grad) { const int batch_size = input.dims()[0]; const int input_channels = input.dims()[1]; const int input_height = input.dims()[2]; @@ -274,7 +274,7 @@ class MaxPool2dGradFunctor { const T* input_data = input.data(); const T* output_data = output.data(); const T* output_grad_data = output_grad.data(); - T* input_grad_data = input_grad.mutable_data(context.GetPlace()); + T* input_grad_data = input_grad->mutable_data(context.GetPlace()); int nthreads = batch_size * output_channels * output_height * output_width; int blocks = (nthreads + 1024 - 1) / 1024; @@ -285,10 +285,10 @@ class MaxPool2dGradFunctor { T><<(context) .stream()>>>( - nthreads, input_data, output_data, output_grad_data, input_grad_data, - input_channels, input_height, input_width, output_height, output_width, - ksize_height, ksize_width, stride_height, stride_width, padding_height, - padding_width); + nthreads, input_data, output_data, output_grad_data, input_channels, + input_height, input_width, output_height, output_width, ksize_height, + ksize_width, stride_height, stride_width, padding_height, padding_width, + input_grad_data); } }; @@ -313,14 +313,16 @@ template class Pool2dGradFunctor< platform::GPUPlace, paddle::operators::math::AvgPoolGrad, double>; template -__global__ void KernelPool3D( - const int nthreads, const T* input_data, T* output_data, const int channels, - const int input_depth, const int input_height, const int input_width, - const int output_depth, const int output_height, const int output_width, - const int ksize_depth, const int ksize_height, const int ksize_width, - const int stride_depth, const int stride_height, const int stride_width, - const int padding_depth, const int padding_height, const int padding_width, - PoolProcess pool_process) { +__global__ void KernelPool3D(const int nthreads, const T* input_data, + const int channels, const int input_depth, + const int input_height, const int input_width, + const int output_depth, const int output_height, + const int output_width, const int ksize_depth, + const int ksize_height, const int ksize_width, + const int stride_depth, const int stride_height, + const int stride_width, const int padding_depth, + const int padding_height, const int padding_width, + PoolProcess pool_process, T* output_data) { for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; index += blockDim.x * gridDim.x) { int pw = index % output_width; @@ -358,13 +360,13 @@ __global__ void KernelPool3D( template __global__ void KernelPool3DGrad( const int nthreads, const T* input_data, const T* output_data, - const T* output_grad, T* input_grad, const int channels, - const int input_depth, const int input_height, const int input_width, - const int output_depth, const int output_height, const int output_width, - const int ksize_depth, const int ksize_height, const int ksize_width, - const int stride_depth, const int stride_height, const int stride_width, - const int padding_depth, const int padding_height, const int padding_width, - PoolProcess pool_process) { + const T* output_grad, const int channels, const int input_depth, + const int input_height, const int input_width, const int output_depth, + const int output_height, const int output_width, const int ksize_depth, + const int ksize_height, const int ksize_width, const int stride_depth, + const int stride_height, const int stride_width, const int padding_depth, + const int padding_height, const int padding_width, PoolProcess pool_process, + T* input_grad) { for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; index += blockDim.x * gridDim.x) { int offsetW = index % input_width + padding_width; @@ -422,13 +424,12 @@ __global__ void KernelPool3DGrad( template __global__ void KernelMaxPool3DGrad( const int nthreads, const T* input_data, const T* output_data, - const T* output_grad, T* input_grad, const int channels, - const int input_depth, const int input_height, const int input_width, - const int output_depth, const int output_height, const int output_width, - const int ksize_depth, const int ksize_height, const int ksize_width, - const int stride_depth, const int stride_height, const int stride_width, - const int padding_depth, const int padding_height, - const int padding_width) { + const T* output_grad, const int channels, const int input_depth, + const int input_height, const int input_width, const int output_depth, + const int output_height, const int output_width, const int ksize_depth, + const int ksize_height, const int ksize_width, const int stride_depth, + const int stride_height, const int stride_width, const int padding_depth, + const int padding_height, const int padding_width, T* input_grad) { for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; index += blockDim.x * gridDim.x) { int pw = index % output_width; @@ -480,18 +481,18 @@ template class Pool3dFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& input, framework::Tensor& output, - std::vector& ksize, std::vector& strides, - std::vector& paddings, PoolProcess pool_process) { + const framework::Tensor& input, std::vector& ksize, + std::vector& strides, std::vector& paddings, + PoolProcess pool_process, framework::Tensor* output) { const int batch_size = input.dims()[0]; const int input_channels = input.dims()[1]; const int input_depth = input.dims()[2]; const int input_height = input.dims()[3]; const int input_width = input.dims()[4]; - const int output_channels = output.dims()[1]; - const int output_depth = output.dims()[2]; - const int output_height = output.dims()[3]; - const int output_width = output.dims()[4]; + const int output_channels = output->dims()[1]; + const int output_depth = output->dims()[2]; + const int output_height = output->dims()[3]; + const int output_width = output->dims()[4]; const int ksize_depth = ksize[0]; const int ksize_height = ksize[1]; const int ksize_width = ksize[2]; @@ -503,7 +504,7 @@ class Pool3dFunctor { const int padding_width = paddings[2]; const T* input_data = input.data(); - T* output_data = output.mutable_data(context.GetPlace()); + T* output_data = output->mutable_data(context.GetPlace()); int nthreads = batch_size * output_channels * output_depth * output_height * output_width; @@ -516,11 +517,11 @@ class Pool3dFunctor { T><<(context) .stream()>>>( - nthreads, input_data, output_data, input_channels, input_depth, - input_height, input_width, output_depth, output_height, output_width, - ksize_depth, ksize_height, ksize_width, stride_depth, stride_height, - stride_width, padding_depth, padding_height, padding_width, - pool_process); + nthreads, input_data, input_channels, input_depth, input_height, + input_width, output_depth, output_height, output_width, ksize_depth, + ksize_height, ksize_width, stride_depth, stride_height, stride_width, + padding_depth, padding_height, padding_width, pool_process, + output_data); } }; @@ -533,11 +534,11 @@ template class Pool3dGradFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& input, framework::Tensor& input_grad, + const framework::Tensor& input, const framework::Tensor& output, const framework::Tensor& output_grad, std::vector& ksize, std::vector& strides, std::vector& paddings, - PoolProcess pool_process) { + PoolProcess pool_process, framework::Tensor* input_grad) { const int batch_size = input.dims()[0]; const int input_channels = input.dims()[1]; const int input_depth = input.dims()[2]; @@ -560,7 +561,7 @@ class Pool3dGradFunctor { const T* input_data = input.data(); const T* output_data = output.data(); const T* output_grad_data = output_grad.data(); - T* input_grad_data = input_grad.mutable_data(context.GetPlace()); + T* input_grad_data = input_grad->mutable_data(context.GetPlace()); int nthreads = batch_size * input_channels * input_depth * input_height * input_width; @@ -573,11 +574,11 @@ class Pool3dGradFunctor { T><<(context) .stream()>>>( - nthreads, input_data, output_data, output_grad_data, input_grad_data, - input_channels, input_depth, input_height, input_width, output_depth, - output_height, output_width, ksize_depth, ksize_height, ksize_width, - stride_depth, stride_height, stride_width, padding_depth, - padding_height, padding_width, pool_process); + nthreads, input_data, output_data, output_grad_data, input_channels, + input_depth, input_height, input_width, output_depth, output_height, + output_width, ksize_depth, ksize_height, ksize_width, stride_depth, + stride_height, stride_width, padding_depth, padding_height, + padding_width, pool_process, input_grad_data); } }; @@ -590,10 +591,11 @@ template class MaxPool3dGradFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& input, framework::Tensor& input_grad, + const framework::Tensor& input, const framework::Tensor& output, const framework::Tensor& output_grad, std::vector& ksize, - std::vector& strides, std::vector& paddings) { + std::vector& strides, std::vector& paddings, + framework::Tensor* input_grad) { const int batch_size = input.dims()[0]; const int input_channels = input.dims()[1]; const int input_depth = input.dims()[2]; @@ -616,7 +618,7 @@ class MaxPool3dGradFunctor { const T* input_data = input.data(); const T* output_data = output.data(); const T* output_grad_data = output_grad.data(); - T* input_grad_data = input_grad.mutable_data(context.GetPlace()); + T* input_grad_data = input_grad->mutable_data(context.GetPlace()); int nthreads = batch_size * output_channels * output_depth * output_height * output_width; @@ -628,11 +630,11 @@ class MaxPool3dGradFunctor { T><<(context) .stream()>>>( - nthreads, input_data, output_data, output_grad_data, input_grad_data, - input_channels, input_depth, input_height, input_width, output_depth, - output_height, output_width, ksize_depth, ksize_height, ksize_width, - stride_depth, stride_height, stride_width, padding_depth, - padding_height, padding_width); + nthreads, input_data, output_data, output_grad_data, input_channels, + input_depth, input_height, input_width, output_depth, output_height, + output_width, ksize_depth, ksize_height, ksize_width, stride_depth, + stride_height, stride_width, padding_depth, padding_height, + padding_width, input_grad_data); } }; @@ -658,11 +660,11 @@ template class Pool3dGradFunctor< template __global__ void KernelMaxPool2dWithIdx( - const int nthreads, const T* input_data, T* output_data, T* mask_data, - const int channels, const int input_height, const int input_width, - const int output_height, const int output_width, const int ksize_height, - const int ksize_width, const int stride_height, const int stride_width, - const int padding_height, const int padding_width) { + const int nthreads, const T* input_data, const int channels, + const int input_height, const int input_width, const int output_height, + const int output_width, const int ksize_height, const int ksize_width, + const int stride_height, const int stride_width, const int padding_height, + const int padding_width, T* output_data, T* mask_data) { for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; index += blockDim.x * gridDim.x) { int pw = index % output_width; @@ -697,11 +699,11 @@ __global__ void KernelMaxPool2dWithIdx( template __global__ void KernelMaxPool2DWithIdxGrad( - const int nthreads, T* input_grad, const T* output_grad, const T* mask_data, + const int nthreads, const T* output_grad, const T* mask_data, const int channels, const int input_height, const int input_width, const int output_height, const int output_width, const int ksize_height, const int ksize_width, const int stride_height, const int stride_width, - const int padding_height, const int padding_width) { + const int padding_height, const int padding_width, T* input_grad) { for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; index += blockDim.x * gridDim.x) { int w_offset = index % input_width; @@ -748,16 +750,16 @@ template class MaxPool2dWithIndexFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& input, framework::Tensor& output, - framework::Tensor& mask, std::vector& ksize, - std::vector& strides, std::vector& paddings) { + const framework::Tensor& input, std::vector& ksize, + std::vector& strides, std::vector& paddings, + framework::Tensor* output, framework::Tensor* mask) { const int batch_size = input.dims()[0]; const int input_channels = input.dims()[1]; const int input_height = input.dims()[2]; const int input_width = input.dims()[3]; - const int output_channels = output.dims()[1]; - const int output_height = output.dims()[2]; - const int output_width = output.dims()[3]; + const int output_channels = output->dims()[1]; + const int output_height = output->dims()[2]; + const int output_width = output->dims()[3]; const int ksize_height = ksize[0]; const int ksize_width = ksize[1]; const int stride_height = strides[0]; @@ -766,8 +768,8 @@ class MaxPool2dWithIndexFunctor { const int padding_width = paddings[1]; const T* input_data = input.data(); - T* output_data = output.mutable_data(context.GetPlace()); - T* mask_data = mask.mutable_data(context.GetPlace()); + T* output_data = output->mutable_data(context.GetPlace()); + T* mask_data = mask->mutable_data(context.GetPlace()); int nthreads = batch_size * output_channels * output_height * output_width; int blocks = (nthreads + 1024 - 1) / 1024; @@ -777,11 +779,10 @@ class MaxPool2dWithIndexFunctor { KernelMaxPool2dWithIdx< T><<(context) - .stream()>>>(nthreads, input_data, output_data, mask_data, - input_channels, input_height, input_width, - output_height, output_width, ksize_height, - ksize_width, stride_height, stride_width, - padding_height, padding_width); + .stream()>>>( + nthreads, input_data, input_channels, input_height, input_width, + output_height, output_width, ksize_height, ksize_width, stride_height, + stride_width, padding_height, padding_width, output_data, mask_data); } }; @@ -794,14 +795,14 @@ template class MaxPool2dWithIndexGradFunctor { public: void operator()(const platform::DeviceContext& context, - framework::Tensor& input_grad, const framework::Tensor& output_grad, const framework::Tensor& mask, std::vector& ksize, - std::vector& strides, std::vector& paddings) { - const int batch_size = input_grad.dims()[0]; - const int input_channels = input_grad.dims()[1]; - const int input_height = input_grad.dims()[2]; - const int input_width = input_grad.dims()[3]; + std::vector& strides, std::vector& paddings, + framework::Tensor* input_grad) { + const int batch_size = input_grad->dims()[0]; + const int input_channels = input_grad->dims()[1]; + const int input_height = input_grad->dims()[2]; + const int input_width = input_grad->dims()[3]; const int output_height = output_grad.dims()[2]; const int output_width = output_grad.dims()[3]; const int ksize_height = ksize[0]; @@ -813,7 +814,7 @@ class MaxPool2dWithIndexGradFunctor { const T* mask_data = mask.data(); const T* output_grad_data = output_grad.data(); - T* input_grad_data = input_grad.mutable_data(context.GetPlace()); + T* input_grad_data = input_grad->mutable_data(context.GetPlace()); int nthreads = batch_size * input_channels * input_height * input_width; int blocks = (nthreads + 1024 - 1) / 1024; @@ -823,11 +824,11 @@ class MaxPool2dWithIndexGradFunctor { KernelMaxPool2DWithIdxGrad< T><<(context) - .stream()>>>(nthreads, input_grad_data, output_grad_data, - mask_data, input_channels, input_height, - input_width, output_height, output_width, - ksize_height, ksize_width, stride_height, - stride_width, padding_height, padding_width); + .stream()>>>(nthreads, output_grad_data, mask_data, + input_channels, input_height, input_width, + output_height, output_width, ksize_height, + ksize_width, stride_height, stride_width, + padding_height, padding_width, input_grad_data); } }; @@ -838,13 +839,13 @@ template class MaxPool2dWithIndexGradFunctor; template __global__ void KernelMaxPool3DWithIdx( - const int nthreads, const T* input_data, T* output_data, T* mask_data, - const int channels, const int input_depth, const int input_height, - const int input_width, const int output_depth, const int output_height, - const int output_width, const int ksize_depth, const int ksize_height, - const int ksize_width, const int stride_depth, const int stride_height, - const int stride_width, const int padding_depth, const int padding_height, - const int padding_width) { + const int nthreads, const T* input_data, const int channels, + const int input_depth, const int input_height, const int input_width, + const int output_depth, const int output_height, const int output_width, + const int ksize_depth, const int ksize_height, const int ksize_width, + const int stride_depth, const int stride_height, const int stride_width, + const int padding_depth, const int padding_height, const int padding_width, + T* output_data, T* mask_data) { for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; index += blockDim.x * gridDim.x) { int pw = index % output_width; @@ -886,13 +887,13 @@ __global__ void KernelMaxPool3DWithIdx( template __global__ void KernelMaxPool3DWithIdxGrad( - const int nthreads, T* input_grad, const T* output_grad, const T* mask, - const int channels, const int input_depth, const int input_height, - const int input_width, const int output_depth, const int output_height, - const int output_width, const int ksize_depth, const int ksize_height, - const int ksize_width, const int stride_depth, const int stride_height, - const int stride_width, const int padding_depth, const int padding_height, - const int padding_width) { + const int nthreads, const T* output_grad, const T* mask, const int channels, + const int input_depth, const int input_height, const int input_width, + const int output_depth, const int output_height, const int output_width, + const int ksize_depth, const int ksize_height, const int ksize_width, + const int stride_depth, const int stride_height, const int stride_width, + const int padding_depth, const int padding_height, const int padding_width, + T* input_grad) { for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; index += blockDim.x * gridDim.x) { int w_offset = index % input_width; @@ -952,18 +953,18 @@ template class MaxPool3dWithIndexFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& input, framework::Tensor& output, - framework::Tensor& mask, std::vector& ksize, - std::vector& strides, std::vector& paddings) { + const framework::Tensor& input, std::vector& ksize, + std::vector& strides, std::vector& paddings, + framework::Tensor* output, framework::Tensor* mask) { const int batch_size = input.dims()[0]; const int input_channels = input.dims()[1]; const int input_depth = input.dims()[2]; const int input_height = input.dims()[3]; const int input_width = input.dims()[4]; - const int output_channels = output.dims()[1]; - const int output_depth = output.dims()[2]; - const int output_height = output.dims()[3]; - const int output_width = output.dims()[4]; + const int output_channels = output->dims()[1]; + const int output_depth = output->dims()[2]; + const int output_height = output->dims()[3]; + const int output_width = output->dims()[4]; const int ksize_depth = ksize[0]; const int ksize_height = ksize[1]; const int ksize_width = ksize[2]; @@ -975,8 +976,8 @@ class MaxPool3dWithIndexFunctor { const int padding_width = paddings[2]; const T* input_data = input.data(); - T* output_data = output.mutable_data(context.GetPlace()); - T* mask_data = mask.mutable_data(context.GetPlace()); + T* output_data = output->mutable_data(context.GetPlace()); + T* mask_data = mask->mutable_data(context.GetPlace()); int nthreads = batch_size * output_channels * output_depth * output_height * output_width; @@ -988,11 +989,10 @@ class MaxPool3dWithIndexFunctor { T><<(context) .stream()>>>( - nthreads, input_data, output_data, mask_data, input_channels, - input_depth, input_height, input_width, output_depth, output_height, - output_width, ksize_depth, ksize_height, ksize_width, stride_depth, - stride_height, stride_width, padding_depth, padding_height, - padding_width); + nthreads, input_data, input_channels, input_depth, input_height, + input_width, output_depth, output_height, output_width, ksize_depth, + ksize_height, ksize_width, stride_depth, stride_height, stride_width, + padding_depth, padding_height, padding_width, output_data, mask_data); } }; @@ -1005,15 +1005,15 @@ template class MaxPool3dWithIndexGradFunctor { public: void operator()(const platform::DeviceContext& context, - framework::Tensor& input_grad, const framework::Tensor& output_grad, const framework::Tensor& mask, std::vector& ksize, - std::vector& strides, std::vector& paddings) { - const int batch_size = input_grad.dims()[0]; - const int input_channels = input_grad.dims()[1]; - const int input_depth = input_grad.dims()[2]; - const int input_height = input_grad.dims()[3]; - const int input_width = input_grad.dims()[4]; + std::vector& strides, std::vector& paddings, + framework::Tensor* input_grad) { + const int batch_size = input_grad->dims()[0]; + const int input_channels = input_grad->dims()[1]; + const int input_depth = input_grad->dims()[2]; + const int input_height = input_grad->dims()[3]; + const int input_width = input_grad->dims()[4]; const int output_depth = output_grad.dims()[2]; const int output_height = output_grad.dims()[3]; const int output_width = output_grad.dims()[4]; @@ -1029,7 +1029,7 @@ class MaxPool3dWithIndexGradFunctor { const T* output_grad_data = output_grad.data(); const T* mask_data = mask.data(); - T* input_grad_data = input_grad.mutable_data(context.GetPlace()); + T* input_grad_data = input_grad->mutable_data(context.GetPlace()); int nthreads = batch_size * input_channels * input_depth * input_height * input_width; @@ -1041,11 +1041,11 @@ class MaxPool3dWithIndexGradFunctor { T><<(context) .stream()>>>( - nthreads, input_grad_data, output_grad_data, mask_data, input_channels, - input_depth, input_height, input_width, output_depth, output_height, - output_width, ksize_depth, ksize_height, ksize_width, stride_depth, - stride_height, stride_width, padding_depth, padding_height, - padding_width); + nthreads, output_grad_data, mask_data, input_channels, input_depth, + input_height, input_width, output_depth, output_height, output_width, + ksize_depth, ksize_height, ksize_width, stride_depth, stride_height, + stride_width, padding_depth, padding_height, padding_width, + input_grad_data); } }; diff --git a/paddle/operators/math/pooling.h b/paddle/operators/math/pooling.h index c50c57b5c52cdc5c12425cb119b80502aef5451e..f6719e1e628cdd2cf7445ec9cd05713bc4f14c84 100644 --- a/paddle/operators/math/pooling.h +++ b/paddle/operators/math/pooling.h @@ -88,60 +88,62 @@ template class Pool2dFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& input, framework::Tensor& output, - std::vector& ksize, std::vector& strides, - std::vector& paddings, PoolProcess pool_compute); + const framework::Tensor& input, std::vector& ksize, + std::vector& strides, std::vector& paddings, + PoolProcess pool_compute, framework::Tensor* output); }; template class Pool2dGradFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& input, framework::Tensor& input_grad, + const framework::Tensor& input, const framework::Tensor& output, const framework::Tensor& output_grad, std::vector& ksize, std::vector& strides, std::vector& paddings, - PoolProcess pool_compute); + PoolProcess pool_compute, framework::Tensor* input_grad); }; template class MaxPool2dGradFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& input, framework::Tensor& input_grad, + const framework::Tensor& input, const framework::Tensor& output, const framework::Tensor& output_grad, std::vector& ksize, - std::vector& strides, std::vector& paddings); + std::vector& strides, std::vector& paddings, + framework::Tensor* input_grad); }; template class Pool3dFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& input, framework::Tensor& output, - std::vector& ksize, std::vector& strides, - std::vector& paddings, PoolProcess pool_compute); + const framework::Tensor& input, std::vector& ksize, + std::vector& strides, std::vector& paddings, + PoolProcess pool_compute, framework::Tensor* output); }; template class Pool3dGradFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& input, framework::Tensor& input_grad, + const framework::Tensor& input, const framework::Tensor& output, const framework::Tensor& output_grad, std::vector& ksize, std::vector& strides, std::vector& paddings, - PoolProcess pool_compute); + PoolProcess pool_compute, framework::Tensor* input_grad); }; template class MaxPool3dGradFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& input, framework::Tensor& input_grad, + const framework::Tensor& input, const framework::Tensor& output, const framework::Tensor& output_grad, std::vector& ksize, - std::vector& strides, std::vector& paddings); + std::vector& strides, std::vector& paddings, + framework::Tensor* input_grad); }; /* @@ -155,38 +157,38 @@ template class MaxPool2dWithIndexFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& input, framework::Tensor& output, - framework::Tensor& mask, std::vector& ksize, - std::vector& strides, std::vector& paddings); + const framework::Tensor& input, std::vector& ksize, + std::vector& strides, std::vector& paddings, + framework::Tensor* output, framework::Tensor* mask); }; template class MaxPool2dWithIndexGradFunctor { public: void operator()(const platform::DeviceContext& context, - framework::Tensor& input_grad, const framework::Tensor& output_grad, const framework::Tensor& mask, std::vector& ksize, - std::vector& strides, std::vector& paddings); + std::vector& strides, std::vector& paddings, + framework::Tensor* input_grad); }; template class MaxPool3dWithIndexFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& input, framework::Tensor& output, - framework::Tensor& mask, std::vector& ksize, - std::vector& strides, std::vector& paddings); + const framework::Tensor& input, std::vector& ksize, + std::vector& strides, std::vector& paddings, + framework::Tensor* output, framework::Tensor* mask); }; template class MaxPool3dWithIndexGradFunctor { public: void operator()(const platform::DeviceContext& context, - framework::Tensor& input_grad, const framework::Tensor& output_grad, const framework::Tensor& mask, std::vector& ksize, - std::vector& strides, std::vector& paddings); + std::vector& strides, std::vector& paddings, + framework::Tensor* input_grad); }; } // namespace math diff --git a/paddle/operators/pool_op.h b/paddle/operators/pool_op.h index 4da1941ab541483e706257667b14aa5a95e0c3cc..63492a89e8d4e44a036bc3c2b16cc54c7e77b534 100644 --- a/paddle/operators/pool_op.h +++ b/paddle/operators/pool_op.h @@ -75,16 +75,16 @@ class PoolKernel : public framework::OpKernel { Place, paddle::operators::math::MaxPool, T> pool2d_forward; paddle::operators::math::MaxPool pool_process; - pool2d_forward(context.device_context(), *in_x, *out, ksize, strides, - paddings, pool_process); + pool2d_forward(context.device_context(), *in_x, ksize, strides, + paddings, pool_process, out); } else if (pooling_type == "avg") { paddle::operators::math::Pool2dFunctor< Place, paddle::operators::math::AvgPool, T> pool2d_forward; paddle::operators::math::AvgPool pool_process; - pool2d_forward(context.device_context(), *in_x, *out, ksize, strides, - paddings, pool_process); + pool2d_forward(context.device_context(), *in_x, ksize, strides, + paddings, pool_process, out); } } break; case 3: { @@ -93,15 +93,15 @@ class PoolKernel : public framework::OpKernel { Place, paddle::operators::math::MaxPool, T> pool3d_forward; paddle::operators::math::MaxPool pool_process; - pool3d_forward(context.device_context(), *in_x, *out, ksize, strides, - paddings, pool_process); + pool3d_forward(context.device_context(), *in_x, ksize, strides, + paddings, pool_process, out); } else if (pooling_type == "avg") { paddle::operators::math::Pool3dFunctor< Place, paddle::operators::math::AvgPool, T> pool3d_forward; paddle::operators::math::AvgPool pool_process; - pool3d_forward(context.device_context(), *in_x, *out, ksize, strides, - paddings, pool_process); + pool3d_forward(context.device_context(), *in_x, ksize, strides, + paddings, pool_process, out); } } break; default: { PADDLE_THROW("Pool op only supports 2D and 3D input."); } @@ -142,30 +142,30 @@ class PoolGradKernel : public framework::OpKernel { if (pooling_type == "max") { paddle::operators::math::MaxPool2dGradFunctor pool2d_backward; - pool2d_backward(context.device_context(), *in_x, *in_x_grad, *out, - *out_grad, ksize, strides, paddings); + pool2d_backward(context.device_context(), *in_x, *out, *out_grad, + ksize, strides, paddings, in_x_grad); } else if (pooling_type == "avg") { paddle::operators::math::Pool2dGradFunctor< Place, paddle::operators::math::AvgPoolGrad, T> pool2d_backward; paddle::operators::math::AvgPoolGrad pool_process; - pool2d_backward(context.device_context(), *in_x, *in_x_grad, *out, - *out_grad, ksize, strides, paddings, pool_process); + pool2d_backward(context.device_context(), *in_x, *out, *out_grad, + ksize, strides, paddings, pool_process, in_x_grad); } } break; case 3: { if (pooling_type == "max") { paddle::operators::math::MaxPool3dGradFunctor pool3d_backward; - pool3d_backward(context.device_context(), *in_x, *in_x_grad, *out, - *out_grad, ksize, strides, paddings); + pool3d_backward(context.device_context(), *in_x, *out, *out_grad, + ksize, strides, paddings, in_x_grad); } else if (pooling_type == "avg") { paddle::operators::math::Pool3dGradFunctor< Place, paddle::operators::math::AvgPoolGrad, T> pool3d_backward; paddle::operators::math::AvgPoolGrad pool_process; - pool3d_backward(context.device_context(), *in_x, *in_x_grad, *out, - *out_grad, ksize, strides, paddings, pool_process); + pool3d_backward(context.device_context(), *in_x, *out, *out_grad, + ksize, strides, paddings, pool_process, in_x_grad); } } break; default: { PADDLE_THROW("Pool op only supports 2D and 3D input."); } diff --git a/paddle/operators/pool_with_index_op.h b/paddle/operators/pool_with_index_op.h index ea37de84abeb577461ccd5c1f0eda8bacb4458eb..c0e3b117dc3ea351b9edfed4d1823de0db27d30a 100644 --- a/paddle/operators/pool_with_index_op.h +++ b/paddle/operators/pool_with_index_op.h @@ -46,14 +46,14 @@ class MaxPoolWithIndexKernel : public framework::OpKernel { case 2: { paddle::operators::math::MaxPool2dWithIndexFunctor pool2d_forward; - pool2d_forward(context.device_context(), *in_x, *out, *mask, ksize, - strides, paddings); + pool2d_forward(context.device_context(), *in_x, ksize, strides, + paddings, out, mask); } break; case 3: { paddle::operators::math::MaxPool3dWithIndexFunctor pool3d_forward; - pool3d_forward(context.device_context(), *in_x, *out, *mask, ksize, - strides, paddings); + pool3d_forward(context.device_context(), *in_x, ksize, strides, + paddings, out, mask); } break; default: { PADDLE_THROW("Pool op only supports 2D and 3D input."); } } @@ -89,14 +89,14 @@ class MaxPoolWithIndexGradKernel : public framework::OpKernel { case 2: { paddle::operators::math::MaxPool2dWithIndexGradFunctor pool2d_backward; - pool2d_backward(context.device_context(), *in_x_grad, *out_grad, - *mask, ksize, strides, paddings); + pool2d_backward(context.device_context(), *out_grad, *mask, ksize, + strides, paddings, in_x_grad); } break; case 3: { paddle::operators::math::MaxPool3dWithIndexGradFunctor pool3d_backward; - pool3d_backward(context.device_context(), *in_x_grad, *out_grad, - *mask, ksize, strides, paddings); + pool3d_backward(context.device_context(), *out_grad, *mask, ksize, + strides, paddings, in_x_grad); } break; default: { PADDLE_THROW("Pool op only supports 2D and 3D input."); } }