/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/operators/math/im2col.h" #include "paddle/platform/cuda_helper.h" namespace paddle { namespace operators { namespace math { template __global__ void im2col(const T* data_im, int num_outs, int im_height, int im_width, int dilation_h, int dilation_w, int filter_height, int filter_width, int stride_height, int stride_width, int padding_height, int padding_width, int col_height, int col_width, T* data_col) { const int index = (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x; if (index < num_outs) { int w_out = index % col_width; int h_out = (index / col_width) % col_height; int channel_in = index / col_width / col_height; int channel_out = channel_in * filter_height * filter_width; int h_in = h_out * stride_height - padding_height; int w_in = w_out * stride_width - padding_width; data_col += (channel_out * col_height + h_out) * col_width + w_out; data_im += (channel_in * im_height + h_in) * im_width + w_in; for (int i = 0; i < filter_height; ++i) { for (int j = 0; j < filter_width; ++j) { int rIdx = h_in + i * dilation_h; int cIdx = w_in + j * dilation_w; *data_col = (rIdx >= im_height || rIdx < 0 || cIdx >= im_width || cIdx < 0) ? 0 : data_im[i * dilation_h * im_width + j * dilation_w]; data_col += col_height * col_width; } } } } /* * im = [input_channels, input_height, input_width] * col = * [input_channels, filter_height, filter_width, output_height, output_width] */ template class Im2ColFunctor { public: void operator()(const platform::DeviceContext& context, const framework::Tensor& im, const std::vector& dilation, const std::vector& stride, const std::vector& padding, framework::Tensor* col) { PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(col->dims().size() == 5); int im_channels = im.dims()[0]; int im_height = im.dims()[1]; int im_width = im.dims()[2]; int filter_height = col->dims()[1]; int filter_width = col->dims()[2]; int col_height = col->dims()[3]; int col_width = col->dims()[4]; PADDLE_ENFORCE_EQ((im_height + padding[0] + padding[2] - (dilation[0] * (filter_height - 1) + 1)) / stride[0] + 1, col_height, "Output_height and padding(padding_up, padding_down) are " "inconsistent."); PADDLE_ENFORCE_EQ((im_width + padding[1] + padding[3] - (dilation[1] * (filter_width - 1) + 1)) / stride[1] + 1, col_width, "col_width and padding(padding_left, padding_right) are " "inconsistent."); int num_outputs = im_channels * col_height * col_width; int blocks = (num_outputs + 1024 - 1) / 1024; int block_x = 512; int block_y = (blocks + 512 - 1) / 512; dim3 threads(1024, 1); dim3 grid(block_x, block_y); im2col<<(context) .stream()>>>( im.data(), num_outputs, im_height, im_width, dilation[0], dilation[1], filter_height, filter_width, stride[0], stride[1], padding[0], padding[1], col_height, col_width, col->data()); } }; template __global__ void col2im(int n, const T* data_col, int im_height, int im_width, int dilation_h, int dilation_w, int filter_height, int filter_width, int stride_height, int stride_width, int padding_height, int padding_width, int col_height, int col_width, T* data_im) { const int index = (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x; const int d_filter_height = dilation_h * (filter_height - 1) + 1; const int d_filter_width = dilation_w * (filter_width - 1) + 1; if (index < n) { T val = 0; int w = index % im_width; int h = (index / im_width) % im_height; int c = index / (im_width * im_height); // compute the start and end of the output int w_col_start = (w < d_filter_width) ? 0 : (w - d_filter_width) / stride_width + 1; int w_col_end = min(w / stride_width + 1, col_width); int h_col_start = (h < d_filter_height) ? 0 : (h - d_filter_height) / stride_height + 1; int h_col_end = min(h / stride_height + 1, col_height); for (int h_col = h_col_start; h_col < h_col_end; ++h_col) { for (int w_col = w_col_start; w_col < w_col_end; ++w_col) { int h_off = (h - h_col * stride_height); int w_off = (w - w_col * stride_width); if (h_off % dilation_h == 0 && w_off % dilation_w == 0) { h_off /= dilation_h; w_off /= dilation_w; int data_col_index = (((c * filter_height + h_off) * filter_width + w_off) * col_height + h_col) * col_width + w_col; val += data_col[data_col_index]; } } } data_im[index] = val; } } /* * im = [input_channels, input_height, input_width] * col = * [input_channels, filter_height, filter_width, output_height, output_width] */ template class Col2ImFunctor { public: void operator()(const platform::DeviceContext& context, const framework::Tensor& col, const std::vector& dilation, const std::vector& stride, const std::vector& padding, framework::Tensor* im) { PADDLE_ENFORCE(im->dims().size() == 3); PADDLE_ENFORCE(col.dims().size() == 5); int im_channels = im->dims()[0]; int im_height = im->dims()[1]; int im_width = im->dims()[2]; int filter_height = col.dims()[1]; int filter_width = col.dims()[2]; int col_height = col.dims()[3]; int col_width = col.dims()[4]; PADDLE_ENFORCE_EQ((im_height + padding[0] + padding[2] - (dilation[0] * (filter_height - 1) + 1)) / stride[0] + 1, col_height, "Output_height and padding(padding_up, padding_down) are " "inconsistent."); PADDLE_ENFORCE_EQ((im_width + padding[1] + padding[3] - (dilation[1] * (filter_width - 1) + 1)) / stride[1] + 1, col_width, "col_width and padding(padding_left, padding_right) are " "inconsistent."); size_t num_kernels = im_channels * im_height * im_width; size_t blocks = (num_kernels + 1024 - 1) / 1024; size_t block_x = 512; size_t block_y = (blocks + 512 - 1) / 512; dim3 threads(1024, 1); dim3 grid(block_x, block_y); // To avoid involving atomic operations, we will launch one kernel per // bottom dimension, and then in the kernel add up the top dimensions. col2im<<(context) .stream()>>>( num_kernels, col.data(), im_height, im_width, dilation[0], dilation[1], filter_height, filter_width, stride[0], stride[1], padding[0], padding[2], col_height, col_width, im->data()); } }; template class Im2ColFunctor; template class Im2ColFunctor; template class Col2ImFunctor; template class Col2ImFunctor; template __global__ void im2colOCF(const T* im_data, int im_channels, int im_height, int im_width, int filter_height, int filter_width, int stride_height, int stride_width, int padding_height, int padding_width, int col_height, int col_width, T* col_data) { int swid = blockIdx.x; int shid = blockIdx.y; for (int channelid = threadIdx.z; channelid < im_channels; channelid += blockDim.z) { for (int idy = threadIdx.y; idy < filter_height; idy += blockDim.y) { for (int idx = threadIdx.x; idx < filter_width; idx += blockDim.x) { int width_offset = idx + swid * stride_width - padding_width; int height_offset = idy + shid * stride_height - padding_height; int im_offset = width_offset + height_offset * im_width + channelid * im_height * im_width; int col_offset = idx + idy * filter_width + channelid * filter_height * filter_width + (shid * col_width + swid) * (im_channels * filter_height * filter_width); col_data[col_offset] = (height_offset >= im_height || height_offset < 0 || width_offset >= im_width || width_offset < 0) ? T(0) : im_data[im_offset]; } } } } /* * im = [input_channels, input_height, input_width] * col = * [output_height, output_width, input_channels, filter_height, filter_width] */ template class Im2ColFunctor { public: void operator()(const platform::DeviceContext& context, const framework::Tensor& im, const std::vector& dilation, const std::vector& stride, const std::vector& padding, framework::Tensor* col) { PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(col->dims().size() == 5); int im_channels = im.dims()[0]; int im_height = im.dims()[1]; int im_width = im.dims()[2]; int filter_height = col->dims()[3]; int filter_width = col->dims()[4]; int col_height = col->dims()[0]; int col_width = col->dims()[1]; PADDLE_ENFORCE_EQ((im_height + padding[0] + padding[2] - (dilation[0] * (filter_height - 1) + 1)) / stride[0] + 1, col_height, "Output_height and padding(padding_up, padding_down) are " "inconsistent."); PADDLE_ENFORCE_EQ((im_width + padding[1] + padding[3] - (dilation[1] * (filter_width - 1) + 1)) / stride[1] + 1, col_width, "col_width and padding(padding_left, padding_right) are " "inconsistent."); int block_dim_x = 0; int block_dim_y = 0; if (filter_height <= 4 && filter_width <= 4) { block_dim_x = 4; block_dim_y = 4; } else if (filter_height <= 8 && filter_width <= 8) { block_dim_x = 8; block_dim_y = 8; } else if (filter_height <= 16 && filter_width <= 16) { block_dim_x = 16; block_dim_y = 16; } else { block_dim_x = 32; block_dim_y = 32; } int block_dim_z = 1024 / block_dim_x / block_dim_y; dim3 threads(block_dim_x, block_dim_y, std::min(block_dim_z, im_channels)); dim3 grid(col_width, col_height); im2colOCF<<(context) .stream()>>>( im.data(), im_channels, im_height, im_width, filter_height, filter_width, stride[0], stride[1], padding[0], padding[1], col_height, col_width, col->data()); } }; template __global__ void col2imOCF(const T* col_data, int im_channels, int im_height, int im_width, int filter_height, int filter_width, int stride_height, int stride_width, int padding_height, int padding_width, int col_height, int col_width, T* im_data) { int swid = blockIdx.x; int shid = blockIdx.y; for (int channelid = threadIdx.z; channelid < im_channels; channelid += blockDim.z) { for (int idy = threadIdx.y; idy < filter_height; idy += blockDim.y) { for (int idx = threadIdx.x; idx < filter_width; idx += blockDim.x) { int width_offset = idx + swid * stride_width - padding_width; int height_offset = idy + shid * stride_height - padding_height; int im_offset = width_offset + height_offset * im_width + channelid * im_height * im_width; int col_offset = idx + idy * filter_width + channelid * filter_height * filter_width + (shid * col_width + swid) * (im_channels * filter_height * filter_width); if (height_offset >= 0 && height_offset < im_height && width_offset >= 0 && width_offset < im_width) { paddle::platform::CudaAtomicAdd(im_data + im_offset, col_data[col_offset]); } } } } } /* * im = [input_channels, input_height, input_width] * col = * [output_height, output_width, input_channels, filter_height, filter_width] */ template class Col2ImFunctor { public: void operator()(const platform::DeviceContext& context, const framework::Tensor& col, const std::vector& dilation, const std::vector& stride, const std::vector& padding, framework::Tensor* im) { PADDLE_ENFORCE(im->dims().size() == 3); PADDLE_ENFORCE(col.dims().size() == 5); int im_channels = im->dims()[0]; int im_height = im->dims()[1]; int im_width = im->dims()[2]; int filter_height = col.dims()[3]; int filter_width = col.dims()[4]; int col_height = col.dims()[0]; int col_width = col.dims()[1]; PADDLE_ENFORCE_EQ((im_height + padding[0] + padding[2] - (dilation[0] * (filter_height - 1) + 1)) / stride[0] + 1, col_height, "Output_height and padding(padding_up, padding_down) are " "inconsistent."); PADDLE_ENFORCE_EQ((im_width + padding[1] + padding[3] - (dilation[1] * (filter_width - 1) + 1)) / stride[1] + 1, col_width, "col_width and padding(padding_left, padding_right) are " "inconsistent."); int block_dim_x = 0; int block_dim_y = 0; if (filter_height <= 4 && filter_width <= 4) { block_dim_x = 4; block_dim_y = 4; } else if (filter_height <= 8 && filter_width <= 8) { block_dim_x = 8; block_dim_y = 8; } else if (filter_height <= 16 && filter_width <= 16) { block_dim_x = 16; block_dim_y = 16; } else { block_dim_x = 32; block_dim_y = 32; } int block_dim_z = 1024 / block_dim_x / block_dim_y; dim3 threads(block_dim_x, block_dim_y, std::min(block_dim_z, im_channels)); dim3 grid(col_width, col_height); col2imOCF<<(context) .stream()>>>( col.data(), im_channels, im_height, im_width, filter_height, filter_width, stride[0], stride[1], padding[0], padding[1], col_height, col_width, im->data()); } }; template class Im2ColFunctor; template class Im2ColFunctor; template class Col2ImFunctor; template class Col2ImFunctor; } // namespace math } // namespace operators } // namespace paddle