提交 4664b8f5 编写于 作者: R Ruilong Liu 提交者: GitHub

Merge pull request #429 from Eclipsess/develop

fix #428 optimize im2col
...@@ -14,8 +14,10 @@ limitations under the License. */ ...@@ -14,8 +14,10 @@ limitations under the License. */
#include "operators/math/im2col.h" #include "operators/math/im2col.h"
#include <vector> #include <vector>
#ifdef __ARM_NEON
#include "arm_neon.h"
#endif
#include "common/types.h" #include "common/types.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
namespace math { namespace math {
...@@ -65,9 +67,350 @@ class Im2ColFunctor<ColFormat::kCFO, CPU, T> { ...@@ -65,9 +67,350 @@ class Im2ColFunctor<ColFormat::kCFO, CPU, T> {
// are " "inconsistent."); // are " "inconsistent.");
int channels_col = im_channels * filter_height * filter_width; int channels_col = im_channels * filter_height * filter_width;
const T *im_data = im.data<T>(); const T *im_data = im.data<T>();
T *col_data = col->data<T>(); T *col_data = col->data<T>();
#ifdef __ARM_NEON
const int osize = col_height;
const int isize = im_height;
bool pad1 = padding[0] > 0;
bool pad2 =
(pad1 &&
(((isize - 2 * padding[0] + filter_height) % stride[0] == 0) ? 1 : 0));
int fill = isize % 2;
if (stride[0] == 1 && filter_height == 3 && pad1 && pad2 &&
dilation[0] == 1) {
for (int c = 0; c < im_channels; ++c) {
int oosize = osize * osize;
int nk4 = osize / 4;
int mk4 = osize % 4;
float *col0 = col_data + 0 * oosize + 2 * osize + 2;
float *col1 = col_data + 1 * oosize + 2 * osize + 1;
float *col2 = col_data + 2 * oosize + 2 * osize;
float *col3 = col_data + 3 * oosize + osize + 2;
float *col4 = col_data + 4 * oosize + osize + 1;
float *col5 = col_data + 5 * oosize + osize;
float *col6 = col_data + 6 * oosize + 2;
float *col7 = col_data + 7 * oosize + 1;
float *col8 = col_data + 8 * oosize;
float32x4_t im1;
const float *im_tmp_data = im_data + osize + 1;
int rrsize = oosize - osize - 1;
int nr4 = rrsize / 4;
int mr4 = rrsize % 4;
for (int i = 0; i < nr4; ++i) {
im1 = vld1q_f32(im_tmp_data);
vst1q_f32(col0, im1);
vst1q_f32(col1, im1);
vst1q_f32(col2, im1);
vst1q_f32(col3, im1);
vst1q_f32(col4, im1);
vst1q_f32(col5, im1);
vst1q_f32(col6, im1);
vst1q_f32(col7, im1);
vst1q_f32(col8, im1);
col0 += 4;
col1 += 4;
col2 += 4;
col3 += 4;
col4 += 4;
col5 += 4;
col6 += 4;
col7 += 4;
col8 += 4;
im_tmp_data += 4;
}
for (int i = 0; i < mr4; ++i) {
*col0 = *im_tmp_data;
*col1 = *im_tmp_data;
*col2 = *im_tmp_data;
*col3 = *im_tmp_data;
*col4 = *im_tmp_data;
*col5 = *im_tmp_data;
*col6 = *im_tmp_data;
*col7 = *im_tmp_data;
*col8 = *im_tmp_data;
col0++;
col1++;
col2++;
col3++;
col4++;
col5++;
col6++;
col7++;
col8++;
im_tmp_data++;
}
im_tmp_data = im_data + 1;
col0 = col_data + 0 * oosize + osize + 2;
col1 = col_data + 1 * oosize + osize + 1;
col2 = col_data + 2 * oosize + osize;
col3 = col_data + 3 * oosize + 2;
col4 = col_data + 4 * oosize + 1;
col5 = col_data + 5 * oosize;
for (int i = 0; i < nk4; i++) {
im1 = vld1q_f32(im_tmp_data);
vst1q_f32(col0, im1);
vst1q_f32(col1, im1);
vst1q_f32(col2, im1);
vst1q_f32(col3, im1);
vst1q_f32(col4, im1);
vst1q_f32(col5, im1);
col0 += 4;
col1 += 4;
col2 += 4;
col3 += 4;
col4 += 4;
col5 += 4;
im_tmp_data += 4;
}
for (int i = 0; i < mk4; i++) {
*col0 = *im_tmp_data;
*col1 = *im_tmp_data;
*col2 = *im_tmp_data;
*col3 = *im_tmp_data;
*col4 = *im_tmp_data;
*col5 = *im_tmp_data;
col0++;
col1++;
col2++;
col3++;
col4++;
col5++;
im_tmp_data++;
}
// fill 0 1 11;
for (int i = 0; i < osize; ++i) {
col_data[0 * oosize + i * osize] = 0.0;
col_data[3 * oosize + i * osize] = 0.0;
col_data[6 * oosize + i * osize] = 0.0;
col_data[2 * oosize + osize - 1 + i * osize] = 0.0;
col_data[5 * oosize + osize - 1 + i * osize] = 0.0;
col_data[8 * oosize + osize - 1 + i * osize] = 0.0;
}
col_data[0 * oosize + osize + 1] = im_data[0];
col_data[3 * oosize + 1] = im_data[0];
col_data[6 * oosize + 1] = im_data[osize];
col_data[1 * oosize + osize] = im_data[0];
col_data[4 * oosize] = im_data[0];
col_data[7 * oosize] = im_data[osize];
float32x4_t zero4;
zero4 = vdupq_n_f32(0.0);
auto col_z0 = col_data;
auto col_z1 = col_data + oosize;
auto col_z2 = col_data + 2 * oosize;
auto col_z6 = col_data + 6 * oosize + osize * (osize - 1);
auto col_z7 = col_data + 7 * oosize + osize * (osize - 1);
auto col_z8 = col_data + 8 * oosize + osize * (osize - 1);
for (int i = 0; i < nk4; ++i) {
vst1q_f32(col_z0, zero4);
vst1q_f32(col_z1, zero4);
vst1q_f32(col_z2, zero4);
vst1q_f32(col_z6, zero4);
vst1q_f32(col_z7, zero4);
vst1q_f32(col_z8, zero4);
col_z0 += 4;
col_z1 += 4;
col_z2 += 4;
col_z6 += 4;
col_z7 += 4;
col_z8 += 4;
}
for (int i = 0; i < mk4; ++i) {
col_z0[i] = 0.0;
col_z1[i] = 0.0;
col_z2[i] = 0.0;
col_z6[i] = 0.0;
col_z7[i] = 0.0;
col_z8[i] = 0.0;
}
col_data += 9 * oosize;
im_data += isize * isize;
}
} else if (stride[0] == 2 && filter_height == 3 && pad1 &&
dilation[0] == 1) {
for (int c = 0; c < im_channels; ++c) {
int oosize = osize * osize;
int nk4 = osize / 4;
int mk4 = osize % 4;
// 3 2 3 1 0 1 3 2 3
float *col0 = col_data + 0 * oosize + osize + 1;
float *col1 = col_data + 1 * oosize + osize;
float *col2 = col_data + 2 * oosize + osize;
float *col3 = col_data + 3 * oosize + 1;
float *col4 = col_data + 4 * oosize;
float *col5 = col_data + 5 * oosize;
float *col6 = col_data + 6 * oosize + 1;
float *col7 = col_data + 7 * oosize;
float *col8 = col_data + 8 * oosize;
float32x4x2_t im01;
float32x4x2_t im23;
const float *im_tmp_data0 = im_data;
const float *im_tmp_data2 = im_data + isize;
for (int j = 0; j < osize; ++j) {
for (int i = 0; i < nk4; ++i) {
im01 = vld2q_f32(im_tmp_data0);
im23 = vld2q_f32(im_tmp_data2);
vst1q_f32(col0, im23.val[1]);
vst1q_f32(col1, im23.val[0]);
vst1q_f32(col2, im23.val[1]);
vst1q_f32(col3, im01.val[1]);
vst1q_f32(col4, im01.val[0]);
vst1q_f32(col5, im01.val[1]);
vst1q_f32(col6, im23.val[1]);
vst1q_f32(col7, im23.val[0]);
vst1q_f32(col8, im23.val[1]);
col0 += 4;
col1 += 4;
col2 += 4;
col3 += 4;
col4 += 4;
col5 += 4;
col6 += 4;
col7 += 4;
col8 += 4;
im_tmp_data0 += 8;
im_tmp_data2 += 8;
}
const float *im_tmp_data1 = im_tmp_data0 + 1;
const float *im_tmp_data3 = im_tmp_data2 + 1;
for (int i = 0; i < mk4; ++i) {
*col0 = *im_tmp_data3;
*col1 = *im_tmp_data2;
*col2 = *im_tmp_data3;
*col3 = *im_tmp_data1;
*col4 = *im_tmp_data0;
*col5 = *im_tmp_data1;
*col6 = *im_tmp_data3;
*col7 = *im_tmp_data2;
*col8 = *im_tmp_data3;
col0++;
col1++;
col2++;
col3++;
col4++;
col5++;
col6++;
col7++;
col8++;
im_tmp_data0 += 2;
im_tmp_data1 += 2;
im_tmp_data2 += 2;
im_tmp_data3 += 2;
}
im_tmp_data0 += (isize - fill);
im_tmp_data2 += (isize - fill);
}
for (int i = 0; i < osize; ++i) {
col_data[0 * oosize + i * osize] = 0.0;
col_data[3 * oosize + i * osize] = 0.0;
col_data[6 * oosize + i * osize] = 0.0;
if (pad2) {
col_data[2 * oosize + osize - 1 + i * osize] = 0.0;
col_data[5 * oosize + osize - 1 + i * osize] = 0.0;
col_data[8 * oosize + osize - 1 + i * osize] = 0.0;
}
}
float32x4_t zero4;
zero4 = vdupq_n_f32(0.0);
auto col_z0 = col_data;
auto col_z1 = col_data + oosize;
auto col_z2 = col_data + 2 * oosize;
auto col_z6 = col_data + 6 * oosize + osize * (osize - 1);
auto col_z7 = col_data + 7 * oosize + osize * (osize - 1);
auto col_z8 = col_data + 8 * oosize + osize * (osize - 1);
for (int i = 0; i < nk4; ++i) {
vst1q_f32(col_z0, zero4);
vst1q_f32(col_z1, zero4);
vst1q_f32(col_z2, zero4);
if (pad2) {
vst1q_f32(col_z6, zero4);
vst1q_f32(col_z7, zero4);
vst1q_f32(col_z8, zero4);
}
col_z0 += 4;
col_z1 += 4;
col_z2 += 4;
col_z6 += 4;
col_z7 += 4;
col_z8 += 4;
}
for (int i = 0; i < mk4; ++i) {
col_z0[i] = 0.0;
col_z1[i] = 0.0;
col_z2[i] = 0.0;
if (pad2) {
col_z6[i] = 0.0;
col_z7[i] = 0.0;
col_z8[i] = 0.0;
}
}
col_data[1 * oosize + osize] = im_data[isize];
for (int i = 1; i < osize; ++i) {
col_data[3 * oosize + i] = im_data[(i - 1) * stride[0] + 1];
}
col_data[4 * oosize] = im_data[0];
col_data[7 * oosize] = im_data[isize];
col_data += 9 * oosize;
im_data += isize * isize;
}
} else {
for (int c = 0; c < channels_col; ++c) {
int w_offset = c % filter_width;
int h_offset = (c / filter_width) % filter_height;
int c_im = c / (filter_width * filter_height);
for (int h = 0; h < col_height; ++h) {
int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0];
for (int w = 0; w < col_width; ++w) {
int im_col_idx =
w * stride[1] - padding[1] + w_offset * dilation[1];
int col_idx = (c * col_height + h) * col_width + w;
int im_idx =
(im_row_idx + c_im * im_height) * im_width + im_col_idx;
col_data[col_idx] = (im_row_idx < 0 || im_row_idx >= im_height ||
im_col_idx < 0 || im_col_idx >= im_width)
? static_cast<T>(0)
: im_data[im_idx];
}
}
}
}
#else
for (int c = 0; c < channels_col; ++c) { for (int c = 0; c < channels_col; ++c) {
int w_offset = c % filter_width; int w_offset = c % filter_width;
int h_offset = (c / filter_width) % filter_height; int h_offset = (c / filter_width) % filter_height;
...@@ -86,6 +429,7 @@ class Im2ColFunctor<ColFormat::kCFO, CPU, T> { ...@@ -86,6 +429,7 @@ class Im2ColFunctor<ColFormat::kCFO, CPU, T> {
} }
} }
} }
#endif
} }
}; };
...@@ -158,7 +502,7 @@ class Col2ImFunctor<ColFormat::kCFO, CPU, T> { ...@@ -158,7 +502,7 @@ class Col2ImFunctor<ColFormat::kCFO, CPU, T> {
}; };
template class Im2ColFunctor<ColFormat::kCFO, CPU, float>; template class Im2ColFunctor<ColFormat::kCFO, CPU, float>;
template class Im2ColFunctor<ColFormat::kCFO, CPU, double>; // template class Im2ColFunctor<ColFormat::kCFO, CPU, double>;
template class Col2ImFunctor<ColFormat::kCFO, CPU, float>; template class Col2ImFunctor<ColFormat::kCFO, CPU, float>;
template class Col2ImFunctor<ColFormat::kCFO, CPU, double>; template class Col2ImFunctor<ColFormat::kCFO, CPU, double>;
......
#!/usr/bin/env sh #!/usr/bin/env sh
# auto build and run # auto build and run
BUILDNET="googlenet" BUILDNET="mobilenetssd"
TESTUNIT="test-googlenet" TESTUNIT="test-mobilenetssd"
push_fn () { push_fn () {
sh build.sh android ${BUILDNET} sh build.sh android ${BUILDNET}
MODELS_PATH="../test/models/*" MODELS_PATH="../test/models/*"
MODELS_SRC="../../test/models" MODELS_SRC="../test/models"
IMAGE_PATH="../test/images/*" IMAGE_PATH="../test/images/*"
EXE_FILE="../test/build/*" EXE_FILE="../test/build/*"
EXE_DIR="data/local/tmp/bin" EXE_DIR="data/local/tmp/bin"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册