提交 6beb66ab 编写于 作者: H hjchen2

Optimize col2im again

上级 5245344d
...@@ -12,12 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,12 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "operators/math/im2col.h"
#include <vector> #include <vector>
#ifdef __ARM_NEON #ifdef __ARM_NEON
#include <arm_neon.h> #include <arm_neon.h>
#endif #endif
#include <algorithm>
#include "common/types.h" #include "common/types.h"
#include "operators/math/im2col.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
namespace math { namespace math {
...@@ -294,6 +296,173 @@ void ExtendToImg<float>(const float *col_data, float *im_data, ...@@ -294,6 +296,173 @@ void ExtendToImg<float>(const float *col_data, float *im_data,
} }
} }
template <>
void ExtendToImgV2<float>(const float *col_data, float *im_data,
const int im_height, const int im_width,
const int col_height, const int col_width,
const int padding_h, const int padding_w,
const int stride_h, const int stride_w, const int kh,
const int kernel_w) {
int col_spatial_size = col_height * col_width;
int h = padding_h - kh;
int col_start_height = h > 0 ? (h + stride_h - 1) / stride_h : 0;
int start_height = kh + col_start_height * stride_h - padding_h;
int end_height = (col_height - col_start_height) * stride_h + start_height;
end_height = end_height > im_height ? im_height : end_height;
im_data += start_height * im_width;
col_data += col_start_height * col_width;
int kw = 0;
for (; kw < kernel_w - 1; kw += 2) {
int w0 = padding_w - kw;
int w1 = padding_w - (kw + 1);
int col_start_width0 = w0 > 0 ? (w0 + stride_w - 1) / stride_w : 0;
int col_start_width1 = w1 > 0 ? (w1 + stride_w - 1) / stride_w : 0;
int start_width0 = kw + col_start_width0 * stride_w - padding_w;
int start_width1 = (kw + 1) + col_start_width1 * stride_w - padding_w;
int end_width0 = (col_width - col_start_width0) * stride_w + start_width0;
end_width0 = end_width0 > im_width ? im_width : end_width0;
int end_width1 = (col_width - col_start_width1) * stride_w + start_width1;
end_width1 = end_width1 > im_width ? im_width : end_width1;
int start_width = 0;
int end_width = 0;
if (stride_w == 1) {
start_width = std::max(start_width0, start_width1);
end_width = std::min(end_width0, end_width1);
} else if (stride_w == 2) {
start_width = std::min(start_width0, start_width1);
end_width = std::min(end_width0, end_width1);
} else {
PADDLE_MOBILE_THROW_EXCEPTION("stride_w must be one of 1 and 2.");
}
// DLOG << "start_width0: " << start_width0 << ", end_width0: " <<
// end_width0; DLOG << "start_width1: " << start_width1 << ", end_width1:
// " << end_width1;
int extend = end_width - start_width;
float *im_data01 = im_data + start_width;
float *im_data0 = im_data + start_width0;
float *im_data1 = im_data + start_width1;
const float *col_data0 = col_data + col_start_width0;
const float *col_data1 = col_data + col_spatial_size + col_start_width1;
for (int i = start_height; i < end_height; i += stride_h) {
int s = 0;
if (stride_w == 1) {
int offset0 = start_width - start_width0;
int offset1 = start_width - start_width1;
for (int ss = 0; ss < start_width - start_width0; ++ss) {
im_data0[ss] += col_data0[ss];
}
for (int ss = 0; ss < start_width - start_width1; ++ss) {
im_data1[ss] += col_data1[ss];
}
#if __ARM_NEON
for (; s < extend - 3; s += 4) {
float32x4_t _col0 = vld1q_f32(col_data0 + offset0 + s);
float32x4_t _col1 = vld1q_f32(col_data1 + offset1 + s);
float32x4_t _img = vld1q_f32(im_data01 + s);
_img = vaddq_f32(_img, _col0);
_img = vaddq_f32(_img, _col1);
vst1q_f32(im_data01 + s, _img);
}
#endif
for (int ss = s; ss < end_width0 - start_width0; ++ss) {
im_data0[ss] += col_data0[ss];
}
for (int ss = s; ss < end_width1 - start_width1; ++ss) {
im_data1[ss] += col_data1[ss];
}
} else if (stride_w == 2) {
if (start_width0 < start_width1) {
#if __ARM_NEON
for (; s < extend - 7; s += 8) {
float32x4_t _col0 = vld1q_f32(col_data0 + s / 2);
float32x4_t _col1 = vld1q_f32(col_data1 + s / 2);
float32x4x2_t _img = vld2q_f32(im_data01 + s);
_img.val[0] = vaddq_f32(_img.val[0], _col0);
_img.val[1] = vaddq_f32(_img.val[1], _col1);
vst2q_f32(im_data01 + s, _img);
}
#endif
} else {
#if __ARM_NEON
for (; s < extend - 7; s += 8) {
float32x4_t _col0 = vld1q_f32(col_data0 + s / 2);
float32x4_t _col1 = vld1q_f32(col_data1 + s / 2);
float32x4x2_t _img = vld2q_f32(im_data01 + s);
_img.val[0] = vaddq_f32(_img.val[0], _col1);
_img.val[1] = vaddq_f32(_img.val[1], _col0);
vst2q_f32(im_data01 + s, _img);
}
#endif
}
for (int ss = s; ss < end_width0 - start_width0; ss += 2) {
im_data0[ss] += col_data0[ss / 2];
}
for (int ss = s; ss < end_width1 - start_width1; ss += 2) {
im_data1[ss] += col_data1[ss / 2];
}
}
im_data0 += im_width * stride_h;
im_data1 += im_width * stride_h;
im_data01 += im_width * stride_h;
col_data0 += col_width;
col_data1 += col_width;
}
col_data += 2 * col_spatial_size;
}
for (; kw < kernel_w; ++kw) {
int w = padding_w - kw;
int col_start_width = w > 0 ? (w + stride_w - 1) / stride_w : 0;
int start_width = kw + col_start_width * stride_w - padding_w;
int end_width = (col_width - col_start_width) * stride_w + start_width;
end_width = end_width > im_width ? im_width : end_width;
int extend = end_width - start_width;
float *im_data0 = im_data + start_width;
const float *col_data0 = col_data + col_start_width;
for (int i = start_height; i < end_height; i += stride_h) {
int s = 0;
if (stride_w == 1) {
#if __ARM_NEON
for (; s < extend - 3; s += 4) {
float32x4_t _col = vld1q_f32(col_data + s);
float32x4_t _img = vld1q_f32(im_data + s);
_img = vaddq_f32(_img, _col);
vst1q_f32(im_data + s, _img);
}
#endif
for (; s < extend; ++s) {
im_data[s] += col_data[s];
}
} else if (stride_w == 2) {
#if __ARM_NEON
for (; s < extend - 7; s += 8) {
float32x4_t _col = vld1q_f32(col_data + s / 2);
float32x4x2_t _img = vld2q_f32(im_data + s);
_img.val[0] = vaddq_f32(_img.val[0], _col);
vst2q_f32(im_data + s, _img);
}
#endif
for (; s < extend; s += 2) {
im_data[s] += col_data[s / 2];
}
} else {
PADDLE_MOBILE_THROW_EXCEPTION("stride_w must be one of 1 and 2.");
}
im_data += im_width * stride_h;
col_data += col_width;
}
col_data += col_spatial_size;
}
}
/* /*
* im = [input_channels, input_height, input_width] * im = [input_channels, input_height, input_width]
* col = * col =
...@@ -331,12 +500,19 @@ class Col2ImFunctor<ColFormat::kCFO, CPU, T> { ...@@ -331,12 +500,19 @@ class Col2ImFunctor<ColFormat::kCFO, CPU, T> {
const T *local_col_data = const T *local_col_data =
col_data + ic * filter_height * filter_width * col_spatial_size; col_data + ic * filter_height * filter_width * col_spatial_size;
for (int kh = 0; kh < filter_height; ++kh) { for (int kh = 0; kh < filter_height; ++kh) {
#if 0
for (int kw = 0; kw < filter_width; ++kw) { for (int kw = 0; kw < filter_width; ++kw) {
ExtendToImg<T>(local_col_data, local_im_data, im_height, im_width, ExtendToImg<T>(local_col_data, local_im_data, im_height, im_width,
col_height, col_width, padding[0], padding[1], col_height, col_width, padding[0], padding[1],
stride[0], stride[1], kh, kw); stride[0], stride[1], kh, kw);
local_col_data += col_spatial_size; local_col_data += col_spatial_size;
} }
#else
ExtendToImgV2<T>(local_col_data, local_im_data, im_height, im_width,
col_height, col_width, padding[0], padding[1],
stride[0], stride[1], kh, filter_width);
local_col_data += col_spatial_size * filter_width;
#endif
} }
} }
} else { } else {
......
...@@ -37,6 +37,13 @@ void ExtendToImg(const T *col_data, T *im_data, const int im_height, ...@@ -37,6 +37,13 @@ void ExtendToImg(const T *col_data, T *im_data, const int im_height,
const int padding_h, const int padding_w, const int stride_h, const int padding_h, const int padding_w, const int stride_h,
const int stride_w, const int kh, const int kw); const int stride_w, const int kh, const int kw);
template <class T>
void ExtendToImgV2(const T *col_data, T *im_data, const int im_height,
const int im_width, const int col_height,
const int col_width, const int padding_h,
const int padding_w, const int stride_h, const int stride_w,
const int kh, const int kernel_w);
/* /*
* \brief Converts the image data of three dimensions(CHW) into a * \brief Converts the image data of three dimensions(CHW) into a
* colData of * colData of
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册