diff --git a/lite/backends/arm/math/conv3x3_winograd_fp32_c4.cc b/lite/backends/arm/math/conv3x3_winograd_fp32_c4.cc index 5834461b8fe0b2d37f174d5f66269fb58f2504a1..ca27d181842a5f7faaf9497de1f947161279eefb 100644 --- a/lite/backends/arm/math/conv3x3_winograd_fp32_c4.cc +++ b/lite/backends/arm/math/conv3x3_winograd_fp32_c4.cc @@ -24,29 +24,48 @@ namespace paddle { namespace lite { namespace arm { namespace math { -void input_trans_c4(const float* src, - int src_stride, - float* dest, - int dest_stride); -void output_trans_c4(const float* src, - int src_stride, - float* dest, - int dest_stride); -void output_trans_c4_post(const float* src, - int src_stride, - float* dest, - int dest_stride, - float* bias_value, - bool has_relu); -void weight_trans_c4( +void input_trans_c4_8x8(const float* src, + int src_stride, + float* dest, + int dest_stride); +void output_trans_c4_6x8(const float* src, + int src_stride, + float* dest, + int dest_stride); +void output_trans_c4_post_6x8(const float* src, + int src_stride, + float* dest, + int dest_stride, + float* bias_value, + bool has_relu); +void input_trans_c4_4x4(const float* src, + int src_stride, + int src_h_stride, + float* dest, + int dest_stride, + int dest_h_stride); +void output_trans_c4_post_2x4(const float* src, + int src_stride, + int src_h_stride, + float* dest, + int dest_stride, + int dest_h_stride, + float* bias_value, + bool has_relu); +void weight_trans_c4_8x8( + float* dest, const float* src, int ic, int oc, void* workspace); +void weight_trans_c4_4x4( float* dest, const float* src, int ic, int oc, void* workspace); /* -*The following function conv_compute_6x6_3x3 is base on +*The following function conv_compute_6x6_3x3 and conv_compute_2x2_3x3[_small] is +*base on *MNN[https://github.com/alibaba/MNN] * *Copyright © 2018, Alibaba Group Holding Limited */ + +// F(6,3) void conv_compute_6x6_3x3(const float* input, float* output, int num, @@ -75,11 +94,14 @@ void conv_compute_6x6_3x3(const float* input, int tile_w = (wout + 5) / 6; int tile_h = (hout + 5) / 6; int size_tile = tile_h * tile_w; - float zero_ptr[8]; - memset(zero_ptr, 0, 8 * sizeof(float)); int w_pad = win + pad_w * 2; int h_pad = hin + pad_h * 2; + + const int zero_len = w_pad; + float zero_ptr[zero_len]; // NOLINT + memset(zero_ptr, 0, zero_len * sizeof(float)); + float* input_c4 = tmp_work_space; int new_h_stride = w_pad * 4; int new_c_stride = new_h_stride * h_pad; @@ -88,9 +110,6 @@ void conv_compute_6x6_3x3(const float* input, int oc_4_stride = wout * hout * 4; int tile_block = 8; -#ifdef __aarch64__ - tile_block = 16; -#endif int block_count = (size_tile + tile_block - 1) / tile_block; int threads = ctx->threads(); @@ -102,7 +121,8 @@ void conv_compute_6x6_3x3(const float* input, // begin compute for (int ni = 0; ni < num; ++ni) { - // trans input to c4 +// trans input to c4 +#pragma omp parallel for num_threads(threads) for (int i = 0; i < ic_4; ++i) { prepack_input_nxwc4_dw(input + ni * in_n_stride, input_c4 + i * new_c_stride, @@ -161,14 +181,14 @@ void conv_compute_6x6_3x3(const float* input, const float* src_ci = src_ptr + ci * ic_4_stride; for (int i = 0; i < 8; ++i) { const float* ci_ptr = src_ci + i * w_pad * 4; - input_trans_c4(ci_ptr, 4, trans_tmp_data + i * 4, 32); + input_trans_c4_8x8(ci_ptr, 4, trans_tmp_data + i * 4, 32); } float* dst_ci = dst_ptr + ci * tile_count * 4; for (int i = 0; i < 8; ++i) { - input_trans_c4(trans_tmp_data + i * 32, - 4, - dst_ci + i * b_gi_stride * 8, - b_gi_stride); + input_trans_c4_8x8(trans_tmp_data + i * 32, + 4, + dst_ci + i * b_gi_stride * 8, + b_gi_stride); } } } else { @@ -189,14 +209,14 @@ void conv_compute_6x6_3x3(const float* input, // trans for (int i = 0; i < 8; ++i) { float* ci_ptr = trans_remain_tmp_data + i * 32; - input_trans_c4(ci_ptr, 4, trans_tmp_data + i * 4, 32); + input_trans_c4_8x8(ci_ptr, 4, trans_tmp_data + i * 4, 32); } float* dst_ci = dst_ptr + ci * tile_count * 4; for (int i = 0; i < 8; ++i) { - input_trans_c4(trans_tmp_data + i * 32, - 4, - dst_ci + i * b_gi_stride * 8, - b_gi_stride); + input_trans_c4_8x8(trans_tmp_data + i * 32, + 4, + dst_ci + i * b_gi_stride * 8, + b_gi_stride); } } // for ci_4 } @@ -213,16 +233,8 @@ void conv_compute_6x6_3x3(const float* input, float* origin_C = dst_temp_data + gi * c_gi_stride; float* origin_B = b_ptr + gi * b_gi_stride; const float* origin_A = weight + gi * w_gi_stride; - sgemm_prepack_c4_small(oc_4 * 4, - tile_count, - ic_4 * 4, - origin_A, - origin_B, - origin_C, - nullptr, - false, - false, - ctx); + sgemm_prepack_c4_small( + oc_4 * 4, tile_count, ic_4 * 4, origin_A, origin_B, origin_C, ctx); } //*/ //* @@ -258,18 +270,18 @@ void conv_compute_6x6_3x3(const float* input, float* dst_ci = dst_ptr + ci * oc_4_stride; float* src_ci = src_ptr + ci * tile_count * 4; for (int i = 0; i < 8; ++i) { - output_trans_c4(src_ci + i * c_gi_stride * 8, - c_gi_stride, - trans_tmp_data + i * 4, - 32); + output_trans_c4_6x8(src_ci + i * c_gi_stride * 8, + c_gi_stride, + trans_tmp_data + i * 4, + 32); } for (int i = 0; i < ey; ++i) { - output_trans_c4_post(trans_tmp_data + i * 32, - 4, - trans_remain_tmp_data + i * 24, - 4, - bias_value, - param.fuse_relu); + output_trans_c4_post_6x8(trans_tmp_data + i * 32, + 4, + trans_remain_tmp_data + i * 24, + 4, + bias_value, + param.fuse_relu); } write_to_output_c4_fp32(trans_remain_tmp_data, output_ptr, @@ -297,18 +309,18 @@ void conv_compute_6x6_3x3(const float* input, float* dst_ci = dst_ptr + ci * oc_4_stride; float* src_ci = src_ptr + ci * tile_count * 4; for (int i = 0; i < 8; ++i) { - output_trans_c4(src_ci + i * c_gi_stride * 8, - c_gi_stride, - trans_tmp_data + i * 4, - 32); + output_trans_c4_6x8(src_ci + i * c_gi_stride * 8, + c_gi_stride, + trans_tmp_data + i * 4, + 32); } for (int i = 0; i < ey; ++i) { - output_trans_c4_post(trans_tmp_data + i * 32, - 4, - trans_remain_tmp_data + i * 24, - 4, - bias_value, - param.fuse_relu); + output_trans_c4_post_6x8(trans_tmp_data + i * 32, + 4, + trans_remain_tmp_data + i * 24, + 4, + bias_value, + param.fuse_relu); } // copy to dest memset(trans_tmp_data, 0, 144 * sizeof(float)); @@ -338,10 +350,522 @@ void conv_compute_6x6_3x3(const float* input, } // for num } // conv_compute -void output_trans_c4(const float* src, - int src_stride, - float* dest, - int dest_stride) { +// F(2,3) +void conv_compute_2x2_3x3(const float* input, + float* output, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + const float* weight, + const float* bias, + const operators::ConvParam& param, + ARMContext* ctx) { + const int pad_h = (*param.paddings)[0]; + const int pad_w = (*param.paddings)[2]; + float* tmp_work_space = + ctx->workspace_data() + ctx->llc_size() / sizeof(float); + + int in_n_stride = chin * hin * win; + int out_n_stride = chout * hout * wout; + int ic_stride = win * hin; + int oc_stride = wout * hout; + int ic_4 = (chin + 3) / 4; + int oc_4 = (chout + 3) / 4; + + int tile_w = (wout + 1) / 2; + int tile_h = (hout + 1) / 2; + int size_tile = tile_h * tile_w; + + int w_pad = win + pad_w * 2; + int h_pad = hin + pad_h * 2; + + const int zero_len = w_pad; + float zero_ptr[zero_len]; // NOLINT + memset(zero_ptr, 0, zero_len * sizeof(float)); + + float* input_c4 = tmp_work_space; + int new_h_stride = w_pad * 4; + int new_c_stride = new_h_stride * h_pad; + + int ic_4_stride = w_pad * h_pad * 4; + int oc_4_stride = wout * hout * 4; + + int tile_block = 8; + int block_count = (size_tile + tile_block - 1) / tile_block; + + int threads = ctx->threads(); + float* g_tmp_data = tmp_work_space + ic_4 * new_c_stride; + int tmp_data_thread_stride = tile_block * (oc_4 + ic_4) * 64; + memset(g_tmp_data, 0, threads * tmp_data_thread_stride * sizeof(float)); + float* g_trans_tmp_data = g_tmp_data + threads * tmp_data_thread_stride; + float* g_trans_remain_tmp_data = g_trans_tmp_data + threads * 64; + + // begin compute + for (int ni = 0; ni < num; ++ni) { +// trans input to c4 +#pragma omp parallel for num_threads(threads) + for (int i = 0; i < ic_4; ++i) { + prepack_input_nxwc4_dw(input + ni * in_n_stride, + input_c4 + i * new_c_stride, + i * 4, + -pad_h, + hin + pad_h, + -pad_w, + win + pad_w, + chin, + win, + hin, + zero_ptr); + } + float* output_ptr = output + ni * out_n_stride; + + const float* weight_ptr = weight; + const float* bias_ptr = bias; +#pragma omp parallel for num_threads(threads) + for (int tbi = 0; tbi < block_count; ++tbi) { +#ifdef ARM_WITH_OMP + float* tmp_data = + g_tmp_data + omp_get_thread_num() * tmp_data_thread_stride; + float* trans_tmp_data = g_trans_tmp_data + omp_get_thread_num() * 64; + float* trans_remain_tmp_data = + g_trans_remain_tmp_data + omp_get_thread_num() * 64; +#else + float* tmp_data = g_tmp_data; + float* trans_tmp_data = g_trans_tmp_data; + float* trans_remain_tmp_data = g_trans_remain_tmp_data; +#endif + int tile_index = tbi * tile_block; + int tile_remain = size_tile - tile_index; + int tile_count = tile_remain > tile_block ? tile_block : tile_remain; + + // input trans + int c_gi_stride = tile_count * oc_4 * 4; + int b_gi_stride = tile_count * ic_4 * 4; + //* + for (int ti = 0; ti < tile_count; ++ti) { + int index = tile_index + ti; + + int tw_index = index % tile_w; + int th_index = index / tile_w; + + int src_x = tw_index + tw_index; + int src_y = th_index + th_index; + int ex = src_x + 4 > w_pad ? w_pad - src_x : 4; + int ey = src_y + 4 > h_pad ? h_pad - src_y : 4; + + float* dst_ptr = tmp_data + ti * 4; + const float* src_ptr = input_c4 + (src_y * w_pad + src_x) * 4; + + if (ex == 4 && ey == 4) { + // trans input + for (int ci = 0; ci < ic_4; ++ci) { + const float* src_ci = src_ptr + ci * ic_4_stride; + float* dst_ci = dst_ptr + ci * tile_count * 4; + input_trans_c4_4x4( + src_ci, 4, w_pad * 4, dst_ci, b_gi_stride, b_gi_stride * 4); + } + } else { + // trans remain input + int x_size = ex; + for (int ci = 0; ci < ic_4; ++ci) { + const float* src_ci = src_ptr + ci * ic_4_stride; + // pad + memset(trans_remain_tmp_data, 0, 64 * sizeof(float)); + if (x_size > 0) { + for (int yi = 0; yi < ey; ++yi) { + float* dst_yi = trans_remain_tmp_data + yi * 16; + const float* src_yi = src_ci + w_pad * yi * 4; + memcpy(dst_yi, src_yi, x_size * sizeof(float) * 4); + } + } + + // trans + float* dst_ci = dst_ptr + ci * tile_count * 4; + input_trans_c4_4x4(trans_remain_tmp_data, + 4, + 16, + dst_ci, + b_gi_stride, + b_gi_stride * 4); + } // for ci_4 + } + } + //*/ + // input trans end + // *begin compute dot + // * + //* + float* dst_temp_data = tmp_data + tile_block * ic_4 * 64; + float* b_ptr = tmp_data; + int w_gi_stride = ic_4 * oc_4 * 16; + for (int gi = 0; gi < 16; ++gi) { + float* origin_C = dst_temp_data + gi * c_gi_stride; + float* origin_B = b_ptr + gi * b_gi_stride; + const float* origin_A = weight + gi * w_gi_stride; + sgemm_prepack_c4_small( + oc_4 * 4, tile_count, ic_4 * 4, origin_A, origin_B, origin_C, ctx); + } + //*/ + //* + // output trans + float bias_value[4]; + memset(bias_value, 0, 4 * sizeof(float)); + + for (int ti = 0; ti < tile_count; ++ti) { + int index = tile_index + ti; + + int tw_index = index % tile_w; + int th_index = index / tile_w; + + int dst_x = tw_index * 2; + int dst_y = th_index * 2; + + int ex = dst_x + 2 > wout ? wout - dst_x : 2; + int ey = dst_y + 2 > hout ? hout - dst_y : 2; + + float* dst_ptr = output + (dst_y * wout + dst_x) * 4; + float* src_ptr = dst_temp_data + ti * 4; + + if (ex == 2) { + // trans output + for (int ci = 0; ci < oc_4; ++ci) { + if (param.bias) { + bias_value[0] = bias[ci * 4]; + bias_value[1] = bias[ci * 4 + 1]; + bias_value[2] = bias[ci * 4 + 2]; + bias_value[3] = bias[ci * 4 + 3]; + } + + float* dst_ci = dst_ptr + ci * oc_4_stride; + float* src_ci = src_ptr + ci * tile_count * 4; + output_trans_c4_post_2x4(src_ci, + c_gi_stride, + c_gi_stride * 4, + trans_remain_tmp_data, + 4, + 8, + bias_value, + param.fuse_relu); + write_to_output_c4_fp32(trans_remain_tmp_data, + output_ptr, + ci * 4, + ci * 4 + 4, + dst_y, + dst_y + ey, + dst_x, + dst_x + ex, + chout, + hout, + wout, + false, + zero_ptr); + } + } else { + for (int ci = 0; ci < oc_4; ++ci) { + if (param.bias) { + bias_value[0] = bias[ci * 4]; + bias_value[1] = bias[ci * 4 + 1]; + bias_value[2] = bias[ci * 4 + 2]; + bias_value[3] = bias[ci * 4 + 3]; + } + // trans output + float* dst_ci = dst_ptr + ci * oc_4_stride; + float* src_ci = src_ptr + ci * tile_count * 4; + output_trans_c4_post_2x4(src_ci, + c_gi_stride, + c_gi_stride * 4, + trans_remain_tmp_data, + 4, + 8, + bias_value, + param.fuse_relu); + // copy to dest + memset(trans_tmp_data, 0, 16 * sizeof(float)); + for (int i = 0; i < ey; ++i) { + memcpy(trans_tmp_data + i * ex * 4, + trans_remain_tmp_data + i * 8, + ex * sizeof(float) * 4); + } + write_to_output_c4_fp32(trans_tmp_data, + output_ptr, + ci * 4, + ci * 4 + 4, + dst_y, + dst_y + ey, + dst_x, + dst_x + ex, + chout, + hout, + wout, + false, + zero_ptr); + } + } + } + //*/ + } // for block_count + } // for num +} // conv_compute +void conv_compute_2x2_3x3_small(const float* input, + float* output, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + const float* weight, + const float* bias, + const operators::ConvParam& param, + ARMContext* ctx) { + const int pad_h = (*param.paddings)[0]; + const int pad_w = (*param.paddings)[2]; + float* tmp_work_space = + ctx->workspace_data() + ctx->llc_size() / sizeof(float); + + int in_n_stride = chin * hin * win; + int out_n_stride = chout * hout * wout; + int ic_stride = win * hin; + int oc_stride = wout * hout; + int ic_4 = (chin + 3) / 4; + int oc_4 = (chout + 3) / 4; + + int tile_w = (wout + 1) / 2; + int tile_h = (hout + 1) / 2; + int size_tile = tile_h * tile_w; + + int w_pad = win + pad_w * 2; + int h_pad = hin + pad_h * 2; + + const int zero_len = w_pad; + float zero_ptr[zero_len]; // NOLINT + memset(zero_ptr, 0, zero_len * sizeof(float)); + + float* input_c4 = tmp_work_space; + int new_h_stride = w_pad * 4; + int new_c_stride = new_h_stride * h_pad; + + int ic_4_stride = w_pad * h_pad * 4; + int oc_4_stride = wout * hout * 4; + + int tile_block = 8; + int block_count = (size_tile + tile_block - 1) / tile_block; + + int threads = ctx->threads(); + float* g_tmp_data = tmp_work_space + ic_4 * new_c_stride; + int tmp_data_thread_stride = tile_block * (oc_4 + ic_4) * 64; + memset(g_tmp_data, 0, tmp_data_thread_stride * sizeof(float)); + float* g_trans_tmp_data = g_tmp_data + tmp_data_thread_stride; + float* g_trans_remain_tmp_data = g_trans_tmp_data + 64; + + // begin compute + for (int ni = 0; ni < num; ++ni) { +// trans input to c4 + +#pragma omp parallel for num_threads(threads) + for (int i = 0; i < ic_4; ++i) { + prepack_input_nxwc4_dw(input + ni * in_n_stride, + input_c4 + i * new_c_stride, + i * 4, + -pad_h, + hin + pad_h, + -pad_w, + win + pad_w, + chin, + win, + hin, + zero_ptr); + } + float* output_ptr = output + ni * out_n_stride; + + const float* weight_ptr = weight; + const float* bias_ptr = bias; + for (int tbi = 0; tbi < block_count; ++tbi) { + float* tmp_data = g_tmp_data; + float* trans_tmp_data = g_trans_tmp_data; + float* trans_remain_tmp_data = g_trans_remain_tmp_data; + int tile_index = tbi * tile_block; + int tile_remain = size_tile - tile_index; + int tile_count = tile_remain > tile_block ? tile_block : tile_remain; + + // input trans + int c_gi_stride = tile_count * oc_4 * 4; + int b_gi_stride = tile_count * ic_4 * 4; + //* + for (int ti = 0; ti < tile_count; ++ti) { + int index = tile_index + ti; + + int tw_index = index % tile_w; + int th_index = index / tile_w; + + int src_x = tw_index + tw_index; + int src_y = th_index + th_index; + int ex = src_x + 4 > w_pad ? w_pad - src_x : 4; + int ey = src_y + 4 > h_pad ? h_pad - src_y : 4; + + float* dst_ptr = tmp_data + ti * 4; + const float* src_ptr = input_c4 + (src_y * w_pad + src_x) * 4; + + if (ex == 4 && ey == 4) { + // trans input + for (int ci = 0; ci < ic_4; ++ci) { + const float* src_ci = src_ptr + ci * ic_4_stride; + float* dst_ci = dst_ptr + ci * tile_count * 4; + input_trans_c4_4x4( + src_ci, 4, w_pad * 4, dst_ci, b_gi_stride, b_gi_stride * 4); + } + } else { + // trans remain input + int x_size = ex; + for (int ci = 0; ci < ic_4; ++ci) { + const float* src_ci = src_ptr + ci * ic_4_stride; + // pad + memset(trans_remain_tmp_data, 0, 64 * sizeof(float)); + if (x_size > 0) { + for (int yi = 0; yi < ey; ++yi) { + float* dst_yi = trans_remain_tmp_data + yi * 16; + const float* src_yi = src_ci + w_pad * yi * 4; + memcpy(dst_yi, src_yi, x_size * sizeof(float) * 4); + } + } + + float* dst_ci = dst_ptr + ci * tile_count * 4; + input_trans_c4_4x4(trans_remain_tmp_data, + 4, + 16, + dst_ci, + b_gi_stride, + b_gi_stride * 4); + } // for ci_4 + } + } + //*/ + // input trans end + // *begin compute dot + // * + //* + float* dst_temp_data = tmp_data + tile_block * ic_4 * 64; + float* b_ptr = tmp_data; + int w_gi_stride = ic_4 * oc_4 * 16; +#pragma omp parallel for num_threads(threads) + for (int gi = 0; gi < 16; ++gi) { + float* origin_C = dst_temp_data + gi * c_gi_stride; + float* origin_B = b_ptr + gi * b_gi_stride; + const float* origin_A = weight + gi * w_gi_stride; + sgemm_prepack_c4_small( + oc_4 * 4, tile_count, ic_4 * 4, origin_A, origin_B, origin_C, ctx); + } + //*/ + //* + // output trans + float bias_value[4]; + memset(bias_value, 0, 4 * sizeof(float)); + + for (int ti = 0; ti < tile_count; ++ti) { + int index = tile_index + ti; + + int tw_index = index % tile_w; + int th_index = index / tile_w; + + int dst_x = tw_index * 2; + int dst_y = th_index * 2; + + int ex = dst_x + 2 > wout ? wout - dst_x : 2; + int ey = dst_y + 2 > hout ? hout - dst_y : 2; + + float* dst_ptr = output + (dst_y * wout + dst_x) * 4; + float* src_ptr = dst_temp_data + ti * 4; + + if (ex == 2) { + // trans output + for (int ci = 0; ci < oc_4; ++ci) { + if (param.bias) { + bias_value[0] = bias[ci * 4]; + bias_value[1] = bias[ci * 4 + 1]; + bias_value[2] = bias[ci * 4 + 2]; + bias_value[3] = bias[ci * 4 + 3]; + } + + float* dst_ci = dst_ptr + ci * oc_4_stride; + float* src_ci = src_ptr + ci * tile_count * 4; + + output_trans_c4_post_2x4(src_ci, + c_gi_stride, + c_gi_stride * 4, + trans_remain_tmp_data, + 4, + 8, + bias_value, + param.fuse_relu); + write_to_output_c4_fp32(trans_remain_tmp_data, + output_ptr, + ci * 4, + ci * 4 + 4, + dst_y, + dst_y + ey, + dst_x, + dst_x + ex, + chout, + hout, + wout, + false, + zero_ptr); + } + } else { + for (int ci = 0; ci < oc_4; ++ci) { + if (param.bias) { + bias_value[0] = bias[ci * 4]; + bias_value[1] = bias[ci * 4 + 1]; + bias_value[2] = bias[ci * 4 + 2]; + bias_value[3] = bias[ci * 4 + 3]; + } + // trans output + float* dst_ci = dst_ptr + ci * oc_4_stride; + float* src_ci = src_ptr + ci * tile_count * 4; + output_trans_c4_post_2x4(src_ci, + c_gi_stride, + c_gi_stride * 4, + trans_remain_tmp_data, + 4, + 8, + bias_value, + param.fuse_relu); + // copy to dest + memset(trans_tmp_data, 0, 16 * sizeof(float)); + for (int i = 0; i < ey; ++i) { + memcpy(trans_tmp_data + i * ex * 4, + trans_remain_tmp_data + i * 8, + ex * sizeof(float) * 4); + } + write_to_output_c4_fp32(trans_tmp_data, + output_ptr, + ci * 4, + ci * 4 + 4, + dst_y, + dst_y + ey, + dst_x, + dst_x + ex, + chout, + hout, + wout, + false, + zero_ptr); + } + } + } + //*/ + } // for block_count + } // for num +} // conv_compute +void output_trans_c4_6x8(const float* src, + int src_stride, + float* dest, + int dest_stride) { const float32x4_t src0 = vld1q_f32(src); const float32x4_t src1 = vld1q_f32(src + src_stride); const float32x4_t src2 = vld1q_f32(src + src_stride * 2); @@ -381,12 +905,13 @@ void output_trans_c4(const float* src, vst1q_f32(dest + dest_stride * 4, dest4); vst1q_f32(dest + dest_stride * 5, dest5); } -void output_trans_c4_post(const float* src, - int src_stride, - float* dest, - int dest_stride, - float* bias_value, - bool has_relu = false) { + +void output_trans_c4_post_6x8(const float* src, + int src_stride, + float* dest, + int dest_stride, + float* bias_value, + bool has_relu = false) { const float32x4_t src0 = vld1q_f32(src); const float32x4_t src1 = vld1q_f32(src + src_stride); const float32x4_t src2 = vld1q_f32(src + src_stride * 2); @@ -447,10 +972,10 @@ void output_trans_c4_post(const float* src, vst1q_f32(dest + dest_stride * 5, dest5); } -void input_trans_c4(const float* src, - int src_stride, - float* dest, - int dest_stride) { +void input_trans_c4_8x8(const float* src, + int src_stride, + float* dest, + int dest_stride) { float32x4_t src0 = vld1q_f32(src); float32x4_t src1 = vld1q_f32(src + src_stride); float32x4_t src2 = vld1q_f32(src + src_stride * 2); @@ -497,7 +1022,165 @@ void input_trans_c4(const float* src, vst1q_f32(dest + dest_stride * 6, dst6); vst1q_f32(dest + dest_stride * 7, dst7); } -void weight_trans_c4( + +// BT=[1, 0, -1, 0, +// 0, 1, 1, 0, +// 0, -1, 1, 0, +// 0, 1, 0, -1] +void input_trans_c4_4x4(const float* src, + int src_stride, + int src_h_stride, + float* dest, + int dest_stride, + int dest_h_stride) { + float32x4_t src00 = vld1q_f32(src); + float32x4_t src01 = vld1q_f32(src + src_stride); + float32x4_t src02 = vld1q_f32(src + src_stride + src_stride); + float32x4_t src03 = vld1q_f32(src + src_stride + src_stride + src_stride); + src += src_h_stride; + float32x4_t src10 = vld1q_f32(src); + float32x4_t src11 = vld1q_f32(src + src_stride); + float32x4_t src12 = vld1q_f32(src + src_stride + src_stride); + float32x4_t src13 = vld1q_f32(src + src_stride + src_stride + src_stride); + src += src_h_stride; + float32x4_t src20 = vld1q_f32(src); + float32x4_t src21 = vld1q_f32(src + src_stride); + float32x4_t src22 = vld1q_f32(src + src_stride + src_stride); + float32x4_t src23 = vld1q_f32(src + src_stride + src_stride + src_stride); + src += src_h_stride; + float32x4_t src30 = vld1q_f32(src); + float32x4_t src31 = vld1q_f32(src + src_stride); + float32x4_t src32 = vld1q_f32(src + src_stride + src_stride); + float32x4_t src33 = vld1q_f32(src + src_stride + src_stride + src_stride); + + float32x4_t dst00 = vsubq_f32(src00, src02); + float32x4_t dst10 = vaddq_f32(src01, src02); + float32x4_t dst20 = vsubq_f32(src02, src01); + float32x4_t dst30 = vsubq_f32(src01, src03); + + float32x4_t dst01 = vsubq_f32(src10, src12); + float32x4_t dst11 = vaddq_f32(src11, src12); + float32x4_t dst21 = vsubq_f32(src12, src11); + float32x4_t dst31 = vsubq_f32(src11, src13); + + float32x4_t dst02 = vsubq_f32(src20, src22); + float32x4_t dst12 = vaddq_f32(src21, src22); + float32x4_t dst22 = vsubq_f32(src22, src21); + float32x4_t dst32 = vsubq_f32(src21, src23); + + float32x4_t dst03 = vsubq_f32(src30, src32); + float32x4_t dst13 = vaddq_f32(src31, src32); + float32x4_t dst23 = vsubq_f32(src32, src31); + float32x4_t dst33 = vsubq_f32(src31, src33); + + float32x4_t dest00 = vsubq_f32(dst00, dst02); + float32x4_t dest10 = vaddq_f32(dst01, dst02); + float32x4_t dest20 = vsubq_f32(dst02, dst01); + float32x4_t dest30 = vsubq_f32(dst01, dst03); + + float32x4_t dest01 = vsubq_f32(dst10, dst12); + float32x4_t dest11 = vaddq_f32(dst11, dst12); + float32x4_t dest21 = vsubq_f32(dst12, dst11); + float32x4_t dest31 = vsubq_f32(dst11, dst13); + + float32x4_t dest02 = vsubq_f32(dst20, dst22); + float32x4_t dest12 = vaddq_f32(dst21, dst22); + float32x4_t dest22 = vsubq_f32(dst22, dst21); + float32x4_t dest32 = vsubq_f32(dst21, dst23); + + float32x4_t dest03 = vsubq_f32(dst30, dst32); + float32x4_t dest13 = vaddq_f32(dst31, dst32); + float32x4_t dest23 = vsubq_f32(dst32, dst31); + float32x4_t dest33 = vsubq_f32(dst31, dst33); + + vst1q_f32(dest, dest00); + vst1q_f32(dest + dest_stride, dest10); + vst1q_f32(dest + dest_stride + dest_stride, dest20); + vst1q_f32(dest + dest_stride + dest_stride + dest_stride, dest30); + dest += dest_h_stride; + vst1q_f32(dest, dest01); + vst1q_f32(dest + dest_stride, dest11); + vst1q_f32(dest + dest_stride + dest_stride, dest21); + vst1q_f32(dest + dest_stride + dest_stride + dest_stride, dest31); + dest += dest_h_stride; + vst1q_f32(dest, dest02); + vst1q_f32(dest + dest_stride, dest12); + vst1q_f32(dest + dest_stride + dest_stride, dest22); + vst1q_f32(dest + dest_stride + dest_stride + dest_stride, dest32); + dest += dest_h_stride; + vst1q_f32(dest, dest03); + vst1q_f32(dest + dest_stride, dest13); + vst1q_f32(dest + dest_stride + dest_stride, dest23); + vst1q_f32(dest + dest_stride + dest_stride + dest_stride, dest33); +} + +// AT=[1, 1, 1, 0, +// 0, 1, -1, -1] +void output_trans_c4_post_2x4(const float* src, + int src_stride, + int src_h_stride, + float* dest, + int dest_stride, + int dest_h_stride, + float* bias_value, + bool has_relu) { + float32x4_t src00 = vld1q_f32(src); + float32x4_t src01 = vld1q_f32(src + src_stride); + float32x4_t src02 = vld1q_f32(src + src_stride + src_stride); + float32x4_t src03 = vld1q_f32(src + src_stride + src_stride + src_stride); + src += src_h_stride; + float32x4_t src10 = vld1q_f32(src); + float32x4_t src11 = vld1q_f32(src + src_stride); + float32x4_t src12 = vld1q_f32(src + src_stride + src_stride); + float32x4_t src13 = vld1q_f32(src + src_stride + src_stride + src_stride); + src += src_h_stride; + float32x4_t src20 = vld1q_f32(src); + float32x4_t src21 = vld1q_f32(src + src_stride); + float32x4_t src22 = vld1q_f32(src + src_stride + src_stride); + float32x4_t src23 = vld1q_f32(src + src_stride + src_stride + src_stride); + src += src_h_stride; + float32x4_t src30 = vld1q_f32(src); + float32x4_t src31 = vld1q_f32(src + src_stride); + float32x4_t src32 = vld1q_f32(src + src_stride + src_stride); + float32x4_t src33 = vld1q_f32(src + src_stride + src_stride + src_stride); + + float32x4_t dst00 = vaddq_f32(vaddq_f32(src00, src01), src02); + float32x4_t dst10 = vsubq_f32(vsubq_f32(src01, src02), src03); + float32x4_t dst01 = vaddq_f32(vaddq_f32(src10, src11), src12); + float32x4_t dst11 = vsubq_f32(vsubq_f32(src11, src12), src13); + float32x4_t dst02 = vaddq_f32(vaddq_f32(src20, src21), src22); + float32x4_t dst12 = vsubq_f32(vsubq_f32(src21, src22), src23); + float32x4_t dst03 = vaddq_f32(vaddq_f32(src30, src31), src32); + float32x4_t dst13 = vsubq_f32(vsubq_f32(src31, src32), src33); + + float32x4_t dest00 = vaddq_f32(vaddq_f32(dst00, dst01), dst02); + float32x4_t dest10 = vsubq_f32(vsubq_f32(dst01, dst02), dst03); + float32x4_t dest01 = vaddq_f32(vaddq_f32(dst10, dst11), dst12); + float32x4_t dest11 = vsubq_f32(vsubq_f32(dst11, dst12), dst13); + + if (bias_value) { + float32x4_t bias = vld1q_f32(bias_value); + dest00 = vaddq_f32(dest00, bias); + dest10 = vaddq_f32(dest10, bias); + dest01 = vaddq_f32(dest01, bias); + dest11 = vaddq_f32(dest11, bias); + } + + if (has_relu) { + float32x4_t zeros = vdupq_n_f32(0); + dest00 = vmaxq_f32(dest00, zeros); + dest10 = vmaxq_f32(dest10, zeros); + dest01 = vmaxq_f32(dest01, zeros); + dest11 = vmaxq_f32(dest11, zeros); + } + + vst1q_f32(dest, dest00); + vst1q_f32(dest + dest_stride, dest10); + dest += dest_h_stride; + vst1q_f32(dest, dest01); + vst1q_f32(dest + dest_stride, dest11); +} +void weight_trans_c4_8x8( float* dest, const float* din, int ch_in, int ch_out, void* workspace) { const float coeff[8][3] = {{1.0f, 0.0f, 0.0f}, {-2.0f / 9, -2.0f / 9, -2.0f / 9}, @@ -558,6 +1241,63 @@ void weight_trans_c4( } } +void weight_trans_c4_4x4( + float* dest, const float* din, int ch_in, int ch_out, void* workspace) { + const float coeff[4][3] = {{1.0f, 0.0f, 0.0f}, + {0.5f, 0.5f, 0.5f}, + {0.5f, -0.5f, 0.5f}, + {0.0f, 0.0f, 1.0f}}; + + float* ptr_out = static_cast(workspace); + + for (int i = 0; i < ch_out; i++) { + for (int j = 0; j < ch_in; j++) { + const float* kernel0 = + static_cast(din) + (i * ch_in + j) * 9; + float* ptr_channel = ptr_out + (i * ch_in + j) * 16; + + //! transform kernel, transposed + const float* k0 = kernel0; + const float* k1 = kernel0 + 3; + const float* k2 = kernel0 + 6; + + //! h + float tmp[4][3]; + for (int i = 0; i < 4; i++) { + tmp[i][0] = + k0[0] * coeff[i][0] + k0[1] * coeff[i][1] + k0[2] * coeff[i][2]; + tmp[i][1] = + k1[0] * coeff[i][0] + k1[1] * coeff[i][1] + k1[2] * coeff[i][2]; + tmp[i][2] = + k2[0] * coeff[i][0] + k2[1] * coeff[i][1] + k2[2] * coeff[i][2]; + } + + //! v + for (int j = 0; j < 4; j++) { + float* tmpp = &tmp[j][0]; + for (int i = 0; i < 4; i++) { + ptr_channel[j * 4 + i] = tmpp[0] * coeff[i][0] + + tmpp[1] * coeff[i][1] + + tmpp[2] * coeff[i][2]; + } + } + } + } + + int oc_pad = (ch_out + 3) / 4 * 4; + int ic_pad = (ch_in + 3) / 4 * 4; + int c_stride = ic_pad * oc_pad; + for (int i = 0; i < ch_out * ch_in * 16; ++i) { + int new_c = i % 16; + int new_oc = i / ch_in / 16 / 4; + int new_ic = i / 16 % (ch_in * 4) % ch_in; + int new_inner = i / ch_in / 16 % 4; + int dest_ind = + new_c * c_stride + new_oc * ic_pad * 4 + new_ic * 4 + new_inner; + dest[dest_ind] = ptr_out[i]; + } +} + } // namespace math } // namespace arm } // namespace lite diff --git a/lite/backends/arm/math/conv_impl.h b/lite/backends/arm/math/conv_impl.h index f4d00039aaa635d0ffb31846fd9ff9077ac0c621..60f74b7feecc91a2fe8262a1fea4dce26430031d 100644 --- a/lite/backends/arm/math/conv_impl.h +++ b/lite/backends/arm/math/conv_impl.h @@ -316,7 +316,9 @@ void fill_bias_int8(int* tensor, int channel_size); // new winograd -void weight_trans_c4( +void weight_trans_c4_8x8( + float* dest, const float* src, int ic, int oc, void* workspace); +void weight_trans_c4_4x4( float* dest, const float* src, int ic, int oc, void* workspace); void conv_compute_6x6_3x3(const float* input, float* output, @@ -331,6 +333,32 @@ void conv_compute_6x6_3x3(const float* input, const float* bias, const operators::ConvParam& param, ARMContext* ctx); +void conv_compute_2x2_3x3(const float* input, + float* output, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + const float* weight, + const float* bias, + const operators::ConvParam& param, + ARMContext* ctx); +void conv_compute_2x2_3x3_small(const float* input, + float* output, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + const float* weight, + const float* bias, + const operators::ConvParam& param, + ARMContext* ctx); } // namespace math } // namespace arm } // namespace lite diff --git a/lite/backends/arm/math/packed_sgemm_c4.cc b/lite/backends/arm/math/packed_sgemm_c4.cc index 8087e0337bda0866f5d399a07ecb674f0fa55a3e..d78a7fd1e254a6cdbc493a02e4b9316b278938ae 100644 --- a/lite/backends/arm/math/packed_sgemm_c4.cc +++ b/lite/backends/arm/math/packed_sgemm_c4.cc @@ -695,7 +695,6 @@ void sgemm_prepack_c4_common(int M, } } } - void sgemm_prepack_c4_small(int M, int N, int K, @@ -1146,6 +1145,540 @@ void sgemm_prepack_c4_small(int M, } } +void sgemm_prepack_c4_small(int M, + int N, + int K, + const float* A_packed, + const float* B, + float* C, + ARMContext* ctx) { + const int m_round = (M + 3) / 4 * 4; + const int k_round = (K + 3) / 4 * 4; + const int mloop = m_round >> 2; + const int lda = 4 * k_round; + const int ldb_byte = 4 * N * sizeof(float); + const int kcnt = k_round >> 2; +#ifdef __aarch64__ + float32x4_t vzero = vdupq_n_f32(0.f); +#endif + for (int m = 0; m < mloop; ++m) { + const float* b = B; + int n = N; +#ifdef __aarch64__ + for (; n > 7; n -= 8) { + int cnt = kcnt; + const float* a_ptr = A_packed; + const float* b_ptr = b; + // clang-format off + asm volatile( + "0:\n" + /* load a0, a1 */ + "ld1 {v16.4s, v17.4s}, [%[a]], #32 \n" + /* load b0, b1 */ + "ld1 {v0.4s, v1.4s}, [%[b]], #32 \n" + /* load b2, b3 */ + "ld1 {v2.4s, v3.4s}, [%[b]], #32 \n" + /* load a2, a3 */ + "fmul v8.4s, v16.4s, v0.s[0] \n" + "fmul v9.4s, v16.4s, v1.s[0] \n" + "fmul v10.4s, v16.4s, v2.s[0] \n" + "fmul v11.4s, v16.4s, v3.s[0] \n" + "ld1 {v18.4s, v19.4s}, [%[a]], #32 \n" + "prfm pldl1keep, [%[b]] \n" + "fmla v8.4s, v17.4s, v0.s[1] \n" + "fmla v9.4s, v17.4s, v1.s[1] \n" + "fmla v10.4s, v17.4s, v2.s[1] \n" + "fmla v11.4s, v17.4s, v3.s[1] \n" + /* load b4, b5 */ + "ld1 {v4.4s, v5.4s}, [%[b]], #32 \n" + "fmla v8.4s, v18.4s, v0.s[2] \n" + "fmla v9.4s, v18.4s, v1.s[2] \n" + "fmla v10.4s, v18.4s, v2.s[2] \n" + "fmla v11.4s, v18.4s, v3.s[2] \n" + /* load b6, b7 */ + "ld1 {v6.4s, v7.4s}, [%[b]], #32 \n" + "fmla v8.4s, v19.4s, v0.s[3] \n" + "fmla v9.4s, v19.4s, v1.s[3] \n" + "fmla v10.4s, v19.4s, v2.s[3] \n" + "fmla v11.4s, v19.4s, v3.s[3] \n" + "sub %[b], %[b], #128 \n" + "fmul v12.4s, v16.4s, v4.s[0] \n" + "fmul v13.4s, v16.4s, v5.s[0] \n" + "fmul v14.4s, v16.4s, v6.s[0] \n" + "fmul v15.4s, v16.4s, v7.s[0] \n" + "add %[b], %[b], %[ldb] \n" + "fmla v12.4s, v17.4s, v4.s[1] \n" + "fmla v13.4s, v17.4s, v5.s[1] \n" + "fmla v14.4s, v17.4s, v6.s[1] \n" + "fmla v15.4s, v17.4s, v7.s[1] \n" + /* load a0, a1 */ + "ld1 {v16.4s, v17.4s}, [%[a]], #32 \n" + "fmla v12.4s, v18.4s, v4.s[2] \n" + "fmla v13.4s, v18.4s, v5.s[2] \n" + "fmla v14.4s, v18.4s, v6.s[2] \n" + "fmla v15.4s, v18.4s, v7.s[2] \n" + /* load b0, b1 */ + "ld1 {v0.4s, v1.4s}, [%[b]], #32 \n" + "fmla v12.4s, v19.4s, v4.s[3] \n" + "fmla v13.4s, v19.4s, v5.s[3] \n" + "fmla v14.4s, v19.4s, v6.s[3] \n" + "fmla v15.4s, v19.4s, v7.s[3] \n" + "subs %w[cnt], %w[cnt], #1 \n" + "beq 2f \n" + "1:\n" + /* load b2, b3 */ + "ld1 {v2.4s, v3.4s}, [%[b]], #32 \n" + "fmla v8.4s, v16.4s, v0.s[0] \n" + "fmla v9.4s, v16.4s, v1.s[0] \n" + "fmla v10.4s, v16.4s, v2.s[0] \n" + "fmla v11.4s, v16.4s, v3.s[0] \n" + /* load a2, a3 */ + "ld1 {v18.4s, v19.4s}, [%[a]], #32 \n" + "prfm pldl1keep, [%[b]] \n" + "fmla v8.4s, v17.4s, v0.s[1] \n" + "fmla v9.4s, v17.4s, v1.s[1] \n" + "fmla v10.4s, v17.4s, v2.s[1] \n" + "fmla v11.4s, v17.4s, v3.s[1] \n" + /* load b4, b5 */ + "ld1 {v4.4s, v5.4s}, [%[b]], #32 \n" + "fmla v8.4s, v18.4s, v0.s[2] \n" + "fmla v9.4s, v18.4s, v1.s[2] \n" + "fmla v10.4s, v18.4s, v2.s[2] \n" + "fmla v11.4s, v18.4s, v3.s[2] \n" + /* load b6, b7 */ + "ld1 {v6.4s, v7.4s}, [%[b]], #32 \n" + "fmla v8.4s, v19.4s, v0.s[3] \n" + "fmla v9.4s, v19.4s, v1.s[3] \n" + "fmla v10.4s, v19.4s, v2.s[3] \n" + "fmla v11.4s, v19.4s, v3.s[3] \n" + "sub %[b], %[b], #128 \n" + "fmla v12.4s, v16.4s, v4.s[0] \n" + "fmla v13.4s, v16.4s, v5.s[0] \n" + "fmla v14.4s, v16.4s, v6.s[0] \n" + "fmla v15.4s, v16.4s, v7.s[0] \n" + "add %[b], %[b], %[ldb] \n" + "fmla v12.4s, v17.4s, v4.s[1] \n" + "fmla v13.4s, v17.4s, v5.s[1] \n" + "fmla v14.4s, v17.4s, v6.s[1] \n" + "fmla v15.4s, v17.4s, v7.s[1] \n" + /* load a0, a1 */ + "ld1 {v16.4s, v17.4s}, [%[a]], #32 \n" + "fmla v12.4s, v18.4s, v4.s[2] \n" + "fmla v13.4s, v18.4s, v5.s[2] \n" + "fmla v14.4s, v18.4s, v6.s[2] \n" + "fmla v15.4s, v18.4s, v7.s[2] \n" + /* load b0, b1 */ + "ld1 {v0.4s, v1.4s}, [%[b]], #32 \n" + "fmla v12.4s, v19.4s, v4.s[3] \n" + "fmla v13.4s, v19.4s, v5.s[3] \n" + "fmla v14.4s, v19.4s, v6.s[3] \n" + "fmla v15.4s, v19.4s, v7.s[3] \n" + "subs %w[cnt], %w[cnt], #1 \n" + "bne 1b \n" + "2:\n" + "st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[c]], #64 \n" + "st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%[c]], #64 \n" + : [a] "+r" (a_ptr), + [b] "+r" (b_ptr), + [c] "+r" (C), + [cnt] "+r" (cnt) + : [ldb] "r" (ldb_byte), + [vzero] "w" (vzero) + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", + "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", + "v19", "cc", "memory" + ); + b += 4 * 8; + } + for (; n > 3; n -= 4) { + int cnt = kcnt; + const float* a_ptr = A_packed; + const float* b_ptr = b; + asm volatile( + "0:\n" + /* load a0, a1 */ + "ld1 {v16.4s, v17.4s}, [%[a]], #32 \n" + /* load b0-b3 */ + "ld1 {v0.4s, v1.4s}, [%[b]], #32 \n" + "ld1 {v2.4s, v3.4s}, [%[b]], #32 \n" + "fmul v8.4s, v16.4s, v0.s[0] \n" + "fmul v9.4s, v16.4s, v1.s[0] \n" + "fmul v10.4s, v16.4s, v2.s[0] \n" + "fmul v11.4s, v16.4s, v3.s[0] \n" + /* load a2, a3 */ + "ld1 {v18.4s, v19.4s}, [%[a]], #32 \n" + "sub %[b], %[b], #64 \n" + "fmla v8.4s, v17.4s, v0.s[1] \n" + "fmla v9.4s, v17.4s, v1.s[1] \n" + "fmla v10.4s, v17.4s, v2.s[1] \n" + "fmla v11.4s, v17.4s, v3.s[1] \n" + "add %[b], %[b], %[ldb] \n" + "fmla v8.4s, v18.4s, v0.s[2] \n" + "fmla v9.4s, v18.4s, v1.s[2] \n" + "fmla v10.4s, v18.4s, v2.s[2] \n" + "fmla v11.4s, v18.4s, v3.s[2] \n" + /* load a0, a1 */ + "ld1 {v16.4s, v17.4s}, [%[a]], #32 \n" + "fmla v8.4s, v19.4s, v0.s[3] \n" + "fmla v9.4s, v19.4s, v1.s[3] \n" + "fmla v10.4s, v19.4s, v2.s[3] \n" + "fmla v11.4s, v19.4s, v3.s[3] \n" + "subs %w[cnt], %w[cnt], #1 \n" + "beq 2f \n" + "1:\n" + /* load b0-b3 */ + "ld1 {v0.4s, v1.4s}, [%[b]], #32 \n" + "ld1 {v2.4s, v3.4s}, [%[b]], #32 \n" + "fmla v8.4s, v16.4s, v0.s[0] \n" + "fmla v9.4s, v16.4s, v1.s[0] \n" + "fmla v10.4s, v16.4s, v2.s[0] \n" + "fmla v11.4s, v16.4s, v3.s[0] \n" + /* load a2, a3 */ + "ld1 {v18.4s, v19.4s}, [%[a]], #32 \n" + "sub %[b], %[b], #64 \n" + "fmla v8.4s, v17.4s, v0.s[1] \n" + "fmla v9.4s, v17.4s, v1.s[1] \n" + "fmla v10.4s, v17.4s, v2.s[1] \n" + "fmla v11.4s, v17.4s, v3.s[1] \n" + "add %[b], %[b], %[ldb] \n" + "fmla v8.4s, v18.4s, v0.s[2] \n" + "fmla v9.4s, v18.4s, v1.s[2] \n" + "fmla v10.4s, v18.4s, v2.s[2] \n" + "fmla v11.4s, v18.4s, v3.s[2] \n" + /* load a0, a1 */ + "ld1 {v16.4s, v17.4s}, [%[a]], #32 \n" + "fmla v8.4s, v19.4s, v0.s[3] \n" + "fmla v9.4s, v19.4s, v1.s[3] \n" + "fmla v10.4s, v19.4s, v2.s[3] \n" + "fmla v11.4s, v19.4s, v3.s[3] \n" + "subs %w[cnt], %w[cnt], #1 \n" + "bne 1b \n" + "2:\n" + "st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[c]], #64 \n" + : [a] "+r" (a_ptr), + [b] "+r" (b_ptr), + [c] "+r" (C), + [cnt] "+r" (cnt) + : [ldb] "r" (ldb_byte), + [vzero] "w" (vzero) + : "v0", "v1", "v2", "v3", "v8", "v9", + "v10", "v11", "v16", "v17", "v18", + "v19", "cc", "memory" + ); + b += 4 * 4; + } + for (; n > 0; n--) { + int cnt = kcnt; + const float* a_ptr = A_packed; + const float* b_ptr = b; + asm volatile( + "0:\n" + /* load a0, a1 */ + "ld1 {v16.4s, v17.4s}, [%[a]], #32 \n" + /* load b0 */ + "ld1 {v0.4s}, [%[b]], #16 \n" + "fmul v8.4s, v16.4s, v0.s[0] \n" + "fmul v9.4s, v17.4s, v0.s[1] \n" + /* load a2, a3 */ + "ld1 {v18.4s, v19.4s}, [%[a]], #32 \n" + "sub %[b], %[b], #16 \n" + "subs %w[cnt], %w[cnt], #1 \n" + "add %[b], %[b], %[ldb] \n" + "fmla v8.4s, v18.4s, v0.s[2] \n" + "fmla v9.4s, v19.4s, v0.s[3] \n" + /* load a0, a1 */ + "ld1 {v16.4s, v17.4s}, [%[a]], #32 \n" + "beq 2f \n" + "1:\n" + /* load b0 */ + "ld1 {v0.4s}, [%[b]], #16 \n" + "fmla v8.4s, v16.4s, v0.s[0] \n" + "fmla v9.4s, v17.4s, v0.s[1] \n" + /* load a2, a3 */ + "ld1 {v18.4s, v19.4s}, [%[a]], #32 \n" + "sub %[b], %[b], #16 \n" + "subs %w[cnt], %w[cnt], #1 \n" + "add %[b], %[b], %[ldb] \n" + "fmla v8.4s, v18.4s, v0.s[2] \n" + "fmla v9.4s, v19.4s, v0.s[3] \n" + /* load a0, a1 */ + "ld1 {v16.4s, v17.4s}, [%[a]], #32 \n" + "bne 1b \n" + "fadd v8.4s, v8.4s, v9.4s \n" + "2:\n" + "st1 {v8.4s}, [%[c]], #16 \n" + : [a] "+r" (a_ptr), + [b] "+r" (b_ptr), + [c] "+r" (C), + [cnt] "+r" (cnt) + : [ldb] "r" (ldb_byte), + [vzero] "w" (vzero) + : "v0", "v8", "v9", "v16", "v17", + "v18", "v19", "cc", "memory" + ); + b += 4; + } +#else + for (; n > 7; n -= 8) { + int cnt = kcnt; + const float* a_ptr = A_packed; + const float* b_ptr = b; + // clang-format off + asm volatile( + "0:\n" + /* load a0, a1 */ + "vld1.32 {d8-d11}, [%[a]]! \n" + "vld1.32 {d0-d3}, [%[b]]! \n" + /* load b2, b3 */ + "vld1.32 {d4-d7}, [%[b]]! \n" + "vmul.f32 q8, q4, d0[0] \n" + "vmul.f32 q9, q4, d2[0] \n" + "vmul.f32 q10, q4, d4[0] \n" + "vmul.f32 q11, q4, d6[0] \n" + /* load a2, a3 */ + "vld1.32 {d12-d15}, [%[a]]! \n" + "pld [%[b]] \n" + "vmla.f32 q8, q5, d0[1] \n" + "vmla.f32 q9, q5, d2[1] \n" + "vmla.f32 q10, q5, d4[1] \n" + "vmla.f32 q11, q5, d6[1] \n" + "subs %[cnt], %[cnt], #1 \n" + "vmla.f32 q8, q6, d1[0] \n" + "vmla.f32 q9, q6, d3[0] \n" + "vmla.f32 q10, q6, d5[0] \n" + "vmla.f32 q11, q6, d7[0] \n" + "pld [%[b], #64] \n" + "vmla.f32 q8, q7, d1[1] \n" + "vmla.f32 q9, q7, d3[1] \n" + /* load b4, b5 */ + "vld1.32 {d0-d3}, [%[b]]! \n" + "vmla.f32 q10, q7, d5[1] \n" + "vmla.f32 q11, q7, d7[1] \n" + /* load b6, b7 */ + "vld1.32 {d4-d7}, [%[b]]! \n" + "vmul.f32 q12, q4, d0[0] \n" + "vmul.f32 q13, q4, d2[0] \n" + "vmul.f32 q14, q4, d4[0] \n" + "vmul.f32 q15, q4, d6[0] \n" + "sub %[b], %[b], #128 \n" + "vmla.f32 q12, q5, d0[1] \n" + "vmla.f32 q13, q5, d2[1] \n" + "vmla.f32 q14, q5, d4[1] \n" + "vmla.f32 q15, q5, d6[1] \n" + "add %[b], %[b], %[ldb] \n" + "vmla.f32 q12, q6, d1[0] \n" + "vmla.f32 q13, q6, d3[0] \n" + "vmla.f32 q14, q6, d5[0] \n" + "vmla.f32 q15, q6, d7[0] \n" + /* load a0, a1 */ + "vld1.32 {d8-d11}, [%[a]]! \n" + "vmla.f32 q12, q7, d1[1] \n" + "vmla.f32 q13, q7, d3[1] \n" + /* load b0, b1 */ + "vld1.32 {d0-d3}, [%[b]]! \n" + "vmla.f32 q14, q7, d5[1] \n" + "vmla.f32 q15, q7, d7[1] \n" + "beq 2f \n" + "1:\n" + /* load b2, b3 */ + "vld1.32 {d4-d7}, [%[b]]! \n" + "vmla.f32 q8, q4, d0[0] \n" + "vmla.f32 q9, q4, d2[0] \n" + "vmla.f32 q10, q4, d4[0] \n" + "vmla.f32 q11, q4, d6[0] \n" + /* load a2, a3 */ + "vld1.32 {d12-d15}, [%[a]]! \n" + "pld [%[b]] \n" + "vmla.f32 q8, q5, d0[1] \n" + "vmla.f32 q9, q5, d2[1] \n" + "vmla.f32 q10, q5, d4[1] \n" + "vmla.f32 q11, q5, d6[1] \n" + "subs %[cnt], %[cnt], #1 \n" + "vmla.f32 q8, q6, d1[0] \n" + "vmla.f32 q9, q6, d3[0] \n" + "vmla.f32 q10, q6, d5[0] \n" + "vmla.f32 q11, q6, d7[0] \n" + "pld [%[b], #64] \n" + "vmla.f32 q8, q7, d1[1] \n" + "vmla.f32 q9, q7, d3[1] \n" + /* load b4, b5 */ + "vld1.32 {d0-d3}, [%[b]]! \n" + "vmla.f32 q10, q7, d5[1] \n" + "vmla.f32 q11, q7, d7[1] \n" + /* load b6, b7 */ + "vld1.32 {d4-d7}, [%[b]]! \n" + "vmla.f32 q12, q4, d0[0] \n" + "vmla.f32 q13, q4, d2[0] \n" + "vmla.f32 q14, q4, d4[0] \n" + "vmla.f32 q15, q4, d6[0] \n" + "sub %[b], %[b], #128 \n" + "vmla.f32 q12, q5, d0[1] \n" + "vmla.f32 q13, q5, d2[1] \n" + "vmla.f32 q14, q5, d4[1] \n" + "vmla.f32 q15, q5, d6[1] \n" + "add %[b], %[b], %[ldb] \n" + "vmla.f32 q12, q6, d1[0] \n" + "vmla.f32 q13, q6, d3[0] \n" + "vmla.f32 q14, q6, d5[0] \n" + "vmla.f32 q15, q6, d7[0] \n" + /* load a0, a1 */ + "vld1.32 {d8-d11}, [%[a]]! \n" + "vmla.f32 q12, q7, d1[1] \n" + "vmla.f32 q13, q7, d3[1] \n" + /* load b0, b1 */ + "vld1.32 {d0-d3}, [%[b]]! \n" + "vmla.f32 q14, q7, d5[1] \n" + "vmla.f32 q15, q7, d7[1] \n" + "bne 1b \n" + "2:\n" + "vst1.32 {d16-d19}, [%[c]]! \n" + "vst1.32 {d20-d23}, [%[c]]! \n" + "vst1.32 {d24-d27}, [%[c]]! \n" + "vst1.32 {d28-d31}, [%[c]]! \n" + : [a] "+r" (a_ptr), + [b] "+r" (b_ptr), + [c] "+r" (C), + [cnt] "+r" (cnt) + : [ldb] "r" (ldb_byte) + : "q0", "q1", "q2", "q3", "q4", "q5", + "q6", "q7", "q8", "q9", "q10", "q11", + "q12", "q13", "q14", "q15", "cc", "memory" + ); + b += 4 * 8; + } + for (; n > 3; n -= 4) { + int cnt = kcnt; + const float* a_ptr = A_packed; + const float* b_ptr = b; + asm volatile( + "0:\n" + /* load a0, a1 */ + "vld1.32 {d8-d11}, [%[a]]! \n" + /* load b0-b3 */ + "vld1.32 {d0-d3}, [%[b]]! \n" + "vld1.32 {d4-d7}, [%[b]]! \n" + "vmul.f32 q8, q4, d0[0] \n" + "vmul.f32 q9, q4, d2[0] \n" + "vmul.f32 q10, q4, d4[0] \n" + "vmul.f32 q11, q4, d6[0] \n" + /* load a2, a3 */ + "vld1.32 {d12-d15}, [%[a]]!\n" + "sub %[b], %[b], #64 \n" + "vmla.f32 q8, q5, d0[1] \n" + "vmla.f32 q9, q5, d2[1] \n" + "vmla.f32 q10, q5, d4[1] \n" + "vmla.f32 q11, q5, d6[1] \n" + "add %[b], %[b], %[ldb] \n" + "vmla.f32 q8, q6, d1[0] \n" + "vmla.f32 q9, q6, d3[0] \n" + "vmla.f32 q10, q6, d5[0] \n" + "vmla.f32 q11, q6, d7[0] \n" + /* load a0, a1 */ + "vld1.32 {d8-d11}, [%[a]]! \n" + "vmla.f32 q8, q7, d1[1] \n" + "vmla.f32 q9, q7, d3[1] \n" + "vmla.f32 q10, q7, d5[1] \n" + "vmla.f32 q11, q7, d7[1] \n" + "subs %[cnt], %[cnt], #1 \n" + "beq 2f \n" + "1:\n" + /* load b0-b3 */ + "vld1.32 {d0-d3}, [%[b]]! \n" + "vld1.32 {d4-d7}, [%[b]]! \n" + "vmla.f32 q8, q4, d0[0] \n" + "vmla.f32 q9, q4, d2[0] \n" + "vmla.f32 q10, q4, d4[0] \n" + "vmla.f32 q11, q4, d6[0] \n" + /* load a2, a3 */ + "vld1.32 {d12-d15}, [%[a]]!\n" + "sub %[b], %[b], #64 \n" + "vmla.f32 q8, q5, d0[1] \n" + "vmla.f32 q9, q5, d2[1] \n" + "vmla.f32 q10, q5, d4[1] \n" + "vmla.f32 q11, q5, d6[1] \n" + "add %[b], %[b], %[ldb] \n" + "vmla.f32 q8, q6, d1[0] \n" + "vmla.f32 q9, q6, d3[0] \n" + "vmla.f32 q10, q6, d5[0] \n" + "vmla.f32 q11, q6, d7[0] \n" + /* load a0, a1 */ + "vld1.32 {d8-d11}, [%[a]]! \n" + "vmla.f32 q8, q7, d1[1] \n" + "vmla.f32 q9, q7, d3[1] \n" + "vmla.f32 q10, q7, d5[1] \n" + "vmla.f32 q11, q7, d7[1] \n" + "subs %[cnt], %[cnt], #1 \n" + "bne 1b \n" + "2:\n" + "vst1.32 {d16-d19}, [%[c]]!\n" + "vst1.32 {d20-d23}, [%[c]]!\n" + : [a] "+r" (a_ptr), + [b] "+r" (b_ptr), + [c] "+r" (C), + [cnt] "+r" (cnt) + : [ldb] "r" (ldb_byte) + : "q0", "q1", "q2", "q3", "q4", "q5", + "q6", "q7", "q8", "q9", "q10", "q11", + "q12", "q13", "cc", "memory" + ); + b += 4 * 4; + } + for (; n > 0; n--) { + int cnt = kcnt; + const float* a_ptr = A_packed; + const float* b_ptr = b; + asm volatile( + "0:\n" + /* load a0, a1 */ + "vld1.32 {d2-d5}, [%[a]]! \n" + /* load b0 */ + "vld1.32 {d0-d1}, [%[b]]! \n" + "vmul.f32 q5, q1, d0[0] \n" + "vmul.f32 q6, q2, d0[1] \n" + /* load a2, a3 */ + "vld1.32 {d6-d9}, [%[a]]! \n" + "sub %[b], %[b], #16 \n" + "subs %[cnt], %[cnt], #1 \n" + "add %[b], %[b], %[ldb] \n" + "vmla.f32 q5, q3, d1[0] \n" + "vmla.f32 q6, q4, d1[1] \n" + /* load a0, a1 */ + "vld1.32 {d2-d5}, [%[a]]! \n" + "beq 2f \n" + "1:\n" + /* load b0 */ + "vld1.32 {d0-d1}, [%[b]]! \n" + "vmla.f32 q5, q1, d0[0] \n" + "vmla.f32 q6, q2, d0[1] \n" + /* load a2, a3 */ + "vld1.32 {d6-d9}, [%[a]]! \n" + "sub %[b], %[b], #16 \n" + "subs %[cnt], %[cnt], #1 \n" + "add %[b], %[b], %[ldb] \n" + "vmla.f32 q5, q3, d1[0] \n" + "vmla.f32 q6, q4, d1[1] \n" + /* load a0, a1 */ + "vld1.32 {d2-d5}, [%[a]]! \n" + "bne 1b \n" + "vadd.f32 q5, q5, q6 \n" + "2:\n" + "vst1.32 {d10-d11}, [%[c]]!\n" + : [a] "+r" (a_ptr), + [b] "+r" (b_ptr), + [c] "+r" (C), + [cnt] "+r" (cnt) + : [ldb] "r" (ldb_byte) + : "q0", "q1", "q2", "q3", "q4", + "q5", "q6", "q7", "q8", "cc", "memory" + ); + // clang-format on + b += 4; + } +#endif + A_packed += lda; + } +} + void sgemm_prepack_c4(int M, int N, int K, diff --git a/lite/backends/arm/math/packed_sgemm_c4.h b/lite/backends/arm/math/packed_sgemm_c4.h index 21e5af634315a7da66914bb04775088fec55550c..3229ff3e0774ce8bff02b12d79d7ec50ed873cea 100644 --- a/lite/backends/arm/math/packed_sgemm_c4.h +++ b/lite/backends/arm/math/packed_sgemm_c4.h @@ -47,6 +47,13 @@ void sgemm_prepack_c4_small(int M, bool has_bias, bool has_relu, ARMContext* ctx); +void sgemm_prepack_c4_small(int M, + int N, + int K, + const float* A_packed, + const float* B, + float* C, + ARMContext* ctx); } // namespace math } // namespace arm } // namespace lite diff --git a/lite/kernels/arm/conv_compute.cc b/lite/kernels/arm/conv_compute.cc index 69e507ba347583b3761fe38d86136a22f2576c15..383934e5d51e0756cd3fdd3269a916dcc1431037 100644 --- a/lite/kernels/arm/conv_compute.cc +++ b/lite/kernels/arm/conv_compute.cc @@ -68,19 +68,9 @@ void ConvCompute::PrepareForRun() { VLOG(3) << "invoking dw conv"; } else if (param.groups == 1 && kw == 3 && stride == 1 && kps_equal && no_dilation) { - bool use_winograd = - (threads == 1 && oc >= 4 && ic >= 4 && hout >= 6 && wout >= 6 && - pads_equal) || - (oc >= 32 && ic >= 32 && hout >= 16 && wout >= 16 && pads_equal); - if (use_winograd) { - /// winograd conv impl - impl_ = new WinogradConv; - VLOG(3) << "invoking winograd conv"; - } else { - /// direct conv impl - impl_ = new DirectConv; - VLOG(3) << "invoking direct conv"; - } + /// winograd conv impl + impl_ = new WinogradConv; + VLOG(3) << "invoking winograd conv"; } else if (param.groups == 1 && kw == 3 && stride == 2 && chin * chout < 4 * hin * win && kps_equal && no_dilation) { /// direct conv impl diff --git a/lite/kernels/arm/conv_winograd.cc b/lite/kernels/arm/conv_winograd.cc index d02cabf277a5e25e2dc731b5bcf0eabe601c9aae..c5cf0b237fc0548ac2bb7549d3950b3cead2b74c 100644 --- a/lite/kernels/arm/conv_winograd.cc +++ b/lite/kernels/arm/conv_winograd.cc @@ -43,79 +43,47 @@ void WinogradConv::ReInitWhenNeeded() { int oh = o_dims[2]; int ow = o_dims[3]; int tile_block = 8; -#ifdef __aarch64__ - tile_block = 16; -#endif - int parallel_threads = - (((ow + 5) / 6) * ((oh + 5) / 6) + tile_block - 1) / tile_block; - if (threads <= 2 && parallel_threads >= threads) { - if (last_kernel_is_c4_ == 1) { + choose_small_ = ow * oh / (tile_block * threads) < 36 ? true : false; + if (choose_small_) { + wino_iw = 4; + + if (last_function_ == 0) { return; } - last_kernel_is_c4_ = 1; - auto pad = *(param.paddings); - int pad_h = pad[0]; - int pad_w = pad[2]; - int oc_pad = (oc + 3) / 4 * 4; - int ic_pad = (ic + 3) / 4 * 4; - const int new_input_size = - (ic + 3) / 4 * 4 * (ih + pad_h * 2) * (iw + pad_w * 2); - const int temp_size = - (tile_block * ((ic + 3) / 4 + (oc + 3) / 4) * 256 + 512) * threads; - ctx.ExtendWorkspace((temp_size + new_input_size) * sizeof(float)); - - weights_.Resize({1, 1, 1, 64 * oc_pad * ic_pad}); - ctx.ExtendWorkspace((temp_size + new_input_size) * sizeof(float)); - void* trans_tmp_ptr = malloc(sizeof(float) * 8 * 8 * oc * ic); - auto weights_data_ = weights_.mutable_data(); - lite::arm::math::weight_trans_c4( - weights_data_, param.filter->data(), ic, oc, trans_tmp_ptr); - free(trans_tmp_ptr); + last_function_ = 0; } else { - if (last_kernel_is_c4_ == 0) { + wino_iw = 8; + if (last_function_ == 1) { return; } - last_kernel_is_c4_ = 0; - int tile_w = (ow + 5) / 6; - int tile_h = (oh + 5) / 6; - - int size_tile = tile_h * tile_w; - int size_trans_channel = 8 * 8 * size_tile; - int max_ch = ic > oc ? ic : oc; - - const int n_wino = size_tile; - ctx.ExtendWorkspace((size_trans_channel * max_ch * 2 + n_wino) * - sizeof(float)); - - const int m_wino = oc; - int hblock = lite::arm::math::get_hblock(&ctx); - int m_round = hblock * ((m_wino + hblock - 1) / hblock); - weights_.Resize({1, 1, 1, 8 * 8 * m_round * ic}); - ctx.ExtendWorkspace((size_trans_channel * max_ch * 2 + n_wino) * - sizeof(float)); - auto weights_wino = - static_cast(malloc(sizeof(float) * 8 * 8 * oc * ic)); - void* trans_tmp_ptr = malloc(sizeof(float) * 8 * 8 * oc * ic); - lite::arm::math::winograd_transform_weights( - weights_wino, param.filter->data(), oc, ic, trans_tmp_ptr); - auto weights_trans = weights_.mutable_data(); - for (int i = 0; i < 64; ++i) { - float* packed_weights = weights_trans + i * m_round * ic; - const float* weights_wino_ptr = weights_wino + i * oc * ic; - lite::arm::math::prepackA(packed_weights, - weights_wino_ptr, - 1.f, - ic, - 0, - m_wino, - 0, - ic, - false, - &ctx); - } - free(trans_tmp_ptr); - free(weights_wino); + last_function_ = 1; } + auto pad = *(param.paddings); + int pad_h = pad[0]; + int pad_w = pad[2]; + int oc_pad = (oc + 3) / 4 * 4; + int ic_pad = (ic + 3) / 4 * 4; + const int new_input_size = + (ic + 3) / 4 * 4 * (ih + pad_h * 2) * (iw + pad_w * 2); + const int temp_size = + (tile_block * ((ic + 3) / 4 + (oc + 3) / 4) * 4 * wino_iw * wino_iw + + 8 * wino_iw * wino_iw) * + threads; + ctx.ExtendWorkspace((temp_size + new_input_size) * sizeof(float)); + + weights_.Resize({1, 1, 1, wino_iw * wino_iw * oc_pad * ic_pad}); + ctx.ExtendWorkspace((temp_size + new_input_size) * sizeof(float)); + void* trans_tmp_ptr = malloc(sizeof(float) * wino_iw * wino_iw * oc * ic); + auto weights_data_ = weights_.mutable_data(); + if (!choose_small_) { + lite::arm::math::weight_trans_c4_8x8( + weights_data_, param.filter->data(), ic, oc, trans_tmp_ptr); + } else { + lite::arm::math::weight_trans_c4_4x4( + weights_data_, param.filter->data(), ic, oc, trans_tmp_ptr); + } + free(trans_tmp_ptr); + last_shape_ = x_dims; } @@ -145,14 +113,7 @@ void WinogradConv::Run() { int ow = o_dims[3]; int oc = o_dims[1]; - int tile_block = 8; -#ifdef __aarch64__ - tile_block = 16; -#endif - int threads = ctx.threads(); - int parallel_threads = - (((ow + 5) / 6) * ((oh + 5) / 6) + tile_block - 1) / tile_block; - if (threads <= 2 && parallel_threads >= threads) { + if (!choose_small_) { lite::arm::math::conv_compute_6x6_3x3(i_data, o_data, bs, @@ -167,19 +128,38 @@ void WinogradConv::Run() { param, &ctx); } else { - lite::arm::math::conv_winograd3x3(i_data, - o_data, - bs, - oc, - oh, - ow, - ic, - ih, - iw, - w_data, - b_data, - param, - &ctx); + int tile_block = 8; + int block_count = + (((ow + 1) / 2) * ((oh + 1) / 2) + tile_block - 1) / tile_block; + if (block_count != 1) { + lite::arm::math::conv_compute_2x2_3x3(i_data, + o_data, + bs, + oc, + oh, + ow, + ic, + ih, + iw, + w_data, + b_data, + param, + &ctx); + } else { + lite::arm::math::conv_compute_2x2_3x3_small(i_data, + o_data, + bs, + oc, + oh, + ow, + ic, + ih, + iw, + w_data, + b_data, + param, + &ctx); + } } } diff --git a/lite/kernels/arm/conv_winograd.h b/lite/kernels/arm/conv_winograd.h index 40ea54b2918ad6c1b18d36a6df287c7e3eb312a6..1a184ac0ccae1967a2f77110ce2a6fb619cf2e8e 100644 --- a/lite/kernels/arm/conv_winograd.h +++ b/lite/kernels/arm/conv_winograd.h @@ -40,7 +40,9 @@ class WinogradConv : public KernelLite { Tensor weights_; DDim last_shape_; int workspace_size_{0}; - int last_kernel_is_c4_{-1}; + int last_function_{-1}; + bool choose_small_{false}; + int wino_iw{8}; }; } // namespace arm