From f7be9cb97aa4b90c0ccd6f954a71d3caada4dac7 Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Wed, 30 Aug 2017 13:55:58 +0800 Subject: [PATCH] Refine the cpu code. --- paddle/operators/math/CMakeLists.txt | 4 +- paddle/operators/math/im2col.cc | 319 +++++++++++++++------------ paddle/operators/math/im2col.h | 20 +- 3 files changed, 186 insertions(+), 157 deletions(-) diff --git a/paddle/operators/math/CMakeLists.txt b/paddle/operators/math/CMakeLists.txt index ed51d416ed9..f31281ebac0 100644 --- a/paddle/operators/math/CMakeLists.txt +++ b/paddle/operators/math/CMakeLists.txt @@ -1,8 +1,8 @@ if(WITH_GPU) - nv_library(math_function SRCS math_function.cc math_function.cu DEPS cblas device_context) + nv_library(math_function SRCS math_function.cc math_function.cu im2col.cc DEPS cblas device_context) else() - cc_library(math_function SRCS math_function.cc DEPS cblas device_context) + cc_library(math_function SRCS math_function.cc im2col.cc DEPS cblas device_context) endif() nv_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) diff --git a/paddle/operators/math/im2col.cc b/paddle/operators/math/im2col.cc index dafb21b335f..8124e322cbf 100644 --- a/paddle/operators/math/im2col.cc +++ b/paddle/operators/math/im2col.cc @@ -12,48 +12,54 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "Im2Col.h" +#include "paddle/operators/math/im2col.h" namespace paddle { /* - * imShape = [inputChannels, inputHeight, inputWidth] - * colShape = - * [inputChannels, filterHeight, filterWidth, outputHeight, outputWidth] + * im = [input_channels, input_height, input_width] + * col = + * [input_channels, filter_height, filter_width, output_height, output_width] */ template -class Im2ColFunctor { +class Im2ColFunctor { public: - void operator()(const T* imData, const TensorShape& imShape, T* colData, - const TensorShape& colShape, int strideHeight, - int strideWidth, int paddingHeight, int paddingWidth) { - int inputChannels = imShape[0]; - int inputHeight = imShape[1]; - int inputWidth = imShape[2]; - int filterHeight = colShape[1]; - int filterWidth = colShape[2]; - int outputHeight = colShape[3]; - int outputWidth = colShape[4]; - int channelsCol = inputChannels * filterHeight * filterWidth; - - for (int c = 0; c < channelsCol; ++c) { - int wOffset = c % filterWidth; - int hOffset = (c / filterWidth) % filterHeight; - int c_im = c / filterWidth / filterHeight; - for (int h = 0; h < outputHeight; ++h) { - for (int w = 0; w < outputWidth; ++w) { - int imRowIdx = h * strideHeight + hOffset; - int imColIdx = w * strideWidth + wOffset; - if ((imRowIdx - paddingHeight) < 0 || - (imRowIdx - paddingHeight) >= inputHeight || - (imColIdx - paddingWidth) < 0 || - (imColIdx - paddingWidth) >= inputWidth) { - colData[(c * outputHeight + h) * outputWidth + w] = T(0); + void operator()(const framework::Tensor& im, framework::Tensor& col, + int stride_height, int stride_width, int padding_height, + int padding_width) { + PADDLE_ENFORCE(im.dims().size() == 3); + PADDLE_ENFORCE(col.dims().size() == 5); + + int input_channels = im.dims()[0]; + int input_height = im.dims()[1]; + int input_width = im.dims()[2]; + int filter_height = col.dims()[1]; + int filter_width = col.dims()[2]; + int output_height = col.dims()[3]; + int output_width = col.dims()[4]; + int channels_col = input_channels * filter_height * filter_width; + + const T* im_data = im.data(); + T* col_data = col.data(); + + for (int c = 0; c < channels_col; ++c) { + int w_offset = c % filter_width; + int h_offset = (c / filter_width) % filter_height; + int c_im = c / filter_width / filter_height; + for (int h = 0; h < output_height; ++h) { + for (int w = 0; w < output_width; ++w) { + int im_row_idx = h * stride_height + h_offset; + int im_col_idx = w * stride_width + w_offset; + if ((im_row_idx - padding_height) < 0 || + (im_row_idx - padding_height) >= input_height || + (im_col_idx - padding_width) < 0 || + (im_col_idx - padding_width) >= input_width) { + col_data[(c * output_height + h) * output_width + w] = T(0); } else { - imRowIdx += c_im * inputHeight - paddingHeight; - imColIdx -= paddingWidth; - colData[(c * outputHeight + h) * outputWidth + w] = - imData[imRowIdx * inputWidth + imColIdx]; + im_row_idx += c_im * input_height - padding_height; + im_col_idx -= padding_width; + col_data[(c * output_height + h) * output_width + w] = + im_data[im_row_idx * input_width + im_col_idx]; } } } @@ -62,41 +68,46 @@ class Im2ColFunctor { }; /* - * imShape = [inputChannels, inputHeight, inputWidth] - * colShape = - * [inputChannels, filterHeight, filterWidth, outputHeight, outputWidth] + * im = [input_channels, input_height, input_width] + * col = + * [input_channels, filter_height, filter_width, output_height, output_width] */ template -class Col2ImFunctor { +class Col2ImFunctor { public: - void operator()(T* imData, const TensorShape& imShape, const T* colData, - const TensorShape& colShape, int strideHeight, - int strideWidth, int paddingHeight, int paddingWidth) { - int inputChannels = imShape[0]; - int inputHeight = imShape[1]; - int inputWidth = imShape[2]; - int filterHeight = colShape[1]; - int filterWidth = colShape[2]; - int outputHeight = colShape[3]; - int outputWidth = colShape[4]; - int channelsCol = inputChannels * filterHeight * filterWidth; - - for (int c = 0; c < channelsCol; ++c) { - int wOffset = c % filterWidth; - int hOffset = (c / filterWidth) % filterHeight; - int c_im = c / filterWidth / filterHeight; - for (int h = 0; h < outputHeight; ++h) { - for (int w = 0; w < outputWidth; ++w) { - int imRowIdx = h * strideHeight + hOffset; - int imColIdx = w * strideWidth + wOffset; - if ((imRowIdx - paddingHeight) >= 0 && - (imRowIdx - paddingHeight) < inputHeight && - (imColIdx - paddingWidth) >= 0 && - (imColIdx - paddingWidth) < inputWidth) { - imRowIdx += c_im * inputHeight - paddingHeight; - imColIdx -= paddingWidth; - imData[imRowIdx * inputWidth + imColIdx] += - colData[(c * outputHeight + h) * outputWidth + w]; + void operator()(framework::Tensor& im, const framework::Tensor& col, + int stride_height, int stride_width, int padding_height, + int padding_width) { + PADDLE_ENFORCE(im.dims().size() == 3); + PADDLE_ENFORCE(col.dims().size() == 5); + int input_channels = im.dims()[0]; + int input_height = im.dims()[1]; + int input_width = im.dims()[2]; + int filter_height = col.dims()[1]; + int filter_width = col.dims()[2]; + int output_height = col.dims()[3]; + int output_width = col.dims()[4]; + int channels_col = input_channels * filter_height * filter_width; + + T* im_data = im.data(); + const T* col_data = col.data(); + + for (int c = 0; c < channels_col; ++c) { + int w_offset = c % filter_width; + int h_offset = (c / filter_width) % filter_height; + int c_im = c / filter_width / filter_height; + for (int h = 0; h < output_height; ++h) { + for (int w = 0; w < output_width; ++w) { + int im_row_idx = h * stride_height + h_offset; + int im_col_idx = w * stride_width + w_offset; + if ((im_row_idx - padding_height) >= 0 && + (im_row_idx - padding_height) < input_height && + (im_col_idx - padding_width) >= 0 && + (im_col_idx - padding_width) < input_width) { + im_row_idx += c_im * input_height - padding_height; + im_col_idx -= padding_width; + im_data[im_row_idx * input_width + im_col_idx] += + col_data[(c * output_height + h) * output_width + w]; } } } @@ -104,52 +115,61 @@ class Col2ImFunctor { } }; -template class Im2ColFunctor; -template class Im2ColFunctor; -template class Col2ImFunctor; -template class Col2ImFunctor; +template class Im2ColFunctor; +template class Im2ColFunctor; +template class Col2ImFunctor; +template class Col2ImFunctor; /* - * imShape = [inputChannels, inputHeight, inputWidth] - * colShape = - * [outputHeight, outputWidth, inputChannels, filterHeight, filterWidth] + * im = [input_channels, input_height, input_width] + * col = + * [output_height, output_width, input_channels, filter_height, filter_width] */ template -class Im2ColFunctor { +class Im2ColFunctor { public: - void operator()(const T* imData, const TensorShape& imShape, T* colData, - const TensorShape& colShape, int strideHeight, - int strideWidth, int paddingHeight, int paddingWidth) { - int inputChannels = imShape[0]; - int inputHeight = imShape[1]; - int inputWidth = imShape[2]; - int filterHeight = colShape[3]; - int filterWidth = colShape[4]; - int outputHeight = colShape[0]; - int outputWidth = colShape[1]; - for (int outputH = 0; outputH < outputHeight; ++outputH) { - for (int outputW = 0; outputW < outputWidth; ++outputW) { - for (int channel = 0; channel < inputChannels; ++channel) { - for (int filterH = 0; filterH < filterHeight; ++filterH) { - for (int filterW = 0; filterW < filterWidth; ++filterW) { - int imRowOffset = - outputH * strideHeight + filterH - paddingHeight; - int imColOffset = outputW * strideWidth + filterW - paddingWidth; - int colDataOffset = - (((outputH * outputWidth + outputW) * inputChannels + - channel) * - filterHeight + - filterH) * - filterWidth + - filterW; - if (imRowOffset < 0 || imRowOffset >= inputHeight || - imColOffset < 0 || imColOffset >= inputWidth) { - colData[colDataOffset] = float(0); + void operator()(const framework::Tensor& im, framework::Tensor& col, + int stride_height, int stride_width, int padding_height, + int padding_width) { + PADDLE_ENFORCE(im.dims().size() == 3); + PADDLE_ENFORCE(col.dims().size() == 5); + int input_channels = im.dims()[0]; + int input_height = im.dims()[1]; + int input_width = im.dims()[2]; + int filter_height = col.dims()[3]; + int filter_width = col.dims()[4]; + int output_height = col.dims()[0]; + int output_width = col.dims()[1]; + + const T* im_data = im.data(); + T* col_data = col.data(); + + for (int col_row_idx = 0; col_row_idx < output_height; ++col_row_idx) { + for (int col_col_idx = 0; col_col_idx < output_width; ++col_col_idx) { + for (int channel = 0; channel < input_channels; ++channel) { + for (int filter_row_idx = 0; filter_row_idx < filter_height; + ++filter_row_idx) { + for (int filter_col_idx = 0; filter_col_idx < filter_width; + ++filter_col_idx) { + int im_row_offset = + col_row_idx * stride_height + filter_row_idx - padding_height; + int im_col_offset = + col_col_idx * stride_width + filter_col_idx - padding_width; + int col_offset = (((col_row_idx * output_width + col_col_idx) * + input_channels + + channel) * + filter_height + + filter_row_idx) * + filter_width + + filter_col_idx; + if (im_row_offset < 0 || im_row_offset >= input_height || + im_col_offset < 0 || im_col_offset >= input_width) { + col_data[col_offset] = T(0); } else { - int imDataOffset = - (channel * inputHeight + imRowOffset) * inputWidth + - imColOffset; - colData[colDataOffset] = imData[imDataOffset]; + int im_offset = + (channel * input_height + im_row_offset) * input_width + + im_col_offset; + col_data[col_offset] = im_data[im_offset]; } } } @@ -160,44 +180,53 @@ class Im2ColFunctor { }; /* - * imShape = [inputChannels, inputHeight, inputWidth] - * colShape = - * [outputHeight, outputWidth, inputChannels, filterHeight, filterWidth] + * im = [input_channels, input_height, input_width] + * col = + * [output_height, output_width, input_channels, filter_height, filter_width] */ template -class Col2ImFunctor { +class Col2ImFunctor { public: - void operator()(T* imData, const TensorShape& imShape, const T* colData, - const TensorShape& colShape, int strideHeight, - int strideWidth, int paddingHeight, int paddingWidth) { - int inputChannels = imShape[0]; - int inputHeight = imShape[1]; - int inputWidth = imShape[2]; - int filterHeight = colShape[3]; - int filterWidth = colShape[4]; - int outputHeight = colShape[0]; - int outputWidth = colShape[1]; - for (int outputH = 0; outputH < outputHeight; ++outputH) { - for (int outputW = 0; outputW < outputWidth; ++outputW) { - for (int channel = 0; channel < inputChannels; ++channel) { - for (int filterH = 0; filterH < filterHeight; ++filterH) { - for (int filterW = 0; filterW < filterWidth; ++filterW) { - int imRowOffset = - outputH * strideHeight + filterH - paddingHeight; - int imColOffset = outputW * strideWidth + filterW - paddingWidth; - int colDataOffset = - (((outputH * outputWidth + outputW) * inputChannels + - channel) * - filterHeight + - filterH) * - filterWidth + - filterW; - if (imRowOffset >= 0 && imRowOffset < inputHeight && - imColOffset >= 0 && imColOffset < inputWidth) { - int imDataOffset = - (channel * inputHeight + imRowOffset) * inputWidth + - imColOffset; - imData[imDataOffset] += colData[colDataOffset]; + void operator()(framework::Tensor& im, const framework::Tensor& col, + int stride_height, int stride_width, int padding_height, + int padding_width) { + PADDLE_ENFORCE(im.dims().size() == 3); + PADDLE_ENFORCE(col.dims().size() == 5); + int input_channels = im.dims()[0]; + int input_height = im.dims()[1]; + int input_width = im.dims()[2]; + int filter_height = col.dims()[3]; + int filter_width = col.dims()[4]; + int output_height = col.dims()[0]; + int output_width = col.dims()[1]; + + T* im_data = im.data(); + const T* col_data = col.data(); + + for (int col_row_idx = 0; col_row_idx < output_height; ++col_row_idx) { + for (int col_col_idx = 0; col_col_idx < output_width; ++col_col_idx) { + for (int channel = 0; channel < input_channels; ++channel) { + for (int filter_row_idx = 0; filter_row_idx < filter_height; + ++filter_row_idx) { + for (int filter_col_idx = 0; filter_col_idx < filter_width; + ++filter_col_idx) { + int im_row_offset = + col_row_idx * stride_height + filter_row_idx - padding_height; + int im_col_offset = + col_col_idx * stride_width + filter_col_idx - padding_width; + int col_offset = (((col_row_idx * output_width + col_col_idx) * + input_channels + + channel) * + filter_height + + filter_row_idx) * + filter_width + + filter_col_idx; + if (im_row_offset >= 0 && im_row_offset < input_height && + im_col_offset >= 0 && im_col_offset < input_width) { + int im_offset = + (channel * input_height + im_row_offset) * input_width + + im_col_offset; + im_data[im_offset] += col_data[col_offset]; } } } @@ -207,9 +236,9 @@ class Col2ImFunctor { } }; -template class Im2ColFunctor; -template class Im2ColFunctor; -template class Col2ImFunctor; -template class Col2ImFunctor; +template class Im2ColFunctor; +template class Im2ColFunctor; +template class Col2ImFunctor; +template class Col2ImFunctor; } // namespace paddle diff --git a/paddle/operators/math/im2col.h b/paddle/operators/math/im2col.h index 4568ca2fd11..f2f982b6875 100644 --- a/paddle/operators/math/im2col.h +++ b/paddle/operators/math/im2col.h @@ -14,8 +14,8 @@ limitations under the License. */ #pragma once -#include "TensorShape.h" -#include "TensorType.h" +#include "paddle/framework/tensor.h" +#include "paddle/platform/device_context.h" namespace paddle { @@ -67,20 +67,20 @@ enum ColFormat { kCFO = 0, kOCF = 1 }; * \note The caller needs to ensure that imShape.inputChannels is equal to * colShape.inputChannels. */ -template +template class Im2ColFunctor { public: - void operator()(const T* imData, const TensorShape& imShape, T* colData, - const TensorShape& colShape, int strideHeight, - int strideWidth, int paddingHeight, int paddingWidth); + void operator()(const framework::Tensor& im, framework::Tensor& col, + int stride_height, int stride_width, int padding_height, + int padding_width); }; -template +template class Col2ImFunctor { public: - void operator()(T* imData, const TensorShape& imShape, const T* colData, - const TensorShape& colShape, int strideHeight, - int strideWidth, int paddingHeight, int paddingWidth); + void operator()(framework::Tensor& im, const framework::Tensor& col, + int stride_height, int stride_width, int padding_height, + int padding_width); }; } // namespace paddle -- GitLab