diff --git a/docs/user_guides/tutorial.md b/docs/user_guides/tutorial.md index 338449bfcb92e4029763c4357eb6d1fd5b820272..ee156038a6ea144921258734c92e9a2ea757d6ec 100644 --- a/docs/user_guides/tutorial.md +++ b/docs/user_guides/tutorial.md @@ -49,4 +49,4 @@ $ ./opt \ ## 五. 测试工具 -为了使您更好的了解并使用Lite框架,我们向有进一步使用需求的用户开放了 [Debug工具](debug#debug) 和 [Profile工具](debug#profiler)。Lite Model Debug Tool可以用来查找Lite框架与PaddlePaddle框架在执行预测时模型中的对应变量值是否有差异,进一步快速定位问题Op,方便复现与排查问题。Profile Monitor Tool可以帮助您了解每个Op的执行时间消耗,其会自动统计Op执行的次数,最长、最短、平均执行时间等等信息,为性能调优做一个基础参考。您可以通过 [相关专题](debug) 了解更多内容。 +为了使您更好的了解并使用Lite框架,我们向有进一步使用需求的用户开放了 [Debug工具](debug) 和 [Profile工具](debug)。Lite Model Debug Tool可以用来查找Lite框架与PaddlePaddle框架在执行预测时模型中的对应变量值是否有差异,进一步快速定位问题Op,方便复现与排查问题。Profile Monitor Tool可以帮助您了解每个Op的执行时间消耗,其会自动统计Op执行的次数,最长、最短、平均执行时间等等信息,为性能调优做一个基础参考。您可以通过 [相关专题](debug) 了解更多内容。 diff --git a/lite/backends/arm/math/CMakeLists.txt b/lite/backends/arm/math/CMakeLists.txt index aecec295ae0269fb34a3c4fa38e396bdf98d4418..d50b46d5bd26e3186e5def2100042e5b22ce4977 100644 --- a/lite/backends/arm/math/CMakeLists.txt +++ b/lite/backends/arm/math/CMakeLists.txt @@ -83,6 +83,7 @@ if (NOT HAS_ARM_MATH_LIB_DIR) conv5x5s2_depthwise_int8.cc conv5x5s2_depthwise_fp32.cc conv3x3_winograd_fp32_c4.cc + conv3x3_winograd_int8.cc conv_winograd_3x3.cc conv_impl.cc softmax.cc diff --git a/lite/backends/arm/math/conv3x3_winograd_fp32_c4.cc b/lite/backends/arm/math/conv3x3_winograd_fp32_c4.cc index 35d9eeaee1b69bed423cd3b489217c71575b3079..2957085493f15016abf2bf50f0aabecbe95f5b36 100644 --- a/lite/backends/arm/math/conv3x3_winograd_fp32_c4.cc +++ b/lite/backends/arm/math/conv3x3_winograd_fp32_c4.cc @@ -1245,7 +1245,7 @@ void weight_trans_c4_8x8( for (int i = 0; i < ch_out * ch_in * 64; ++i) { int new_c = i % 64; int new_oc = i / ch_in / 64 / 4; - int new_ic = i / 64 % (ch_in * 4) % ch_in; + int new_ic = i / 64 % ch_in; int new_inner = i / ch_in / 64 % 4; int dest_ind = new_c * c_stride + new_oc * ic_pad * 4 + new_ic * 4 + new_inner; @@ -1302,7 +1302,7 @@ void weight_trans_c4_4x4( for (int i = 0; i < ch_out * ch_in * 16; ++i) { int new_c = i % 16; int new_oc = i / ch_in / 16 / 4; - int new_ic = i / 16 % (ch_in * 4) % ch_in; + int new_ic = i / 16 % ch_in; int new_inner = i / ch_in / 16 % 4; int dest_ind = new_c * c_stride + new_oc * ic_pad * 4 + new_ic * 4 + new_inner; diff --git a/lite/backends/arm/math/conv3x3_winograd_int8.cc b/lite/backends/arm/math/conv3x3_winograd_int8.cc new file mode 100644 index 0000000000000000000000000000000000000000..7221559130f9363da81251af3d82410217cae5ad --- /dev/null +++ b/lite/backends/arm/math/conv3x3_winograd_int8.cc @@ -0,0 +1,601 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/backends/arm/math/conv_block_utils.h" +#include "lite/backends/arm/math/conv_impl.h" +#include "lite/backends/arm/math/packed_sgemm_c4.h" +#ifdef ARM_WITH_OMP +#include +#endif +#include +namespace paddle { +namespace lite { +namespace arm { +namespace math { +void input_trans_c8_4x4_int8(const int8_t* src, + int src_stride, + int src_h_stride, + int16_t* dest, + int dest_stride, + int dest_h_stride); +void output_trans_c8_post_2x4_int8(const int32_t* src, + int src_stride, + int src_h_stride, + int32_t* dest, + int dest_stride, + int dest_h_stride); +void weight_trans_c8_4x4_int8( + int16_t* dest, const int8_t* src, int ic, int oc, void* workspace); + +// F(2,3) +template +void conv_compute_2x2_3x3_int8(const int8_t* input, + Dtype* output, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + const int16_t* weight, + const float* bias, + const float* scale, + const operators::ConvParam& param, + ARMContext* ctx) { + auto act_param = param.activation_param; + const int pad_h0 = (*param.paddings)[0]; + const int pad_h1 = (*param.paddings)[1]; + const int pad_w0 = (*param.paddings)[2]; + const int pad_w1 = (*param.paddings)[3]; + int8_t* tmp_work_space = + ctx->workspace_data() + ctx->llc_size() / sizeof(int8_t); + + int in_n_stride = chin * hin * win; + int out_n_stride = chout * hout * wout; + int ic_stride = win * hin; + int oc_stride = wout * hout; + int ic_8 = (chin + 7) / 8; + int oc_8 = (chout + 7) / 8; + + int tile_w = (wout + 1) / 2; + int tile_h = (hout + 1) / 2; + int size_tile = tile_h * tile_w; + + int w_pad = win + pad_w0 + pad_w1; + int h_pad = hin + pad_h0 + pad_h1; + + const int zero_len = (w_pad + 3) / 4 * 4; + Dtype zero_ptr[zero_len]; // NOLINT + memset(zero_ptr, 0, zero_len * sizeof(Dtype)); + + int8_t* input_c8 = tmp_work_space; + int new_h_stride = w_pad * 8; + int new_c_stride = new_h_stride * h_pad; + + int ic_8_stride = w_pad * h_pad * 8; + int oc_8_stride = wout * hout * 8; + + int tile_block = 8; + int block_count = (size_tile + tile_block - 1) / tile_block; + + int threads = ctx->threads(); + int16_t* g_tmp_data = + (int16_t*)(tmp_work_space + ic_8 * ic_8_stride + // NOLINT + oc_8 * oc_8_stride * sizeof(int32_t)); + int tmp_input_thread_stride = tile_block * ic_8 * 128; + int tmp_output_thread_stride = tile_block * oc_8 * 128; + int tmp_data_thread_stride_size = tmp_input_thread_stride * sizeof(int16_t) + + tmp_output_thread_stride * sizeof(int32_t); + memset(g_tmp_data, 0, tmp_data_thread_stride_size); + int8_t* g_trans_remain_tmp_data = + (int8_t*)(g_tmp_data + // NOLINT + threads * (tmp_input_thread_stride + + tmp_output_thread_stride * sizeof(int32_t) / + sizeof(int16_t))); + int32_t* g_trans_tmp_data = + (int32_t*)(g_trans_remain_tmp_data + threads * 128); // NOLINT + auto act_type = act_param.active_type; + int flag_act = 0; // relu: 1, relu6: 2, leakey: 3 + float alpha[4] = {0.f, 0.f, 0.f, 0.f}; + if (act_param.has_active) { + if (act_type == lite_api::ActivationType::kRelu) { + flag_act = 1; + } else if (act_type == lite_api::ActivationType::kRelu6) { + flag_act = 2; + float local_alpha = act_param.Relu_clipped_coef; + alpha[0] = local_alpha; + alpha[1] = local_alpha; + alpha[2] = local_alpha; + alpha[3] = local_alpha; + } else if (act_type == lite_api::ActivationType::kLeakyRelu) { + flag_act = 3; + float local_alpha = act_param.Leaky_relu_alpha; + alpha[0] = local_alpha; + alpha[1] = local_alpha; + alpha[2] = local_alpha; + alpha[3] = local_alpha; + } + } + // begin compute + for (int ni = 0; ni < num; ++ni) { + // trans input to c8 + for (int i = 0; i < ic_8; ++i) { + prepack_input_nxwc8_int8_dw(input + ni * in_n_stride, + input_c8 + i * new_c_stride, + i * 8, + -pad_h0, + hin + pad_h1, + -pad_w0, + win + pad_w1, + chin, + win, + hin); + } + int32_t* output_c8 = (int32_t*)(input_c8 + ic_8 * ic_8_stride); // NOLINT + Dtype* output_ptr = output + ni * out_n_stride; + + const int16_t* weight_ptr = weight; +#pragma omp parallel for num_threads(threads) + for (int tbi = 0; tbi < block_count; ++tbi) { +#ifdef ARM_WITH_OMP + int16_t* tmp_data = + g_tmp_data + + omp_get_thread_num() * tmp_data_thread_stride_size / sizeof(int16_t); + int32_t* trans_tmp_data = g_trans_tmp_data + omp_get_thread_num() * 32; + int8_t* trans_remain_tmp_data = + g_trans_remain_tmp_data + omp_get_thread_num() * 128; +#else + int16_t* tmp_data = g_tmp_data; + int32_t* trans_tmp_data = g_trans_tmp_data; + int8_t* trans_remain_tmp_data = g_trans_remain_tmp_data; +#endif + int tile_index = tbi * tile_block; + int tile_remain = size_tile - tile_index; + int tile_count = tile_remain > tile_block ? tile_block : tile_remain; + + // input trans + int c_gi_stride = tile_count * oc_8 * 8; + int b_gi_stride = tile_count * ic_8 * 8; + //* + for (int ti = 0; ti < tile_count; ++ti) { + int index = tile_index + ti; + + int tw_index = index % tile_w; + int th_index = index / tile_w; + + int src_x = tw_index + tw_index; + int src_y = th_index + th_index; + int ex = src_x + 4 > w_pad ? w_pad - src_x : 4; + int ey = src_y + 4 > h_pad ? h_pad - src_y : 4; + + int16_t* dst_ptr = tmp_data + ti * 8; + const int8_t* src_ptr = input_c8 + (src_y * w_pad + src_x) * 8; + + if (ex == 4 && ey == 4) { + // trans input + for (int ci = 0; ci < ic_8; ++ci) { + const int8_t* src_ci = src_ptr + ci * ic_8_stride; + int16_t* dst_ci = dst_ptr + ci * tile_count * 8; + input_trans_c8_4x4_int8( + src_ci, 8, w_pad * 8, dst_ci, b_gi_stride, b_gi_stride * 4); + } + } else { + // trans remain input + int x_size = ex; + for (int ci = 0; ci < ic_8; ++ci) { + const int8_t* src_ci = src_ptr + ci * ic_8_stride; + // pad + memset(trans_remain_tmp_data, 0, 128 * sizeof(int8_t)); + if (x_size > 0) { + for (int yi = 0; yi < ey; ++yi) { + int8_t* dst_yi = trans_remain_tmp_data + yi * 32; + const int8_t* src_yi = src_ci + w_pad * yi * 8; + memcpy(dst_yi, src_yi, x_size * sizeof(int8_t) * 8); + } + } + + // trans + int16_t* dst_ci = dst_ptr + ci * tile_count * 8; + input_trans_c8_4x4_int8(trans_remain_tmp_data, + 8, + 32, + dst_ci, + b_gi_stride, + b_gi_stride * 4); + } // for ci_4 + } + } + //*/ + // input trans end + // *begin compute dot + // * + //* + int32_t* dst_temp_data = + (int32_t*)(tmp_data + tmp_input_thread_stride); // NOLINT + int16_t* b_ptr = tmp_data; + int w_gi_stride = ic_8 * oc_8 * 64; + for (int gi = 0; gi < 16; ++gi) { + int32_t* origin_C = dst_temp_data + gi * c_gi_stride; + int16_t* origin_B = b_ptr + gi * b_gi_stride; + const int16_t* origin_A = weight + gi * w_gi_stride; + sgemm_prepack_c8_int16_small( + oc_8 * 8, tile_count, ic_8 * 8, origin_A, origin_B, origin_C, ctx); + } + //*/ + //* + // output trans + for (int ti = 0; ti < tile_count; ++ti) { + int index = tile_index + ti; + + int tw_index = index % tile_w; + int th_index = index / tile_w; + + int dst_x = tw_index * 2; + int dst_y = th_index * 2; + + int ex = dst_x + 2 > wout ? wout - dst_x : 2; + int ey = dst_y + 2 > hout ? hout - dst_y : 2; + + int32_t* src_ptr = dst_temp_data + ti * 8; + int32_t* trans_remain_tmp_i32_data = + (int32_t*)(trans_remain_tmp_data); // NOLINT + int32_t* dst_ptr = output_c8 + (dst_y * wout + dst_x) * 8; + + if (ex == 2 && ey == 2) { + // trans output + for (int ci = 0; ci < oc_8; ++ci) { + int cur_ind = ci * 8; + + int32_t* src_ci = src_ptr + ci * tile_count * 8; + int32_t* dst_ci = dst_ptr + ci * oc_8_stride; + output_trans_c8_post_2x4_int8( + src_ci, c_gi_stride, c_gi_stride * 4, dst_ci, 8, wout * 8); + } + } else { + for (int ci = 0; ci < oc_8; ++ci) { + int cur_ind = ci * 8; + // trans output + int32_t* src_ci = src_ptr + ci * tile_count * 8; + output_trans_c8_post_2x4_int8(src_ci, + c_gi_stride, + c_gi_stride * 4, + trans_remain_tmp_i32_data, + 8, + 16); + // copy to dest + int32_t* dst_ci = dst_ptr + ci * oc_8_stride; + for (int i = 0; i < ey; ++i) { + memcpy(dst_ci + i * wout * 8, + trans_remain_tmp_i32_data + i * 16, + ex * sizeof(int32_t) * 8); + } + } + } + } + //*/ + } // for block_count + const float* bias_local_ptr = bias; + for (int ci = 0; ci < oc_8; ++ci) { + float bias_local[8] = {bias_local_ptr[0], + bias_local_ptr[1], + bias_local_ptr[2], + bias_local_ptr[3], + bias_local_ptr[4], + bias_local_ptr[5], + bias_local_ptr[6], + bias_local_ptr[7]}; + write_int32_nchwc8_to_nchw(output_c8 + ci * oc_8_stride, + output_ptr, + ci * 8, + ci * 8 + 8, + 0, + hout, + 0, + wout, + chout, + hout, + wout, + flag_act > 0, + bias_local, + param.bias, + zero_ptr, + scale + ci * 8); + bias_local_ptr += 8; + } + } // for num +} // conv compute +template void conv_compute_2x2_3x3_int8( + const int8_t* input, + int8_t* output, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + const int16_t* weight, + const float* bias, + const float* scale, + const operators::ConvParam& param, + ARMContext* ctx); +template void conv_compute_2x2_3x3_int8( + const int8_t* input, + float* output, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + const int16_t* weight, + const float* bias, + const float* scale, + const operators::ConvParam& param, + ARMContext* ctx); + +// BT=[1, 0, -1, 0, +// 0, 1, 1, 0, +// 0, -1, 1, 0, +// 0, 1, 0, -1] +void input_trans_c8_4x4_int8(const int8_t* src, + int src_stride, + int src_h_stride, + int16_t* dest, + int dest_stride, + int dest_h_stride) { + int8x8_t src00 = vld1_s8(src); + int8x8_t src01 = vld1_s8(src + src_stride); + int8x8_t src02 = vld1_s8(src + src_stride + src_stride); + int8x8_t src03 = vld1_s8(src + src_stride + src_stride + src_stride); + src += src_h_stride; + int8x8_t src10 = vld1_s8(src); + int8x8_t src11 = vld1_s8(src + src_stride); + int8x8_t src12 = vld1_s8(src + src_stride + src_stride); + int8x8_t src13 = vld1_s8(src + src_stride + src_stride + src_stride); + src += src_h_stride; + int8x8_t src20 = vld1_s8(src); + int8x8_t src21 = vld1_s8(src + src_stride); + int8x8_t src22 = vld1_s8(src + src_stride + src_stride); + int8x8_t src23 = vld1_s8(src + src_stride + src_stride + src_stride); + src += src_h_stride; + int8x8_t src30 = vld1_s8(src); + int8x8_t src31 = vld1_s8(src + src_stride); + int8x8_t src32 = vld1_s8(src + src_stride + src_stride); + int8x8_t src33 = vld1_s8(src + src_stride + src_stride + src_stride); + + int16x8_t dst00 = vsubl_s8(src00, src02); + int16x8_t dst10 = vaddl_s8(src01, src02); + int16x8_t dst20 = vsubl_s8(src02, src01); + int16x8_t dst30 = vsubl_s8(src01, src03); + + int16x8_t dst01 = vsubl_s8(src10, src12); + int16x8_t dst11 = vaddl_s8(src11, src12); + int16x8_t dst21 = vsubl_s8(src12, src11); + int16x8_t dst31 = vsubl_s8(src11, src13); + + int16x8_t dst02 = vsubl_s8(src20, src22); + int16x8_t dst12 = vaddl_s8(src21, src22); + int16x8_t dst22 = vsubl_s8(src22, src21); + int16x8_t dst32 = vsubl_s8(src21, src23); + + int16x8_t dst03 = vsubl_s8(src30, src32); + int16x8_t dst13 = vaddl_s8(src31, src32); + int16x8_t dst23 = vsubl_s8(src32, src31); + int16x8_t dst33 = vsubl_s8(src31, src33); + + int16x8_t dest00 = vsubq_s16(dst00, dst02); + int16x8_t dest10 = vaddq_s16(dst01, dst02); + int16x8_t dest20 = vsubq_s16(dst02, dst01); + int16x8_t dest30 = vsubq_s16(dst01, dst03); + + int16x8_t dest01 = vsubq_s16(dst10, dst12); + int16x8_t dest11 = vaddq_s16(dst11, dst12); + int16x8_t dest21 = vsubq_s16(dst12, dst11); + int16x8_t dest31 = vsubq_s16(dst11, dst13); + + int16x8_t dest02 = vsubq_s16(dst20, dst22); + int16x8_t dest12 = vaddq_s16(dst21, dst22); + int16x8_t dest22 = vsubq_s16(dst22, dst21); + int16x8_t dest32 = vsubq_s16(dst21, dst23); + + int16x8_t dest03 = vsubq_s16(dst30, dst32); + int16x8_t dest13 = vaddq_s16(dst31, dst32); + int16x8_t dest23 = vsubq_s16(dst32, dst31); + int16x8_t dest33 = vsubq_s16(dst31, dst33); + + vst1q_s16(dest, dest00); + vst1q_s16(dest + dest_stride, dest10); + vst1q_s16(dest + dest_stride + dest_stride, dest20); + vst1q_s16(dest + dest_stride + dest_stride + dest_stride, dest30); + dest += dest_h_stride; + vst1q_s16(dest, dest01); + vst1q_s16(dest + dest_stride, dest11); + vst1q_s16(dest + dest_stride + dest_stride, dest21); + vst1q_s16(dest + dest_stride + dest_stride + dest_stride, dest31); + dest += dest_h_stride; + vst1q_s16(dest, dest02); + vst1q_s16(dest + dest_stride, dest12); + vst1q_s16(dest + dest_stride + dest_stride, dest22); + vst1q_s16(dest + dest_stride + dest_stride + dest_stride, dest32); + dest += dest_h_stride; + vst1q_s16(dest, dest03); + vst1q_s16(dest + dest_stride, dest13); + vst1q_s16(dest + dest_stride + dest_stride, dest23); + vst1q_s16(dest + dest_stride + dest_stride + dest_stride, dest33); +} + +// AT=[1, 1, 1, 0, +// 0, 1, -1, -1] +void output_trans_c8_post_2x4_int8(const int32_t* src, + int src_stride, + int src_h_stride, + int32_t* dest, + int dest_stride, + int dest_h_stride) { + int32x4_t src400 = vld1q_s32(src); + int32x4_t src800 = vld1q_s32(src + 4); + src += src_stride; + int32x4_t src401 = vld1q_s32(src); + int32x4_t src801 = vld1q_s32(src + 4); + src += src_stride; + int32x4_t src402 = vld1q_s32(src); + int32x4_t src802 = vld1q_s32(src + 4); + src += src_stride; + int32x4_t src403 = vld1q_s32(src); + int32x4_t src803 = vld1q_s32(src + 4); + + src += src_h_stride - 3 * src_stride; + + int32x4_t src410 = vld1q_s32(src); + int32x4_t src810 = vld1q_s32(src + 4); + src += src_stride; + int32x4_t src411 = vld1q_s32(src); + int32x4_t src811 = vld1q_s32(src + 4); + src += src_stride; + int32x4_t src412 = vld1q_s32(src); + int32x4_t src812 = vld1q_s32(src + 4); + src += src_stride; + int32x4_t src413 = vld1q_s32(src); + int32x4_t src813 = vld1q_s32(src + 4); + + src += src_h_stride - 3 * src_stride; + + int32x4_t src420 = vld1q_s32(src); + int32x4_t src820 = vld1q_s32(src + 4); + src += src_stride; + int32x4_t src421 = vld1q_s32(src); + int32x4_t src821 = vld1q_s32(src + 4); + src += src_stride; + int32x4_t src422 = vld1q_s32(src); + int32x4_t src822 = vld1q_s32(src + 4); + src += src_stride; + int32x4_t src423 = vld1q_s32(src); + int32x4_t src823 = vld1q_s32(src + 4); + + src += src_h_stride - 3 * src_stride; + + int32x4_t src430 = vld1q_s32(src); + int32x4_t src830 = vld1q_s32(src + 4); + src += src_stride; + int32x4_t src431 = vld1q_s32(src); + int32x4_t src831 = vld1q_s32(src + 4); + src += src_stride; + int32x4_t src432 = vld1q_s32(src); + int32x4_t src832 = vld1q_s32(src + 4); + src += src_stride; + int32x4_t src433 = vld1q_s32(src); + int32x4_t src833 = vld1q_s32(src + 4); + + int32x4_t dst400 = vaddq_s32(vaddq_s32(src400, src401), src402); + int32x4_t dst410 = vsubq_s32(vsubq_s32(src401, src402), src403); + int32x4_t dst401 = vaddq_s32(vaddq_s32(src410, src411), src412); + int32x4_t dst411 = vsubq_s32(vsubq_s32(src411, src412), src413); + int32x4_t dst402 = vaddq_s32(vaddq_s32(src420, src421), src422); + int32x4_t dst412 = vsubq_s32(vsubq_s32(src421, src422), src423); + int32x4_t dst403 = vaddq_s32(vaddq_s32(src430, src431), src432); + int32x4_t dst413 = vsubq_s32(vsubq_s32(src431, src432), src433); + + int32x4_t dst800 = vaddq_s32(vaddq_s32(src800, src801), src802); + int32x4_t dst810 = vsubq_s32(vsubq_s32(src801, src802), src803); + int32x4_t dst801 = vaddq_s32(vaddq_s32(src810, src811), src812); + int32x4_t dst811 = vsubq_s32(vsubq_s32(src811, src812), src813); + int32x4_t dst802 = vaddq_s32(vaddq_s32(src820, src821), src822); + int32x4_t dst812 = vsubq_s32(vsubq_s32(src821, src822), src823); + int32x4_t dst803 = vaddq_s32(vaddq_s32(src830, src831), src832); + int32x4_t dst813 = vsubq_s32(vsubq_s32(src831, src832), src833); + + int32x4_t dest400 = vaddq_s32(vaddq_s32(dst400, dst401), dst402); + int32x4_t dest410 = vsubq_s32(vsubq_s32(dst401, dst402), dst403); + int32x4_t dest401 = vaddq_s32(vaddq_s32(dst410, dst411), dst412); + int32x4_t dest411 = vsubq_s32(vsubq_s32(dst411, dst412), dst413); + + int32x4_t dest800 = vaddq_s32(vaddq_s32(dst800, dst801), dst802); + int32x4_t dest810 = vsubq_s32(vsubq_s32(dst801, dst802), dst803); + int32x4_t dest801 = vaddq_s32(vaddq_s32(dst810, dst811), dst812); + int32x4_t dest811 = vsubq_s32(vsubq_s32(dst811, dst812), dst813); + + vst1q_s32(dest, dest400); + vst1q_s32(dest + 4, dest800); + dest += dest_stride; + vst1q_s32(dest, dest410); + vst1q_s32(dest + 4, dest810); + dest += dest_h_stride - dest_stride; + vst1q_s32(dest, dest401); + vst1q_s32(dest + 4, dest801); + dest += dest_stride; + vst1q_s32(dest, dest411); + vst1q_s32(dest + 4, dest811); +} + +void weight_trans_c8_4x4_int8( + int16_t* dest, const int8_t* din, int ch_in, int ch_out, void* workspace) { + const int16_t coeff[4][3] = {{2, 0, 0}, {1, 1, 1}, {1, -1, 1}, {0, 0, 2}}; + + int16_t* ptr_out = static_cast(workspace); + + for (int i = 0; i < ch_out; i++) { + for (int j = 0; j < ch_in; j++) { + const int8_t* kernel0 = + static_cast(din) + (i * ch_in + j) * 9; + int16_t* ptr_channel = ptr_out + (i * ch_in + j) * 16; + + //! transform kernel, transposed + const int8_t* k0 = kernel0; + const int8_t* k1 = kernel0 + 3; + const int8_t* k2 = kernel0 + 6; + + //! h + int16_t tmp[4][3]; + for (int i = 0; i < 4; i++) { + tmp[i][0] = + k0[0] * coeff[i][0] + k0[1] * coeff[i][1] + k0[2] * coeff[i][2]; + tmp[i][1] = + k1[0] * coeff[i][0] + k1[1] * coeff[i][1] + k1[2] * coeff[i][2]; + tmp[i][2] = + k2[0] * coeff[i][0] + k2[1] * coeff[i][1] + k2[2] * coeff[i][2]; + } + + //! v + for (int j = 0; j < 4; j++) { + int16_t* tmpp = &tmp[j][0]; + for (int i = 0; i < 4; i++) { + ptr_channel[j * 4 + i] = tmpp[0] * coeff[i][0] + + tmpp[1] * coeff[i][1] + + tmpp[2] * coeff[i][2]; + } + } + } + } + + int oc_pad = (ch_out + 7) / 8 * 8; + int ic_pad = (ch_in + 7) / 8 * 8; + int c_stride = ic_pad * oc_pad; + for (int i = 0; i < ch_out * ch_in * 16; ++i) { + int new_c = i % 16; + int new_oc = i / ch_in / 16 / 8; + int new_ic = i / 16 % ch_in; + int new_inner = i / ch_in / 16 % 8; + int dest_ind = + new_c * c_stride + new_oc * ic_pad * 8 + new_ic * 8 + new_inner; + dest[dest_ind] = ptr_out[i]; + } +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/backends/arm/math/conv_block_utils.h b/lite/backends/arm/math/conv_block_utils.h index c4fb51021e5b0288a4bc1fd476764348fdc7e450..7a7f06fba3400e74d837c2644c02c9a6539e9fe9 100644 --- a/lite/backends/arm/math/conv_block_utils.h +++ b/lite/backends/arm/math/conv_block_utils.h @@ -3878,6 +3878,7 @@ inline void write_int32_nchwc8_to_nchw(const int* din, int w_stride = we - ws; int valid_w = (we > width ? width : we) - ws; int cnt = valid_w / 4; + int remain = valid_w & 3; float32x4_t w_scale0 = vld1q_f32(scale); float32x4_t w_scale1 = vld1q_f32(scale + 4); @@ -3933,10 +3934,10 @@ inline void write_int32_nchwc8_to_nchw(const int* din, w_bias1, flag_relu); } - if (we > width) { + if (remain > 0) { int offset = 32 * cnt; din_hei_ptr = ptr_din + offset; - for (int j = ws + cnt * 4; j < width; ++j) { + for (int j = 0; j < remain; ++j) { if (flag_bias) { *(doutc0_ptr++) = cvt_kernel(din_hei_ptr[0], scale[0], bias[0], flag_relu); diff --git a/lite/backends/arm/math/conv_impl.h b/lite/backends/arm/math/conv_impl.h index 28a2fb7e2a42a27e9ecd3d42b25f9942b481004e..495a13eec17a0c35e90fbf3ef47c505028721857 100644 --- a/lite/backends/arm/math/conv_impl.h +++ b/lite/backends/arm/math/conv_impl.h @@ -359,6 +359,35 @@ void conv_compute_2x2_3x3_small(const float* input, const float* bias, const operators::ConvParam& param, ARMContext* ctx); +void input_trans_c8_4x4_int8(const int8_t* src, + int src_stride, + int src_h_stride, + int16_t* dest, + int dest_stride, + int dest_h_stride); +void output_trans_c8_post_2x4_int8(const int32_t* src, + int src_stride, + int src_h_stride, + int32_t* dest, + int dest_stride, + int dest_h_stride); +void weight_trans_c8_4x4_int8( + int16_t* dest, const int8_t* src, int ic, int oc, void* workspace); +template +void conv_compute_2x2_3x3_int8(const int8_t* input, + Dtype* output, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + const int16_t* weight, + const float* bias, + const float* scale, + const operators::ConvParam& param, + ARMContext* ctx); template void im2col(const Dtype* data_im, diff --git a/lite/backends/arm/math/gemm_prepacked_int8.cc b/lite/backends/arm/math/gemm_prepacked_int8.cc index 08f88105e052322e13390b7482fed7d8dd15089b..fb68550725923bceffe6d4f00c687e83c7b32354 100644 --- a/lite/backends/arm/math/gemm_prepacked_int8.cc +++ b/lite/backends/arm/math/gemm_prepacked_int8.cc @@ -1922,19 +1922,45 @@ void gemm_prepack_oth_int8(const int8_t* A_packed, Dtype* tmp1 = nullptr; Dtype* tmp2 = nullptr; Dtype* tmp3 = nullptr; - float32_t scale_local[4]; + float32_t scale_local[4] = {0, 0, 0, 0}; float32_t bias_local[4] = {0, 0, 0, 0}; if (is_bias) { - bias_local[0] = bias[y]; - bias_local[1] = bias[y + 1]; - bias_local[2] = bias[y + 2]; - bias_local[3] = bias[y + 3]; + if (y + 4 <= M) { + bias_local[0] = bias[y]; + bias_local[1] = bias[y + 1]; + bias_local[2] = bias[y + 2]; + bias_local[3] = bias[y + 3]; + } else { + switch (M - y) { + case 3: + bias_local[2] = bias[y + 2]; + case 2: + bias_local[1] = bias[y + 1]; + case 1: + bias_local[0] = bias[y + 0]; + default: + break; + } + } } if (scale) { - scale_local[0] = scale[y]; - scale_local[1] = scale[y + 1]; - scale_local[2] = scale[y + 2]; - scale_local[3] = scale[y + 3]; + if (y + 4 <= M) { + scale_local[0] = scale[y]; + scale_local[1] = scale[y + 1]; + scale_local[2] = scale[y + 2]; + scale_local[3] = scale[y + 3]; + } else { + switch (M - y) { + case 3: + scale_local[2] = scale[y + 2]; + case 2: + scale_local[1] = scale[y + 1]; + case 1: + scale_local[0] = scale[y + 0]; + default: + break; + } + } } if (y + MBLOCK_INT8_OTH > M) { switch (y + MBLOCK_INT8_OTH - M) { diff --git a/lite/backends/arm/math/packed_sgemm_c4.cc b/lite/backends/arm/math/packed_sgemm_c4.cc index af4934e85756f03ec197520b2b5c130e27bdcad6..db1189a63c38bdb6ab33c6fa280a6f618b53ef7f 100644 --- a/lite/backends/arm/math/packed_sgemm_c4.cc +++ b/lite/backends/arm/math/packed_sgemm_c4.cc @@ -1679,6 +1679,912 @@ void sgemm_prepack_c4_small(int M, } } +void sgemm_prepack_c8_int16_small(int M, + int N, + int K, + const int16_t* A_packed, + const int16_t* B, + int32_t* C, + ARMContext* ctx) { + const int m_round = (M + 7) / 8 * 8; + const int k_round = (K + 7) / 8 * 8; + const int mloop = m_round >> 3; + const int lda = 8 * k_round; + const int ldb_byte = 8 * N * sizeof(int16_t); + const int kcnt = k_round >> 3; +#ifdef __aarch64__ + float32x4_t vzero = vdupq_n_f32(0.f); +#endif + for (int m = 0; m < mloop; ++m) { + const int16_t* b = B; + int n = N; +#ifdef __aarch64__ + for (; n > 7; n -= 8) { + int cnt = kcnt; + const int16_t* a_ptr = A_packed; + const int16_t* b_ptr = b; + // clang-format off + asm volatile( + "ld1 {v0.8h, v1.8h}, [%[a]], #32 \n" //load a0, a1 + "ld1 {v4.8h, v5.8h}, [%[b]], #32 \n" //load b0, b1 + "ld1 {v6.8h, v7.8h}, [%[b]], #32 \n" //load b2, b3 + + "smull v20.4s, v0.4h, v4.h[0] \n" + "smull v21.4s, v0.4h, v5.h[0] \n" + "smull v22.4s, v0.4h, v6.h[0] \n" + "smull v23.4s, v0.4h, v7.h[0] \n" + "ld1 {v8.8h, v9.8h}, [%[b]], #32 \n" //load b0, b1 + "ld1 {v10.8h, v11.8h}, [%[b]], #32 \n" //load b2, b3 + + "smull2 v24.4s, v0.8h, v4.h[0] \n" + "smull2 v25.4s, v0.8h, v5.h[0] \n" + "smull2 v26.4s, v0.8h, v6.h[0] \n" + "smull2 v27.4s, v0.8h, v7.h[0] \n" + "ld1 {v2.8h, v3.8h}, [%[a]], #32 \n" //load a2, a3 + + "smlal v20.4s, v1.4h, v4.h[1] \n" + "smlal v21.4s, v1.4h, v5.h[1] \n" + "smlal v22.4s, v1.4h, v6.h[1] \n" + "smlal v23.4s, v1.4h, v7.h[1] \n" + + "smlal2 v24.4s, v1.8h, v4.h[1] \n" + "smlal2 v25.4s, v1.8h, v5.h[1] \n" + "smlal2 v26.4s, v1.8h, v6.h[1] \n" + "smlal2 v27.4s, v1.8h, v7.h[1] \n" + + "smull v12.4s, v0.4h, v8.h[0] \n" + "smull v13.4s, v0.4h, v9.h[0] \n" + "smull v14.4s, v0.4h, v10.h[0] \n" + "smull v15.4s, v0.4h, v11.h[0] \n" + + "smull2 v16.4s, v0.8h, v8.h[0] \n" + "smull2 v17.4s, v0.8h, v9.h[0] \n" + "smull2 v18.4s, v0.8h, v10.h[0] \n" + "smull2 v19.4s, v0.8h, v11.h[0] \n" + + "smlal v12.4s, v1.4h, v8.h[1] \n" + "smlal v13.4s, v1.4h, v9.h[1] \n" + "smlal v14.4s, v1.4h, v10.h[1] \n" + "smlal v15.4s, v1.4h, v11.h[1] \n" + + "smlal2 v16.4s, v1.8h, v8.h[1] \n" + "smlal2 v17.4s, v1.8h, v9.h[1] \n" + "smlal2 v18.4s, v1.8h, v10.h[1] \n" + "smlal2 v19.4s, v1.8h, v11.h[1] \n" + + "smlal v20.4s, v2.4h, v4.h[2] \n" + "smlal v21.4s, v2.4h, v5.h[2] \n" + "smlal v22.4s, v2.4h, v6.h[2] \n" + "smlal v23.4s, v2.4h, v7.h[2] \n" + "ld1 {v0.8h, v1.8h}, [%[a]], #32 \n" //load a0, a1 + "smlal2 v24.4s, v2.8h, v4.h[2] \n" + "smlal2 v25.4s, v2.8h, v5.h[2] \n" + "smlal2 v26.4s, v2.8h, v6.h[2] \n" + "smlal2 v27.4s, v2.8h, v7.h[2] \n" + "smlal v12.4s, v2.4h, v8.h[2] \n" + "smlal v13.4s, v2.4h, v9.h[2] \n" + "smlal v14.4s, v2.4h, v10.h[2] \n" + "smlal v15.4s, v2.4h, v11.h[2] \n" + "smlal2 v16.4s, v2.8h, v8.h[2] \n" + "smlal2 v17.4s, v2.8h, v9.h[2] \n" + "smlal2 v18.4s, v2.8h, v10.h[2] \n" + "smlal2 v19.4s, v2.8h, v11.h[2] \n" + + "smlal v20.4s, v3.4h, v4.h[3] \n" + "smlal v21.4s, v3.4h, v5.h[3] \n" + "smlal v22.4s, v3.4h, v6.h[3] \n" + "smlal v23.4s, v3.4h, v7.h[3] \n" + "smlal2 v24.4s, v3.8h, v4.h[3] \n" + "smlal2 v25.4s, v3.8h, v5.h[3] \n" + "smlal2 v26.4s, v3.8h, v6.h[3] \n" + "smlal2 v27.4s, v3.8h, v7.h[3] \n" + "smlal v12.4s, v3.4h, v8.h[3] \n" + "smlal v13.4s, v3.4h, v9.h[3] \n" + "smlal v14.4s, v3.4h, v10.h[3] \n" + "smlal v15.4s, v3.4h, v11.h[3] \n" + "smlal2 v16.4s, v3.8h, v8.h[3] \n" + "smlal2 v17.4s, v3.8h, v9.h[3] \n" + "smlal2 v18.4s, v3.8h, v10.h[3] \n" + "smlal2 v19.4s, v3.8h, v11.h[3] \n" + + "smlal v20.4s, v0.4h, v4.h[4] \n" + "smlal v21.4s, v0.4h, v5.h[4] \n" + "smlal v22.4s, v0.4h, v6.h[4] \n" + "smlal v23.4s, v0.4h, v7.h[4] \n" + + "smlal2 v24.4s, v0.8h, v4.h[4] \n" + "smlal2 v25.4s, v0.8h, v5.h[4] \n" + "smlal2 v26.4s, v0.8h, v6.h[4] \n" + "smlal2 v27.4s, v0.8h, v7.h[4] \n" + "ld1 {v2.8h, v3.8h}, [%[a]], #32 \n" //load a2, a3 + + "smlal v20.4s, v1.4h, v4.h[5] \n" + "smlal v21.4s, v1.4h, v5.h[5] \n" + "smlal v22.4s, v1.4h, v6.h[5] \n" + "smlal v23.4s, v1.4h, v7.h[5] \n" + + "smlal2 v24.4s, v1.8h, v4.h[5] \n" + "smlal2 v25.4s, v1.8h, v5.h[5] \n" + "smlal2 v26.4s, v1.8h, v6.h[5] \n" + "smlal2 v27.4s, v1.8h, v7.h[5] \n" + + "smlal v12.4s, v0.4h, v8.h[4] \n" + "smlal v13.4s, v0.4h, v9.h[4] \n" + "smlal v14.4s, v0.4h, v10.h[4] \n" + "smlal v15.4s, v0.4h, v11.h[4] \n" + + "smlal2 v16.4s, v0.8h, v8.h[4] \n" + "smlal2 v17.4s, v0.8h, v9.h[4] \n" + "smlal2 v18.4s, v0.8h, v10.h[4] \n" + "smlal2 v19.4s, v0.8h, v11.h[4] \n" + + "smlal v12.4s, v1.4h, v8.h[5] \n" + "smlal v13.4s, v1.4h, v9.h[5] \n" + "smlal v14.4s, v1.4h, v10.h[5] \n" + "smlal v15.4s, v1.4h, v11.h[5] \n" + + "smlal2 v16.4s, v1.8h, v8.h[5] \n" + "smlal2 v17.4s, v1.8h, v9.h[5] \n" + "smlal2 v18.4s, v1.8h, v10.h[5] \n" + "smlal2 v19.4s, v1.8h, v11.h[5] \n" + + "smlal v20.4s, v2.4h, v4.h[6] \n" + "smlal v21.4s, v2.4h, v5.h[6] \n" + "smlal v22.4s, v2.4h, v6.h[6] \n" + "smlal v23.4s, v2.4h, v7.h[6] \n" + "ld1 {v0.8h, v1.8h}, [%[a]], #32 \n" //load a0, a1 + "smlal2 v24.4s, v2.8h, v4.h[6] \n" + "smlal2 v25.4s, v2.8h, v5.h[6] \n" + "smlal2 v26.4s, v2.8h, v6.h[6] \n" + "smlal2 v27.4s, v2.8h, v7.h[6] \n" + "sub %[b], %[b], #128 \n" + "add %[b], %[b], %[ldb] \n" + "smlal v20.4s, v3.4h, v4.h[7] \n" + "smlal v21.4s, v3.4h, v5.h[7] \n" + "smlal v22.4s, v3.4h, v6.h[7] \n" + "smlal v23.4s, v3.4h, v7.h[7] \n" + "smlal2 v24.4s, v3.8h, v4.h[7] \n" + "smlal2 v25.4s, v3.8h, v5.h[7] \n" + "smlal2 v26.4s, v3.8h, v6.h[7] \n" + "smlal2 v27.4s, v3.8h, v7.h[7] \n" + "ld1 {v4.8h, v5.8h}, [%[b]], #32 \n" //load b0, b1 + "ld1 {v6.8h, v7.8h}, [%[b]], #32 \n" //load b2, b3 + + "smlal v12.4s, v2.4h, v8.h[6] \n" + "smlal v13.4s, v2.4h, v9.h[6] \n" + "smlal v14.4s, v2.4h, v10.h[6] \n" + "smlal v15.4s, v2.4h, v11.h[6] \n" + "smlal2 v16.4s, v2.8h, v8.h[6] \n" + "smlal2 v17.4s, v2.8h, v9.h[6] \n" + "smlal2 v18.4s, v2.8h, v10.h[6] \n" + "smlal2 v19.4s, v2.8h, v11.h[6] \n" + "subs %w[cnt], %w[cnt], #1 \n" + + "smlal v12.4s, v3.4h, v8.h[7] \n" + "smlal v13.4s, v3.4h, v9.h[7] \n" + "smlal v14.4s, v3.4h, v10.h[7] \n" + "smlal v15.4s, v3.4h, v11.h[7] \n" + "smlal2 v16.4s, v3.8h, v8.h[7] \n" + "smlal2 v17.4s, v3.8h, v9.h[7] \n" + "smlal2 v18.4s, v3.8h, v10.h[7] \n" + "smlal2 v19.4s, v3.8h, v11.h[7] \n" + + "beq 2f \n" + "1:\n" + "smlal v20.4s, v0.4h, v4.h[0] \n" + "smlal v21.4s, v0.4h, v5.h[0] \n" + "smlal v22.4s, v0.4h, v6.h[0] \n" + "smlal v23.4s, v0.4h, v7.h[0] \n" + "ld1 {v8.8h, v9.8h}, [%[b]], #32 \n" //load b0, b1 + "ld1 {v10.8h, v11.8h}, [%[b]], #32 \n" //load b2, b3 + + "smlal2 v24.4s, v0.8h, v4.h[0] \n" + "smlal2 v25.4s, v0.8h, v5.h[0] \n" + "smlal2 v26.4s, v0.8h, v6.h[0] \n" + "smlal2 v27.4s, v0.8h, v7.h[0] \n" + "ld1 {v2.8h, v3.8h}, [%[a]], #32 \n" //load a2, a3 + + "smlal v20.4s, v1.4h, v4.h[1] \n" + "smlal v21.4s, v1.4h, v5.h[1] \n" + "smlal v22.4s, v1.4h, v6.h[1] \n" + "smlal v23.4s, v1.4h, v7.h[1] \n" + + "smlal2 v24.4s, v1.8h, v4.h[1] \n" + "smlal2 v25.4s, v1.8h, v5.h[1] \n" + "smlal2 v26.4s, v1.8h, v6.h[1] \n" + "smlal2 v27.4s, v1.8h, v7.h[1] \n" + + "smlal v12.4s, v0.4h, v8.h[0] \n" + "smlal v13.4s, v0.4h, v9.h[0] \n" + "smlal v14.4s, v0.4h, v10.h[0] \n" + "smlal v15.4s, v0.4h, v11.h[0] \n" + + "smlal2 v16.4s, v0.8h, v8.h[0] \n" + "smlal2 v17.4s, v0.8h, v9.h[0] \n" + "smlal2 v18.4s, v0.8h, v10.h[0] \n" + "smlal2 v19.4s, v0.8h, v11.h[0] \n" + + "smlal v12.4s, v1.4h, v8.h[1] \n" + "smlal v13.4s, v1.4h, v9.h[1] \n" + "smlal v14.4s, v1.4h, v10.h[1] \n" + "smlal v15.4s, v1.4h, v11.h[1] \n" + + "smlal2 v16.4s, v1.8h, v8.h[1] \n" + "smlal2 v17.4s, v1.8h, v9.h[1] \n" + "smlal2 v18.4s, v1.8h, v10.h[1] \n" + "smlal2 v19.4s, v1.8h, v11.h[1] \n" + + "smlal v20.4s, v2.4h, v4.h[2] \n" + "smlal v21.4s, v2.4h, v5.h[2] \n" + "smlal v22.4s, v2.4h, v6.h[2] \n" + "smlal v23.4s, v2.4h, v7.h[2] \n" + "ld1 {v0.8h, v1.8h}, [%[a]], #32 \n" //load a0, a1 + "smlal2 v24.4s, v2.8h, v4.h[2] \n" + "smlal2 v25.4s, v2.8h, v5.h[2] \n" + "smlal2 v26.4s, v2.8h, v6.h[2] \n" + "smlal2 v27.4s, v2.8h, v7.h[2] \n" + "smlal v12.4s, v2.4h, v8.h[2] \n" + "smlal v13.4s, v2.4h, v9.h[2] \n" + "smlal v14.4s, v2.4h, v10.h[2] \n" + "smlal v15.4s, v2.4h, v11.h[2] \n" + "smlal2 v16.4s, v2.8h, v8.h[2] \n" + "smlal2 v17.4s, v2.8h, v9.h[2] \n" + "smlal2 v18.4s, v2.8h, v10.h[2] \n" + "smlal2 v19.4s, v2.8h, v11.h[2] \n" + + "smlal v20.4s, v3.4h, v4.h[3] \n" + "smlal v21.4s, v3.4h, v5.h[3] \n" + "smlal v22.4s, v3.4h, v6.h[3] \n" + "smlal v23.4s, v3.4h, v7.h[3] \n" + "smlal2 v24.4s, v3.8h, v4.h[3] \n" + "smlal2 v25.4s, v3.8h, v5.h[3] \n" + "smlal2 v26.4s, v3.8h, v6.h[3] \n" + "smlal2 v27.4s, v3.8h, v7.h[3] \n" + "smlal v12.4s, v3.4h, v8.h[3] \n" + "smlal v13.4s, v3.4h, v9.h[3] \n" + "smlal v14.4s, v3.4h, v10.h[3] \n" + "smlal v15.4s, v3.4h, v11.h[3] \n" + "smlal2 v16.4s, v3.8h, v8.h[3] \n" + "smlal2 v17.4s, v3.8h, v9.h[3] \n" + "smlal2 v18.4s, v3.8h, v10.h[3] \n" + "smlal2 v19.4s, v3.8h, v11.h[3] \n" + + "smlal v20.4s, v0.4h, v4.h[4] \n" + "smlal v21.4s, v0.4h, v5.h[4] \n" + "smlal v22.4s, v0.4h, v6.h[4] \n" + "smlal v23.4s, v0.4h, v7.h[4] \n" + + "smlal2 v24.4s, v0.8h, v4.h[4] \n" + "smlal2 v25.4s, v0.8h, v5.h[4] \n" + "smlal2 v26.4s, v0.8h, v6.h[4] \n" + "smlal2 v27.4s, v0.8h, v7.h[4] \n" + "ld1 {v2.8h, v3.8h}, [%[a]], #32 \n" //load a2, a3 + + "smlal v20.4s, v1.4h, v4.h[5] \n" + "smlal v21.4s, v1.4h, v5.h[5] \n" + "smlal v22.4s, v1.4h, v6.h[5] \n" + "smlal v23.4s, v1.4h, v7.h[5] \n" + + "smlal2 v24.4s, v1.8h, v4.h[5] \n" + "smlal2 v25.4s, v1.8h, v5.h[5] \n" + "smlal2 v26.4s, v1.8h, v6.h[5] \n" + "smlal2 v27.4s, v1.8h, v7.h[5] \n" + + "smlal v12.4s, v0.4h, v8.h[4] \n" + "smlal v13.4s, v0.4h, v9.h[4] \n" + "smlal v14.4s, v0.4h, v10.h[4] \n" + "smlal v15.4s, v0.4h, v11.h[4] \n" + + "smlal2 v16.4s, v0.8h, v8.h[4] \n" + "smlal2 v17.4s, v0.8h, v9.h[4] \n" + "smlal2 v18.4s, v0.8h, v10.h[4] \n" + "smlal2 v19.4s, v0.8h, v11.h[4] \n" + + "smlal v12.4s, v1.4h, v8.h[5] \n" + "smlal v13.4s, v1.4h, v9.h[5] \n" + "smlal v14.4s, v1.4h, v10.h[5] \n" + "smlal v15.4s, v1.4h, v11.h[5] \n" + + "smlal2 v16.4s, v1.8h, v8.h[5] \n" + "smlal2 v17.4s, v1.8h, v9.h[5] \n" + "smlal2 v18.4s, v1.8h, v10.h[5] \n" + "smlal2 v19.4s, v1.8h, v11.h[5] \n" + + "smlal v20.4s, v2.4h, v4.h[6] \n" + "smlal v21.4s, v2.4h, v5.h[6] \n" + "smlal v22.4s, v2.4h, v6.h[6] \n" + "smlal v23.4s, v2.4h, v7.h[6] \n" + "ld1 {v0.8h, v1.8h}, [%[a]], #32 \n" //load a0, a1 + "smlal2 v24.4s, v2.8h, v4.h[6] \n" + "smlal2 v25.4s, v2.8h, v5.h[6] \n" + "smlal2 v26.4s, v2.8h, v6.h[6] \n" + "smlal2 v27.4s, v2.8h, v7.h[6] \n" + "sub %[b], %[b], #128 \n" + "add %[b], %[b], %[ldb] \n" + "smlal v20.4s, v3.4h, v4.h[7] \n" + "smlal v21.4s, v3.4h, v5.h[7] \n" + "smlal v22.4s, v3.4h, v6.h[7] \n" + "smlal v23.4s, v3.4h, v7.h[7] \n" + "smlal2 v24.4s, v3.8h, v4.h[7] \n" + "smlal2 v25.4s, v3.8h, v5.h[7] \n" + "smlal2 v26.4s, v3.8h, v6.h[7] \n" + "smlal2 v27.4s, v3.8h, v7.h[7] \n" + "ld1 {v4.8h, v5.8h}, [%[b]], #32 \n" //load b0, b1 + "ld1 {v6.8h, v7.8h}, [%[b]], #32 \n" //load b2, b3 + + "smlal v12.4s, v2.4h, v8.h[6] \n" + "smlal v13.4s, v2.4h, v9.h[6] \n" + "smlal v14.4s, v2.4h, v10.h[6] \n" + "smlal v15.4s, v2.4h, v11.h[6] \n" + "smlal2 v16.4s, v2.8h, v8.h[6] \n" + "smlal2 v17.4s, v2.8h, v9.h[6] \n" + "smlal2 v18.4s, v2.8h, v10.h[6] \n" + "smlal2 v19.4s, v2.8h, v11.h[6] \n" + "subs %w[cnt], %w[cnt], #1 \n" + + "smlal v12.4s, v3.4h, v8.h[7] \n" + "smlal v13.4s, v3.4h, v9.h[7] \n" + "smlal v14.4s, v3.4h, v10.h[7] \n" + "smlal v15.4s, v3.4h, v11.h[7] \n" + "smlal2 v16.4s, v3.8h, v8.h[7] \n" + "smlal2 v17.4s, v3.8h, v9.h[7] \n" + "smlal2 v18.4s, v3.8h, v10.h[7] \n" + "smlal2 v19.4s, v3.8h, v11.h[7] \n" + + "bne 1b \n" + "2: \n" + "stp q20, q24, [%[c]], #32 \n" + "stp q21, q25, [%[c]], #32 \n" + "stp q22, q26, [%[c]], #32 \n" + "stp q23, q27, [%[c]], #32 \n" + "stp q12, q16, [%[c]], #32 \n" + "stp q13, q17, [%[c]], #32 \n" + "stp q14, q18, [%[c]], #32 \n" + "stp q15, q19, [%[c]], #32 \n" + : [a] "+r" (a_ptr), + [b] "+r" (b_ptr), + [c] "+r" (C), + [cnt] "+r" (cnt) + : [ldb] "r" (ldb_byte) + : "v0", "v1", "v2", "v3", "v4","v5", "v6", "v7", "v8", "v9", + "v10", "v11", "13", "14", "15", "16", "17", "18", "19","v20", + "v21", "v22", "v23", "v24", "v25", "v26", "v27", "cc", "memory" + ); + // clang format on + b += 64; + } + for (; n > 3; n -= 4) { + int cnt = kcnt; + const int16_t* a_ptr = A_packed; + const int16_t* b_ptr = b; + // clang-format off + asm volatile( + "ld1 {v0.8h, v1.8h}, [%[a]], #32 \n" + "ld1 {v4.8h, v5.8h}, [%[b]], #32 \n" + + "smull v8.4s, v0.4h, v4.h[0] \n" + "smull v9.4s, v0.4h, v5.h[0] \n" + "ld1 {v6.8h, v7.8h}, [%[b]], #32 \n" + "smull2 v10.4s, v0.8h, v4.h[0] \n" + "smull2 v11.4s, v0.8h, v5.h[0] \n" + + "smlal v8.4s, v1.4h, v4.h[1] \n" + "smlal v9.4s, v1.4h, v5.h[1] \n" + "smlal2 v10.4s, v1.8h, v4.h[1] \n" + "smlal2 v11.4s, v1.8h, v5.h[1] \n" + "ld1 {v2.8h, v3.8h}, [%[a]], #32 \n" + + "smull v12.4s, v0.4h, v6.h[0] \n" + "smull v13.4s, v0.4h, v7.h[0] \n" + "smull2 v14.4s, v0.8h, v6.h[0] \n" + "smull2 v15.4s, v0.8h, v7.h[0] \n" + "smlal v12.4s, v1.4h, v6.h[1] \n" + "smlal v13.4s, v1.4h, v7.h[1] \n" + "smlal2 v14.4s, v1.8h, v6.h[1] \n" + "smlal2 v15.4s, v1.8h, v7.h[1] \n" + + "smlal v8.4s, v2.4h, v4.h[2] \n" + "smlal v9.4s, v2.4h, v5.h[2] \n" + "ld1 {v0.8h, v1.8h}, [%[a]], #32 \n" + "smlal2 v10.4s, v2.8h, v4.h[2] \n" + "smlal2 v11.4s, v2.8h, v5.h[2] \n" + "smlal v8.4s, v3.4h, v4.h[3] \n" + "smlal v9.4s, v3.4h, v5.h[3] \n" + "smlal2 v10.4s, v3.8h, v4.h[3] \n" + "smlal2 v11.4s, v3.8h, v5.h[3] \n" + + "smlal v12.4s, v2.4h, v6.h[2] \n" + "smlal v13.4s, v2.4h, v7.h[2] \n" + "smlal2 v14.4s, v2.8h, v6.h[2] \n" + "smlal2 v15.4s, v2.8h, v7.h[2] \n" + "smlal v12.4s, v3.4h, v6.h[3] \n" + "smlal v13.4s, v3.4h, v7.h[3] \n" + "smlal2 v14.4s, v3.8h, v6.h[3] \n" + "smlal2 v15.4s, v3.8h, v7.h[3] \n" + + "smlal v8.4s, v0.4h, v4.h[4] \n" + "smlal v9.4s, v0.4h, v5.h[4] \n" + "smlal2 v10.4s, v0.8h, v4.h[4] \n" + "smlal2 v11.4s, v0.8h, v5.h[4] \n" + + "smlal v8.4s, v1.4h, v4.h[5] \n" + "smlal v9.4s, v1.4h, v5.h[5] \n" + "smlal2 v10.4s, v1.8h, v4.h[5] \n" + "smlal2 v11.4s, v1.8h, v5.h[5] \n" + "ld1 {v2.8h, v3.8h}, [%[a]], #32 \n" + + "smlal v12.4s, v0.4h, v6.h[4] \n" + "smlal v13.4s, v0.4h, v7.h[4] \n" + "smlal2 v14.4s, v0.8h, v6.h[4] \n" + "smlal2 v15.4s, v0.8h, v7.h[4] \n" + "smlal v12.4s, v1.4h, v6.h[5] \n" + "smlal v13.4s, v1.4h, v7.h[5] \n" + "smlal2 v14.4s, v1.8h, v6.h[5] \n" + "smlal2 v15.4s, v1.8h, v7.h[5] \n" + + "smlal v8.4s, v2.4h, v4.h[6] \n" + "smlal v9.4s, v2.4h, v5.h[6] \n" + "ld1 {v0.8h, v1.8h}, [%[a]], #32 \n" + "smlal2 v10.4s, v2.8h, v4.h[6] \n" + "smlal2 v11.4s, v2.8h, v5.h[6] \n" + "smlal v8.4s, v3.4h, v4.h[7] \n" + "smlal v9.4s, v3.4h, v5.h[7] \n" + "smlal2 v10.4s, v3.8h, v4.h[7] \n" + "smlal2 v11.4s, v3.8h, v5.h[7] \n" + "sub %[b], %[b], #64 \n" + "add %[b], %[b], %[ldb] \n" + + "smlal v12.4s, v2.4h, v6.h[6] \n" + "smlal v13.4s, v2.4h, v7.h[6] \n" + "subs %w[cnt], %w[cnt], #1 \n" + "ld1 {v4.8h, v5.8h}, [%[b]], #32 \n" + "smlal2 v14.4s, v2.8h, v6.h[6] \n" + "smlal2 v15.4s, v2.8h, v7.h[6] \n" + "smlal v12.4s, v3.4h, v6.h[7] \n" + "smlal v13.4s, v3.4h, v7.h[7] \n" + "smlal2 v14.4s, v3.8h, v6.h[7] \n" + "smlal2 v15.4s, v3.8h, v7.h[7] \n" + + "beq 2f \n" + "1: \n" + "smlal v8.4s, v0.4h, v4.h[0] \n" + "smlal v9.4s, v0.4h, v5.h[0] \n" + "ld1 {v6.8h, v7.8h}, [%[b]], #32 \n" + "smlal2 v10.4s, v0.8h, v4.h[0] \n" + "smlal2 v11.4s, v0.8h, v5.h[0] \n" + + "smlal v8.4s, v1.4h, v4.h[1] \n" + "smlal v9.4s, v1.4h, v5.h[1] \n" + "smlal2 v10.4s, v1.8h, v4.h[1] \n" + "smlal2 v11.4s, v1.8h, v5.h[1] \n" + "ld1 {v2.8h, v3.8h}, [%[a]], #32 \n" + + "smlal v12.4s, v0.4h, v6.h[0] \n" + "smlal v13.4s, v0.4h, v7.h[0] \n" + "smlal2 v14.4s, v0.8h, v6.h[0] \n" + "smlal2 v15.4s, v0.8h, v7.h[0] \n" + "smlal v12.4s, v1.4h, v6.h[1] \n" + "smlal v13.4s, v1.4h, v7.h[1] \n" + "smlal2 v14.4s, v1.8h, v6.h[1] \n" + "smlal2 v15.4s, v1.8h, v7.h[1] \n" + + "smlal v8.4s, v2.4h, v4.h[2] \n" + "smlal v9.4s, v2.4h, v5.h[2] \n" + "ld1 {v0.8h, v1.8h}, [%[a]], #32 \n" + "smlal2 v10.4s, v2.8h, v4.h[2] \n" + "smlal2 v11.4s, v2.8h, v5.h[2] \n" + "smlal v8.4s, v3.4h, v4.h[3] \n" + "smlal v9.4s, v3.4h, v5.h[3] \n" + "smlal2 v10.4s, v3.8h, v4.h[3] \n" + "smlal2 v11.4s, v3.8h, v5.h[3] \n" + + "smlal v12.4s, v2.4h, v6.h[2] \n" + "smlal v13.4s, v2.4h, v7.h[2] \n" + "smlal2 v14.4s, v2.8h, v6.h[2] \n" + "smlal2 v15.4s, v2.8h, v7.h[2] \n" + "smlal v12.4s, v3.4h, v6.h[3] \n" + "smlal v13.4s, v3.4h, v7.h[3] \n" + "smlal2 v14.4s, v3.8h, v6.h[3] \n" + "smlal2 v15.4s, v3.8h, v7.h[3] \n" + + "smlal v8.4s, v0.4h, v4.h[4] \n" + "smlal v9.4s, v0.4h, v5.h[4] \n" + "ld1 {v2.8h, v3.8h}, [%[a]], #32 \n" + "smlal2 v10.4s, v0.8h, v4.h[4] \n" + "smlal2 v11.4s, v0.8h, v5.h[4] \n" + + "smlal v8.4s, v1.4h, v4.h[5] \n" + "smlal v9.4s, v1.4h, v5.h[5] \n" + "smlal2 v10.4s, v1.8h, v4.h[5] \n" + "smlal2 v11.4s, v1.8h, v5.h[5] \n" + + "smlal v12.4s, v0.4h, v6.h[4] \n" + "smlal v13.4s, v0.4h, v7.h[4] \n" + "smlal2 v14.4s, v0.8h, v6.h[4] \n" + "smlal2 v15.4s, v0.8h, v7.h[4] \n" + "smlal v12.4s, v1.4h, v6.h[5] \n" + "smlal v13.4s, v1.4h, v7.h[5] \n" + "smlal2 v14.4s, v1.8h, v6.h[5] \n" + "smlal2 v15.4s, v1.8h, v7.h[5] \n" + + "smlal v8.4s, v2.4h, v4.h[6] \n" + "smlal v9.4s, v2.4h, v5.h[6] \n" + "ld1 {v0.8h, v1.8h}, [%[a]], #32 \n" + "smlal2 v10.4s, v2.8h, v4.h[6] \n" + "smlal2 v11.4s, v2.8h, v5.h[6] \n" + "smlal v8.4s, v3.4h, v4.h[7] \n" + "smlal v9.4s, v3.4h, v5.h[7] \n" + "smlal2 v10.4s, v3.8h, v4.h[7] \n" + "smlal2 v11.4s, v3.8h, v5.h[7] \n" + "sub %[b], %[b], #64 \n" + "add %[b], %[b], %[ldb] \n" + + "smlal v12.4s, v2.4h, v6.h[6] \n" + "smlal v13.4s, v2.4h, v7.h[6] \n" + "subs %w[cnt], %w[cnt], #1 \n" + "ld1 {v4.8h, v5.8h}, [%[b]], #32 \n" + "smlal2 v14.4s, v2.8h, v6.h[6] \n" + "smlal2 v15.4s, v2.8h, v7.h[6] \n" + "smlal v12.4s, v3.4h, v6.h[7] \n" + "smlal v13.4s, v3.4h, v7.h[7] \n" + "smlal2 v14.4s, v3.8h, v6.h[7] \n" + "smlal2 v15.4s, v3.8h, v7.h[7] \n" + + "bne 1b \n" + "2: \n" + "stp q8, q10, [%[c]], #32 \n" + "stp q9, q11, [%[c]], #32 \n" + "stp q12, q14, [%[c]], #32 \n" + "stp q13, q15, [%[c]], #32 \n" + : [a] "+r" (a_ptr), + [b] "+r" (b_ptr), + [c] "+r" (C), + [cnt] "+r" (cnt) + : [ldb] "r" (ldb_byte) + : "v0", "v1", "v2", "v3", "v4","v5", "v6", "v7", "v8", "v9", + "v10", "v11","v12", "v13", "v14", "v15", "cc", "memory" + ); + // clang-format on + b += 32; + } + for (; n > 0; --n) { + int cnt = kcnt; + const int16_t* a_ptr = A_packed; + const int16_t* b_ptr = b; + // clang-format off + asm volatile( + "ld1 {v0.8h, v1.8h}, [%[a]], #32 \n" + "ld1 {v4.8h}, [%[b]], #16 \n" + "ld1 {v2.8h, v3.8h}, [%[a]], #32 \n" + "smull v5.4s, v0.4h, v4.h[0] \n" + "smull2 v6.4s, v0.8h, v4.h[0] \n" + "ld1 {v10.8h, v11.8h}, [%[a]], #32 \n" + "smlal v5.4s, v1.4h, v4.h[1] \n" + "smlal2 v6.4s, v1.8h, v4.h[1] \n" + "ld1 {v12.8h, v13.8h}, [%[a]], #32 \n" + "smlal v5.4s, v2.4h, v4.h[2] \n" + "smlal2 v6.4s, v2.8h, v4.h[2] \n" + "smlal v5.4s, v3.4h, v4.h[3] \n" + "smlal2 v6.4s, v3.8h, v4.h[3] \n" + "sub %[b], %[b], #16 \n" + "add %[b], %[b], %[ldb] \n" + "smlal v5.4s, v10.4h, v4.h[4] \n" + "smlal2 v6.4s, v10.8h, v4.h[4] \n" + "smlal v5.4s, v11.4h, v4.h[5] \n" + "smlal2 v6.4s, v11.8h, v4.h[5] \n" + "subs %w[cnt], %w[cnt], #1 \n" + "ld1 {v0.8h, v1.8h}, [%[a]], #32 \n" + "smlal v5.4s, v12.4h, v4.h[6] \n" + "smlal2 v6.4s, v12.8h, v4.h[6] \n" + "smlal v5.4s, v13.4h, v4.h[7] \n" + "smlal2 v6.4s, v13.8h, v4.h[7] \n" + + "beq 2f \n" + "1: \n" + "ld1 {v4.8h}, [%[b]], #16 \n" + "ld1 {v2.8h, v3.8h}, [%[a]], #32 \n" + "smlal v5.4s, v0.4h, v4.h[0] \n" + "smlal2 v6.4s, v0.8h, v4.h[0] \n" + "ld1 {v10.8h, v11.8h}, [%[a]], #32 \n" + "smlal v5.4s, v1.4h, v4.h[1] \n" + "smlal2 v6.4s, v1.8h, v4.h[1] \n" + "ld1 {v12.8h, v13.8h}, [%[a]], #32 \n" + "smlal v5.4s, v2.4h, v4.h[2] \n" + "smlal2 v6.4s, v2.8h, v4.h[2] \n" + "smlal v5.4s, v3.4h, v4.h[3] \n" + "smlal2 v6.4s, v3.8h, v4.h[3] \n" + "sub %[b], %[b], #16 \n" + "add %[b], %[b], %[ldb] \n" + "smlal v5.4s, v10.4h, v4.h[4] \n" + "smlal2 v6.4s, v10.8h, v4.h[4] \n" + "smlal v5.4s, v11.4h, v4.h[5] \n" + "smlal2 v6.4s, v11.8h, v4.h[5] \n" + "subs %w[cnt], %w[cnt], #1 \n" + "ld1 {v0.8h, v1.8h}, [%[a]], #32 \n" + "smlal v5.4s, v12.4h, v4.h[6] \n" + "smlal2 v6.4s, v12.8h, v4.h[6] \n" + "smlal v5.4s, v13.4h, v4.h[7] \n" + "smlal2 v6.4s, v13.8h, v4.h[7] \n" + "bne 1b \n" + + "2: \n" + "st1 {v5.4s, v6.4s}, [%[c]], #32 \n" + : [a] "+r" (a_ptr), + [b] "+r" (b_ptr), + [c] "+r" (C), + [cnt] "+r" (cnt) + : [ldb] "r" (ldb_byte) + : "v0", "v1", "v2", "v3", "v4","v5", "v6", "cc", "memory" + ); + // clang-format on + b += 8; + } +#else + for (; n > 3; n -= 4) { + int cnt = kcnt; + const int16_t* a_ptr = A_packed; + const int16_t* b_ptr = b; + // clang-format off + asm volatile ( + "vld1.16 {d0-d3}, [%[b]]! \n" + "vld1.16 {d8-d11}, [%[a]]! \n" + "vld1.16 {d4-d7}, [%[b]]! \n" + "vmull.s16 q8, d8, d0[0] \n" + "vmull.s16 q9, d8, d2[0] \n" + "vld1.16 {d12-d15}, [%[a]]! \n" + "vmull.s16 q10, d9, d0[0] \n" + "vmull.s16 q11, d9, d2[0] \n" + "vmlal.s16 q8, d10, d0[1] \n" + "vmlal.s16 q9, d10, d2[1] \n" + "vmlal.s16 q10, d11, d0[1] \n" + "vmlal.s16 q11, d11, d2[1] \n" + "vmull.s16 q12, d8, d4[0] \n" + "vmull.s16 q13, d8, d6[0] \n" + "vmull.s16 q14, d9, d4[0] \n" + "vmull.s16 q15, d9, d6[0] \n" + "vmlal.s16 q12, d10, d4[1] \n" + "vmlal.s16 q13, d10, d6[1] \n" + "vmlal.s16 q14, d11, d4[1] \n" + "vmlal.s16 q15, d11, d6[1] \n" + + "vmlal.s16 q8, d12, d0[2] \n" + "vmlal.s16 q9, d12, d2[2] \n" + "vld1.16 {d8-d11}, [%[a]]! \n" + "vmlal.s16 q10, d13, d0[2] \n" + "vmlal.s16 q11, d13, d2[2] \n" + "vmlal.s16 q8, d14, d0[3] \n" + "vmlal.s16 q9, d14, d2[3] \n" + "vmlal.s16 q10, d15, d0[3] \n" + "vmlal.s16 q11, d15, d2[3] \n" + + "vmlal.s16 q12, d12, d4[2] \n" + "vmlal.s16 q13, d12, d6[2] \n" + "vmlal.s16 q14, d13, d4[2] \n" + "vmlal.s16 q15, d13, d6[2] \n" + "vmlal.s16 q12, d14, d4[3] \n" + "vmlal.s16 q13, d14, d6[3] \n" + "vmlal.s16 q14, d15, d4[3] \n" + "vmlal.s16 q15, d15, d6[3] \n" + + "sub %[b], %[b], #64 \n" + "add %[b], %[b], %[ldb] \n" + "vld1.16 {d12-d15}, [%[a]]! \n" + "vmlal.s16 q8, d8, d1[0] \n" + "vmlal.s16 q9, d8, d3[0] \n" + "vmlal.s16 q10, d9, d1[0] \n" + "vmlal.s16 q11, d9, d3[0] \n" + "vmlal.s16 q8, d10, d1[1] \n" + "vmlal.s16 q9, d10, d3[1] \n" + "vmlal.s16 q10, d11, d1[1] \n" + "vmlal.s16 q11, d11, d3[1] \n" + "vmlal.s16 q8, d12, d1[2] \n" + "vmlal.s16 q9, d12, d3[2] \n" + "vmlal.s16 q10, d13, d1[2] \n" + "vmlal.s16 q11, d13, d3[2] \n" + "vmlal.s16 q8, d14, d1[3] \n" + "vmlal.s16 q9, d14, d3[3] \n" + "vmlal.s16 q10, d15, d1[3] \n" + "vmlal.s16 q11, d15, d3[3] \n" + "vld1.16 {d0-d3}, [%[b]]! \n" + "vmlal.s16 q12, d8, d5[0] \n" + "vmlal.s16 q13, d8, d7[0] \n" + "vmlal.s16 q14, d9, d5[0] \n" + "vmlal.s16 q15, d9, d7[0] \n" + "vmlal.s16 q12, d10, d5[1] \n" + "vmlal.s16 q13, d10, d7[1] \n" + "subs %[cnt], %[cnt], #1 \n" + "vmlal.s16 q14, d11, d5[1] \n" + "vmlal.s16 q15, d11, d7[1] \n" + "vld1.16 {d8-d11}, [%[a]]! \n" + "vmlal.s16 q12, d12, d5[2] \n" + "vmlal.s16 q13, d12, d7[2] \n" + "vmlal.s16 q14, d13, d5[2] \n" + "vmlal.s16 q15, d13, d7[2] \n" + "vmlal.s16 q12, d14, d5[3] \n" + "vmlal.s16 q13, d14, d7[3] \n" + "vmlal.s16 q14, d15, d5[3] \n" + "vmlal.s16 q15, d15, d7[3] \n" + + "beq 2f \n" + "1: \n" + "vld1.16 {d4-d7}, [%[b]]! \n" + "vmlal.s16 q8, d8, d0[0] \n" + "vmlal.s16 q9, d8, d2[0] \n" + "vld1.16 {d12-d15}, [%[a]]! \n" + "vmlal.s16 q10, d9, d0[0] \n" + "vmlal.s16 q11, d9, d2[0] \n" + "vmlal.s16 q8, d10, d0[1] \n" + "vmlal.s16 q9, d10, d2[1] \n" + "vmlal.s16 q10, d11, d0[1] \n" + "vmlal.s16 q11, d11, d2[1] \n" + "vmlal.s16 q12, d8, d4[0] \n" + "vmlal.s16 q13, d8, d6[0] \n" + "vmlal.s16 q14, d9, d4[0] \n" + "vmlal.s16 q15, d9, d6[0] \n" + "vmlal.s16 q12, d10, d4[1] \n" + "vmlal.s16 q13, d10, d6[1] \n" + "vmlal.s16 q14, d11, d4[1] \n" + "vmlal.s16 q15, d11, d6[1] \n" + + "vmlal.s16 q8, d12, d0[2] \n" + "vmlal.s16 q9, d12, d2[2] \n" + "vld1.16 {d8-d11}, [%[a]]! \n" + "vmlal.s16 q10, d13, d0[2] \n" + "vmlal.s16 q11, d13, d2[2] \n" + "vmlal.s16 q8, d14, d0[3] \n" + "vmlal.s16 q9, d14, d2[3] \n" + "vmlal.s16 q10, d15, d0[3] \n" + "vmlal.s16 q11, d15, d2[3] \n" + + "vmlal.s16 q12, d12, d4[2] \n" + "vmlal.s16 q13, d12, d6[2] \n" + "vmlal.s16 q14, d13, d4[2] \n" + "vmlal.s16 q15, d13, d6[2] \n" + "vmlal.s16 q12, d14, d4[3] \n" + "vmlal.s16 q13, d14, d6[3] \n" + "vmlal.s16 q14, d15, d4[3] \n" + "vmlal.s16 q15, d15, d6[3] \n" + + "sub %[b], %[b], #64 \n" + "add %[b], %[b], %[ldb] \n" + "vld1.16 {d12-d15}, [%[a]]! \n" + "vmlal.s16 q8, d8, d1[0] \n" + "vmlal.s16 q9, d8, d3[0] \n" + "vmlal.s16 q10, d9, d1[0] \n" + "vmlal.s16 q11, d9, d3[0] \n" + "vmlal.s16 q8, d10, d1[1] \n" + "vmlal.s16 q9, d10, d3[1] \n" + "vmlal.s16 q10, d11, d1[1] \n" + "vmlal.s16 q11, d11, d3[1] \n" + "vmlal.s16 q8, d12, d1[2] \n" + "vmlal.s16 q9, d12, d3[2] \n" + "vmlal.s16 q10, d13, d1[2] \n" + "vmlal.s16 q11, d13, d3[2] \n" + "vmlal.s16 q8, d14, d1[3] \n" + "vmlal.s16 q9, d14, d3[3] \n" + "vmlal.s16 q10, d15, d1[3] \n" + "vmlal.s16 q11, d15, d3[3] \n" + "vld1.16 {d0-d3}, [%[b]]! \n" + "vmlal.s16 q12, d8, d5[0] \n" + "vmlal.s16 q13, d8, d7[0] \n" + "vmlal.s16 q14, d9, d5[0] \n" + "vmlal.s16 q15, d9, d7[0] \n" + "vmlal.s16 q12, d10, d5[1] \n" + "vmlal.s16 q13, d10, d7[1] \n" + "subs %[cnt], %[cnt], #1 \n" + "vmlal.s16 q14, d11, d5[1] \n" + "vmlal.s16 q15, d11, d7[1] \n" + "vld1.16 {d8-d11}, [%[a]]! \n" + "vmlal.s16 q12, d12, d5[2] \n" + "vmlal.s16 q13, d12, d7[2] \n" + "vmlal.s16 q14, d13, d5[2] \n" + "vmlal.s16 q15, d13, d7[2] \n" + "vmlal.s16 q12, d14, d5[3] \n" + "vmlal.s16 q13, d14, d7[3] \n" + "vmlal.s16 q14, d15, d5[3] \n" + "vmlal.s16 q15, d15, d7[3] \n" + + "bne 1b \n" + "2: \n" + "vst1.32 {d16-d17}, [%[c]]! \n" + "vst1.32 {d20-d21}, [%[c]]! \n" + "vst1.32 {d18-d19}, [%[c]]! \n" + "vst1.32 {d22-d23}, [%[c]]! \n" + "vst1.32 {d24-d25}, [%[c]]! \n" + "vst1.32 {d28-d29}, [%[c]]! \n" + "vst1.32 {d26-d27}, [%[c]]! \n" + "vst1.32 {d30-d31}, [%[c]]! \n" + : [a] "+r" (a_ptr), + [b] "+r" (b_ptr), + [c] "+r" (C), + [cnt] "+r" (cnt) + : [ldb] "r" (ldb_byte) + : "q0", "q1", "q2", "q3", "q4","q5", "q6", "q7", "q8", + "q9", "q10", "q11", "q12", "q13", "q14", "q15", "cc", "memory" + ); + // clang format on + b += 32; + } + for (; n > 0; --n) { + int cnt = kcnt; + const int16_t* a_ptr = A_packed; + const int16_t* b_ptr = b; + // clang format off + asm volatile ( + "vld1.16 {d0-d1}, [%[b]]! \n" + "vld1.16 {d4-d7}, [%[a]]! \n" + "vld1.16 {d8-d11}, [%[a]]! \n" + "vmull.s16 q8, d4, d0[0] \n" + "vmull.s16 q9, d5, d0[0] \n" + "sub %[b], %[b], #16 \n" + "vmlal.s16 q8, d6, d0[1] \n" + "vmlal.s16 q9, d7, d0[1] \n" + "add %[b], %[b], %[ldb] \n" + "subs %[cnt], %[cnt], #1 \n" + + "vld1.16 {d4-d7}, [%[a]]! \n" + "vmlal.s16 q8, d8, d0[2] \n" + "vmlal.s16 q9, d9, d0[2] \n" + "vmlal.s16 q8, d10, d0[3] \n" + "vmlal.s16 q9, d11, d0[3] \n" + "vld1.16 {d8-d11}, [%[a]]! \n" + + "vmlal.s16 q8, d4, d1[0] \n" + "vmlal.s16 q9, d5, d1[0] \n" + "vmlal.s16 q8, d6, d1[1] \n" + "vmlal.s16 q9, d7, d1[1] \n" + "vld1.16 {d4-d7}, [%[a]]! \n" + "vmlal.s16 q8, d8, d1[2] \n" + "vmlal.s16 q9, d9, d1[2] \n" + "vmlal.s16 q8, d10, d1[3] \n" + "vmlal.s16 q9, d11, d1[3] \n" + "beq 2f \n" + "1:\n" + "vld1.16 {d0-d1}, [%[b]]! \n" + "vld1.16 {d8-d11}, [%[a]]! \n" + "vmlal.s16 q8, d4, d0[0] \n" + "vmlal.s16 q9, d5, d0[0] \n" + "sub %[b], %[b], #16 \n" + "vmlal.s16 q8, d6, d0[1] \n" + "vmlal.s16 q9, d7, d0[1] \n" + "add %[b], %[b], %[ldb] \n" + "subs %[cnt], %[cnt], #1 \n" + + "vld1.16 {d4-d7}, [%[a]]! \n" + "vmlal.s16 q8, d8, d0[2] \n" + "vmlal.s16 q9, d9, d0[2] \n" + "vmlal.s16 q8, d10, d0[3] \n" + "vmlal.s16 q9, d11, d0[3] \n" + "vld1.16 {d8-d11}, [%[a]]! \n" + + "vmlal.s16 q8, d4, d1[0] \n" + "vmlal.s16 q9, d5, d1[0] \n" + "vmlal.s16 q8, d6, d1[1] \n" + "vmlal.s16 q9, d7, d1[1] \n" + "vld1.16 {d4-d7}, [%[a]]! \n" + "vmlal.s16 q8, d8, d1[2] \n" + "vmlal.s16 q9, d9, d1[2] \n" + "vmlal.s16 q8, d10, d1[3] \n" + "vmlal.s16 q9, d11, d1[3] \n" + "bne 1b \n" + "2: \n" + "vst1.32 {d16-d19}, [%[c]]! \n" + : [a] "+r" (a_ptr), + [b] "+r" (b_ptr), + [c] "+r" (C), + [cnt] "+r" (cnt) + : [ldb] "r" (ldb_byte) + : "q0", "q1", "q2", "q3", "q4","q5", "q6", "q7", "q8", + "q9", "cc", "memory" + ); + // clang-format on + b += 8; + } +#endif + A_packed += lda; + } +} + void sgemm_prepack_c4(int M, int N, int K, diff --git a/lite/backends/arm/math/packed_sgemm_c4.h b/lite/backends/arm/math/packed_sgemm_c4.h index 3229ff3e0774ce8bff02b12d79d7ec50ed873cea..51457d57405396f68bf1991bfa43cc6aa9fbe050 100644 --- a/lite/backends/arm/math/packed_sgemm_c4.h +++ b/lite/backends/arm/math/packed_sgemm_c4.h @@ -54,6 +54,13 @@ void sgemm_prepack_c4_small(int M, const float* B, float* C, ARMContext* ctx); +void sgemm_prepack_c8_int16_small(int M, + int N, + int K, + const int16_t* A_packed, + const int16_t* B, + int32_t* C, + ARMContext* ctx); } // namespace math } // namespace arm } // namespace lite diff --git a/lite/backends/arm/math/pooling.cc b/lite/backends/arm/math/pooling.cc index fdcbc7394b1be9e438686f91dfa407065d24f91a..0ea1f2e8f5e1430e14ec7035dad0caa7b1d64a90 100644 --- a/lite/backends/arm/math/pooling.cc +++ b/lite/backends/arm/math/pooling.cc @@ -206,6 +206,20 @@ void pooling_basic(const float* din, "ld2 {v0.4s, v1.4s}, [%[dr0]], #32\n" /* load q0-q1, dr0, 0-7*/ \ "ld2 {v2.4s, v3.4s}, [%[dr1]], #32\n" /* load q2-q3, dr1, 0-7*/ +#define P2x2S2P1_MAX \ + "ext v6.16b, %[vzero].16b, v1.16b, #12\n" /* 1357-0135 */ \ + "ext v8.16b, %[vzero].16b, v3.16b, #12\n" /* 1357-0135 */ \ + "sub %[dr0], %[dr0], #4\n" /* sub */ \ + "sub %[dr1], %[dr1], #4\n" /* sub */ \ + "fmax v4.4s, v0.4s, v6.4s\n" /* max */ \ + "fmax v5.4s, v2.4s, v8.4s\n" /* max */ \ + "ld2 {v0.4s, v1.4s}, [%[dr0]], #32\n" /* load q0-q1, dr0, 0-7*/ \ + "ld2 {v2.4s, v3.4s}, [%[dr1]], #32\n" /* load q2-q3, dr1, 0-7*/ \ + "fmax v6.4s, v4.4s, v5.4s\n" /* max reduce */ \ + "subs %w[cnt_num], %w[cnt_num], #1\n" /* subs cnt_num, #1*/ \ + "st1 {v6.4s}, [%[dr_out]], #16\n" /* store 4 out, dr_out */ \ + "ble 2f\n" /* bne s3_max_loop_mid */ + #define P2x2S2P0_MAX \ "1: \n" \ "fmax v4.4s, v0.4s, v1.4s\n" /* max */ \ @@ -217,6 +231,21 @@ void pooling_basic(const float* din, "st1 {v6.4s}, [%[dr_out]], #16\n" /* store 4 out, dr_out */ \ "bne 1b\n" /* bne s3_max_loop_mid */ +#define P2x2S2P1_AVG \ + "ext v6.16b, %[vzero].16b, v1.16b, #12\n" /* 1357-0135 */ \ + "ext v8.16b, %[vzero].16b, v3.16b, #12\n" /* 1357-0135 */ \ + "sub %[dr0], %[dr0], #4\n" /* sub */ \ + "sub %[dr1], %[dr1], #4\n" /* sub */ \ + "fadd v4.4s, v0.4s, v6.4s\n" /* add 0, 2, 4, 6 and 1, 3, 5, 7 */ \ + "fadd v5.4s, v2.4s, v8.4s\n" /* add 0, 2, 4, 6 and 1, 3, 5, 7 */ \ + "ld2 {v0.4s, v1.4s}, [%[dr0]], #32\n" /* load q0-q1, dr0, 0-7*/ \ + "ld2 {v2.4s, v3.4s}, [%[dr1]], #32\n" /* load q2-q3, dr1, 0-7*/ \ + "fadd v6.4s, v4.4s, v5.4s\n" /* add reduce */ \ + "subs %w[cnt_num], %w[cnt_num], #1\n" /* subs cnt_num, #1*/ \ + "fmul v4.4s, v6.4s, %[vcoef_left].4s\n" /* mul coef */ \ + "st1 {v4.4s}, [%[dr_out]], #16\n" /* store 4 out, dr_out */ \ + "ble 2f\n" /* bne s3_max_loop_mid */ + #define P2x2S2P0_AVG \ "1: \n" /* load bias to q2, q3*/ \ "fadd v4.4s, v0.4s, v1.4s\n" /* add 0, 2, 4, 6 and 1, 3, 5, 7 */ \ @@ -228,6 +257,7 @@ void pooling_basic(const float* din, "fmul v4.4s, v6.4s, %[vcoef].4s\n" /* mul coef */ \ "st1 {v4.4s}, [%[dr_out]], #16\n" /* store 4 out, dr_out */ \ "bne 1b\n" /* bne s3_max_loop_mid */ + #define P3x3S1_INIT \ "ldr q0, [%[dr0]], #16\n" /* load q0, dr0, 0-3*/ \ "ldr q1, [%[dr1]], #16\n" /* load q1, dr1, 0-3*/ \ @@ -518,16 +548,45 @@ void pooling_basic(const float* din, "vld2.f32 {d0-d3}, [%[dr0]]! @ load \n" \ "vld2.f32 {d4-d7}, [%[dr1]]! @ load \n" +#define P2x2S2P1_MAX \ + "vext.32 q4, %q[vzero], q1, #3 @ 1357-0135\n" \ + "vext.32 q5, %q[vzero], q3, #3 @ 1357-0135\n" \ + "sub %[dr0], #4 @sub \n" \ + "sub %[dr1], #4 @sub \n" \ + "vmax.f32 q8, q0, q4 @ max \n" \ + "vmax.f32 q9, q2, q5 @ max \n" \ + "vld2.f32 {d0-d3}, [%[dr0]]! @ load \n" \ + "vld2.f32 {d4-d7}, [%[dr1]]! @ load \n" \ + "vmax.f32 q5, q9, q8 @ max reduce\n" \ + "subs %[cnt_num], #1 @ subs cnt_num \n" \ + "vst1.f32 {d10-d11}, [%[dr_out]]! @ store 4 out \n" \ + "ble 2f @ bne \n" + #define P2x2S2P0_MAX \ "1: @ main loop\n" \ "vmax.f32 q4, q0, q1 @ max \n" \ "vmax.f32 q5, q2, q3 @ max \n" \ "vld2.f32 {d0-d3}, [%[dr0]]! @ load \n" \ "vld2.f32 {d4-d7}, [%[dr1]]! @ load \n" \ - "vmax.f32 q6, q4, q5 @ max reduce\n" \ + "vmax.f32 q8, q4, q5 @ max reduce\n" \ "subs %[cnt_num], #1 @ subs cnt_num \n" \ - "vst1.f32 {d12-d13}, [%[dr_out]]! @ store 4 out \n" \ - "bne 1b @ bne " + "vst1.f32 {d16-d17}, [%[dr_out]]! @ store 4 out \n" \ + "bne 1b @ bne \n" + +#define P2x2S2P1_AVG \ + "vext.32 q4, %q[vzero], q1, #3 @ 1357-0135\n" \ + "vext.32 q5, %q[vzero], q3, #3 @ 1357-0135\n" \ + "sub %[dr0], #4 @sub \n" \ + "sub %[dr1], #4 @sub \n" \ + "vadd.f32 q9, q0, q4 @ max \n" \ + "vadd.f32 q8, q2, q5 @ max \n" \ + "vld2.f32 {d0-d3}, [%[dr0]]! @ load \n" \ + "vld2.f32 {d4-d7}, [%[dr1]]! @ load \n" \ + "vadd.f32 q5, q9, q8 @ max reduce\n" \ + "subs %[cnt_num], #1 @ subs cnt_num \n" \ + "vmul.f32 q4, q5, %q[vcoef_left] @ mul coef \n" \ + "vst1.f32 {d8-d9}, [%[dr_out]]! @ store 4 out \n" \ + "ble 2f @ bne\n" #define P2x2S2P0_AVG \ "1: @ main loop\n" \ @@ -535,9 +594,9 @@ void pooling_basic(const float* din, "vadd.f32 q5, q2, q3 @ add 0, 2, 4, 6 \n" \ "vld2.f32 {d0-d3}, [%[dr0]]! @ load d0-d3 \n" \ "vld2.f32 {d4-d7}, [%[dr1]]! @ load d4-d7 \n" \ - "vadd.f32 q6, q4, q5 @ add reduce \n" \ + "vadd.f32 q8, q4, q5 @ add reduce \n" \ "subs %[cnt_num], #1 @ subs \n" \ - "vmul.f32 q4, q6, %q[vcoef] @ mul coef \n" \ + "vmul.f32 q4, q8, %q[vcoef] @ mul coef \n" \ "vst1.f32 {d8-d9}, [%[dr_out]]! @ store 4 out \n" \ "bne 1b @ bne \n" @@ -1037,17 +1096,17 @@ void pooling1x1s2p0_max(const float* din, TargetFree(TARGET(kARM), write_ptr); } -void pooling2x2s2_max(const float* din, - float* dout, - int num, - int chout, - int hout, - int wout, - int chin, - int hin, - int win, - int pad_bottom, - int pad_right) { +void pooling2x2s2p0_max(const float* din, + float* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + int pad_bottom, + int pad_right) { int size_channel_out = wout * hout; int size_channel_in = win * hin; auto data_out = static_cast(dout); @@ -1095,7 +1154,7 @@ void pooling2x2s2_max(const float* din, [dr_out] "+r"(dr_out), [cnt_num] "+r"(cnt_num) : - : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6"); + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q8"); #endif dr0 -= 8; dr1 -= 8; @@ -1121,18 +1180,18 @@ void pooling2x2s2_max(const float* din, } } -void pooling2x2s2_avg(const float* din, - float* dout, - int num, - int chout, - int hout, - int wout, - int chin, - int hin, - int win, - bool exclusive, - int pad_bottom, - int pad_right) { +void pooling2x2s2p0_avg(const float* din, + float* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + bool exclusive, + int pad_bottom, + int pad_right) { int size_channel_out = wout * hout; int size_channel_in = win * hin; auto data_out = static_cast(dout); @@ -1158,12 +1217,14 @@ void pooling2x2s2_avg(const float* din, const float* data_in_channel = data_in_batch + c * size_channel_in; const float* r0 = data_in_channel; const float* r1 = r0 + win; + vcoef = vdupq_n_f32(0.25f); for (int h = 0; h < hout; h++) { float* dr_out = data_out_channel; auto dr0 = r0; auto dr1 = r1; if (h * S + K - P > hin) { dr1 = zero_ptr; + vcoef = vdupq_n_f32(0.5f); } int cnt_num = w_unroll_size; if (w_unroll_size > 0) { @@ -1184,7 +1245,7 @@ void pooling2x2s2_avg(const float* din, [dr_out] "+r"(dr_out), [cnt_num] "+r"(cnt_num) : [vcoef] "w"(vcoef) - : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6"); + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q8"); #endif dr0 -= 8; dr1 -= 8; @@ -1194,8 +1255,14 @@ void pooling2x2s2_avg(const float* din, int wstart = 0; for (int j = 0; j < w_unroll_remian; ++j) { int wend = std::min(wstart + K, rem); - float coef = 0.5f / (wend - wstart); + float coef = 0.25f; float tmp = 0.f; + if (wend - wstart == 1 && pad_right == 0) { + coef *= 2; + } + if (h * S + K - P > hin && pad_bottom == 0) { + coef *= 2; + } for (int i = wstart; i < wend; i++) { tmp += dr0[i] + dr1[i]; } @@ -1212,6 +1279,235 @@ void pooling2x2s2_avg(const float* din, TargetFree(TARGET(kARM), zero_ptr); } +void pooling2x2s2p1_max(const float* din, + float* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + int pad_bottom, + int pad_right) { + int size_channel_out = wout * hout; + int size_channel_in = win * hin; + auto data_out = static_cast(dout); + auto data_in = static_cast(din); + + const int K = 2; + const int P = 1; + const int S = 2; + + int w_unroll_size = wout / 4; + int w_unroll_remian = wout - w_unroll_size * 4; + float32x4_t vzero = vdupq_n_f32(std::numeric_limits::lowest()); + if (w_unroll_remian == 0) { + w_unroll_size -= 1; + w_unroll_remian = wout - w_unroll_size * 4; + } + + for (int n = 0; n < num; ++n) { + float* data_out_batch = data_out + n * chout * size_channel_out; + const float* data_in_batch = data_in + n * chin * size_channel_in; +#pragma omp parallel for + for (int c = 0; c < chout; c++) { + float* data_out_channel = data_out_batch + c * size_channel_out; + const float* data_in_channel = data_in_batch + c * size_channel_in; + const float* r0 = data_in_channel; + const float* r1 = r0 + win; + for (int h = 0; h < hout; h++) { + float* dr_out = data_out_channel; + auto dr0 = r0; + auto dr1 = r1; + if (h == 0) { + dr0 = r0; + dr1 = r0; + r0 = r1; + r1 = r0 + win; + } else { + r0 = r1 + win; + r1 = r0 + win; + } + if (h * S + K - P > hin) { + dr1 = dr0; + if (h * S + K - P > hin + 1) { + memset(dr_out, 0, wout * sizeof(float)); + continue; + } + } + int cnt_num = w_unroll_size; + if (w_unroll_size > 0) { +#ifdef __aarch64__ + asm volatile( + P2x2S2_INIT P2x2S2P1_MAX P2x2S2P0_MAX "2: \n" /* end */ + : [dr0] "+r"(dr0), + [dr1] "+r"(dr1), + [dr_out] "+r"(dr_out), + [cnt_num] "+r"(cnt_num) + : [vzero] "w"(vzero) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v8"); +#else + asm volatile( + P2x2S2_INIT P2x2S2P1_MAX P2x2S2P0_MAX "2: \n" /* end */ + : [dr0] "+r"(dr0), + [dr1] "+r"(dr1), + [dr_out] "+r"(dr_out), + [cnt_num] "+r"(cnt_num) + : [vzero] "w"(vzero) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q8", "q9"); +#endif + dr0 -= 8; + dr1 -= 8; + } + // deal with right pad + int wstart = w_unroll_size * 4 * S - P; + for (int j = 0; j < w_unroll_remian; ++j) { + int wend = std::min(wstart + K, win); + int st = wstart > 0 ? wstart : 0; + float tmp = wend == st ? 0.f : dr0[0]; + for (int i = 0; i < wend - st; i++) { + tmp = std::max(tmp, dr0[i]); + tmp = std::max(tmp, dr1[i]); + } + *(dr_out++) = tmp; + dr0 += S - (st - wstart); + dr1 += S - (st - wstart); + wstart += S; + } + data_out_channel += wout; + } + } + } +} + +void pooling2x2s2p1_avg(const float* din, + float* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + bool exclusive, + int pad_bottom, + int pad_right) { + int size_channel_out = wout * hout; + int size_channel_in = win * hin; + auto data_out = static_cast(dout); + auto data_in = static_cast(din); + + const int K = 2; + const int P = 1; + const int S = 2; + + int w_unroll_size = wout / 4; + int w_unroll_remian = wout - w_unroll_size * 4; + auto zero_ptr = + static_cast(TargetMalloc(TARGET(kARM), win * sizeof(float))); + float32x4_t vzero = vdupq_n_f32(0.f); + memset(zero_ptr, 0, win * sizeof(float)); + + if (w_unroll_remian == 0) { + w_unroll_size -= 1; + w_unroll_remian = wout - w_unroll_size * 4; + } + + for (int n = 0; n < num; ++n) { + float* data_out_batch = data_out + n * chout * size_channel_out; + const float* data_in_batch = data_in + n * chin * size_channel_in; +#pragma omp parallel for + for (int c = 0; c < chout; c++) { + float* data_out_channel = data_out_batch + c * size_channel_out; + const float* data_in_channel = data_in_batch + c * size_channel_in; + const float* r0 = data_in_channel; + const float* r1 = r0 + win; + for (int h = 0; h < hout; h++) { + float* dr_out = data_out_channel; + auto dr0 = r0; + auto dr1 = r1; + float coef_h = 0.5f; + if (h == 0) { + dr0 = zero_ptr; + dr1 = r0; + r0 = r1; + r1 = r0 + win; + if (exclusive) { + coef_h = 1.f; + } + } else { + r0 = r1 + win; + r1 = r0 + win; + } + if (h * S + K - P > hin) { + dr1 = zero_ptr; + if (exclusive) { + coef_h = 1.f; + } + if (h * S + K - P > hin + 1) { + memset(dr_out, 0, wout * sizeof(float)); + continue; + } + } + float coef_left_most = exclusive ? coef_h : coef_h / 2; + float32x4_t vcoef = vdupq_n_f32(coef_h / 2); + float coef_left[4] = { + coef_left_most, coef_h / 2, coef_h / 2, coef_h / 2}; + float32x4_t vcoef_left = vld1q_f32(coef_left); + int cnt_num = w_unroll_size; + if (w_unroll_size > 0) { +#ifdef __aarch64__ + asm volatile( + P2x2S2_INIT P2x2S2P1_AVG P2x2S2P0_AVG "2: \n" + : [dr0] "+r"(dr0), + [dr1] "+r"(dr1), + [dr_out] "+r"(dr_out), + [cnt_num] "+r"(cnt_num) + : [vcoef] "w"(vcoef), + [vzero] "w"(vzero), + [vcoef_left] "w"(vcoef_left) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v8"); +#else + asm volatile( + P2x2S2_INIT P2x2S2P1_AVG P2x2S2P0_AVG "2: \n" + : [dr0] "+r"(dr0), + [dr1] "+r"(dr1), + [dr_out] "+r"(dr_out), + [cnt_num] "+r"(cnt_num) + : [vcoef] "w"(vcoef), + [vzero] "w"(vzero), + [vcoef_left] "w"(vcoef_left) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q8", "q9"); +#endif + dr0 -= 8; + dr1 -= 8; + } + // deal with right pad + int wstart = w_unroll_size * 4 * S - P; + for (int j = 0; j < w_unroll_remian; ++j) { + int wend = std::min(wstart + K, win); + int st = wstart > 0 ? wstart : 0; + float tmp = 0.f; + float coef = coef_h / 2; + if (exclusive && wend - st == 1) { + coef = coef_h; + } + for (int i = 0; i < wend - st; i++) { + tmp += dr0[i] + dr1[i]; + } + *(dr_out++) = tmp * coef; + dr0 += S - (st - wstart); + dr1 += S - (st - wstart); + wstart += S; + } + data_out_channel += wout; + } + } + } + TargetFree(TARGET(kARM), zero_ptr); +} + void pooling3x3s1p1_max(const float* din, float* dout, int num, @@ -2240,6 +2536,9 @@ void pooling3x3s2p0_max(const float* din, w_unroll_remian = wout - w_unroll_size * 4; } + int remain = w_unroll_remian - 1; + int right = wout * 2 + 1 - win; // if need right pad + for (int n = 0; n < num; ++n) { float* data_out_batch = data_out + n * chout * size_channel_out; const float* data_in_batch = data_in + n * chin * size_channel_in; @@ -2266,6 +2565,7 @@ void pooling3x3s2p0_max(const float* din, } } int cnt_num = w_unroll_size; + int cnt_remain = remain; if (w_unroll_size > 0) { #ifdef __aarch64__ asm volatile(P3x3S2P0_INIT P3x3S2P0_MAX @@ -2289,46 +2589,80 @@ void pooling3x3s2p0_max(const float* din, "v9", "v10", "v11"); -#else - asm volatile(P3x3S2P0_INIT P3x3S2P0_MAX - : [dr0] "+r"(dr0), - [dr1] "+r"(dr1), - [dr2] "+r"(dr2), - [dr_out] "+r"(dr_out), - [cnt_num] "+r"(cnt_num) - : - : "cc", - "memory", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11"); -#endif dr0 -= 8; dr1 -= 8; dr2 -= 8; - } - // deal with right pad - int rem = win - (w_unroll_size * 4) * S; - int wstart = 0; - for (int j = 0; j < w_unroll_remian; ++j) { - int wend = std::min(wstart + K, rem); - float tmp = dr0[wstart]; // std::numeric_limits::min(); - for (int i = wstart; i < wend; i++) { - tmp = std::max(tmp, dr0[i]); - tmp = std::max(tmp, dr1[i]); - tmp = std::max(tmp, dr2[i]); + int rem = win - (w_unroll_size * 4) * S; + int wstart = 0; + for (int j = 0; j < w_unroll_remian; ++j) { + int wend = std::min(wstart + K, rem); + float tmp = dr0[wstart]; // std::numeric_limits::min(); + for (int i = wstart; i < wend; i++) { + tmp = std::max(tmp, dr0[i]); + tmp = std::max(tmp, dr1[i]); + tmp = std::max(tmp, dr2[i]); + } + *(dr_out++) = tmp; + wstart += S; } - *(dr_out++) = tmp; - wstart += S; +#else + asm volatile( + P3x3S2P0_INIT P3x3S2P0_MAX + "cmp %[remain], #0 @cmp cnt_num\n" + "sub %[dr0], #32 @sub - 8\n" + "sub %[dr1], #32 @sub - 8\n" + "sub %[dr2], #32 @sub - 8\n" + "ble 4f @ble exit1\n" + "2: @mid loop\n" + "vld1.f32 {d0-d1}, [%[dr0]]! @load \n" + "vld1.f32 {d2-d3}, [%[dr1]]! @load \n" + "vld1.f32 {d4-d5}, [%[dr2]]! @load \n" + "vmov.f32 s3,s2 @mov \n" + "vmov.f32 s7,s6 @mov \n" + "vmov.f32 s11,s10 @mov \n" + "vmax.f32 q0, q0, q1 @max n" + "sub %[dr0], #8 @add w \n" + "sub %[dr1], #8 @add w \n" + "sub %[dr2], #8 @add w \n" + "vmax.f32 q0, q0, q2 @max \n" + "vpmax.f32 d0, d0, d1 @pmax \n" + "vpmax.f32 d0, d0, d0 @pmax \n" + "subs %[remain], #1 @subs \n" + "vst1.f32 d0[0], [%[dr_out]]! @vst \n" + "bne 2b @bne \n" + "4: @exit\n" + : [dr0] "+r"(dr0), + [dr1] "+r"(dr1), + [dr2] "+r"(dr2), + [dr_out] "+r"(dr_out), + [remain] "+r"(cnt_remain), + [cnt_num] "+r"(cnt_num) + : + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11"); + if (right) { + int wstart = (w_unroll_size * 4 + remain) * S; + int wend = std::min(wstart + K, win); + float tmp = dr0[wstart]; // std::numeric_limits::min(); + for (int i = wstart; i < wend; i++) { + tmp = std::max(tmp, std::max(dr0[i], dr1[i])); + tmp = std::max(tmp, dr2[i]); + } + *(dr_out++) = tmp; + } +#endif } r0 = r2; @@ -2368,6 +2702,9 @@ void pooling3x3s2p0_avg(const float* din, w_unroll_remian = wout - w_unroll_size * 4; } + // do overflow process + w_unroll_size -= 1; + w_unroll_remian += 4; auto zero_ptr = static_cast(TargetMalloc(TARGET(kARM), win * sizeof(float))); memset(zero_ptr, 0, win * sizeof(float)); diff --git a/lite/backends/arm/math/pooling.h b/lite/backends/arm/math/pooling.h index 7bbffa8e2f4594da4be589569efc0ef18b8dd0da..572919e3f083f736d8f49b3bae0dd2820fac35c4 100644 --- a/lite/backends/arm/math/pooling.h +++ b/lite/backends/arm/math/pooling.h @@ -76,30 +76,55 @@ void pooling1x1s2p0_max(const float* din, int pad_bottom, int pad_right); -void pooling2x2s2_max(const float* din, - float* dout, - int num, - int chout, - int hout, - int wout, - int chin, - int hin, - int win, - int pad_bottom, - int pad_right); - -void pooling2x2s2_avg(const float* din, - float* dout, - int num, - int chout, - int hout, - int wout, - int chin, - int hin, - int win, - bool exclusive, - int pad_bottom, - int pad_right); +void pooling2x2s2p0_max(const float* din, + float* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + int pad_bottom, + int pad_right); + +void pooling2x2s2p0_avg(const float* din, + float* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + bool exclusive, + int pad_bottom, + int pad_right); + +void pooling2x2s2p1_max(const float* din, + float* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + int pad_bottom, + int pad_right); + +void pooling2x2s2p1_avg(const float* din, + float* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + bool exclusive, + int pad_bottom, + int pad_right); void pooling3x3s1p1_max(const float* din, float* dout, diff --git a/lite/backends/arm/math/softmax.cc b/lite/backends/arm/math/softmax.cc index 65d41b049123680f26674cc05d3c02172a260b31..b7f82e9f376e8b62195d884e8de19a142d76b316 100644 --- a/lite/backends/arm/math/softmax.cc +++ b/lite/backends/arm/math/softmax.cc @@ -531,7 +531,7 @@ void softmax_inner1_large_axis(const float* din, } float32x2_t vhmax = vmax_f32(vget_high_f32(vmax), vget_low_f32(vmax)); float max_data = std::max(vget_lane_f32(vhmax, 0), vget_lane_f32(vhmax, 1)); - for (j = 4 * j; j < axis_size; ++j) { + for (j = 4 * nn; j < axis_size; ++j) { max_data = std::max(max_data, din_max_ptr[0]); din_max_ptr++; } @@ -557,7 +557,7 @@ void softmax_inner1_large_axis(const float* din, float32x2_t vhsum = vadd_f32(vget_high_f32(vsum), vget_low_f32(vsum)); float sum_data = vget_lane_f32(vhsum, 0) + vget_lane_f32(vhsum, 1); - for (j = 4 * j; j < axis_size; ++j) { + for (j = 4 * nn; j < axis_size; ++j) { dout_sum_ptr[0] = expf(din_sum_ptr[0] - max_data); sum_data += dout_sum_ptr[0]; din_sum_ptr++; diff --git a/lite/backends/fpga/KD/pes/pooling_pe.hpp b/lite/backends/fpga/KD/pes/pooling_pe.hpp index 60755ee1dbf81512bde618389cbf3a88cf93d1ce..2bc4f91f1d8c76b243a0ffb4a083f8d6ab138553 100644 --- a/lite/backends/fpga/KD/pes/pooling_pe.hpp +++ b/lite/backends/fpga/KD/pes/pooling_pe.hpp @@ -50,13 +50,14 @@ class PoolingPE : public PE { PoolingArgs args = {0}; args.mode = param_.type; + auto paddings = *param_.paddings; args.kernel_reciprocal = fp32_2_fp16(1.0f / (k_width * k_height)); args.image.address = input->data(); args.image.channels = input->shape().channel(); args.image.height = input->shape().height(); args.image.width = input->shape().width(); - args.image.pad_height = param_.paddings[0]; - args.image.pad_width = param_.paddings[1]; + args.image.pad_height = paddings[0]; + args.image.pad_width = paddings[2]; args.image.scale_address = input->scale(); args.output.address = output->mutableData(); args.output.scale_address = output->scale(); @@ -69,8 +70,7 @@ class PoolingPE : public PE { param_.poolingArgs = args; // use_cpu_ = output->shape().width() == 1 && output->shape().height() == 1 - // && - // (k_width > 7 || k_height > 7); + // && (k_width > 7 || k_height > 7); use_cpu_ = output->shape().width() == 1 && output->shape().height() == 1 && (k_width > 255 || k_height > 255); // use_cpu_ = param_.type == AVERAGE; @@ -86,12 +86,13 @@ class PoolingPE : public PE { float* image_addr = float_input.mutableData(FP32, input->shape()); float_input.copyFrom(input); float16* data_out = output->data(); + auto paddings = *param_.paddings; int image_height = input->shape().height(); int image_width = input->shape().width(); int image_channels = input->shape().channel(); - int image_pad_h = param_.paddings[0]; - int image_pad_w = param_.paddings[1]; + int image_pad_h = paddings[0]; + int image_pad_w = paddings[2]; int kernel_height = param_.kernelSize[1]; int kernel_width = param_.kernelSize[0]; int kernel_step_h = param_.strides[0]; diff --git a/lite/kernels/arm/concat_compute.cc b/lite/kernels/arm/concat_compute.cc index dc78e1b955c29b261b2103479ea00bb836c0a31f..c954483258e45e90ca704d116e43a8d5b385aab6 100644 --- a/lite/kernels/arm/concat_compute.cc +++ b/lite/kernels/arm/concat_compute.cc @@ -71,6 +71,9 @@ void ConcatCompute::Run() { auto* axis_tensor_data = axis_tensor->data(); axis = axis_tensor_data[0]; } + if (axis < 0) { + axis += inputs[0]->dims().size(); + } switch (inputs.front()->precision()) { case PRECISION(kFloat): diff --git a/lite/kernels/arm/conv_compute.cc b/lite/kernels/arm/conv_compute.cc index 2a545e70691f030a3a1e3f2a9a9822f5cd8b85b9..54e67de5abbfc88f64a50b07335d2527d9738206 100644 --- a/lite/kernels/arm/conv_compute.cc +++ b/lite/kernels/arm/conv_compute.cc @@ -73,7 +73,6 @@ void ConvCompute::PrepareForRun() { // VLOG(3) << "invoking dw conv"; } else if (param.groups == 1 && kw == 3 && stride == 1 && ks_equal && no_dilation) { - // TODO(MyPandaShaoxiang): winograd conv support any pad impl_ = new WinogradConv; // VLOG(3) << "invoking winograd conv"; } else if (param.groups == 1 && kw == 3 && stride == 2 && @@ -122,10 +121,14 @@ void ConvCompute::PrepareForRun() { no_dilation && flag_dw) { impl_ = new DepthwiseConv; // VLOG(3) << "Run DepthwiseConv Int8"; - } else if (param.groups == 1 && kw == 3 && (sw == 1 || sw == 2) && - ic * oc < 4 * hin * win && kps_equal && no_dilation) { + } else if (param.groups == 1 && kw == 3 && sw == 2 && no_dilation && + pads_equal) { impl_ = new DirectConv; // VLOG(3) << "Run DirectConv Int8"; + } else if (param.groups == 1 && kw == 3 && sw == 1 && no_dilation && + pads_equal) { + impl_ = new WinogradConv; + // VLOG(3) << "Run WinogradConv Int8"; } else { impl_ = new GemmLikeConv; // VLOG(3) << "Run GemmLikeConvInt8"; @@ -169,10 +172,14 @@ void ConvCompute::PrepareForRun() { no_dilation && flag_dw) { impl_ = new DepthwiseConv; // VLOG(3) << "Run DepthwiseConv Int8"; - } else if (param.groups == 1 && kw == 3 && (sw == 1 || sw == 2) && - ic * oc < 4 * hin * win && kps_equal && no_dilation) { + } else if (param.groups == 1 && kw == 3 && sw == 2 && no_dilation && + pads_equal) { impl_ = new DirectConv; // VLOG(3) << "Run DirectConv Int8"; + } else if (param.groups == 1 && kw == 3 && sw == 1 && no_dilation && + pads_equal) { + impl_ = new WinogradConv; + // VLOG(3) << "Run WinogradConv Int8"; } else { impl_ = new GemmLikeConv; // VLOG(3) << "Run GemmLikeConvInt8"; diff --git a/lite/kernels/arm/conv_winograd.cc b/lite/kernels/arm/conv_winograd.cc index d0880e51de1eff4763c63d2d3fa4bc74cafc859e..6aa93366095a0567f4dee8fe732f76030f81a0aa 100644 --- a/lite/kernels/arm/conv_winograd.cc +++ b/lite/kernels/arm/conv_winograd.cc @@ -13,7 +13,6 @@ // limitations under the License. #include "lite/kernels/arm/conv_winograd.h" -#include #include "lite/backends/arm/math/conv_impl.h" #include "lite/backends/arm/math/packed_sgemm.h" @@ -166,6 +165,189 @@ void WinogradConv::Run() { } } +template +void WinogradConv::ReInitWhenNeeded() { + auto& param = this->Param(); + auto& ctx = this->ctx_->template As(); + int threads = ctx.threads(); + + auto x_dims = param.x->dims(); + auto w_dims = param.filter->dims(); + auto o_dims = param.output->dims(); + + if (last_shape_ == x_dims) { + return; + } + last_shape_ = x_dims; + //! update workspace size + int ic = x_dims[1]; + int ih = x_dims[2]; + int iw = x_dims[3]; + int oc = o_dims[1]; + int oh = o_dims[2]; + int ow = o_dims[3]; + int tile_block = 8; + auto pad = *(param.paddings); + int pad_h0 = pad[0]; + int pad_h1 = pad[1]; + int pad_w0 = pad[2]; + int pad_w1 = pad[3]; + int oc_pad = (oc + 7) / 8 * 8; + int ic_pad = (ic + 7) / 8 * 8; + const int new_input_size = + ic_pad * (ih + pad_h0 + pad_h1) * (iw + pad_w0 + pad_w1) + + oc_pad * oh * ow * sizeof(int32_t); + int tmp_input_thread_size_byte = + tile_block * ic_pad * wino_iw * wino_iw * sizeof(int16_t); + int tmp_output_thread_size_byte = + tile_block * oc_pad * wino_iw * wino_iw * sizeof(int32_t); + const int temp_size = + (tmp_input_thread_size_byte + tmp_output_thread_size_byte + + wino_iw * wino_iw * (8 + 8 * sizeof(int32_t))) * + threads; + workspace_size_ = temp_size + new_input_size; + + //! update trans weights impl + // choose_small_ = ow * oh / (tile_block * threads) < 36 ? true : false; + // we only support 2x2 now + choose_small_ = true; + float w_fact = 0.25; + if (choose_small_) { + wino_iw = 4; + + if (last_function_ == 0) { + return; + } + last_function_ = 0; + } else { + wino_iw = 6; + if (last_function_ == 1) { + return; + } + last_function_ = 1; + } + /// update scale + for (auto& ws : w_scale_) { + ws *= w_fact; + } + + weights_.Resize({1, 1, 1, wino_iw * wino_iw * oc_pad * ic_pad}); + void* trans_tmp_ptr = malloc(sizeof(int16_t) * wino_iw * wino_iw * oc * ic); + auto weights_data_ = weights_.mutable_data(); + if (!choose_small_) { + } else { + lite::arm::math::weight_trans_c8_4x4_int8( + weights_data_, + param.filter->template data(), + ic, + oc, + trans_tmp_ptr); + } + free(trans_tmp_ptr); +} + +template +void WinogradConv::PrepareForRun() { + auto& param = this->Param(); + w_scale_ = param.weight_scale; + if (w_scale_.size() != 1 && w_scale_.size() != param.filter->dims()[0]) { + LOG(FATAL) << "weights scale size must equal to filter size"; + return; + } + if (w_scale_.size() == 1) { + for (int i = 0; i < param.filter->dims()[0] - 1; ++i) { + w_scale_.push_back(w_scale_[0]); + } + } + float input_scale = param.input_scale; + for (auto& ws : w_scale_) { + ws *= input_scale; + } + if (param.bias) { + bias_.Resize(param.bias->dims()); + auto ptr = bias_.mutable_data(); + auto ptr_in = param.bias->template data(); + for (int i = 0; i < bias_.numel(); ++i) { + ptr[i] = ptr_in[i]; + } + } + if (OutType == PRECISION(kInt8)) { + float output_scale = param.output_scale; + for (auto& ws : w_scale_) { + ws /= output_scale; + } + if (param.bias) { + auto ptr = bias_.mutable_data(); + for (int i = 0; i < bias_.numel(); ++i) { + ptr[i] /= output_scale; + } + } + } + ReInitWhenNeeded(); +} + +template +void WinogradConv::Run() { + auto& param = this->Param(); + auto& ctx = this->ctx_->template As(); + ctx.ExtendWorkspace(workspace_size_); + const auto* i_data = param.x->template data(); + const auto* w_data = weights_.data(); + const auto* b_data = param.bias ? bias_.data() : nullptr; + // const float* i_data; + auto x_dims = param.x->dims(); + auto w_dims = param.filter->dims(); + auto o_dims = param.output->dims(); + + int iw = x_dims[3]; // nchw + int ih = x_dims[2]; + int ic = x_dims[1]; + int bs = x_dims[0]; + int oh = o_dims[2]; + int ow = o_dims[3]; + int oc = o_dims[1]; + + // now always choose small + if (OutType == PRECISION(kInt8)) { + auto* o_data = param.output->template mutable_data(); + lite::arm::math::conv_compute_2x2_3x3_int8(i_data, + o_data, + bs, + oc, + oh, + ow, + ic, + ih, + iw, + w_data, + b_data, + w_scale_.data(), + param, + &ctx); + } else { + auto* o_data = param.output->template mutable_data(); + lite::arm::math::conv_compute_2x2_3x3_int8(i_data, + o_data, + bs, + oc, + oh, + ow, + ic, + ih, + iw, + w_data, + b_data, + w_scale_.data(), + param, + &ctx); + } +#ifdef LITE_WITH_PROFILE + kernel_func_name_ = "conv_compute_2x2_3x3_int8"; +#endif +} +template class WinogradConv; +template class WinogradConv; + } // namespace arm } // namespace kernels } // namespace lite diff --git a/lite/kernels/arm/conv_winograd.h b/lite/kernels/arm/conv_winograd.h index 1a184ac0ccae1967a2f77110ce2a6fb619cf2e8e..6fb702150db2ce25e3789fc7d8f01b68d550c9e9 100644 --- a/lite/kernels/arm/conv_winograd.h +++ b/lite/kernels/arm/conv_winograd.h @@ -15,11 +15,12 @@ #pragma once #include +#include +#include #include "lite/backends/arm/math/conv_impl.h" #include "lite/core/context.h" #include "lite/core/kernel.h" #include "lite/core/target_wrapper.h" - namespace paddle { namespace lite { namespace kernels { @@ -44,7 +45,34 @@ class WinogradConv : public KernelLite { bool choose_small_{false}; int wino_iw{8}; }; +template +class WinogradConv + : public KernelLite { + public: + WinogradConv() = default; + ~WinogradConv() {} + virtual void PrepareForRun(); + virtual void ReInitWhenNeeded(); + virtual void Run(); +#ifdef LITE_WITH_PROFILE + virtual void SetProfileRuntimeKernelInfo( + paddle::lite::profile::OpCharacter* ch) { + ch->kernel_func_name = kernel_func_name_; + } + std::string kernel_func_name_{"NotImplForConvWino"}; +#endif + protected: + using param_t = operators::ConvParam; + Tensor weights_; + Tensor bias_; + DDim last_shape_; + int workspace_size_{0}; + int last_function_{-1}; + bool choose_small_{true}; + int wino_iw{4}; + std::vector w_scale_; +}; } // namespace arm } // namespace kernels } // namespace lite diff --git a/lite/kernels/arm/pool_compute.cc b/lite/kernels/arm/pool_compute.cc index ff6100c4e2c68d7eee0d5d0eeabbb64a1ca699e2..5cfca8f1b7d9a286d24dda5af5664aa381c8e0f1 100644 --- a/lite/kernels/arm/pool_compute.cc +++ b/lite/kernels/arm/pool_compute.cc @@ -58,6 +58,7 @@ void PoolCompute::Run() { bool global_pooling = (paddings[0] == 0) && (ksize[0] == in_dims[2]) && (ksize[1] == in_dims[3]) && kps_equal && pads_equal; global_pooling = param.global_pooling || global_pooling; + if (global_pooling) { for (size_t i = 0; i < ksize.size(); ++i) { paddings[2 * i] = 0; @@ -107,35 +108,65 @@ void PoolCompute::Run() { } else if (ksize[0] == 2 && strides[0] == 2 && paddings[0] == 0 && kps_equal) { if (pooling_type == "max") { - lite::arm::math::pooling2x2s2_max(din, - dout, - out_dims[0], - out_dims[1], - out_dims[2], - out_dims[3], - in_dims[1], - in_dims[2], - in_dims[3], - paddings[1], - paddings[3]); + lite::arm::math::pooling2x2s2p0_max(din, + dout, + out_dims[0], + out_dims[1], + out_dims[2], + out_dims[3], + in_dims[1], + in_dims[2], + in_dims[3], + paddings[1], + paddings[3]); return; } else if (pooling_type == "avg") { - lite::arm::math::pooling2x2s2_avg(din, - dout, - out_dims[0], - out_dims[1], - out_dims[2], - out_dims[3], - in_dims[1], - in_dims[2], - in_dims[3], - exclusive, - paddings[1], - paddings[3]); + lite::arm::math::pooling2x2s2p0_avg(din, + dout, + out_dims[0], + out_dims[1], + out_dims[2], + out_dims[3], + in_dims[1], + in_dims[2], + in_dims[3], + exclusive, + paddings[1], + paddings[3]); return; } - } else if (ksize[0] == 3 && strides[0] == 1 && paddings[0] == 1 && + } else if (ksize[0] == 2 && strides[0] == 2 && paddings[0] == 1 && kps_equal) { + if (pooling_type == "max") { + lite::arm::math::pooling2x2s2p1_max(din, + dout, + out_dims[0], + out_dims[1], + out_dims[2], + out_dims[3], + in_dims[1], + in_dims[2], + in_dims[3], + paddings[1], + paddings[3]); + return; + } else if (pooling_type == "avg") { + lite::arm::math::pooling2x2s2p1_avg(din, + dout, + out_dims[0], + out_dims[1], + out_dims[2], + out_dims[3], + in_dims[1], + in_dims[2], + in_dims[3], + exclusive, + paddings[1], + paddings[3]); + return; + } + } else if (ksize[0] == 3 && strides[0] == 1 && paddings[0] == 1 && + pads_equal && kps_equal) { if (pooling_type == "max") { lite::arm::math::pooling3x3s1p1_max(din, dout, @@ -165,7 +196,7 @@ void PoolCompute::Run() { return; } } else if (ksize[0] == 3 && strides[0] == 1 && paddings[0] == 0 && - kps_equal) { + pads_equal && kps_equal) { if (pooling_type == "max") { lite::arm::math::pooling3x3s1p0_max(din, dout, @@ -195,7 +226,7 @@ void PoolCompute::Run() { return; } } else if (ksize[0] == 3 && strides[0] == 2 && paddings[0] == 0 && - kps_equal) { + pads_equal && kps_equal) { if (pooling_type == "max") { lite::arm::math::pooling3x3s2p0_max(din, dout, @@ -225,7 +256,7 @@ void PoolCompute::Run() { return; } } else if (ksize[0] == 3 && strides[0] == 2 && paddings[0] == 1 && - kps_equal) { + pads_equal && kps_equal) { if (pooling_type == "max") { lite::arm::math::pooling3x3s2p1_max(din, dout, diff --git a/lite/kernels/arm/softmax_compute.cc b/lite/kernels/arm/softmax_compute.cc index 3409d0f5c5bd6e7ce1ea77809f7715b62bb10ca2..79ea23ab3fad3340c63846ea11cc89b371f5c6c9 100644 --- a/lite/kernels/arm/softmax_compute.cc +++ b/lite/kernels/arm/softmax_compute.cc @@ -34,7 +34,7 @@ void SoftmaxCompute::Run() { int inner_num = x_dims.Slice(axis + 1, x_rank).production(); int axis_size = x_dims[axis]; if (inner_num == 1) { - if (axis_size >= 4) { + if (axis_size > 4) { lite::arm::math::softmax_inner1_large_axis( din, dout, outer_num, axis_size); } else { diff --git a/lite/tests/math/conv_int8_compute_test.cc b/lite/tests/math/conv_int8_compute_test.cc index 02478a23f9634c96864429be73e7c4c22153e21f..3f0e48d24aa020aa9fa709e65ac4bb37e7e28c04 100644 --- a/lite/tests/math/conv_int8_compute_test.cc +++ b/lite/tests/math/conv_int8_compute_test.cc @@ -34,7 +34,7 @@ DEFINE_int32(power_mode, DEFINE_int32(threads, 1, "threads num"); DEFINE_int32(warmup, 0, "warmup times"); DEFINE_int32(repeats, 1, "repeats times"); -DEFINE_bool(basic_test, true, "do all tests"); +DEFINE_bool(basic_test, false, "do all tests"); DEFINE_bool(check_result, true, "check the result"); DEFINE_int32(batch, 1, "batch size"); @@ -59,6 +59,7 @@ DEFINE_bool(flag_bias, true, "with bias"); typedef paddle::lite::DDim DDim; typedef paddle::lite::Tensor Tensor; typedef paddle::lite::operators::ConvParam ConvParam; +typedef paddle::lite::operators::ActivationParam ActivationParam; using paddle::lite::profile::Timer; DDim compute_out_dim(const DDim& dim_in, @@ -165,7 +166,18 @@ void test_conv_int8(const std::vector& input_dims, param_fp32_out.bias->CopyDataFrom(*param_int8_out.bias); bias_fp32.CopyDataFrom(*param_int8_out.bias); } - + if (flag_relu) { + ActivationParam act_param; + act_param.has_active = true; + act_param.active_type = (paddle::lite_api::ActivationType) + flag_relu; // 1-relu, 2-relu6, 4-leakyrelu + if (flag_relu) { + param_fp32_out.fuse_relu = true; + param_int8_out.fuse_relu = true; + } + param_fp32_out.activation_param = act_param; + param_int8_out.activation_param = act_param; + } std::vector scale_in{1.f / 127}; std::vector scale_out{weight_dim.count(1, 4) / 127.f}; std::vector scale_w(weight_dim[0], 1.f / 127); @@ -580,6 +592,9 @@ TEST(TestConv3x3s1Int8, test_conv_3x3s1) { dims.push_back(DDim({batch, cin, h, h})); } } + if (cin == 1 && cout == 1) { + continue; + } test_conv_int8(dims, weights_dim, 1, diff --git a/lite/tests/math/sgemm_c4_compute_test.cc b/lite/tests/math/sgemm_c4_compute_test.cc index 3e5577e03075502bab30aa03a50241b817fa8742..b5beeaffaed6bff8a260c158bdce234fce6c1349 100644 --- a/lite/tests/math/sgemm_c4_compute_test.cc +++ b/lite/tests/math/sgemm_c4_compute_test.cc @@ -179,6 +179,141 @@ bool test_sgemm_c4( #endif return true; } +bool test_sgemm_c8( + int m, int n, int k, bool has_bias, bool has_relu, int cls, int ths) { + int m_round = (m + 7) / 8 * 8; + int k_round = (k + 7) / 8 * 8; + int size_a = m * k; + int size_b = n * k; + int size_a_c4 = m_round * k_round; + int size_b_c8 = k_round * n; + + Tensor ta; + Tensor tb; + Tensor ta_c4; + Tensor tb_c8; + Tensor tc; + Tensor tc_basic; + Tensor tc_backup; + Tensor tbias; + + ta.Resize({size_a}); + tb.Resize({size_b}); + ta_c4.Resize({size_a_c4}); + tb_c8.Resize({size_b_c8}); + tc.Resize({m_round * n}); + tc_basic.Resize({m_round * n}); + tbias.Resize({m}); + + ta.set_precision(PRECISION(kInt16)); + tb.set_precision(PRECISION(kInt16)); + ta_c4.set_precision(PRECISION(kInt16)); + tb_c8.set_precision(PRECISION(kInt16)); + tc.set_precision(PRECISION(kInt32)); + tc_basic.set_precision(PRECISION(kInt32)); + tbias.set_precision(PRECISION(kInt32)); + + fill_tensor_rand(ta); + fill_tensor_rand(tb); + fill_tensor_rand(tbias); + fill_tensor_rand(tc); + + auto da = ta.mutable_data(); + auto db = tb.mutable_data(); + auto da_c4 = ta_c4.mutable_data(); + auto db_c8 = tb_c8.mutable_data(); + auto dc_basic = tc_basic.mutable_data(); + auto dbias = tbias.mutable_data(); + + // trans A, B to c4 + basic_trans_mat_to_c8(da, da_c4, k, m, k, true); + basic_trans_mat_to_c8(db, db_c8, n, k, n, false); + + LOG(INFO) << "sgemm_c8 M: " << m << ", N: " << n << ", K: " << k + << ", relu: " << (has_relu ? "true" : "false") + << ", bias: " << (has_bias ? "true" : "false"); + + if (FLAGS_check_result) { + basic_gemm_c8(false, + false, + m, + n, + k, + 1, + da, + k, + db, + n, + 0, + dc_basic, + n, + dbias, + false, + false); + } + Timer t0; + LOG(INFO) << "basic test end"; +#ifdef LITE_WITH_ARM + //! compute + double ops = 2.0 * m_round * n * k_round; + std::unique_ptr ctx1( + new paddle::lite::KernelContext); + auto& ctx = ctx1->As(); + ctx.SetRunMode(static_cast(cls), ths); + auto dc = tc.mutable_data(); + for (int j = 0; j < FLAGS_warmup; ++j) { + paddle::lite::arm::math::sgemm_prepack_c8_int16_small( + m, n, k, da_c4, db_c8, dc, &ctx); + } + LOG(INFO) << "basic test end"; + + for (int i = 0; i < FLAGS_repeats; ++i) { + t0.Start(); + paddle::lite::arm::math::sgemm_prepack_c8_int16_small( + m, n, k, da_c4, db_c8, dc, &ctx); + t0.Stop(); + } + LOG(INFO) << "basic test end"; + LOG(INFO) << "M: " << m << ", N: " << n << ", K: " << k + << ", power_mode: " << cls << ", threads: " << ths + << ", GOPS: " << ops * 1e-9f + << " GOPS, avg time: " << t0.LapTimes().Avg() + << " ms, min time: " << t0.LapTimes().Min() + << " ms, mean GOPs: " << ops * 1e-6f / t0.LapTimes().Avg() + << " GOPs, max GOPs: " << ops * 1e-6f / t0.LapTimes().Min() + << " GOPs"; + + if (FLAGS_check_result) { + double max_ratio = 0; + double max_diff = 0; + tensor_cmp_host(tc_basic, tc, max_ratio, max_diff); + LOG(INFO) << "compare result, max diff: " << max_diff + << ", max ratio: " << max_ratio; + if (std::abs(max_ratio) > 1e-4f && std::abs(max_diff) > 5e-5f) { + Tensor tdiff; + tdiff.set_precision(PRECISION(kInt32)); + tdiff.Resize(tc.dims()); + tensor_diff(tc_basic, tc, tdiff); + LOG(INFO) << "a: "; + print_tensor(ta); + LOG(INFO) << "a_c8: "; + print_tensor(ta_c4); + LOG(INFO) << "b: "; + print_tensor(tb); + LOG(INFO) << "b_c8: "; + print_tensor(tb_c8); + LOG(INFO) << "basic result: "; + print_tensor(tc_basic); + LOG(INFO) << "lite result: "; + print_tensor(tc); + LOG(INFO) << "diff result: "; + print_tensor(tdiff); + return false; + } + } +#endif + return true; +} TEST(TestSgemmC4, test_func_sgemm_c4_prepacked) { if (FLAGS_basic_test) { @@ -186,11 +321,11 @@ TEST(TestSgemmC4, test_func_sgemm_c4_prepacked) { paddle::lite::DeviceInfo::Init(); #endif LOG(INFO) << "run basic sgemm_c4 test"; - for (auto& m : {1, 3, 8, 32, 397}) { - for (auto& n : {1, 2, 3, 4, 13, 141, 789}) { - for (auto& k : {1, 3, 8, 59, 234}) { - for (auto& has_bias : {false, true}) { - for (auto& has_relu : {false, true}) { + for (auto& m : {1, 3, 8, 32, 397, 32, 64, 77}) { + for (auto& n : {1, 2, 3, 4, 13, 141, 789, 1}) { + for (auto& k : {1, 3, 8, 59, 234, 19}) { + for (auto& has_bias : {false}) { + for (auto& has_relu : {false}) { for (auto& th : {1, 2, 4}) { auto flag = test_sgemm_c4( m, n, k, has_bias, has_relu, FLAGS_power_mode, th); @@ -213,8 +348,41 @@ TEST(TestSgemmC4, test_func_sgemm_c4_prepacked) { } } } +TEST(TestSgemmC8, test_func_sgemm_c8_prepacked) { + if (FLAGS_basic_test) { +#ifdef LITE_WITH_ARM + paddle::lite::DeviceInfo::Init(); +#endif + LOG(INFO) << "run basic sgemm_c4 test"; + for (auto& m : {1, 3, 8, 32, 397, 32, 64, 77}) { + for (auto& n : {1, 2, 3, 4, 13, 141, 789, 1}) { + for (auto& k : {1, 3, 8, 59, 234, 19}) { + for (auto& has_bias : {false}) { + for (auto& has_relu : {false}) { + for (auto& th : {1}) { + auto flag = test_sgemm_c8( + m, n, k, has_bias, has_relu, FLAGS_power_mode, th); + if (flag) { + LOG(INFO) << "test m = " << m << ", n=" << n << ", k=" << k + << ", bias: " << (has_bias ? "true" : "false") + << ", relu: " << (has_relu ? "true" : "false") + << " passed\n"; + } else { + LOG(FATAL) << "test m = " << m << ", n=" << n << ", k=" << k + << ", bias: " << (has_bias ? "true" : "false") + << ", relu: " << (has_relu ? "true" : "false") + << " failed\n"; + } + } + } + } + } + } + } + } +} -TEST(TestSgemmC4Custom, test_func_sgemm_c4_prepacked_custom) { +TEST(TestSgemmCnCustom, test_func_sgemm_cn_prepacked_custom) { #ifdef LITE_WITH_ARM paddle::lite::DeviceInfo::Init(); #endif @@ -230,6 +398,18 @@ TEST(TestSgemmC4Custom, test_func_sgemm_c4_prepacked_custom) { << ", k=" << FLAGS_K << ", bias: " << FLAGS_flag_bias << ", relu: " << FLAGS_flag_relu << " failed!!"; } + flag = test_sgemm_c8(FLAGS_M, + FLAGS_N, + FLAGS_K, + FLAGS_flag_bias, + FLAGS_flag_relu, + FLAGS_power_mode, + FLAGS_threads); + if (!flag) { + LOG(FATAL) << "test m = " << FLAGS_M << ", n=" << FLAGS_N + << ", k=" << FLAGS_K << ", bias: " << FLAGS_flag_bias + << ", relu: " << FLAGS_flag_relu << " failed!!"; + } LOG(INFO) << "test m = " << FLAGS_M << ", n=" << FLAGS_N << ", k=" << FLAGS_K << ", bias: " << FLAGS_flag_bias << ", relu: " << FLAGS_flag_relu << " passed!!"; diff --git a/lite/tests/utils/naive_math_impl.h b/lite/tests/utils/naive_math_impl.h index e5ef77ca061d31a0b9b735d49cda9bbeda53c294..67e1b8a0e6656fee34158eb8452f32ba2a115c1c 100644 --- a/lite/tests/utils/naive_math_impl.h +++ b/lite/tests/utils/naive_math_impl.h @@ -60,6 +60,72 @@ static void basic_trans_mat_to_c4(const type* input, } } } +template +static void basic_trans_mat_to_c8(const type* input, + type* output, + const int ldin, + const int M, + const int K, + bool pack_k) { + const int m_round = (M + 7) / 8 * 8; + int k_round = (K + 7) / 8 * 8; + if (!pack_k) { + k_round = K; + } + const int m_loop = m_round / 8; + type zero_buf[K]; + memset(zero_buf, 0, K * sizeof(type)); + for (int i = 0; i < m_loop; ++i) { + const type* in0 = input + i * 8 * ldin; + const type* in1 = in0 + ldin; + const type* in2 = in1 + ldin; + const type* in3 = in2 + ldin; + const type* in4 = in3 + ldin; + const type* in5 = in4 + ldin; + const type* in6 = in5 + ldin; + const type* in7 = in6 + ldin; + if (8 * (i + 1) - M > 0) { + switch (8 * (i + 1) - M) { + case 7: + in1 = zero_buf; + case 6: + in2 = zero_buf; + case 5: + in3 = zero_buf; + case 4: + in4 = zero_buf; + case 3: + in5 = zero_buf; + case 2: + in6 = zero_buf; + case 1: + in7 = zero_buf; + default: + break; + } + } + for (int j = 0; j < K; ++j) { + *output++ = *in0++; + *output++ = *in1++; + *output++ = *in2++; + *output++ = *in3++; + *output++ = *in4++; + *output++ = *in5++; + *output++ = *in6++; + *output++ = *in7++; + } + for (int j = K; j < k_round; ++j) { + *output++ = static_cast(0); + *output++ = static_cast(0); + *output++ = static_cast(0); + *output++ = static_cast(0); + *output++ = static_cast(0); + *output++ = static_cast(0); + *output++ = static_cast(0); + *output++ = static_cast(0); + } + } +} template static void basic_gemm_c4(bool trans_a, @@ -116,6 +182,60 @@ static void basic_gemm_c4(bool trans_a, free(tmp_c); } +template +static void basic_gemm_c8(bool trans_a, + bool trans_b, + int m, + int n, + int k, + type2 alpha, + const type* a, + int lda, + const type* b, + int ldb, + type2 beta, + type2* c, + int ldc, + const type2* bias, + bool flag_bias = false, + bool flag_relu = false) { + type2* tmp_c = reinterpret_cast(malloc(m * ldc * sizeof(type2))); + memset(tmp_c, 0, m * ldc * sizeof(type2)); +#pragma omp parallel for + for (int i = 0; i < m; ++i) { + auto bias_data = static_cast(0); + if (flag_bias) { + bias_data = bias[i]; + } + for (int j = 0; j < n; ++j) { + auto sum = static_cast(0); + for (int l = 0; l < k; ++l) { + type av; + type bv; + if (trans_a) { + av = a[l * lda + i]; + } else { + av = a[i * lda + l]; + } + if (trans_b) { + bv = b[j * ldb + l]; + } else { + bv = b[l * ldb + j]; + } + sum += av * bv; + } + type2 tmp = alpha * sum + beta * tmp_c[i * ldc + j] + bias_data; + if (flag_relu) { + tmp_c[i * ldc + j] = tmp > (type2)0 ? tmp : (type2)0; + } else { + tmp_c[i * ldc + j] = tmp; + } + } + } + //! trans c to c4 + basic_trans_mat_to_c8(tmp_c, c, ldc, m, n, false); + free(tmp_c); +} template static void basic_gemm(bool trans_a, bool trans_b, diff --git a/lite/tests/utils/tensor_utils.h b/lite/tests/utils/tensor_utils.h index 4f8d1ad2aa70dc09ab22d0e22df2180b5da83788..3ab8ac7261df37e9688f3f4ed6efcebc31b9797e 100644 --- a/lite/tests/utils/tensor_utils.h +++ b/lite/tests/utils/tensor_utils.h @@ -41,6 +41,10 @@ void fill_tensor_const(Tensor& tensor, float value) { // NOLINT fill_tensor_host_const_impl( tensor.mutable_data(), static_cast(value), size); break; + case PRECISION(kInt16): + fill_tensor_host_const_impl( + tensor.mutable_data(), static_cast(value), size); + break; case PRECISION(kInt32): fill_tensor_host_const_impl( tensor.mutable_data(), static_cast(value), size); @@ -69,6 +73,12 @@ void fill_tensor_host_rand_impl(signed char* dio, int64_t size) { } } template <> +void fill_tensor_host_rand_impl(int16_t* dio, int64_t size) { + for (int64_t i = 0; i < size; ++i) { + dio[i] = (rand() % 256 - 128) * 2; // NOLINT + } +} +template <> void fill_tensor_host_rand_impl(unsigned char* dio, int64_t size) { for (int64_t i = 0; i < size; ++i) { @@ -86,6 +96,9 @@ void fill_tensor_rand(Tensor& tensor) { // NOLINT case PRECISION(kInt8): fill_tensor_host_rand_impl(tensor.mutable_data(), size); break; + case PRECISION(kInt16): + fill_tensor_host_rand_impl(tensor.mutable_data(), size); + break; case PRECISION(kInt32): fill_tensor_host_rand_impl(tensor.mutable_data(), size); break; diff --git a/lite/utils/cv/image_resize.cc b/lite/utils/cv/image_resize.cc index 8b1638b5665bf625c1335da760d4df843618b080..1a971bf78b50f149b9d1ce781d943d906ea902e4 100644 --- a/lite/utils/cv/image_resize.cc +++ b/lite/utils/cv/image_resize.cc @@ -678,15 +678,9 @@ void resize(const uint8_t* src, } else if (srcFormat == NV12 || srcFormat == NV21) { nv21_resize(src, dst, srcw, srch, dstw, dsth); return; - num = 1; - int hout = static_cast(0.5 * dsth); - dsth += hout; } else if (srcFormat == BGR || srcFormat == RGB) { bgr_resize(src, dst, srcw, srch, dstw, dsth); return; - w_in = srcw * 3; - w_out = dstw * 3; - num = 3; } else if (srcFormat == BGRA || srcFormat == RGBA) { w_in = srcw * 4; w_out = dstw * 4;