未验证 提交 e0eee83c 编写于 作者: T TianXiaogang 提交者: GitHub

add winograd c4 implement (#2494)

fix: fix conv_block prepack_input_nxwc4 bug
* fix: optimize sgemm_c4 in armv7
     change condition of choose winograd kernel
* fix: change conv choose kernel condition
上级 93cfddb5
...@@ -79,6 +79,7 @@ if (NOT HAS_ARM_MATH_LIB_DIR) ...@@ -79,6 +79,7 @@ if (NOT HAS_ARM_MATH_LIB_DIR)
conv5x5s1_depthwise_int8.cc conv5x5s1_depthwise_int8.cc
conv5x5s1_depthwise_fp32.cc conv5x5s1_depthwise_fp32.cc
conv5x5s2_depthwise_fp32.cc conv5x5s2_depthwise_fp32.cc
conv3x3_winograd_fp32_c4.cc
conv_winograd_3x3.cc conv_winograd_3x3.cc
conv_impl.cc conv_impl.cc
softmax.cc softmax.cc
......
此差异已折叠。
...@@ -254,6 +254,7 @@ inline void prepack_input_nxwc4_dw(const float* din, ...@@ -254,6 +254,7 @@ inline void prepack_input_nxwc4_dw(const float* din,
LOG(FATAL) << "prepack_dw_input, valid height must > zero"; LOG(FATAL) << "prepack_dw_input, valid height must > zero";
} }
float32x4_t vzero = vdupq_n_f32(0.f); float32x4_t vzero = vdupq_n_f32(0.f);
auto out_data = dout;
int size_w = we - ws; int size_w = we - ws;
int w0 = ws < 0 ? 0 : ws; int w0 = ws < 0 ? 0 : ws;
...@@ -269,6 +270,7 @@ inline void prepack_input_nxwc4_dw(const float* din, ...@@ -269,6 +270,7 @@ inline void prepack_input_nxwc4_dw(const float* din,
bool flag_ext_l = left_remain > 0; bool flag_ext_l = left_remain > 0;
int left_sl = 4 - left_remain; int left_sl = 4 - left_remain;
int left_valid_sl = left_sl > width ? width : left_sl;
uint32x4_t vmask_padl; uint32x4_t vmask_padl;
bool flag_mask_l = false; bool flag_mask_l = false;
if (flag_ext_l) { if (flag_ext_l) {
...@@ -290,6 +292,7 @@ inline void prepack_input_nxwc4_dw(const float* din, ...@@ -290,6 +292,7 @@ inline void prepack_input_nxwc4_dw(const float* din,
} }
int size_c = width * height; int size_c = width * height;
for (int h = hs; h < he; ++h) { for (int h = hs; h < he; ++h) {
dout = out_data + (h - hs) * 4 * size_w;
auto ptr_c0 = din + cs * size_c + h * width; auto ptr_c0 = din + cs * size_c + h * width;
auto ptr_c1 = ptr_c0 + size_c; auto ptr_c1 = ptr_c0 + size_c;
auto ptr_c2 = ptr_c1 + size_c; auto ptr_c2 = ptr_c1 + size_c;
...@@ -351,10 +354,10 @@ inline void prepack_input_nxwc4_dw(const float* din, ...@@ -351,10 +354,10 @@ inline void prepack_input_nxwc4_dw(const float* din,
} }
transpose_4x4(vc0, vc1, vc2, vc3, dout); transpose_4x4(vc0, vc1, vc2, vc3, dout);
dout += 16; dout += 16;
ptr_c0 += left_sl; ptr_c0 += left_valid_sl;
ptr_c1 += left_sl; ptr_c1 += left_valid_sl;
ptr_c2 += left_sl; ptr_c2 += left_valid_sl;
ptr_c3 += left_sl; ptr_c3 += left_valid_sl;
} }
/// valid /// valid
for (int i = 0; i < cnt_valid; ++i) { for (int i = 0; i < cnt_valid; ++i) {
...@@ -986,7 +989,9 @@ inline bool write_to_output_c4_fp32(const float* din, ...@@ -986,7 +989,9 @@ inline bool write_to_output_c4_fp32(const float* din,
int size_h = (he > height ? height : he) - hs; // size_h == hei_n int size_h = (he > height ? height : he) - hs; // size_h == hei_n
int cnt = (width - ws) / w4; int valid_we = we > width ? width : we;
int cnt = (valid_we - ws) / w4;
int remain = valid_we - ws - cnt * w4;
for (int i = 0; i < size_h; i++) { for (int i = 0; i < size_h; i++) {
int size_w = i * width; int size_w = i * width;
...@@ -1087,12 +1092,12 @@ inline bool write_to_output_c4_fp32(const float* din, ...@@ -1087,12 +1092,12 @@ inline bool write_to_output_c4_fp32(const float* din,
#endif #endif
} }
} }
if (we > width) { if (remain > 0) {
int offset = i * w_round * c4 + c4 * w4 * cnt; int offset = i * w_round * c4 + c4 * w4 * cnt;
din_hei_ptr = ptr_din + offset; din_hei_ptr = ptr_din + offset;
int j = we - w4; int j = 0;
if (flag_relu) { if (flag_relu) {
for (; j < width; ++j) { for (; j < remain; ++j) {
*(doutc0_ptr++) = LITEMAX(din_hei_ptr[0], 0.f); *(doutc0_ptr++) = LITEMAX(din_hei_ptr[0], 0.f);
*(doutc1_ptr++) = LITEMAX(din_hei_ptr[1], 0.f); *(doutc1_ptr++) = LITEMAX(din_hei_ptr[1], 0.f);
*(doutc2_ptr++) = LITEMAX(din_hei_ptr[2], 0.f); *(doutc2_ptr++) = LITEMAX(din_hei_ptr[2], 0.f);
...@@ -1100,7 +1105,7 @@ inline bool write_to_output_c4_fp32(const float* din, ...@@ -1100,7 +1105,7 @@ inline bool write_to_output_c4_fp32(const float* din,
din_hei_ptr += w4; din_hei_ptr += w4;
} }
} else { } else {
for (; j < width; ++j) { for (; j < remain; ++j) {
*(doutc0_ptr++) = din_hei_ptr[0]; *(doutc0_ptr++) = din_hei_ptr[0];
*(doutc1_ptr++) = din_hei_ptr[1]; *(doutc1_ptr++) = din_hei_ptr[1];
*(doutc2_ptr++) = din_hei_ptr[2]; *(doutc2_ptr++) = din_hei_ptr[2];
......
...@@ -314,7 +314,23 @@ void fill_bias_int8(int* tensor, ...@@ -314,7 +314,23 @@ void fill_bias_int8(int* tensor,
const int* bias, const int* bias,
int channel, int channel,
int channel_size); int channel_size);
// new winograd
void weight_trans_c4(
float* dest, const float* src, int ic, int oc, void* workspace);
void conv_compute_6x6_3x3(const float* input,
float* output,
int num,
int chout,
int hout,
int wout,
int chin,
int hin,
int win,
const float* weight,
const float* bias,
const operators::ConvParam& param,
ARMContext* ctx);
} // namespace math } // namespace math
} // namespace arm } // namespace arm
} // namespace lite } // namespace lite
......
...@@ -37,6 +37,16 @@ void sgemm_prepack_c4(int M, ...@@ -37,6 +37,16 @@ void sgemm_prepack_c4(int M,
bool has_bias, bool has_bias,
bool has_relu, bool has_relu,
ARMContext* ctx); ARMContext* ctx);
void sgemm_prepack_c4_small(int M,
int N,
int K,
const float* A_packed,
const float* B,
float* C,
const float* bias,
bool has_bias,
bool has_relu,
ARMContext* ctx);
} // namespace math } // namespace math
} // namespace arm } // namespace arm
} // namespace lite } // namespace lite
......
...@@ -40,6 +40,7 @@ void ConvCompute<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() { ...@@ -40,6 +40,7 @@ void ConvCompute<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
int kw = w_dims[3]; int kw = w_dims[3];
int pad = paddings[0]; int pad = paddings[0];
int stride = param.strides[0]; int stride = param.strides[0];
int threads = ctx.threads();
bool pads_equal = bool pads_equal =
((paddings[0] == paddings[1]) && (paddings[2] == paddings[3])); ((paddings[0] == paddings[1]) && (paddings[2] == paddings[3]));
...@@ -67,7 +68,15 @@ void ConvCompute<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() { ...@@ -67,7 +68,15 @@ void ConvCompute<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
VLOG(3) << "invoking dw conv"; VLOG(3) << "invoking dw conv";
} else if (param.groups == 1 && kw == 3 && stride == 1 && kps_equal && } else if (param.groups == 1 && kw == 3 && stride == 1 && kps_equal &&
no_dilation) { no_dilation) {
if (ic >= 32 && oc >= 32 && hout > 16 && wout > 16) { int tile_block = 8;
#ifdef __aarch64__
tile_block = 16;
#endif
bool use_winograd =
(threads == 1 && oc >= 4 && ic >= 4 && hout >= 6 && wout >= 6 &&
pads_equal) ||
(oc >= 32 && ic >= 32 && hout >= 16 && wout >= 16 && pads_equal);
if (use_winograd) {
/// winograd conv impl /// winograd conv impl
impl_ = new WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>; impl_ = new WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>;
VLOG(3) << "invoking winograd conv"; VLOG(3) << "invoking winograd conv";
......
...@@ -26,6 +26,7 @@ template <> ...@@ -26,6 +26,7 @@ template <>
void WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>::ReInitWhenNeeded() { void WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>::ReInitWhenNeeded() {
auto& param = this->Param<param_t>(); auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<ARMContext>(); auto& ctx = this->ctx_->template As<ARMContext>();
int threads = ctx.threads();
auto x_dims = param.x->dims(); auto x_dims = param.x->dims();
auto w_dims = param.filter->dims(); auto w_dims = param.filter->dims();
...@@ -36,77 +37,89 @@ void WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>::ReInitWhenNeeded() { ...@@ -36,77 +37,89 @@ void WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>::ReInitWhenNeeded() {
} }
int ic = x_dims[1]; int ic = x_dims[1];
int ow = o_dims[3]; int ih = x_dims[2];
int oh = o_dims[2]; int iw = x_dims[3];
int oc = o_dims[1]; int oc = o_dims[1];
int tile_w = (ow + 5) / 6; int oh = o_dims[2];
int tile_h = (oh + 5) / 6; int ow = o_dims[3];
int size_tile = tile_h * tile_w; int tile_block = 8;
int size_trans_channel = 8 * 8 * size_tile; #ifdef __aarch64__
int max_ch = ic > oc ? ic : oc; tile_block = 16;
#endif
const int n_wino = size_tile; int parallel_threads =
workspace_size_ = (size_trans_channel * max_ch * 2 + n_wino) * sizeof(float); (((ow + 5) / 6) * ((oh + 5) / 6) + tile_block - 1) / tile_block;
if (threads <= 2 && parallel_threads >= threads) {
auto pad = *(param.paddings);
int pad_h = pad[0];
int pad_w = pad[2];
int oc_pad = (oc + 3) / 4 * 4;
int ic_pad = (ic + 3) / 4 * 4;
const int new_input_size =
(ic + 3) / 4 * 4 * (ih + pad_h * 2) * (iw + pad_w * 2);
const int temp_size =
(tile_block * ((ic + 3) / 4 + (oc + 3) / 4) * 256 + 512) * threads;
ctx.ExtendWorkspace((temp_size + new_input_size) * sizeof(float));
weights_.Resize({1, 1, 1, 64 * oc_pad * ic_pad});
ctx.ExtendWorkspace((temp_size + new_input_size) * sizeof(float));
void* trans_tmp_ptr = malloc(sizeof(float) * 8 * 8 * oc * ic);
auto weights_data_ = weights_.mutable_data<float>();
lite::arm::math::weight_trans_c4(
weights_data_, param.filter->data<float>(), ic, oc, trans_tmp_ptr);
free(trans_tmp_ptr);
} else {
int tile_w = (ow + 5) / 6;
int tile_h = (oh + 5) / 6;
int size_tile = tile_h * tile_w;
int size_trans_channel = 8 * 8 * size_tile;
int max_ch = ic > oc ? ic : oc;
const int n_wino = size_tile;
ctx.ExtendWorkspace((size_trans_channel * max_ch * 2 + n_wino) *
sizeof(float));
const int m_wino = oc;
int hblock = lite::arm::math::get_hblock(&ctx);
int m_round = hblock * ((m_wino + hblock - 1) / hblock);
weights_.Resize({1, 1, 1, 8 * 8 * m_round * ic});
ctx.ExtendWorkspace((size_trans_channel * max_ch * 2 + n_wino) *
sizeof(float));
auto weights_wino =
static_cast<float*>(malloc(sizeof(float) * 8 * 8 * oc * ic));
void* trans_tmp_ptr = malloc(sizeof(float) * 8 * 8 * oc * ic);
lite::arm::math::winograd_transform_weights(
weights_wino, param.filter->data<float>(), oc, ic, trans_tmp_ptr);
auto weights_trans = weights_.mutable_data<float>();
for (int i = 0; i < 64; ++i) {
float* packed_weights = weights_trans + i * m_round * ic;
const float* weights_wino_ptr = weights_wino + i * oc * ic;
lite::arm::math::prepackA(packed_weights,
weights_wino_ptr,
1.f,
ic,
0,
m_wino,
0,
ic,
false,
&ctx);
}
free(trans_tmp_ptr);
free(weights_wino);
}
last_shape_ = x_dims; last_shape_ = x_dims;
} }
template <> template <>
void WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() { void WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
auto& param = this->Param<param_t>(); ReInitWhenNeeded();
auto& ctx = this->ctx_->template As<ARMContext>();
auto x_dims = param.x->dims();
auto w_dims = param.filter->dims();
auto o_dims = param.output->dims();
last_shape_ = x_dims;
int ic = x_dims[1];
int ow = o_dims[3];
int oh = o_dims[2];
int oc = o_dims[1];
int tile_w = (ow + 5) / 6;
int tile_h = (oh + 5) / 6;
int size_tile = tile_h * tile_w;
int size_trans_channel = 8 * 8 * size_tile;
int max_ch = ic > oc ? ic : oc;
const int m_wino = oc;
const int n_wino = size_tile;
int hblock = lite::arm::math::get_hblock(&ctx);
int m_round = hblock * ((m_wino + hblock - 1) / hblock);
weights_.Resize({1, 1, 1, 8 * 8 * m_round * ic});
workspace_size_ = (size_trans_channel * max_ch * 2 + n_wino) * sizeof(float);
auto weights_wino =
static_cast<float*>(malloc(sizeof(float) * 8 * 8 * oc * ic));
void* trans_tmp_ptr = malloc(sizeof(float) * 8 * 8 * oc * ic);
lite::arm::math::winograd_transform_weights(
weights_wino, param.filter->data<float>(), oc, ic, trans_tmp_ptr);
auto weights_trans = weights_.mutable_data<float>();
for (int i = 0; i < 64; ++i) {
float* packed_weights = weights_trans + i * m_round * ic;
const float* weights_wino_ptr = weights_wino + i * oc * ic;
lite::arm::math::prepackA(packed_weights,
weights_wino_ptr,
1.f,
ic,
0,
m_wino,
0,
ic,
false,
&ctx);
}
free(trans_tmp_ptr);
free(weights_wino);
} }
template <> template <>
void WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>::Run() { void WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>::Run() {
auto& param = this->Param<param_t>(); auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<ARMContext>(); auto& ctx = this->ctx_->template As<ARMContext>();
// extend workspace
ctx.ExtendWorkspace(workspace_size_);
const auto* i_data = param.x->data<float>(); const auto* i_data = param.x->data<float>();
const auto* w_data = weights_.data<float>(); const auto* w_data = weights_.data<float>();
const auto* b_data = param.bias ? param.bias->data<float>() : nullptr; const auto* b_data = param.bias ? param.bias->data<float>() : nullptr;
...@@ -124,8 +137,42 @@ void WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>::Run() { ...@@ -124,8 +137,42 @@ void WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>::Run() {
int ow = o_dims[3]; int ow = o_dims[3];
int oc = o_dims[1]; int oc = o_dims[1];
lite::arm::math::conv_winograd3x3( int tile_block = 8;
i_data, o_data, bs, oc, oh, ow, ic, ih, iw, w_data, b_data, param, &ctx); #ifdef __aarch64__
tile_block = 16;
#endif
int threads = ctx.threads();
int parallel_threads =
(((ow + 5) / 6) * ((oh + 5) / 6) + tile_block - 1) / tile_block;
if (threads <= 2 && parallel_threads >= threads) {
lite::arm::math::conv_compute_6x6_3x3(i_data,
o_data,
bs,
oc,
oh,
ow,
ic,
ih,
iw,
w_data,
b_data,
param,
&ctx);
} else {
lite::arm::math::conv_winograd3x3(i_data,
o_data,
bs,
oc,
oh,
ow,
ic,
ih,
iw,
w_data,
b_data,
param,
&ctx);
}
} }
} // namespace arm } // namespace arm
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册