提交 6beb66ab 编写于 作者: H hjchen2

Optimize col2im again

上级 5245344d
......@@ -12,12 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "operators/math/im2col.h"
#include <vector>
#ifdef __ARM_NEON
#include <arm_neon.h>
#endif
#include <algorithm>
#include "common/types.h"
#include "operators/math/im2col.h"
namespace paddle_mobile {
namespace operators {
namespace math {
......@@ -294,6 +296,173 @@ void ExtendToImg<float>(const float *col_data, float *im_data,
}
}
template <>
void ExtendToImgV2<float>(const float *col_data, float *im_data,
const int im_height, const int im_width,
const int col_height, const int col_width,
const int padding_h, const int padding_w,
const int stride_h, const int stride_w, const int kh,
const int kernel_w) {
int col_spatial_size = col_height * col_width;
int h = padding_h - kh;
int col_start_height = h > 0 ? (h + stride_h - 1) / stride_h : 0;
int start_height = kh + col_start_height * stride_h - padding_h;
int end_height = (col_height - col_start_height) * stride_h + start_height;
end_height = end_height > im_height ? im_height : end_height;
im_data += start_height * im_width;
col_data += col_start_height * col_width;
int kw = 0;
for (; kw < kernel_w - 1; kw += 2) {
int w0 = padding_w - kw;
int w1 = padding_w - (kw + 1);
int col_start_width0 = w0 > 0 ? (w0 + stride_w - 1) / stride_w : 0;
int col_start_width1 = w1 > 0 ? (w1 + stride_w - 1) / stride_w : 0;
int start_width0 = kw + col_start_width0 * stride_w - padding_w;
int start_width1 = (kw + 1) + col_start_width1 * stride_w - padding_w;
int end_width0 = (col_width - col_start_width0) * stride_w + start_width0;
end_width0 = end_width0 > im_width ? im_width : end_width0;
int end_width1 = (col_width - col_start_width1) * stride_w + start_width1;
end_width1 = end_width1 > im_width ? im_width : end_width1;
int start_width = 0;
int end_width = 0;
if (stride_w == 1) {
start_width = std::max(start_width0, start_width1);
end_width = std::min(end_width0, end_width1);
} else if (stride_w == 2) {
start_width = std::min(start_width0, start_width1);
end_width = std::min(end_width0, end_width1);
} else {
PADDLE_MOBILE_THROW_EXCEPTION("stride_w must be one of 1 and 2.");
}
// DLOG << "start_width0: " << start_width0 << ", end_width0: " <<
// end_width0; DLOG << "start_width1: " << start_width1 << ", end_width1:
// " << end_width1;
int extend = end_width - start_width;
float *im_data01 = im_data + start_width;
float *im_data0 = im_data + start_width0;
float *im_data1 = im_data + start_width1;
const float *col_data0 = col_data + col_start_width0;
const float *col_data1 = col_data + col_spatial_size + col_start_width1;
for (int i = start_height; i < end_height; i += stride_h) {
int s = 0;
if (stride_w == 1) {
int offset0 = start_width - start_width0;
int offset1 = start_width - start_width1;
for (int ss = 0; ss < start_width - start_width0; ++ss) {
im_data0[ss] += col_data0[ss];
}
for (int ss = 0; ss < start_width - start_width1; ++ss) {
im_data1[ss] += col_data1[ss];
}
#if __ARM_NEON
for (; s < extend - 3; s += 4) {
float32x4_t _col0 = vld1q_f32(col_data0 + offset0 + s);
float32x4_t _col1 = vld1q_f32(col_data1 + offset1 + s);
float32x4_t _img = vld1q_f32(im_data01 + s);
_img = vaddq_f32(_img, _col0);
_img = vaddq_f32(_img, _col1);
vst1q_f32(im_data01 + s, _img);
}
#endif
for (int ss = s; ss < end_width0 - start_width0; ++ss) {
im_data0[ss] += col_data0[ss];
}
for (int ss = s; ss < end_width1 - start_width1; ++ss) {
im_data1[ss] += col_data1[ss];
}
} else if (stride_w == 2) {
if (start_width0 < start_width1) {
#if __ARM_NEON
for (; s < extend - 7; s += 8) {
float32x4_t _col0 = vld1q_f32(col_data0 + s / 2);
float32x4_t _col1 = vld1q_f32(col_data1 + s / 2);
float32x4x2_t _img = vld2q_f32(im_data01 + s);
_img.val[0] = vaddq_f32(_img.val[0], _col0);
_img.val[1] = vaddq_f32(_img.val[1], _col1);
vst2q_f32(im_data01 + s, _img);
}
#endif
} else {
#if __ARM_NEON
for (; s < extend - 7; s += 8) {
float32x4_t _col0 = vld1q_f32(col_data0 + s / 2);
float32x4_t _col1 = vld1q_f32(col_data1 + s / 2);
float32x4x2_t _img = vld2q_f32(im_data01 + s);
_img.val[0] = vaddq_f32(_img.val[0], _col1);
_img.val[1] = vaddq_f32(_img.val[1], _col0);
vst2q_f32(im_data01 + s, _img);
}
#endif
}
for (int ss = s; ss < end_width0 - start_width0; ss += 2) {
im_data0[ss] += col_data0[ss / 2];
}
for (int ss = s; ss < end_width1 - start_width1; ss += 2) {
im_data1[ss] += col_data1[ss / 2];
}
}
im_data0 += im_width * stride_h;
im_data1 += im_width * stride_h;
im_data01 += im_width * stride_h;
col_data0 += col_width;
col_data1 += col_width;
}
col_data += 2 * col_spatial_size;
}
for (; kw < kernel_w; ++kw) {
int w = padding_w - kw;
int col_start_width = w > 0 ? (w + stride_w - 1) / stride_w : 0;
int start_width = kw + col_start_width * stride_w - padding_w;
int end_width = (col_width - col_start_width) * stride_w + start_width;
end_width = end_width > im_width ? im_width : end_width;
int extend = end_width - start_width;
float *im_data0 = im_data + start_width;
const float *col_data0 = col_data + col_start_width;
for (int i = start_height; i < end_height; i += stride_h) {
int s = 0;
if (stride_w == 1) {
#if __ARM_NEON
for (; s < extend - 3; s += 4) {
float32x4_t _col = vld1q_f32(col_data + s);
float32x4_t _img = vld1q_f32(im_data + s);
_img = vaddq_f32(_img, _col);
vst1q_f32(im_data + s, _img);
}
#endif
for (; s < extend; ++s) {
im_data[s] += col_data[s];
}
} else if (stride_w == 2) {
#if __ARM_NEON
for (; s < extend - 7; s += 8) {
float32x4_t _col = vld1q_f32(col_data + s / 2);
float32x4x2_t _img = vld2q_f32(im_data + s);
_img.val[0] = vaddq_f32(_img.val[0], _col);
vst2q_f32(im_data + s, _img);
}
#endif
for (; s < extend; s += 2) {
im_data[s] += col_data[s / 2];
}
} else {
PADDLE_MOBILE_THROW_EXCEPTION("stride_w must be one of 1 and 2.");
}
im_data += im_width * stride_h;
col_data += col_width;
}
col_data += col_spatial_size;
}
}
/*
* im = [input_channels, input_height, input_width]
* col =
......@@ -331,12 +500,19 @@ class Col2ImFunctor<ColFormat::kCFO, CPU, T> {
const T *local_col_data =
col_data + ic * filter_height * filter_width * col_spatial_size;
for (int kh = 0; kh < filter_height; ++kh) {
#if 0
for (int kw = 0; kw < filter_width; ++kw) {
ExtendToImg<T>(local_col_data, local_im_data, im_height, im_width,
col_height, col_width, padding[0], padding[1],
stride[0], stride[1], kh, kw);
local_col_data += col_spatial_size;
}
#else
ExtendToImgV2<T>(local_col_data, local_im_data, im_height, im_width,
col_height, col_width, padding[0], padding[1],
stride[0], stride[1], kh, filter_width);
local_col_data += col_spatial_size * filter_width;
#endif
}
}
} else {
......
......@@ -37,6 +37,13 @@ void ExtendToImg(const T *col_data, T *im_data, const int im_height,
const int padding_h, const int padding_w, const int stride_h,
const int stride_w, const int kh, const int kw);
template <class T>
void ExtendToImgV2(const T *col_data, T *im_data, const int im_height,
const int im_width, const int col_height,
const int col_width, const int padding_h,
const int padding_w, const int stride_h, const int stride_w,
const int kh, const int kernel_w);
/*
* \brief Converts the image data of three dimensions(CHW) into a
* colData of
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册