未验证 提交 86685190 编写于 作者: L limingshu 提交者: GitHub

Optimization of pool2d grad (#35389)

* Optimization of pool2d grad, first commit.

* remove useless print codes

* refine codes

* refine codes

* seal more operation into template specialization

* fix template struct error in MaxPool2dGrad.

* Fix header including error

* refine code with comment

* Seal the param-preparation codes into function for common use.

* Seal the param-preparation codes into function for common use.

* Seal the param-preparation into funciton and make it common for other kernels

* polish code and erase useless template speicalization

* Rerun triger

* rerun trigger
上级 9f88d327
......@@ -16,65 +16,140 @@ limitations under the License. */
#include <vector>
#include "paddle/fluid/operators/math/pooling.h"
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/fast_divmod.h"
#include "paddle/fluid/platform/gpu_launch_config.h"
#ifdef __HIPCC__
#define POOLING_BLOCK_SIZE 256
#else
#define POOLING_BLOCK_SIZE 512
#endif
namespace paddle {
namespace operators {
namespace math {
struct FastDivModForPooling {
public:
platform::FastDivMod channel;
platform::FastDivMod width;
platform::FastDivMod height;
explicit HOSTDEVICE FastDivModForPooling(const int channels,
const int output_width,
const int output_height) {
channel = platform::FastDivMod(channels);
width = platform::FastDivMod(output_width);
height = platform::FastDivMod(output_height);
}
};
struct FastDivModForPoolingWithMoreStaff {
public:
platform::FastDivMod channel;
platform::FastDivMod width;
platform::FastDivMod height;
platform::FastDivMod ksize_w;
platform::FastDivMod ksize_h;
platform::FastDivMod stride_w;
platform::FastDivMod stride_h;
explicit HOSTDEVICE FastDivModForPoolingWithMoreStaff(
const int channels, const int input_width, const int input_height,
const int ksize_width, const int ksize_height, const int stride_width,
const int stride_height) {
channel = platform::FastDivMod(channels);
width = platform::FastDivMod(input_width);
height = platform::FastDivMod(input_height);
ksize_w = platform::FastDivMod(ksize_width);
ksize_h = platform::FastDivMod(ksize_height);
stride_w = platform::FastDivMod(stride_width);
stride_h = platform::FastDivMod(stride_height);
}
};
template <typename FastDivModForPooling>
__device__ void OffsetPreparationFor4Dimension(
int index, bool channel_last, FastDivModForPooling divmods,
const int pad_width, const int pad_height, const int aux_width,
const int aux_height, int* w_offset, int* h_offset, int* c_offset,
int* stride) {
if (!channel_last) { /* NCHW */
auto input_width_divmod = divmods.width.Divmod(index);
auto input_height_divmod = divmods.height.Divmod(input_width_divmod.val[0]);
auto channel_divmod = divmods.channel.Divmod(input_height_divmod.val[0]);
*w_offset = input_width_divmod.val[1] + pad_width;
*h_offset = input_height_divmod.val[1] + pad_height;
*c_offset = channel_divmod.val[1];
*stride = (channel_divmod.val[0] * divmods.channel.divisor + *c_offset) *
aux_height * aux_width;
} else { /* NHWC */
auto c_divmod = divmods.channel.Divmod(index);
auto input_width_divmod = divmods.width.Divmod(c_divmod.val[0]);
auto input_height_divmod = divmods.height.Divmod(input_width_divmod.val[0]);
*c_offset = c_divmod.val[1];
*w_offset = input_width_divmod.val[1] + pad_width;
*h_offset = input_height_divmod.val[1] + pad_height;
*stride = input_height_divmod.val[0] * aux_height * aux_width *
divmods.channel.divisor;
}
}
int GetThreadsPerBlock(const platform::CUDADeviceContext& ctx,
int threads_per_block, int64_t numel) {
int sm_count = ctx.GetSMCount();
if (numel / (sm_count << 1) < threads_per_block) {
// Round up threads number into an exponential multiple of 2, while number
// of acitve blocks is about twice of SM, to acquire better performance.
threads_per_block = platform::RoundToPowerOfTwo(numel / (sm_count << 1));
} else if (numel / (sm_count << 2) < threads_per_block) {
// Round up threads number into an exponential multiple of 2, while number
// of acitve blocks is about 4 times of SM, to acquire better performance.
threads_per_block = platform::RoundToPowerOfTwo(numel / (sm_count << 2));
}
// Number of threads per block shall be larger than 64.
return std::max(64, threads_per_block);
}
template <typename PoolProcess, typename T>
__global__ void KernelPool2D(const int nthreads, const T* input_data,
const int channels, const int input_height,
const int input_width, const int output_height,
const int output_width, const int ksize_height,
const int ksize_width, const int stride_height,
const int stride_width, const int padding_height,
const int padding_width, PoolProcess pool_process,
bool exclusive, bool adaptive, T* output_data,
__global__ void KernelPool2D(
const int nthreads, const T* input_data, const int channels,
const int input_height, const int input_width, const int output_height,
const int output_width, const int ksize_height, const int ksize_width,
const int stride_height, const int stride_width, const int padding_height,
const int padding_width, FastDivModForPooling divmods,
PoolProcess pool_process, bool exclusive, bool adaptive, T* output_data,
bool channel_last = false) {
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
index += blockDim.x * gridDim.x) {
int pw, ph, c, batch_idx;
if (!channel_last) { /*NCHW*/
pw = index % output_width;
ph = (index / output_width) % output_height;
c = (index / output_width / output_height) % channels;
batch_idx = index / output_width / output_height / channels;
} else { /*NHWC*/
c = index % channels;
pw = (index / channels) % output_width;
ph = (index / channels / output_width) % output_height;
batch_idx = index / channels / output_width / output_height;
}
int hstart, hend, wstart, wend;
int w_offset, h_offset, c_offset, input_offset;
OffsetPreparationFor4Dimension<FastDivModForPooling>(
index, channel_last, divmods, 0, 0, input_width, input_height,
&w_offset, &h_offset, &c_offset, &input_offset);
input_data += input_offset;
int hstart, hend;
int wstart, wend;
if (adaptive) {
hstart = AdaptStartIndex(ph, input_height, output_height);
hend = AdaptEndIndex(ph, input_height, output_height);
wstart = AdaptStartIndex(pw, input_width, output_width);
wend = AdaptEndIndex(pw, input_width, output_width);
hstart = AdaptStartIndex(h_offset, input_height, output_height);
hend = AdaptEndIndex(h_offset, input_height, output_height);
wstart = AdaptStartIndex(w_offset, input_width, output_width);
wend = AdaptEndIndex(w_offset, input_width, output_width);
} else {
hstart = ph * stride_height - padding_height;
hstart = h_offset * stride_height - padding_height;
hend = min(hstart + ksize_height, input_height);
hstart = max(hstart, 0);
wstart = pw * stride_width - padding_width;
wstart = w_offset * stride_width - padding_width;
wend = min(wstart + ksize_width, input_width);
wstart = max(wstart, 0);
}
if (!channel_last) {
input_data += (batch_idx * channels + c) * input_height * input_width;
} else {
input_data += batch_idx * input_height * input_width * channels;
}
T ele = pool_process.initial();
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
auto input_idx = channel_last ? (h * input_width + w) * channels + c
auto input_idx = channel_last
? (h * input_width + w) * channels + c_offset
: h * input_width + w;
pool_process.compute(input_data[input_idx], &ele);
}
......@@ -85,91 +160,109 @@ __global__ void KernelPool2D(const int nthreads, const T* input_data,
output_data[index] = ele;
}
}
template <typename PoolProcess, typename T>
template <typename T, typename PoolProcess>
__global__ void KernelPool2DGrad(
const int nthreads, const T* input_data, const T* output_data,
const T* output_grad, const int channels, const int input_height,
const int input_width, const int output_height, const int output_width,
const int ksize_height, const int ksize_width, const int stride_height,
const int stride_width, const int padding_height, const int padding_width,
PoolProcess pool_process, bool exclusive, bool adaptive, T* input_grad,
bool channel_last = false) {
const int nthreads, const T* __restrict__ input_data,
const T* __restrict__ output_data, const const T* __restrict__ output_grad,
const int output_width, const int output_height, const int input_width,
const int input_height, const int ksize_width, const int ksize_height,
const int stride_width, const int stride_height, const int padding_width,
const int padding_height, FastDivModForPoolingWithMoreStaff divmods,
PoolProcess pool_process, bool exclusive, bool adaptive,
T* __restrict__ input_grad, bool channel_last = false) {
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
index += blockDim.x * gridDim.x) {
int w_offset, h_offset, offsetC, batch_idx;
if (!channel_last) { /* NCHW */
w_offset = index % input_width + padding_width;
h_offset = (index / input_width) % input_height + padding_height;
offsetC = (index / input_width / input_height) % channels;
batch_idx = index / input_width / input_height / channels;
} else { /* NHWC */
offsetC = index % channels;
w_offset = (index / channels) % input_width + padding_width;
h_offset =
(index / channels / input_width) % input_height + padding_height;
batch_idx = index / channels / input_width / input_height;
}
T input = static_cast<T>(0);
T input_grad_data = static_cast<T>(0);
int phstart, phend, pwstart, pwend;
int w_offset, h_offset, c_offset, output_offset;
OffsetPreparationFor4Dimension<>(index, channel_last, divmods,
padding_width, padding_height,
output_width, output_height, &w_offset,
&h_offset, &c_offset, &output_offset);
if (pool_process.use_x) {
input = input_data[index];
output_data += output_offset;
}
output_grad += output_offset;
int phstart, phend;
int pwstart, pwend;
if (adaptive) {
phstart = AdaptStartIndex(h_offset, output_height, input_height);
phend = AdaptEndIndex(h_offset, output_height, input_height);
auto tmp_phend = divmods.height.Divmod((h_offset + 1) * output_height);
auto tmp_pwend = divmods.width.Divmod((w_offset + 1) * output_width);
phstart = divmods.height.Div(h_offset * output_height);
pwstart = divmods.width.Div(w_offset * output_width);
phend = tmp_phend.val[1] > 0 ? tmp_phend.val[0] + 1 : tmp_phend.val[0];
pwend = tmp_pwend.val[1] > 0 ? tmp_pwend.val[0] + 1 : tmp_pwend.val[0];
pwstart = AdaptStartIndex(w_offset, output_width, input_width);
pwend = AdaptEndIndex(w_offset, output_width, input_width);
} else {
phstart = (h_offset < ksize_height)
? 0
: (h_offset - ksize_height) / stride_height + 1;
pwstart = (w_offset < ksize_width)
? 0
: (w_offset - ksize_width) / stride_width + 1;
phend = min(h_offset / stride_height + 1, output_height);
pwend = min(w_offset / stride_width + 1, output_width);
for (int ph = phstart; ph < phend; ++ph) {
for (int pw = pwstart; pw < pwend; ++pw) {
auto ksize_w_divmod = divmods.ksize_w.Divmod(input_width);
auto ksize_h_divmod = divmods.ksize_h.Divmod(input_height);
auto tmp_width = ksize_w_divmod.val[1] > 0 ? ksize_w_divmod.val[0] + 1
: ksize_w_divmod.val[0];
auto tmp_height = ksize_h_divmod.val[1] > 0
? ksize_h_divmod.val[0] + 1
: ksize_h_divmod.val[0];
int pool_size = tmp_height * tmp_width;
int tmp_idx = ph * output_width + pw;
int output_sub_idx =
channel_last ? tmp_idx * divmods.channel.divisor + c_offset
: tmp_idx;
T ouput_value = pool_process.use_x ? output_data[output_sub_idx]
: static_cast<T>(0);
pool_process.compute(input, ouput_value, output_grad[output_sub_idx],
static_cast<T>(1.0 / pool_size),
&input_grad_data);
}
T gradient = static_cast<T>(0.0);
T input = input_data[index];
int output_stride;
if (!channel_last) {
output_stride =
(batch_idx * channels + offsetC) * output_height * output_width;
} else {
output_stride = batch_idx * output_height * output_width * channels;
}
output_data += output_stride;
output_grad += output_stride;
} else {
auto stride_height_div = divmods.stride_h.Div(h_offset - ksize_height);
auto stride_width_div = divmods.stride_w.Div(w_offset - ksize_width);
phstart = (h_offset < ksize_height) ? 0 : stride_height_div + 1;
pwstart = (w_offset < ksize_width) ? 0 : stride_width_div + 1;
phend = min(divmods.stride_h.Div(h_offset) + 1, output_height);
pwend = min(divmods.stride_w.Div(w_offset) + 1, output_width);
if (exclusive) {
for (int ph = phstart; ph < phend; ++ph) {
for (int pw = pwstart; pw < pwend; ++pw) {
int pool_size;
if (adaptive) {
pool_size = static_cast<int>(ceil(static_cast<double>(input_height) /
ksize_height)) *
static_cast<int>(
ceil(static_cast<double>(input_width) / ksize_width));
} else {
int hstart = ph * stride_height - padding_height;
int wstart = pw * stride_width - padding_width;
int hend = min(hstart + ksize_height, input_height);
int wend = min(wstart + ksize_width, input_width);
hstart = max(hstart, 0);
wstart = max(wstart, 0);
pool_size = exclusive ? (hend - hstart) * (wend - wstart)
: ksize_height * ksize_width;
int pool_size = (hend - hstart) * (wend - wstart);
int tmp_idx = ph * output_width + pw;
int output_sub_idx =
channel_last ? tmp_idx * divmods.channel.divisor + c_offset
: tmp_idx;
T ouput_value = pool_process.use_x ? output_data[output_sub_idx]
: static_cast<T>(0);
pool_process.compute(
input, ouput_value, output_grad[output_sub_idx],
static_cast<T>(1.0 / pool_size), &input_grad_data);
}
}
} else {
for (int ph = phstart; ph < phend; ++ph) {
for (int pw = pwstart; pw < pwend; ++pw) {
int pool_size = ksize_height * ksize_width;
int tmp_idx = ph * output_width + pw;
int output_sub_idx =
channel_last ? tmp_idx * divmods.channel.divisor + c_offset
: tmp_idx;
T ouput_value = pool_process.use_x ? output_data[output_sub_idx]
: static_cast<T>(0);
pool_process.compute(
input, ouput_value, output_grad[output_sub_idx],
static_cast<T>(1.0 / pool_size), &input_grad_data);
}
int output_sub_idx = channel_last
? (ph * output_width + pw) * channels + offsetC
: ph * output_width + pw;
pool_process.compute(input, output_data[output_sub_idx],
output_grad[output_sub_idx],
static_cast<T>(1.0 / pool_size), &gradient);
}
}
input_grad[index] = gradient;
}
input_grad[index] = input_grad_data;
}
}
......@@ -180,44 +273,31 @@ __global__ void KernelMaxPool2DGrad(
const int input_width, const int output_height, const int output_width,
const int ksize_height, const int ksize_width, const int stride_height,
const int stride_width, const int padding_height, const int padding_width,
T* input_grad, bool channel_last = false) {
T* input_grad, FastDivModForPooling divmods, bool channel_last = false) {
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
index += blockDim.x * gridDim.x) {
int pw, ph, c, batch_idx;
if (!channel_last) { /* NCHW */
pw = index % output_width;
ph = (index / output_width) % output_height;
c = (index / output_width / output_height) % channels;
batch_idx = index / output_width / output_height / channels;
} else { /* NHWC */
c = index % channels;
pw = (index / channels) % output_width;
ph = (index / channels / output_width) % output_height;
batch_idx = index / channels / output_width / output_height;
}
int hstart = ph * stride_height - padding_height;
int w_offset, h_offset, c_offset, input_offset;
OffsetPreparationFor4Dimension<FastDivModForPooling>(
index, channel_last, divmods, 0, 0, input_width, input_height,
&w_offset, &h_offset, &c_offset, &input_offset);
input_data += input_offset;
input_grad += input_offset;
int hstart = h_offset * stride_height - padding_height;
int hend = min(hstart + ksize_height, input_height);
hstart = max(hstart, 0);
int wstart = pw * stride_width - padding_width;
int wstart = w_offset * stride_width - padding_width;
int wend = min(wstart + ksize_width, input_width);
wstart = max(wstart, 0);
int input_stride;
if (!channel_last) {
input_stride = (batch_idx * channels + c) * input_height * input_width;
} else {
input_stride = batch_idx * input_height * input_width * channels;
}
input_data += input_stride;
input_grad += input_stride;
T ele = output_data[index];
int maxIndex = -1;
bool stop = false;
for (int h = hstart; h < hend && !stop; ++h) {
for (int w = wstart; w < wend && !stop; ++w) {
int input_data_idx = channel_last ? (h * input_width + w) * channels + c
int input_data_idx = channel_last
? (h * input_width + w) * channels + c_offset
: h * input_width + w;
if (ele == input_data[input_data_idx]) {
maxIndex = input_data_idx;
......@@ -264,10 +344,13 @@ void Pool2dDirectCUDAFunctor<PoolProcess, T>::operator()(
dim3 threads(thread_num, 1);
dim3 grid(blocks, 1);
auto pool_divmods =
FastDivModForPooling(input_channels, output_width, output_height);
KernelPool2D<PoolProcess, T><<<grid, threads, 0, stream>>>(
nthreads, input, input_channels, input_height, input_width, output_height,
output_width, ksize_height, ksize_width, stride_height, stride_width,
padding_height, padding_width, pool_compute, exclusive, adaptive, output);
padding_height, padding_width, pool_divmods, pool_compute, exclusive,
adaptive, output);
}
/*
......@@ -311,11 +394,14 @@ class Pool2dFunctor<platform::CUDADeviceContext, PoolProcess, T> {
int blocks = (nthreads + thread_num - 1) / thread_num;
dim3 threads(thread_num, 1);
dim3 grid(blocks, 1);
auto pool_divmods =
FastDivModForPooling(input_channels, output_width, output_height);
KernelPool2D<PoolProcess, T><<<grid, threads, 0, context.stream()>>>(
nthreads, input_data, input_channels, input_height, input_width,
output_height, output_width, ksize_height, ksize_width, stride_height,
stride_width, padding_height, padding_width, pool_process, exclusive,
adaptive, output_data);
stride_width, padding_height, padding_width, pool_divmods, pool_process,
exclusive, adaptive, output_data);
}
void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor& input, const std::vector<int>& ksize,
......@@ -357,11 +443,14 @@ class Pool2dFunctor<platform::CUDADeviceContext, PoolProcess, T> {
int blocks = (nthreads + thread_num - 1) / thread_num;
dim3 threads(thread_num, 1);
dim3 grid(blocks, 1);
auto pool_divmods =
FastDivModForPooling(input_channels, output_width, output_height);
KernelPool2D<PoolProcess, T><<<grid, threads, 0, context.stream()>>>(
nthreads, input_data, input_channels, input_height, input_width,
output_height, output_width, ksize_height, ksize_width, stride_height,
stride_width, padding_height, padding_width, pool_process, exclusive,
adaptive, output_data, channel_last);
stride_width, padding_height, padding_width, pool_divmods, pool_process,
exclusive, adaptive, output_data, channel_last);
}
};
/*
......@@ -402,15 +491,18 @@ class Pool2dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
int nthreads = batch_size * input_channels * input_height * input_width;
int blocks = (nthreads + 1024 - 1) / 1024;
dim3 threads(1024, 1);
dim3 grid(blocks, 1);
int blocks = GetThreadsPerBlock(context, POOLING_BLOCK_SIZE, nthreads);
int grids = (nthreads + blocks - 1) / blocks;
KernelPool2DGrad<PoolProcess, T><<<grid, threads, 0, context.stream()>>>(
nthreads, input_data, output_data, output_grad_data, input_channels,
input_height, input_width, output_height, output_width, ksize_height,
ksize_width, stride_height, stride_width, padding_height, padding_width,
pool_process, exclusive, adaptive, input_grad_data);
auto pool_divmods = FastDivModForPoolingWithMoreStaff(
input_channels, input_width, input_height, ksize_width, ksize_height,
stride_width, stride_height);
KernelPool2DGrad<T, PoolProcess><<<grids, blocks, 0, context.stream()>>>(
nthreads, input_data, output_data, output_grad_data, output_width,
output_height, input_width, input_height, ksize_width, ksize_height,
stride_width, stride_height, padding_width, padding_height,
pool_divmods, pool_process, exclusive, adaptive, input_grad_data);
}
void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor& input,
......@@ -424,7 +516,6 @@ class Pool2dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
bool channel_last = (data_format == "NHWC");
const int batch_size = input.dims()[0];
const int input_channels = channel_last ? input.dims()[3] : input.dims()[1];
const int input_height = channel_last ? input.dims()[1] : input.dims()[2];
const int input_width = channel_last ? input.dims()[2] : input.dims()[3];
......@@ -447,19 +538,22 @@ class Pool2dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
const T* input_data = input.data<T>();
const T* output_data = output.data<T>();
const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
int nthreads = batch_size * input_channels * input_height * input_width;
int blocks = (nthreads + 1024 - 1) / 1024;
dim3 threads(1024, 1);
dim3 grid(blocks, 1);
int blocks = GetThreadsPerBlock(context, POOLING_BLOCK_SIZE, nthreads);
int grids = (nthreads + blocks - 1) / blocks;
KernelPool2DGrad<PoolProcess, T><<<grid, threads, 0, context.stream()>>>(
nthreads, input_data, output_data, output_grad_data, input_channels,
input_height, input_width, output_height, output_width, ksize_height,
ksize_width, stride_height, stride_width, padding_height, padding_width,
pool_process, exclusive, adaptive, input_grad_data, channel_last);
auto pool_divmods = FastDivModForPoolingWithMoreStaff(
input_channels, input_width, input_height, ksize_width, ksize_height,
stride_width, stride_height);
KernelPool2DGrad<T, PoolProcess><<<grids, blocks, 0, context.stream()>>>(
nthreads, input_data, output_data, output_grad_data, output_width,
output_height, input_width, input_height, ksize_width, ksize_height,
stride_width, stride_height, padding_width, padding_height,
pool_divmods, pool_process, exclusive, adaptive, input_grad_data,
channel_last);
}
};
......@@ -505,11 +599,13 @@ class MaxPool2dGradFunctor<platform::CUDADeviceContext, T> {
dim3 threads(1024, 1);
dim3 grid(blocks, 1);
auto pool_divmods =
FastDivModForPooling(input_channels, output_width, output_height);
KernelMaxPool2DGrad<T><<<grid, threads, 0, context.stream()>>>(
nthreads, input_data, output_data, output_grad_data, input_channels,
input_height, input_width, output_height, output_width, ksize_height,
ksize_width, stride_height, stride_width, padding_height, padding_width,
input_grad_data);
input_grad_data, pool_divmods);
}
void operator()(
const platform::CUDADeviceContext& context,
......@@ -550,11 +646,14 @@ class MaxPool2dGradFunctor<platform::CUDADeviceContext, T> {
dim3 threads(1024, 1);
dim3 grid(blocks, 1);
auto pool_divmods =
FastDivModForPooling(input_channels, output_width, output_height);
KernelMaxPool2DGrad<T><<<grid, threads, 0, context.stream()>>>(
nthreads, input_data, output_data, output_grad_data, input_channels,
input_height, input_width, output_height, output_width, ksize_height,
ksize_width, stride_height, stride_width, padding_height, padding_width,
input_grad_data, channel_last);
input_grad_data, pool_divmods, channel_last);
}
};
......@@ -689,35 +788,40 @@ __global__ void KernelPool3D(
}
}
template <typename PoolProcess, typename T>
template <typename T, typename PoolProcess>
__global__ void KernelPool3DGrad(
const int nthreads, const T* input_data, const T* output_data,
const T* output_grad, const int channels, const int input_depth,
const int input_height, const int input_width, const int output_depth,
const int output_height, const int output_width, const int ksize_depth,
const int ksize_height, const int ksize_width, const int stride_depth,
const int stride_height, const int stride_width, const int padding_depth,
const int padding_height, const int padding_width, PoolProcess pool_process,
bool exclusive, bool adaptive, T* input_grad, bool channel_last = false) {
const int nthreads, const T* __restrict__ input_data,
const T* __restrict__ output_data, const T* __restrict__ output_grad,
const int channels, const int input_depth, const int input_height,
const int input_width, const int output_depth, const int output_height,
const int output_width, const int ksize_depth, const int ksize_height,
const int ksize_width, const int stride_depth, const int stride_height,
const int stride_width, const int padding_depth, const int padding_height,
const int padding_width, PoolProcess pool_process, bool exclusive,
bool adaptive, T* input_grad, bool channel_last = false) {
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
index += blockDim.x * gridDim.x) {
int w_offset, h_offset, d_offset, offsetC, batch_idx;
int w_offset, h_offset, d_offset, c_offset, batch_idx, output_stride;
T input = static_cast<T>(0);
if (!channel_last) { /* "NCDHW" */
w_offset = index % input_width + padding_width;
h_offset = (index / input_width) % input_height + padding_height;
d_offset =
(index / input_width / input_height) % input_depth + padding_depth;
offsetC = (index / input_width / input_height / input_depth) % channels;
c_offset = (index / input_width / input_height / input_depth) % channels;
batch_idx = index / input_width / input_height / input_depth / channels;
output_stride = (batch_idx * channels + c_offset) * output_depth *
output_height * output_width;
} else { /* "NDHWC" */
offsetC = index % channels;
c_offset = index % channels;
w_offset = (index / channels) % input_width + padding_width;
h_offset =
(index / channels / input_width) % input_height + padding_height;
d_offset = (index / channels / input_width / input_height) % input_depth +
padding_depth;
batch_idx = index / channels / input_width / input_height / input_depth;
output_stride =
batch_idx * output_depth * output_height * output_width * channels;
}
int pdstart, pdend;
......@@ -746,20 +850,12 @@ __global__ void KernelPool3DGrad(
phend = min((h_offset) / stride_height + 1, output_height);
pwend = min((w_offset) / stride_width + 1, output_width);
}
T gradient = static_cast<T>(0.0);
T input = input_data[index];
int output_stride;
if (!channel_last) {
output_stride = (batch_idx * channels + offsetC) * output_depth *
output_height * output_width;
} else {
output_stride =
batch_idx * output_depth * output_height * output_width * channels;
}
if (pool_process.use_x) {
input = input_data[index];
output_data += output_stride;
}
output_grad += output_stride;
T input_grad_data = static_cast<T>(0.0);
for (int pd = pdstart; pd < pdend; ++pd) {
for (int ph = phstart; ph < phend; ++ph) {
......@@ -792,16 +888,17 @@ __global__ void KernelPool3DGrad(
int output_sub_idx =
channel_last
? ((pd * output_height + ph) * output_width + pw) * channels +
offsetC
c_offset
: (pd * output_height + ph) * output_width + pw;
pool_process.compute(input, output_data[output_sub_idx],
output_grad[output_sub_idx],
static_cast<T>(1.0 / pool_size), &gradient);
T ouput_value = pool_process.use_x ? output_data[output_sub_idx]
: static_cast<T>(0);
pool_process.compute(input, ouput_value, output_grad[output_sub_idx],
static_cast<T>(1.0 / pool_size),
&input_grad_data);
}
}
}
input_grad[index] = gradient;
input_grad[index] = input_grad_data;
}
}
......@@ -1045,7 +1142,7 @@ class Pool3dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
dim3 threads(1024, 1);
dim3 grid(blocks, 1);
KernelPool3DGrad<PoolProcess, T><<<grid, threads, 0, context.stream()>>>(
KernelPool3DGrad<T, PoolProcess><<<grid, threads, 0, context.stream()>>>(
nthreads, input_data, output_data, output_grad_data, input_channels,
input_depth, input_height, input_width, output_depth, output_height,
output_width, ksize_depth, ksize_height, ksize_width, stride_depth,
......@@ -1099,7 +1196,7 @@ class Pool3dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
dim3 threads(1024, 1);
dim3 grid(blocks, 1);
KernelPool3DGrad<PoolProcess, T><<<grid, threads, 0, context.stream()>>>(
KernelPool3DGrad<T, PoolProcess><<<grid, threads, 0, context.stream()>>>(
nthreads, input_data, output_data, output_grad_data, input_channels,
input_depth, input_height, input_width, output_depth, output_height,
output_width, ksize_depth, ksize_height, ksize_width, stride_depth,
......@@ -1267,33 +1364,33 @@ __global__ void KernelMaxPool2dWithIdx(
const int input_height, const int input_width, const int output_height,
const int output_width, const int ksize_height, const int ksize_width,
const int stride_height, const int stride_width, const int padding_height,
const int padding_width, bool adaptive, T1* output_data, T2* mask_data) {
const int padding_width, bool adaptive, T1* output_data, T2* mask_data,
FastDivModForPooling divmods) {
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
index += blockDim.x * gridDim.x) {
int pw = index % output_width;
int ph = (index / output_width) % output_height;
int c = (index / output_width / output_height) % channels;
int batch_idx = index / output_width / output_height / channels;
int hstart, hend, wstart, wend;
int w_offset, h_offset, c_offset, input_offset;
OffsetPreparationFor4Dimension<FastDivModForPooling>(
index, false, divmods, 0, 0, input_width, input_height, &w_offset,
&h_offset, &c_offset, &input_offset);
input_data += input_offset;
int hstart, hend;
int wstart, wend;
if (adaptive) {
hstart = AdaptStartIndex(ph, input_height, output_height);
hend = AdaptEndIndex(ph, input_height, output_height);
hstart = AdaptStartIndex(h_offset, input_height, output_height);
hend = AdaptEndIndex(h_offset, input_height, output_height);
wstart = AdaptStartIndex(pw, input_width, output_width);
wend = AdaptEndIndex(pw, input_width, output_width);
wstart = AdaptStartIndex(w_offset, input_width, output_width);
wend = AdaptEndIndex(w_offset, input_width, output_width);
} else {
hstart = ph * stride_height - padding_height;
hstart = h_offset * stride_height - padding_height;
hend = min(hstart + ksize_height, input_height);
hstart = max(hstart, 0);
wstart = pw * stride_width - padding_width;
wstart = w_offset * stride_width - padding_width;
wend = min(wstart + ksize_width, input_width);
wstart = max(wstart, 0);
}
input_data += (batch_idx * channels + c) * input_height * input_width;
T1 ele = -FLT_MAX;
int max_index = -1;
for (int h = hstart; h < hend; ++h) {
......@@ -1317,16 +1414,17 @@ __global__ void KernelMaxPool2DWithIdxGrad(
const int output_height, const int output_width, const int ksize_height,
const int ksize_width, const int stride_height, const int stride_width,
const int padding_height, const int padding_width, bool adaptive,
T1* input_grad) {
T1* input_grad, FastDivModForPooling divmods) {
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
index += blockDim.x * gridDim.x) {
int w_offset = index % input_width;
int h_offset = (index / input_width) % input_height;
int offsetC = (index / input_width / input_height) % channels;
int batch_idx = index / input_width / input_height / channels;
int phstart, phend, pwstart, pwend;
int w_offset, h_offset, c_offset, output_offset;
OffsetPreparationFor4Dimension<FastDivModForPooling>(
index, false, divmods, 0, 0, output_width, output_height, &w_offset,
&h_offset, &c_offset, &output_offset);
mask_data += output_offset;
output_grad += output_offset;
int phstart, phend;
int pwstart, pwend;
if (adaptive) {
phstart = h_offset * output_height / input_height;
phend =
......@@ -1348,20 +1446,15 @@ __global__ void KernelMaxPool2DWithIdxGrad(
pwend = min((w_offset + padding_width) / stride_width + 1, output_width);
}
T1 gradient = 0;
T1 input_grad_data = 0;
int input_current_featuremap_idx = h_offset * input_width + w_offset;
int output_idx =
(batch_idx * channels + offsetC) * output_height * output_width;
mask_data += output_idx;
output_grad += output_idx;
for (int ph = phstart; ph < phend; ++ph) {
for (int pw = pwstart; pw < pwend; ++pw) {
if (mask_data[ph * output_width + pw] == input_current_featuremap_idx)
gradient += output_grad[ph * output_width + pw];
input_grad_data += output_grad[ph * output_width + pw];
}
}
input_grad[index] = gradient;
input_grad[index] = input_grad_data;
}
}
......@@ -1405,11 +1498,14 @@ class MaxPool2dWithIndexFunctor<platform::CUDADeviceContext, T1, T2> {
int blocks = (nthreads + thread_num - 1) / thread_num;
dim3 threads(thread_num, 1);
dim3 grid(blocks, 1);
auto pool_divmods =
FastDivModForPooling(input_channels, output_width, output_height);
KernelMaxPool2dWithIdx<T1, T2><<<grid, threads, 0, context.stream()>>>(
nthreads, input_data, input_channels, input_height, input_width,
output_height, output_width, ksize_height, ksize_width, stride_height,
stride_width, padding_height, padding_width, adaptive, output_data,
mask_data);
mask_data, pool_divmods);
}
};
......@@ -1449,11 +1545,13 @@ class MaxPool2dWithIndexGradFunctor<platform::CUDADeviceContext, T1, T2> {
dim3 threads(1024, 1);
dim3 grid(blocks, 1);
auto pool_divmods =
FastDivModForPooling(input_channels, input_width, input_height);
KernelMaxPool2DWithIdxGrad<T1, T2><<<grid, threads, 0, context.stream()>>>(
nthreads, output_grad_data, mask_data, input_channels, input_height,
input_width, output_height, output_width, ksize_height, ksize_width,
stride_height, stride_width, padding_height, padding_width, adaptive,
input_grad_data);
input_grad_data, pool_divmods);
}
};
......@@ -1542,7 +1640,8 @@ __global__ void KernelMaxPool3DWithIdxGrad(
int w_offset = index % input_width;
int h_offset = (index / input_width) % input_height;
int d_offset = (index / input_width / input_height) % input_depth;
int offsetC = (index / input_width / input_height / input_depth) % channels;
int c_offset =
(index / input_width / input_height / input_depth) % channels;
int batch_idx = index / input_width / input_height / input_depth / channels;
int pdstart, pdend;
......@@ -1577,10 +1676,10 @@ __global__ void KernelMaxPool3DWithIdxGrad(
pwend = min((w_offset + padding_width) / stride_width + 1, output_width);
}
T1 gradient = 0;
T1 input_grad_data = 0;
int input_current_feature_map_idx =
(d_offset * input_height + h_offset) * input_width + w_offset;
int output_idx = (batch_idx * channels + offsetC) * output_depth *
int output_idx = (batch_idx * channels + c_offset) * output_depth *
output_height * output_width;
mask += output_idx;
output_grad += output_idx;
......@@ -1590,12 +1689,12 @@ __global__ void KernelMaxPool3DWithIdxGrad(
for (int pw = pwstart; pw < pwend; ++pw) {
if (mask[(pd * output_height + ph) * output_width + pw] ==
input_current_feature_map_idx)
gradient +=
input_grad_data +=
output_grad[(pd * output_height + ph) * output_width + pw];
}
}
}
input_grad[index] = gradient;
input_grad[index] = input_grad_data;
}
}
......
......@@ -68,7 +68,8 @@ class AvgPool {
template <class T>
class MaxPoolGrad {
public:
DEVICE inline void compute(const T& x, const T& y, const T& dy, T scale,
static constexpr bool use_x = true;
HOSTDEVICE inline void compute(const T& x, const T& y, const T& dy, T scale,
T* dx) {
*dx += dy * static_cast<T>(x == y);
}
......@@ -77,7 +78,8 @@ class MaxPoolGrad {
template <class T>
class AvgPoolGrad {
public:
DEVICE inline void compute(const T& x, const T& y, const T& dy, T scale,
static constexpr bool use_x = false;
HOSTDEVICE inline void compute(const T& x, const T& y, const T& dy, T scale,
T* dx) {
*dx += (scale * dy);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册