提交 00350575 编写于 作者: H hjchen2

Fix im2col bug and support do winograd with multi-threads

上级 4645c6dd
...@@ -166,7 +166,8 @@ void ConvCompute(const ConvParam<CPU> &param) { ...@@ -166,7 +166,8 @@ void ConvCompute(const ConvParam<CPU> &param) {
param.Strides()[0] == param.Strides()[1] && param.Strides()[0] == param.Strides()[1] &&
param.Dilations()[0] == param.Dilations()[1] && param.Dilations()[0] == param.Dilations()[1] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1 && param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1 &&
param.Dilations()[0] == 1 && param.Input()->dims()[1] >= 16) { param.Dilations()[0] == 1 && param.Output()->dims()[1] >= 16 &&
param.Output()->dims()[2] >= 16) {
BatchConv3x3Winograd(param); BatchConv3x3Winograd(param);
} else { } else {
ConvBasic<float, float>(param); ConvBasic<float, float>(param);
......
...@@ -41,48 +41,48 @@ void ExtractToImg(const float *im_data, float *col_data, const int im_height, ...@@ -41,48 +41,48 @@ void ExtractToImg(const float *im_data, float *col_data, const int im_height,
im_data += start_height * im_width + start_width; im_data += start_height * im_width + start_width;
col_data += col_start_height * col_width + col_start_width; col_data += col_start_height * col_width + col_start_width;
#pragma omp parallel for
for (int i = start_height; i < end_height; i += stride_h) { for (int i = start_height; i < end_height; i += stride_h) {
const float *local_im_data = im_data + i * im_width * stride_h;
float *local_col_data = col_data + col_width;
if (stride_w == 1) { if (stride_w == 1) {
memcpy(local_col_data, local_im_data, extract * sizeof(float)); memcpy(col_data, im_data, extract * sizeof(float));
} else if (stride_w == 2) { } else if (stride_w == 2) {
int s = 0; int s = 0;
#if __ARM_NEON #if __ARM_NEON
for (; s < extract - 15; s += 16) { for (; s < extract - 3; s += 4) {
float32x4x2_t img = vld2q_f32(local_im_data + s * 2); float32x4x2_t img = vld2q_f32(im_data + s * 2);
vst1q_f32(local_col_data + s, img.val[0]); vst1q_f32(col_data + s, img.val[0]);
} }
#endif #endif
for (; s < extract; ++s) { for (; s < extract; ++s) {
local_col_data[s] = local_im_data[s * 2]; col_data[s] = im_data[s * 2];
} }
} else if (stride_w == 3) { } else if (stride_w == 3) {
int s = 0; int s = 0;
#if __ARM_NEON #if __ARM_NEON
for (; s < extract - 15; s += 16) { for (; s < extract - 3; s += 4) {
float32x4x3_t img = vld3q_f32(local_im_data + s * 3); float32x4x3_t img = vld3q_f32(im_data + s * 3);
vst1q_f32(local_col_data + s, img.val[0]); vst1q_f32(col_data + s, img.val[0]);
} }
#endif #endif
for (; s < extract; ++s) { for (; s < extract; ++s) {
local_col_data[s] = local_im_data[s * 3]; col_data[s] = im_data[s * 3];
} }
} else if (stride_w == 4) { } else if (stride_w == 4) {
int s = 0; int s = 0;
#if __ARM_NEON #if __ARM_NEON
for (; s < extract - 15; s += 16) { for (; s < extract - 3; s += 4) {
float32x4x4_t img = vld4q_f32(local_im_data + s * 4); float32x4x4_t img = vld4q_f32(im_data + s * 4);
vst1q_f32(local_col_data + s, img.val[0]); vst1q_f32(col_data + s, img.val[0]);
} }
#endif #endif
for (; s < extract; ++s) { for (; s < extract; ++s) {
local_col_data[s] = local_im_data[s * 4]; col_data[s] = im_data[s * 4];
} }
} else { } else {
PADDLE_MOBILE_THROW_EXCEPTION("stride_w must be one of 1, 2, 3 and 4."); PADDLE_MOBILE_THROW_EXCEPTION("stride_w must be one of 1, 2, 3 and 4.");
} }
im_data += im_width * stride_h;
col_data += col_width;
} }
} }
...@@ -428,18 +428,23 @@ void Im2ColFunctor<ColFormat::kCFO, CPU, float>::operator()( ...@@ -428,18 +428,23 @@ void Im2ColFunctor<ColFormat::kCFO, CPU, float>::operator()(
im_data += isize * isize; im_data += isize * isize;
} }
} else if (stride[0] <= 4 && dilation[0] == 1 && dilation[0] == dilation[1]) { } else if (stride[0] <= 4 && dilation[0] == 1 && dilation[0] == dilation[1]) {
int im_spatial_size = im_height * im_width;
int col_spatial_size = col_height * col_width;
// pad 0 // pad 0
memset(col_data, 0, col->numel() * sizeof(float)); memset(col_data, 0, col->numel() * sizeof(float));
#pragma omp parallel for
for (int ic = 0; ic < im_channels; ++ic) { for (int ic = 0; ic < im_channels; ++ic) {
const float *local_im_data = im_data + ic * im_spatial_size;
float *local_col_data =
col_data + ic * filter_height * filter_width * col_spatial_size;
for (int kh = 0; kh < filter_height; ++kh) { for (int kh = 0; kh < filter_height; ++kh) {
for (int kw = 0; kw < filter_width; ++kw) { for (int kw = 0; kw < filter_width; ++kw) {
ExtractToImg(im_data, col_data, im_height, im_width, col_height, ExtractToImg(local_im_data, local_col_data, im_height, im_width,
col_width, padding[0], padding[1], stride[0], stride[1], col_height, col_width, padding[0], padding[1], stride[0],
kh, kw); stride[1], kh, kw);
col_data += col_height * col_width; local_col_data += col_spatial_size;
} }
} }
im_data += im_height * im_width;
} }
} else { } else {
#endif #endif
...@@ -553,18 +558,23 @@ void Im2ColFunctor<ColFormat::kCFO, CPU, int8_t>::operator()( ...@@ -553,18 +558,23 @@ void Im2ColFunctor<ColFormat::kCFO, CPU, int8_t>::operator()(
int8_t *col_data = col->mutable_data<int8_t>(); int8_t *col_data = col->mutable_data<int8_t>();
#if defined(__ARM_NEON__) || defined(__ARM_NEON) #if defined(__ARM_NEON__) || defined(__ARM_NEON)
if (stride[0] <= 4 && dilation[0] == 1 && dilation[0] == dilation[1]) { if (stride[0] <= 4 && dilation[0] == 1 && dilation[0] == dilation[1]) {
int im_spatial_size = im_height * im_width;
int col_spatial_size = col_height * col_width;
// pad 0 // pad 0
memset(col_data, 0, col->numel() * sizeof(int8_t)); memset(col_data, 0, col->numel() * sizeof(int8_t));
#pragma omp parallel for
for (int ic = 0; ic < im_channels; ++ic) { for (int ic = 0; ic < im_channels; ++ic) {
const int8_t *local_im_data = im_data + ic * im_spatial_size;
int8_t *local_col_data =
col_data + ic * filter_height * filter_width * col_spatial_size;
for (int kh = 0; kh < filter_height; ++kh) { for (int kh = 0; kh < filter_height; ++kh) {
for (int kw = 0; kw < filter_width; ++kw) { for (int kw = 0; kw < filter_width; ++kw) {
ExtractToImg(im_data, col_data, im_height, im_width, col_height, ExtractToImg(local_im_data, local_col_data, im_height, im_width,
col_width, padding[0], padding[1], stride[0], stride[1], col_height, col_width, padding[0], padding[1], stride[0],
kh, kw); stride[1], kh, kw);
col_data += col_height * col_width; local_col_data += col_spatial_size;
} }
} }
im_data += im_height * im_width;
} }
} else { } else {
#endif #endif
......
...@@ -30,6 +30,9 @@ void winograd_f6k3(const framework::Tensor &input, ...@@ -30,6 +30,9 @@ void winograd_f6k3(const framework::Tensor &input,
const framework::Tensor &weight, framework::Tensor *output) { const framework::Tensor &weight, framework::Tensor *output) {
framework::Tensor transformed_input; framework::Tensor transformed_input;
framework::Tensor transformed_weight; framework::Tensor transformed_weight;
#if __aarch64__
// TODO(hjchen2)
#else
// transform weight // transform weight
winograd_transform_weight<8, 3>(weight, &transformed_weight); winograd_transform_weight<8, 3>(weight, &transformed_weight);
// tile input and transform // tile input and transform
...@@ -37,6 +40,7 @@ void winograd_f6k3(const framework::Tensor &input, ...@@ -37,6 +40,7 @@ void winograd_f6k3(const framework::Tensor &input,
// caculate output // caculate output
winograd_transform_output<8, 3>(transformed_input, transformed_weight, winograd_transform_output<8, 3>(transformed_input, transformed_weight,
output); output);
#endif
} }
// F(4X4, 5X5) // F(4X4, 5X5)
......
...@@ -12,8 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,8 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
// Inspired by https://arxiv.org/abs/1509.09308 and // Inspired by https://arxiv.org/abs/1509.09308 and refered from nnpack and ncnn
// https://github.com/andravin/wincnn and refered from nnpack and ncnn project // project.
#ifdef CONV_OP
#ifndef __aarch64__
#include "operators/math/pad.h" #include "operators/math/pad.h"
#include "operators/math/winograd/winograd_transform.h" #include "operators/math/winograd/winograd_transform.h"
...@@ -47,12 +51,13 @@ void winograd_transform_weight<8, 3>(const framework::Tensor &weight, ...@@ -47,12 +51,13 @@ void winograd_transform_weight<8, 3>(const framework::Tensor &weight,
const float transform_matrix[8] = {2.f, -2.f / 9, 1.f / 90, 1.f / 180}; const float transform_matrix[8] = {2.f, -2.f / 9, 1.f / 90, 1.f / 180};
const float *inptr = weight.data<float>(); const float *inptr = weight.data<float>();
int remain_start = out_channel & 0xFFFC; int remain_start = out_channel & 0xFFFC;
#ifdef __aarch64__ #if 0
remain_start = 0; remain_start = 0;
#else #else
#pragma omp parallel for
for (int oc = 0; oc < out_channel - 3; oc += 4) { for (int oc = 0; oc < out_channel - 3; oc += 4) {
float gw[96]; // gw[3][8][4] float gw[96]; // gw[3][8][4]
const float *inptr0 = inptr + oc * in_channel * 9; // const float *inptr0 = inptr + oc * in_channel * 9;
const float *inptr1 = inptr + (oc + 1) * in_channel * 9; const float *inptr1 = inptr + (oc + 1) * in_channel * 9;
const float *inptr2 = inptr + (oc + 2) * in_channel * 9; const float *inptr2 = inptr + (oc + 2) * in_channel * 9;
const float *inptr3 = inptr + (oc + 3) * in_channel * 9; const float *inptr3 = inptr + (oc + 3) * in_channel * 9;
...@@ -252,9 +257,10 @@ void winograd_transform_weight<8, 3>(const framework::Tensor &weight, ...@@ -252,9 +257,10 @@ void winograd_transform_weight<8, 3>(const framework::Tensor &weight,
"q13", "r0"); "q13", "r0");
} }
} }
#endif // __aarch64__ #endif
// remain output channel // remain output channel
#pragma omp parallel for
for (int oc = remain_start; oc < out_channel; ++oc) { for (int oc = remain_start; oc < out_channel; ++oc) {
float gw[3][8]; // gw[3][8] float gw[3][8]; // gw[3][8]
const float *inptr0 = inptr + oc * in_channel * 9; // const float *inptr0 = inptr + oc * in_channel * 9; //
...@@ -301,10 +307,6 @@ void winograd_transform_weight<8, 3>(const framework::Tensor &weight, ...@@ -301,10 +307,6 @@ void winograd_transform_weight<8, 3>(const framework::Tensor &weight,
outptr += 4; outptr += 4;
} }
} }
// for (int i = 0; i < output->numel(); ++i) {
// DLOG << "TransK[" << i << "] = " << trans_outptr[i];
// }
} }
template <> template <>
...@@ -657,6 +659,7 @@ void winograd_transform_input<8, 3>(const framework::Tensor &input, ...@@ -657,6 +659,7 @@ void winograd_transform_input<8, 3>(const framework::Tensor &input,
#endif #endif
// remainer channels // remainer channels
#pragma omp parallel for
for (int c = remain_c_start; c < channel; ++c) { for (int c = remain_c_start; c < channel; ++c) {
const float *in = inptr + c * image_size; const float *in = inptr + c * image_size;
float d_bt[64]; // d * B_t float d_bt[64]; // d * B_t
...@@ -867,15 +870,6 @@ void winograd_transform_input<8, 3>(const framework::Tensor &input, ...@@ -867,15 +870,6 @@ void winograd_transform_input<8, 3>(const framework::Tensor &input,
} }
} }
} }
// for (int c = 0; c < channel; ++c) {
// for (int tile = 0; tile < output->numel()/channel/64; ++tile) {
// for (int i = 0; i < 64; ++i) {
// int offset = (((tile / 8) * 64 + i) * channel + c) * 8 + (tile % 8);
// DLOG << "TransInput[" << i << "] = " << outptr[offset];
// }
// }
// }
} }
template <> template <>
...@@ -897,6 +891,7 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input, ...@@ -897,6 +891,7 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input,
const float *input_ptr = input.data<float>(); const float *input_ptr = input.data<float>();
const float *weight_ptr = weight.data<float>(); const float *weight_ptr = weight.data<float>();
#pragma omp parallel for
for (int i = 0; i < out_channel; ++i) { for (int i = 0; i < out_channel; ++i) {
float *uv_ptr = uv_trans_ptr + (i * tiles * 64 * 32); float *uv_ptr = uv_trans_ptr + (i * tiles * 64 * 32);
for (int k = 0; k < 64; ++k) { for (int k = 0; k < 64; ++k) {
...@@ -1017,15 +1012,6 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input, ...@@ -1017,15 +1012,6 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input,
} }
} }
// for (int c = 0; c < 4 * out_channel; ++c) {
// for (int tile = 0; tile < 8 * tiles; ++tile) {
// for (int i = 0; i < 64; ++i) {
// int offset = (c * 8 * tiles + tile) * 64 + i;
// DLOG << "uv_trans[" << i << "] = " << uv_trans_ptr[offset];
// }
// }
// }
/* /*
* s0 = m0 + (m1 + m2) + (m3 + m4) + 32 * (m5 + m6) * s0 = m0 + (m1 + m2) + (m3 + m4) + 32 * (m5 + m6)
* s1 = (m1 - m2) + 2 * (m3 - m4) + 16 * (m5 - m6) * s1 = (m1 - m2) + 2 * (m3 - m4) + 16 * (m5 - m6)
...@@ -1045,12 +1031,7 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input, ...@@ -1045,12 +1031,7 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input,
int uv_image_size = uv_trans.dims()[1] * 64; int uv_image_size = uv_trans.dims()[1] * 64;
float transform_matrix[8] = {2.f, 4.f, 8.f, 16.f}; float transform_matrix[8] = {2.f, 4.f, 8.f, 16.f};
// DLOG << "out_channel: " << out_channel; #pragma omp parallel for
// DLOG << "h_tiles: " << h_tiles;
// DLOG << "w_tiles: " << w_tiles;
// DLOG << "remain_h: " << remain_h;
// DLOG << "remain_w: " << remain_w;
for (int oc = 0; oc < out_channel; ++oc) { for (int oc = 0; oc < out_channel; ++oc) {
float at_m[48]; // [6][8] float at_m[48]; // [6][8]
float output_tmp[36]; // [6][6], temporarily restore results float output_tmp[36]; // [6][6], temporarily restore results
...@@ -1118,9 +1099,6 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input, ...@@ -1118,9 +1099,6 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input,
: [tm_ptr] "r"((float *)transform_matrix) : [tm_ptr] "r"((float *)transform_matrix)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15", "r0"); "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15", "r0");
// for (int i = 0; i < 48; ++i) {
// DLOG << "at_m[" << i << "] = " << at_m[i];
// }
float *at_m_ptr0 = at_m; float *at_m_ptr0 = at_m;
float *at_m_ptr1 = at_m + 24; float *at_m_ptr1 = at_m + 24;
...@@ -1252,9 +1230,6 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input, ...@@ -1252,9 +1230,6 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input,
float *out_ptr = output_ptr + offset; float *out_ptr = output_ptr + offset;
int remain_row = (tile_h < h_tiles - 1) ? 6 : remain_h; int remain_row = (tile_h < h_tiles - 1) ? 6 : remain_h;
int remain_col = (tile_w < w_tiles - 1) ? 6 : remain_w; int remain_col = (tile_w < w_tiles - 1) ? 6 : remain_w;
// for (int i = 0; i < 36; ++i) {
// DLOG << "output_tmp[" << i << "] = " << output_tmp[i];
// }
for (int i = 0; i < remain_row; ++i, out_ptr += out_w) { for (int i = 0; i < remain_row; ++i, out_ptr += out_w) {
memcpy(out_ptr, output_tmp + i * 6, remain_col * sizeof(float)); memcpy(out_ptr, output_tmp + i * 6, remain_col * sizeof(float));
} }
...@@ -1391,3 +1366,6 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input, ...@@ -1391,3 +1366,6 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input,
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
#endif // __aarch64__
#endif // CONV_OP
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册