From e055452844098358e22cd0a169dee499026e9716 Mon Sep 17 00:00:00 2001 From: TianXiaogang Date: Thu, 28 Nov 2019 11:49:45 +0800 Subject: [PATCH] add winograd c4 implement (#2494) (#2508) * add winograd c4 implement (#2494) fix: fix conv_block prepack_input_nxwc4 bug * fix: optimize sgemm_c4 in armv7 change condition of choose winograd kernel * fix: change conv choose kernel condition * fix winograd reinitwhenneed (#2511) --- lite/backends/arm/math/CMakeLists.txt | 1 + .../arm/math/conv3x3_winograd_fp32_c4.cc | 564 ++++++++++++++++++ lite/backends/arm/math/conv_block_utils.h | 23 +- lite/backends/arm/math/conv_impl.h | 16 + lite/backends/arm/math/packed_sgemm_c4.h | 10 + lite/kernels/arm/conv_compute.cc | 7 +- lite/kernels/arm/conv_winograd.cc | 177 ++++-- lite/kernels/arm/conv_winograd.h | 1 + 8 files changed, 728 insertions(+), 71 deletions(-) create mode 100644 lite/backends/arm/math/conv3x3_winograd_fp32_c4.cc diff --git a/lite/backends/arm/math/CMakeLists.txt b/lite/backends/arm/math/CMakeLists.txt index a38afd5503..076c791daa 100644 --- a/lite/backends/arm/math/CMakeLists.txt +++ b/lite/backends/arm/math/CMakeLists.txt @@ -79,6 +79,7 @@ if (NOT HAS_ARM_MATH_LIB_DIR) conv5x5s1_depthwise_int8.cc conv5x5s1_depthwise_fp32.cc conv5x5s2_depthwise_fp32.cc + conv3x3_winograd_fp32_c4.cc conv_winograd_3x3.cc conv_impl.cc softmax.cc diff --git a/lite/backends/arm/math/conv3x3_winograd_fp32_c4.cc b/lite/backends/arm/math/conv3x3_winograd_fp32_c4.cc new file mode 100644 index 0000000000..5834461b8f --- /dev/null +++ b/lite/backends/arm/math/conv3x3_winograd_fp32_c4.cc @@ -0,0 +1,564 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/backends/arm/math/conv_block_utils.h" +#include "lite/backends/arm/math/conv_impl.h" +#include "lite/backends/arm/math/packed_sgemm_c4.h" +#ifdef ARM_WITH_OMP +#include +#endif +#include + +namespace paddle { +namespace lite { +namespace arm { +namespace math { +void input_trans_c4(const float* src, + int src_stride, + float* dest, + int dest_stride); +void output_trans_c4(const float* src, + int src_stride, + float* dest, + int dest_stride); +void output_trans_c4_post(const float* src, + int src_stride, + float* dest, + int dest_stride, + float* bias_value, + bool has_relu); +void weight_trans_c4( + float* dest, const float* src, int ic, int oc, void* workspace); + +/* +*The following function conv_compute_6x6_3x3 is base on +*MNN[https://github.com/alibaba/MNN] +* +*Copyright © 2018, Alibaba Group Holding Limited +*/ +void conv_compute_6x6_3x3(const float* input, + float* output, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + const float* weight, + const float* bias, + const operators::ConvParam& param, + ARMContext* ctx) { + const int pad_h = (*param.paddings)[0]; + const int pad_w = (*param.paddings)[2]; + float* tmp_work_space = + ctx->workspace_data() + ctx->llc_size() / sizeof(float); + + int in_n_stride = chin * hin * win; + int out_n_stride = chout * hout * wout; + int ic_stride = win * hin; + int oc_stride = wout * hout; + int ic_4 = (chin + 3) / 4; + int oc_4 = (chout + 3) / 4; + + int tile_w = (wout + 5) / 6; + int tile_h = (hout + 5) / 6; + int size_tile = tile_h * tile_w; + float zero_ptr[8]; + memset(zero_ptr, 0, 8 * sizeof(float)); + + int w_pad = win + pad_w * 2; + int h_pad = hin + pad_h * 2; + float* input_c4 = tmp_work_space; + int new_h_stride = w_pad * 4; + int new_c_stride = new_h_stride * h_pad; + + int ic_4_stride = w_pad * h_pad * 4; + int oc_4_stride = wout * hout * 4; + + int tile_block = 8; +#ifdef __aarch64__ + tile_block = 16; +#endif + int block_count = (size_tile + tile_block - 1) / tile_block; + + int threads = ctx->threads(); + float* g_tmp_data = tmp_work_space + ic_4 * new_c_stride; + int tmp_data_thread_stride = tile_block * (oc_4 + ic_4) * 256; + memset(g_tmp_data, 0, threads * tmp_data_thread_stride * sizeof(float)); + float* g_trans_tmp_data = g_tmp_data + threads * tmp_data_thread_stride; + float* g_trans_remain_tmp_data = g_trans_tmp_data + threads * 256; + + // begin compute + for (int ni = 0; ni < num; ++ni) { + // trans input to c4 + for (int i = 0; i < ic_4; ++i) { + prepack_input_nxwc4_dw(input + ni * in_n_stride, + input_c4 + i * new_c_stride, + i * 4, + -pad_h, + hin + pad_h, + -pad_w, + win + pad_w, + chin, + win, + hin, + zero_ptr); + } + float* output_ptr = output + ni * out_n_stride; + + const float* weight_ptr = weight; + const float* bias_ptr = bias; +#pragma omp parallel for num_threads(threads) + for (int tbi = 0; tbi < block_count; ++tbi) { +#ifdef ARM_WITH_OMP + float* tmp_data = + g_tmp_data + omp_get_thread_num() * tmp_data_thread_stride; + float* trans_tmp_data = g_trans_tmp_data + omp_get_thread_num() * 256; + float* trans_remain_tmp_data = + g_trans_remain_tmp_data + omp_get_thread_num() * 256; +#else + float* tmp_data = g_tmp_data; + float* trans_tmp_data = g_trans_tmp_data; + float* trans_remain_tmp_data = g_trans_remain_tmp_data; +#endif + int tile_index = tbi * tile_block; + int tile_remain = size_tile - tile_index; + int tile_count = tile_remain > tile_block ? tile_block : tile_remain; + + // input trans + int c_gi_stride = tile_count * oc_4 * 4; + int b_gi_stride = tile_count * ic_4 * 4; + //* + for (int ti = 0; ti < tile_count; ++ti) { + int index = tile_index + ti; + + int tw_index = index % tile_w; + int th_index = index / tile_w; + + int src_x = tw_index * 6; + int src_y = th_index * 6; + int ex = src_x + 8 > w_pad ? w_pad - src_x : 8; + int ey = src_y + 8 > h_pad ? h_pad - src_y : 8; + + float* dst_ptr = tmp_data + ti * 4; + const float* src_ptr = input_c4 + (src_y * w_pad + src_x) * 4; + + if (ex == 8 && ey == 8) { + // trans input + for (int ci = 0; ci < ic_4; ++ci) { + const float* src_ci = src_ptr + ci * ic_4_stride; + for (int i = 0; i < 8; ++i) { + const float* ci_ptr = src_ci + i * w_pad * 4; + input_trans_c4(ci_ptr, 4, trans_tmp_data + i * 4, 32); + } + float* dst_ci = dst_ptr + ci * tile_count * 4; + for (int i = 0; i < 8; ++i) { + input_trans_c4(trans_tmp_data + i * 32, + 4, + dst_ci + i * b_gi_stride * 8, + b_gi_stride); + } + } + } else { + // trans remain input + int x_size = ex; + for (int ci = 0; ci < ic_4; ++ci) { + const float* src_ci = src_ptr + ci * ic_4_stride; + // pad + memset(trans_remain_tmp_data, 0, 256 * sizeof(float)); + if (x_size > 0) { + for (int yi = 0; yi < ey; ++yi) { + float* dst_yi = trans_remain_tmp_data + yi * 32; + const float* src_yi = src_ci + w_pad * yi * 4; + memcpy(dst_yi, src_yi, x_size * sizeof(float) * 4); + } + } + + // trans + for (int i = 0; i < 8; ++i) { + float* ci_ptr = trans_remain_tmp_data + i * 32; + input_trans_c4(ci_ptr, 4, trans_tmp_data + i * 4, 32); + } + float* dst_ci = dst_ptr + ci * tile_count * 4; + for (int i = 0; i < 8; ++i) { + input_trans_c4(trans_tmp_data + i * 32, + 4, + dst_ci + i * b_gi_stride * 8, + b_gi_stride); + } + } // for ci_4 + } + } + //*/ + // input trans end + // *begin compute dot + // * + //* + float* dst_temp_data = tmp_data + tile_block * ic_4 * 256; + float* b_ptr = tmp_data; + int w_gi_stride = ic_4 * oc_4 * 16; + for (int gi = 0; gi < 64; ++gi) { + float* origin_C = dst_temp_data + gi * c_gi_stride; + float* origin_B = b_ptr + gi * b_gi_stride; + const float* origin_A = weight + gi * w_gi_stride; + sgemm_prepack_c4_small(oc_4 * 4, + tile_count, + ic_4 * 4, + origin_A, + origin_B, + origin_C, + nullptr, + false, + false, + ctx); + } + //*/ + //* + // output trans + float bias_value[4]; + memset(bias_value, 0, 4 * sizeof(float)); + + for (int ti = 0; ti < tile_count; ++ti) { + int index = tile_index + ti; + + int tw_index = index % tile_w; + int th_index = index / tile_w; + + int dst_x = tw_index * 6; + int dst_y = th_index * 6; + + int ex = dst_x + 6 > wout ? wout - dst_x : 6; + int ey = dst_y + 6 > hout ? hout - dst_y : 6; + + float* dst_ptr = output + (dst_y * wout + dst_x) * 4; + float* src_ptr = dst_temp_data + ti * 4; + + if (ex == 6) { + // trans output + for (int ci = 0; ci < oc_4; ++ci) { + if (param.bias) { + bias_value[0] = bias[ci * 4]; + bias_value[1] = bias[ci * 4 + 1]; + bias_value[2] = bias[ci * 4 + 2]; + bias_value[3] = bias[ci * 4 + 3]; + } + + float* dst_ci = dst_ptr + ci * oc_4_stride; + float* src_ci = src_ptr + ci * tile_count * 4; + for (int i = 0; i < 8; ++i) { + output_trans_c4(src_ci + i * c_gi_stride * 8, + c_gi_stride, + trans_tmp_data + i * 4, + 32); + } + for (int i = 0; i < ey; ++i) { + output_trans_c4_post(trans_tmp_data + i * 32, + 4, + trans_remain_tmp_data + i * 24, + 4, + bias_value, + param.fuse_relu); + } + write_to_output_c4_fp32(trans_remain_tmp_data, + output_ptr, + ci * 4, + ci * 4 + 4, + dst_y, + dst_y + ey, + dst_x, + dst_x + ex, + chout, + hout, + wout, + false, + zero_ptr); + } + } else { + for (int ci = 0; ci < oc_4; ++ci) { + if (param.bias) { + bias_value[0] = bias[ci * 4]; + bias_value[1] = bias[ci * 4 + 1]; + bias_value[2] = bias[ci * 4 + 2]; + bias_value[3] = bias[ci * 4 + 3]; + } + // trans output + float* dst_ci = dst_ptr + ci * oc_4_stride; + float* src_ci = src_ptr + ci * tile_count * 4; + for (int i = 0; i < 8; ++i) { + output_trans_c4(src_ci + i * c_gi_stride * 8, + c_gi_stride, + trans_tmp_data + i * 4, + 32); + } + for (int i = 0; i < ey; ++i) { + output_trans_c4_post(trans_tmp_data + i * 32, + 4, + trans_remain_tmp_data + i * 24, + 4, + bias_value, + param.fuse_relu); + } + // copy to dest + memset(trans_tmp_data, 0, 144 * sizeof(float)); + for (int i = 0; i < ey; ++i) { + memcpy(trans_tmp_data + i * ex * 4, + trans_remain_tmp_data + i * 24, + ex * sizeof(float) * 4); + } + write_to_output_c4_fp32(trans_tmp_data, + output_ptr, + ci * 4, + ci * 4 + 4, + dst_y, + dst_y + ey, + dst_x, + dst_x + ex, + chout, + hout, + wout, + false, + zero_ptr); + } + } + } + //*/ + } // for block_count + } // for num +} // conv_compute + +void output_trans_c4(const float* src, + int src_stride, + float* dest, + int dest_stride) { + const float32x4_t src0 = vld1q_f32(src); + const float32x4_t src1 = vld1q_f32(src + src_stride); + const float32x4_t src2 = vld1q_f32(src + src_stride * 2); + const float32x4_t src3 = vld1q_f32(src + src_stride * 3); + const float32x4_t src4 = vld1q_f32(src + src_stride * 4); + const float32x4_t src5 = vld1q_f32(src + src_stride * 5); + const float32x4_t src6 = vld1q_f32(src + src_stride * 6); + const float32x4_t src7 = vld1q_f32(src + src_stride * 7); + + float32x4_t tmp024a = vaddq_f32(src1, src2); + float32x4_t tmp135a = vsubq_f32(src1, src2); + float32x4_t tmp024b = vaddq_f32(src3, src4); + float32x4_t tmp135b = vsubq_f32(src3, src4); + float32x4_t tmp024c = vaddq_f32(src5, src6); + float32x4_t tmp135c = vsubq_f32(src5, src6); + + float32x4_t dest0 = + vaddq_f32(vaddq_f32(vaddq_f32(src0, tmp024a), tmp024b), tmp024c); + float32x4_t dest2 = vaddq_f32(vaddq_f32(tmp024a, vmulq_n_f32(tmp024b, 4)), + vmulq_n_f32(tmp024c, 0.25f)); + float32x4_t dest4 = vaddq_f32(vaddq_f32(tmp024a, vmulq_n_f32(tmp024b, 16)), + vmulq_n_f32(tmp024c, 0.0625f)); + + float32x4_t dest1 = vaddq_f32(vaddq_f32(tmp135a, vmulq_n_f32(tmp135b, 2)), + vmulq_n_f32(tmp135c, 0.5f)); + float32x4_t dest3 = vaddq_f32(vaddq_f32(tmp135a, vmulq_n_f32(tmp135b, 8)), + vmulq_n_f32(tmp135c, 0.125f)); + float32x4_t dest5 = + vaddq_f32(src7, + vaddq_f32(vaddq_f32(tmp135a, vmulq_n_f32(tmp135b, 32)), + vmulq_n_f32(tmp135c, 0.03125f))); + + vst1q_f32(dest, dest0); + vst1q_f32(dest + dest_stride, dest1); + vst1q_f32(dest + dest_stride * 2, dest2); + vst1q_f32(dest + dest_stride * 3, dest3); + vst1q_f32(dest + dest_stride * 4, dest4); + vst1q_f32(dest + dest_stride * 5, dest5); +} +void output_trans_c4_post(const float* src, + int src_stride, + float* dest, + int dest_stride, + float* bias_value, + bool has_relu = false) { + const float32x4_t src0 = vld1q_f32(src); + const float32x4_t src1 = vld1q_f32(src + src_stride); + const float32x4_t src2 = vld1q_f32(src + src_stride * 2); + const float32x4_t src3 = vld1q_f32(src + src_stride * 3); + const float32x4_t src4 = vld1q_f32(src + src_stride * 4); + const float32x4_t src5 = vld1q_f32(src + src_stride * 5); + const float32x4_t src6 = vld1q_f32(src + src_stride * 6); + const float32x4_t src7 = vld1q_f32(src + src_stride * 7); + + float32x4_t tmp024a = vaddq_f32(src1, src2); + float32x4_t tmp135a = vsubq_f32(src1, src2); + float32x4_t tmp024b = vaddq_f32(src3, src4); + float32x4_t tmp135b = vsubq_f32(src3, src4); + float32x4_t tmp024c = vaddq_f32(src5, src6); + float32x4_t tmp135c = vsubq_f32(src5, src6); + + float32x4_t dest0 = + vaddq_f32(vaddq_f32(vaddq_f32(src0, tmp024a), tmp024b), tmp024c); + float32x4_t dest2 = vaddq_f32(vaddq_f32(tmp024a, vmulq_n_f32(tmp024b, 4)), + vmulq_n_f32(tmp024c, 0.25f)); + float32x4_t dest4 = vaddq_f32(vaddq_f32(tmp024a, vmulq_n_f32(tmp024b, 16)), + vmulq_n_f32(tmp024c, 0.0625f)); + + float32x4_t dest1 = vaddq_f32(vaddq_f32(tmp135a, vmulq_n_f32(tmp135b, 2)), + vmulq_n_f32(tmp135c, 0.5f)); + float32x4_t dest3 = vaddq_f32(vaddq_f32(tmp135a, vmulq_n_f32(tmp135b, 8)), + vmulq_n_f32(tmp135c, 0.125f)); + float32x4_t dest5 = + vaddq_f32(src7, + vaddq_f32(vaddq_f32(tmp135a, vmulq_n_f32(tmp135b, 32)), + vmulq_n_f32(tmp135c, 0.03125f))); + + if (bias_value) { + float32x4_t bias = vld1q_f32(bias_value); + dest0 = vaddq_f32(dest0, bias); + dest1 = vaddq_f32(dest1, bias); + dest2 = vaddq_f32(dest2, bias); + dest3 = vaddq_f32(dest3, bias); + dest4 = vaddq_f32(dest4, bias); + dest5 = vaddq_f32(dest5, bias); + } + + if (has_relu) { + float32x4_t zeros = vdupq_n_f32(0); + dest0 = vmaxq_f32(dest0, zeros); + dest1 = vmaxq_f32(dest1, zeros); + dest2 = vmaxq_f32(dest2, zeros); + dest3 = vmaxq_f32(dest3, zeros); + dest4 = vmaxq_f32(dest4, zeros); + dest5 = vmaxq_f32(dest5, zeros); + } + + vst1q_f32(dest, dest0); + vst1q_f32(dest + dest_stride, dest1); + vst1q_f32(dest + dest_stride * 2, dest2); + vst1q_f32(dest + dest_stride * 3, dest3); + vst1q_f32(dest + dest_stride * 4, dest4); + vst1q_f32(dest + dest_stride * 5, dest5); +} + +void input_trans_c4(const float* src, + int src_stride, + float* dest, + int dest_stride) { + float32x4_t src0 = vld1q_f32(src); + float32x4_t src1 = vld1q_f32(src + src_stride); + float32x4_t src2 = vld1q_f32(src + src_stride * 2); + float32x4_t src3 = vld1q_f32(src + src_stride * 3); + float32x4_t src4 = vld1q_f32(src + src_stride * 4); + float32x4_t src5 = vld1q_f32(src + src_stride * 5); + float32x4_t src6 = vld1q_f32(src + src_stride * 6); + float32x4_t src7 = vld1q_f32(src + src_stride * 7); + + float32x4_t dst0 = vaddq_f32(vsubq_f32(src0, src6), + vmulq_n_f32(vsubq_f32(src4, src2), 5.25)); + float32x4_t dst7 = vaddq_f32(vsubq_f32(src7, src1), + vmulq_n_f32(vsubq_f32(src3, src5), 5.25)); + + float32x4_t tmp12a = + vsubq_f32(vaddq_f32(src2, src6), vmulq_n_f32(src4, 4.25)); + float32x4_t tmp12b = + vsubq_f32(vaddq_f32(src1, src5), vmulq_n_f32(src3, 4.25)); + float32x4_t dst1 = vaddq_f32(tmp12a, tmp12b); + float32x4_t dst2 = vsubq_f32(tmp12a, tmp12b); + + float32x4_t tmp34a = vsubq_f32(vaddq_f32(src6, vmulq_n_f32(src2, 0.25)), + vmulq_n_f32(src4, 1.25)); + float32x4_t tmp34b = + vaddq_f32(vsubq_f32(vmulq_n_f32(src1, 0.5), vmulq_n_f32(src3, 2.5)), + vmulq_n_f32(src5, 2)); + float32x4_t dst3 = vaddq_f32(tmp34a, tmp34b); + float32x4_t dst4 = vsubq_f32(tmp34a, tmp34b); + + float32x4_t tmp56a = + vaddq_f32(src6, vmulq_n_f32(vsubq_f32(src2, vmulq_n_f32(src4, 1.25)), 4)); + float32x4_t tmp56b = + vaddq_f32(vsubq_f32(vmulq_n_f32(src1, 2), vmulq_n_f32(src3, 2.5)), + vmulq_n_f32(src5, 0.5)); + float32x4_t dst5 = vaddq_f32(tmp56a, tmp56b); + float32x4_t dst6 = vsubq_f32(tmp56a, tmp56b); + + vst1q_f32(dest, dst0); + vst1q_f32(dest + dest_stride, dst1); + vst1q_f32(dest + dest_stride * 2, dst2); + vst1q_f32(dest + dest_stride * 3, dst3); + vst1q_f32(dest + dest_stride * 4, dst4); + vst1q_f32(dest + dest_stride * 5, dst5); + vst1q_f32(dest + dest_stride * 6, dst6); + vst1q_f32(dest + dest_stride * 7, dst7); +} +void weight_trans_c4( + float* dest, const float* din, int ch_in, int ch_out, void* workspace) { + const float coeff[8][3] = {{1.0f, 0.0f, 0.0f}, + {-2.0f / 9, -2.0f / 9, -2.0f / 9}, + {-2.0f / 9, 2.0f / 9, -2.0f / 9}, + {1.0f / 90, 1.0f / 45, 2.0f / 45}, + {1.0f / 90, -1.0f / 45, 2.0f / 45}, + {32.0f / 45, 16.0f / 45, 8.0f / 45}, + {32.0f / 45, -16.0f / 45, 8.0f / 45}, + {0.0f, 0.0f, 1.0f}}; + + float* ptr_out = static_cast(workspace); + + for (int i = 0; i < ch_out; i++) { + for (int j = 0; j < ch_in; j++) { + const float* kernel0 = + static_cast(din) + (i * ch_in + j) * 9; + float* ptr_channel = ptr_out + (i * ch_in + j) * 64; + + //! transform kernel, transposed + const float* k0 = kernel0; + const float* k1 = kernel0 + 3; + const float* k2 = kernel0 + 6; + + //! h + float tmp[8][3]; + for (int i = 0; i < 8; i++) { + tmp[i][0] = + k0[0] * coeff[i][0] + k0[1] * coeff[i][1] + k0[2] * coeff[i][2]; + tmp[i][1] = + k1[0] * coeff[i][0] + k1[1] * coeff[i][1] + k1[2] * coeff[i][2]; + tmp[i][2] = + k2[0] * coeff[i][0] + k2[1] * coeff[i][1] + k2[2] * coeff[i][2]; + } + + //! v + for (int j = 0; j < 8; j++) { + float* tmpp = &tmp[j][0]; + for (int i = 0; i < 8; i++) { + ptr_channel[j * 8 + i] = tmpp[0] * coeff[i][0] + + tmpp[1] * coeff[i][1] + + tmpp[2] * coeff[i][2]; + } + } + } + } + + int oc_pad = (ch_out + 3) / 4 * 4; + int ic_pad = (ch_in + 3) / 4 * 4; + int c_stride = ic_pad * oc_pad; + for (int i = 0; i < ch_out * ch_in * 64; ++i) { + int new_c = i % 64; + int new_oc = i / ch_in / 64 / 4; + int new_ic = i / 64 % (ch_in * 4) % ch_in; + int new_inner = i / ch_in / 64 % 4; + int dest_ind = + new_c * c_stride + new_oc * ic_pad * 4 + new_ic * 4 + new_inner; + dest[dest_ind] = ptr_out[i]; + } +} + +} // namespace math +} // namespace arm +} // namespace lite +} // namespace paddle diff --git a/lite/backends/arm/math/conv_block_utils.h b/lite/backends/arm/math/conv_block_utils.h index 24b99692cc..e4279d9a72 100644 --- a/lite/backends/arm/math/conv_block_utils.h +++ b/lite/backends/arm/math/conv_block_utils.h @@ -254,6 +254,7 @@ inline void prepack_input_nxwc4_dw(const float* din, LOG(FATAL) << "prepack_dw_input, valid height must > zero"; } float32x4_t vzero = vdupq_n_f32(0.f); + auto out_data = dout; int size_w = we - ws; int w0 = ws < 0 ? 0 : ws; @@ -269,6 +270,7 @@ inline void prepack_input_nxwc4_dw(const float* din, bool flag_ext_l = left_remain > 0; int left_sl = 4 - left_remain; + int left_valid_sl = left_sl > width ? width : left_sl; uint32x4_t vmask_padl; bool flag_mask_l = false; if (flag_ext_l) { @@ -290,6 +292,7 @@ inline void prepack_input_nxwc4_dw(const float* din, } int size_c = width * height; for (int h = hs; h < he; ++h) { + dout = out_data + (h - hs) * 4 * size_w; auto ptr_c0 = din + cs * size_c + h * width; auto ptr_c1 = ptr_c0 + size_c; auto ptr_c2 = ptr_c1 + size_c; @@ -351,10 +354,10 @@ inline void prepack_input_nxwc4_dw(const float* din, } transpose_4x4(vc0, vc1, vc2, vc3, dout); dout += 16; - ptr_c0 += left_sl; - ptr_c1 += left_sl; - ptr_c2 += left_sl; - ptr_c3 += left_sl; + ptr_c0 += left_valid_sl; + ptr_c1 += left_valid_sl; + ptr_c2 += left_valid_sl; + ptr_c3 += left_valid_sl; } /// valid for (int i = 0; i < cnt_valid; ++i) { @@ -986,7 +989,9 @@ inline bool write_to_output_c4_fp32(const float* din, int size_h = (he > height ? height : he) - hs; // size_h == hei_n - int cnt = (width - ws) / w4; + int valid_we = we > width ? width : we; + int cnt = (valid_we - ws) / w4; + int remain = valid_we - ws - cnt * w4; for (int i = 0; i < size_h; i++) { int size_w = i * width; @@ -1087,12 +1092,12 @@ inline bool write_to_output_c4_fp32(const float* din, #endif } } - if (we > width) { + if (remain > 0) { int offset = i * w_round * c4 + c4 * w4 * cnt; din_hei_ptr = ptr_din + offset; - int j = we - w4; + int j = 0; if (flag_relu) { - for (; j < width; ++j) { + for (; j < remain; ++j) { *(doutc0_ptr++) = LITEMAX(din_hei_ptr[0], 0.f); *(doutc1_ptr++) = LITEMAX(din_hei_ptr[1], 0.f); *(doutc2_ptr++) = LITEMAX(din_hei_ptr[2], 0.f); @@ -1100,7 +1105,7 @@ inline bool write_to_output_c4_fp32(const float* din, din_hei_ptr += w4; } } else { - for (; j < width; ++j) { + for (; j < remain; ++j) { *(doutc0_ptr++) = din_hei_ptr[0]; *(doutc1_ptr++) = din_hei_ptr[1]; *(doutc2_ptr++) = din_hei_ptr[2]; diff --git a/lite/backends/arm/math/conv_impl.h b/lite/backends/arm/math/conv_impl.h index c5baa31e14..f4d00039aa 100644 --- a/lite/backends/arm/math/conv_impl.h +++ b/lite/backends/arm/math/conv_impl.h @@ -314,7 +314,23 @@ void fill_bias_int8(int* tensor, const int* bias, int channel, int channel_size); +// new winograd +void weight_trans_c4( + float* dest, const float* src, int ic, int oc, void* workspace); +void conv_compute_6x6_3x3(const float* input, + float* output, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + const float* weight, + const float* bias, + const operators::ConvParam& param, + ARMContext* ctx); } // namespace math } // namespace arm } // namespace lite diff --git a/lite/backends/arm/math/packed_sgemm_c4.h b/lite/backends/arm/math/packed_sgemm_c4.h index 0b88de36d7..21e5af6343 100644 --- a/lite/backends/arm/math/packed_sgemm_c4.h +++ b/lite/backends/arm/math/packed_sgemm_c4.h @@ -37,6 +37,16 @@ void sgemm_prepack_c4(int M, bool has_bias, bool has_relu, ARMContext* ctx); +void sgemm_prepack_c4_small(int M, + int N, + int K, + const float* A_packed, + const float* B, + float* C, + const float* bias, + bool has_bias, + bool has_relu, + ARMContext* ctx); } // namespace math } // namespace arm } // namespace lite diff --git a/lite/kernels/arm/conv_compute.cc b/lite/kernels/arm/conv_compute.cc index 799e8e2122..8fed33bf69 100644 --- a/lite/kernels/arm/conv_compute.cc +++ b/lite/kernels/arm/conv_compute.cc @@ -40,6 +40,7 @@ void ConvCompute::PrepareForRun() { int kw = w_dims[3]; int pad = paddings[0]; int stride = param.strides[0]; + int threads = ctx.threads(); bool pads_equal = ((paddings[0] == paddings[1]) && (paddings[2] == paddings[3])); @@ -67,7 +68,11 @@ void ConvCompute::PrepareForRun() { VLOG(3) << "invoking dw conv"; } else if (param.groups == 1 && kw == 3 && stride == 1 && kps_equal && no_dilation) { - if (ic >= 32 && oc >= 32 && hout > 16 && wout > 16) { + bool use_winograd = + (threads == 1 && oc >= 4 && ic >= 4 && hout >= 6 && wout >= 6 && + pads_equal) || + (oc >= 32 && ic >= 32 && hout >= 16 && wout >= 16 && pads_equal); + if (use_winograd) { /// winograd conv impl impl_ = new WinogradConv; VLOG(3) << "invoking winograd conv"; diff --git a/lite/kernels/arm/conv_winograd.cc b/lite/kernels/arm/conv_winograd.cc index d1b8d8a48e..d02cabf277 100644 --- a/lite/kernels/arm/conv_winograd.cc +++ b/lite/kernels/arm/conv_winograd.cc @@ -26,6 +26,7 @@ template <> void WinogradConv::ReInitWhenNeeded() { auto& param = this->Param(); auto& ctx = this->ctx_->template As(); + int threads = ctx.threads(); auto x_dims = param.x->dims(); auto w_dims = param.filter->dims(); @@ -36,77 +37,97 @@ void WinogradConv::ReInitWhenNeeded() { } int ic = x_dims[1]; - int ow = o_dims[3]; - int oh = o_dims[2]; + int ih = x_dims[2]; + int iw = x_dims[3]; int oc = o_dims[1]; - int tile_w = (ow + 5) / 6; - int tile_h = (oh + 5) / 6; - int size_tile = tile_h * tile_w; - int size_trans_channel = 8 * 8 * size_tile; - int max_ch = ic > oc ? ic : oc; - - const int n_wino = size_tile; - workspace_size_ = (size_trans_channel * max_ch * 2 + n_wino) * sizeof(float); + int oh = o_dims[2]; + int ow = o_dims[3]; + int tile_block = 8; +#ifdef __aarch64__ + tile_block = 16; +#endif + int parallel_threads = + (((ow + 5) / 6) * ((oh + 5) / 6) + tile_block - 1) / tile_block; + if (threads <= 2 && parallel_threads >= threads) { + if (last_kernel_is_c4_ == 1) { + return; + } + last_kernel_is_c4_ = 1; + auto pad = *(param.paddings); + int pad_h = pad[0]; + int pad_w = pad[2]; + int oc_pad = (oc + 3) / 4 * 4; + int ic_pad = (ic + 3) / 4 * 4; + const int new_input_size = + (ic + 3) / 4 * 4 * (ih + pad_h * 2) * (iw + pad_w * 2); + const int temp_size = + (tile_block * ((ic + 3) / 4 + (oc + 3) / 4) * 256 + 512) * threads; + ctx.ExtendWorkspace((temp_size + new_input_size) * sizeof(float)); + + weights_.Resize({1, 1, 1, 64 * oc_pad * ic_pad}); + ctx.ExtendWorkspace((temp_size + new_input_size) * sizeof(float)); + void* trans_tmp_ptr = malloc(sizeof(float) * 8 * 8 * oc * ic); + auto weights_data_ = weights_.mutable_data(); + lite::arm::math::weight_trans_c4( + weights_data_, param.filter->data(), ic, oc, trans_tmp_ptr); + free(trans_tmp_ptr); + } else { + if (last_kernel_is_c4_ == 0) { + return; + } + last_kernel_is_c4_ = 0; + int tile_w = (ow + 5) / 6; + int tile_h = (oh + 5) / 6; + + int size_tile = tile_h * tile_w; + int size_trans_channel = 8 * 8 * size_tile; + int max_ch = ic > oc ? ic : oc; + + const int n_wino = size_tile; + ctx.ExtendWorkspace((size_trans_channel * max_ch * 2 + n_wino) * + sizeof(float)); + + const int m_wino = oc; + int hblock = lite::arm::math::get_hblock(&ctx); + int m_round = hblock * ((m_wino + hblock - 1) / hblock); + weights_.Resize({1, 1, 1, 8 * 8 * m_round * ic}); + ctx.ExtendWorkspace((size_trans_channel * max_ch * 2 + n_wino) * + sizeof(float)); + auto weights_wino = + static_cast(malloc(sizeof(float) * 8 * 8 * oc * ic)); + void* trans_tmp_ptr = malloc(sizeof(float) * 8 * 8 * oc * ic); + lite::arm::math::winograd_transform_weights( + weights_wino, param.filter->data(), oc, ic, trans_tmp_ptr); + auto weights_trans = weights_.mutable_data(); + for (int i = 0; i < 64; ++i) { + float* packed_weights = weights_trans + i * m_round * ic; + const float* weights_wino_ptr = weights_wino + i * oc * ic; + lite::arm::math::prepackA(packed_weights, + weights_wino_ptr, + 1.f, + ic, + 0, + m_wino, + 0, + ic, + false, + &ctx); + } + free(trans_tmp_ptr); + free(weights_wino); + } last_shape_ = x_dims; } template <> void WinogradConv::PrepareForRun() { - auto& param = this->Param(); - auto& ctx = this->ctx_->template As(); - - auto x_dims = param.x->dims(); - auto w_dims = param.filter->dims(); - auto o_dims = param.output->dims(); - last_shape_ = x_dims; - - int ic = x_dims[1]; - int ow = o_dims[3]; - int oh = o_dims[2]; - int oc = o_dims[1]; - int tile_w = (ow + 5) / 6; - int tile_h = (oh + 5) / 6; - int size_tile = tile_h * tile_w; - int size_trans_channel = 8 * 8 * size_tile; - int max_ch = ic > oc ? ic : oc; - - const int m_wino = oc; - const int n_wino = size_tile; - int hblock = lite::arm::math::get_hblock(&ctx); - int m_round = hblock * ((m_wino + hblock - 1) / hblock); - weights_.Resize({1, 1, 1, 8 * 8 * m_round * ic}); - workspace_size_ = (size_trans_channel * max_ch * 2 + n_wino) * sizeof(float); - auto weights_wino = - static_cast(malloc(sizeof(float) * 8 * 8 * oc * ic)); - void* trans_tmp_ptr = malloc(sizeof(float) * 8 * 8 * oc * ic); - lite::arm::math::winograd_transform_weights( - weights_wino, param.filter->data(), oc, ic, trans_tmp_ptr); - auto weights_trans = weights_.mutable_data(); - for (int i = 0; i < 64; ++i) { - float* packed_weights = weights_trans + i * m_round * ic; - const float* weights_wino_ptr = weights_wino + i * oc * ic; - lite::arm::math::prepackA(packed_weights, - weights_wino_ptr, - 1.f, - ic, - 0, - m_wino, - 0, - ic, - false, - &ctx); - } - free(trans_tmp_ptr); - free(weights_wino); + ReInitWhenNeeded(); } template <> void WinogradConv::Run() { auto& param = this->Param(); auto& ctx = this->ctx_->template As(); - // extend workspace - ctx.ExtendWorkspace(workspace_size_); - const auto* i_data = param.x->data(); const auto* w_data = weights_.data(); const auto* b_data = param.bias ? param.bias->data() : nullptr; @@ -124,8 +145,42 @@ void WinogradConv::Run() { int ow = o_dims[3]; int oc = o_dims[1]; - lite::arm::math::conv_winograd3x3( - i_data, o_data, bs, oc, oh, ow, ic, ih, iw, w_data, b_data, param, &ctx); + int tile_block = 8; +#ifdef __aarch64__ + tile_block = 16; +#endif + int threads = ctx.threads(); + int parallel_threads = + (((ow + 5) / 6) * ((oh + 5) / 6) + tile_block - 1) / tile_block; + if (threads <= 2 && parallel_threads >= threads) { + lite::arm::math::conv_compute_6x6_3x3(i_data, + o_data, + bs, + oc, + oh, + ow, + ic, + ih, + iw, + w_data, + b_data, + param, + &ctx); + } else { + lite::arm::math::conv_winograd3x3(i_data, + o_data, + bs, + oc, + oh, + ow, + ic, + ih, + iw, + w_data, + b_data, + param, + &ctx); + } } } // namespace arm diff --git a/lite/kernels/arm/conv_winograd.h b/lite/kernels/arm/conv_winograd.h index 33f0edc017..40ea54b291 100644 --- a/lite/kernels/arm/conv_winograd.h +++ b/lite/kernels/arm/conv_winograd.h @@ -40,6 +40,7 @@ class WinogradConv : public KernelLite { Tensor weights_; DDim last_shape_; int workspace_size_{0}; + int last_kernel_is_c4_{-1}; }; } // namespace arm -- GitLab