提交 c0f7ecb4 编写于 作者: H hjchen2

Optimize general col2im to speed up transpose conv

上级 33de575e
...@@ -22,10 +22,13 @@ namespace paddle_mobile { ...@@ -22,10 +22,13 @@ namespace paddle_mobile {
namespace operators { namespace operators {
namespace math { namespace math {
void ExtractToImg(const float *im_data, float *col_data, const int im_height, template <>
const int im_width, const int col_height, const int col_width, void ExtractToImg<float>(const float *im_data, float *col_data,
const int padding_h, const int padding_w, const int stride_h, const int im_height, const int im_width,
const int stride_w, const int kh, const int kw) { const int col_height, const int col_width,
const int padding_h, const int padding_w,
const int stride_h, const int stride_w, const int kh,
const int kw) {
int h = padding_h - kh; int h = padding_h - kh;
int w = padding_w - kw; int w = padding_w - kw;
int col_start_height = h > 0 ? (h + stride_h - 1) / stride_h : 0; int col_start_height = h > 0 ? (h + stride_h - 1) / stride_h : 0;
...@@ -41,48 +44,43 @@ void ExtractToImg(const float *im_data, float *col_data, const int im_height, ...@@ -41,48 +44,43 @@ void ExtractToImg(const float *im_data, float *col_data, const int im_height,
im_data += start_height * im_width + start_width; im_data += start_height * im_width + start_width;
col_data += col_start_height * col_width + col_start_width; col_data += col_start_height * col_width + col_start_width;
for (int i = start_height; i < end_height; i += stride_h) { for (int i = start_height; i < end_height; i += stride_h) {
int s = 0;
if (stride_w == 1) { if (stride_w == 1) {
// memcpy(col_data, im_data, extract * sizeof(float));
int s = 0;
#if __ARM_NEON #if __ARM_NEON
for (; s < extract - 3; s += 4) { for (; s < extract - 3; s += 4) {
float32x4_t img = vld1q_f32(im_data + s); float32x4_t _img = vld1q_f32(im_data + s);
vst1q_f32(col_data + s, img); vst1q_f32(col_data + s, _img);
} }
#endif #endif
for (; s < extract; ++s) { for (; s < extract; ++s) {
col_data[s] = im_data[s]; col_data[s] = im_data[s];
} }
} else if (stride_w == 2) { } else if (stride_w == 2) {
int s = 0;
#if __ARM_NEON #if __ARM_NEON
for (; s < extract - 3; s += 4) { for (; s < extract - 3; s += 4) {
float32x4x2_t img = vld2q_f32(im_data + s * 2); float32x4x2_t _img = vld2q_f32(im_data + s * 2);
vst1q_f32(col_data + s, img.val[0]); vst1q_f32(col_data + s, _img.val[0]);
} }
#endif #endif
for (; s < extract; ++s) { for (; s < extract; ++s) {
col_data[s] = im_data[s * 2]; col_data[s] = im_data[s * 2];
} }
} else if (stride_w == 3) { } else if (stride_w == 3) {
int s = 0;
#if __ARM_NEON #if __ARM_NEON
for (; s < extract - 3; s += 4) { for (; s < extract - 3; s += 4) {
float32x4x3_t img = vld3q_f32(im_data + s * 3); float32x4x3_t _img = vld3q_f32(im_data + s * 3);
vst1q_f32(col_data + s, img.val[0]); vst1q_f32(col_data + s, _img.val[0]);
} }
#endif #endif
for (; s < extract; ++s) { for (; s < extract; ++s) {
col_data[s] = im_data[s * 3]; col_data[s] = im_data[s * 3];
} }
} else if (stride_w == 4) { } else if (stride_w == 4) {
int s = 0;
#if __ARM_NEON #if __ARM_NEON
for (; s < extract - 3; s += 4) { for (; s < extract - 3; s += 4) {
float32x4x4_t img = vld4q_f32(im_data + s * 4); float32x4x4_t _img = vld4q_f32(im_data + s * 4);
vst1q_f32(col_data + s, img.val[0]); vst1q_f32(col_data + s, _img.val[0]);
} }
#endif #endif
for (; s < extract; ++s) { for (; s < extract; ++s) {
...@@ -96,77 +94,13 @@ void ExtractToImg(const float *im_data, float *col_data, const int im_height, ...@@ -96,77 +94,13 @@ void ExtractToImg(const float *im_data, float *col_data, const int im_height,
} }
} }
/*
* im = [input_channels, input_height, input_width]
* col =
* [input_channels, filter_height, filter_width, output_height,
* output_width]
*/
template <> template <>
void Im2ColFunctor<ColFormat::kCFO, CPU, float>::operator()( void ExtractToImg<int8_t>(const int8_t *im_data, int8_t *col_data,
const framework::Tensor &im, const std::vector<int> &dilation, const int im_height, const int im_width,
const std::vector<int> &stride, const std::vector<int> &padding, const int col_height, const int col_width,
framework::Tensor *col) { const int padding_h, const int padding_w,
int im_channels = im.dims()[0]; const int stride_h, const int stride_w, const int kh,
int im_height = im.dims()[1]; const int kw) {
int im_width = im.dims()[2];
int filter_height = col->dims()[1];
int filter_width = col->dims()[2];
int col_height = col->dims()[3];
int col_width = col->dims()[4];
int channels_col = im_channels * filter_height * filter_width;
const float *im_data = im.data<float>();
float *col_data = col->data<float>();
#if __ARM_NEON
if (stride[0] <= 4 && dilation[0] == 1 && dilation[0] == dilation[1]) {
int im_spatial_size = im_height * im_width;
int col_spatial_size = col_height * col_width;
// pad 0
memset(col_data, 0, col->numel() * sizeof(float));
#pragma omp parallel for
for (int ic = 0; ic < im_channels; ++ic) {
const float *local_im_data = im_data + ic * im_spatial_size;
float *local_col_data =
col_data + ic * filter_height * filter_width * col_spatial_size;
for (int kh = 0; kh < filter_height; ++kh) {
for (int kw = 0; kw < filter_width; ++kw) {
ExtractToImg(local_im_data, local_col_data, im_height, im_width,
col_height, col_width, padding[0], padding[1], stride[0],
stride[1], kh, kw);
local_col_data += col_spatial_size;
}
}
}
} else {
#endif
for (int c = 0; c < channels_col; ++c) {
int w_offset = c % filter_width;
int h_offset = (c / filter_width) % filter_height;
int c_im = c / (filter_width * filter_height);
for (int h = 0; h < col_height; ++h) {
int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0];
for (int w = 0; w < col_width; ++w) {
int im_col_idx = w * stride[1] - padding[1] + w_offset * dilation[1];
int col_idx = (c * col_height + h) * col_width + w;
int im_idx = (im_row_idx + c_im * im_height) * im_width + im_col_idx;
col_data[col_idx] = (im_row_idx < 0 || im_row_idx >= im_height ||
im_col_idx < 0 || im_col_idx >= im_width)
? static_cast<float>(0)
: im_data[im_idx];
}
}
}
#if __ARM_NEON
}
#endif
}
void ExtractToImg(const int8_t *im_data, int8_t *col_data, const int im_height,
const int im_width, const int col_height, const int col_width,
const int padding_h, const int padding_w, const int stride_h,
const int stride_w, const int kh, const int kw) {
int h = padding_h - kh; int h = padding_h - kh;
int w = padding_w - kw; int w = padding_w - kw;
int col_start_height = h > 0 ? (h + stride_h - 1) / stride_h : 0; int col_start_height = h > 0 ? (h + stride_h - 1) / stride_h : 0;
...@@ -183,21 +117,26 @@ void ExtractToImg(const int8_t *im_data, int8_t *col_data, const int im_height, ...@@ -183,21 +117,26 @@ void ExtractToImg(const int8_t *im_data, int8_t *col_data, const int im_height,
im_data += start_height * im_width + start_width; im_data += start_height * im_width + start_width;
col_data += col_start_height * col_width + col_start_width; col_data += col_start_height * col_width + col_start_width;
for (int i = start_height; i < end_height; i += stride_h) { for (int i = start_height; i < end_height; i += stride_h) {
int s = 0;
if (stride_w == 1) { if (stride_w == 1) {
memcpy(col_data, im_data, extract * sizeof(int8_t)); for (; s < extract - 15; s += 16) {
int8x16_t _img = vld1q_s8(im_data + s);
vst1q_s8(col_data + s, _img);
}
for (; s < extract; ++s) {
col_data[s] = im_data[s];
}
} else if (stride_w == 2) { } else if (stride_w == 2) {
int s = 0;
#if __ARM_NEON #if __ARM_NEON
for (; s < extract - 15; s += 16) { for (; s < extract - 15; s += 16) {
int8x16x2_t img = vld2q_s8(im_data + s * 2); int8x16x2_t _img = vld2q_s8(im_data + s * 2);
vst1q_s8(col_data + s, img.val[0]); vst1q_s8(col_data + s, _img.val[0]);
} }
#endif #endif
for (; s < extract; ++s) { for (; s < extract; ++s) {
col_data[s] = im_data[s * 2]; col_data[s] = im_data[s * 2];
} }
} else if (stride_w == 3) { } else if (stride_w == 3) {
int s = 0;
#if __ARM_NEON #if __ARM_NEON
for (; s < extract - 15; s += 16) { for (; s < extract - 15; s += 16) {
int8x16x3_t img = vld3q_s8(im_data + s * 3); int8x16x3_t img = vld3q_s8(im_data + s * 3);
...@@ -208,7 +147,6 @@ void ExtractToImg(const int8_t *im_data, int8_t *col_data, const int im_height, ...@@ -208,7 +147,6 @@ void ExtractToImg(const int8_t *im_data, int8_t *col_data, const int im_height,
col_data[s] = im_data[s * 3]; col_data[s] = im_data[s * 3];
} }
} else if (stride_w == 4) { } else if (stride_w == 4) {
int s = 0;
#if __ARM_NEON #if __ARM_NEON
for (; s < extract - 15; s += 16) { for (; s < extract - 15; s += 16) {
int8x16x4_t img = vld4q_s8(im_data + s * 4); int8x16x4_t img = vld4q_s8(im_data + s * 4);
...@@ -232,65 +170,128 @@ void ExtractToImg(const int8_t *im_data, int8_t *col_data, const int im_height, ...@@ -232,65 +170,128 @@ void ExtractToImg(const int8_t *im_data, int8_t *col_data, const int im_height,
* [input_channels, filter_height, filter_width, output_height, * [input_channels, filter_height, filter_width, output_height,
* output_width] * output_width]
*/ */
template <> template <class T>
void Im2ColFunctor<ColFormat::kCFO, CPU, int8_t>::operator()( class Im2ColFunctor<ColFormat::kCFO, CPU, T> {
const framework::Tensor &im, const std::vector<int> &dilation, public:
const std::vector<int> &stride, const std::vector<int> &padding, void operator()(const framework::Tensor &im, const std::vector<int> &dilation,
framework::Tensor *col) { const std::vector<int> &stride,
int im_channels = im.dims()[0]; const std::vector<int> &padding, framework::Tensor *col) {
int im_height = im.dims()[1]; int im_channels = im.dims()[0];
int im_width = im.dims()[2]; int im_height = im.dims()[1];
int filter_height = col->dims()[1]; int im_width = im.dims()[2];
int filter_width = col->dims()[2]; int filter_height = col->dims()[1];
int col_height = col->dims()[3]; int filter_width = col->dims()[2];
int col_width = col->dims()[4]; int col_height = col->dims()[3];
int col_width = col->dims()[4];
int channels_col = im_channels * filter_height * filter_width;
const int8_t *im_data = im.data<int8_t>(); int channels_col = im_channels * filter_height * filter_width;
int8_t *col_data = col->mutable_data<int8_t>(); const T *im_data = im.data<T>();
#if defined(__ARM_NEON__) || defined(__ARM_NEON) T *col_data = col->data<T>();
if (stride[0] <= 4 && dilation[0] == 1 && dilation[0] == dilation[1]) { #if __ARM_NEON
int im_spatial_size = im_height * im_width; if (stride[0] <= 4 && dilation[0] == 1 && dilation[0] == dilation[1]) {
int col_spatial_size = col_height * col_width; int im_spatial_size = im_height * im_width;
// pad 0 int col_spatial_size = col_height * col_width;
memset(col_data, 0, col->numel() * sizeof(int8_t)); // pad 0
#pragma omp parallel for memset(col_data, 0, col->numel() * sizeof(T));
for (int ic = 0; ic < im_channels; ++ic) {
const int8_t *local_im_data = im_data + ic * im_spatial_size; #pragma omp parallel for
int8_t *local_col_data = for (int ic = 0; ic < im_channels; ++ic) {
col_data + ic * filter_height * filter_width * col_spatial_size; const T *local_im_data = im_data + ic * im_spatial_size;
for (int kh = 0; kh < filter_height; ++kh) { T *local_col_data =
for (int kw = 0; kw < filter_width; ++kw) { col_data + ic * filter_height * filter_width * col_spatial_size;
ExtractToImg(local_im_data, local_col_data, im_height, im_width, for (int kh = 0; kh < filter_height; ++kh) {
col_height, col_width, padding[0], padding[1], stride[0], for (int kw = 0; kw < filter_width; ++kw) {
stride[1], kh, kw); ExtractToImg<T>(local_im_data, local_col_data, im_height, im_width,
local_col_data += col_spatial_size; col_height, col_width, padding[0], padding[1],
stride[0], stride[1], kh, kw);
local_col_data += col_spatial_size;
}
} }
} }
} } else {
} else {
#endif #endif
for (int c = 0; c < channels_col; ++c) { for (int c = 0; c < channels_col; ++c) {
int w_offset = c % filter_width; int w_offset = c % filter_width;
int h_offset = (c / filter_width) % filter_height; int h_offset = (c / filter_width) % filter_height;
int c_im = c / (filter_width * filter_height); int c_im = c / (filter_width * filter_height);
for (int h = 0; h < col_height; ++h) { for (int h = 0; h < col_height; ++h) {
int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0]; int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0];
for (int w = 0; w < col_width; ++w) { for (int w = 0; w < col_width; ++w) {
int im_col_idx = w * stride[1] - padding[1] + w_offset * dilation[1]; int im_col_idx =
int col_idx = (c * col_height + h) * col_width + w; w * stride[1] - padding[1] + w_offset * dilation[1];
int im_idx = (im_row_idx + c_im * im_height) * im_width + im_col_idx; int col_idx = (c * col_height + h) * col_width + w;
int im_idx =
col_data[col_idx] = (im_row_idx < 0 || im_row_idx >= im_height || (im_row_idx + c_im * im_height) * im_width + im_col_idx;
im_col_idx < 0 || im_col_idx >= im_width)
? static_cast<int8_t>(0) col_data[col_idx] = (im_row_idx < 0 || im_row_idx >= im_height ||
: im_data[im_idx]; im_col_idx < 0 || im_col_idx >= im_width)
? static_cast<T>(0)
: im_data[im_idx];
}
} }
} }
#if __ARM_NEON
} }
#if defined(__ARM_NEON__) || defined(__ARM_NEON) #endif
} }
};
template <>
void ExtendToImg<float>(const float *col_data, float *im_data,
const int im_height, const int im_width,
const int col_height, const int col_width,
const int padding_h, const int padding_w,
const int stride_h, const int stride_w, const int kh,
const int kw) {
int h = padding_h - kh;
int w = padding_w - kw;
int col_start_height = h > 0 ? (h + stride_h - 1) / stride_h : 0;
int col_start_width = w > 0 ? (w + stride_w - 1) / stride_w : 0;
int start_height = kh + col_start_height * stride_h - padding_h;
int start_width = kw + col_start_width * stride_w - padding_w;
int end_height = (col_height - col_start_height) * stride_h + start_height;
end_height = end_height > im_height ? im_height : end_height;
int end_width = (col_width - col_start_width) * stride_w + start_width;
end_width = end_width > im_width ? im_width : end_width;
// int extract = (end_width - start_width + stride_w - 1) / stride_w;
int extend = end_width - start_width;
im_data += start_height * im_width + start_width;
col_data += col_start_height * col_width + col_start_width;
for (int i = start_height; i < end_height; i += stride_h) {
int s = 0;
if (stride_w == 1) {
#if __ARM_NEON
for (; s < extend - 3; s += 4) {
float32x4_t _col = vld1q_f32(col_data + s);
float32x4_t _img = vld1q_f32(im_data + s);
_img = vaddq_f32(_img, _col);
vst1q_f32(im_data + s, _img);
}
#endif #endif
for (; s < extend; ++s) {
im_data[s] += col_data[s];
}
} else if (stride_w == 2) {
#if __ARM_NEON
for (; s < extend - 7; s += 8) {
float32x4_t _col = vld1q_f32(col_data + s / 2);
float32x4x2_t _img = vld2q_f32(im_data + s);
_img.val[0] = vaddq_f32(_img.val[0], _col);
vst2q_f32(im_data + s, _img);
}
#endif
for (; s < extend; s += 2) {
im_data[s] += col_data[s / 2];
}
} else {
PADDLE_MOBILE_THROW_EXCEPTION("stride_w must be one of 1 and 2.");
}
im_data += im_width * stride_h;
col_data += col_width;
}
} }
/* /*
...@@ -306,8 +307,6 @@ class Col2ImFunctor<ColFormat::kCFO, CPU, T> { ...@@ -306,8 +307,6 @@ class Col2ImFunctor<ColFormat::kCFO, CPU, T> {
const std::vector<int> &dilation, const std::vector<int> &dilation,
const std::vector<int> &stride, const std::vector<int> &stride,
const std::vector<int> &padding, framework::Tensor *im) { const std::vector<int> &padding, framework::Tensor *im) {
// PADDLE_ENFORCE(im->dims().size() == 3);
// PADDLE_ENFORCE(col.dims().size() == 5);
int im_channels = im->dims()[0]; int im_channels = im->dims()[0];
int im_height = im->dims()[1]; int im_height = im->dims()[1];
int im_width = im->dims()[2]; int im_width = im->dims()[2];
...@@ -317,34 +316,59 @@ class Col2ImFunctor<ColFormat::kCFO, CPU, T> { ...@@ -317,34 +316,59 @@ class Col2ImFunctor<ColFormat::kCFO, CPU, T> {
int col_width = col.dims()[4]; int col_width = col.dims()[4];
int channels_col = im_channels * filter_height * filter_width; int channels_col = im_channels * filter_height * filter_width;
T *im_data = im->data<T>();
const T *col_data = col.data<T>(); const T *col_data = col.data<T>();
T *im_data = im->data<T>();
memset(static_cast<void *>(im_data), 0, sizeof(T) * im->numel()); memset(static_cast<void *>(im_data), 0, sizeof(T) * im->numel());
for (int c = 0; c < channels_col; ++c) { #if __ARM_NEON
int w_offset = c % filter_width; if (stride[0] <= 2 && dilation[0] == 1 && dilation[0] == dilation[1]) {
int h_offset = (c / filter_width) % filter_height; int im_spatial_size = im_height * im_width;
int c_im = c / (filter_width * filter_height); int col_spatial_size = col_height * col_width;
for (int h = 0; h < col_height; ++h) {
int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0]; #pragma omp parallel for
for (int w = 0; w < col_width; ++w) { for (int ic = 0; ic < im_channels; ++ic) {
int im_col_idx = w * stride[1] - padding[1] + w_offset * dilation[1]; T *local_im_data = im_data + ic * im_spatial_size;
if ((im_row_idx) >= 0 && (im_row_idx) < im_height && const T *local_col_data =
(im_col_idx) >= 0 && (im_col_idx) < im_width) { col_data + ic * filter_height * filter_width * col_spatial_size;
im_data[(im_row_idx + c_im * im_height) * im_width + im_col_idx] += for (int kh = 0; kh < filter_height; ++kh) {
col_data[(c * col_height + h) * col_width + w]; for (int kw = 0; kw < filter_width; ++kw) {
ExtendToImg<T>(local_col_data, local_im_data, im_height, im_width,
col_height, col_width, padding[0], padding[1],
stride[0], stride[1], kh, kw);
local_col_data += col_spatial_size;
} }
} }
} }
} else {
#endif
for (int c = 0; c < channels_col; ++c) {
int w_offset = c % filter_width;
int h_offset = (c / filter_width) % filter_height;
int c_im = c / (filter_width * filter_height);
for (int h = 0; h < col_height; ++h) {
int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0];
for (int w = 0; w < col_width; ++w) {
int im_col_idx =
w * stride[1] - padding[1] + w_offset * dilation[1];
if ((im_row_idx) >= 0 && (im_row_idx) < im_height &&
(im_col_idx) >= 0 && (im_col_idx) < im_width) {
im_data[(im_row_idx + c_im * im_height) * im_width +
im_col_idx] +=
col_data[(c * col_height + h) * col_width + w];
}
}
}
}
#if __ARM_NEON
} }
#endif
} }
}; };
template class Im2ColFunctor<ColFormat::kCFO, CPU, float>; template class Im2ColFunctor<ColFormat::kCFO, CPU, float>;
template class Im2ColFunctor<ColFormat::kCFO, CPU, int8_t>; template class Im2ColFunctor<ColFormat::kCFO, CPU, int8_t>;
template class Col2ImFunctor<ColFormat::kCFO, CPU, float>; template class Col2ImFunctor<ColFormat::kCFO, CPU, float>;
template class Col2ImFunctor<ColFormat::kCFO, CPU, int8_t>; // template class Col2ImFunctor<ColFormat::kCFO, CPU, int8_t>;
/* /*
* im = [input_channels, input_height, input_width] * im = [input_channels, input_height, input_width]
......
...@@ -25,6 +25,18 @@ namespace math { ...@@ -25,6 +25,18 @@ namespace math {
* Col2ImFunctor. */ * Col2ImFunctor. */
enum class ColFormat { kCFO = 0, kOCF = 1 }; enum class ColFormat { kCFO = 0, kOCF = 1 };
template <class T>
void ExtractToImg(const T *im_data, T *col_data, const int im_height,
const int im_width, const int col_height, const int col_width,
const int padding_h, const int padding_w, const int stride_h,
const int stride_w, const int kh, const int kw);
template <class T>
void ExtendToImg(const T *col_data, T *im_data, const int im_height,
const int im_width, const int col_height, const int col_width,
const int padding_h, const int padding_w, const int stride_h,
const int stride_w, const int kh, const int kw);
/* /*
* \brief Converts the image data of three dimensions(CHW) into a * \brief Converts the image data of three dimensions(CHW) into a
* colData of * colData of
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册