提交 e967645c 编写于 作者: H hedaoyuan

Refine the gpu code.

上级 f7be9cb9
if(WITH_GPU)
nv_library(math_function SRCS math_function.cc math_function.cu im2col.cc DEPS cblas device_context)
nv_library(math_function SRCS math_function.cc math_function.cu im2col.cc
im2col.cu DEPS cblas device_context)
else()
cc_library(math_function SRCS math_function.cc im2col.cc DEPS cblas device_context)
endif()
......
......@@ -12,86 +12,89 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "Im2Col.h"
#include "hl_device_functions.cuh"
#include "paddle/operators/math/im2col.h"
#include "paddle/platform/cuda_helper.h"
namespace paddle {
template <class T>
__global__ void im2col(const T* data_im, int numOuts, int height, int width,
int blockH, int blockW, int strideH, int strideW,
int paddingH, int paddingW, int height_col,
int width_col, T* data_col) {
__global__ void im2col(const T* data_im, int num_outs, int height, int width,
int filter_height, int filter_width, int stride_height,
int stride_width, int padding_height, int padding_width,
int output_height, int output_width, T* data_col) {
int index = (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
if (index < numOuts) {
int w_out = index % width_col;
index /= width_col;
int h_out = index % height_col;
int channel_in = index / height_col;
int channel_out = channel_in * blockH * blockW;
int h_in = h_out * strideH;
int w_in = w_out * strideW;
if (index < num_outs) {
int w_out = index % output_width;
index /= output_width;
int h_out = index % output_height;
int channel_in = index / output_height;
int channel_out = channel_in * filter_height * filter_width;
int h_in = h_out * stride_height;
int w_in = w_out * stride_width;
data_col += (channel_out * height_col + h_out) * width_col + w_out;
for (int i = 0; i < blockH; ++i) {
for (int j = 0; j < blockW; ++j) {
data_col += (channel_out * output_height + h_out) * output_width + w_out;
for (int i = 0; i < filter_height; ++i) {
for (int j = 0; j < filter_width; ++j) {
int rIdx = int(h_in + i);
int cIdx = int(w_in + j);
if ((rIdx - (int)paddingH) >= (int)height ||
(rIdx - (int)paddingH) < 0 ||
(cIdx - (int)paddingW) >= (int)width ||
(cIdx - (int)paddingW) < 0) {
if ((rIdx - (int)padding_height) >= (int)height ||
(rIdx - (int)padding_height) < 0 ||
(cIdx - (int)padding_width) >= (int)width ||
(cIdx - (int)padding_width) < 0) {
*data_col = 0;
} else {
rIdx = rIdx + channel_in * height - paddingH;
cIdx = cIdx - paddingW;
rIdx = rIdx + channel_in * height - padding_height;
cIdx = cIdx - padding_width;
*data_col = data_im[rIdx * width + cIdx];
}
data_col += height_col * width_col;
data_col += output_height * output_width;
}
}
}
}
/*
* imShape = [inputChannels, inputHeight, inputWidth]
* colShape =
* [inputChannels, filterHeight, filterWidth, outputHeight, outputWidth]
* im = [input_channels, input_height, input_width]
* col =
* [input_channels, filter_height, filter_width, output_height, output_width]
*/
template <class T>
class Im2ColFunctor<kCFO, DEVICE_TYPE_GPU, T> {
class Im2ColFunctor<kCFO, platform::GPUPlace, T> {
public:
void operator()(const T* imData, const TensorShape& imShape, T* colData,
const TensorShape& colShape, int strideHeight,
int strideWidth, int paddingHeight, int paddingWidth) {
int inputChannels = imShape[0];
int inputHeight = imShape[1];
int inputWidth = imShape[2];
int filterHeight = colShape[1];
int filterWidth = colShape[2];
int outputHeight = colShape[3];
int outputWidth = colShape[4];
void operator()(const framework::Tensor& im, framework::Tensor& col,
int stride_height, int stride_width, int padding_height,
int padding_width) {
PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5);
int numKernels = inputChannels * outputHeight * outputWidth;
int blocks = (numKernels + 1024 - 1) / 1024;
int blockX = 512;
int blockY = (blocks + 512 - 1) / 512;
int input_channels = im.dims()[0];
int input_height = im.dims()[1];
int input_width = im.dims()[2];
int filter_height = col.dims()[1];
int filter_width = col.dims()[2];
int output_height = col.dims()[3];
int output_width = col.dims()[4];
int num_outputs = input_channels * output_height * output_width;
int blocks = (num_outputs + 1024 - 1) / 1024;
int block_x = 512;
int block_y = (blocks + 512 - 1) / 512;
dim3 threads(1024, 1);
dim3 grid(blockX, blockY);
im2col<T><<<grid, threads, 0, STREAM_DEFAULT>>>(
imData, numKernels, inputHeight, inputWidth, filterHeight, filterWidth,
strideHeight, strideWidth, paddingHeight, paddingWidth, outputHeight,
outputWidth, colData);
CHECK_SYNC("Im2ColFunctor GPU failed");
dim3 grid(block_x, block_y);
im2col<T><<<grid, threads>>>(
im.data<T>(), num_outputs, input_height, input_width, filter_height,
filter_width, stride_height, stride_width, padding_height,
padding_width, output_height, output_width, col.data<T>());
}
};
template <class T>
__global__ void col2im(size_t n, const T* data_col, size_t height, size_t width,
size_t channels, size_t blockH, size_t blockW,
size_t strideH, size_t strideW, size_t paddingH,
size_t paddingW, size_t height_col, size_t width_col,
T* data_im) {
size_t channels, size_t filter_height,
size_t filter_width, size_t stride_height,
size_t stride_width, size_t padding_height,
size_t padding_width, size_t output_height,
size_t output_width, T* data_im) {
size_t index =
(blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
if (index < n) {
......@@ -99,104 +102,112 @@ __global__ void col2im(size_t n, const T* data_col, size_t height, size_t width,
int w = int(index % width);
int h = int((index / width) % height);
int c = int(index / (width * height));
if ((w - (int)paddingW) >= 0 &&
(w - (int)paddingW) < (width - 2 * paddingW) &&
(h - (int)paddingH) >= 0 && (h - paddingH) < (height - 2 * paddingH)) {
if ((w - (int)padding_width) >= 0 &&
(w - (int)padding_width) < (width - 2 * padding_width) &&
(h - (int)padding_height) >= 0 &&
(h - padding_height) < (height - 2 * padding_height)) {
// compute the start and end of the output
int w_col_start =
(w < (int)blockW) ? 0 : (w - int(blockW)) / (int)strideW + 1;
int w_col_end = min((int)(w / (int)strideW + 1), (int)(width_col));
int h_col_start =
(h < (int)blockH) ? 0 : (h - (int)blockH) / (int)strideH + 1;
int h_col_end = min(int(h / strideH + 1), int(height_col));
int w_col_start = (w < (int)filter_width)
? 0
: (w - int(filter_width)) / (int)stride_width + 1;
int w_col_end =
min((int)(w / (int)stride_width + 1), (int)(output_width));
int h_col_start = (h < (int)filter_height)
? 0
: (h - (int)filter_height) / (int)stride_height + 1;
int h_col_end = min(int(h / stride_height + 1), int(output_height));
for (int h_col = h_col_start; h_col < h_col_end; ++h_col) {
for (int w_col = w_col_start; w_col < w_col_end; ++w_col) {
// the col location: [c * width * height + h_out, w_out]
int c_col = int(c * blockH * blockW) +
(h - h_col * (int)strideH) * (int)blockW +
(w - w_col * (int)strideW);
val += data_col[(c_col * height_col + h_col) * width_col + w_col];
int c_col = int(c * filter_height * filter_width) +
(h - h_col * (int)stride_height) * (int)filter_width +
(w - w_col * (int)stride_width);
val +=
data_col[(c_col * output_height + h_col) * output_width + w_col];
}
}
h -= paddingH;
w -= paddingW;
data_im[c * ((width - 2 * paddingW) * (height - 2 * paddingH)) +
h * (width - 2 * paddingW) + w] += val;
h -= padding_height;
w -= padding_width;
data_im[c * ((width - 2 * padding_width) *
(height - 2 * padding_height)) +
h * (width - 2 * padding_width) + w] += val;
}
}
}
/*
* imShape = [inputChannels, inputHeight, inputWidth]
* colShape =
* [inputChannels, filterHeight, filterWidth, outputHeight, outputWidth]
* im = [input_channels, input_height, input_width]
* col =
* [input_channels, filter_height, filter_width, output_height, output_width]
*/
template <class T>
class Col2ImFunctor<kCFO, DEVICE_TYPE_GPU, T> {
class Col2ImFunctor<kCFO, platform::GPUPlace, T> {
public:
void operator()(T* imData, const TensorShape& imShape, const T* colData,
const TensorShape& colShape, int strideHeight,
int strideWidth, int paddingHeight, int paddingWidth) {
int inputChannels = imShape[0];
int inputHeight = imShape[1];
int inputWidth = imShape[2];
int filterHeight = colShape[1];
int filterWidth = colShape[2];
int outputHeight = colShape[3];
int outputWidth = colShape[4];
void operator()(framework::Tensor& im, const framework::Tensor& col,
int stride_height, int stride_width, int padding_height,
int padding_width) {
PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0];
int input_height = im.dims()[1];
int input_width = im.dims()[2];
int filter_height = col.dims()[1];
int filter_width = col.dims()[2];
int output_height = col.dims()[3];
int output_width = col.dims()[4];
size_t numKernels = inputChannels * (inputHeight + 2 * paddingHeight) *
(inputWidth + 2 * paddingWidth);
size_t num_kernels = input_channels * (input_height + 2 * padding_height) *
(input_width + 2 * padding_width);
size_t blocks = (numKernels + 1024 - 1) / 1024;
size_t blockX = 512;
size_t blockY = (blocks + 512 - 1) / 512;
size_t blocks = (num_kernels + 1024 - 1) / 1024;
size_t block_x = 512;
size_t block_y = (blocks + 512 - 1) / 512;
dim3 threads(1024, 1);
dim3 grid(blockX, blockY);
dim3 grid(block_x, block_y);
// To avoid involving atomic operations, we will launch one kernel per
// bottom dimension, and then in the kernel add up the top dimensions.
col2im<T><<<grid, threads, 0, STREAM_DEFAULT>>>(
numKernels, colData, inputHeight + 2 * paddingHeight,
inputWidth + 2 * paddingWidth, inputChannels, filterHeight, filterWidth,
strideHeight, strideWidth, paddingHeight, paddingWidth, outputHeight,
outputWidth, imData);
CHECK_SYNC("Col2ImFunctor GPU failed");
col2im<T><<<grid, threads>>>(
num_kernels, col.data<T>(), input_height + 2 * padding_height,
input_width + 2 * padding_width, input_channels, filter_height,
filter_width, stride_height, stride_width, padding_height,
padding_width, output_height, output_width, im.data<T>());
}
};
template class Im2ColFunctor<kCFO, DEVICE_TYPE_GPU, float>;
template class Im2ColFunctor<kCFO, DEVICE_TYPE_GPU, double>;
template class Col2ImFunctor<kCFO, DEVICE_TYPE_GPU, float>;
template class Col2ImFunctor<kCFO, DEVICE_TYPE_GPU, double>;
template class Im2ColFunctor<kCFO, platform::GPUPlace, float>;
template class Im2ColFunctor<kCFO, platform::GPUPlace, double>;
template class Col2ImFunctor<kCFO, platform::GPUPlace, float>;
template class Col2ImFunctor<kCFO, platform::GPUPlace, double>;
template <class T>
__global__ void im2colOCF(const T* imData, T* colData, int inputChannels,
int inputHeight, int inputWidth, int filterHeight,
int filterWidth, int strideHeight, int strideWidth,
int paddingHeight, int paddingWidth, int outputHeight,
int outputWidth) {
int swId = blockIdx.x;
int shId = blockIdx.y;
for (int channelId = threadIdx.z; channelId < inputChannels;
channelId += blockDim.z) {
for (int idy = threadIdx.y; idy < filterHeight; idy += blockDim.y) {
for (int idx = threadIdx.x; idx < filterWidth; idx += blockDim.x) {
int widthOffset = idx + swId * strideWidth - paddingWidth;
int heightOffset = idy + shId * strideHeight - paddingHeight;
int imOffset = widthOffset + heightOffset * inputWidth +
channelId * inputHeight * inputWidth;
__global__ void im2colOCF(const T* im_data, T* col_data, int input_channels,
int input_height, int input_width, int filter_height,
int filter_width, int stride_height, int stride_width,
int padding_height, int padding_width,
int output_height, int output_width) {
int swid = blockIdx.x;
int shid = blockIdx.y;
for (int channelid = threadIdx.z; channelid < input_channels;
channelid += blockDim.z) {
for (int idy = threadIdx.y; idy < filter_height; idy += blockDim.y) {
for (int idx = threadIdx.x; idx < filter_width; idx += blockDim.x) {
int width_offset = idx + swid * stride_width - padding_width;
int height_offset = idy + shid * stride_height - padding_height;
int im_offset = width_offset + height_offset * input_width +
channelid * input_height * input_width;
int colOffset = idx + idy * filterWidth +
channelId * filterHeight * filterWidth +
(shId * outputWidth + swId) *
(inputChannels * filterHeight * filterWidth);
int col_offset = idx + idy * filter_width +
channelid * filter_height * filter_width +
(shid * output_width + swid) *
(input_channels * filter_height * filter_width);
if (heightOffset >= inputHeight || heightOffset < 0 ||
widthOffset >= inputWidth || widthOffset < 0) {
colData[colOffset] = T(0);
if (height_offset >= input_height || height_offset < 0 ||
width_offset >= input_width || width_offset < 0) {
col_data[col_offset] = T(0);
} else {
colData[colOffset] = imData[imOffset];
col_data[col_offset] = im_data[im_offset];
}
}
}
......@@ -204,76 +215,79 @@ __global__ void im2colOCF(const T* imData, T* colData, int inputChannels,
}
/*
* imShape = [inputChannels, inputHeight, inputWidth]
* colShape =
* [outputHeight, outputWidth, inputChannels, filterHeight, filterWidth]
* im = [input_channels, input_height, input_width]
* col =
* [output_height, output_width, input_channels, filter_height, filter_width]
*/
template <class T>
class Im2ColFunctor<kOCF, DEVICE_TYPE_GPU, T> {
class Im2ColFunctor<kOCF, platform::GPUPlace, T> {
public:
void operator()(const T* imData, const TensorShape& imShape, T* colData,
const TensorShape& colShape, int strideHeight,
int strideWidth, int paddingHeight, int paddingWidth) {
int inputChannels = imShape[0];
int inputHeight = imShape[1];
int inputWidth = imShape[2];
int filterHeight = colShape[3];
int filterWidth = colShape[4];
int outputHeight = colShape[0];
int outputWidth = colShape[1];
void operator()(const framework::Tensor& im, framework::Tensor& col,
int stride_height, int stride_width, int padding_height,
int padding_width) {
PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0];
int input_height = im.dims()[1];
int input_width = im.dims()[2];
int filter_height = col.dims()[3];
int filter_width = col.dims()[4];
int output_height = col.dims()[0];
int output_width = col.dims()[1];
int blockDimX = 0;
int blockDimY = 0;
if (filterHeight <= 4 && filterWidth <= 4) {
blockDimX = 4;
blockDimY = 4;
} else if (filterHeight <= 8 && filterWidth <= 8) {
blockDimX = 8;
blockDimY = 8;
} else if (filterHeight <= 16 && filterWidth <= 16) {
blockDimX = 16;
blockDimY = 16;
int block_dim_x = 0;
int block_dim_y = 0;
if (filter_height <= 4 && filter_width <= 4) {
block_dim_x = 4;
block_dim_y = 4;
} else if (filter_height <= 8 && filter_width <= 8) {
block_dim_x = 8;
block_dim_y = 8;
} else if (filter_height <= 16 && filter_width <= 16) {
block_dim_x = 16;
block_dim_y = 16;
} else {
blockDimX = 32;
blockDimY = 32;
block_dim_x = 32;
block_dim_y = 32;
}
int blockDimZ = 1024 / blockDimX / blockDimY;
dim3 threads(blockDimX, blockDimY, std::min(blockDimZ, inputChannels));
dim3 grid(outputWidth, outputHeight);
im2colOCF<T><<<grid, threads, 0, STREAM_DEFAULT>>>(
imData, colData, inputChannels, inputHeight, inputWidth, filterHeight,
filterWidth, strideHeight, strideWidth, paddingHeight, paddingWidth,
outputHeight, outputWidth);
CHECK_SYNC("Im2ColFunctor GPU failed");
int block_dim_z = 1024 / block_dim_x / block_dim_y;
dim3 threads(block_dim_x, block_dim_y,
std::min(block_dim_z, input_channels));
dim3 grid(output_width, output_height);
im2colOCF<T><<<grid, threads>>>(
im.data<T>(), col.data<T>(), input_channels, input_height, input_width,
filter_height, filter_width, stride_height, stride_width,
padding_height, padding_width, output_height, output_width);
}
};
template <class T>
__global__ void col2imOCF(T* imData, const T* colData, int inputChannels,
int inputHeight, int inputWidth, int filterHeight,
int filterWidth, int strideHeight, int strideWidth,
int paddingHeight, int paddingWidth, int outputHeight,
int outputWidth) {
int swId = blockIdx.x;
int shId = blockIdx.y;
for (int channelId = threadIdx.z; channelId < inputChannels;
channelId += blockDim.z) {
for (int idy = threadIdx.y; idy < filterHeight; idy += blockDim.y) {
for (int idx = threadIdx.x; idx < filterWidth; idx += blockDim.x) {
int widthOffset = idx + swId * strideWidth - paddingWidth;
int heightOffset = idy + shId * strideHeight - paddingHeight;
int imOffset = widthOffset + heightOffset * inputWidth +
channelId * inputHeight * inputWidth;
__global__ void col2imOCF(T* im_data, const T* col_data, int input_channels,
int input_height, int input_width, int filter_height,
int filter_width, int stride_height, int stride_width,
int padding_height, int padding_width,
int output_height, int output_width) {
int swid = blockIdx.x;
int shid = blockIdx.y;
for (int channelid = threadIdx.z; channelid < input_channels;
channelid += blockDim.z) {
for (int idy = threadIdx.y; idy < filter_height; idy += blockDim.y) {
for (int idx = threadIdx.x; idx < filter_width; idx += blockDim.x) {
int width_offset = idx + swid * stride_width - padding_width;
int height_offset = idy + shid * stride_height - padding_height;
int im_offset = width_offset + height_offset * input_width +
channelid * input_height * input_width;
int colOffset = idx + idy * filterWidth +
channelId * filterHeight * filterWidth +
(shId * outputWidth + swId) *
(inputChannels * filterHeight * filterWidth);
int col_offset = idx + idy * filter_width +
channelid * filter_height * filter_width +
(shid * output_width + swid) *
(input_channels * filter_height * filter_width);
if (heightOffset >= 0 && heightOffset < inputHeight &&
widthOffset >= 0 && widthOffset < inputWidth) {
paddle::paddleAtomicAdd(imData + imOffset, colData[colOffset]);
if (height_offset >= 0 && height_offset < input_height &&
width_offset >= 0 && width_offset < input_width) {
paddle::platform::CudaAtomicAdd(im_data + im_offset,
col_data[col_offset]);
}
}
}
......@@ -281,54 +295,56 @@ __global__ void col2imOCF(T* imData, const T* colData, int inputChannels,
}
/*
* imShape = [inputChannels, inputHeight, inputWidth]
* colShape =
* [outputHeight, outputWidth, inputChannels, filterHeight, filterWidth]
* im = [input_channels, input_height, input_width]
* col =
* [output_height, output_width, input_channels, filter_height, filter_width]
*/
template <class T>
class Col2ImFunctor<kOCF, DEVICE_TYPE_GPU, T> {
class Col2ImFunctor<kOCF, platform::GPUPlace, T> {
public:
void operator()(T* imData, const TensorShape& imShape, const T* colData,
const TensorShape& colShape, int strideHeight,
int strideWidth, int paddingHeight, int paddingWidth) {
int inputChannels = imShape[0];
int inputHeight = imShape[1];
int inputWidth = imShape[2];
int filterHeight = colShape[3];
int filterWidth = colShape[4];
int outputHeight = colShape[0];
int outputWidth = colShape[1];
void operator()(framework::Tensor& im, const framework::Tensor& col,
int stride_height, int stride_width, int padding_height,
int padding_width) {
PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0];
int input_height = im.dims()[1];
int input_width = im.dims()[2];
int filter_height = col.dims()[3];
int filter_width = col.dims()[4];
int output_height = col.dims()[0];
int output_width = col.dims()[1];
int blockDimX = 0;
int blockDimY = 0;
if (filterHeight <= 4 && filterWidth <= 4) {
blockDimX = 4;
blockDimY = 4;
} else if (filterHeight <= 8 && filterWidth <= 8) {
blockDimX = 8;
blockDimY = 8;
} else if (filterHeight <= 16 && filterWidth <= 16) {
blockDimX = 16;
blockDimY = 16;
int block_dim_x = 0;
int block_dim_y = 0;
if (filter_height <= 4 && filter_width <= 4) {
block_dim_x = 4;
block_dim_y = 4;
} else if (filter_height <= 8 && filter_width <= 8) {
block_dim_x = 8;
block_dim_y = 8;
} else if (filter_height <= 16 && filter_width <= 16) {
block_dim_x = 16;
block_dim_y = 16;
} else {
blockDimX = 32;
blockDimY = 32;
block_dim_x = 32;
block_dim_y = 32;
}
int blockDimZ = 1024 / blockDimX / blockDimY;
dim3 threads(blockDimX, blockDimY, std::min(blockDimZ, inputChannels));
dim3 grid(outputWidth, outputHeight);
col2imOCF<T><<<grid, threads, 0, STREAM_DEFAULT>>>(
imData, colData, inputChannels, inputHeight, inputWidth, filterHeight,
filterWidth, strideHeight, strideWidth, paddingHeight, paddingWidth,
outputHeight, outputWidth);
CHECK_SYNC("Col2ImFunctor GPU failed");
int block_dim_z = 1024 / block_dim_x / block_dim_y;
dim3 threads(block_dim_x, block_dim_y,
std::min(block_dim_z, input_channels));
dim3 grid(output_width, output_height);
col2imOCF<T><<<grid, threads, 0>>>(
im.data<T>(), col.data<T>(), input_channels, input_height, input_width,
filter_height, filter_width, stride_height, stride_width,
padding_height, padding_width, output_height, output_width);
}
};
template class Im2ColFunctor<kOCF, DEVICE_TYPE_GPU, float>;
template class Im2ColFunctor<kOCF, DEVICE_TYPE_GPU, double>;
template class Col2ImFunctor<kOCF, DEVICE_TYPE_GPU, float>;
template class Col2ImFunctor<kOCF, DEVICE_TYPE_GPU, double>;
template class Im2ColFunctor<kOCF, platform::GPUPlace, float>;
template class Im2ColFunctor<kOCF, platform::GPUPlace, double>;
template class Col2ImFunctor<kOCF, platform::GPUPlace, float>;
template class Col2ImFunctor<kOCF, platform::GPUPlace, double>;
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册