提交 9e7c0b5e 编写于 作者: C chengduoZH

Add pooling2d(max, ave) and pooling3d(max, ave) functor

上级 59c48f98
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/pooling.h"
namespace paddle {
namespace operators {
namespace math {
template <typename PoolProcess, typename T>
class Pool2dForwardFunctor<platform::CPUPlace, PoolProcess, T> {
public:
void operator()(const framework::Tensor& input, framework::Tensor& output,
std::vector<int>& ksize, std::vector<int>& strides,
std::vector<int>& paddings, PoolProcess pool_process,
platform::DeviceContext* context) {
const int batch_size = input.dims()[0];
const int input_height = input.dims()[2];
const int input_width = input.dims()[3];
const int output_channels = output.dims()[1];
const int output_height = output.dims()[2];
const int output_width = output.dims()[3];
const int ksize_height = ksize[0];
const int ksize_width = ksize[1];
const int stride_height = strides[0];
const int stride_width = strides[1];
const int padding_height = paddings[0];
const int padding_width = paddings[1];
const int input_stride = input_height * input_width;
const int output_stride = output_height * output_width;
const T* input_data = input.data<T>();
T* output_data = output.mutable_data<T>(context->GetPlace());
for (int i = 0; i < batch_size; i++) {
for (int c = 0; c < output_channels; ++c) {
for (int ph = 0; ph < output_height; ++ph) {
int hstart = ph * stride_height - padding_height;
int hend = std::min(hstart + ksize_height, input_height);
hstart = std::max(hstart, 0);
for (int pw = 0; pw < output_width; ++pw) {
int wstart = pw * stride_width - padding_width;
int wend = std::min(wstart + ksize_width, input_width);
wstart = std::max(wstart, 0);
T ele = pool_process.initial();
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
pool_process.process(ele, input_data[h * input_width + w]);
}
}
int pool_size = (hend - hstart) * (wend - wstart);
pool_process.finalize(ele, (static_cast<T>(pool_size)));
output_data[ph * output_width + pw] = ele;
}
}
input_data += input_stride;
output_data += output_stride;
}
}
}
};
template <typename PoolProcess, class T>
class Pool2dBackwardFunctor<platform::CPUPlace, PoolProcess, T> {
public:
void operator()(const framework::Tensor& input, framework::Tensor& input_grad,
const framework::Tensor& output,
const framework::Tensor& output_grad, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings,
PoolProcess pool_process, platform::DeviceContext* context) {
const int batch_size = input.dims()[0];
const int input_height = input.dims()[2];
const int input_width = input.dims()[3];
const int output_channels = output.dims()[1];
const int output_height = output.dims()[2];
const int output_width = output.dims()[3];
const int ksize_height = ksize[0];
const int ksize_width = ksize[1];
const int stride_height = strides[0];
const int stride_width = strides[1];
const int padding_height = paddings[0];
const int padding_width = paddings[1];
const int input_stride = input_height * input_width;
const int output_stride = output_height * output_width;
const T* input_data = input.data<T>();
const T* output_data = output.data<T>();
const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad.mutable_data<T>(context->GetPlace());
for (int i = 0; i < batch_size; i++) {
for (int c = 0; c < output_channels; ++c) {
for (int ph = 0; ph < output_height; ++ph) {
int hstart = ph * stride_height - padding_height;
int hend = std::min(hstart + ksize_height, input_height);
hstart = std::max(hstart, 0);
for (int pw = 0; pw < output_width; ++pw) {
int wstart = pw * stride_width - padding_width;
int wend = std::min(wstart + ksize_width, input_width);
wstart = std::max(wstart, 0);
int pool_size = (hend - hstart) * (wend - wstart);
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
pool_process.gradProcess(
input_data[h * input_width + w],
output_data[ph * output_width + pw],
output_grad_data[ph * output_width + pw],
input_grad_data[h * input_width + w],
static_cast<T>(pool_size));
}
}
}
}
input_data += input_stride;
output_data += output_stride;
input_grad_data += input_stride;
output_grad_data += output_stride;
}
}
}
};
template class Pool2dForwardFunctor<
platform::CPUPlace, paddle::operators::math::pool::maxPool<float>, float>;
template class Pool2dForwardFunctor<
platform::CPUPlace, paddle::operators::math::pool::avePool<float>, float>;
template class Pool2dBackwardFunctor<
platform::CPUPlace, paddle::operators::math::pool::maxPool<float>, float>;
template class Pool2dBackwardFunctor<
platform::CPUPlace, paddle::operators::math::pool::avePool<float>, float>;
template class Pool2dForwardFunctor<
platform::CPUPlace, paddle::operators::math::pool::maxPool<double>, double>;
template class Pool2dForwardFunctor<
platform::CPUPlace, paddle::operators::math::pool::avePool<double>, double>;
template class Pool2dBackwardFunctor<
platform::CPUPlace, paddle::operators::math::pool::maxPool<double>, double>;
template class Pool2dBackwardFunctor<
platform::CPUPlace, paddle::operators::math::pool::avePool<double>, double>;
template <typename PoolProcess, class T>
class Pool3dForwardFunctor<platform::CPUPlace, PoolProcess, T> {
public:
void operator()(const framework::Tensor& input, framework::Tensor& output,
std::vector<int>& ksize, std::vector<int>& strides,
std::vector<int>& paddings, PoolProcess pool_process,
platform::DeviceContext* context) {
const int batch_size = input.dims()[0];
const int input_depth = input.dims()[2];
const int input_height = input.dims()[3];
const int input_width = input.dims()[4];
const int output_channels = output.dims()[1];
const int output_depth = output.dims()[2];
const int output_height = output.dims()[3];
const int output_width = output.dims()[4];
const int ksize_depth = ksize[0];
const int ksize_height = ksize[1];
const int ksize_width = ksize[2];
const int stride_depth = strides[0];
const int stride_height = strides[1];
const int stride_width = strides[2];
const int padding_depth = paddings[0];
const int padding_height = paddings[1];
const int padding_width = paddings[2];
const int input_stride = input_depth * input_height * input_width;
const int output_stride = output_depth * output_height * output_width;
const T* input_data = input.data<T>();
T* output_data = output.mutable_data<T>(context->GetPlace());
for (int i = 0; i < batch_size; i++) {
for (int c = 0; c < output_channels; ++c) {
for (int pd = 0; pd < output_depth; ++pd) {
int dstart = pd * stride_depth - padding_depth;
int dend = std::min(dstart + ksize_depth, input_depth);
dstart = std::max(dstart, 0);
for (int ph = 0; ph < output_height; ++ph) {
int hstart = ph * stride_height - padding_height;
int hend = std::min(hstart + ksize_height, input_height);
hstart = std::max(hstart, 0);
for (int pw = 0; pw < output_width; ++pw) {
int wstart = pw * stride_width - padding_width;
int wend = std::min(wstart + ksize_width, input_width);
wstart = std::max(wstart, 0);
int output_idx = (pd * output_height + ph) * output_width + pw;
T ele = pool_process.initial();
for (int d = dstart; d < dend; ++d) {
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
pool_process.process(
ele,
input_data[(d * input_height + h) * input_width + w]);
}
}
}
int pool_size =
(dend - dstart) * (hend - hstart) * (wend - wstart);
pool_process.finalize(ele, static_cast<T>(pool_size));
output_data[output_idx] = ele;
}
}
}
input_data += input_stride;
output_data += output_stride;
}
}
}
};
template <typename PoolProcess, class T>
class Pool3dBackwardFunctor<platform::CPUPlace, PoolProcess, T> {
public:
void operator()(const framework::Tensor& input, framework::Tensor& input_grad,
const framework::Tensor& output,
const framework::Tensor& output_grad, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings,
PoolProcess pool_process, platform::DeviceContext* context) {
const int batch_size = input.dims()[0];
const int input_depth = input.dims()[2];
const int input_height = input.dims()[3];
const int input_width = input.dims()[4];
const int output_channels = output.dims()[1];
const int output_depth = output.dims()[2];
const int output_height = output.dims()[3];
const int output_width = output.dims()[4];
const int ksize_depth = ksize[0];
const int ksize_height = ksize[1];
const int ksize_width = ksize[2];
const int stride_depth = strides[0];
const int stride_height = strides[1];
const int stride_width = strides[2];
const int padding_depth = paddings[0];
const int padding_height = paddings[1];
const int padding_width = paddings[2];
const int input_stride = input_depth * input_height * input_width;
const int output_stride = output_depth * output_height * output_width;
const T* input_data = input.data<T>();
const T* output_data = output.data<T>();
const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad.mutable_data<T>(context->GetPlace());
for (int i = 0; i < batch_size; i++) {
for (int c = 0; c < output_channels; ++c) {
for (int pd = 0; pd < output_depth; ++pd) {
int dstart = pd * stride_depth - padding_depth;
int dend = std::min(dstart + ksize_depth, input_depth);
dstart = std::max(dstart, 0);
for (int ph = 0; ph < output_height; ++ph) {
int hstart = ph * stride_height - padding_height;
int hend = std::min(hstart + ksize_height, input_height);
hstart = std::max(hstart, 0);
for (int pw = 0; pw < output_width; ++pw) {
int wstart = pw * stride_width - padding_width;
int wend = std::min(wstart + ksize_width, input_width);
wstart = std::max(wstart, 0);
int pool_size =
(dend - dstart) * (hend - hstart) * (wend - wstart);
for (int d = dstart; d < dend; ++d) {
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
int input_idx = (d * input_height + h) * input_width + w;
int output_idx =
(pd * output_height + ph) * output_width + pw;
pool_process.gradProcess(
input_data[input_idx], output_data[output_idx],
output_grad_data[output_idx],
input_grad_data[input_idx], static_cast<T>(pool_size));
}
}
}
}
}
input_data += input_stride;
output_data += output_stride;
input_grad_data += input_stride;
output_grad_data += output_stride;
}
}
}
}
};
template class Pool3dForwardFunctor<
platform::CPUPlace, paddle::operators::math::pool::maxPool<float>, float>;
template class Pool3dForwardFunctor<
platform::CPUPlace, paddle::operators::math::pool::avePool<float>, float>;
template class Pool3dBackwardFunctor<
platform::CPUPlace, paddle::operators::math::pool::maxPool<float>, float>;
template class Pool3dBackwardFunctor<
platform::CPUPlace, paddle::operators::math::pool::avePool<float>, float>;
template class Pool3dForwardFunctor<
platform::CPUPlace, paddle::operators::math::pool::maxPool<double>, double>;
template class Pool3dForwardFunctor<
platform::CPUPlace, paddle::operators::math::pool::avePool<double>, double>;
template class Pool3dBackwardFunctor<
platform::CPUPlace, paddle::operators::math::pool::maxPool<double>, double>;
template class Pool3dBackwardFunctor<
platform::CPUPlace, paddle::operators::math::pool::avePool<double>, double>;
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 paddlepaddle Authors. All Rights
Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/pooling.h"
namespace paddle {
namespace operators {
namespace math {
template <typename PoolProcess, typename T>
__global__ void KernelPool2dForward(
const int nthreads, const T* input_data, T* output_data, const int channels,
const int input_height, const int input_width, const int output_height,
const int output_width, const int ksize_height, const int ksize_width,
const int stride_height, const int stride_width, const int padding_height,
const int padding_width, PoolProcess pool_process) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < nthreads) {
int pw = index % output_width;
int ph = (index / output_width) % output_height;
int c = (index / output_width / output_height) % channels;
int batch_idx = index / output_width / output_height / channels;
int hstart = ph * stride_height - padding_height;
int hend = min(hstart + ksize_height, input_height);
hstart = max(hstart, 0);
int wstart = pw * stride_width - padding_width;
int wend = min(wstart + ksize_width, input_width);
wstart = max(wstart, 0);
input_data += (batch_idx * channels + c) * input_height * input_width;
T ele = pool_process.initial();
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
pool_process.process(ele, input_data[h * input_width + w]);
}
}
int pool_size = (hend - hstart) * (wend - wstart);
pool_process.finalize(ele, (static_cast<T>(pool_size)));
output_data[index] = ele;
}
}
template <typename PoolProcess, typename T>
__global__ void KernelPool2dBackward(
const int nthreads, const T* input_data, const T* output_data,
const T* output_grad, T* input_grad, const int channels,
const int input_height, const int input_width, const int output_height,
const int output_width, const int ksize_height, const int ksize_width,
const int stride_height, const int stride_width, const int padding_height,
const int padding_width, PoolProcess pool_process) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < nthreads) {
int offsetW = index % input_width + padding_width;
int offsetH = (index / input_width) % input_height + padding_height;
int offsetC = (index / input_width / input_height) % channels;
int batch_idx = index / input_width / input_height / channels;
int phstart = (offsetH < ksize_height)
? 0
: (offsetH - ksize_height) / stride_height + 1;
int pwstart = (offsetW < ksize_width)
? 0
: (offsetW - ksize_width) / stride_width + 1;
int phend = min(offsetH / stride_height + 1, output_height);
int pwend = min(offsetW / stride_width + 1, output_width);
T gradient = 0;
T input = input_data[index];
int output_idx =
(batch_idx * channels + offsetC) * output_height * output_width;
output_data += output_idx;
output_grad += output_idx;
for (int ph = phstart; ph < phend; ++ph) {
for (int pw = pwstart; pw < pwend; ++pw) {
int hstart = ph * stride_height - padding_height;
int wstart = pw * stride_width - padding_width;
int hend = min(hstart + ksize_height, input_height);
int wend = min(wstart + ksize_width, input_width);
hstart = max(hstart, 0);
wstart = max(wstart, 0);
int pool_size = (hend - hstart) * (wend - wstart);
int output_sub_idx = ph * output_width + pw;
pool_process.gradProcess(input, output_data[output_sub_idx],
output_grad[output_sub_idx], gradient,
static_cast<T>(pool_size));
}
}
input_grad[index] = gradient;
}
}
template <typename PoolProcess, typename T>
class Pool2dForwardFunctor<platform::GPUPlace, PoolProcess, T> {
public:
void operator()(const framework::Tensor& input, framework::Tensor& output,
std::vector<int>& ksize, std::vector<int>& strides,
std::vector<int>& paddings, PoolProcess pool_process,
platform::DeviceContext* context) {
const int batch_size = input.dims()[0];
const int input_channels = input.dims()[1];
const int input_height = input.dims()[2];
const int input_width = input.dims()[3];
const int output_channels = output.dims()[1];
const int output_height = output.dims()[2];
const int output_width = output.dims()[3];
const int ksize_height = ksize[0];
const int ksize_width = ksize[1];
const int stride_height = strides[0];
const int stride_width = strides[1];
const int padding_height = paddings[0];
const int padding_width = paddings[1];
const T* input_data = input.data<T>();
T* output_data = output.mutable_data<T>(context->GetPlace());
int nthreads = batch_size * output_channels * output_height * output_width;
int blocks = (nthreads + 1024 - 1) / 1024;
dim3 threads(1024, 1);
dim3 grid(blocks, 1);
KernelPool2dForward<PoolProcess, T><<<grid, threads, 0, 0>>>(
nthreads, input_data, output_data, input_channels, input_height,
input_width, output_height, output_width, ksize_height, ksize_width,
stride_height, stride_width, padding_height, padding_width,
pool_process);
// CHECK_SYNC("Pool2dForwardKernel failed");
}
};
template <typename PoolProcess, typename T>
class Pool2dBackwardFunctor<platform::GPUPlace, PoolProcess, T> {
public:
void operator()(const framework::Tensor& input, framework::Tensor& input_grad,
const framework::Tensor& output,
const framework::Tensor& output_grad, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings,
PoolProcess pool_process, platform::DeviceContext* context) {
const int batch_size = input.dims()[0];
const int input_channels = input.dims()[1];
const int input_height = input.dims()[2];
const int input_width = input.dims()[3];
const int output_height = output.dims()[2];
const int output_width = output.dims()[3];
const int ksize_height = ksize[0];
const int ksize_width = ksize[1];
const int stride_height = strides[0];
const int stride_width = strides[1];
const int padding_height = paddings[0];
const int padding_width = paddings[1];
const T* input_data = input.data<T>();
const T* output_data = output.data<T>();
const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad.mutable_data<T>(context->GetPlace());
int nthreads = batch_size * input_channels * input_height * input_width;
int blocks = (nthreads + 1024 - 1) / 1024;
dim3 threads(1024, 1);
dim3 grid(blocks, 1);
KernelPool2dBackward<PoolProcess, T><<<grid, threads, 0, 0>>>(
nthreads, input_data, output_data, output_grad_data, input_grad_data,
input_channels, input_height, input_width, output_height, output_width,
ksize_height, ksize_width, stride_height, stride_width, padding_height,
padding_width, pool_process);
// CHECK_SYNC("KernelPool2dBackward failed");
}
};
template class Pool2dForwardFunctor<
platform::GPUPlace, paddle::operators::math::pool::maxPool<float>, float>;
template class Pool2dForwardFunctor<
platform::GPUPlace, paddle::operators::math::pool::avePool<float>, float>;
template class Pool2dBackwardFunctor<
platform::GPUPlace, paddle::operators::math::pool::maxPool<float>, float>;
template class Pool2dBackwardFunctor<
platform::GPUPlace, paddle::operators::math::pool::avePool<float>, float>;
template class Pool2dForwardFunctor<
platform::GPUPlace, paddle::operators::math::pool::maxPool<double>, double>;
template class Pool2dForwardFunctor<
platform::GPUPlace, paddle::operators::math::pool::avePool<double>, double>;
template class Pool2dBackwardFunctor<
platform::GPUPlace, paddle::operators::math::pool::maxPool<double>, double>;
template class Pool2dBackwardFunctor<
platform::GPUPlace, paddle::operators::math::pool::avePool<double>, double>;
template <typename PoolProcess, typename T>
__global__ void KernelPool3DForward(
const int nthreads, const T* input_data, T* output_data, const int channels,
const int input_depth, const int input_height, const int input_width,
const int output_depth, const int output_height, const int output_width,
const int ksize_depth, const int ksize_height, const int ksize_width,
const int stride_depth, const int stride_height, const int stride_width,
const int padding_depth, const int padding_height, const int padding_width,
PoolProcess pool_process) {
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < (nthreads);
index += blockDim.x * gridDim.x) {
int pw = index % output_width;
int ph = (index / output_width) % output_height;
int pd = (index / output_width / output_height) % output_depth;
int c = (index / output_width / output_height / output_depth) % channels;
int batch_idx =
index / output_width / output_height / output_depth / channels;
int dstart = pd * stride_depth - padding_depth;
int hstart = ph * stride_height - padding_height;
int wstart = pw * stride_width - padding_width;
int dend = min(dstart + ksize_depth, input_depth);
int hend = min(hstart + ksize_height, input_height);
int wend = min(wstart + ksize_width, input_width);
dstart = max(dstart, 0);
hstart = max(hstart, 0);
wstart = max(wstart, 0);
T ele = pool_process.initial();
input_data +=
(batch_idx * channels + c) * input_depth * input_height * input_width;
for (int d = dstart; d < dend; ++d) {
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
pool_process.process(
ele, input_data[(d * input_height + h) * input_width + w]);
}
}
}
int pool_size = (dend - dstart) * (hend - hstart) * (wend - wstart);
pool_process.finalize(ele, static_cast<T>(pool_size));
output_data[index] = ele;
}
}
template <typename PoolProcess, typename T>
__global__ void KernelPool3DBackward(
const int nthreads, const T* input_data, const T* output_data,
const T* output_grad, T* input_grad, const int channels,
const int input_depth, const int input_height, const int input_width,
const int output_depth, const int output_height, const int output_width,
const int ksize_depth, const int ksize_height, const int ksize_width,
const int stride_depth, const int stride_height, const int stride_width,
const int padding_depth, const int padding_height, const int padding_width,
PoolProcess pool_process) {
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < (nthreads);
index += blockDim.x * gridDim.x) {
int offsetW = index % input_width + padding_width;
int offsetH = (index / input_width) % input_height + padding_height;
int offsetD =
(index / input_width / input_height) % input_depth + padding_depth;
int offsetC = (index / input_width / input_height / input_depth) % channels;
int batch_idx = index / input_width / input_height / input_depth / channels;
int pdstart = (offsetD < ksize_depth)
? 0
: (offsetD + ksize_depth) / stride_depth + 1;
int phstart = (offsetH < ksize_height)
? 0
: (offsetH - ksize_height) / stride_height + 1;
int pwstart = (offsetW < ksize_width)
? 0
: (offsetW - ksize_width) / stride_width + 1;
int pdend = min((offsetD) / stride_depth + 1, output_depth);
int phend = min((offsetH) / stride_height + 1, output_height);
int pwend = min((offsetW) / stride_width + 1, output_width);
T gradient = 0;
T input = input_data[index];
int output_idx = (batch_idx * channels + offsetC) * output_depth *
output_height * output_width;
output_data += output_idx;
output_grad += output_idx;
for (int pd = pdstart; pd < pdend; ++pd) {
for (int ph = phstart; ph < phend; ++ph) {
for (int pw = pwstart; pw < pwend; ++pw) {
// figure out the pooling size
int dstart = pd * stride_depth - padding_depth;
int hstart = ph * stride_height - padding_height;
int wstart = pw * stride_width - padding_width;
int dend = min(dstart + ksize_depth, input_depth);
int hend = min(hstart + ksize_height, input_height);
int wend = min(wstart + ksize_width, input_width);
dstart = max(dstart, 0);
hstart = max(hstart, 0);
wstart = max(wstart, 0);
int pool_size = (dend - dstart) * (hend - hstart) * (wend - wstart);
int output_sub_idx = ph * output_width + pw;
pool_process.gradProcess(input, output_data[output_sub_idx],
output_grad[output_sub_idx], gradient,
static_cast<T>(pool_size));
}
}
}
input_grad[index] = gradient;
}
}
template <typename PoolProcess, class T>
class Pool3dForwardFunctor<platform::GPUPlace, PoolProcess, T> {
public:
void operator()(const framework::Tensor& input, framework::Tensor& output,
std::vector<int>& ksize, std::vector<int>& strides,
std::vector<int>& paddings, PoolProcess pool_process,
platform::DeviceContext* context) {
const int batch_size = input.dims()[0];
const int input_channels = input.dims()[1];
const int input_depth = input.dims()[2];
const int input_height = input.dims()[3];
const int input_width = input.dims()[4];
const int output_channels = output.dims()[1];
const int output_depth = output.dims()[2];
const int output_height = output.dims()[3];
const int output_width = output.dims()[4];
const int ksize_depth = ksize[0];
const int ksize_height = ksize[1];
const int ksize_width = ksize[2];
const int stride_depth = strides[0];
const int stride_height = strides[1];
const int stride_width = strides[2];
const int padding_depth = paddings[0];
const int padding_height = paddings[1];
const int padding_width = paddings[2];
const T* input_data = input.data<T>();
T* output_data = output.mutable_data<T>(context->GetPlace());
int nthreads = batch_size * output_channels * output_depth * output_height *
output_width;
int blocks = (nthreads + 1024 - 1) / 1024;
dim3 threads(1024, 1);
dim3 grid(blocks, 1);
KernelPool3DForward<PoolProcess, T><<<grid, threads, 0, 0>>>(
nthreads, input_data, output_data, input_channels, input_depth,
input_height, input_width, output_depth, output_height, output_width,
ksize_depth, ksize_height, ksize_width, stride_depth, stride_height,
stride_width, padding_depth, padding_height, padding_width,
pool_process);
// CHECK_SYNC("Pool2dForwardKernel failed");
}
};
template <typename PoolProcess, class T>
class Pool3dBackwardFunctor<platform::GPUPlace, PoolProcess, T> {
public:
void operator()(const framework::Tensor& input, framework::Tensor& input_grad,
const framework::Tensor& output,
const framework::Tensor& output_grad, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings,
PoolProcess pool_process, platform::DeviceContext* context) {
const int batch_size = input.dims()[0];
const int input_channels = input.dims()[1];
const int input_depth = input.dims()[2];
const int input_height = input.dims()[3];
const int input_width = input.dims()[4];
const int output_channels = output.dims()[1];
const int output_depth = output.dims()[2];
const int output_height = output.dims()[3];
const int output_width = output.dims()[4];
const int ksize_depth = ksize[0];
const int ksize_height = ksize[1];
const int ksize_width = ksize[2];
const int stride_depth = strides[0];
const int stride_height = strides[1];
const int stride_width = strides[2];
const int padding_depth = paddings[0];
const int padding_height = paddings[1];
const int padding_width = paddings[2];
const T* input_data = input.data<T>();
const T* output_data = output.data<T>();
const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad.mutable_data<T>(context->GetPlace());
int nthreads = batch_size * input_channels * input_height * input_width;
int blocks = (nthreads + 1024 - 1) / 1024;
dim3 threads(1024, 1);
dim3 grid(blocks, 1);
KernelPool3DBackward<PoolProcess, T><<<grid, threads, 0, 0>>>(
nthreads, input_data, output_data, output_grad_data, input_grad_data,
input_channels, input_depth, input_height, input_width, output_depth,
output_height, output_width, ksize_depth, ksize_height, ksize_width,
stride_depth, stride_height, stride_width, padding_depth,
padding_height, padding_width, pool_process);
// CHECK_SYNC("KernelPool2dBackward failed");
}
};
template class Pool3dForwardFunctor<
platform::GPUPlace, paddle::operators::math::pool::maxPool<float>, float>;
template class Pool3dForwardFunctor<
platform::GPUPlace, paddle::operators::math::pool::avePool<float>, float>;
template class Pool3dBackwardFunctor<
platform::GPUPlace, paddle::operators::math::pool::maxPool<float>, float>;
template class Pool3dBackwardFunctor<
platform::GPUPlace, paddle::operators::math::pool::avePool<float>, float>;
template class Pool3dForwardFunctor<
platform::GPUPlace, paddle::operators::math::pool::maxPool<double>, double>;
template class Pool3dForwardFunctor<
platform::GPUPlace, paddle::operators::math::pool::avePool<double>, double>;
template class Pool3dBackwardFunctor<
platform::GPUPlace, paddle::operators::math::pool::maxPool<double>, double>;
template class Pool3dBackwardFunctor<
platform::GPUPlace, paddle::operators::math::pool::avePool<double>, double>;
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h"
namespace paddle {
namespace operators {
namespace math {
//////////////////////
#ifdef __NVCC__
#define HL_DEVICE __device__
#else
#define HL_DEVICE
#endif
#define FLT_MAX __FLT_MAX__
/////////////////////
namespace pool {
template <class T>
class maxPool {
public:
HL_DEVICE inline T initial() { return -(T)(FLT_MAX); }
HL_DEVICE inline void process(T& y, const T& x) { y = y > x ? y : x; }
HL_DEVICE inline void finalize(T& y, const T& poo_size) {}
HL_DEVICE inline void gradProcess(const T& x, const T& y, const T& dy, T& dx,
T scale) {
dx += dy * (x == y);
}
};
template <class T>
class avePool {
public:
HL_DEVICE inline T initial() { return 0; }
HL_DEVICE inline void process(T& y, const T& x) { y += x; }
HL_DEVICE inline void finalize(T& y, const T& poo_size) { y /= poo_size; }
HL_DEVICE inline void gradProcess(const T& x, const T& y, const T& dy, T& dx,
T scale) {
dx += (scale * dy);
}
};
} // namespace pool
template <typename Place, typename PoolProcess, typename T>
class Pool2dForwardFunctor {
public:
void operator()(const framework::Tensor& input, framework::Tensor& output,
std::vector<int>& ksize, std::vector<int>& strides,
std::vector<int>& paddings, PoolProcess pool_process,
platform::DeviceContext* context);
};
template <typename Place, typename PoolProcess, typename T>
class Pool2dBackwardFunctor {
public:
void operator()(const framework::Tensor& input, framework::Tensor& input_grad,
const framework::Tensor& output,
const framework::Tensor& output_grad, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings,
PoolProcess pool_process, platform::DeviceContext* context);
};
template <typename Place, typename PoolProcess, typename T>
class Pool3dForwardFunctor {
public:
void operator()(const framework::Tensor& input, framework::Tensor& output,
std::vector<int>& ksize, std::vector<int>& strides,
std::vector<int>& paddings, PoolProcess pool_process,
platform::DeviceContext* context);
};
template <typename Place, typename PoolProcess, typename T>
class Pool3dBackwardFunctor {
public:
void operator()(const framework::Tensor& input, framework::Tensor& input_grad,
const framework::Tensor& output,
const framework::Tensor& output_grad, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings,
PoolProcess pool_process, platform::DeviceContext* context);
};
} // namespace math
} // namespace operators
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册