提交 00350575 编写于 作者: H hjchen2

Fix im2col bug and support do winograd with multi-threads

上级 4645c6dd
......@@ -166,7 +166,8 @@ void ConvCompute(const ConvParam<CPU> &param) {
param.Strides()[0] == param.Strides()[1] &&
param.Dilations()[0] == param.Dilations()[1] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1 &&
param.Dilations()[0] == 1 && param.Input()->dims()[1] >= 16) {
param.Dilations()[0] == 1 && param.Output()->dims()[1] >= 16 &&
param.Output()->dims()[2] >= 16) {
BatchConv3x3Winograd(param);
} else {
ConvBasic<float, float>(param);
......
......@@ -41,48 +41,48 @@ void ExtractToImg(const float *im_data, float *col_data, const int im_height,
im_data += start_height * im_width + start_width;
col_data += col_start_height * col_width + col_start_width;
#pragma omp parallel for
for (int i = start_height; i < end_height; i += stride_h) {
const float *local_im_data = im_data + i * im_width * stride_h;
float *local_col_data = col_data + col_width;
if (stride_w == 1) {
memcpy(local_col_data, local_im_data, extract * sizeof(float));
memcpy(col_data, im_data, extract * sizeof(float));
} else if (stride_w == 2) {
int s = 0;
#if __ARM_NEON
for (; s < extract - 15; s += 16) {
float32x4x2_t img = vld2q_f32(local_im_data + s * 2);
vst1q_f32(local_col_data + s, img.val[0]);
for (; s < extract - 3; s += 4) {
float32x4x2_t img = vld2q_f32(im_data + s * 2);
vst1q_f32(col_data + s, img.val[0]);
}
#endif
for (; s < extract; ++s) {
local_col_data[s] = local_im_data[s * 2];
col_data[s] = im_data[s * 2];
}
} else if (stride_w == 3) {
int s = 0;
#if __ARM_NEON
for (; s < extract - 15; s += 16) {
float32x4x3_t img = vld3q_f32(local_im_data + s * 3);
vst1q_f32(local_col_data + s, img.val[0]);
for (; s < extract - 3; s += 4) {
float32x4x3_t img = vld3q_f32(im_data + s * 3);
vst1q_f32(col_data + s, img.val[0]);
}
#endif
for (; s < extract; ++s) {
local_col_data[s] = local_im_data[s * 3];
col_data[s] = im_data[s * 3];
}
} else if (stride_w == 4) {
int s = 0;
#if __ARM_NEON
for (; s < extract - 15; s += 16) {
float32x4x4_t img = vld4q_f32(local_im_data + s * 4);
vst1q_f32(local_col_data + s, img.val[0]);
for (; s < extract - 3; s += 4) {
float32x4x4_t img = vld4q_f32(im_data + s * 4);
vst1q_f32(col_data + s, img.val[0]);
}
#endif
for (; s < extract; ++s) {
local_col_data[s] = local_im_data[s * 4];
col_data[s] = im_data[s * 4];
}
} else {
PADDLE_MOBILE_THROW_EXCEPTION("stride_w must be one of 1, 2, 3 and 4.");
}
im_data += im_width * stride_h;
col_data += col_width;
}
}
......@@ -428,18 +428,23 @@ void Im2ColFunctor<ColFormat::kCFO, CPU, float>::operator()(
im_data += isize * isize;
}
} else if (stride[0] <= 4 && dilation[0] == 1 && dilation[0] == dilation[1]) {
int im_spatial_size = im_height * im_width;
int col_spatial_size = col_height * col_width;
// pad 0
memset(col_data, 0, col->numel() * sizeof(float));
#pragma omp parallel for
for (int ic = 0; ic < im_channels; ++ic) {
const float *local_im_data = im_data + ic * im_spatial_size;
float *local_col_data =
col_data + ic * filter_height * filter_width * col_spatial_size;
for (int kh = 0; kh < filter_height; ++kh) {
for (int kw = 0; kw < filter_width; ++kw) {
ExtractToImg(im_data, col_data, im_height, im_width, col_height,
col_width, padding[0], padding[1], stride[0], stride[1],
kh, kw);
col_data += col_height * col_width;
ExtractToImg(local_im_data, local_col_data, im_height, im_width,
col_height, col_width, padding[0], padding[1], stride[0],
stride[1], kh, kw);
local_col_data += col_spatial_size;
}
}
im_data += im_height * im_width;
}
} else {
#endif
......@@ -553,18 +558,23 @@ void Im2ColFunctor<ColFormat::kCFO, CPU, int8_t>::operator()(
int8_t *col_data = col->mutable_data<int8_t>();
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
if (stride[0] <= 4 && dilation[0] == 1 && dilation[0] == dilation[1]) {
int im_spatial_size = im_height * im_width;
int col_spatial_size = col_height * col_width;
// pad 0
memset(col_data, 0, col->numel() * sizeof(int8_t));
#pragma omp parallel for
for (int ic = 0; ic < im_channels; ++ic) {
const int8_t *local_im_data = im_data + ic * im_spatial_size;
int8_t *local_col_data =
col_data + ic * filter_height * filter_width * col_spatial_size;
for (int kh = 0; kh < filter_height; ++kh) {
for (int kw = 0; kw < filter_width; ++kw) {
ExtractToImg(im_data, col_data, im_height, im_width, col_height,
col_width, padding[0], padding[1], stride[0], stride[1],
kh, kw);
col_data += col_height * col_width;
ExtractToImg(local_im_data, local_col_data, im_height, im_width,
col_height, col_width, padding[0], padding[1], stride[0],
stride[1], kh, kw);
local_col_data += col_spatial_size;
}
}
im_data += im_height * im_width;
}
} else {
#endif
......
......@@ -30,6 +30,9 @@ void winograd_f6k3(const framework::Tensor &input,
const framework::Tensor &weight, framework::Tensor *output) {
framework::Tensor transformed_input;
framework::Tensor transformed_weight;
#if __aarch64__
// TODO(hjchen2)
#else
// transform weight
winograd_transform_weight<8, 3>(weight, &transformed_weight);
// tile input and transform
......@@ -37,6 +40,7 @@ void winograd_f6k3(const framework::Tensor &input,
// caculate output
winograd_transform_output<8, 3>(transformed_input, transformed_weight,
output);
#endif
}
// F(4X4, 5X5)
......
......@@ -12,8 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// Inspired by https://arxiv.org/abs/1509.09308 and
// https://github.com/andravin/wincnn and refered from nnpack and ncnn project
// Inspired by https://arxiv.org/abs/1509.09308 and refered from nnpack and ncnn
// project.
#ifdef CONV_OP
#ifndef __aarch64__
#include "operators/math/pad.h"
#include "operators/math/winograd/winograd_transform.h"
......@@ -47,12 +51,13 @@ void winograd_transform_weight<8, 3>(const framework::Tensor &weight,
const float transform_matrix[8] = {2.f, -2.f / 9, 1.f / 90, 1.f / 180};
const float *inptr = weight.data<float>();
int remain_start = out_channel & 0xFFFC;
#ifdef __aarch64__
#if 0
remain_start = 0;
#else
#pragma omp parallel for
for (int oc = 0; oc < out_channel - 3; oc += 4) {
float gw[96]; // gw[3][8][4]
const float *inptr0 = inptr + oc * in_channel * 9; //
const float *inptr0 = inptr + oc * in_channel * 9;
const float *inptr1 = inptr + (oc + 1) * in_channel * 9;
const float *inptr2 = inptr + (oc + 2) * in_channel * 9;
const float *inptr3 = inptr + (oc + 3) * in_channel * 9;
......@@ -252,9 +257,10 @@ void winograd_transform_weight<8, 3>(const framework::Tensor &weight,
"q13", "r0");
}
}
#endif // __aarch64__
#endif
// remain output channel
#pragma omp parallel for
for (int oc = remain_start; oc < out_channel; ++oc) {
float gw[3][8]; // gw[3][8]
const float *inptr0 = inptr + oc * in_channel * 9; //
......@@ -301,10 +307,6 @@ void winograd_transform_weight<8, 3>(const framework::Tensor &weight,
outptr += 4;
}
}
// for (int i = 0; i < output->numel(); ++i) {
// DLOG << "TransK[" << i << "] = " << trans_outptr[i];
// }
}
template <>
......@@ -657,6 +659,7 @@ void winograd_transform_input<8, 3>(const framework::Tensor &input,
#endif
// remainer channels
#pragma omp parallel for
for (int c = remain_c_start; c < channel; ++c) {
const float *in = inptr + c * image_size;
float d_bt[64]; // d * B_t
......@@ -867,15 +870,6 @@ void winograd_transform_input<8, 3>(const framework::Tensor &input,
}
}
}
// for (int c = 0; c < channel; ++c) {
// for (int tile = 0; tile < output->numel()/channel/64; ++tile) {
// for (int i = 0; i < 64; ++i) {
// int offset = (((tile / 8) * 64 + i) * channel + c) * 8 + (tile % 8);
// DLOG << "TransInput[" << i << "] = " << outptr[offset];
// }
// }
// }
}
template <>
......@@ -897,6 +891,7 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input,
const float *input_ptr = input.data<float>();
const float *weight_ptr = weight.data<float>();
#pragma omp parallel for
for (int i = 0; i < out_channel; ++i) {
float *uv_ptr = uv_trans_ptr + (i * tiles * 64 * 32);
for (int k = 0; k < 64; ++k) {
......@@ -1017,15 +1012,6 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input,
}
}
// for (int c = 0; c < 4 * out_channel; ++c) {
// for (int tile = 0; tile < 8 * tiles; ++tile) {
// for (int i = 0; i < 64; ++i) {
// int offset = (c * 8 * tiles + tile) * 64 + i;
// DLOG << "uv_trans[" << i << "] = " << uv_trans_ptr[offset];
// }
// }
// }
/*
* s0 = m0 + (m1 + m2) + (m3 + m4) + 32 * (m5 + m6)
* s1 = (m1 - m2) + 2 * (m3 - m4) + 16 * (m5 - m6)
......@@ -1045,12 +1031,7 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input,
int uv_image_size = uv_trans.dims()[1] * 64;
float transform_matrix[8] = {2.f, 4.f, 8.f, 16.f};
// DLOG << "out_channel: " << out_channel;
// DLOG << "h_tiles: " << h_tiles;
// DLOG << "w_tiles: " << w_tiles;
// DLOG << "remain_h: " << remain_h;
// DLOG << "remain_w: " << remain_w;
#pragma omp parallel for
for (int oc = 0; oc < out_channel; ++oc) {
float at_m[48]; // [6][8]
float output_tmp[36]; // [6][6], temporarily restore results
......@@ -1118,9 +1099,6 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input,
: [tm_ptr] "r"((float *)transform_matrix)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15", "r0");
// for (int i = 0; i < 48; ++i) {
// DLOG << "at_m[" << i << "] = " << at_m[i];
// }
float *at_m_ptr0 = at_m;
float *at_m_ptr1 = at_m + 24;
......@@ -1252,9 +1230,6 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input,
float *out_ptr = output_ptr + offset;
int remain_row = (tile_h < h_tiles - 1) ? 6 : remain_h;
int remain_col = (tile_w < w_tiles - 1) ? 6 : remain_w;
// for (int i = 0; i < 36; ++i) {
// DLOG << "output_tmp[" << i << "] = " << output_tmp[i];
// }
for (int i = 0; i < remain_row; ++i, out_ptr += out_w) {
memcpy(out_ptr, output_tmp + i * 6, remain_col * sizeof(float));
}
......@@ -1391,3 +1366,6 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input,
} // namespace math
} // namespace operators
} // namespace paddle_mobile
#endif // __aarch64__
#endif // CONV_OP
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册