提交 e1e3859e 编写于 作者: C chengduoZH

remove custom attr checker and fix code format

上级 3c0f0793
......@@ -24,7 +24,7 @@ class Pool2dFunctor<platform::CPUPlace, PoolProcess, T> {
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& output,
std::vector<int>& ksize, std::vector<int>& strides,
std::vector<int>& paddings, PoolProcess pool_compute) {
std::vector<int>& paddings, PoolProcess pool_process) {
const int batch_size = input.dims()[0];
const int input_height = input.dims()[2];
const int input_width = input.dims()[3];
......@@ -54,14 +54,15 @@ class Pool2dFunctor<platform::CPUPlace, PoolProcess, T> {
int wstart = pw * stride_width - padding_width;
int wend = std::min(wstart + ksize_width, input_width);
wstart = std::max(wstart, 0);
T ele = pool_compute.initial();
T ele = pool_process.initial();
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
pool_compute.compute(ele, input_data[h * input_width + w]);
pool_process.compute(ele, input_data[h * input_width + w]);
}
}
int pool_size = (hend - hstart) * (wend - wstart);
pool_compute.finalize(ele, (static_cast<T>(pool_size)));
pool_process.finalize(ele, (static_cast<T>(pool_size)));
output_data[ph * output_width + pw] = ele;
}
}
......@@ -80,7 +81,7 @@ class Pool2dGradFunctor<platform::CPUPlace, PoolProcess, T> {
const framework::Tensor& output,
const framework::Tensor& output_grad, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings,
PoolProcess pool_compute) {
PoolProcess pool_grad_process) {
const int batch_size = input.dims()[0];
const int input_height = input.dims()[2];
const int input_width = input.dims()[3];
......@@ -115,11 +116,12 @@ class Pool2dGradFunctor<platform::CPUPlace, PoolProcess, T> {
float scale = 1.0 / pool_size;
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
pool_compute.compute(input_data[h * input_width + w],
output_data[ph * output_width + pw],
output_grad_data[ph * output_width + pw],
input_grad_data[h * input_width + w],
static_cast<T>(scale));
pool_grad_process.compute(
input_data[h * input_width + w],
output_data[ph * output_width + pw],
output_grad_data[ph * output_width + pw],
input_grad_data[h * input_width + w],
static_cast<T>(scale));
}
}
}
......@@ -198,21 +200,21 @@ template class MaxPool2dGradFunctor<platform::CPUPlace, float>;
// template class MaxPool2dGradFunctor<platform::CPUPlace, double>;
template class Pool2dFunctor<platform::CPUPlace,
paddle::operators::math::maxPool<float>, float>;
paddle::operators::math::MaxPool<float>, float>;
template class Pool2dFunctor<platform::CPUPlace,
paddle::operators::math::avgPool<float>, float>;
paddle::operators::math::AvgPool<float>, float>;
template class Pool2dGradFunctor<
platform::CPUPlace, paddle::operators::math::maxPoolGrad<float>, float>;
platform::CPUPlace, paddle::operators::math::MaxPoolGrad<float>, float>;
template class Pool2dGradFunctor<
platform::CPUPlace, paddle::operators::math::avgPoolGrad<float>, float>;
platform::CPUPlace, paddle::operators::math::AvgPoolGrad<float>, float>;
template class Pool2dFunctor<platform::CPUPlace,
paddle::operators::math::maxPool<double>, double>;
paddle::operators::math::MaxPool<double>, double>;
template class Pool2dFunctor<platform::CPUPlace,
paddle::operators::math::avgPool<double>, double>;
paddle::operators::math::AvgPool<double>, double>;
template class Pool2dGradFunctor<
platform::CPUPlace, paddle::operators::math::maxPoolGrad<double>, double>;
platform::CPUPlace, paddle::operators::math::MaxPoolGrad<double>, double>;
template class Pool2dGradFunctor<
platform::CPUPlace, paddle::operators::math::avgPoolGrad<double>, double>;
platform::CPUPlace, paddle::operators::math::AvgPoolGrad<double>, double>;
template <typename PoolProcess, class T>
class Pool3dFunctor<platform::CPUPlace, PoolProcess, T> {
......@@ -220,7 +222,7 @@ class Pool3dFunctor<platform::CPUPlace, PoolProcess, T> {
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& output,
std::vector<int>& ksize, std::vector<int>& strides,
std::vector<int>& paddings, PoolProcess pool_compute) {
std::vector<int>& paddings, PoolProcess pool_process) {
const int batch_size = input.dims()[0];
const int input_depth = input.dims()[2];
const int input_height = input.dims()[3];
......@@ -260,11 +262,11 @@ class Pool3dFunctor<platform::CPUPlace, PoolProcess, T> {
int wend = std::min(wstart + ksize_width, input_width);
wstart = std::max(wstart, 0);
int output_idx = (pd * output_height + ph) * output_width + pw;
T ele = pool_compute.initial();
T ele = pool_process.initial();
for (int d = dstart; d < dend; ++d) {
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
pool_compute.compute(
pool_process.compute(
ele,
input_data[(d * input_height + h) * input_width + w]);
}
......@@ -272,7 +274,7 @@ class Pool3dFunctor<platform::CPUPlace, PoolProcess, T> {
}
int pool_size =
(dend - dstart) * (hend - hstart) * (wend - wstart);
pool_compute.finalize(ele, static_cast<T>(pool_size));
pool_process.finalize(ele, static_cast<T>(pool_size));
output_data[output_idx] = ele;
}
}
......@@ -292,7 +294,7 @@ class Pool3dGradFunctor<platform::CPUPlace, PoolProcess, T> {
const framework::Tensor& output,
const framework::Tensor& output_grad, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings,
PoolProcess pool_compute) {
PoolProcess pool_grad_process) {
const int batch_size = input.dims()[0];
const int input_depth = input.dims()[2];
const int input_height = input.dims()[3];
......@@ -343,7 +345,7 @@ class Pool3dGradFunctor<platform::CPUPlace, PoolProcess, T> {
int input_idx = (d * input_height + h) * input_width + w;
int output_idx =
(pd * output_height + ph) * output_width + pw;
pool_compute.compute(
pool_grad_process.compute(
input_data[input_idx], output_data[output_idx],
output_grad_data[output_idx],
input_grad_data[input_idx], static_cast<T>(scale));
......@@ -441,21 +443,21 @@ template class MaxPool3dGradFunctor<platform::CPUPlace, float>;
// template class MaxPool3dGradFunctor<platform::CPUPlace, double>;
template class Pool3dFunctor<platform::CPUPlace,
paddle::operators::math::maxPool<float>, float>;
paddle::operators::math::MaxPool<float>, float>;
template class Pool3dFunctor<platform::CPUPlace,
paddle::operators::math::avgPool<float>, float>;
paddle::operators::math::AvgPool<float>, float>;
template class Pool3dGradFunctor<
platform::CPUPlace, paddle::operators::math::maxPoolGrad<float>, float>;
platform::CPUPlace, paddle::operators::math::MaxPoolGrad<float>, float>;
template class Pool3dGradFunctor<
platform::CPUPlace, paddle::operators::math::avgPoolGrad<float>, float>;
platform::CPUPlace, paddle::operators::math::AvgPoolGrad<float>, float>;
template class Pool3dFunctor<platform::CPUPlace,
paddle::operators::math::maxPool<double>, double>;
paddle::operators::math::MaxPool<double>, double>;
template class Pool3dFunctor<platform::CPUPlace,
paddle::operators::math::avgPool<double>, double>;
paddle::operators::math::AvgPool<double>, double>;
template class Pool3dGradFunctor<
platform::CPUPlace, paddle::operators::math::maxPoolGrad<double>, double>;
platform::CPUPlace, paddle::operators::math::MaxPoolGrad<double>, double>;
template class Pool3dGradFunctor<
platform::CPUPlace, paddle::operators::math::avgPoolGrad<double>, double>;
platform::CPUPlace, paddle::operators::math::AvgPoolGrad<double>, double>;
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -20,14 +20,16 @@ namespace operators {
namespace math {
template <typename PoolProcess, typename T>
__global__ void KernelPool2dForward(
const int nthreads, const T* input_data, T* output_data, const int channels,
const int input_height, const int input_width, const int output_height,
const int output_width, const int ksize_height, const int ksize_width,
const int stride_height, const int stride_width, const int padding_height,
const int padding_width, PoolProcess pool_compute) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < nthreads) {
__global__ void KernelPool2D(const int nthreads, const T* input_data,
T* output_data, const int channels,
const int input_height, const int input_width,
const int output_height, const int output_width,
const int ksize_height, const int ksize_width,
const int stride_height, const int stride_width,
const int padding_height, const int padding_width,
PoolProcess pool_process) {
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
index += blockDim.x * gridDim.x) {
int pw = index % output_width;
int ph = (index / output_width) % output_height;
int c = (index / output_width / output_height) % channels;
......@@ -42,28 +44,28 @@ __global__ void KernelPool2dForward(
wstart = max(wstart, 0);
input_data += (batch_idx * channels + c) * input_height * input_width;
T ele = pool_compute.initial();
T ele = pool_process.initial();
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
pool_compute.compute(ele, input_data[h * input_width + w]);
pool_process.compute(ele, input_data[h * input_width + w]);
}
}
int pool_size = (hend - hstart) * (wend - wstart);
pool_compute.finalize(ele, (static_cast<T>(pool_size)));
pool_process.finalize(ele, (static_cast<T>(pool_size)));
output_data[index] = ele;
}
}
template <typename PoolProcess, typename T>
__global__ void KernelPool2dBackward(
__global__ void KernelPool2DGrad(
const int nthreads, const T* input_data, const T* output_data,
const T* output_grad, T* input_grad, const int channels,
const int input_height, const int input_width, const int output_height,
const int output_width, const int ksize_height, const int ksize_width,
const int stride_height, const int stride_width, const int padding_height,
const int padding_width, PoolProcess pool_compute) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < nthreads) {
const int padding_width, PoolProcess pool_process) {
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
index += blockDim.x * gridDim.x) {
int offsetW = index % input_width + padding_width;
int offsetH = (index / input_width) % input_height + padding_height;
int offsetC = (index / input_width / input_height) % channels;
......@@ -93,7 +95,7 @@ __global__ void KernelPool2dBackward(
wstart = max(wstart, 0);
int pool_size = (hend - hstart) * (wend - wstart);
int output_sub_idx = ph * output_width + pw;
pool_compute.compute(input, output_data[output_sub_idx],
pool_process.compute(input, output_data[output_sub_idx],
output_grad[output_sub_idx], gradient,
static_cast<T>(1.0 / pool_size));
}
......@@ -103,15 +105,15 @@ __global__ void KernelPool2dBackward(
}
template <typename T>
__global__ void KernelMaxPool2dBackward(
__global__ void KernelMaxPool2DGrad(
const int nthreads, const T* input_data, const T* output_data,
const T* output_grad, T* input_grad, const int channels,
const int input_height, const int input_width, const int output_height,
const int output_width, const int ksize_height, const int ksize_width,
const int stride_height, const int stride_width, const int padding_height,
const int padding_width) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < nthreads) {
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
index += blockDim.x * gridDim.x) {
int pw = index % output_width;
int ph = (index / output_width) % output_height;
int c = (index / output_width / output_height) % channels;
......@@ -153,7 +155,7 @@ class Pool2dFunctor<platform::GPUPlace, PoolProcess, T> {
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& output,
std::vector<int>& ksize, std::vector<int>& strides,
std::vector<int>& paddings, PoolProcess pool_compute) {
std::vector<int>& paddings, PoolProcess pool_process) {
const int batch_size = input.dims()[0];
const int input_channels = input.dims()[1];
const int input_height = input.dims()[2];
......@@ -176,7 +178,7 @@ class Pool2dFunctor<platform::GPUPlace, PoolProcess, T> {
dim3 threads(1024, 1);
dim3 grid(blocks, 1);
KernelPool2dForward<
KernelPool2D<
PoolProcess,
T><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
......@@ -184,7 +186,7 @@ class Pool2dFunctor<platform::GPUPlace, PoolProcess, T> {
input_height, input_width, output_height,
output_width, ksize_height, ksize_width,
stride_height, stride_width, padding_height,
padding_width, pool_compute);
padding_width, pool_process);
}
};
......@@ -196,7 +198,7 @@ class Pool2dGradFunctor<platform::GPUPlace, PoolProcess, T> {
const framework::Tensor& output,
const framework::Tensor& output_grad, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings,
PoolProcess pool_compute) {
PoolProcess pool_process) {
const int batch_size = input.dims()[0];
const int input_channels = input.dims()[1];
const int input_height = input.dims()[2];
......@@ -220,7 +222,7 @@ class Pool2dGradFunctor<platform::GPUPlace, PoolProcess, T> {
dim3 threads(1024, 1);
dim3 grid(blocks, 1);
KernelPool2dBackward<
KernelPool2DGrad<
PoolProcess,
T><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
......@@ -228,7 +230,7 @@ class Pool2dGradFunctor<platform::GPUPlace, PoolProcess, T> {
nthreads, input_data, output_data, output_grad_data, input_grad_data,
input_channels, input_height, input_width, output_height, output_width,
ksize_height, ksize_width, stride_height, stride_width, padding_height,
padding_width, pool_compute);
padding_width, pool_process);
}
};
......@@ -264,7 +266,7 @@ class MaxPool2dGradFunctor<platform::GPUPlace, T> {
dim3 threads(1024, 1);
dim3 grid(blocks, 1);
KernelMaxPool2dBackward<
KernelMaxPool2DGrad<
T><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(
......@@ -276,35 +278,37 @@ class MaxPool2dGradFunctor<platform::GPUPlace, T> {
};
template class MaxPool2dGradFunctor<platform::GPUPlace, float>;
// template class MaxPool2dGradFunctor<platform::GPUPlace, double>;
// template class MaxPool2dGradFunctor<platform::GPUPlace, double>; // The
// 64-bit floating-point version of atomicAdd() is only supported by devices of
// compute capability 6.x and higher.
template class Pool2dFunctor<platform::GPUPlace,
paddle::operators::math::maxPool<float>, float>;
paddle::operators::math::MaxPool<float>, float>;
template class Pool2dFunctor<platform::GPUPlace,
paddle::operators::math::avgPool<float>, float>;
paddle::operators::math::AvgPool<float>, float>;
template class Pool2dGradFunctor<
platform::GPUPlace, paddle::operators::math::maxPoolGrad<float>, float>;
platform::GPUPlace, paddle::operators::math::MaxPoolGrad<float>, float>;
template class Pool2dGradFunctor<
platform::GPUPlace, paddle::operators::math::avgPoolGrad<float>, float>;
platform::GPUPlace, paddle::operators::math::AvgPoolGrad<float>, float>;
template class Pool2dFunctor<platform::GPUPlace,
paddle::operators::math::maxPool<double>, double>;
paddle::operators::math::MaxPool<double>, double>;
template class Pool2dFunctor<platform::GPUPlace,
paddle::operators::math::avgPool<double>, double>;
paddle::operators::math::AvgPool<double>, double>;
template class Pool2dGradFunctor<
platform::GPUPlace, paddle::operators::math::maxPoolGrad<double>, double>;
platform::GPUPlace, paddle::operators::math::MaxPoolGrad<double>, double>;
template class Pool2dGradFunctor<
platform::GPUPlace, paddle::operators::math::avgPoolGrad<double>, double>;
platform::GPUPlace, paddle::operators::math::AvgPoolGrad<double>, double>;
template <typename PoolProcess, typename T>
__global__ void KernelPool3DForward(
__global__ void KernelPool3D(
const int nthreads, const T* input_data, T* output_data, const int channels,
const int input_depth, const int input_height, const int input_width,
const int output_depth, const int output_height, const int output_width,
const int ksize_depth, const int ksize_height, const int ksize_width,
const int stride_depth, const int stride_height, const int stride_width,
const int padding_depth, const int padding_height, const int padding_width,
PoolProcess pool_compute) {
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < (nthreads);
PoolProcess pool_process) {
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
index += blockDim.x * gridDim.x) {
int pw = index % output_width;
int ph = (index / output_width) % output_height;
......@@ -321,25 +325,25 @@ __global__ void KernelPool3DForward(
dstart = max(dstart, 0);
hstart = max(hstart, 0);
wstart = max(wstart, 0);
T ele = pool_compute.initial();
T ele = pool_process.initial();
input_data +=
(batch_idx * channels + c) * input_depth * input_height * input_width;
for (int d = dstart; d < dend; ++d) {
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
pool_compute.compute(
pool_process.compute(
ele, input_data[(d * input_height + h) * input_width + w]);
}
}
}
int pool_size = (dend - dstart) * (hend - hstart) * (wend - wstart);
pool_compute.finalize(ele, static_cast<T>(pool_size));
pool_process.finalize(ele, static_cast<T>(pool_size));
output_data[index] = ele;
}
}
template <typename PoolProcess, typename T>
__global__ void KernelPool3DBackward(
__global__ void KernelPool3DGrad(
const int nthreads, const T* input_data, const T* output_data,
const T* output_grad, T* input_grad, const int channels,
const int input_depth, const int input_height, const int input_width,
......@@ -347,8 +351,8 @@ __global__ void KernelPool3DBackward(
const int ksize_depth, const int ksize_height, const int ksize_width,
const int stride_depth, const int stride_height, const int stride_width,
const int padding_depth, const int padding_height, const int padding_width,
PoolProcess pool_compute) {
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < (nthreads);
PoolProcess pool_process) {
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
index += blockDim.x * gridDim.x) {
int offsetW = index % input_width + padding_width;
int offsetH = (index / input_width) % input_height + padding_height;
......@@ -392,7 +396,7 @@ __global__ void KernelPool3DBackward(
wstart = max(wstart, 0);
int pool_size = (dend - dstart) * (hend - hstart) * (wend - wstart);
int output_sub_idx = (pd * output_height + ph) * output_width + pw;
pool_compute.compute(input, output_data[output_sub_idx],
pool_process.compute(input, output_data[output_sub_idx],
output_grad[output_sub_idx], gradient,
static_cast<T>(1.0 / pool_size));
}
......@@ -403,7 +407,7 @@ __global__ void KernelPool3DBackward(
}
template <typename T>
__global__ void KernelMaxPool3DBackward(
__global__ void KernelMaxPool3DGrad(
const int nthreads, const T* input_data, const T* output_data,
const T* output_grad, T* input_grad, const int channels,
const int input_depth, const int input_height, const int input_width,
......@@ -412,7 +416,7 @@ __global__ void KernelMaxPool3DBackward(
const int stride_depth, const int stride_height, const int stride_width,
const int padding_depth, const int padding_height,
const int padding_width) {
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < (nthreads);
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
index += blockDim.x * gridDim.x) {
int pw = index % output_width;
int ph = (index / output_width) % output_height;
......@@ -460,7 +464,7 @@ class Pool3dFunctor<platform::GPUPlace, PoolProcess, T> {
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& output,
std::vector<int>& ksize, std::vector<int>& strides,
std::vector<int>& paddings, PoolProcess pool_compute) {
std::vector<int>& paddings, PoolProcess pool_process) {
const int batch_size = input.dims()[0];
const int input_channels = input.dims()[1];
const int input_depth = input.dims()[2];
......@@ -489,7 +493,7 @@ class Pool3dFunctor<platform::GPUPlace, PoolProcess, T> {
dim3 threads(1024, 1);
dim3 grid(blocks, 1);
KernelPool3DForward<
KernelPool3D<
PoolProcess,
T><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
......@@ -498,7 +502,7 @@ class Pool3dFunctor<platform::GPUPlace, PoolProcess, T> {
input_height, input_width, output_depth, output_height, output_width,
ksize_depth, ksize_height, ksize_width, stride_depth, stride_height,
stride_width, padding_depth, padding_height, padding_width,
pool_compute);
pool_process);
}
};
......@@ -510,7 +514,7 @@ class Pool3dGradFunctor<platform::GPUPlace, PoolProcess, T> {
const framework::Tensor& output,
const framework::Tensor& output_grad, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings,
PoolProcess pool_compute) {
PoolProcess pool_process) {
const int batch_size = input.dims()[0];
const int input_channels = input.dims()[1];
const int input_depth = input.dims()[2];
......@@ -541,7 +545,7 @@ class Pool3dGradFunctor<platform::GPUPlace, PoolProcess, T> {
dim3 threads(1024, 1);
dim3 grid(blocks, 1);
KernelPool3DBackward<
KernelPool3DGrad<
PoolProcess,
T><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
......@@ -550,7 +554,7 @@ class Pool3dGradFunctor<platform::GPUPlace, PoolProcess, T> {
input_channels, input_depth, input_height, input_width, output_depth,
output_height, output_width, ksize_depth, ksize_height, ksize_width,
stride_depth, stride_height, stride_width, padding_depth,
padding_height, padding_width, pool_compute);
padding_height, padding_width, pool_process);
}
};
......@@ -592,7 +596,7 @@ class MaxPool3dGradFunctor<platform::GPUPlace, T> {
dim3 threads(1024, 1);
dim3 grid(blocks, 1);
KernelMaxPool3DBackward<
KernelMaxPool3DGrad<
T><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(
......@@ -605,24 +609,26 @@ class MaxPool3dGradFunctor<platform::GPUPlace, T> {
};
template class MaxPool3dGradFunctor<platform::GPUPlace, float>;
// template class MaxPool3dGradFunctor<platform::GPUPlace, double>;
// template class MaxPool3dGradFunctor<platform::GPUPlace, double>; // The
// 64-bit floating-point version of atomicAdd() is only supported by devices of
// compute capability 6.x and higher.
template class Pool3dFunctor<platform::GPUPlace,
paddle::operators::math::maxPool<float>, float>;
paddle::operators::math::MaxPool<float>, float>;
template class Pool3dFunctor<platform::GPUPlace,
paddle::operators::math::avgPool<float>, float>;
paddle::operators::math::AvgPool<float>, float>;
template class Pool3dGradFunctor<
platform::GPUPlace, paddle::operators::math::maxPoolGrad<float>, float>;
platform::GPUPlace, paddle::operators::math::MaxPoolGrad<float>, float>;
template class Pool3dGradFunctor<
platform::GPUPlace, paddle::operators::math::avgPoolGrad<float>, float>;
platform::GPUPlace, paddle::operators::math::AvgPoolGrad<float>, float>;
template class Pool3dFunctor<platform::GPUPlace,
paddle::operators::math::maxPool<double>, double>;
paddle::operators::math::MaxPool<double>, double>;
template class Pool3dFunctor<platform::GPUPlace,
paddle::operators::math::avgPool<double>, double>;
paddle::operators::math::AvgPool<double>, double>;
template class Pool3dGradFunctor<
platform::GPUPlace, paddle::operators::math::maxPoolGrad<double>, double>;
platform::GPUPlace, paddle::operators::math::MaxPoolGrad<double>, double>;
template class Pool3dGradFunctor<
platform::GPUPlace, paddle::operators::math::avgPoolGrad<double>, double>;
platform::GPUPlace, paddle::operators::math::AvgPoolGrad<double>, double>;
} // namespace math
} // namespace operators
......
......@@ -22,11 +22,10 @@ namespace paddle {
namespace operators {
namespace math {
//////////////////////
#define FLT_MAX __FLT_MAX__
/////////////////////
#define FLT_MAX __FLT_MAX__ //
template <class T>
class maxPool {
class MaxPool {
public:
DEVICE inline T initial() { return static_cast<T>(-FLT_MAX); }
DEVICE inline void compute(T& y, const T& x) { y = y > x ? y : x; }
......@@ -34,14 +33,14 @@ class maxPool {
};
template <class T>
class avgPool {
class AvgPool {
public:
DEVICE inline T initial() { return static_cast<T>(0); }
DEVICE inline void compute(T& y, const T& x) { y += x; }
DEVICE inline void finalize(T& y, const T& poo_size) { y /= poo_size; }
};
template <class T>
class maxPoolGrad {
class MaxPoolGrad {
public:
DEVICE inline void compute(const T& x, const T& y, const T& dy, T& dx,
T scale) {
......@@ -50,7 +49,7 @@ class maxPoolGrad {
};
template <class T>
class avgPoolGrad {
class AvgPoolGrad {
public:
DEVICE inline void compute(const T& x, const T& y, const T& dy, T& dx,
T scale) {
......
......@@ -51,7 +51,7 @@ class PoolOp : public framework::OperatorWithKernel {
ksize[i] = static_cast<int>(in_x_dims[i + 2]);
}
PADDLE_ENFORCE(in_x_dims.size() - ksize.size() == 2,
PADDLE_ENFORCE(in_x_dims.size() - ksize.size() == 2U,
"Input size and Pooling size should be consistent.");
PADDLE_ENFORCE(ksize.size() == 2 || ksize.size() == 3,
"Pooling size should be 2 elements. or 3 elements.");
......@@ -79,7 +79,6 @@ class PoolOpGrad : public framework::OperatorWithKernel {
"X(Input) of Pooling should not be null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
"Input@Grad of Pooling should not be null.");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
};
......@@ -98,66 +97,36 @@ class Pool2dOpMaker : public framework::OpProtoAndCheckerMaker {
"The format of output tensor is also NCHW.");
AddAttr<std::string>("poolingType",
"poolingType of pooling operator."
"str constant equal to 'max' or 'avg'");
"PoolingType of pooling operator."
"Str constant equal to 'max' or 'avg'.")
.InEnum({"max", "avg"});
AddAttr<std::vector<int>>(
"ksize",
"Pooling size(depth, height, width) of pooling operator."
"If globalPooling = true, ksize is ignored and need not be specified.");
"If globalPooling = true, ksize is ignored and need not be "
"specified."); // TODO(Add checker)
AddAttr<bool>(
"globalPooling",
"whether to use the globalPooling."
"int constant equal to false or true"
"default false"
"Whether to use the globalPooling."
"Bool constant equal to false or true."
"Default false."
"If globalPooling = true, ksize is ignored and need not be specified.")
.SetDefault(false);
AddAttr<std::vector<int>>("strides",
"strides(height, width) of pooling operator."
"default {1,1}")
.SetDefault({1, 1})
.AddCustomChecker(GreaterThanChecker_pool({0, 0}));
"Strides(height, width) of pooling operator."
"Default {1,1}")
.SetDefault({1, 1}); // TODO(Add checker)
AddAttr<std::vector<int>>("paddings",
"paddings(height, width) of pooling operator."
"default {0,0}")
.SetDefault({0, 0})
.AddCustomChecker(EqualGreaterThanChecker_pool({0, 0}));
"Paddings(height, width) of pooling operator."
"Default {0,0}.")
.SetDefault({0, 0}); // TODO(Add checker)
AddComment(R"DOC(
The pooling2d operation calculates the output based on
the input, poolingType and ksize, strides, paddings parameters.
)DOC");
}
private:
struct GreaterThanChecker_pool {
public:
explicit GreaterThanChecker_pool(std::vector<int> lower_bound)
: lower_bound_(lower_bound) {}
void operator()(std::vector<int> &value) const {
PADDLE_ENFORCE(value.size() == lower_bound_.size(), "equal check fails.");
for (size_t i = 0; i < value.size(); ++i) {
PADDLE_ENFORCE(value[i] > lower_bound_[i], "larger_than check fails.");
}
}
private:
std::vector<int> lower_bound_;
};
struct EqualGreaterThanChecker_pool {
public:
explicit EqualGreaterThanChecker_pool(std::vector<int> lower_bound)
: lower_bound_(lower_bound) {}
void operator()(std::vector<int> &value) const {
PADDLE_ENFORCE(value.size() == lower_bound_.size(), "equal check fails.");
for (size_t i = 0; i < value.size(); ++i) {
PADDLE_ENFORCE(value[i] >= lower_bound_[i], "larger_than check fails.");
}
}
private:
std::vector<int> lower_bound_;
};
};
class Pool3dOpMaker : public framework::OpProtoAndCheckerMaker {
public:
Pool3dOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
......@@ -173,67 +142,36 @@ class Pool3dOpMaker : public framework::OpProtoAndCheckerMaker {
"The format of output tensor is also NCDHW.");
AddAttr<std::string>("poolingType",
"poolingType of pooling operator."
"str constant equal to 'max' or 'avg'");
"PoolingType of pooling operator."
"str constant equal to 'max' or 'avg'.")
.InEnum({"max", "avg"});
AddAttr<std::vector<int>>(
"ksize",
"pooling size(depth, height, width) of pooling operator."
"If globalPooling = true, ksize is ignored and need not be specified.");
"Pooling size(depth, height, width) of pooling operator."
"If globalPooling = true, ksize is ignored and need not be "
"specified."); // TODO(Add checker)
AddAttr<bool>(
"globalPooling",
"whether to use the globalPooling."
"int constant equal to false or true"
"default false"
"Whether to use the globalPooling."
"Bool constant equal to false or true."
"Default false."
"If globalPooling = true, ksize is ignored and need not be specified.")
.SetDefault(false);
AddAttr<std::vector<int>>(
"strides",
"strides(depth, height, width) of pooling operator."
"default {1,1,1}")
.SetDefault({1, 1, 1})
.AddCustomChecker(GreaterThanChecker_pool({0, 0, 0}));
"Strides(depth, height, width) of pooling operator."
"Default {1,1,1}.")
.SetDefault({1, 1, 1}); // TODO(Add checker)
AddAttr<std::vector<int>>(
"paddings",
"paddings(depth, height, width) of pooling operator."
"default {0,0,0}")
.SetDefault({0, 0, 0})
.AddCustomChecker(EqualGreaterThanChecker_pool({0, 0, 0}));
"Paddings(depth, height, width) of pooling operator."
"Default {0,0,0}.")
.SetDefault({0, 0, 0}); // TODO(Add checker)
AddComment(R"DOC(
The pooling3d operation calculates the output based on
the input, poolingType and ksize, strides, paddings parameters.
)DOC");
}
private:
struct GreaterThanChecker_pool {
public:
explicit GreaterThanChecker_pool(std::vector<int> lower_bound)
: lower_bound_(lower_bound) {}
void operator()(std::vector<int> &value) const {
PADDLE_ENFORCE(value.size() == lower_bound_.size(), "equal check fails.");
for (size_t i = 0; i < value.size(); ++i) {
PADDLE_ENFORCE(value[i] > lower_bound_[i], "larger_than check fails.");
}
}
private:
std::vector<int> lower_bound_;
};
struct EqualGreaterThanChecker_pool {
public:
explicit EqualGreaterThanChecker_pool(std::vector<int> lower_bound)
: lower_bound_(lower_bound) {}
void operator()(std::vector<int> &value) const {
PADDLE_ENFORCE(value.size() == lower_bound_.size(), "equal check fails.");
for (size_t i = 0; i < value.size(); ++i) {
PADDLE_ENFORCE(value[i] >= lower_bound_[i], "larger_than check fails.");
}
}
private:
std::vector<int> lower_bound_;
};
};
} // namespace operators
} // namespace paddle
......
......@@ -31,12 +31,11 @@ class PoolKernel : public framework::OpKernel {
const Tensor* in_x = context.Input<Tensor>("X");
Tensor* out = context.Output<Tensor>("Out");
bool global_pooling = context.Attr<bool>("globalPooling");
std::string pooling_type = context.Attr<std::string>("poolingType");
std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
if (global_pooling) {
if (context.Attr<bool>("globalPooling")) {
for (size_t i = 0; i < ksize.size(); ++i) {
ksize[i] = static_cast<int>(in_x->dims()[i + 2]);
}
......@@ -46,17 +45,17 @@ class PoolKernel : public framework::OpKernel {
case 2: {
if (pooling_type == "max") {
paddle::operators::math::Pool2dFunctor<
Place, paddle::operators::math::maxPool<T>, T>
Place, paddle::operators::math::MaxPool<T>, T>
pool2d_forward;
paddle::operators::math::maxPool<T> pool_process;
paddle::operators::math::MaxPool<T> pool_process;
pool2d_forward(context.device_context(), *in_x, *out, ksize, strides,
paddings, pool_process);
} else if (pooling_type == "avg") {
paddle::operators::math::Pool2dFunctor<
Place, paddle::operators::math::avgPool<T>, T>
Place, paddle::operators::math::AvgPool<T>, T>
pool2d_forward;
paddle::operators::math::avgPool<T> pool_process;
paddle::operators::math::AvgPool<T> pool_process;
pool2d_forward(context.device_context(), *in_x, *out, ksize, strides,
paddings, pool_process);
}
......@@ -64,16 +63,16 @@ class PoolKernel : public framework::OpKernel {
case 3: {
if (pooling_type == "max") {
paddle::operators::math::Pool3dFunctor<
Place, paddle::operators::math::maxPool<T>, T>
Place, paddle::operators::math::MaxPool<T>, T>
pool3d_forward;
paddle::operators::math::maxPool<T> pool_process;
paddle::operators::math::MaxPool<T> pool_process;
pool3d_forward(context.device_context(), *in_x, *out, ksize, strides,
paddings, pool_process);
} else if (pooling_type == "avg") {
paddle::operators::math::Pool3dFunctor<
Place, paddle::operators::math::avgPool<T>, T>
Place, paddle::operators::math::AvgPool<T>, T>
pool3d_forward;
paddle::operators::math::avgPool<T> pool_process;
paddle::operators::math::AvgPool<T> pool_process;
pool3d_forward(context.device_context(), *in_x, *out, ksize, strides,
paddings, pool_process);
}
......@@ -92,13 +91,12 @@ class PoolGradKernel : public framework::OpKernel {
context.Input<Tensor>(framework::GradVarName("Out"));
Tensor* in_x_grad = context.Output<Tensor>(framework::GradVarName("X"));
bool global_pooling = context.Attr<bool>("globalPooling");
std::string pooling_type = context.Attr<std::string>("poolingType");
std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
if (global_pooling) {
if (context.Attr<bool>("globalPooling")) {
for (size_t i = 0; i < ksize.size(); ++i)
ksize[i] = static_cast<int>(in_x->dims()[i + 2]);
}
......@@ -118,9 +116,9 @@ class PoolGradKernel : public framework::OpKernel {
*out_grad, ksize, strides, paddings);
} else if (pooling_type == "avg") {
paddle::operators::math::Pool2dGradFunctor<
Place, paddle::operators::math::avgPoolGrad<T>, T>
Place, paddle::operators::math::AvgPoolGrad<T>, T>
pool2d_backward;
paddle::operators::math::avgPoolGrad<T> pool_process;
paddle::operators::math::AvgPoolGrad<T> pool_process;
pool2d_backward(context.device_context(), *in_x, *in_x_grad, *out,
*out_grad, ksize, strides, paddings, pool_process);
}
......@@ -133,9 +131,9 @@ class PoolGradKernel : public framework::OpKernel {
*out_grad, ksize, strides, paddings);
} else if (pooling_type == "avg") {
paddle::operators::math::Pool3dGradFunctor<
Place, paddle::operators::math::avgPoolGrad<T>, T>
Place, paddle::operators::math::AvgPoolGrad<T>, T>
pool3d_backward;
paddle::operators::math::avgPoolGrad<T> pool_process;
paddle::operators::math::AvgPoolGrad<T> pool_process;
pool3d_backward(context.device_context(), *in_x, *in_x_grad, *out,
*out_grad, ksize, strides, paddings, pool_process);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册