提交 c0f7ecb4 编写于 作者: H hjchen2

Optimize general col2im to speed up transpose conv

上级 33de575e
......@@ -22,10 +22,13 @@ namespace paddle_mobile {
namespace operators {
namespace math {
void ExtractToImg(const float *im_data, float *col_data, const int im_height,
const int im_width, const int col_height, const int col_width,
const int padding_h, const int padding_w, const int stride_h,
const int stride_w, const int kh, const int kw) {
template <>
void ExtractToImg<float>(const float *im_data, float *col_data,
const int im_height, const int im_width,
const int col_height, const int col_width,
const int padding_h, const int padding_w,
const int stride_h, const int stride_w, const int kh,
const int kw) {
int h = padding_h - kh;
int w = padding_w - kw;
int col_start_height = h > 0 ? (h + stride_h - 1) / stride_h : 0;
......@@ -41,48 +44,43 @@ void ExtractToImg(const float *im_data, float *col_data, const int im_height,
im_data += start_height * im_width + start_width;
col_data += col_start_height * col_width + col_start_width;
for (int i = start_height; i < end_height; i += stride_h) {
if (stride_w == 1) {
// memcpy(col_data, im_data, extract * sizeof(float));
int s = 0;
if (stride_w == 1) {
#if __ARM_NEON
for (; s < extract - 3; s += 4) {
float32x4_t img = vld1q_f32(im_data + s);
vst1q_f32(col_data + s, img);
float32x4_t _img = vld1q_f32(im_data + s);
vst1q_f32(col_data + s, _img);
}
#endif
for (; s < extract; ++s) {
col_data[s] = im_data[s];
}
} else if (stride_w == 2) {
int s = 0;
#if __ARM_NEON
for (; s < extract - 3; s += 4) {
float32x4x2_t img = vld2q_f32(im_data + s * 2);
vst1q_f32(col_data + s, img.val[0]);
float32x4x2_t _img = vld2q_f32(im_data + s * 2);
vst1q_f32(col_data + s, _img.val[0]);
}
#endif
for (; s < extract; ++s) {
col_data[s] = im_data[s * 2];
}
} else if (stride_w == 3) {
int s = 0;
#if __ARM_NEON
for (; s < extract - 3; s += 4) {
float32x4x3_t img = vld3q_f32(im_data + s * 3);
vst1q_f32(col_data + s, img.val[0]);
float32x4x3_t _img = vld3q_f32(im_data + s * 3);
vst1q_f32(col_data + s, _img.val[0]);
}
#endif
for (; s < extract; ++s) {
col_data[s] = im_data[s * 3];
}
} else if (stride_w == 4) {
int s = 0;
#if __ARM_NEON
for (; s < extract - 3; s += 4) {
float32x4x4_t img = vld4q_f32(im_data + s * 4);
vst1q_f32(col_data + s, img.val[0]);
float32x4x4_t _img = vld4q_f32(im_data + s * 4);
vst1q_f32(col_data + s, _img.val[0]);
}
#endif
for (; s < extract; ++s) {
......@@ -96,77 +94,13 @@ void ExtractToImg(const float *im_data, float *col_data, const int im_height,
}
}
/*
* im = [input_channels, input_height, input_width]
* col =
* [input_channels, filter_height, filter_width, output_height,
* output_width]
*/
template <>
void Im2ColFunctor<ColFormat::kCFO, CPU, float>::operator()(
const framework::Tensor &im, const std::vector<int> &dilation,
const std::vector<int> &stride, const std::vector<int> &padding,
framework::Tensor *col) {
int im_channels = im.dims()[0];
int im_height = im.dims()[1];
int im_width = im.dims()[2];
int filter_height = col->dims()[1];
int filter_width = col->dims()[2];
int col_height = col->dims()[3];
int col_width = col->dims()[4];
int channels_col = im_channels * filter_height * filter_width;
const float *im_data = im.data<float>();
float *col_data = col->data<float>();
#if __ARM_NEON
if (stride[0] <= 4 && dilation[0] == 1 && dilation[0] == dilation[1]) {
int im_spatial_size = im_height * im_width;
int col_spatial_size = col_height * col_width;
// pad 0
memset(col_data, 0, col->numel() * sizeof(float));
#pragma omp parallel for
for (int ic = 0; ic < im_channels; ++ic) {
const float *local_im_data = im_data + ic * im_spatial_size;
float *local_col_data =
col_data + ic * filter_height * filter_width * col_spatial_size;
for (int kh = 0; kh < filter_height; ++kh) {
for (int kw = 0; kw < filter_width; ++kw) {
ExtractToImg(local_im_data, local_col_data, im_height, im_width,
col_height, col_width, padding[0], padding[1], stride[0],
stride[1], kh, kw);
local_col_data += col_spatial_size;
}
}
}
} else {
#endif
for (int c = 0; c < channels_col; ++c) {
int w_offset = c % filter_width;
int h_offset = (c / filter_width) % filter_height;
int c_im = c / (filter_width * filter_height);
for (int h = 0; h < col_height; ++h) {
int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0];
for (int w = 0; w < col_width; ++w) {
int im_col_idx = w * stride[1] - padding[1] + w_offset * dilation[1];
int col_idx = (c * col_height + h) * col_width + w;
int im_idx = (im_row_idx + c_im * im_height) * im_width + im_col_idx;
col_data[col_idx] = (im_row_idx < 0 || im_row_idx >= im_height ||
im_col_idx < 0 || im_col_idx >= im_width)
? static_cast<float>(0)
: im_data[im_idx];
}
}
}
#if __ARM_NEON
}
#endif
}
void ExtractToImg(const int8_t *im_data, int8_t *col_data, const int im_height,
const int im_width, const int col_height, const int col_width,
const int padding_h, const int padding_w, const int stride_h,
const int stride_w, const int kh, const int kw) {
void ExtractToImg<int8_t>(const int8_t *im_data, int8_t *col_data,
const int im_height, const int im_width,
const int col_height, const int col_width,
const int padding_h, const int padding_w,
const int stride_h, const int stride_w, const int kh,
const int kw) {
int h = padding_h - kh;
int w = padding_w - kw;
int col_start_height = h > 0 ? (h + stride_h - 1) / stride_h : 0;
......@@ -183,21 +117,26 @@ void ExtractToImg(const int8_t *im_data, int8_t *col_data, const int im_height,
im_data += start_height * im_width + start_width;
col_data += col_start_height * col_width + col_start_width;
for (int i = start_height; i < end_height; i += stride_h) {
int s = 0;
if (stride_w == 1) {
memcpy(col_data, im_data, extract * sizeof(int8_t));
for (; s < extract - 15; s += 16) {
int8x16_t _img = vld1q_s8(im_data + s);
vst1q_s8(col_data + s, _img);
}
for (; s < extract; ++s) {
col_data[s] = im_data[s];
}
} else if (stride_w == 2) {
int s = 0;
#if __ARM_NEON
for (; s < extract - 15; s += 16) {
int8x16x2_t img = vld2q_s8(im_data + s * 2);
vst1q_s8(col_data + s, img.val[0]);
int8x16x2_t _img = vld2q_s8(im_data + s * 2);
vst1q_s8(col_data + s, _img.val[0]);
}
#endif
for (; s < extract; ++s) {
col_data[s] = im_data[s * 2];
}
} else if (stride_w == 3) {
int s = 0;
#if __ARM_NEON
for (; s < extract - 15; s += 16) {
int8x16x3_t img = vld3q_s8(im_data + s * 3);
......@@ -208,7 +147,6 @@ void ExtractToImg(const int8_t *im_data, int8_t *col_data, const int im_height,
col_data[s] = im_data[s * 3];
}
} else if (stride_w == 4) {
int s = 0;
#if __ARM_NEON
for (; s < extract - 15; s += 16) {
int8x16x4_t img = vld4q_s8(im_data + s * 4);
......@@ -232,11 +170,12 @@ void ExtractToImg(const int8_t *im_data, int8_t *col_data, const int im_height,
* [input_channels, filter_height, filter_width, output_height,
* output_width]
*/
template <>
void Im2ColFunctor<ColFormat::kCFO, CPU, int8_t>::operator()(
const framework::Tensor &im, const std::vector<int> &dilation,
const std::vector<int> &stride, const std::vector<int> &padding,
framework::Tensor *col) {
template <class T>
class Im2ColFunctor<ColFormat::kCFO, CPU, T> {
public:
void operator()(const framework::Tensor &im, const std::vector<int> &dilation,
const std::vector<int> &stride,
const std::vector<int> &padding, framework::Tensor *col) {
int im_channels = im.dims()[0];
int im_height = im.dims()[1];
int im_width = im.dims()[2];
......@@ -246,24 +185,25 @@ void Im2ColFunctor<ColFormat::kCFO, CPU, int8_t>::operator()(
int col_width = col->dims()[4];
int channels_col = im_channels * filter_height * filter_width;
const int8_t *im_data = im.data<int8_t>();
int8_t *col_data = col->mutable_data<int8_t>();
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
const T *im_data = im.data<T>();
T *col_data = col->data<T>();
#if __ARM_NEON
if (stride[0] <= 4 && dilation[0] == 1 && dilation[0] == dilation[1]) {
int im_spatial_size = im_height * im_width;
int col_spatial_size = col_height * col_width;
// pad 0
memset(col_data, 0, col->numel() * sizeof(int8_t));
memset(col_data, 0, col->numel() * sizeof(T));
#pragma omp parallel for
for (int ic = 0; ic < im_channels; ++ic) {
const int8_t *local_im_data = im_data + ic * im_spatial_size;
int8_t *local_col_data =
const T *local_im_data = im_data + ic * im_spatial_size;
T *local_col_data =
col_data + ic * filter_height * filter_width * col_spatial_size;
for (int kh = 0; kh < filter_height; ++kh) {
for (int kw = 0; kw < filter_width; ++kw) {
ExtractToImg(local_im_data, local_col_data, im_height, im_width,
col_height, col_width, padding[0], padding[1], stride[0],
stride[1], kh, kw);
ExtractToImg<T>(local_im_data, local_col_data, im_height, im_width,
col_height, col_width, padding[0], padding[1],
stride[0], stride[1], kh, kw);
local_col_data += col_spatial_size;
}
}
......@@ -277,20 +217,81 @@ void Im2ColFunctor<ColFormat::kCFO, CPU, int8_t>::operator()(
for (int h = 0; h < col_height; ++h) {
int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0];
for (int w = 0; w < col_width; ++w) {
int im_col_idx = w * stride[1] - padding[1] + w_offset * dilation[1];
int im_col_idx =
w * stride[1] - padding[1] + w_offset * dilation[1];
int col_idx = (c * col_height + h) * col_width + w;
int im_idx = (im_row_idx + c_im * im_height) * im_width + im_col_idx;
int im_idx =
(im_row_idx + c_im * im_height) * im_width + im_col_idx;
col_data[col_idx] = (im_row_idx < 0 || im_row_idx >= im_height ||
im_col_idx < 0 || im_col_idx >= im_width)
? static_cast<int8_t>(0)
? static_cast<T>(0)
: im_data[im_idx];
}
}
}
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
#if __ARM_NEON
}
#endif
}
};
template <>
void ExtendToImg<float>(const float *col_data, float *im_data,
const int im_height, const int im_width,
const int col_height, const int col_width,
const int padding_h, const int padding_w,
const int stride_h, const int stride_w, const int kh,
const int kw) {
int h = padding_h - kh;
int w = padding_w - kw;
int col_start_height = h > 0 ? (h + stride_h - 1) / stride_h : 0;
int col_start_width = w > 0 ? (w + stride_w - 1) / stride_w : 0;
int start_height = kh + col_start_height * stride_h - padding_h;
int start_width = kw + col_start_width * stride_w - padding_w;
int end_height = (col_height - col_start_height) * stride_h + start_height;
end_height = end_height > im_height ? im_height : end_height;
int end_width = (col_width - col_start_width) * stride_w + start_width;
end_width = end_width > im_width ? im_width : end_width;
// int extract = (end_width - start_width + stride_w - 1) / stride_w;
int extend = end_width - start_width;
im_data += start_height * im_width + start_width;
col_data += col_start_height * col_width + col_start_width;
for (int i = start_height; i < end_height; i += stride_h) {
int s = 0;
if (stride_w == 1) {
#if __ARM_NEON
for (; s < extend - 3; s += 4) {
float32x4_t _col = vld1q_f32(col_data + s);
float32x4_t _img = vld1q_f32(im_data + s);
_img = vaddq_f32(_img, _col);
vst1q_f32(im_data + s, _img);
}
#endif
for (; s < extend; ++s) {
im_data[s] += col_data[s];
}
} else if (stride_w == 2) {
#if __ARM_NEON
for (; s < extend - 7; s += 8) {
float32x4_t _col = vld1q_f32(col_data + s / 2);
float32x4x2_t _img = vld2q_f32(im_data + s);
_img.val[0] = vaddq_f32(_img.val[0], _col);
vst2q_f32(im_data + s, _img);
}
#endif
for (; s < extend; s += 2) {
im_data[s] += col_data[s / 2];
}
} else {
PADDLE_MOBILE_THROW_EXCEPTION("stride_w must be one of 1 and 2.");
}
im_data += im_width * stride_h;
col_data += col_width;
}
}
/*
......@@ -306,8 +307,6 @@ class Col2ImFunctor<ColFormat::kCFO, CPU, T> {
const std::vector<int> &dilation,
const std::vector<int> &stride,
const std::vector<int> &padding, framework::Tensor *im) {
// PADDLE_ENFORCE(im->dims().size() == 3);
// PADDLE_ENFORCE(col.dims().size() == 5);
int im_channels = im->dims()[0];
int im_height = im->dims()[1];
int im_width = im->dims()[2];
......@@ -317,11 +316,31 @@ class Col2ImFunctor<ColFormat::kCFO, CPU, T> {
int col_width = col.dims()[4];
int channels_col = im_channels * filter_height * filter_width;
T *im_data = im->data<T>();
const T *col_data = col.data<T>();
T *im_data = im->data<T>();
memset(static_cast<void *>(im_data), 0, sizeof(T) * im->numel());
#if __ARM_NEON
if (stride[0] <= 2 && dilation[0] == 1 && dilation[0] == dilation[1]) {
int im_spatial_size = im_height * im_width;
int col_spatial_size = col_height * col_width;
#pragma omp parallel for
for (int ic = 0; ic < im_channels; ++ic) {
T *local_im_data = im_data + ic * im_spatial_size;
const T *local_col_data =
col_data + ic * filter_height * filter_width * col_spatial_size;
for (int kh = 0; kh < filter_height; ++kh) {
for (int kw = 0; kw < filter_width; ++kw) {
ExtendToImg<T>(local_col_data, local_im_data, im_height, im_width,
col_height, col_width, padding[0], padding[1],
stride[0], stride[1], kh, kw);
local_col_data += col_spatial_size;
}
}
}
} else {
#endif
for (int c = 0; c < channels_col; ++c) {
int w_offset = c % filter_width;
int h_offset = (c / filter_width) % filter_height;
......@@ -329,22 +348,27 @@ class Col2ImFunctor<ColFormat::kCFO, CPU, T> {
for (int h = 0; h < col_height; ++h) {
int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0];
for (int w = 0; w < col_width; ++w) {
int im_col_idx = w * stride[1] - padding[1] + w_offset * dilation[1];
int im_col_idx =
w * stride[1] - padding[1] + w_offset * dilation[1];
if ((im_row_idx) >= 0 && (im_row_idx) < im_height &&
(im_col_idx) >= 0 && (im_col_idx) < im_width) {
im_data[(im_row_idx + c_im * im_height) * im_width + im_col_idx] +=
im_data[(im_row_idx + c_im * im_height) * im_width +
im_col_idx] +=
col_data[(c * col_height + h) * col_width + w];
}
}
}
}
#if __ARM_NEON
}
#endif
}
};
template class Im2ColFunctor<ColFormat::kCFO, CPU, float>;
template class Im2ColFunctor<ColFormat::kCFO, CPU, int8_t>;
template class Col2ImFunctor<ColFormat::kCFO, CPU, float>;
template class Col2ImFunctor<ColFormat::kCFO, CPU, int8_t>;
// template class Col2ImFunctor<ColFormat::kCFO, CPU, int8_t>;
/*
* im = [input_channels, input_height, input_width]
......
......@@ -25,6 +25,18 @@ namespace math {
* Col2ImFunctor. */
enum class ColFormat { kCFO = 0, kOCF = 1 };
template <class T>
void ExtractToImg(const T *im_data, T *col_data, const int im_height,
const int im_width, const int col_height, const int col_width,
const int padding_h, const int padding_w, const int stride_h,
const int stride_w, const int kh, const int kw);
template <class T>
void ExtendToImg(const T *col_data, T *im_data, const int im_height,
const int im_width, const int col_height, const int col_width,
const int padding_h, const int padding_w, const int stride_h,
const int stride_w, const int kh, const int kw);
/*
* \brief Converts the image data of three dimensions(CHW) into a
* colData of
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册