diff --git a/paddle/operators/math/CMakeLists.txt b/paddle/operators/math/CMakeLists.txt index a0ceb029e3abee2fe591325ffa3100168c3aa8e3..2fd559e90a22d01cbaf89c0fbd0f011bfdf66596 100644 --- a/paddle/operators/math/CMakeLists.txt +++ b/paddle/operators/math/CMakeLists.txt @@ -3,11 +3,14 @@ if(WITH_GPU) nv_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) nv_library(softmax SRCS softmax.cc softmax.cu DEPS operator) nv_library(cross_entropy SRCS cross_entropy.cc cross_entropy.cu DEPS operator) + nv_library(vol2col SRCS vol2col.cc vol2col.cu DEPS device_context) else() cc_library(math_function SRCS math_function.cc im2col.cc pooling.cc DEPS cblas device_context operator) cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) cc_library(softmax SRCS softmax.cc DEPS operator) cc_library(cross_entropy SRCS cross_entropy.cc DEPS operator) + cc_library(vol2col SRCS vol2col.cc DEPS device_context) endif() cc_test(im2col_test SRCS im2col_test.cc DEPS math_function tensor) +cc_test(vol2col_test SRCS vol2col_test.cc DEPS vol2col tensor) diff --git a/paddle/operators/math/vol2col.cc b/paddle/operators/math/vol2col.cc new file mode 100644 index 0000000000000000000000000000000000000000..e9718a047381596a1570b4b00546622968b70227 --- /dev/null +++ b/paddle/operators/math/vol2col.cc @@ -0,0 +1,155 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/vol2col.h" + +namespace paddle { +namespace operators { +namespace math { + +/* + * vol = [input_channels, input_depth, input_height, input_width] + * col = + * [input_channels, filter_depth, filter_height, filter_width, + * output_depth, output_height, output_width] + */ +template +class Vol2ColFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& vol, framework::Tensor& col, + int stride_depth, int stride_height, int stride_width, + int padding_depth, int padding_height, + int padding_width) const { + PADDLE_ENFORCE(vol.dims().size() == 4); + PADDLE_ENFORCE(col.dims().size() == 7); + + int input_channels = vol.dims()[0]; + int input_depth = vol.dims()[1]; + int input_height = vol.dims()[2]; + int input_width = vol.dims()[3]; + int filter_depth = col.dims()[1]; + int filter_height = col.dims()[2]; + int filter_width = col.dims()[3]; + int output_depth = col.dims()[4]; + int output_height = col.dims()[5]; + int output_width = col.dims()[6]; + int channels_col = + input_channels * filter_depth * filter_height * filter_width; + + const T* vol_data = vol.data(); + T* col_data = col.data(); + + for (int c = 0; c < channels_col; ++c) { + int w_offset = c % filter_width; + int h_offset = (c / filter_width) % filter_height; + int d_offset = (c / filter_width / filter_height) % filter_depth; + int c_in = c / filter_width / filter_height / filter_depth; + for (int d = 0; d < output_depth; ++d) { + int d_pad = d * stride_depth - padding_depth + d_offset; + for (int h = 0; h < output_height; ++h) { + int h_pad = h * stride_height - padding_height + h_offset; + for (int w = 0; w < output_width; ++w) { + int w_pad = w * stride_width - padding_width + w_offset; + + int col_idx = + ((c * output_depth + d) * output_height + h) * output_width + w; + if (h_pad < 0 || h_pad >= input_height || w_pad < 0 || + w_pad >= input_width || d_pad < 0 || d_pad >= input_depth) { + col_data[col_idx] = static_cast(0); + } else { + int vol_idx = + ((c_in * input_depth + d_pad) * input_height + h_pad) * + input_width + + w_pad; + col_data[col_idx] = vol_data[vol_idx]; + } + } + } + } + } + } +}; + +/* + * vol = [input_channels,input_depth, input_height, input_width] + * col = + * [input_channels, filter_depth, filter_height, filter_width, + * output_depth, output_height, output_width] + */ +template +class Col2VolFunctor { + public: + void operator()(const platform::DeviceContext& context, + framework::Tensor& vol, const framework::Tensor& col, + int stride_depth, int stride_height, int stride_width, + int padding_depth, int padding_height, + int padding_width) const { + PADDLE_ENFORCE(vol.dims().size() == 4); + PADDLE_ENFORCE(col.dims().size() == 7); + + int input_channels = vol.dims()[0]; + int input_depth = vol.dims()[1]; + int input_height = vol.dims()[2]; + int input_width = vol.dims()[3]; + int filter_depth = col.dims()[1]; + int filter_height = col.dims()[2]; + int filter_width = col.dims()[3]; + int output_depth = col.dims()[4]; + int output_height = col.dims()[5]; + int output_width = col.dims()[6]; + int channels_col = + input_channels * filter_depth * filter_height * filter_width; + + T* vol_data = vol.data(); + const T* col_data = col.data(); + + for (int c = 0; c < channels_col; ++c) { + int w_offset = c % filter_width; + int h_offset = (c / filter_width) % filter_height; + int d_offset = (c / filter_width / filter_height) % filter_depth; + int cIm = c / filter_width / filter_height / filter_depth; + for (int d = 0; d < output_depth; ++d) { + int d_pad = d * stride_depth - padding_depth + d_offset; + for (int h = 0; h < output_height; ++h) { + int h_pad = h * stride_height - padding_height + h_offset; + for (int w = 0; w < output_width; ++w) { + int w_pad = w * stride_width - padding_width + w_offset; + + if (h_pad >= 0 && h_pad < input_height && w_pad >= 0 && + w_pad < input_width && d_pad >= 0 && d_pad < input_depth) { + int vol_idx = + ((cIm * input_depth + d_pad) * input_height + h_pad) * + input_width + + w_pad; + int col_idx = + ((c * output_depth + d) * output_height + h) * output_width + + w; + vol_data[vol_idx] += col_data[col_idx]; + } + } + } + } + } + } +}; + +template class Vol2ColFunctor; +template class Vol2ColFunctor; +template class Col2VolFunctor; +template class Col2VolFunctor; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/vol2col.cu b/paddle/operators/math/vol2col.cu new file mode 100644 index 0000000000000000000000000000000000000000..27b11fb237575fd25a789a5fcc24ed4e30607009 --- /dev/null +++ b/paddle/operators/math/vol2col.cu @@ -0,0 +1,204 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/vol2col.h" +#include "paddle/platform/cuda_helper.h" + +namespace paddle { +namespace operators { +namespace math { + +template +__global__ void vol2col(int num_kernels, const T* data_vol, int depth, + int height, int width, int filter_depth, + int filter_height, int filter_width, int stride_depth, + int stride_height, int stride_width, int padding_depth, + int padding_height, int padding_width, int output_detph, + int output_height, int output_width, T* data_col) { + for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < num_kernels; + index += blockDim.x * gridDim.x) { + int w_out = index % output_width; + int h_out = (index / output_width) % output_height; + int d_out = (index / output_width / output_height) % output_detph; + int channel_in = index / output_width / output_height / output_detph; + int channel_out = channel_in * filter_depth * filter_height * filter_width; + int w_in = w_out * stride_width - padding_width; + int h_in = h_out * stride_height - padding_height; + int d_in = d_out * stride_depth - padding_depth; + + data_col += ((channel_out * output_detph + d_out) * output_height + h_out) * + output_width + + w_out; + data_vol += ((channel_in * depth + d_in) * height + h_in) * width + w_in; + for (int k = 0; k < filter_depth; ++k) { + for (int i = 0; i < filter_height; ++i) { + for (int j = 0; j < filter_width; ++j) { + int d = d_in + k; + int h = h_in + i; + int w = w_in + j; + *data_col = (d >= 0 && d < depth && h >= 0 && h < height && w >= 0 && + w < width) + ? data_vol[(k * height + i) * width + j] + : 0; + data_col += output_detph * output_height * output_width; + } + } + } + } +} + +/* + * im = [input_channels,intpu_depth, input_height, input_width] + * col = + * [input_channels, filter_depth, filter_height, filter_width, + * output_depth, output_height, output_width] + */ +template +class Vol2ColFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& vol, framework::Tensor& col, + int stride_depth, int stride_height, int stride_width, + int padding_depth, int padding_height, + int padding_width) const { + PADDLE_ENFORCE(vol.dims().size() == 4); + PADDLE_ENFORCE(col.dims().size() == 7); + + int input_channels = vol.dims()[0]; + int input_depth = vol.dims()[1]; + int input_height = vol.dims()[2]; + int input_width = vol.dims()[3]; + int filter_depth = col.dims()[1]; + int filter_height = col.dims()[2]; + int filter_width = col.dims()[3]; + int output_depth = col.dims()[4]; + int output_height = col.dims()[5]; + int output_width = col.dims()[6]; + + int num_outputs = + input_channels * output_depth * output_height * output_width; + + const int threads = 1024; + const int blocks = (num_outputs + 1024 - 1) / 1024; + vol2col<<(context) + .stream()>>>( + num_outputs, vol.data(), input_depth, input_height, input_width, + filter_depth, filter_height, filter_width, stride_depth, stride_height, + stride_width, padding_depth, padding_height, padding_width, + output_depth, output_height, output_width, col.data()); + } +}; + +template +__global__ void col2vol(int num_kernels, const T* data_col, int depth, + int height, int width, int filter_depth, + int filter_height, int filter_width, int stride_depth, + int stride_height, int stride_width, int padding_depth, + int padding_height, int padding_width, int output_detph, + int output_height, int output_width, T* data_vol) { + for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < num_kernels; + index += blockDim.x * gridDim.x) { + T src_val = 0; + int w = index % width + padding_width; + int h = (index / width) % height + padding_height; + int d = (index / width / height) % depth + padding_depth; + int c = index / width / height / depth; + // compute the start and end of the output + int w_col_start = + (w < filter_width) ? 0 : (w - filter_width) / stride_width + 1; + int w_col_end = min(w / stride_width + 1, output_width); + int h_col_start = + (h < filter_height) ? 0 : (h - filter_height) / stride_height + 1; + int h_col_end = min(h / stride_height + 1, output_height); + int d_col_start = + (d < filter_depth) ? 0 : (d - filter_depth) / stride_depth + 1; + int d_col_end = min(d / stride_depth + 1, output_detph); + + int offset = (c * filter_depth * filter_height * filter_width + + d * filter_width * filter_height + h * filter_width + w) * + output_detph * output_height * output_width; + + int coeff_d_col = + (1 - stride_depth * filter_width * filter_height * output_detph) * + output_height * output_width; + int coeff_h_col = + (1 - stride_height * filter_width * output_detph * output_height) * + output_width; + int coeff_w_col = + (1 - stride_width * output_detph * output_height * output_width); + + for (int d_col = d_col_start; d_col < d_col_end; ++d_col) { + for (int h_col = h_col_start; h_col < h_col_end; ++h_col) { + for (int w_col = w_col_start; w_col < w_col_end; ++w_col) { + src_val += data_col[offset + d_col * coeff_d_col + + h_col * coeff_h_col + w_col * coeff_w_col]; + } + } + } + data_vol[index] = src_val; + } +} + +/* + * im = [input_channels, input_depth, input_height, input_width] + * col = + * [input_channels, filter_depth, filter_height, filter_width, + * output_depth, output_height, output_width] + */ +template +class Col2VolFunctor { + public: + void operator()(const platform::DeviceContext& context, + framework::Tensor& vol, const framework::Tensor& col, + int stride_depth, int stride_height, int stride_width, + int padding_depth, int padding_height, + int padding_width) const { + PADDLE_ENFORCE(vol.dims().size() == 4); + PADDLE_ENFORCE(col.dims().size() == 7); + + int input_channels = vol.dims()[0]; + int input_depth = vol.dims()[1]; + int input_height = vol.dims()[2]; + int input_width = vol.dims()[3]; + int filter_depth = col.dims()[1]; + int filter_height = col.dims()[2]; + int filter_width = col.dims()[3]; + int output_depth = col.dims()[4]; + int output_height = col.dims()[5]; + int output_width = col.dims()[6]; + + int num_kernels = input_channels * input_depth * input_height * input_width; + + const int threads = 1024; + const int blocks = (num_kernels + 1024 - 1) / 1024; + + col2vol<<(context) + .stream()>>>( + num_kernels, col.data(), input_depth, input_height, input_width, + filter_depth, filter_height, filter_width, stride_depth, stride_height, + stride_width, padding_depth, padding_height, padding_width, + output_depth, output_height, output_width, vol.data()); + } +}; + +template class Vol2ColFunctor; +template class Vol2ColFunctor; +template class Col2VolFunctor; +template class Col2VolFunctor; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/vol2col.h b/paddle/operators/math/vol2col.h new file mode 100644 index 0000000000000000000000000000000000000000..f022365a16fbf61981e94bedbd8b21a32887b235 --- /dev/null +++ b/paddle/operators/math/vol2col.h @@ -0,0 +1,78 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/framework/tensor.h" +#include "paddle/platform/device_context.h" + +namespace paddle { +namespace operators { +namespace math { +/* + * \brief Converts the feature data of four dimensions(CDHW) into a colData of + * seven dimensions in the Vol2ColFunctor calculation, + * And in the Col2VolFunctor calculation, it is reversed. + * + * \param volData Vol data. + * \param volShape The shape of volData, + * [input_channels, input_depth, input_height, input_width]. + * \param colData Column data. + * \param colShape The shape of colData. + * + * The shape of colData is: + * [input_channels, filter_depth, filter_height, filter_width, output_depth, + * output_height, output_width] + * So, it is easy to reshape into a convolution matrix for convolution + * calculation based on matrix multiplication. + * The shape of convolution matrix is [height, width], where the height is equal + * input_channels * filter_depth * filter_height * filter_width, and the width + * is equal output_depth * output_height * output_width. + * + * Reshape: + * shape of colData shape of convolution matrix + * [input_channels, + * filter_depth, + * filter_height, + * filter_width, ======> [height, width] + * output_depth, + * output_height, + * output_width] + * + * \note The caller needs to ensure that volShape.inputChannels is equal to + * colShape.inputChannels. + */ +template +class Vol2ColFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& vol, framework::Tensor& col, + int stride_depth, int stride_height, int stride_width, + int padding_depth, int padding_height, + int padding_width) const; +}; + +template +class Col2VolFunctor { + public: + void operator()(const platform::DeviceContext& context, + framework::Tensor& vol, const framework::Tensor& col, + int stride_depth, int stride_height, int stride_width, + int padding_depth, int padding_height, + int padding_width) const; +}; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/vol2col_test.cc b/paddle/operators/math/vol2col_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..81225e9a9803ce371d23620876ac22da63a8e2d1 --- /dev/null +++ b/paddle/operators/math/vol2col_test.cc @@ -0,0 +1,135 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/vol2col.h" +#include +#include + +template +void testVol2col() { + paddle::framework::Tensor input; + paddle::framework::Tensor input_tmp; + paddle::framework::Tensor output; + paddle::framework::Tensor output_tmp; + + auto* place = new Place(); + paddle::platform::DeviceContext* context; + if (paddle::platform::is_cpu_place(*place)) { + context = + new paddle::platform::CPUDeviceContext(paddle::platform::CPUPlace()); + } else { +#ifdef PADDLE_WITH_CUDA + context = + new paddle::platform::CUDADeviceContext(paddle::platform::GPUPlace()); +#else + PADDLE_THROW("no GPU support"); +#endif // PADDLE_WITH_CUDA + } + + /** + * input = [[0, 1, 2, + * 3, 4, 5] + * [6, 7, 8, + * 9, 10, 11]] + * + * output = [0, 1 + * 1, 2 + * 3, 4 + * 4, 5 + * 6, 7 + * 7, 8 + * 9, 10 + * 10, 11] + * + * col2vol = [[0, 2, 2, + * 3, 8, 5] + * [6, 14, 8, + * 9, 20, 11]] + * + */ + int input_depth = 2; + int input_height = 2; + int input_width = 3; + int filter_size = 2; + int stride = 1; + int padding = 0; + int output_depth = (input_depth - filter_size + 2 * padding) / stride + 1; + int output_height = (input_height - filter_size + 2 * padding) / stride + 1; + int output_width = (input_width - filter_size + 2 * padding) / stride + 1; + + // Vol2Col test + float* input_ptr = + input_tmp.mutable_data({1, input_depth, input_height, input_width}, + paddle::platform::CPUPlace()); + float arr[12] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; + memcpy(input_ptr, arr, 12 * sizeof(float)); + + if (paddle::platform::is_cpu_place(*place)) { + input = input_tmp; + } else { + input.CopyFrom(input_tmp, *place); + } + output.mutable_data({1, filter_size, filter_size, filter_size, + output_depth, output_height, output_width}, + *place); + + paddle::operators::math::Vol2ColFunctor vol2col; + vol2col(*context, input, output, stride, stride, stride, padding, padding, + padding); + + float vol_2_col[] = {0, 1, 1, 2, 3, 4, 4, 5, 6, 7, 7, 8, 9, 10, 10, 11}; + float* out_cfo_ptr; + if (paddle::platform::is_cpu_place(*place)) { + out_cfo_ptr = output.data(); + } else { + output_tmp.CopyFrom(output, paddle::platform::CPUPlace()); + out_cfo_ptr = output_tmp.data(); + } + + for (int i = 0; i < 16; ++i) { + EXPECT_EQ(out_cfo_ptr[i], vol_2_col[i]); + } + + // Col2Vol test + float col_2_vol[] = {0, 2, 2, 3, 8, 5, 6, 14, 8, 9, 20, 11}; + memset(input_ptr, 0, 12 * sizeof(float)); + if (paddle::platform::is_cpu_place(*place)) { + input = input_tmp; + } else { + input.CopyFrom(input_tmp, *place); + } + + paddle::operators::math::Col2VolFunctor col2vol; + col2vol(*context, input, output, stride, stride, stride, padding, padding, + padding); + + float* in_ptr; + if (paddle::platform::is_cpu_place(*place)) { + in_ptr = input.data(); + } else { + input_tmp.CopyFrom(input, paddle::platform::CPUPlace()); + in_ptr = input_tmp.data(); + } + + for (int i = 0; i < 12; ++i) { + EXPECT_EQ(in_ptr[i], col_2_vol[i]); + } +} + +TEST(math, vol2col) { + testVol2col(); +#ifdef PADDLE_WITH_CUDA + testVol2col(); +#endif // PADDLE_WITH_CUDA +}