未验证 提交 73209b72 编写于 作者: H Houjiang Chen 提交者: GitHub

Merge pull request #1505 from hjchen2/backup

Optimize general col2im to speed up transpose conv
...@@ -56,10 +56,9 @@ void InitBaseConvKernel(ConvParam<CPU> *param) { ...@@ -56,10 +56,9 @@ void InitBaseConvKernel(ConvParam<CPU> *param) {
param->Strides()[0] == param->Strides()[1] && param->Strides()[0] == param->Strides()[1] &&
param->Dilations()[0] == param->Dilations()[1] && param->Dilations()[0] == param->Dilations()[1] &&
param->Strides()[0] == 1 && param->Dilations()[0] == 1 param->Strides()[0] == 1 && param->Dilations()[0] == 1
#if 0 #if 1
&& param->Output()->dims()[1] >= 16 && && (param->Input()->dims()[1] >= 4 ||
param->Input()->dims()[1] >= 16 && param->Output()->dims()[1] >= 16)
param->Input()->dims()[2] <= 140 */ /* refered from ncnn */
#endif #endif
) { ) {
param->ExecMode() = ConvParam<CPU>::EXEC_WINOGRAD3X3_FLOAT; param->ExecMode() = ConvParam<CPU>::EXEC_WINOGRAD3X3_FLOAT;
......
...@@ -12,20 +12,25 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,20 +12,25 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "operators/math/im2col.h"
#include <vector> #include <vector>
#ifdef __ARM_NEON #ifdef __ARM_NEON
#include <arm_neon.h> #include <arm_neon.h>
#endif #endif
#include <algorithm>
#include "common/types.h" #include "common/types.h"
#include "operators/math/im2col.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
namespace math { namespace math {
void ExtractToImg(const float *im_data, float *col_data, const int im_height, template <>
const int im_width, const int col_height, const int col_width, void ExtractToImg<float>(const float *im_data, float *col_data,
const int padding_h, const int padding_w, const int stride_h, const int im_height, const int im_width,
const int stride_w, const int kh, const int kw) { const int col_height, const int col_width,
const int padding_h, const int padding_w,
const int stride_h, const int stride_w, const int kh,
const int kw) {
int h = padding_h - kh; int h = padding_h - kh;
int w = padding_w - kw; int w = padding_w - kw;
int col_start_height = h > 0 ? (h + stride_h - 1) / stride_h : 0; int col_start_height = h > 0 ? (h + stride_h - 1) / stride_h : 0;
...@@ -41,48 +46,43 @@ void ExtractToImg(const float *im_data, float *col_data, const int im_height, ...@@ -41,48 +46,43 @@ void ExtractToImg(const float *im_data, float *col_data, const int im_height,
im_data += start_height * im_width + start_width; im_data += start_height * im_width + start_width;
col_data += col_start_height * col_width + col_start_width; col_data += col_start_height * col_width + col_start_width;
for (int i = start_height; i < end_height; i += stride_h) { for (int i = start_height; i < end_height; i += stride_h) {
int s = 0;
if (stride_w == 1) { if (stride_w == 1) {
// memcpy(col_data, im_data, extract * sizeof(float));
int s = 0;
#if __ARM_NEON #if __ARM_NEON
for (; s < extract - 3; s += 4) { for (; s < extract - 3; s += 4) {
float32x4_t img = vld1q_f32(im_data + s); float32x4_t _img = vld1q_f32(im_data + s);
vst1q_f32(col_data + s, img); vst1q_f32(col_data + s, _img);
} }
#endif #endif
for (; s < extract; ++s) { for (; s < extract; ++s) {
col_data[s] = im_data[s]; col_data[s] = im_data[s];
} }
} else if (stride_w == 2) { } else if (stride_w == 2) {
int s = 0;
#if __ARM_NEON #if __ARM_NEON
for (; s < extract - 3; s += 4) { for (; s < extract - 3; s += 4) {
float32x4x2_t img = vld2q_f32(im_data + s * 2); float32x4x2_t _img = vld2q_f32(im_data + s * 2);
vst1q_f32(col_data + s, img.val[0]); vst1q_f32(col_data + s, _img.val[0]);
} }
#endif #endif
for (; s < extract; ++s) { for (; s < extract; ++s) {
col_data[s] = im_data[s * 2]; col_data[s] = im_data[s * 2];
} }
} else if (stride_w == 3) { } else if (stride_w == 3) {
int s = 0;
#if __ARM_NEON #if __ARM_NEON
for (; s < extract - 3; s += 4) { for (; s < extract - 3; s += 4) {
float32x4x3_t img = vld3q_f32(im_data + s * 3); float32x4x3_t _img = vld3q_f32(im_data + s * 3);
vst1q_f32(col_data + s, img.val[0]); vst1q_f32(col_data + s, _img.val[0]);
} }
#endif #endif
for (; s < extract; ++s) { for (; s < extract; ++s) {
col_data[s] = im_data[s * 3]; col_data[s] = im_data[s * 3];
} }
} else if (stride_w == 4) { } else if (stride_w == 4) {
int s = 0;
#if __ARM_NEON #if __ARM_NEON
for (; s < extract - 3; s += 4) { for (; s < extract - 3; s += 4) {
float32x4x4_t img = vld4q_f32(im_data + s * 4); float32x4x4_t _img = vld4q_f32(im_data + s * 4);
vst1q_f32(col_data + s, img.val[0]); vst1q_f32(col_data + s, _img.val[0]);
} }
#endif #endif
for (; s < extract; ++s) { for (; s < extract; ++s) {
...@@ -96,77 +96,13 @@ void ExtractToImg(const float *im_data, float *col_data, const int im_height, ...@@ -96,77 +96,13 @@ void ExtractToImg(const float *im_data, float *col_data, const int im_height,
} }
} }
/*
* im = [input_channels, input_height, input_width]
* col =
* [input_channels, filter_height, filter_width, output_height,
* output_width]
*/
template <> template <>
void Im2ColFunctor<ColFormat::kCFO, CPU, float>::operator()( void ExtractToImg<int8_t>(const int8_t *im_data, int8_t *col_data,
const framework::Tensor &im, const std::vector<int> &dilation, const int im_height, const int im_width,
const std::vector<int> &stride, const std::vector<int> &padding, const int col_height, const int col_width,
framework::Tensor *col) { const int padding_h, const int padding_w,
int im_channels = im.dims()[0]; const int stride_h, const int stride_w, const int kh,
int im_height = im.dims()[1]; const int kw) {
int im_width = im.dims()[2];
int filter_height = col->dims()[1];
int filter_width = col->dims()[2];
int col_height = col->dims()[3];
int col_width = col->dims()[4];
int channels_col = im_channels * filter_height * filter_width;
const float *im_data = im.data<float>();
float *col_data = col->data<float>();
#if __ARM_NEON
if (stride[0] <= 4 && dilation[0] == 1 && dilation[0] == dilation[1]) {
int im_spatial_size = im_height * im_width;
int col_spatial_size = col_height * col_width;
// pad 0
memset(col_data, 0, col->numel() * sizeof(float));
#pragma omp parallel for
for (int ic = 0; ic < im_channels; ++ic) {
const float *local_im_data = im_data + ic * im_spatial_size;
float *local_col_data =
col_data + ic * filter_height * filter_width * col_spatial_size;
for (int kh = 0; kh < filter_height; ++kh) {
for (int kw = 0; kw < filter_width; ++kw) {
ExtractToImg(local_im_data, local_col_data, im_height, im_width,
col_height, col_width, padding[0], padding[1], stride[0],
stride[1], kh, kw);
local_col_data += col_spatial_size;
}
}
}
} else {
#endif
for (int c = 0; c < channels_col; ++c) {
int w_offset = c % filter_width;
int h_offset = (c / filter_width) % filter_height;
int c_im = c / (filter_width * filter_height);
for (int h = 0; h < col_height; ++h) {
int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0];
for (int w = 0; w < col_width; ++w) {
int im_col_idx = w * stride[1] - padding[1] + w_offset * dilation[1];
int col_idx = (c * col_height + h) * col_width + w;
int im_idx = (im_row_idx + c_im * im_height) * im_width + im_col_idx;
col_data[col_idx] = (im_row_idx < 0 || im_row_idx >= im_height ||
im_col_idx < 0 || im_col_idx >= im_width)
? static_cast<float>(0)
: im_data[im_idx];
}
}
}
#if __ARM_NEON
}
#endif
}
void ExtractToImg(const int8_t *im_data, int8_t *col_data, const int im_height,
const int im_width, const int col_height, const int col_width,
const int padding_h, const int padding_w, const int stride_h,
const int stride_w, const int kh, const int kw) {
int h = padding_h - kh; int h = padding_h - kh;
int w = padding_w - kw; int w = padding_w - kw;
int col_start_height = h > 0 ? (h + stride_h - 1) / stride_h : 0; int col_start_height = h > 0 ? (h + stride_h - 1) / stride_h : 0;
...@@ -183,21 +119,26 @@ void ExtractToImg(const int8_t *im_data, int8_t *col_data, const int im_height, ...@@ -183,21 +119,26 @@ void ExtractToImg(const int8_t *im_data, int8_t *col_data, const int im_height,
im_data += start_height * im_width + start_width; im_data += start_height * im_width + start_width;
col_data += col_start_height * col_width + col_start_width; col_data += col_start_height * col_width + col_start_width;
for (int i = start_height; i < end_height; i += stride_h) { for (int i = start_height; i < end_height; i += stride_h) {
int s = 0;
if (stride_w == 1) { if (stride_w == 1) {
memcpy(col_data, im_data, extract * sizeof(int8_t)); for (; s < extract - 15; s += 16) {
int8x16_t _img = vld1q_s8(im_data + s);
vst1q_s8(col_data + s, _img);
}
for (; s < extract; ++s) {
col_data[s] = im_data[s];
}
} else if (stride_w == 2) { } else if (stride_w == 2) {
int s = 0;
#if __ARM_NEON #if __ARM_NEON
for (; s < extract - 15; s += 16) { for (; s < extract - 15; s += 16) {
int8x16x2_t img = vld2q_s8(im_data + s * 2); int8x16x2_t _img = vld2q_s8(im_data + s * 2);
vst1q_s8(col_data + s, img.val[0]); vst1q_s8(col_data + s, _img.val[0]);
} }
#endif #endif
for (; s < extract; ++s) { for (; s < extract; ++s) {
col_data[s] = im_data[s * 2]; col_data[s] = im_data[s * 2];
} }
} else if (stride_w == 3) { } else if (stride_w == 3) {
int s = 0;
#if __ARM_NEON #if __ARM_NEON
for (; s < extract - 15; s += 16) { for (; s < extract - 15; s += 16) {
int8x16x3_t img = vld3q_s8(im_data + s * 3); int8x16x3_t img = vld3q_s8(im_data + s * 3);
...@@ -208,7 +149,6 @@ void ExtractToImg(const int8_t *im_data, int8_t *col_data, const int im_height, ...@@ -208,7 +149,6 @@ void ExtractToImg(const int8_t *im_data, int8_t *col_data, const int im_height,
col_data[s] = im_data[s * 3]; col_data[s] = im_data[s * 3];
} }
} else if (stride_w == 4) { } else if (stride_w == 4) {
int s = 0;
#if __ARM_NEON #if __ARM_NEON
for (; s < extract - 15; s += 16) { for (; s < extract - 15; s += 16) {
int8x16x4_t img = vld4q_s8(im_data + s * 4); int8x16x4_t img = vld4q_s8(im_data + s * 4);
...@@ -232,65 +172,295 @@ void ExtractToImg(const int8_t *im_data, int8_t *col_data, const int im_height, ...@@ -232,65 +172,295 @@ void ExtractToImg(const int8_t *im_data, int8_t *col_data, const int im_height,
* [input_channels, filter_height, filter_width, output_height, * [input_channels, filter_height, filter_width, output_height,
* output_width] * output_width]
*/ */
template <> template <class T>
void Im2ColFunctor<ColFormat::kCFO, CPU, int8_t>::operator()( class Im2ColFunctor<ColFormat::kCFO, CPU, T> {
const framework::Tensor &im, const std::vector<int> &dilation, public:
const std::vector<int> &stride, const std::vector<int> &padding, void operator()(const framework::Tensor &im, const std::vector<int> &dilation,
framework::Tensor *col) { const std::vector<int> &stride,
int im_channels = im.dims()[0]; const std::vector<int> &padding, framework::Tensor *col) {
int im_height = im.dims()[1]; int im_channels = im.dims()[0];
int im_width = im.dims()[2]; int im_height = im.dims()[1];
int filter_height = col->dims()[1]; int im_width = im.dims()[2];
int filter_width = col->dims()[2]; int filter_height = col->dims()[1];
int col_height = col->dims()[3]; int filter_width = col->dims()[2];
int col_width = col->dims()[4]; int col_height = col->dims()[3];
int col_width = col->dims()[4];
int channels_col = im_channels * filter_height * filter_width;
const int8_t *im_data = im.data<int8_t>(); int channels_col = im_channels * filter_height * filter_width;
int8_t *col_data = col->mutable_data<int8_t>(); const T *im_data = im.data<T>();
#if defined(__ARM_NEON__) || defined(__ARM_NEON) T *col_data = col->data<T>();
if (stride[0] <= 4 && dilation[0] == 1 && dilation[0] == dilation[1]) { #if __ARM_NEON
int im_spatial_size = im_height * im_width; if (stride[0] <= 4 && dilation[0] == 1 && dilation[0] == dilation[1]) {
int col_spatial_size = col_height * col_width; int im_spatial_size = im_height * im_width;
// pad 0 int col_spatial_size = col_height * col_width;
memset(col_data, 0, col->numel() * sizeof(int8_t)); // pad 0
#pragma omp parallel for memset(col_data, 0, col->numel() * sizeof(T));
for (int ic = 0; ic < im_channels; ++ic) {
const int8_t *local_im_data = im_data + ic * im_spatial_size; #pragma omp parallel for
int8_t *local_col_data = for (int ic = 0; ic < im_channels; ++ic) {
col_data + ic * filter_height * filter_width * col_spatial_size; const T *local_im_data = im_data + ic * im_spatial_size;
for (int kh = 0; kh < filter_height; ++kh) { T *local_col_data =
for (int kw = 0; kw < filter_width; ++kw) { col_data + ic * filter_height * filter_width * col_spatial_size;
ExtractToImg(local_im_data, local_col_data, im_height, im_width, for (int kh = 0; kh < filter_height; ++kh) {
col_height, col_width, padding[0], padding[1], stride[0], for (int kw = 0; kw < filter_width; ++kw) {
stride[1], kh, kw); ExtractToImg<T>(local_im_data, local_col_data, im_height, im_width,
local_col_data += col_spatial_size; col_height, col_width, padding[0], padding[1],
stride[0], stride[1], kh, kw);
local_col_data += col_spatial_size;
}
}
}
} else {
#endif
for (int c = 0; c < channels_col; ++c) {
int w_offset = c % filter_width;
int h_offset = (c / filter_width) % filter_height;
int c_im = c / (filter_width * filter_height);
for (int h = 0; h < col_height; ++h) {
int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0];
for (int w = 0; w < col_width; ++w) {
int im_col_idx =
w * stride[1] - padding[1] + w_offset * dilation[1];
int col_idx = (c * col_height + h) * col_width + w;
int im_idx =
(im_row_idx + c_im * im_height) * im_width + im_col_idx;
col_data[col_idx] = (im_row_idx < 0 || im_row_idx >= im_height ||
im_col_idx < 0 || im_col_idx >= im_width)
? static_cast<T>(0)
: im_data[im_idx];
}
} }
} }
#if __ARM_NEON
} }
} else {
#endif #endif
for (int c = 0; c < channels_col; ++c) { }
int w_offset = c % filter_width; };
int h_offset = (c / filter_width) % filter_height;
int c_im = c / (filter_width * filter_height); template <>
for (int h = 0; h < col_height; ++h) { void ExtendToImg<float>(const float *col_data, float *im_data,
int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0]; const int im_height, const int im_width,
for (int w = 0; w < col_width; ++w) { const int col_height, const int col_width,
int im_col_idx = w * stride[1] - padding[1] + w_offset * dilation[1]; const int padding_h, const int padding_w,
int col_idx = (c * col_height + h) * col_width + w; const int stride_h, const int stride_w, const int kh,
int im_idx = (im_row_idx + c_im * im_height) * im_width + im_col_idx; const int kw) {
int h = padding_h - kh;
col_data[col_idx] = (im_row_idx < 0 || im_row_idx >= im_height || int w = padding_w - kw;
im_col_idx < 0 || im_col_idx >= im_width) int col_start_height = h > 0 ? (h + stride_h - 1) / stride_h : 0;
? static_cast<int8_t>(0) int col_start_width = w > 0 ? (w + stride_w - 1) / stride_w : 0;
: im_data[im_idx]; int start_height = kh + col_start_height * stride_h - padding_h;
int start_width = kw + col_start_width * stride_w - padding_w;
int end_height = (col_height - col_start_height) * stride_h + start_height;
end_height = end_height > im_height ? im_height : end_height;
int end_width = (col_width - col_start_width) * stride_w + start_width;
end_width = end_width > im_width ? im_width : end_width;
// int extract = (end_width - start_width + stride_w - 1) / stride_w;
int extend = end_width - start_width;
im_data += start_height * im_width + start_width;
col_data += col_start_height * col_width + col_start_width;
for (int i = start_height; i < end_height; i += stride_h) {
int s = 0;
if (stride_w == 1) {
#if __ARM_NEON
for (; s < extend - 3; s += 4) {
float32x4_t _col = vld1q_f32(col_data + s);
float32x4_t _img = vld1q_f32(im_data + s);
_img = vaddq_f32(_img, _col);
vst1q_f32(im_data + s, _img);
}
#endif
for (; s < extend; ++s) {
im_data[s] += col_data[s];
}
} else if (stride_w == 2) {
#if __ARM_NEON
for (; s < extend - 7; s += 8) {
float32x4_t _col = vld1q_f32(col_data + s / 2);
float32x4x2_t _img = vld2q_f32(im_data + s);
_img.val[0] = vaddq_f32(_img.val[0], _col);
vst2q_f32(im_data + s, _img);
}
#endif
for (; s < extend; s += 2) {
im_data[s] += col_data[s / 2];
}
} else {
PADDLE_MOBILE_THROW_EXCEPTION("stride_w must be one of 1 and 2.");
}
im_data += im_width * stride_h;
col_data += col_width;
}
}
template <>
void ExtendToImgV2<float>(const float *col_data, float *im_data,
const int im_height, const int im_width,
const int col_height, const int col_width,
const int padding_h, const int padding_w,
const int stride_h, const int stride_w, const int kh,
const int kernel_w) {
int col_spatial_size = col_height * col_width;
int h = padding_h - kh;
int col_start_height = h > 0 ? (h + stride_h - 1) / stride_h : 0;
int start_height = kh + col_start_height * stride_h - padding_h;
int end_height = (col_height - col_start_height) * stride_h + start_height;
end_height = end_height > im_height ? im_height : end_height;
im_data += start_height * im_width;
col_data += col_start_height * col_width;
int kw = 0;
for (; kw < kernel_w - 1; kw += 2) {
int w0 = padding_w - kw;
int w1 = padding_w - (kw + 1);
int col_start_width0 = w0 > 0 ? (w0 + stride_w - 1) / stride_w : 0;
int col_start_width1 = w1 > 0 ? (w1 + stride_w - 1) / stride_w : 0;
int start_width0 = kw + col_start_width0 * stride_w - padding_w;
int start_width1 = (kw + 1) + col_start_width1 * stride_w - padding_w;
int end_width0 = (col_width - col_start_width0) * stride_w + start_width0;
end_width0 = end_width0 > im_width ? im_width : end_width0;
int end_width1 = (col_width - col_start_width1) * stride_w + start_width1;
end_width1 = end_width1 > im_width ? im_width : end_width1;
int start_width = 0;
int end_width = 0;
if (stride_w == 1) {
start_width = std::max(start_width0, start_width1);
end_width = std::min(end_width0, end_width1);
} else if (stride_w == 2) {
start_width = std::min(start_width0, start_width1);
end_width = std::min(end_width0, end_width1);
} else {
PADDLE_MOBILE_THROW_EXCEPTION("stride_w must be one of 1 and 2.");
}
// DLOG << "start_width0: " << start_width0 << ", end_width0: " <<
// end_width0; DLOG << "start_width1: " << start_width1 << ", end_width1:
// " << end_width1;
int extend = end_width - start_width;
float *im_data01 = im_data + start_width;
float *im_data0 = im_data + start_width0;
float *im_data1 = im_data + start_width1;
const float *col_data0 = col_data + col_start_width0;
const float *col_data1 = col_data + col_spatial_size + col_start_width1;
for (int i = start_height; i < end_height; i += stride_h) {
int s = 0;
if (stride_w == 1) {
int offset0 = start_width - start_width0;
int offset1 = start_width - start_width1;
for (int ss = 0; ss < start_width - start_width0; ++ss) {
im_data0[ss] += col_data0[ss];
}
for (int ss = 0; ss < start_width - start_width1; ++ss) {
im_data1[ss] += col_data1[ss];
}
#if __ARM_NEON
for (; s < extend - 3; s += 4) {
float32x4_t _col0 = vld1q_f32(col_data0 + offset0 + s);
float32x4_t _col1 = vld1q_f32(col_data1 + offset1 + s);
float32x4_t _img = vld1q_f32(im_data01 + s);
_img = vaddq_f32(_img, _col0);
_img = vaddq_f32(_img, _col1);
vst1q_f32(im_data01 + s, _img);
}
#endif
for (int ss = s; ss < end_width0 - start_width0; ++ss) {
im_data0[ss] += col_data0[ss];
}
for (int ss = s; ss < end_width1 - start_width1; ++ss) {
im_data1[ss] += col_data1[ss];
}
} else if (stride_w == 2) {
if (start_width0 < start_width1) {
#if __ARM_NEON
for (; s < extend - 7; s += 8) {
float32x4_t _col0 = vld1q_f32(col_data0 + s / 2);
float32x4_t _col1 = vld1q_f32(col_data1 + s / 2);
float32x4x2_t _img = vld2q_f32(im_data01 + s);
_img.val[0] = vaddq_f32(_img.val[0], _col0);
_img.val[1] = vaddq_f32(_img.val[1], _col1);
vst2q_f32(im_data01 + s, _img);
}
#endif
} else {
#if __ARM_NEON
for (; s < extend - 7; s += 8) {
float32x4_t _col0 = vld1q_f32(col_data0 + s / 2);
float32x4_t _col1 = vld1q_f32(col_data1 + s / 2);
float32x4x2_t _img = vld2q_f32(im_data01 + s);
_img.val[0] = vaddq_f32(_img.val[0], _col1);
_img.val[1] = vaddq_f32(_img.val[1], _col0);
vst2q_f32(im_data01 + s, _img);
}
#endif
}
for (int ss = s; ss < end_width0 - start_width0; ss += 2) {
im_data0[ss] += col_data0[ss / 2];
}
for (int ss = s; ss < end_width1 - start_width1; ss += 2) {
im_data1[ss] += col_data1[ss / 2];
} }
} }
im_data0 += im_width * stride_h;
im_data1 += im_width * stride_h;
im_data01 += im_width * stride_h;
col_data0 += col_width;
col_data1 += col_width;
} }
#if defined(__ARM_NEON__) || defined(__ARM_NEON) col_data += 2 * col_spatial_size;
} }
for (; kw < kernel_w; ++kw) {
int w = padding_w - kw;
int col_start_width = w > 0 ? (w + stride_w - 1) / stride_w : 0;
int start_width = kw + col_start_width * stride_w - padding_w;
int end_width = (col_width - col_start_width) * stride_w + start_width;
end_width = end_width > im_width ? im_width : end_width;
int extend = end_width - start_width;
float *im_data0 = im_data + start_width;
const float *col_data0 = col_data + col_start_width;
for (int i = start_height; i < end_height; i += stride_h) {
int s = 0;
if (stride_w == 1) {
#if __ARM_NEON
for (; s < extend - 3; s += 4) {
float32x4_t _col = vld1q_f32(col_data + s);
float32x4_t _img = vld1q_f32(im_data + s);
_img = vaddq_f32(_img, _col);
vst1q_f32(im_data + s, _img);
}
#endif #endif
for (; s < extend; ++s) {
im_data[s] += col_data[s];
}
} else if (stride_w == 2) {
#if __ARM_NEON
for (; s < extend - 7; s += 8) {
float32x4_t _col = vld1q_f32(col_data + s / 2);
float32x4x2_t _img = vld2q_f32(im_data + s);
_img.val[0] = vaddq_f32(_img.val[0], _col);
vst2q_f32(im_data + s, _img);
}
#endif
for (; s < extend; s += 2) {
im_data[s] += col_data[s / 2];
}
} else {
PADDLE_MOBILE_THROW_EXCEPTION("stride_w must be one of 1 and 2.");
}
im_data += im_width * stride_h;
col_data += col_width;
}
col_data += col_spatial_size;
}
} }
/* /*
...@@ -306,8 +476,6 @@ class Col2ImFunctor<ColFormat::kCFO, CPU, T> { ...@@ -306,8 +476,6 @@ class Col2ImFunctor<ColFormat::kCFO, CPU, T> {
const std::vector<int> &dilation, const std::vector<int> &dilation,
const std::vector<int> &stride, const std::vector<int> &stride,
const std::vector<int> &padding, framework::Tensor *im) { const std::vector<int> &padding, framework::Tensor *im) {
// PADDLE_ENFORCE(im->dims().size() == 3);
// PADDLE_ENFORCE(col.dims().size() == 5);
int im_channels = im->dims()[0]; int im_channels = im->dims()[0];
int im_height = im->dims()[1]; int im_height = im->dims()[1];
int im_width = im->dims()[2]; int im_width = im->dims()[2];
...@@ -317,34 +485,66 @@ class Col2ImFunctor<ColFormat::kCFO, CPU, T> { ...@@ -317,34 +485,66 @@ class Col2ImFunctor<ColFormat::kCFO, CPU, T> {
int col_width = col.dims()[4]; int col_width = col.dims()[4];
int channels_col = im_channels * filter_height * filter_width; int channels_col = im_channels * filter_height * filter_width;
T *im_data = im->data<T>();
const T *col_data = col.data<T>(); const T *col_data = col.data<T>();
T *im_data = im->data<T>();
memset(static_cast<void *>(im_data), 0, sizeof(T) * im->numel()); memset(static_cast<void *>(im_data), 0, sizeof(T) * im->numel());
for (int c = 0; c < channels_col; ++c) { #if __ARM_NEON
int w_offset = c % filter_width; if (stride[0] <= 2 && dilation[0] == 1 && dilation[0] == dilation[1]) {
int h_offset = (c / filter_width) % filter_height; int im_spatial_size = im_height * im_width;
int c_im = c / (filter_width * filter_height); int col_spatial_size = col_height * col_width;
for (int h = 0; h < col_height; ++h) {
int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0]; #pragma omp parallel for
for (int w = 0; w < col_width; ++w) { for (int ic = 0; ic < im_channels; ++ic) {
int im_col_idx = w * stride[1] - padding[1] + w_offset * dilation[1]; T *local_im_data = im_data + ic * im_spatial_size;
if ((im_row_idx) >= 0 && (im_row_idx) < im_height && const T *local_col_data =
(im_col_idx) >= 0 && (im_col_idx) < im_width) { col_data + ic * filter_height * filter_width * col_spatial_size;
im_data[(im_row_idx + c_im * im_height) * im_width + im_col_idx] += for (int kh = 0; kh < filter_height; ++kh) {
col_data[(c * col_height + h) * col_width + w]; #if 0
for (int kw = 0; kw < filter_width; ++kw) {
ExtendToImg<T>(local_col_data, local_im_data, im_height, im_width,
col_height, col_width, padding[0], padding[1],
stride[0], stride[1], kh, kw);
local_col_data += col_spatial_size;
} }
#else
ExtendToImgV2<T>(local_col_data, local_im_data, im_height, im_width,
col_height, col_width, padding[0], padding[1],
stride[0], stride[1], kh, filter_width);
local_col_data += col_spatial_size * filter_width;
#endif
} }
} }
} else {
#endif
for (int c = 0; c < channels_col; ++c) {
int w_offset = c % filter_width;
int h_offset = (c / filter_width) % filter_height;
int c_im = c / (filter_width * filter_height);
for (int h = 0; h < col_height; ++h) {
int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0];
for (int w = 0; w < col_width; ++w) {
int im_col_idx =
w * stride[1] - padding[1] + w_offset * dilation[1];
if ((im_row_idx) >= 0 && (im_row_idx) < im_height &&
(im_col_idx) >= 0 && (im_col_idx) < im_width) {
im_data[(im_row_idx + c_im * im_height) * im_width +
im_col_idx] +=
col_data[(c * col_height + h) * col_width + w];
}
}
}
}
#if __ARM_NEON
} }
#endif
} }
}; };
template class Im2ColFunctor<ColFormat::kCFO, CPU, float>; template class Im2ColFunctor<ColFormat::kCFO, CPU, float>;
template class Im2ColFunctor<ColFormat::kCFO, CPU, int8_t>; template class Im2ColFunctor<ColFormat::kCFO, CPU, int8_t>;
template class Col2ImFunctor<ColFormat::kCFO, CPU, float>; template class Col2ImFunctor<ColFormat::kCFO, CPU, float>;
template class Col2ImFunctor<ColFormat::kCFO, CPU, int8_t>; // template class Col2ImFunctor<ColFormat::kCFO, CPU, int8_t>;
/* /*
* im = [input_channels, input_height, input_width] * im = [input_channels, input_height, input_width]
......
...@@ -25,6 +25,25 @@ namespace math { ...@@ -25,6 +25,25 @@ namespace math {
* Col2ImFunctor. */ * Col2ImFunctor. */
enum class ColFormat { kCFO = 0, kOCF = 1 }; enum class ColFormat { kCFO = 0, kOCF = 1 };
template <class T>
void ExtractToImg(const T *im_data, T *col_data, const int im_height,
const int im_width, const int col_height, const int col_width,
const int padding_h, const int padding_w, const int stride_h,
const int stride_w, const int kh, const int kw);
template <class T>
void ExtendToImg(const T *col_data, T *im_data, const int im_height,
const int im_width, const int col_height, const int col_width,
const int padding_h, const int padding_w, const int stride_h,
const int stride_w, const int kh, const int kw);
template <class T>
void ExtendToImgV2(const T *col_data, T *im_data, const int im_height,
const int im_width, const int col_height,
const int col_width, const int padding_h,
const int padding_w, const int stride_h, const int stride_w,
const int kh, const int kernel_w);
/* /*
* \brief Converts the image data of three dimensions(CHW) into a * \brief Converts the image data of three dimensions(CHW) into a
* colData of * colData of
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册