提交 f7be9cb9 编写于 作者: H hedaoyuan

Refine the cpu code.

上级 6efbe2ff
if(WITH_GPU) if(WITH_GPU)
nv_library(math_function SRCS math_function.cc math_function.cu DEPS cblas device_context) nv_library(math_function SRCS math_function.cc math_function.cu im2col.cc DEPS cblas device_context)
else() else()
cc_library(math_function SRCS math_function.cc DEPS cblas device_context) cc_library(math_function SRCS math_function.cc im2col.cc DEPS cblas device_context)
endif() endif()
nv_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) nv_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor)
...@@ -12,48 +12,54 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,48 +12,54 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "Im2Col.h" #include "paddle/operators/math/im2col.h"
namespace paddle { namespace paddle {
/* /*
* imShape = [inputChannels, inputHeight, inputWidth] * im = [input_channels, input_height, input_width]
* colShape = * col =
* [inputChannels, filterHeight, filterWidth, outputHeight, outputWidth] * [input_channels, filter_height, filter_width, output_height, output_width]
*/ */
template <class T> template <class T>
class Im2ColFunctor<kCFO, DEVICE_TYPE_CPU, T> { class Im2ColFunctor<kCFO, platform::CPUPlace, T> {
public: public:
void operator()(const T* imData, const TensorShape& imShape, T* colData, void operator()(const framework::Tensor& im, framework::Tensor& col,
const TensorShape& colShape, int strideHeight, int stride_height, int stride_width, int padding_height,
int strideWidth, int paddingHeight, int paddingWidth) { int padding_width) {
int inputChannels = imShape[0]; PADDLE_ENFORCE(im.dims().size() == 3);
int inputHeight = imShape[1]; PADDLE_ENFORCE(col.dims().size() == 5);
int inputWidth = imShape[2];
int filterHeight = colShape[1]; int input_channels = im.dims()[0];
int filterWidth = colShape[2]; int input_height = im.dims()[1];
int outputHeight = colShape[3]; int input_width = im.dims()[2];
int outputWidth = colShape[4]; int filter_height = col.dims()[1];
int channelsCol = inputChannels * filterHeight * filterWidth; int filter_width = col.dims()[2];
int output_height = col.dims()[3];
for (int c = 0; c < channelsCol; ++c) { int output_width = col.dims()[4];
int wOffset = c % filterWidth; int channels_col = input_channels * filter_height * filter_width;
int hOffset = (c / filterWidth) % filterHeight;
int c_im = c / filterWidth / filterHeight; const T* im_data = im.data<T>();
for (int h = 0; h < outputHeight; ++h) { T* col_data = col.data<T>();
for (int w = 0; w < outputWidth; ++w) {
int imRowIdx = h * strideHeight + hOffset; for (int c = 0; c < channels_col; ++c) {
int imColIdx = w * strideWidth + wOffset; int w_offset = c % filter_width;
if ((imRowIdx - paddingHeight) < 0 || int h_offset = (c / filter_width) % filter_height;
(imRowIdx - paddingHeight) >= inputHeight || int c_im = c / filter_width / filter_height;
(imColIdx - paddingWidth) < 0 || for (int h = 0; h < output_height; ++h) {
(imColIdx - paddingWidth) >= inputWidth) { for (int w = 0; w < output_width; ++w) {
colData[(c * outputHeight + h) * outputWidth + w] = T(0); int im_row_idx = h * stride_height + h_offset;
int im_col_idx = w * stride_width + w_offset;
if ((im_row_idx - padding_height) < 0 ||
(im_row_idx - padding_height) >= input_height ||
(im_col_idx - padding_width) < 0 ||
(im_col_idx - padding_width) >= input_width) {
col_data[(c * output_height + h) * output_width + w] = T(0);
} else { } else {
imRowIdx += c_im * inputHeight - paddingHeight; im_row_idx += c_im * input_height - padding_height;
imColIdx -= paddingWidth; im_col_idx -= padding_width;
colData[(c * outputHeight + h) * outputWidth + w] = col_data[(c * output_height + h) * output_width + w] =
imData[imRowIdx * inputWidth + imColIdx]; im_data[im_row_idx * input_width + im_col_idx];
} }
} }
} }
...@@ -62,41 +68,46 @@ class Im2ColFunctor<kCFO, DEVICE_TYPE_CPU, T> { ...@@ -62,41 +68,46 @@ class Im2ColFunctor<kCFO, DEVICE_TYPE_CPU, T> {
}; };
/* /*
* imShape = [inputChannels, inputHeight, inputWidth] * im = [input_channels, input_height, input_width]
* colShape = * col =
* [inputChannels, filterHeight, filterWidth, outputHeight, outputWidth] * [input_channels, filter_height, filter_width, output_height, output_width]
*/ */
template <class T> template <class T>
class Col2ImFunctor<kCFO, DEVICE_TYPE_CPU, T> { class Col2ImFunctor<kCFO, platform::CPUPlace, T> {
public: public:
void operator()(T* imData, const TensorShape& imShape, const T* colData, void operator()(framework::Tensor& im, const framework::Tensor& col,
const TensorShape& colShape, int strideHeight, int stride_height, int stride_width, int padding_height,
int strideWidth, int paddingHeight, int paddingWidth) { int padding_width) {
int inputChannels = imShape[0]; PADDLE_ENFORCE(im.dims().size() == 3);
int inputHeight = imShape[1]; PADDLE_ENFORCE(col.dims().size() == 5);
int inputWidth = imShape[2]; int input_channels = im.dims()[0];
int filterHeight = colShape[1]; int input_height = im.dims()[1];
int filterWidth = colShape[2]; int input_width = im.dims()[2];
int outputHeight = colShape[3]; int filter_height = col.dims()[1];
int outputWidth = colShape[4]; int filter_width = col.dims()[2];
int channelsCol = inputChannels * filterHeight * filterWidth; int output_height = col.dims()[3];
int output_width = col.dims()[4];
for (int c = 0; c < channelsCol; ++c) { int channels_col = input_channels * filter_height * filter_width;
int wOffset = c % filterWidth;
int hOffset = (c / filterWidth) % filterHeight; T* im_data = im.data<T>();
int c_im = c / filterWidth / filterHeight; const T* col_data = col.data<T>();
for (int h = 0; h < outputHeight; ++h) {
for (int w = 0; w < outputWidth; ++w) { for (int c = 0; c < channels_col; ++c) {
int imRowIdx = h * strideHeight + hOffset; int w_offset = c % filter_width;
int imColIdx = w * strideWidth + wOffset; int h_offset = (c / filter_width) % filter_height;
if ((imRowIdx - paddingHeight) >= 0 && int c_im = c / filter_width / filter_height;
(imRowIdx - paddingHeight) < inputHeight && for (int h = 0; h < output_height; ++h) {
(imColIdx - paddingWidth) >= 0 && for (int w = 0; w < output_width; ++w) {
(imColIdx - paddingWidth) < inputWidth) { int im_row_idx = h * stride_height + h_offset;
imRowIdx += c_im * inputHeight - paddingHeight; int im_col_idx = w * stride_width + w_offset;
imColIdx -= paddingWidth; if ((im_row_idx - padding_height) >= 0 &&
imData[imRowIdx * inputWidth + imColIdx] += (im_row_idx - padding_height) < input_height &&
colData[(c * outputHeight + h) * outputWidth + w]; (im_col_idx - padding_width) >= 0 &&
(im_col_idx - padding_width) < input_width) {
im_row_idx += c_im * input_height - padding_height;
im_col_idx -= padding_width;
im_data[im_row_idx * input_width + im_col_idx] +=
col_data[(c * output_height + h) * output_width + w];
} }
} }
} }
...@@ -104,52 +115,61 @@ class Col2ImFunctor<kCFO, DEVICE_TYPE_CPU, T> { ...@@ -104,52 +115,61 @@ class Col2ImFunctor<kCFO, DEVICE_TYPE_CPU, T> {
} }
}; };
template class Im2ColFunctor<kCFO, DEVICE_TYPE_CPU, float>; template class Im2ColFunctor<kCFO, platform::CPUPlace, float>;
template class Im2ColFunctor<kCFO, DEVICE_TYPE_CPU, double>; template class Im2ColFunctor<kCFO, platform::CPUPlace, double>;
template class Col2ImFunctor<kCFO, DEVICE_TYPE_CPU, float>; template class Col2ImFunctor<kCFO, platform::CPUPlace, float>;
template class Col2ImFunctor<kCFO, DEVICE_TYPE_CPU, double>; template class Col2ImFunctor<kCFO, platform::CPUPlace, double>;
/* /*
* imShape = [inputChannels, inputHeight, inputWidth] * im = [input_channels, input_height, input_width]
* colShape = * col =
* [outputHeight, outputWidth, inputChannels, filterHeight, filterWidth] * [output_height, output_width, input_channels, filter_height, filter_width]
*/ */
template <class T> template <class T>
class Im2ColFunctor<kOCF, DEVICE_TYPE_CPU, T> { class Im2ColFunctor<kOCF, platform::CPUPlace, T> {
public: public:
void operator()(const T* imData, const TensorShape& imShape, T* colData, void operator()(const framework::Tensor& im, framework::Tensor& col,
const TensorShape& colShape, int strideHeight, int stride_height, int stride_width, int padding_height,
int strideWidth, int paddingHeight, int paddingWidth) { int padding_width) {
int inputChannels = imShape[0]; PADDLE_ENFORCE(im.dims().size() == 3);
int inputHeight = imShape[1]; PADDLE_ENFORCE(col.dims().size() == 5);
int inputWidth = imShape[2]; int input_channels = im.dims()[0];
int filterHeight = colShape[3]; int input_height = im.dims()[1];
int filterWidth = colShape[4]; int input_width = im.dims()[2];
int outputHeight = colShape[0]; int filter_height = col.dims()[3];
int outputWidth = colShape[1]; int filter_width = col.dims()[4];
for (int outputH = 0; outputH < outputHeight; ++outputH) { int output_height = col.dims()[0];
for (int outputW = 0; outputW < outputWidth; ++outputW) { int output_width = col.dims()[1];
for (int channel = 0; channel < inputChannels; ++channel) {
for (int filterH = 0; filterH < filterHeight; ++filterH) { const T* im_data = im.data<T>();
for (int filterW = 0; filterW < filterWidth; ++filterW) { T* col_data = col.data<T>();
int imRowOffset =
outputH * strideHeight + filterH - paddingHeight; for (int col_row_idx = 0; col_row_idx < output_height; ++col_row_idx) {
int imColOffset = outputW * strideWidth + filterW - paddingWidth; for (int col_col_idx = 0; col_col_idx < output_width; ++col_col_idx) {
int colDataOffset = for (int channel = 0; channel < input_channels; ++channel) {
(((outputH * outputWidth + outputW) * inputChannels + for (int filter_row_idx = 0; filter_row_idx < filter_height;
++filter_row_idx) {
for (int filter_col_idx = 0; filter_col_idx < filter_width;
++filter_col_idx) {
int im_row_offset =
col_row_idx * stride_height + filter_row_idx - padding_height;
int im_col_offset =
col_col_idx * stride_width + filter_col_idx - padding_width;
int col_offset = (((col_row_idx * output_width + col_col_idx) *
input_channels +
channel) * channel) *
filterHeight + filter_height +
filterH) * filter_row_idx) *
filterWidth + filter_width +
filterW; filter_col_idx;
if (imRowOffset < 0 || imRowOffset >= inputHeight || if (im_row_offset < 0 || im_row_offset >= input_height ||
imColOffset < 0 || imColOffset >= inputWidth) { im_col_offset < 0 || im_col_offset >= input_width) {
colData[colDataOffset] = float(0); col_data[col_offset] = T(0);
} else { } else {
int imDataOffset = int im_offset =
(channel * inputHeight + imRowOffset) * inputWidth + (channel * input_height + im_row_offset) * input_width +
imColOffset; im_col_offset;
colData[colDataOffset] = imData[imDataOffset]; col_data[col_offset] = im_data[im_offset];
} }
} }
} }
...@@ -160,44 +180,53 @@ class Im2ColFunctor<kOCF, DEVICE_TYPE_CPU, T> { ...@@ -160,44 +180,53 @@ class Im2ColFunctor<kOCF, DEVICE_TYPE_CPU, T> {
}; };
/* /*
* imShape = [inputChannels, inputHeight, inputWidth] * im = [input_channels, input_height, input_width]
* colShape = * col =
* [outputHeight, outputWidth, inputChannels, filterHeight, filterWidth] * [output_height, output_width, input_channels, filter_height, filter_width]
*/ */
template <class T> template <class T>
class Col2ImFunctor<kOCF, DEVICE_TYPE_CPU, T> { class Col2ImFunctor<kOCF, platform::CPUPlace, T> {
public: public:
void operator()(T* imData, const TensorShape& imShape, const T* colData, void operator()(framework::Tensor& im, const framework::Tensor& col,
const TensorShape& colShape, int strideHeight, int stride_height, int stride_width, int padding_height,
int strideWidth, int paddingHeight, int paddingWidth) { int padding_width) {
int inputChannels = imShape[0]; PADDLE_ENFORCE(im.dims().size() == 3);
int inputHeight = imShape[1]; PADDLE_ENFORCE(col.dims().size() == 5);
int inputWidth = imShape[2]; int input_channels = im.dims()[0];
int filterHeight = colShape[3]; int input_height = im.dims()[1];
int filterWidth = colShape[4]; int input_width = im.dims()[2];
int outputHeight = colShape[0]; int filter_height = col.dims()[3];
int outputWidth = colShape[1]; int filter_width = col.dims()[4];
for (int outputH = 0; outputH < outputHeight; ++outputH) { int output_height = col.dims()[0];
for (int outputW = 0; outputW < outputWidth; ++outputW) { int output_width = col.dims()[1];
for (int channel = 0; channel < inputChannels; ++channel) {
for (int filterH = 0; filterH < filterHeight; ++filterH) { T* im_data = im.data<T>();
for (int filterW = 0; filterW < filterWidth; ++filterW) { const T* col_data = col.data<T>();
int imRowOffset =
outputH * strideHeight + filterH - paddingHeight; for (int col_row_idx = 0; col_row_idx < output_height; ++col_row_idx) {
int imColOffset = outputW * strideWidth + filterW - paddingWidth; for (int col_col_idx = 0; col_col_idx < output_width; ++col_col_idx) {
int colDataOffset = for (int channel = 0; channel < input_channels; ++channel) {
(((outputH * outputWidth + outputW) * inputChannels + for (int filter_row_idx = 0; filter_row_idx < filter_height;
++filter_row_idx) {
for (int filter_col_idx = 0; filter_col_idx < filter_width;
++filter_col_idx) {
int im_row_offset =
col_row_idx * stride_height + filter_row_idx - padding_height;
int im_col_offset =
col_col_idx * stride_width + filter_col_idx - padding_width;
int col_offset = (((col_row_idx * output_width + col_col_idx) *
input_channels +
channel) * channel) *
filterHeight + filter_height +
filterH) * filter_row_idx) *
filterWidth + filter_width +
filterW; filter_col_idx;
if (imRowOffset >= 0 && imRowOffset < inputHeight && if (im_row_offset >= 0 && im_row_offset < input_height &&
imColOffset >= 0 && imColOffset < inputWidth) { im_col_offset >= 0 && im_col_offset < input_width) {
int imDataOffset = int im_offset =
(channel * inputHeight + imRowOffset) * inputWidth + (channel * input_height + im_row_offset) * input_width +
imColOffset; im_col_offset;
imData[imDataOffset] += colData[colDataOffset]; im_data[im_offset] += col_data[col_offset];
} }
} }
} }
...@@ -207,9 +236,9 @@ class Col2ImFunctor<kOCF, DEVICE_TYPE_CPU, T> { ...@@ -207,9 +236,9 @@ class Col2ImFunctor<kOCF, DEVICE_TYPE_CPU, T> {
} }
}; };
template class Im2ColFunctor<kOCF, DEVICE_TYPE_CPU, float>; template class Im2ColFunctor<kOCF, platform::CPUPlace, float>;
template class Im2ColFunctor<kOCF, DEVICE_TYPE_CPU, double>; template class Im2ColFunctor<kOCF, platform::CPUPlace, double>;
template class Col2ImFunctor<kOCF, DEVICE_TYPE_CPU, float>; template class Col2ImFunctor<kOCF, platform::CPUPlace, float>;
template class Col2ImFunctor<kOCF, DEVICE_TYPE_CPU, double>; template class Col2ImFunctor<kOCF, platform::CPUPlace, double>;
} // namespace paddle } // namespace paddle
...@@ -14,8 +14,8 @@ limitations under the License. */ ...@@ -14,8 +14,8 @@ limitations under the License. */
#pragma once #pragma once
#include "TensorShape.h" #include "paddle/framework/tensor.h"
#include "TensorType.h" #include "paddle/platform/device_context.h"
namespace paddle { namespace paddle {
...@@ -67,20 +67,20 @@ enum ColFormat { kCFO = 0, kOCF = 1 }; ...@@ -67,20 +67,20 @@ enum ColFormat { kCFO = 0, kOCF = 1 };
* \note The caller needs to ensure that imShape.inputChannels is equal to * \note The caller needs to ensure that imShape.inputChannels is equal to
* colShape.inputChannels. * colShape.inputChannels.
*/ */
template <ColFormat Format, DeviceType Device, class T> template <ColFormat Format, typename Place, typename T>
class Im2ColFunctor { class Im2ColFunctor {
public: public:
void operator()(const T* imData, const TensorShape& imShape, T* colData, void operator()(const framework::Tensor& im, framework::Tensor& col,
const TensorShape& colShape, int strideHeight, int stride_height, int stride_width, int padding_height,
int strideWidth, int paddingHeight, int paddingWidth); int padding_width);
}; };
template <ColFormat Format, DeviceType Device, class T> template <ColFormat Format, typename Place, typename T>
class Col2ImFunctor { class Col2ImFunctor {
public: public:
void operator()(T* imData, const TensorShape& imShape, const T* colData, void operator()(framework::Tensor& im, const framework::Tensor& col,
const TensorShape& colShape, int strideHeight, int stride_height, int stride_width, int padding_height,
int strideWidth, int paddingHeight, int paddingWidth); int padding_width);
}; };
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册