diff --git a/paddle/operators/math/CMakeLists.txt b/paddle/operators/math/CMakeLists.txt index 185708cdaab4af29824961260ca04f71048a0978..1233a9ea3242214ca83f9707f54a070778a0837b 100644 --- a/paddle/operators/math/CMakeLists.txt +++ b/paddle/operators/math/CMakeLists.txt @@ -7,3 +7,5 @@ endif() nv_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) cc_test(im2col_test SRCS im2col_test.cc DEPS math_function tensor) +cc_test(pool_test_maxPool2d_test SRCS pool_test_maxPool2d.cc DEPS math_function tensor) +cc_test(pool_test_maxPool3d_test SRCS pool_test_maxPool3d.cc DEPS math_function tensor) diff --git a/paddle/operators/math/pool_test_maxPool2d.cc b/paddle/operators/math/pool_test_maxPool2d.cc new file mode 100644 index 0000000000000000000000000000000000000000..8ddf1d79c7a9213064dd4049fd2ded68dece983e --- /dev/null +++ b/paddle/operators/math/pool_test_maxPool2d.cc @@ -0,0 +1,150 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#include +#include "paddle/operators/math/pooling.h" + +#include "paddle/memory/memcpy.h" +#include "paddle/platform/enforce.h" + +#include +#include + +#ifndef PADDLE_ONLY_CPU + +template +void testPool2d(paddle::platform::DeviceContext& context, PooType pool_process, + paddle::framework::Tensor& input, + paddle::framework::Tensor& input_grad, + paddle::framework::Tensor& output, + paddle::framework::Tensor& output_grad, std::vector& ksize, + std::vector& strides, std::vector& paddings) { + paddle::operators::math::Pool2dForwardFunctor + pool2d_forward; + pool2d_forward(context, input, output, ksize, strides, paddings, + pool_process); + + int times = 50; + clock_t start, finish; + double totaltime; + + // Pool2dBackwardFunctor + start = clock(); + for (int i = 0; i < times; ++i) { + paddle::operators::math::Pool2dBackwardFunctor + pool2d_backward; + pool2d_backward(context, input, input_grad, output, output_grad, ksize, + strides, paddings, pool_process); + PADDLE_ENFORCE(cudaStreamSynchronize(0), + "cudaStreamSynchronize failed in pool2d_backward CopyFrom"); + } + finish = clock(); + totaltime = (double)(finish - start) / CLOCKS_PER_SEC; + totaltime /= times; + std::cout << "\nPool3dBackwardFunctor: " << totaltime << "s" << std::endl; + + // MaxPool3dBackwardFunctor + start = clock(); + for (int j = 0; j < times; ++j) { + paddle::operators::math::MaxPool2dBackwardFunctor< + paddle::platform::GPUPlace, float> + maxpool2d_backward; + maxpool2d_backward(context, input, input_grad, output, output_grad, ksize, + strides, paddings); + PADDLE_ENFORCE( + cudaStreamSynchronize(0), + "cudaStreamSynchronize failed in maxpool2d_backward CopyFrom"); + } + finish = clock(); + totaltime = (double)(finish - start) / CLOCKS_PER_SEC; + totaltime /= times; + std::cout << "\nMaxPool3dBackwardFunctor: " << totaltime << "s" << std::endl; +} + +void test2dPool() { + using paddle::platform::DeviceContext; + using paddle::platform::CUDADeviceContext; + using paddle::platform::GPUPlace; + + paddle::framework::Tensor input_tmp; + paddle::framework::Tensor output_tmp; + paddle::framework::Tensor input; + paddle::framework::Tensor input_grad; + paddle::framework::Tensor output; + paddle::framework::Tensor output_grad; + + int batch = 32; + int channel = 32; + int input_height = 128; + int input_width = 128; + int in_len = batch * channel * input_height * input_width; + std::vector ksize({3, 3}); + std::vector strides({1, 1}); + std::vector paddings({0, 0}); + + int output_height = + (input_height - ksize[0] + 2 * paddings[0]) / strides[0] + 1; + int output_width = + (input_width - ksize[1] + 2 * paddings[1]) / strides[1] + 1; + int output_len = output_height * output_width; + + input_tmp.mutable_data({batch, channel, input_height, input_width}, + paddle::platform::CPUPlace()); + output_tmp.mutable_data({batch, channel, output_height, output_width}, + paddle::platform::CPUPlace()); + + float* arr = new float[in_len]; + auto* place = new paddle::platform::GPUPlace(); + + float* input_ptr = input_tmp.data(); + for (int i = 0; i < in_len; ++i) arr[i] = i; // rand() / double(RAND_MAX/2); + memcpy(input_ptr, arr, in_len * sizeof(float)); + input.CopyFrom(input_tmp, *place); + + input_ptr = input_tmp.data(); + for (int i = 0; i < in_len; ++i) arr[i] = 0; + memcpy(input_ptr, arr, in_len * sizeof(float)); + input_grad.CopyFrom(input_tmp, *place); + + // output + input_ptr = output_tmp.data(); + for (int i = 0; i < output_len; ++i) + arr[i] = 0; // rand() / double(RAND_MAX/2); + memcpy(input_ptr, arr, output_len * sizeof(float)); + output.CopyFrom(input_tmp, *place); + + // output + input_ptr = output_tmp.data(); + for (int i = 0; i < output_len; ++i) + arr[i] = 1; // rand() / double(RAND_MAX/2); + memcpy(input_ptr, arr, output_len * sizeof(float)); + output_grad.CopyFrom(input_tmp, *place); + + paddle::platform::DeviceContext* context = + new paddle::platform::CUDADeviceContext(paddle::platform::GPUPlace()); + paddle::operators::math::pool::maxPool pool_process; + + testPool2d>( + *context, pool_process, input, input_grad, output, output_grad, ksize, + strides, paddings); +} + +int main() { + // testPool3d(); + test2dPool(); + // testPool3d(); +} +#endif \ No newline at end of file diff --git a/paddle/operators/math/pool_test_maxPool3d.cc b/paddle/operators/math/pool_test_maxPool3d.cc new file mode 100644 index 0000000000000000000000000000000000000000..000b006a1281d99d3427740bc7b0f9ed3da5d1dd --- /dev/null +++ b/paddle/operators/math/pool_test_maxPool3d.cc @@ -0,0 +1,154 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#include +#include "paddle/operators/math/pooling.h" + +#include "paddle/memory/memcpy.h" +#include "paddle/platform/enforce.h" + +#include +#include + +#ifndef PADDLE_ONLY_CPU + +template +void testPool3d(paddle::platform::DeviceContext& context, PooType pool_process, + paddle::framework::Tensor& input, + paddle::framework::Tensor& input_grad, + paddle::framework::Tensor& output, + paddle::framework::Tensor& output_grad, std::vector& ksize, + std::vector& strides, std::vector& paddings) { + paddle::operators::math::Pool3dForwardFunctor + pool3d_forward; + pool3d_forward(context, input, output, ksize, strides, paddings, + pool_process); + + int times = 50; + clock_t start, finish; + double totaltime; + + // Pool3dBackwardFunctor + start = clock(); + for (int i = 0; i < times; ++i) { + paddle::operators::math::Pool3dBackwardFunctor + pool3d_backward; + pool3d_backward(context, input, input_grad, output, output_grad, ksize, + strides, paddings, pool_process); + PADDLE_ENFORCE(cudaStreamSynchronize(0), + "cudaStreamSynchronize failed in pool3d_backward CopyFrom"); + } + finish = clock(); + totaltime = (double)(finish - start) / CLOCKS_PER_SEC; + totaltime /= times; + std::cout << "\nPool3dBackwardFunctor: " << totaltime << "s" << std::endl; + + // MaxPool3dBackwardFunctor + start = clock(); + for (int j = 0; j < times; ++j) { + paddle::operators::math::MaxPool3dBackwardFunctor< + paddle::platform::GPUPlace, float> + maxpool3d_backward; + maxpool3d_backward(context, input, input_grad, output, output_grad, ksize, + strides, paddings); + PADDLE_ENFORCE( + cudaStreamSynchronize(0), + "cudaStreamSynchronize failed in maxpool3d_backward CopyFrom"); + } + finish = clock(); + totaltime = (double)(finish - start) / CLOCKS_PER_SEC; + totaltime /= times; + std::cout << "\nMaxPool3dBackwardFunctor: " << totaltime << "s" << std::endl; +} + +void test3dPool() { + using paddle::platform::DeviceContext; + using paddle::platform::CUDADeviceContext; + using paddle::platform::GPUPlace; + + paddle::framework::Tensor input_tmp; + paddle::framework::Tensor output_tmp; + paddle::framework::Tensor input; + paddle::framework::Tensor input_grad; + paddle::framework::Tensor output; + paddle::framework::Tensor output_grad; + + int batch = 32; + int channel = 4; + int input_depth = 4; + int input_height = 128; + int input_width = 128; + int in_len = batch * channel * input_depth * input_height * input_width; + std::vector ksize({3, 3, 3}); + std::vector strides({2, 2, 2}); + std::vector paddings({1, 1, 1}); + + int output_depth = + (input_depth - ksize[0] + 2 * paddings[0]) / strides[0] + 1; + int output_height = + (input_height - ksize[1] + 2 * paddings[1]) / strides[1] + 1; + int output_width = + (input_width - ksize[2] + 2 * paddings[2]) / strides[2] + 1; + + int output_len = output_depth * output_height * output_width; + + input_tmp.mutable_data( + {batch, channel, input_depth, input_height, input_width}, + paddle::platform::CPUPlace()); + output_tmp.mutable_data( + {batch, channel, output_depth, output_height, output_width}, + paddle::platform::CPUPlace()); + + float* arr = new float[in_len]; + auto* place = new paddle::platform::GPUPlace(); + + // input + float* input_ptr = input_tmp.data(); + for (int i = 0; i < in_len; ++i) arr[i] = i; // rand() / double(RAND_MAX/2); + memcpy(input_ptr, arr, in_len * sizeof(float)); + input.CopyFrom(input_tmp, *place); + + // input_grad + input_ptr = input_tmp.data(); + for (int i = 0; i < in_len; ++i) arr[i] = 0; + memcpy(input_ptr, arr, in_len * sizeof(float)); + input_grad.CopyFrom(input_tmp, *place); + + // output + input_ptr = output_tmp.data(); + for (int i = 0; i < output_len; ++i) + arr[i] = 0; // rand() / double(RAND_MAX/2); + memcpy(input_ptr, arr, output_len * sizeof(float)); + output.CopyFrom(input_tmp, *place); + + // output_grad + input_ptr = output_tmp.data(); + for (int i = 0; i < output_len; ++i) + arr[i] = 1; // rand() / double(RAND_MAX/2); + memcpy(input_ptr, arr, output_len * sizeof(float)); + output_grad.CopyFrom(input_tmp, *place); + + paddle::platform::DeviceContext* context = + new paddle::platform::CUDADeviceContext(paddle::platform::GPUPlace()); + paddle::operators::math::pool::maxPool pool_process; + + testPool3d>( + *context, pool_process, input, input_grad, output, output_grad, ksize, + strides, paddings); +} + +int main() { test3dPool(); } +#endif \ No newline at end of file diff --git a/paddle/operators/math/pooling.cc b/paddle/operators/math/pooling.cc index 5ce748ff08b1045011b6efb7497d28331d067b15..7c616023ca7345d7eb21b596f7876ef8be0cb1c5 100644 --- a/paddle/operators/math/pooling.cc +++ b/paddle/operators/math/pooling.cc @@ -134,6 +134,70 @@ class Pool2dBackwardFunctor { } }; +template +class MaxPool2dBackwardFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor& input_grad, + const framework::Tensor& output, + const framework::Tensor& output_grad, std::vector& ksize, + std::vector& strides, std::vector& paddings) { + const int batch_size = input.dims()[0]; + const int input_height = input.dims()[2]; + const int input_width = input.dims()[3]; + const int output_channels = output.dims()[1]; + const int output_height = output.dims()[2]; + const int output_width = output.dims()[3]; + const int ksize_height = ksize[0]; + const int ksize_width = ksize[1]; + const int stride_height = strides[0]; + const int stride_width = strides[1]; + const int padding_height = paddings[0]; + const int padding_width = paddings[1]; + const int input_stride = input_height * input_width; + const int output_stride = output_height * output_width; + + const T* input_data = input.data(); + const T* output_data = output.data(); + const T* output_grad_data = output_grad.data(); + T* input_grad_data = input_grad.mutable_data(context.GetPlace()); + + for (int i = 0; i < batch_size; i++) { + for (int c = 0; c < output_channels; ++c) { + for (int ph = 0; ph < output_height; ++ph) { + int hstart = ph * stride_height - padding_height; + int hend = std::min(hstart + ksize_height, input_height); + hstart = std::max(hstart, 0); + for (int pw = 0; pw < output_width; ++pw) { + int wstart = pw * stride_width - padding_width; + int wend = std::min(wstart + ksize_width, input_width); + wstart = std::max(wstart, 0); + + bool stop = false; + for (int h = hstart; h < hend && !stop; ++h) { + for (int w = wstart; w < wend && !stop; ++w) { + int input_idx = h * input_width + w; + int output_idx = ph * output_width + pw; + if (input_data[input_idx] == output_data[output_idx]) { + input_grad_data[input_idx] += output_grad_data[output_idx]; + stop = true; + } + } + } + } + } + input_data += input_stride; + output_data += output_stride; + input_grad_data += input_stride; + output_grad_data += output_stride; + } + } + } +}; + +template class MaxPool2dBackwardFunctor; +template class MaxPool2dBackwardFunctor; + template class Pool2dForwardFunctor< platform::CPUPlace, paddle::operators::math::pool::maxPool, float>; template class Pool2dForwardFunctor< @@ -299,6 +363,84 @@ class Pool3dBackwardFunctor { } }; +template +class MaxPool3dBackwardFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor& input_grad, + const framework::Tensor& output, + const framework::Tensor& output_grad, std::vector& ksize, + std::vector& strides, std::vector& paddings) { + const int batch_size = input.dims()[0]; + const int input_depth = input.dims()[2]; + const int input_height = input.dims()[3]; + const int input_width = input.dims()[4]; + const int output_channels = output.dims()[1]; + const int output_depth = output.dims()[2]; + const int output_height = output.dims()[3]; + const int output_width = output.dims()[4]; + const int ksize_depth = ksize[0]; + const int ksize_height = ksize[1]; + const int ksize_width = ksize[2]; + const int stride_depth = strides[0]; + const int stride_height = strides[1]; + const int stride_width = strides[2]; + const int padding_depth = paddings[0]; + const int padding_height = paddings[1]; + const int padding_width = paddings[2]; + const int input_stride = input_depth * input_height * input_width; + const int output_stride = output_depth * output_height * output_width; + + const T* input_data = input.data(); + const T* output_data = output.data(); + const T* output_grad_data = output_grad.data(); + T* input_grad_data = input_grad.mutable_data(context.GetPlace()); + + for (int i = 0; i < batch_size; i++) { + for (int c = 0; c < output_channels; ++c) { + for (int pd = 0; pd < output_depth; ++pd) { + int dstart = pd * stride_depth - padding_depth; + int dend = std::min(dstart + ksize_depth, input_depth); + dstart = std::max(dstart, 0); + for (int ph = 0; ph < output_height; ++ph) { + int hstart = ph * stride_height - padding_height; + int hend = std::min(hstart + ksize_height, input_height); + hstart = std::max(hstart, 0); + for (int pw = 0; pw < output_width; ++pw) { + int wstart = pw * stride_width - padding_width; + int wend = std::min(wstart + ksize_width, input_width); + wstart = std::max(wstart, 0); + bool stop = false; + for (int d = dstart; d < dend && !stop; ++d) { + for (int h = hstart; h < hend && !stop; ++h) { + for (int w = wstart; w < wend && !stop; ++w) { + int input_idx = (d * input_height + h) * input_width + w; + int output_idx = + (pd * output_height + ph) * output_width + pw; + + if (input_data[input_idx] == output_data[output_idx]) { + input_grad_data[input_idx] += + output_grad_data[output_idx]; + stop = true; + } + } + } + } + } + } + } + input_data += input_stride; + output_data += output_stride; + input_grad_data += input_stride; + output_grad_data += output_stride; + } + } + } +}; + +template class MaxPool3dBackwardFunctor; +template class MaxPool3dBackwardFunctor; + template class Pool3dForwardFunctor< platform::CPUPlace, paddle::operators::math::pool::maxPool, float>; template class Pool3dForwardFunctor< diff --git a/paddle/operators/math/pooling.cu b/paddle/operators/math/pooling.cu index 09e6bd9000a4e76af4aded60a69bd40950177254..347270abc3f43a0c5666e24006c84307317a002b 100644 --- a/paddle/operators/math/pooling.cu +++ b/paddle/operators/math/pooling.cu @@ -102,6 +102,51 @@ __global__ void KernelPool2dBackward( } } +template +__global__ void KernelMaxPool2dBackward( + const int nthreads, const T* input_data, const T* output_data, + const T* output_grad, T* input_grad, const int channels, + const int input_height, const int input_width, const int output_height, + const int output_width, const int ksize_height, const int ksize_width, + const int stride_height, const int stride_width, const int padding_height, + const int padding_width) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < nthreads) { + int pw = index % output_width; + int ph = (index / output_width) % output_height; + int c = (index / output_width / output_height) % channels; + int batch_idx = index / output_width / output_height / channels; + + int hstart = ph * stride_height - padding_height; + int hend = min(hstart + ksize_height, input_height); + hstart = max(hstart, 0); + + int wstart = pw * stride_width - padding_width; + int wend = min(wstart + ksize_width, input_width); + wstart = max(wstart, 0); + + input_data += (batch_idx * channels + c) * input_height * input_width; + input_grad += (batch_idx * channels + c) * input_height * input_width; + + T ele = output_data[index]; + int maxIndex = -1; + bool stop = false; + for (int h = hstart; h < hend && !stop; ++h) { + for (int w = wstart; w < wend && !stop; ++w) { + if (ele == input_data[h * input_width + w]) { + maxIndex = h * input_width + w; + stop = true; + } + } + } + + if (maxIndex != -1) { + // atomic add + atomicAdd(input_grad + maxIndex, output_grad[index]); + } + } +} + template class Pool2dForwardFunctor { public: @@ -187,6 +232,52 @@ class Pool2dBackwardFunctor { } }; +template +class MaxPool2dBackwardFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor& input_grad, + const framework::Tensor& output, + const framework::Tensor& output_grad, std::vector& ksize, + std::vector& strides, std::vector& paddings) { + const int batch_size = input.dims()[0]; + const int input_channels = input.dims()[1]; + const int input_height = input.dims()[2]; + const int input_width = input.dims()[3]; + const int output_channels = output.dims()[1]; + const int output_height = output.dims()[2]; + const int output_width = output.dims()[3]; + const int ksize_height = ksize[0]; + const int ksize_width = ksize[1]; + const int stride_height = strides[0]; + const int stride_width = strides[1]; + const int padding_height = paddings[0]; + const int padding_width = paddings[1]; + + const T* input_data = input.data(); + const T* output_data = output.data(); + const T* output_grad_data = output_grad.data(); + T* input_grad_data = input_grad.mutable_data(context.GetPlace()); + + int nthreads = batch_size * output_channels * output_height * output_width; + int blocks = (nthreads + 1024 - 1) / 1024; + dim3 threads(1024, 1); + dim3 grid(blocks, 1); + + KernelMaxPool2dBackward< + T><<(context) + .stream()>>>( + nthreads, input_data, output_data, output_grad_data, input_grad_data, + input_channels, input_height, input_width, output_height, output_width, + ksize_height, ksize_width, stride_height, stride_width, padding_height, + padding_width); + } +}; + +template class MaxPool2dBackwardFunctor; +// template class MaxPool2dBackwardFunctor; + template class Pool2dForwardFunctor< platform::GPUPlace, paddle::operators::math::pool::maxPool, float>; template class Pool2dForwardFunctor< @@ -311,6 +402,58 @@ __global__ void KernelPool3DBackward( } } +template +__global__ void KernelMaxPool3DBackward( + const int nthreads, const T* input_data, const T* output_data, + const T* output_grad, T* input_grad, const int channels, + const int input_depth, const int input_height, const int input_width, + const int output_depth, const int output_height, const int output_width, + const int ksize_depth, const int ksize_height, const int ksize_width, + const int stride_depth, const int stride_height, const int stride_width, + const int padding_depth, const int padding_height, + const int padding_width) { + for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < (nthreads); + index += blockDim.x * gridDim.x) { + int pw = index % output_width; + int ph = (index / output_width) % output_height; + int pd = (index / output_width / output_height) % output_depth; + int c = (index / output_width / output_height / output_depth) % channels; + int batch_idx = + index / output_width / output_height / output_depth / channels; + int dstart = pd * stride_depth - padding_depth; + int hstart = ph * stride_height - padding_height; + int wstart = pw * stride_width - padding_width; + int dend = min(dstart + ksize_depth, input_depth); + int hend = min(hstart + ksize_height, input_height); + int wend = min(wstart + ksize_width, input_width); + dstart = max(dstart, 0); + hstart = max(hstart, 0); + wstart = max(wstart, 0); + T ele = output_data[index]; + bool stop = false; + int maxIdx = -1; + input_data += + (batch_idx * channels + c) * input_depth * input_height * input_width; + input_grad += + (batch_idx * channels + c) * input_depth * input_height * input_width; + + for (int d = dstart; d < dend && !stop; ++d) { + for (int h = hstart; h < hend && !stop; ++h) { + for (int w = wstart; w < wend && !stop; ++w) { + if (ele == input_data[(d * input_height + h) * input_width + w]) { + stop = true; + maxIdx = (d * input_height + h) * input_width + w; + } + } + } + } + if (maxIdx != -1) { + // atomic add + atomicAdd(input_grad + maxIdx, output_grad[index]); + } + } +} + template class Pool3dForwardFunctor { public: @@ -411,6 +554,59 @@ class Pool3dBackwardFunctor { } }; +template +class MaxPool3dBackwardFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor& input_grad, + const framework::Tensor& output, + const framework::Tensor& output_grad, std::vector& ksize, + std::vector& strides, std::vector& paddings) { + const int batch_size = input.dims()[0]; + const int input_channels = input.dims()[1]; + const int input_depth = input.dims()[2]; + const int input_height = input.dims()[3]; + const int input_width = input.dims()[4]; + const int output_channels = output.dims()[1]; + const int output_depth = output.dims()[2]; + const int output_height = output.dims()[3]; + const int output_width = output.dims()[4]; + const int ksize_depth = ksize[0]; + const int ksize_height = ksize[1]; + const int ksize_width = ksize[2]; + const int stride_depth = strides[0]; + const int stride_height = strides[1]; + const int stride_width = strides[2]; + const int padding_depth = paddings[0]; + const int padding_height = paddings[1]; + const int padding_width = paddings[2]; + + const T* input_data = input.data(); + const T* output_data = output.data(); + const T* output_grad_data = output_grad.data(); + T* input_grad_data = input_grad.mutable_data(context.GetPlace()); + + int nthreads = batch_size * output_channels * output_depth * output_height * + output_width; + int blocks = (nthreads + 1024 - 1) / 1024; + dim3 threads(1024, 1); + dim3 grid(blocks, 1); + + KernelMaxPool3DBackward< + T><<(context) + .stream()>>>( + nthreads, input_data, output_data, output_grad_data, input_grad_data, + input_channels, input_depth, input_height, input_width, output_depth, + output_height, output_width, ksize_depth, ksize_height, ksize_width, + stride_depth, stride_height, stride_width, padding_depth, + padding_height, padding_width); + } +}; + +template class MaxPool3dBackwardFunctor; +// template class MaxPool3dBackwardFunctor; + template class Pool3dForwardFunctor< platform::GPUPlace, paddle::operators::math::pool::maxPool, float>; template class Pool3dForwardFunctor< diff --git a/paddle/operators/math/pooling.h b/paddle/operators/math/pooling.h index 627ece2ca4023070c63684f48b2546ca8a4f17bb..4b8b02d374c565a942dd78031c43df45492c7595 100644 --- a/paddle/operators/math/pooling.h +++ b/paddle/operators/math/pooling.h @@ -72,6 +72,16 @@ class Pool2dBackwardFunctor { PoolProcess pool_process); }; +template +class MaxPool2dBackwardFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor& input_grad, + const framework::Tensor& output, + const framework::Tensor& output_grad, std::vector& ksize, + std::vector& strides, std::vector& paddings); +}; + template class Pool3dForwardFunctor { public: @@ -92,6 +102,16 @@ class Pool3dBackwardFunctor { PoolProcess pool_process); }; +template +class MaxPool3dBackwardFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor& input_grad, + const framework::Tensor& output, + const framework::Tensor& output_grad, std::vector& ksize, + std::vector& strides, std::vector& paddings); +}; + } // namespace math } // namespace operators } // namespace paddle