未验证 提交 53570d38 编写于 作者: H HappyAngel 提交者: GitHub

[cherry-pick] fix xiaodu crash and profiler (#3906)

* [arm]add 2x2s2p1 pooling  (#3705)

* fix pooling bug and speed

* add 2x2s2p1 pooling. test=develop

* fix conflict, test=develop

* fix conflict in wino

* [arm] add 3x3s1 Winograd int8 (#3767)

* fix: winograd support unsame pad
test=develop

* feat: add winograd int8 kernel
test=develop

* fix: style fix
test=develo

* fix winograd_int8 ut sgement default. test=develop

* close basic_test, test=develop
Co-authored-by: NMyPandaShaoxiang <txg4794@163.com>

* fix xiaodu crash in gemm prepacked

* in huwen phone, 3x3s2p0 avg pooling will rand crash, other phone does not have this feature

* [arm] update con int8 kernel choose (#3834)

* fix conv int8 kernel choose and sooftmax compute bug

* change axis_size = 4 kernel choose, test=develop

* fix format. test=develop

* fix format.test=develop

* fix build test=develop

* fix buiild error test=develop

* fix wino_int8 computte erroor. test=develop

* Update the link to debug, test=develop, test=document_fix (#3870) (#3871)
Co-authored-by: NMyPandaShaoxiang <txg4794@163.com>
Co-authored-by: Ncc <52520497+juncaipeng@users.noreply.github.com>
上级 1166948a
......@@ -49,4 +49,4 @@ $ ./opt \
## 五. 测试工具
为了使您更好的了解并使用Lite框架,我们向有进一步使用需求的用户开放了 [Debug工具](debug#debug)[Profile工具](debug#profiler)。Lite Model Debug Tool可以用来查找Lite框架与PaddlePaddle框架在执行预测时模型中的对应变量值是否有差异,进一步快速定位问题Op,方便复现与排查问题。Profile Monitor Tool可以帮助您了解每个Op的执行时间消耗,其会自动统计Op执行的次数,最长、最短、平均执行时间等等信息,为性能调优做一个基础参考。您可以通过 [相关专题](debug) 了解更多内容。
为了使您更好的了解并使用Lite框架,我们向有进一步使用需求的用户开放了 [Debug工具](debug)[Profile工具](debug)。Lite Model Debug Tool可以用来查找Lite框架与PaddlePaddle框架在执行预测时模型中的对应变量值是否有差异,进一步快速定位问题Op,方便复现与排查问题。Profile Monitor Tool可以帮助您了解每个Op的执行时间消耗,其会自动统计Op执行的次数,最长、最短、平均执行时间等等信息,为性能调优做一个基础参考。您可以通过 [相关专题](debug) 了解更多内容。
......@@ -83,6 +83,7 @@ if (NOT HAS_ARM_MATH_LIB_DIR)
conv5x5s2_depthwise_int8.cc
conv5x5s2_depthwise_fp32.cc
conv3x3_winograd_fp32_c4.cc
conv3x3_winograd_int8.cc
conv_winograd_3x3.cc
conv_impl.cc
softmax.cc
......
......@@ -1245,7 +1245,7 @@ void weight_trans_c4_8x8(
for (int i = 0; i < ch_out * ch_in * 64; ++i) {
int new_c = i % 64;
int new_oc = i / ch_in / 64 / 4;
int new_ic = i / 64 % (ch_in * 4) % ch_in;
int new_ic = i / 64 % ch_in;
int new_inner = i / ch_in / 64 % 4;
int dest_ind =
new_c * c_stride + new_oc * ic_pad * 4 + new_ic * 4 + new_inner;
......@@ -1302,7 +1302,7 @@ void weight_trans_c4_4x4(
for (int i = 0; i < ch_out * ch_in * 16; ++i) {
int new_c = i % 16;
int new_oc = i / ch_in / 16 / 4;
int new_ic = i / 16 % (ch_in * 4) % ch_in;
int new_ic = i / 16 % ch_in;
int new_inner = i / ch_in / 16 % 4;
int dest_ind =
new_c * c_stride + new_oc * ic_pad * 4 + new_ic * 4 + new_inner;
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/backends/arm/math/conv_block_utils.h"
#include "lite/backends/arm/math/conv_impl.h"
#include "lite/backends/arm/math/packed_sgemm_c4.h"
#ifdef ARM_WITH_OMP
#include <omp.h>
#endif
#include <arm_neon.h>
namespace paddle {
namespace lite {
namespace arm {
namespace math {
void input_trans_c8_4x4_int8(const int8_t* src,
int src_stride,
int src_h_stride,
int16_t* dest,
int dest_stride,
int dest_h_stride);
void output_trans_c8_post_2x4_int8(const int32_t* src,
int src_stride,
int src_h_stride,
int32_t* dest,
int dest_stride,
int dest_h_stride);
void weight_trans_c8_4x4_int8(
int16_t* dest, const int8_t* src, int ic, int oc, void* workspace);
// F(2,3)
template <typename Dtype>
void conv_compute_2x2_3x3_int8(const int8_t* input,
Dtype* output,
int num,
int chout,
int hout,
int wout,
int chin,
int hin,
int win,
const int16_t* weight,
const float* bias,
const float* scale,
const operators::ConvParam& param,
ARMContext* ctx) {
auto act_param = param.activation_param;
const int pad_h0 = (*param.paddings)[0];
const int pad_h1 = (*param.paddings)[1];
const int pad_w0 = (*param.paddings)[2];
const int pad_w1 = (*param.paddings)[3];
int8_t* tmp_work_space =
ctx->workspace_data<int8_t>() + ctx->llc_size() / sizeof(int8_t);
int in_n_stride = chin * hin * win;
int out_n_stride = chout * hout * wout;
int ic_stride = win * hin;
int oc_stride = wout * hout;
int ic_8 = (chin + 7) / 8;
int oc_8 = (chout + 7) / 8;
int tile_w = (wout + 1) / 2;
int tile_h = (hout + 1) / 2;
int size_tile = tile_h * tile_w;
int w_pad = win + pad_w0 + pad_w1;
int h_pad = hin + pad_h0 + pad_h1;
const int zero_len = (w_pad + 3) / 4 * 4;
Dtype zero_ptr[zero_len]; // NOLINT
memset(zero_ptr, 0, zero_len * sizeof(Dtype));
int8_t* input_c8 = tmp_work_space;
int new_h_stride = w_pad * 8;
int new_c_stride = new_h_stride * h_pad;
int ic_8_stride = w_pad * h_pad * 8;
int oc_8_stride = wout * hout * 8;
int tile_block = 8;
int block_count = (size_tile + tile_block - 1) / tile_block;
int threads = ctx->threads();
int16_t* g_tmp_data =
(int16_t*)(tmp_work_space + ic_8 * ic_8_stride + // NOLINT
oc_8 * oc_8_stride * sizeof(int32_t));
int tmp_input_thread_stride = tile_block * ic_8 * 128;
int tmp_output_thread_stride = tile_block * oc_8 * 128;
int tmp_data_thread_stride_size = tmp_input_thread_stride * sizeof(int16_t) +
tmp_output_thread_stride * sizeof(int32_t);
memset(g_tmp_data, 0, tmp_data_thread_stride_size);
int8_t* g_trans_remain_tmp_data =
(int8_t*)(g_tmp_data + // NOLINT
threads * (tmp_input_thread_stride +
tmp_output_thread_stride * sizeof(int32_t) /
sizeof(int16_t)));
int32_t* g_trans_tmp_data =
(int32_t*)(g_trans_remain_tmp_data + threads * 128); // NOLINT
auto act_type = act_param.active_type;
int flag_act = 0; // relu: 1, relu6: 2, leakey: 3
float alpha[4] = {0.f, 0.f, 0.f, 0.f};
if (act_param.has_active) {
if (act_type == lite_api::ActivationType::kRelu) {
flag_act = 1;
} else if (act_type == lite_api::ActivationType::kRelu6) {
flag_act = 2;
float local_alpha = act_param.Relu_clipped_coef;
alpha[0] = local_alpha;
alpha[1] = local_alpha;
alpha[2] = local_alpha;
alpha[3] = local_alpha;
} else if (act_type == lite_api::ActivationType::kLeakyRelu) {
flag_act = 3;
float local_alpha = act_param.Leaky_relu_alpha;
alpha[0] = local_alpha;
alpha[1] = local_alpha;
alpha[2] = local_alpha;
alpha[3] = local_alpha;
}
}
// begin compute
for (int ni = 0; ni < num; ++ni) {
// trans input to c8
for (int i = 0; i < ic_8; ++i) {
prepack_input_nxwc8_int8_dw(input + ni * in_n_stride,
input_c8 + i * new_c_stride,
i * 8,
-pad_h0,
hin + pad_h1,
-pad_w0,
win + pad_w1,
chin,
win,
hin);
}
int32_t* output_c8 = (int32_t*)(input_c8 + ic_8 * ic_8_stride); // NOLINT
Dtype* output_ptr = output + ni * out_n_stride;
const int16_t* weight_ptr = weight;
#pragma omp parallel for num_threads(threads)
for (int tbi = 0; tbi < block_count; ++tbi) {
#ifdef ARM_WITH_OMP
int16_t* tmp_data =
g_tmp_data +
omp_get_thread_num() * tmp_data_thread_stride_size / sizeof(int16_t);
int32_t* trans_tmp_data = g_trans_tmp_data + omp_get_thread_num() * 32;
int8_t* trans_remain_tmp_data =
g_trans_remain_tmp_data + omp_get_thread_num() * 128;
#else
int16_t* tmp_data = g_tmp_data;
int32_t* trans_tmp_data = g_trans_tmp_data;
int8_t* trans_remain_tmp_data = g_trans_remain_tmp_data;
#endif
int tile_index = tbi * tile_block;
int tile_remain = size_tile - tile_index;
int tile_count = tile_remain > tile_block ? tile_block : tile_remain;
// input trans
int c_gi_stride = tile_count * oc_8 * 8;
int b_gi_stride = tile_count * ic_8 * 8;
//*
for (int ti = 0; ti < tile_count; ++ti) {
int index = tile_index + ti;
int tw_index = index % tile_w;
int th_index = index / tile_w;
int src_x = tw_index + tw_index;
int src_y = th_index + th_index;
int ex = src_x + 4 > w_pad ? w_pad - src_x : 4;
int ey = src_y + 4 > h_pad ? h_pad - src_y : 4;
int16_t* dst_ptr = tmp_data + ti * 8;
const int8_t* src_ptr = input_c8 + (src_y * w_pad + src_x) * 8;
if (ex == 4 && ey == 4) {
// trans input
for (int ci = 0; ci < ic_8; ++ci) {
const int8_t* src_ci = src_ptr + ci * ic_8_stride;
int16_t* dst_ci = dst_ptr + ci * tile_count * 8;
input_trans_c8_4x4_int8(
src_ci, 8, w_pad * 8, dst_ci, b_gi_stride, b_gi_stride * 4);
}
} else {
// trans remain input
int x_size = ex;
for (int ci = 0; ci < ic_8; ++ci) {
const int8_t* src_ci = src_ptr + ci * ic_8_stride;
// pad
memset(trans_remain_tmp_data, 0, 128 * sizeof(int8_t));
if (x_size > 0) {
for (int yi = 0; yi < ey; ++yi) {
int8_t* dst_yi = trans_remain_tmp_data + yi * 32;
const int8_t* src_yi = src_ci + w_pad * yi * 8;
memcpy(dst_yi, src_yi, x_size * sizeof(int8_t) * 8);
}
}
// trans
int16_t* dst_ci = dst_ptr + ci * tile_count * 8;
input_trans_c8_4x4_int8(trans_remain_tmp_data,
8,
32,
dst_ci,
b_gi_stride,
b_gi_stride * 4);
} // for ci_4
}
}
//*/
// input trans end
// *begin compute dot
// *
//*
int32_t* dst_temp_data =
(int32_t*)(tmp_data + tmp_input_thread_stride); // NOLINT
int16_t* b_ptr = tmp_data;
int w_gi_stride = ic_8 * oc_8 * 64;
for (int gi = 0; gi < 16; ++gi) {
int32_t* origin_C = dst_temp_data + gi * c_gi_stride;
int16_t* origin_B = b_ptr + gi * b_gi_stride;
const int16_t* origin_A = weight + gi * w_gi_stride;
sgemm_prepack_c8_int16_small(
oc_8 * 8, tile_count, ic_8 * 8, origin_A, origin_B, origin_C, ctx);
}
//*/
//*
// output trans
for (int ti = 0; ti < tile_count; ++ti) {
int index = tile_index + ti;
int tw_index = index % tile_w;
int th_index = index / tile_w;
int dst_x = tw_index * 2;
int dst_y = th_index * 2;
int ex = dst_x + 2 > wout ? wout - dst_x : 2;
int ey = dst_y + 2 > hout ? hout - dst_y : 2;
int32_t* src_ptr = dst_temp_data + ti * 8;
int32_t* trans_remain_tmp_i32_data =
(int32_t*)(trans_remain_tmp_data); // NOLINT
int32_t* dst_ptr = output_c8 + (dst_y * wout + dst_x) * 8;
if (ex == 2 && ey == 2) {
// trans output
for (int ci = 0; ci < oc_8; ++ci) {
int cur_ind = ci * 8;
int32_t* src_ci = src_ptr + ci * tile_count * 8;
int32_t* dst_ci = dst_ptr + ci * oc_8_stride;
output_trans_c8_post_2x4_int8(
src_ci, c_gi_stride, c_gi_stride * 4, dst_ci, 8, wout * 8);
}
} else {
for (int ci = 0; ci < oc_8; ++ci) {
int cur_ind = ci * 8;
// trans output
int32_t* src_ci = src_ptr + ci * tile_count * 8;
output_trans_c8_post_2x4_int8(src_ci,
c_gi_stride,
c_gi_stride * 4,
trans_remain_tmp_i32_data,
8,
16);
// copy to dest
int32_t* dst_ci = dst_ptr + ci * oc_8_stride;
for (int i = 0; i < ey; ++i) {
memcpy(dst_ci + i * wout * 8,
trans_remain_tmp_i32_data + i * 16,
ex * sizeof(int32_t) * 8);
}
}
}
}
//*/
} // for block_count
const float* bias_local_ptr = bias;
for (int ci = 0; ci < oc_8; ++ci) {
float bias_local[8] = {bias_local_ptr[0],
bias_local_ptr[1],
bias_local_ptr[2],
bias_local_ptr[3],
bias_local_ptr[4],
bias_local_ptr[5],
bias_local_ptr[6],
bias_local_ptr[7]};
write_int32_nchwc8_to_nchw(output_c8 + ci * oc_8_stride,
output_ptr,
ci * 8,
ci * 8 + 8,
0,
hout,
0,
wout,
chout,
hout,
wout,
flag_act > 0,
bias_local,
param.bias,
zero_ptr,
scale + ci * 8);
bias_local_ptr += 8;
}
} // for num
} // conv compute
template void conv_compute_2x2_3x3_int8<int8_t>(
const int8_t* input,
int8_t* output,
int num,
int chout,
int hout,
int wout,
int chin,
int hin,
int win,
const int16_t* weight,
const float* bias,
const float* scale,
const operators::ConvParam& param,
ARMContext* ctx);
template void conv_compute_2x2_3x3_int8<float>(
const int8_t* input,
float* output,
int num,
int chout,
int hout,
int wout,
int chin,
int hin,
int win,
const int16_t* weight,
const float* bias,
const float* scale,
const operators::ConvParam& param,
ARMContext* ctx);
// BT=[1, 0, -1, 0,
// 0, 1, 1, 0,
// 0, -1, 1, 0,
// 0, 1, 0, -1]
void input_trans_c8_4x4_int8(const int8_t* src,
int src_stride,
int src_h_stride,
int16_t* dest,
int dest_stride,
int dest_h_stride) {
int8x8_t src00 = vld1_s8(src);
int8x8_t src01 = vld1_s8(src + src_stride);
int8x8_t src02 = vld1_s8(src + src_stride + src_stride);
int8x8_t src03 = vld1_s8(src + src_stride + src_stride + src_stride);
src += src_h_stride;
int8x8_t src10 = vld1_s8(src);
int8x8_t src11 = vld1_s8(src + src_stride);
int8x8_t src12 = vld1_s8(src + src_stride + src_stride);
int8x8_t src13 = vld1_s8(src + src_stride + src_stride + src_stride);
src += src_h_stride;
int8x8_t src20 = vld1_s8(src);
int8x8_t src21 = vld1_s8(src + src_stride);
int8x8_t src22 = vld1_s8(src + src_stride + src_stride);
int8x8_t src23 = vld1_s8(src + src_stride + src_stride + src_stride);
src += src_h_stride;
int8x8_t src30 = vld1_s8(src);
int8x8_t src31 = vld1_s8(src + src_stride);
int8x8_t src32 = vld1_s8(src + src_stride + src_stride);
int8x8_t src33 = vld1_s8(src + src_stride + src_stride + src_stride);
int16x8_t dst00 = vsubl_s8(src00, src02);
int16x8_t dst10 = vaddl_s8(src01, src02);
int16x8_t dst20 = vsubl_s8(src02, src01);
int16x8_t dst30 = vsubl_s8(src01, src03);
int16x8_t dst01 = vsubl_s8(src10, src12);
int16x8_t dst11 = vaddl_s8(src11, src12);
int16x8_t dst21 = vsubl_s8(src12, src11);
int16x8_t dst31 = vsubl_s8(src11, src13);
int16x8_t dst02 = vsubl_s8(src20, src22);
int16x8_t dst12 = vaddl_s8(src21, src22);
int16x8_t dst22 = vsubl_s8(src22, src21);
int16x8_t dst32 = vsubl_s8(src21, src23);
int16x8_t dst03 = vsubl_s8(src30, src32);
int16x8_t dst13 = vaddl_s8(src31, src32);
int16x8_t dst23 = vsubl_s8(src32, src31);
int16x8_t dst33 = vsubl_s8(src31, src33);
int16x8_t dest00 = vsubq_s16(dst00, dst02);
int16x8_t dest10 = vaddq_s16(dst01, dst02);
int16x8_t dest20 = vsubq_s16(dst02, dst01);
int16x8_t dest30 = vsubq_s16(dst01, dst03);
int16x8_t dest01 = vsubq_s16(dst10, dst12);
int16x8_t dest11 = vaddq_s16(dst11, dst12);
int16x8_t dest21 = vsubq_s16(dst12, dst11);
int16x8_t dest31 = vsubq_s16(dst11, dst13);
int16x8_t dest02 = vsubq_s16(dst20, dst22);
int16x8_t dest12 = vaddq_s16(dst21, dst22);
int16x8_t dest22 = vsubq_s16(dst22, dst21);
int16x8_t dest32 = vsubq_s16(dst21, dst23);
int16x8_t dest03 = vsubq_s16(dst30, dst32);
int16x8_t dest13 = vaddq_s16(dst31, dst32);
int16x8_t dest23 = vsubq_s16(dst32, dst31);
int16x8_t dest33 = vsubq_s16(dst31, dst33);
vst1q_s16(dest, dest00);
vst1q_s16(dest + dest_stride, dest10);
vst1q_s16(dest + dest_stride + dest_stride, dest20);
vst1q_s16(dest + dest_stride + dest_stride + dest_stride, dest30);
dest += dest_h_stride;
vst1q_s16(dest, dest01);
vst1q_s16(dest + dest_stride, dest11);
vst1q_s16(dest + dest_stride + dest_stride, dest21);
vst1q_s16(dest + dest_stride + dest_stride + dest_stride, dest31);
dest += dest_h_stride;
vst1q_s16(dest, dest02);
vst1q_s16(dest + dest_stride, dest12);
vst1q_s16(dest + dest_stride + dest_stride, dest22);
vst1q_s16(dest + dest_stride + dest_stride + dest_stride, dest32);
dest += dest_h_stride;
vst1q_s16(dest, dest03);
vst1q_s16(dest + dest_stride, dest13);
vst1q_s16(dest + dest_stride + dest_stride, dest23);
vst1q_s16(dest + dest_stride + dest_stride + dest_stride, dest33);
}
// AT=[1, 1, 1, 0,
// 0, 1, -1, -1]
void output_trans_c8_post_2x4_int8(const int32_t* src,
int src_stride,
int src_h_stride,
int32_t* dest,
int dest_stride,
int dest_h_stride) {
int32x4_t src400 = vld1q_s32(src);
int32x4_t src800 = vld1q_s32(src + 4);
src += src_stride;
int32x4_t src401 = vld1q_s32(src);
int32x4_t src801 = vld1q_s32(src + 4);
src += src_stride;
int32x4_t src402 = vld1q_s32(src);
int32x4_t src802 = vld1q_s32(src + 4);
src += src_stride;
int32x4_t src403 = vld1q_s32(src);
int32x4_t src803 = vld1q_s32(src + 4);
src += src_h_stride - 3 * src_stride;
int32x4_t src410 = vld1q_s32(src);
int32x4_t src810 = vld1q_s32(src + 4);
src += src_stride;
int32x4_t src411 = vld1q_s32(src);
int32x4_t src811 = vld1q_s32(src + 4);
src += src_stride;
int32x4_t src412 = vld1q_s32(src);
int32x4_t src812 = vld1q_s32(src + 4);
src += src_stride;
int32x4_t src413 = vld1q_s32(src);
int32x4_t src813 = vld1q_s32(src + 4);
src += src_h_stride - 3 * src_stride;
int32x4_t src420 = vld1q_s32(src);
int32x4_t src820 = vld1q_s32(src + 4);
src += src_stride;
int32x4_t src421 = vld1q_s32(src);
int32x4_t src821 = vld1q_s32(src + 4);
src += src_stride;
int32x4_t src422 = vld1q_s32(src);
int32x4_t src822 = vld1q_s32(src + 4);
src += src_stride;
int32x4_t src423 = vld1q_s32(src);
int32x4_t src823 = vld1q_s32(src + 4);
src += src_h_stride - 3 * src_stride;
int32x4_t src430 = vld1q_s32(src);
int32x4_t src830 = vld1q_s32(src + 4);
src += src_stride;
int32x4_t src431 = vld1q_s32(src);
int32x4_t src831 = vld1q_s32(src + 4);
src += src_stride;
int32x4_t src432 = vld1q_s32(src);
int32x4_t src832 = vld1q_s32(src + 4);
src += src_stride;
int32x4_t src433 = vld1q_s32(src);
int32x4_t src833 = vld1q_s32(src + 4);
int32x4_t dst400 = vaddq_s32(vaddq_s32(src400, src401), src402);
int32x4_t dst410 = vsubq_s32(vsubq_s32(src401, src402), src403);
int32x4_t dst401 = vaddq_s32(vaddq_s32(src410, src411), src412);
int32x4_t dst411 = vsubq_s32(vsubq_s32(src411, src412), src413);
int32x4_t dst402 = vaddq_s32(vaddq_s32(src420, src421), src422);
int32x4_t dst412 = vsubq_s32(vsubq_s32(src421, src422), src423);
int32x4_t dst403 = vaddq_s32(vaddq_s32(src430, src431), src432);
int32x4_t dst413 = vsubq_s32(vsubq_s32(src431, src432), src433);
int32x4_t dst800 = vaddq_s32(vaddq_s32(src800, src801), src802);
int32x4_t dst810 = vsubq_s32(vsubq_s32(src801, src802), src803);
int32x4_t dst801 = vaddq_s32(vaddq_s32(src810, src811), src812);
int32x4_t dst811 = vsubq_s32(vsubq_s32(src811, src812), src813);
int32x4_t dst802 = vaddq_s32(vaddq_s32(src820, src821), src822);
int32x4_t dst812 = vsubq_s32(vsubq_s32(src821, src822), src823);
int32x4_t dst803 = vaddq_s32(vaddq_s32(src830, src831), src832);
int32x4_t dst813 = vsubq_s32(vsubq_s32(src831, src832), src833);
int32x4_t dest400 = vaddq_s32(vaddq_s32(dst400, dst401), dst402);
int32x4_t dest410 = vsubq_s32(vsubq_s32(dst401, dst402), dst403);
int32x4_t dest401 = vaddq_s32(vaddq_s32(dst410, dst411), dst412);
int32x4_t dest411 = vsubq_s32(vsubq_s32(dst411, dst412), dst413);
int32x4_t dest800 = vaddq_s32(vaddq_s32(dst800, dst801), dst802);
int32x4_t dest810 = vsubq_s32(vsubq_s32(dst801, dst802), dst803);
int32x4_t dest801 = vaddq_s32(vaddq_s32(dst810, dst811), dst812);
int32x4_t dest811 = vsubq_s32(vsubq_s32(dst811, dst812), dst813);
vst1q_s32(dest, dest400);
vst1q_s32(dest + 4, dest800);
dest += dest_stride;
vst1q_s32(dest, dest410);
vst1q_s32(dest + 4, dest810);
dest += dest_h_stride - dest_stride;
vst1q_s32(dest, dest401);
vst1q_s32(dest + 4, dest801);
dest += dest_stride;
vst1q_s32(dest, dest411);
vst1q_s32(dest + 4, dest811);
}
void weight_trans_c8_4x4_int8(
int16_t* dest, const int8_t* din, int ch_in, int ch_out, void* workspace) {
const int16_t coeff[4][3] = {{2, 0, 0}, {1, 1, 1}, {1, -1, 1}, {0, 0, 2}};
int16_t* ptr_out = static_cast<int16_t*>(workspace);
for (int i = 0; i < ch_out; i++) {
for (int j = 0; j < ch_in; j++) {
const int8_t* kernel0 =
static_cast<const int8_t*>(din) + (i * ch_in + j) * 9;
int16_t* ptr_channel = ptr_out + (i * ch_in + j) * 16;
//! transform kernel, transposed
const int8_t* k0 = kernel0;
const int8_t* k1 = kernel0 + 3;
const int8_t* k2 = kernel0 + 6;
//! h
int16_t tmp[4][3];
for (int i = 0; i < 4; i++) {
tmp[i][0] =
k0[0] * coeff[i][0] + k0[1] * coeff[i][1] + k0[2] * coeff[i][2];
tmp[i][1] =
k1[0] * coeff[i][0] + k1[1] * coeff[i][1] + k1[2] * coeff[i][2];
tmp[i][2] =
k2[0] * coeff[i][0] + k2[1] * coeff[i][1] + k2[2] * coeff[i][2];
}
//! v
for (int j = 0; j < 4; j++) {
int16_t* tmpp = &tmp[j][0];
for (int i = 0; i < 4; i++) {
ptr_channel[j * 4 + i] = tmpp[0] * coeff[i][0] +
tmpp[1] * coeff[i][1] +
tmpp[2] * coeff[i][2];
}
}
}
}
int oc_pad = (ch_out + 7) / 8 * 8;
int ic_pad = (ch_in + 7) / 8 * 8;
int c_stride = ic_pad * oc_pad;
for (int i = 0; i < ch_out * ch_in * 16; ++i) {
int new_c = i % 16;
int new_oc = i / ch_in / 16 / 8;
int new_ic = i / 16 % ch_in;
int new_inner = i / ch_in / 16 % 8;
int dest_ind =
new_c * c_stride + new_oc * ic_pad * 8 + new_ic * 8 + new_inner;
dest[dest_ind] = ptr_out[i];
}
}
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
......@@ -3878,6 +3878,7 @@ inline void write_int32_nchwc8_to_nchw(const int* din,
int w_stride = we - ws;
int valid_w = (we > width ? width : we) - ws;
int cnt = valid_w / 4;
int remain = valid_w & 3;
float32x4_t w_scale0 = vld1q_f32(scale);
float32x4_t w_scale1 = vld1q_f32(scale + 4);
......@@ -3933,10 +3934,10 @@ inline void write_int32_nchwc8_to_nchw(const int* din,
w_bias1,
flag_relu);
}
if (we > width) {
if (remain > 0) {
int offset = 32 * cnt;
din_hei_ptr = ptr_din + offset;
for (int j = ws + cnt * 4; j < width; ++j) {
for (int j = 0; j < remain; ++j) {
if (flag_bias) {
*(doutc0_ptr++) =
cvt_kernel<Dtype>(din_hei_ptr[0], scale[0], bias[0], flag_relu);
......
......@@ -359,6 +359,35 @@ void conv_compute_2x2_3x3_small(const float* input,
const float* bias,
const operators::ConvParam& param,
ARMContext* ctx);
void input_trans_c8_4x4_int8(const int8_t* src,
int src_stride,
int src_h_stride,
int16_t* dest,
int dest_stride,
int dest_h_stride);
void output_trans_c8_post_2x4_int8(const int32_t* src,
int src_stride,
int src_h_stride,
int32_t* dest,
int dest_stride,
int dest_h_stride);
void weight_trans_c8_4x4_int8(
int16_t* dest, const int8_t* src, int ic, int oc, void* workspace);
template <typename Dtype>
void conv_compute_2x2_3x3_int8(const int8_t* input,
Dtype* output,
int num,
int chout,
int hout,
int wout,
int chin,
int hin,
int win,
const int16_t* weight,
const float* bias,
const float* scale,
const operators::ConvParam& param,
ARMContext* ctx);
template <typename Dtype>
void im2col(const Dtype* data_im,
......
......@@ -1922,19 +1922,45 @@ void gemm_prepack_oth_int8(const int8_t* A_packed,
Dtype* tmp1 = nullptr;
Dtype* tmp2 = nullptr;
Dtype* tmp3 = nullptr;
float32_t scale_local[4];
float32_t scale_local[4] = {0, 0, 0, 0};
float32_t bias_local[4] = {0, 0, 0, 0};
if (is_bias) {
if (y + 4 <= M) {
bias_local[0] = bias[y];
bias_local[1] = bias[y + 1];
bias_local[2] = bias[y + 2];
bias_local[3] = bias[y + 3];
} else {
switch (M - y) {
case 3:
bias_local[2] = bias[y + 2];
case 2:
bias_local[1] = bias[y + 1];
case 1:
bias_local[0] = bias[y + 0];
default:
break;
}
}
}
if (scale) {
if (y + 4 <= M) {
scale_local[0] = scale[y];
scale_local[1] = scale[y + 1];
scale_local[2] = scale[y + 2];
scale_local[3] = scale[y + 3];
} else {
switch (M - y) {
case 3:
scale_local[2] = scale[y + 2];
case 2:
scale_local[1] = scale[y + 1];
case 1:
scale_local[0] = scale[y + 0];
default:
break;
}
}
}
if (y + MBLOCK_INT8_OTH > M) {
switch (y + MBLOCK_INT8_OTH - M) {
......
......@@ -1679,6 +1679,912 @@ void sgemm_prepack_c4_small(int M,
}
}
void sgemm_prepack_c8_int16_small(int M,
int N,
int K,
const int16_t* A_packed,
const int16_t* B,
int32_t* C,
ARMContext* ctx) {
const int m_round = (M + 7) / 8 * 8;
const int k_round = (K + 7) / 8 * 8;
const int mloop = m_round >> 3;
const int lda = 8 * k_round;
const int ldb_byte = 8 * N * sizeof(int16_t);
const int kcnt = k_round >> 3;
#ifdef __aarch64__
float32x4_t vzero = vdupq_n_f32(0.f);
#endif
for (int m = 0; m < mloop; ++m) {
const int16_t* b = B;
int n = N;
#ifdef __aarch64__
for (; n > 7; n -= 8) {
int cnt = kcnt;
const int16_t* a_ptr = A_packed;
const int16_t* b_ptr = b;
// clang-format off
asm volatile(
"ld1 {v0.8h, v1.8h}, [%[a]], #32 \n" //load a0, a1
"ld1 {v4.8h, v5.8h}, [%[b]], #32 \n" //load b0, b1
"ld1 {v6.8h, v7.8h}, [%[b]], #32 \n" //load b2, b3
"smull v20.4s, v0.4h, v4.h[0] \n"
"smull v21.4s, v0.4h, v5.h[0] \n"
"smull v22.4s, v0.4h, v6.h[0] \n"
"smull v23.4s, v0.4h, v7.h[0] \n"
"ld1 {v8.8h, v9.8h}, [%[b]], #32 \n" //load b0, b1
"ld1 {v10.8h, v11.8h}, [%[b]], #32 \n" //load b2, b3
"smull2 v24.4s, v0.8h, v4.h[0] \n"
"smull2 v25.4s, v0.8h, v5.h[0] \n"
"smull2 v26.4s, v0.8h, v6.h[0] \n"
"smull2 v27.4s, v0.8h, v7.h[0] \n"
"ld1 {v2.8h, v3.8h}, [%[a]], #32 \n" //load a2, a3
"smlal v20.4s, v1.4h, v4.h[1] \n"
"smlal v21.4s, v1.4h, v5.h[1] \n"
"smlal v22.4s, v1.4h, v6.h[1] \n"
"smlal v23.4s, v1.4h, v7.h[1] \n"
"smlal2 v24.4s, v1.8h, v4.h[1] \n"
"smlal2 v25.4s, v1.8h, v5.h[1] \n"
"smlal2 v26.4s, v1.8h, v6.h[1] \n"
"smlal2 v27.4s, v1.8h, v7.h[1] \n"
"smull v12.4s, v0.4h, v8.h[0] \n"
"smull v13.4s, v0.4h, v9.h[0] \n"
"smull v14.4s, v0.4h, v10.h[0] \n"
"smull v15.4s, v0.4h, v11.h[0] \n"
"smull2 v16.4s, v0.8h, v8.h[0] \n"
"smull2 v17.4s, v0.8h, v9.h[0] \n"
"smull2 v18.4s, v0.8h, v10.h[0] \n"
"smull2 v19.4s, v0.8h, v11.h[0] \n"
"smlal v12.4s, v1.4h, v8.h[1] \n"
"smlal v13.4s, v1.4h, v9.h[1] \n"
"smlal v14.4s, v1.4h, v10.h[1] \n"
"smlal v15.4s, v1.4h, v11.h[1] \n"
"smlal2 v16.4s, v1.8h, v8.h[1] \n"
"smlal2 v17.4s, v1.8h, v9.h[1] \n"
"smlal2 v18.4s, v1.8h, v10.h[1] \n"
"smlal2 v19.4s, v1.8h, v11.h[1] \n"
"smlal v20.4s, v2.4h, v4.h[2] \n"
"smlal v21.4s, v2.4h, v5.h[2] \n"
"smlal v22.4s, v2.4h, v6.h[2] \n"
"smlal v23.4s, v2.4h, v7.h[2] \n"
"ld1 {v0.8h, v1.8h}, [%[a]], #32 \n" //load a0, a1
"smlal2 v24.4s, v2.8h, v4.h[2] \n"
"smlal2 v25.4s, v2.8h, v5.h[2] \n"
"smlal2 v26.4s, v2.8h, v6.h[2] \n"
"smlal2 v27.4s, v2.8h, v7.h[2] \n"
"smlal v12.4s, v2.4h, v8.h[2] \n"
"smlal v13.4s, v2.4h, v9.h[2] \n"
"smlal v14.4s, v2.4h, v10.h[2] \n"
"smlal v15.4s, v2.4h, v11.h[2] \n"
"smlal2 v16.4s, v2.8h, v8.h[2] \n"
"smlal2 v17.4s, v2.8h, v9.h[2] \n"
"smlal2 v18.4s, v2.8h, v10.h[2] \n"
"smlal2 v19.4s, v2.8h, v11.h[2] \n"
"smlal v20.4s, v3.4h, v4.h[3] \n"
"smlal v21.4s, v3.4h, v5.h[3] \n"
"smlal v22.4s, v3.4h, v6.h[3] \n"
"smlal v23.4s, v3.4h, v7.h[3] \n"
"smlal2 v24.4s, v3.8h, v4.h[3] \n"
"smlal2 v25.4s, v3.8h, v5.h[3] \n"
"smlal2 v26.4s, v3.8h, v6.h[3] \n"
"smlal2 v27.4s, v3.8h, v7.h[3] \n"
"smlal v12.4s, v3.4h, v8.h[3] \n"
"smlal v13.4s, v3.4h, v9.h[3] \n"
"smlal v14.4s, v3.4h, v10.h[3] \n"
"smlal v15.4s, v3.4h, v11.h[3] \n"
"smlal2 v16.4s, v3.8h, v8.h[3] \n"
"smlal2 v17.4s, v3.8h, v9.h[3] \n"
"smlal2 v18.4s, v3.8h, v10.h[3] \n"
"smlal2 v19.4s, v3.8h, v11.h[3] \n"
"smlal v20.4s, v0.4h, v4.h[4] \n"
"smlal v21.4s, v0.4h, v5.h[4] \n"
"smlal v22.4s, v0.4h, v6.h[4] \n"
"smlal v23.4s, v0.4h, v7.h[4] \n"
"smlal2 v24.4s, v0.8h, v4.h[4] \n"
"smlal2 v25.4s, v0.8h, v5.h[4] \n"
"smlal2 v26.4s, v0.8h, v6.h[4] \n"
"smlal2 v27.4s, v0.8h, v7.h[4] \n"
"ld1 {v2.8h, v3.8h}, [%[a]], #32 \n" //load a2, a3
"smlal v20.4s, v1.4h, v4.h[5] \n"
"smlal v21.4s, v1.4h, v5.h[5] \n"
"smlal v22.4s, v1.4h, v6.h[5] \n"
"smlal v23.4s, v1.4h, v7.h[5] \n"
"smlal2 v24.4s, v1.8h, v4.h[5] \n"
"smlal2 v25.4s, v1.8h, v5.h[5] \n"
"smlal2 v26.4s, v1.8h, v6.h[5] \n"
"smlal2 v27.4s, v1.8h, v7.h[5] \n"
"smlal v12.4s, v0.4h, v8.h[4] \n"
"smlal v13.4s, v0.4h, v9.h[4] \n"
"smlal v14.4s, v0.4h, v10.h[4] \n"
"smlal v15.4s, v0.4h, v11.h[4] \n"
"smlal2 v16.4s, v0.8h, v8.h[4] \n"
"smlal2 v17.4s, v0.8h, v9.h[4] \n"
"smlal2 v18.4s, v0.8h, v10.h[4] \n"
"smlal2 v19.4s, v0.8h, v11.h[4] \n"
"smlal v12.4s, v1.4h, v8.h[5] \n"
"smlal v13.4s, v1.4h, v9.h[5] \n"
"smlal v14.4s, v1.4h, v10.h[5] \n"
"smlal v15.4s, v1.4h, v11.h[5] \n"
"smlal2 v16.4s, v1.8h, v8.h[5] \n"
"smlal2 v17.4s, v1.8h, v9.h[5] \n"
"smlal2 v18.4s, v1.8h, v10.h[5] \n"
"smlal2 v19.4s, v1.8h, v11.h[5] \n"
"smlal v20.4s, v2.4h, v4.h[6] \n"
"smlal v21.4s, v2.4h, v5.h[6] \n"
"smlal v22.4s, v2.4h, v6.h[6] \n"
"smlal v23.4s, v2.4h, v7.h[6] \n"
"ld1 {v0.8h, v1.8h}, [%[a]], #32 \n" //load a0, a1
"smlal2 v24.4s, v2.8h, v4.h[6] \n"
"smlal2 v25.4s, v2.8h, v5.h[6] \n"
"smlal2 v26.4s, v2.8h, v6.h[6] \n"
"smlal2 v27.4s, v2.8h, v7.h[6] \n"
"sub %[b], %[b], #128 \n"
"add %[b], %[b], %[ldb] \n"
"smlal v20.4s, v3.4h, v4.h[7] \n"
"smlal v21.4s, v3.4h, v5.h[7] \n"
"smlal v22.4s, v3.4h, v6.h[7] \n"
"smlal v23.4s, v3.4h, v7.h[7] \n"
"smlal2 v24.4s, v3.8h, v4.h[7] \n"
"smlal2 v25.4s, v3.8h, v5.h[7] \n"
"smlal2 v26.4s, v3.8h, v6.h[7] \n"
"smlal2 v27.4s, v3.8h, v7.h[7] \n"
"ld1 {v4.8h, v5.8h}, [%[b]], #32 \n" //load b0, b1
"ld1 {v6.8h, v7.8h}, [%[b]], #32 \n" //load b2, b3
"smlal v12.4s, v2.4h, v8.h[6] \n"
"smlal v13.4s, v2.4h, v9.h[6] \n"
"smlal v14.4s, v2.4h, v10.h[6] \n"
"smlal v15.4s, v2.4h, v11.h[6] \n"
"smlal2 v16.4s, v2.8h, v8.h[6] \n"
"smlal2 v17.4s, v2.8h, v9.h[6] \n"
"smlal2 v18.4s, v2.8h, v10.h[6] \n"
"smlal2 v19.4s, v2.8h, v11.h[6] \n"
"subs %w[cnt], %w[cnt], #1 \n"
"smlal v12.4s, v3.4h, v8.h[7] \n"
"smlal v13.4s, v3.4h, v9.h[7] \n"
"smlal v14.4s, v3.4h, v10.h[7] \n"
"smlal v15.4s, v3.4h, v11.h[7] \n"
"smlal2 v16.4s, v3.8h, v8.h[7] \n"
"smlal2 v17.4s, v3.8h, v9.h[7] \n"
"smlal2 v18.4s, v3.8h, v10.h[7] \n"
"smlal2 v19.4s, v3.8h, v11.h[7] \n"
"beq 2f \n"
"1:\n"
"smlal v20.4s, v0.4h, v4.h[0] \n"
"smlal v21.4s, v0.4h, v5.h[0] \n"
"smlal v22.4s, v0.4h, v6.h[0] \n"
"smlal v23.4s, v0.4h, v7.h[0] \n"
"ld1 {v8.8h, v9.8h}, [%[b]], #32 \n" //load b0, b1
"ld1 {v10.8h, v11.8h}, [%[b]], #32 \n" //load b2, b3
"smlal2 v24.4s, v0.8h, v4.h[0] \n"
"smlal2 v25.4s, v0.8h, v5.h[0] \n"
"smlal2 v26.4s, v0.8h, v6.h[0] \n"
"smlal2 v27.4s, v0.8h, v7.h[0] \n"
"ld1 {v2.8h, v3.8h}, [%[a]], #32 \n" //load a2, a3
"smlal v20.4s, v1.4h, v4.h[1] \n"
"smlal v21.4s, v1.4h, v5.h[1] \n"
"smlal v22.4s, v1.4h, v6.h[1] \n"
"smlal v23.4s, v1.4h, v7.h[1] \n"
"smlal2 v24.4s, v1.8h, v4.h[1] \n"
"smlal2 v25.4s, v1.8h, v5.h[1] \n"
"smlal2 v26.4s, v1.8h, v6.h[1] \n"
"smlal2 v27.4s, v1.8h, v7.h[1] \n"
"smlal v12.4s, v0.4h, v8.h[0] \n"
"smlal v13.4s, v0.4h, v9.h[0] \n"
"smlal v14.4s, v0.4h, v10.h[0] \n"
"smlal v15.4s, v0.4h, v11.h[0] \n"
"smlal2 v16.4s, v0.8h, v8.h[0] \n"
"smlal2 v17.4s, v0.8h, v9.h[0] \n"
"smlal2 v18.4s, v0.8h, v10.h[0] \n"
"smlal2 v19.4s, v0.8h, v11.h[0] \n"
"smlal v12.4s, v1.4h, v8.h[1] \n"
"smlal v13.4s, v1.4h, v9.h[1] \n"
"smlal v14.4s, v1.4h, v10.h[1] \n"
"smlal v15.4s, v1.4h, v11.h[1] \n"
"smlal2 v16.4s, v1.8h, v8.h[1] \n"
"smlal2 v17.4s, v1.8h, v9.h[1] \n"
"smlal2 v18.4s, v1.8h, v10.h[1] \n"
"smlal2 v19.4s, v1.8h, v11.h[1] \n"
"smlal v20.4s, v2.4h, v4.h[2] \n"
"smlal v21.4s, v2.4h, v5.h[2] \n"
"smlal v22.4s, v2.4h, v6.h[2] \n"
"smlal v23.4s, v2.4h, v7.h[2] \n"
"ld1 {v0.8h, v1.8h}, [%[a]], #32 \n" //load a0, a1
"smlal2 v24.4s, v2.8h, v4.h[2] \n"
"smlal2 v25.4s, v2.8h, v5.h[2] \n"
"smlal2 v26.4s, v2.8h, v6.h[2] \n"
"smlal2 v27.4s, v2.8h, v7.h[2] \n"
"smlal v12.4s, v2.4h, v8.h[2] \n"
"smlal v13.4s, v2.4h, v9.h[2] \n"
"smlal v14.4s, v2.4h, v10.h[2] \n"
"smlal v15.4s, v2.4h, v11.h[2] \n"
"smlal2 v16.4s, v2.8h, v8.h[2] \n"
"smlal2 v17.4s, v2.8h, v9.h[2] \n"
"smlal2 v18.4s, v2.8h, v10.h[2] \n"
"smlal2 v19.4s, v2.8h, v11.h[2] \n"
"smlal v20.4s, v3.4h, v4.h[3] \n"
"smlal v21.4s, v3.4h, v5.h[3] \n"
"smlal v22.4s, v3.4h, v6.h[3] \n"
"smlal v23.4s, v3.4h, v7.h[3] \n"
"smlal2 v24.4s, v3.8h, v4.h[3] \n"
"smlal2 v25.4s, v3.8h, v5.h[3] \n"
"smlal2 v26.4s, v3.8h, v6.h[3] \n"
"smlal2 v27.4s, v3.8h, v7.h[3] \n"
"smlal v12.4s, v3.4h, v8.h[3] \n"
"smlal v13.4s, v3.4h, v9.h[3] \n"
"smlal v14.4s, v3.4h, v10.h[3] \n"
"smlal v15.4s, v3.4h, v11.h[3] \n"
"smlal2 v16.4s, v3.8h, v8.h[3] \n"
"smlal2 v17.4s, v3.8h, v9.h[3] \n"
"smlal2 v18.4s, v3.8h, v10.h[3] \n"
"smlal2 v19.4s, v3.8h, v11.h[3] \n"
"smlal v20.4s, v0.4h, v4.h[4] \n"
"smlal v21.4s, v0.4h, v5.h[4] \n"
"smlal v22.4s, v0.4h, v6.h[4] \n"
"smlal v23.4s, v0.4h, v7.h[4] \n"
"smlal2 v24.4s, v0.8h, v4.h[4] \n"
"smlal2 v25.4s, v0.8h, v5.h[4] \n"
"smlal2 v26.4s, v0.8h, v6.h[4] \n"
"smlal2 v27.4s, v0.8h, v7.h[4] \n"
"ld1 {v2.8h, v3.8h}, [%[a]], #32 \n" //load a2, a3
"smlal v20.4s, v1.4h, v4.h[5] \n"
"smlal v21.4s, v1.4h, v5.h[5] \n"
"smlal v22.4s, v1.4h, v6.h[5] \n"
"smlal v23.4s, v1.4h, v7.h[5] \n"
"smlal2 v24.4s, v1.8h, v4.h[5] \n"
"smlal2 v25.4s, v1.8h, v5.h[5] \n"
"smlal2 v26.4s, v1.8h, v6.h[5] \n"
"smlal2 v27.4s, v1.8h, v7.h[5] \n"
"smlal v12.4s, v0.4h, v8.h[4] \n"
"smlal v13.4s, v0.4h, v9.h[4] \n"
"smlal v14.4s, v0.4h, v10.h[4] \n"
"smlal v15.4s, v0.4h, v11.h[4] \n"
"smlal2 v16.4s, v0.8h, v8.h[4] \n"
"smlal2 v17.4s, v0.8h, v9.h[4] \n"
"smlal2 v18.4s, v0.8h, v10.h[4] \n"
"smlal2 v19.4s, v0.8h, v11.h[4] \n"
"smlal v12.4s, v1.4h, v8.h[5] \n"
"smlal v13.4s, v1.4h, v9.h[5] \n"
"smlal v14.4s, v1.4h, v10.h[5] \n"
"smlal v15.4s, v1.4h, v11.h[5] \n"
"smlal2 v16.4s, v1.8h, v8.h[5] \n"
"smlal2 v17.4s, v1.8h, v9.h[5] \n"
"smlal2 v18.4s, v1.8h, v10.h[5] \n"
"smlal2 v19.4s, v1.8h, v11.h[5] \n"
"smlal v20.4s, v2.4h, v4.h[6] \n"
"smlal v21.4s, v2.4h, v5.h[6] \n"
"smlal v22.4s, v2.4h, v6.h[6] \n"
"smlal v23.4s, v2.4h, v7.h[6] \n"
"ld1 {v0.8h, v1.8h}, [%[a]], #32 \n" //load a0, a1
"smlal2 v24.4s, v2.8h, v4.h[6] \n"
"smlal2 v25.4s, v2.8h, v5.h[6] \n"
"smlal2 v26.4s, v2.8h, v6.h[6] \n"
"smlal2 v27.4s, v2.8h, v7.h[6] \n"
"sub %[b], %[b], #128 \n"
"add %[b], %[b], %[ldb] \n"
"smlal v20.4s, v3.4h, v4.h[7] \n"
"smlal v21.4s, v3.4h, v5.h[7] \n"
"smlal v22.4s, v3.4h, v6.h[7] \n"
"smlal v23.4s, v3.4h, v7.h[7] \n"
"smlal2 v24.4s, v3.8h, v4.h[7] \n"
"smlal2 v25.4s, v3.8h, v5.h[7] \n"
"smlal2 v26.4s, v3.8h, v6.h[7] \n"
"smlal2 v27.4s, v3.8h, v7.h[7] \n"
"ld1 {v4.8h, v5.8h}, [%[b]], #32 \n" //load b0, b1
"ld1 {v6.8h, v7.8h}, [%[b]], #32 \n" //load b2, b3
"smlal v12.4s, v2.4h, v8.h[6] \n"
"smlal v13.4s, v2.4h, v9.h[6] \n"
"smlal v14.4s, v2.4h, v10.h[6] \n"
"smlal v15.4s, v2.4h, v11.h[6] \n"
"smlal2 v16.4s, v2.8h, v8.h[6] \n"
"smlal2 v17.4s, v2.8h, v9.h[6] \n"
"smlal2 v18.4s, v2.8h, v10.h[6] \n"
"smlal2 v19.4s, v2.8h, v11.h[6] \n"
"subs %w[cnt], %w[cnt], #1 \n"
"smlal v12.4s, v3.4h, v8.h[7] \n"
"smlal v13.4s, v3.4h, v9.h[7] \n"
"smlal v14.4s, v3.4h, v10.h[7] \n"
"smlal v15.4s, v3.4h, v11.h[7] \n"
"smlal2 v16.4s, v3.8h, v8.h[7] \n"
"smlal2 v17.4s, v3.8h, v9.h[7] \n"
"smlal2 v18.4s, v3.8h, v10.h[7] \n"
"smlal2 v19.4s, v3.8h, v11.h[7] \n"
"bne 1b \n"
"2: \n"
"stp q20, q24, [%[c]], #32 \n"
"stp q21, q25, [%[c]], #32 \n"
"stp q22, q26, [%[c]], #32 \n"
"stp q23, q27, [%[c]], #32 \n"
"stp q12, q16, [%[c]], #32 \n"
"stp q13, q17, [%[c]], #32 \n"
"stp q14, q18, [%[c]], #32 \n"
"stp q15, q19, [%[c]], #32 \n"
: [a] "+r" (a_ptr),
[b] "+r" (b_ptr),
[c] "+r" (C),
[cnt] "+r" (cnt)
: [ldb] "r" (ldb_byte)
: "v0", "v1", "v2", "v3", "v4","v5", "v6", "v7", "v8", "v9",
"v10", "v11", "13", "14", "15", "16", "17", "18", "19","v20",
"v21", "v22", "v23", "v24", "v25", "v26", "v27", "cc", "memory"
);
// clang format on
b += 64;
}
for (; n > 3; n -= 4) {
int cnt = kcnt;
const int16_t* a_ptr = A_packed;
const int16_t* b_ptr = b;
// clang-format off
asm volatile(
"ld1 {v0.8h, v1.8h}, [%[a]], #32 \n"
"ld1 {v4.8h, v5.8h}, [%[b]], #32 \n"
"smull v8.4s, v0.4h, v4.h[0] \n"
"smull v9.4s, v0.4h, v5.h[0] \n"
"ld1 {v6.8h, v7.8h}, [%[b]], #32 \n"
"smull2 v10.4s, v0.8h, v4.h[0] \n"
"smull2 v11.4s, v0.8h, v5.h[0] \n"
"smlal v8.4s, v1.4h, v4.h[1] \n"
"smlal v9.4s, v1.4h, v5.h[1] \n"
"smlal2 v10.4s, v1.8h, v4.h[1] \n"
"smlal2 v11.4s, v1.8h, v5.h[1] \n"
"ld1 {v2.8h, v3.8h}, [%[a]], #32 \n"
"smull v12.4s, v0.4h, v6.h[0] \n"
"smull v13.4s, v0.4h, v7.h[0] \n"
"smull2 v14.4s, v0.8h, v6.h[0] \n"
"smull2 v15.4s, v0.8h, v7.h[0] \n"
"smlal v12.4s, v1.4h, v6.h[1] \n"
"smlal v13.4s, v1.4h, v7.h[1] \n"
"smlal2 v14.4s, v1.8h, v6.h[1] \n"
"smlal2 v15.4s, v1.8h, v7.h[1] \n"
"smlal v8.4s, v2.4h, v4.h[2] \n"
"smlal v9.4s, v2.4h, v5.h[2] \n"
"ld1 {v0.8h, v1.8h}, [%[a]], #32 \n"
"smlal2 v10.4s, v2.8h, v4.h[2] \n"
"smlal2 v11.4s, v2.8h, v5.h[2] \n"
"smlal v8.4s, v3.4h, v4.h[3] \n"
"smlal v9.4s, v3.4h, v5.h[3] \n"
"smlal2 v10.4s, v3.8h, v4.h[3] \n"
"smlal2 v11.4s, v3.8h, v5.h[3] \n"
"smlal v12.4s, v2.4h, v6.h[2] \n"
"smlal v13.4s, v2.4h, v7.h[2] \n"
"smlal2 v14.4s, v2.8h, v6.h[2] \n"
"smlal2 v15.4s, v2.8h, v7.h[2] \n"
"smlal v12.4s, v3.4h, v6.h[3] \n"
"smlal v13.4s, v3.4h, v7.h[3] \n"
"smlal2 v14.4s, v3.8h, v6.h[3] \n"
"smlal2 v15.4s, v3.8h, v7.h[3] \n"
"smlal v8.4s, v0.4h, v4.h[4] \n"
"smlal v9.4s, v0.4h, v5.h[4] \n"
"smlal2 v10.4s, v0.8h, v4.h[4] \n"
"smlal2 v11.4s, v0.8h, v5.h[4] \n"
"smlal v8.4s, v1.4h, v4.h[5] \n"
"smlal v9.4s, v1.4h, v5.h[5] \n"
"smlal2 v10.4s, v1.8h, v4.h[5] \n"
"smlal2 v11.4s, v1.8h, v5.h[5] \n"
"ld1 {v2.8h, v3.8h}, [%[a]], #32 \n"
"smlal v12.4s, v0.4h, v6.h[4] \n"
"smlal v13.4s, v0.4h, v7.h[4] \n"
"smlal2 v14.4s, v0.8h, v6.h[4] \n"
"smlal2 v15.4s, v0.8h, v7.h[4] \n"
"smlal v12.4s, v1.4h, v6.h[5] \n"
"smlal v13.4s, v1.4h, v7.h[5] \n"
"smlal2 v14.4s, v1.8h, v6.h[5] \n"
"smlal2 v15.4s, v1.8h, v7.h[5] \n"
"smlal v8.4s, v2.4h, v4.h[6] \n"
"smlal v9.4s, v2.4h, v5.h[6] \n"
"ld1 {v0.8h, v1.8h}, [%[a]], #32 \n"
"smlal2 v10.4s, v2.8h, v4.h[6] \n"
"smlal2 v11.4s, v2.8h, v5.h[6] \n"
"smlal v8.4s, v3.4h, v4.h[7] \n"
"smlal v9.4s, v3.4h, v5.h[7] \n"
"smlal2 v10.4s, v3.8h, v4.h[7] \n"
"smlal2 v11.4s, v3.8h, v5.h[7] \n"
"sub %[b], %[b], #64 \n"
"add %[b], %[b], %[ldb] \n"
"smlal v12.4s, v2.4h, v6.h[6] \n"
"smlal v13.4s, v2.4h, v7.h[6] \n"
"subs %w[cnt], %w[cnt], #1 \n"
"ld1 {v4.8h, v5.8h}, [%[b]], #32 \n"
"smlal2 v14.4s, v2.8h, v6.h[6] \n"
"smlal2 v15.4s, v2.8h, v7.h[6] \n"
"smlal v12.4s, v3.4h, v6.h[7] \n"
"smlal v13.4s, v3.4h, v7.h[7] \n"
"smlal2 v14.4s, v3.8h, v6.h[7] \n"
"smlal2 v15.4s, v3.8h, v7.h[7] \n"
"beq 2f \n"
"1: \n"
"smlal v8.4s, v0.4h, v4.h[0] \n"
"smlal v9.4s, v0.4h, v5.h[0] \n"
"ld1 {v6.8h, v7.8h}, [%[b]], #32 \n"
"smlal2 v10.4s, v0.8h, v4.h[0] \n"
"smlal2 v11.4s, v0.8h, v5.h[0] \n"
"smlal v8.4s, v1.4h, v4.h[1] \n"
"smlal v9.4s, v1.4h, v5.h[1] \n"
"smlal2 v10.4s, v1.8h, v4.h[1] \n"
"smlal2 v11.4s, v1.8h, v5.h[1] \n"
"ld1 {v2.8h, v3.8h}, [%[a]], #32 \n"
"smlal v12.4s, v0.4h, v6.h[0] \n"
"smlal v13.4s, v0.4h, v7.h[0] \n"
"smlal2 v14.4s, v0.8h, v6.h[0] \n"
"smlal2 v15.4s, v0.8h, v7.h[0] \n"
"smlal v12.4s, v1.4h, v6.h[1] \n"
"smlal v13.4s, v1.4h, v7.h[1] \n"
"smlal2 v14.4s, v1.8h, v6.h[1] \n"
"smlal2 v15.4s, v1.8h, v7.h[1] \n"
"smlal v8.4s, v2.4h, v4.h[2] \n"
"smlal v9.4s, v2.4h, v5.h[2] \n"
"ld1 {v0.8h, v1.8h}, [%[a]], #32 \n"
"smlal2 v10.4s, v2.8h, v4.h[2] \n"
"smlal2 v11.4s, v2.8h, v5.h[2] \n"
"smlal v8.4s, v3.4h, v4.h[3] \n"
"smlal v9.4s, v3.4h, v5.h[3] \n"
"smlal2 v10.4s, v3.8h, v4.h[3] \n"
"smlal2 v11.4s, v3.8h, v5.h[3] \n"
"smlal v12.4s, v2.4h, v6.h[2] \n"
"smlal v13.4s, v2.4h, v7.h[2] \n"
"smlal2 v14.4s, v2.8h, v6.h[2] \n"
"smlal2 v15.4s, v2.8h, v7.h[2] \n"
"smlal v12.4s, v3.4h, v6.h[3] \n"
"smlal v13.4s, v3.4h, v7.h[3] \n"
"smlal2 v14.4s, v3.8h, v6.h[3] \n"
"smlal2 v15.4s, v3.8h, v7.h[3] \n"
"smlal v8.4s, v0.4h, v4.h[4] \n"
"smlal v9.4s, v0.4h, v5.h[4] \n"
"ld1 {v2.8h, v3.8h}, [%[a]], #32 \n"
"smlal2 v10.4s, v0.8h, v4.h[4] \n"
"smlal2 v11.4s, v0.8h, v5.h[4] \n"
"smlal v8.4s, v1.4h, v4.h[5] \n"
"smlal v9.4s, v1.4h, v5.h[5] \n"
"smlal2 v10.4s, v1.8h, v4.h[5] \n"
"smlal2 v11.4s, v1.8h, v5.h[5] \n"
"smlal v12.4s, v0.4h, v6.h[4] \n"
"smlal v13.4s, v0.4h, v7.h[4] \n"
"smlal2 v14.4s, v0.8h, v6.h[4] \n"
"smlal2 v15.4s, v0.8h, v7.h[4] \n"
"smlal v12.4s, v1.4h, v6.h[5] \n"
"smlal v13.4s, v1.4h, v7.h[5] \n"
"smlal2 v14.4s, v1.8h, v6.h[5] \n"
"smlal2 v15.4s, v1.8h, v7.h[5] \n"
"smlal v8.4s, v2.4h, v4.h[6] \n"
"smlal v9.4s, v2.4h, v5.h[6] \n"
"ld1 {v0.8h, v1.8h}, [%[a]], #32 \n"
"smlal2 v10.4s, v2.8h, v4.h[6] \n"
"smlal2 v11.4s, v2.8h, v5.h[6] \n"
"smlal v8.4s, v3.4h, v4.h[7] \n"
"smlal v9.4s, v3.4h, v5.h[7] \n"
"smlal2 v10.4s, v3.8h, v4.h[7] \n"
"smlal2 v11.4s, v3.8h, v5.h[7] \n"
"sub %[b], %[b], #64 \n"
"add %[b], %[b], %[ldb] \n"
"smlal v12.4s, v2.4h, v6.h[6] \n"
"smlal v13.4s, v2.4h, v7.h[6] \n"
"subs %w[cnt], %w[cnt], #1 \n"
"ld1 {v4.8h, v5.8h}, [%[b]], #32 \n"
"smlal2 v14.4s, v2.8h, v6.h[6] \n"
"smlal2 v15.4s, v2.8h, v7.h[6] \n"
"smlal v12.4s, v3.4h, v6.h[7] \n"
"smlal v13.4s, v3.4h, v7.h[7] \n"
"smlal2 v14.4s, v3.8h, v6.h[7] \n"
"smlal2 v15.4s, v3.8h, v7.h[7] \n"
"bne 1b \n"
"2: \n"
"stp q8, q10, [%[c]], #32 \n"
"stp q9, q11, [%[c]], #32 \n"
"stp q12, q14, [%[c]], #32 \n"
"stp q13, q15, [%[c]], #32 \n"
: [a] "+r" (a_ptr),
[b] "+r" (b_ptr),
[c] "+r" (C),
[cnt] "+r" (cnt)
: [ldb] "r" (ldb_byte)
: "v0", "v1", "v2", "v3", "v4","v5", "v6", "v7", "v8", "v9",
"v10", "v11","v12", "v13", "v14", "v15", "cc", "memory"
);
// clang-format on
b += 32;
}
for (; n > 0; --n) {
int cnt = kcnt;
const int16_t* a_ptr = A_packed;
const int16_t* b_ptr = b;
// clang-format off
asm volatile(
"ld1 {v0.8h, v1.8h}, [%[a]], #32 \n"
"ld1 {v4.8h}, [%[b]], #16 \n"
"ld1 {v2.8h, v3.8h}, [%[a]], #32 \n"
"smull v5.4s, v0.4h, v4.h[0] \n"
"smull2 v6.4s, v0.8h, v4.h[0] \n"
"ld1 {v10.8h, v11.8h}, [%[a]], #32 \n"
"smlal v5.4s, v1.4h, v4.h[1] \n"
"smlal2 v6.4s, v1.8h, v4.h[1] \n"
"ld1 {v12.8h, v13.8h}, [%[a]], #32 \n"
"smlal v5.4s, v2.4h, v4.h[2] \n"
"smlal2 v6.4s, v2.8h, v4.h[2] \n"
"smlal v5.4s, v3.4h, v4.h[3] \n"
"smlal2 v6.4s, v3.8h, v4.h[3] \n"
"sub %[b], %[b], #16 \n"
"add %[b], %[b], %[ldb] \n"
"smlal v5.4s, v10.4h, v4.h[4] \n"
"smlal2 v6.4s, v10.8h, v4.h[4] \n"
"smlal v5.4s, v11.4h, v4.h[5] \n"
"smlal2 v6.4s, v11.8h, v4.h[5] \n"
"subs %w[cnt], %w[cnt], #1 \n"
"ld1 {v0.8h, v1.8h}, [%[a]], #32 \n"
"smlal v5.4s, v12.4h, v4.h[6] \n"
"smlal2 v6.4s, v12.8h, v4.h[6] \n"
"smlal v5.4s, v13.4h, v4.h[7] \n"
"smlal2 v6.4s, v13.8h, v4.h[7] \n"
"beq 2f \n"
"1: \n"
"ld1 {v4.8h}, [%[b]], #16 \n"
"ld1 {v2.8h, v3.8h}, [%[a]], #32 \n"
"smlal v5.4s, v0.4h, v4.h[0] \n"
"smlal2 v6.4s, v0.8h, v4.h[0] \n"
"ld1 {v10.8h, v11.8h}, [%[a]], #32 \n"
"smlal v5.4s, v1.4h, v4.h[1] \n"
"smlal2 v6.4s, v1.8h, v4.h[1] \n"
"ld1 {v12.8h, v13.8h}, [%[a]], #32 \n"
"smlal v5.4s, v2.4h, v4.h[2] \n"
"smlal2 v6.4s, v2.8h, v4.h[2] \n"
"smlal v5.4s, v3.4h, v4.h[3] \n"
"smlal2 v6.4s, v3.8h, v4.h[3] \n"
"sub %[b], %[b], #16 \n"
"add %[b], %[b], %[ldb] \n"
"smlal v5.4s, v10.4h, v4.h[4] \n"
"smlal2 v6.4s, v10.8h, v4.h[4] \n"
"smlal v5.4s, v11.4h, v4.h[5] \n"
"smlal2 v6.4s, v11.8h, v4.h[5] \n"
"subs %w[cnt], %w[cnt], #1 \n"
"ld1 {v0.8h, v1.8h}, [%[a]], #32 \n"
"smlal v5.4s, v12.4h, v4.h[6] \n"
"smlal2 v6.4s, v12.8h, v4.h[6] \n"
"smlal v5.4s, v13.4h, v4.h[7] \n"
"smlal2 v6.4s, v13.8h, v4.h[7] \n"
"bne 1b \n"
"2: \n"
"st1 {v5.4s, v6.4s}, [%[c]], #32 \n"
: [a] "+r" (a_ptr),
[b] "+r" (b_ptr),
[c] "+r" (C),
[cnt] "+r" (cnt)
: [ldb] "r" (ldb_byte)
: "v0", "v1", "v2", "v3", "v4","v5", "v6", "cc", "memory"
);
// clang-format on
b += 8;
}
#else
for (; n > 3; n -= 4) {
int cnt = kcnt;
const int16_t* a_ptr = A_packed;
const int16_t* b_ptr = b;
// clang-format off
asm volatile (
"vld1.16 {d0-d3}, [%[b]]! \n"
"vld1.16 {d8-d11}, [%[a]]! \n"
"vld1.16 {d4-d7}, [%[b]]! \n"
"vmull.s16 q8, d8, d0[0] \n"
"vmull.s16 q9, d8, d2[0] \n"
"vld1.16 {d12-d15}, [%[a]]! \n"
"vmull.s16 q10, d9, d0[0] \n"
"vmull.s16 q11, d9, d2[0] \n"
"vmlal.s16 q8, d10, d0[1] \n"
"vmlal.s16 q9, d10, d2[1] \n"
"vmlal.s16 q10, d11, d0[1] \n"
"vmlal.s16 q11, d11, d2[1] \n"
"vmull.s16 q12, d8, d4[0] \n"
"vmull.s16 q13, d8, d6[0] \n"
"vmull.s16 q14, d9, d4[0] \n"
"vmull.s16 q15, d9, d6[0] \n"
"vmlal.s16 q12, d10, d4[1] \n"
"vmlal.s16 q13, d10, d6[1] \n"
"vmlal.s16 q14, d11, d4[1] \n"
"vmlal.s16 q15, d11, d6[1] \n"
"vmlal.s16 q8, d12, d0[2] \n"
"vmlal.s16 q9, d12, d2[2] \n"
"vld1.16 {d8-d11}, [%[a]]! \n"
"vmlal.s16 q10, d13, d0[2] \n"
"vmlal.s16 q11, d13, d2[2] \n"
"vmlal.s16 q8, d14, d0[3] \n"
"vmlal.s16 q9, d14, d2[3] \n"
"vmlal.s16 q10, d15, d0[3] \n"
"vmlal.s16 q11, d15, d2[3] \n"
"vmlal.s16 q12, d12, d4[2] \n"
"vmlal.s16 q13, d12, d6[2] \n"
"vmlal.s16 q14, d13, d4[2] \n"
"vmlal.s16 q15, d13, d6[2] \n"
"vmlal.s16 q12, d14, d4[3] \n"
"vmlal.s16 q13, d14, d6[3] \n"
"vmlal.s16 q14, d15, d4[3] \n"
"vmlal.s16 q15, d15, d6[3] \n"
"sub %[b], %[b], #64 \n"
"add %[b], %[b], %[ldb] \n"
"vld1.16 {d12-d15}, [%[a]]! \n"
"vmlal.s16 q8, d8, d1[0] \n"
"vmlal.s16 q9, d8, d3[0] \n"
"vmlal.s16 q10, d9, d1[0] \n"
"vmlal.s16 q11, d9, d3[0] \n"
"vmlal.s16 q8, d10, d1[1] \n"
"vmlal.s16 q9, d10, d3[1] \n"
"vmlal.s16 q10, d11, d1[1] \n"
"vmlal.s16 q11, d11, d3[1] \n"
"vmlal.s16 q8, d12, d1[2] \n"
"vmlal.s16 q9, d12, d3[2] \n"
"vmlal.s16 q10, d13, d1[2] \n"
"vmlal.s16 q11, d13, d3[2] \n"
"vmlal.s16 q8, d14, d1[3] \n"
"vmlal.s16 q9, d14, d3[3] \n"
"vmlal.s16 q10, d15, d1[3] \n"
"vmlal.s16 q11, d15, d3[3] \n"
"vld1.16 {d0-d3}, [%[b]]! \n"
"vmlal.s16 q12, d8, d5[0] \n"
"vmlal.s16 q13, d8, d7[0] \n"
"vmlal.s16 q14, d9, d5[0] \n"
"vmlal.s16 q15, d9, d7[0] \n"
"vmlal.s16 q12, d10, d5[1] \n"
"vmlal.s16 q13, d10, d7[1] \n"
"subs %[cnt], %[cnt], #1 \n"
"vmlal.s16 q14, d11, d5[1] \n"
"vmlal.s16 q15, d11, d7[1] \n"
"vld1.16 {d8-d11}, [%[a]]! \n"
"vmlal.s16 q12, d12, d5[2] \n"
"vmlal.s16 q13, d12, d7[2] \n"
"vmlal.s16 q14, d13, d5[2] \n"
"vmlal.s16 q15, d13, d7[2] \n"
"vmlal.s16 q12, d14, d5[3] \n"
"vmlal.s16 q13, d14, d7[3] \n"
"vmlal.s16 q14, d15, d5[3] \n"
"vmlal.s16 q15, d15, d7[3] \n"
"beq 2f \n"
"1: \n"
"vld1.16 {d4-d7}, [%[b]]! \n"
"vmlal.s16 q8, d8, d0[0] \n"
"vmlal.s16 q9, d8, d2[0] \n"
"vld1.16 {d12-d15}, [%[a]]! \n"
"vmlal.s16 q10, d9, d0[0] \n"
"vmlal.s16 q11, d9, d2[0] \n"
"vmlal.s16 q8, d10, d0[1] \n"
"vmlal.s16 q9, d10, d2[1] \n"
"vmlal.s16 q10, d11, d0[1] \n"
"vmlal.s16 q11, d11, d2[1] \n"
"vmlal.s16 q12, d8, d4[0] \n"
"vmlal.s16 q13, d8, d6[0] \n"
"vmlal.s16 q14, d9, d4[0] \n"
"vmlal.s16 q15, d9, d6[0] \n"
"vmlal.s16 q12, d10, d4[1] \n"
"vmlal.s16 q13, d10, d6[1] \n"
"vmlal.s16 q14, d11, d4[1] \n"
"vmlal.s16 q15, d11, d6[1] \n"
"vmlal.s16 q8, d12, d0[2] \n"
"vmlal.s16 q9, d12, d2[2] \n"
"vld1.16 {d8-d11}, [%[a]]! \n"
"vmlal.s16 q10, d13, d0[2] \n"
"vmlal.s16 q11, d13, d2[2] \n"
"vmlal.s16 q8, d14, d0[3] \n"
"vmlal.s16 q9, d14, d2[3] \n"
"vmlal.s16 q10, d15, d0[3] \n"
"vmlal.s16 q11, d15, d2[3] \n"
"vmlal.s16 q12, d12, d4[2] \n"
"vmlal.s16 q13, d12, d6[2] \n"
"vmlal.s16 q14, d13, d4[2] \n"
"vmlal.s16 q15, d13, d6[2] \n"
"vmlal.s16 q12, d14, d4[3] \n"
"vmlal.s16 q13, d14, d6[3] \n"
"vmlal.s16 q14, d15, d4[3] \n"
"vmlal.s16 q15, d15, d6[3] \n"
"sub %[b], %[b], #64 \n"
"add %[b], %[b], %[ldb] \n"
"vld1.16 {d12-d15}, [%[a]]! \n"
"vmlal.s16 q8, d8, d1[0] \n"
"vmlal.s16 q9, d8, d3[0] \n"
"vmlal.s16 q10, d9, d1[0] \n"
"vmlal.s16 q11, d9, d3[0] \n"
"vmlal.s16 q8, d10, d1[1] \n"
"vmlal.s16 q9, d10, d3[1] \n"
"vmlal.s16 q10, d11, d1[1] \n"
"vmlal.s16 q11, d11, d3[1] \n"
"vmlal.s16 q8, d12, d1[2] \n"
"vmlal.s16 q9, d12, d3[2] \n"
"vmlal.s16 q10, d13, d1[2] \n"
"vmlal.s16 q11, d13, d3[2] \n"
"vmlal.s16 q8, d14, d1[3] \n"
"vmlal.s16 q9, d14, d3[3] \n"
"vmlal.s16 q10, d15, d1[3] \n"
"vmlal.s16 q11, d15, d3[3] \n"
"vld1.16 {d0-d3}, [%[b]]! \n"
"vmlal.s16 q12, d8, d5[0] \n"
"vmlal.s16 q13, d8, d7[0] \n"
"vmlal.s16 q14, d9, d5[0] \n"
"vmlal.s16 q15, d9, d7[0] \n"
"vmlal.s16 q12, d10, d5[1] \n"
"vmlal.s16 q13, d10, d7[1] \n"
"subs %[cnt], %[cnt], #1 \n"
"vmlal.s16 q14, d11, d5[1] \n"
"vmlal.s16 q15, d11, d7[1] \n"
"vld1.16 {d8-d11}, [%[a]]! \n"
"vmlal.s16 q12, d12, d5[2] \n"
"vmlal.s16 q13, d12, d7[2] \n"
"vmlal.s16 q14, d13, d5[2] \n"
"vmlal.s16 q15, d13, d7[2] \n"
"vmlal.s16 q12, d14, d5[3] \n"
"vmlal.s16 q13, d14, d7[3] \n"
"vmlal.s16 q14, d15, d5[3] \n"
"vmlal.s16 q15, d15, d7[3] \n"
"bne 1b \n"
"2: \n"
"vst1.32 {d16-d17}, [%[c]]! \n"
"vst1.32 {d20-d21}, [%[c]]! \n"
"vst1.32 {d18-d19}, [%[c]]! \n"
"vst1.32 {d22-d23}, [%[c]]! \n"
"vst1.32 {d24-d25}, [%[c]]! \n"
"vst1.32 {d28-d29}, [%[c]]! \n"
"vst1.32 {d26-d27}, [%[c]]! \n"
"vst1.32 {d30-d31}, [%[c]]! \n"
: [a] "+r" (a_ptr),
[b] "+r" (b_ptr),
[c] "+r" (C),
[cnt] "+r" (cnt)
: [ldb] "r" (ldb_byte)
: "q0", "q1", "q2", "q3", "q4","q5", "q6", "q7", "q8",
"q9", "q10", "q11", "q12", "q13", "q14", "q15", "cc", "memory"
);
// clang format on
b += 32;
}
for (; n > 0; --n) {
int cnt = kcnt;
const int16_t* a_ptr = A_packed;
const int16_t* b_ptr = b;
// clang format off
asm volatile (
"vld1.16 {d0-d1}, [%[b]]! \n"
"vld1.16 {d4-d7}, [%[a]]! \n"
"vld1.16 {d8-d11}, [%[a]]! \n"
"vmull.s16 q8, d4, d0[0] \n"
"vmull.s16 q9, d5, d0[0] \n"
"sub %[b], %[b], #16 \n"
"vmlal.s16 q8, d6, d0[1] \n"
"vmlal.s16 q9, d7, d0[1] \n"
"add %[b], %[b], %[ldb] \n"
"subs %[cnt], %[cnt], #1 \n"
"vld1.16 {d4-d7}, [%[a]]! \n"
"vmlal.s16 q8, d8, d0[2] \n"
"vmlal.s16 q9, d9, d0[2] \n"
"vmlal.s16 q8, d10, d0[3] \n"
"vmlal.s16 q9, d11, d0[3] \n"
"vld1.16 {d8-d11}, [%[a]]! \n"
"vmlal.s16 q8, d4, d1[0] \n"
"vmlal.s16 q9, d5, d1[0] \n"
"vmlal.s16 q8, d6, d1[1] \n"
"vmlal.s16 q9, d7, d1[1] \n"
"vld1.16 {d4-d7}, [%[a]]! \n"
"vmlal.s16 q8, d8, d1[2] \n"
"vmlal.s16 q9, d9, d1[2] \n"
"vmlal.s16 q8, d10, d1[3] \n"
"vmlal.s16 q9, d11, d1[3] \n"
"beq 2f \n"
"1:\n"
"vld1.16 {d0-d1}, [%[b]]! \n"
"vld1.16 {d8-d11}, [%[a]]! \n"
"vmlal.s16 q8, d4, d0[0] \n"
"vmlal.s16 q9, d5, d0[0] \n"
"sub %[b], %[b], #16 \n"
"vmlal.s16 q8, d6, d0[1] \n"
"vmlal.s16 q9, d7, d0[1] \n"
"add %[b], %[b], %[ldb] \n"
"subs %[cnt], %[cnt], #1 \n"
"vld1.16 {d4-d7}, [%[a]]! \n"
"vmlal.s16 q8, d8, d0[2] \n"
"vmlal.s16 q9, d9, d0[2] \n"
"vmlal.s16 q8, d10, d0[3] \n"
"vmlal.s16 q9, d11, d0[3] \n"
"vld1.16 {d8-d11}, [%[a]]! \n"
"vmlal.s16 q8, d4, d1[0] \n"
"vmlal.s16 q9, d5, d1[0] \n"
"vmlal.s16 q8, d6, d1[1] \n"
"vmlal.s16 q9, d7, d1[1] \n"
"vld1.16 {d4-d7}, [%[a]]! \n"
"vmlal.s16 q8, d8, d1[2] \n"
"vmlal.s16 q9, d9, d1[2] \n"
"vmlal.s16 q8, d10, d1[3] \n"
"vmlal.s16 q9, d11, d1[3] \n"
"bne 1b \n"
"2: \n"
"vst1.32 {d16-d19}, [%[c]]! \n"
: [a] "+r" (a_ptr),
[b] "+r" (b_ptr),
[c] "+r" (C),
[cnt] "+r" (cnt)
: [ldb] "r" (ldb_byte)
: "q0", "q1", "q2", "q3", "q4","q5", "q6", "q7", "q8",
"q9", "cc", "memory"
);
// clang-format on
b += 8;
}
#endif
A_packed += lda;
}
}
void sgemm_prepack_c4(int M,
int N,
int K,
......
......@@ -54,6 +54,13 @@ void sgemm_prepack_c4_small(int M,
const float* B,
float* C,
ARMContext* ctx);
void sgemm_prepack_c8_int16_small(int M,
int N,
int K,
const int16_t* A_packed,
const int16_t* B,
int32_t* C,
ARMContext* ctx);
} // namespace math
} // namespace arm
} // namespace lite
......
......@@ -206,6 +206,20 @@ void pooling_basic(const float* din,
"ld2 {v0.4s, v1.4s}, [%[dr0]], #32\n" /* load q0-q1, dr0, 0-7*/ \
"ld2 {v2.4s, v3.4s}, [%[dr1]], #32\n" /* load q2-q3, dr1, 0-7*/
#define P2x2S2P1_MAX \
"ext v6.16b, %[vzero].16b, v1.16b, #12\n" /* 1357-0135 */ \
"ext v8.16b, %[vzero].16b, v3.16b, #12\n" /* 1357-0135 */ \
"sub %[dr0], %[dr0], #4\n" /* sub */ \
"sub %[dr1], %[dr1], #4\n" /* sub */ \
"fmax v4.4s, v0.4s, v6.4s\n" /* max */ \
"fmax v5.4s, v2.4s, v8.4s\n" /* max */ \
"ld2 {v0.4s, v1.4s}, [%[dr0]], #32\n" /* load q0-q1, dr0, 0-7*/ \
"ld2 {v2.4s, v3.4s}, [%[dr1]], #32\n" /* load q2-q3, dr1, 0-7*/ \
"fmax v6.4s, v4.4s, v5.4s\n" /* max reduce */ \
"subs %w[cnt_num], %w[cnt_num], #1\n" /* subs cnt_num, #1*/ \
"st1 {v6.4s}, [%[dr_out]], #16\n" /* store 4 out, dr_out */ \
"ble 2f\n" /* bne s3_max_loop_mid */
#define P2x2S2P0_MAX \
"1: \n" \
"fmax v4.4s, v0.4s, v1.4s\n" /* max */ \
......@@ -217,6 +231,21 @@ void pooling_basic(const float* din,
"st1 {v6.4s}, [%[dr_out]], #16\n" /* store 4 out, dr_out */ \
"bne 1b\n" /* bne s3_max_loop_mid */
#define P2x2S2P1_AVG \
"ext v6.16b, %[vzero].16b, v1.16b, #12\n" /* 1357-0135 */ \
"ext v8.16b, %[vzero].16b, v3.16b, #12\n" /* 1357-0135 */ \
"sub %[dr0], %[dr0], #4\n" /* sub */ \
"sub %[dr1], %[dr1], #4\n" /* sub */ \
"fadd v4.4s, v0.4s, v6.4s\n" /* add 0, 2, 4, 6 and 1, 3, 5, 7 */ \
"fadd v5.4s, v2.4s, v8.4s\n" /* add 0, 2, 4, 6 and 1, 3, 5, 7 */ \
"ld2 {v0.4s, v1.4s}, [%[dr0]], #32\n" /* load q0-q1, dr0, 0-7*/ \
"ld2 {v2.4s, v3.4s}, [%[dr1]], #32\n" /* load q2-q3, dr1, 0-7*/ \
"fadd v6.4s, v4.4s, v5.4s\n" /* add reduce */ \
"subs %w[cnt_num], %w[cnt_num], #1\n" /* subs cnt_num, #1*/ \
"fmul v4.4s, v6.4s, %[vcoef_left].4s\n" /* mul coef */ \
"st1 {v4.4s}, [%[dr_out]], #16\n" /* store 4 out, dr_out */ \
"ble 2f\n" /* bne s3_max_loop_mid */
#define P2x2S2P0_AVG \
"1: \n" /* load bias to q2, q3*/ \
"fadd v4.4s, v0.4s, v1.4s\n" /* add 0, 2, 4, 6 and 1, 3, 5, 7 */ \
......@@ -228,6 +257,7 @@ void pooling_basic(const float* din,
"fmul v4.4s, v6.4s, %[vcoef].4s\n" /* mul coef */ \
"st1 {v4.4s}, [%[dr_out]], #16\n" /* store 4 out, dr_out */ \
"bne 1b\n" /* bne s3_max_loop_mid */
#define P3x3S1_INIT \
"ldr q0, [%[dr0]], #16\n" /* load q0, dr0, 0-3*/ \
"ldr q1, [%[dr1]], #16\n" /* load q1, dr1, 0-3*/ \
......@@ -518,16 +548,45 @@ void pooling_basic(const float* din,
"vld2.f32 {d0-d3}, [%[dr0]]! @ load \n" \
"vld2.f32 {d4-d7}, [%[dr1]]! @ load \n"
#define P2x2S2P1_MAX \
"vext.32 q4, %q[vzero], q1, #3 @ 1357-0135\n" \
"vext.32 q5, %q[vzero], q3, #3 @ 1357-0135\n" \
"sub %[dr0], #4 @sub \n" \
"sub %[dr1], #4 @sub \n" \
"vmax.f32 q8, q0, q4 @ max \n" \
"vmax.f32 q9, q2, q5 @ max \n" \
"vld2.f32 {d0-d3}, [%[dr0]]! @ load \n" \
"vld2.f32 {d4-d7}, [%[dr1]]! @ load \n" \
"vmax.f32 q5, q9, q8 @ max reduce\n" \
"subs %[cnt_num], #1 @ subs cnt_num \n" \
"vst1.f32 {d10-d11}, [%[dr_out]]! @ store 4 out \n" \
"ble 2f @ bne \n"
#define P2x2S2P0_MAX \
"1: @ main loop\n" \
"vmax.f32 q4, q0, q1 @ max \n" \
"vmax.f32 q5, q2, q3 @ max \n" \
"vld2.f32 {d0-d3}, [%[dr0]]! @ load \n" \
"vld2.f32 {d4-d7}, [%[dr1]]! @ load \n" \
"vmax.f32 q6, q4, q5 @ max reduce\n" \
"vmax.f32 q8, q4, q5 @ max reduce\n" \
"subs %[cnt_num], #1 @ subs cnt_num \n" \
"vst1.f32 {d12-d13}, [%[dr_out]]! @ store 4 out \n" \
"bne 1b @ bne "
"vst1.f32 {d16-d17}, [%[dr_out]]! @ store 4 out \n" \
"bne 1b @ bne \n"
#define P2x2S2P1_AVG \
"vext.32 q4, %q[vzero], q1, #3 @ 1357-0135\n" \
"vext.32 q5, %q[vzero], q3, #3 @ 1357-0135\n" \
"sub %[dr0], #4 @sub \n" \
"sub %[dr1], #4 @sub \n" \
"vadd.f32 q9, q0, q4 @ max \n" \
"vadd.f32 q8, q2, q5 @ max \n" \
"vld2.f32 {d0-d3}, [%[dr0]]! @ load \n" \
"vld2.f32 {d4-d7}, [%[dr1]]! @ load \n" \
"vadd.f32 q5, q9, q8 @ max reduce\n" \
"subs %[cnt_num], #1 @ subs cnt_num \n" \
"vmul.f32 q4, q5, %q[vcoef_left] @ mul coef \n" \
"vst1.f32 {d8-d9}, [%[dr_out]]! @ store 4 out \n" \
"ble 2f @ bne\n"
#define P2x2S2P0_AVG \
"1: @ main loop\n" \
......@@ -535,9 +594,9 @@ void pooling_basic(const float* din,
"vadd.f32 q5, q2, q3 @ add 0, 2, 4, 6 \n" \
"vld2.f32 {d0-d3}, [%[dr0]]! @ load d0-d3 \n" \
"vld2.f32 {d4-d7}, [%[dr1]]! @ load d4-d7 \n" \
"vadd.f32 q6, q4, q5 @ add reduce \n" \
"vadd.f32 q8, q4, q5 @ add reduce \n" \
"subs %[cnt_num], #1 @ subs \n" \
"vmul.f32 q4, q6, %q[vcoef] @ mul coef \n" \
"vmul.f32 q4, q8, %q[vcoef] @ mul coef \n" \
"vst1.f32 {d8-d9}, [%[dr_out]]! @ store 4 out \n" \
"bne 1b @ bne \n"
......@@ -1037,7 +1096,7 @@ void pooling1x1s2p0_max(const float* din,
TargetFree(TARGET(kARM), write_ptr);
}
void pooling2x2s2_max(const float* din,
void pooling2x2s2p0_max(const float* din,
float* dout,
int num,
int chout,
......@@ -1095,7 +1154,7 @@ void pooling2x2s2_max(const float* din,
[dr_out] "+r"(dr_out),
[cnt_num] "+r"(cnt_num)
:
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6");
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q8");
#endif
dr0 -= 8;
dr1 -= 8;
......@@ -1121,7 +1180,7 @@ void pooling2x2s2_max(const float* din,
}
}
void pooling2x2s2_avg(const float* din,
void pooling2x2s2p0_avg(const float* din,
float* dout,
int num,
int chout,
......@@ -1158,12 +1217,14 @@ void pooling2x2s2_avg(const float* din,
const float* data_in_channel = data_in_batch + c * size_channel_in;
const float* r0 = data_in_channel;
const float* r1 = r0 + win;
vcoef = vdupq_n_f32(0.25f);
for (int h = 0; h < hout; h++) {
float* dr_out = data_out_channel;
auto dr0 = r0;
auto dr1 = r1;
if (h * S + K - P > hin) {
dr1 = zero_ptr;
vcoef = vdupq_n_f32(0.5f);
}
int cnt_num = w_unroll_size;
if (w_unroll_size > 0) {
......@@ -1184,7 +1245,7 @@ void pooling2x2s2_avg(const float* din,
[dr_out] "+r"(dr_out),
[cnt_num] "+r"(cnt_num)
: [vcoef] "w"(vcoef)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6");
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q8");
#endif
dr0 -= 8;
dr1 -= 8;
......@@ -1194,8 +1255,14 @@ void pooling2x2s2_avg(const float* din,
int wstart = 0;
for (int j = 0; j < w_unroll_remian; ++j) {
int wend = std::min(wstart + K, rem);
float coef = 0.5f / (wend - wstart);
float coef = 0.25f;
float tmp = 0.f;
if (wend - wstart == 1 && pad_right == 0) {
coef *= 2;
}
if (h * S + K - P > hin && pad_bottom == 0) {
coef *= 2;
}
for (int i = wstart; i < wend; i++) {
tmp += dr0[i] + dr1[i];
}
......@@ -1212,6 +1279,235 @@ void pooling2x2s2_avg(const float* din,
TargetFree(TARGET(kARM), zero_ptr);
}
void pooling2x2s2p1_max(const float* din,
float* dout,
int num,
int chout,
int hout,
int wout,
int chin,
int hin,
int win,
int pad_bottom,
int pad_right) {
int size_channel_out = wout * hout;
int size_channel_in = win * hin;
auto data_out = static_cast<float*>(dout);
auto data_in = static_cast<const float*>(din);
const int K = 2;
const int P = 1;
const int S = 2;
int w_unroll_size = wout / 4;
int w_unroll_remian = wout - w_unroll_size * 4;
float32x4_t vzero = vdupq_n_f32(std::numeric_limits<float>::lowest());
if (w_unroll_remian == 0) {
w_unroll_size -= 1;
w_unroll_remian = wout - w_unroll_size * 4;
}
for (int n = 0; n < num; ++n) {
float* data_out_batch = data_out + n * chout * size_channel_out;
const float* data_in_batch = data_in + n * chin * size_channel_in;
#pragma omp parallel for
for (int c = 0; c < chout; c++) {
float* data_out_channel = data_out_batch + c * size_channel_out;
const float* data_in_channel = data_in_batch + c * size_channel_in;
const float* r0 = data_in_channel;
const float* r1 = r0 + win;
for (int h = 0; h < hout; h++) {
float* dr_out = data_out_channel;
auto dr0 = r0;
auto dr1 = r1;
if (h == 0) {
dr0 = r0;
dr1 = r0;
r0 = r1;
r1 = r0 + win;
} else {
r0 = r1 + win;
r1 = r0 + win;
}
if (h * S + K - P > hin) {
dr1 = dr0;
if (h * S + K - P > hin + 1) {
memset(dr_out, 0, wout * sizeof(float));
continue;
}
}
int cnt_num = w_unroll_size;
if (w_unroll_size > 0) {
#ifdef __aarch64__
asm volatile(
P2x2S2_INIT P2x2S2P1_MAX P2x2S2P0_MAX "2: \n" /* end */
: [dr0] "+r"(dr0),
[dr1] "+r"(dr1),
[dr_out] "+r"(dr_out),
[cnt_num] "+r"(cnt_num)
: [vzero] "w"(vzero)
: "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v8");
#else
asm volatile(
P2x2S2_INIT P2x2S2P1_MAX P2x2S2P0_MAX "2: \n" /* end */
: [dr0] "+r"(dr0),
[dr1] "+r"(dr1),
[dr_out] "+r"(dr_out),
[cnt_num] "+r"(cnt_num)
: [vzero] "w"(vzero)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q8", "q9");
#endif
dr0 -= 8;
dr1 -= 8;
}
// deal with right pad
int wstart = w_unroll_size * 4 * S - P;
for (int j = 0; j < w_unroll_remian; ++j) {
int wend = std::min(wstart + K, win);
int st = wstart > 0 ? wstart : 0;
float tmp = wend == st ? 0.f : dr0[0];
for (int i = 0; i < wend - st; i++) {
tmp = std::max(tmp, dr0[i]);
tmp = std::max(tmp, dr1[i]);
}
*(dr_out++) = tmp;
dr0 += S - (st - wstart);
dr1 += S - (st - wstart);
wstart += S;
}
data_out_channel += wout;
}
}
}
}
void pooling2x2s2p1_avg(const float* din,
float* dout,
int num,
int chout,
int hout,
int wout,
int chin,
int hin,
int win,
bool exclusive,
int pad_bottom,
int pad_right) {
int size_channel_out = wout * hout;
int size_channel_in = win * hin;
auto data_out = static_cast<float*>(dout);
auto data_in = static_cast<const float*>(din);
const int K = 2;
const int P = 1;
const int S = 2;
int w_unroll_size = wout / 4;
int w_unroll_remian = wout - w_unroll_size * 4;
auto zero_ptr =
static_cast<float*>(TargetMalloc(TARGET(kARM), win * sizeof(float)));
float32x4_t vzero = vdupq_n_f32(0.f);
memset(zero_ptr, 0, win * sizeof(float));
if (w_unroll_remian == 0) {
w_unroll_size -= 1;
w_unroll_remian = wout - w_unroll_size * 4;
}
for (int n = 0; n < num; ++n) {
float* data_out_batch = data_out + n * chout * size_channel_out;
const float* data_in_batch = data_in + n * chin * size_channel_in;
#pragma omp parallel for
for (int c = 0; c < chout; c++) {
float* data_out_channel = data_out_batch + c * size_channel_out;
const float* data_in_channel = data_in_batch + c * size_channel_in;
const float* r0 = data_in_channel;
const float* r1 = r0 + win;
for (int h = 0; h < hout; h++) {
float* dr_out = data_out_channel;
auto dr0 = r0;
auto dr1 = r1;
float coef_h = 0.5f;
if (h == 0) {
dr0 = zero_ptr;
dr1 = r0;
r0 = r1;
r1 = r0 + win;
if (exclusive) {
coef_h = 1.f;
}
} else {
r0 = r1 + win;
r1 = r0 + win;
}
if (h * S + K - P > hin) {
dr1 = zero_ptr;
if (exclusive) {
coef_h = 1.f;
}
if (h * S + K - P > hin + 1) {
memset(dr_out, 0, wout * sizeof(float));
continue;
}
}
float coef_left_most = exclusive ? coef_h : coef_h / 2;
float32x4_t vcoef = vdupq_n_f32(coef_h / 2);
float coef_left[4] = {
coef_left_most, coef_h / 2, coef_h / 2, coef_h / 2};
float32x4_t vcoef_left = vld1q_f32(coef_left);
int cnt_num = w_unroll_size;
if (w_unroll_size > 0) {
#ifdef __aarch64__
asm volatile(
P2x2S2_INIT P2x2S2P1_AVG P2x2S2P0_AVG "2: \n"
: [dr0] "+r"(dr0),
[dr1] "+r"(dr1),
[dr_out] "+r"(dr_out),
[cnt_num] "+r"(cnt_num)
: [vcoef] "w"(vcoef),
[vzero] "w"(vzero),
[vcoef_left] "w"(vcoef_left)
: "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v8");
#else
asm volatile(
P2x2S2_INIT P2x2S2P1_AVG P2x2S2P0_AVG "2: \n"
: [dr0] "+r"(dr0),
[dr1] "+r"(dr1),
[dr_out] "+r"(dr_out),
[cnt_num] "+r"(cnt_num)
: [vcoef] "w"(vcoef),
[vzero] "w"(vzero),
[vcoef_left] "w"(vcoef_left)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q8", "q9");
#endif
dr0 -= 8;
dr1 -= 8;
}
// deal with right pad
int wstart = w_unroll_size * 4 * S - P;
for (int j = 0; j < w_unroll_remian; ++j) {
int wend = std::min(wstart + K, win);
int st = wstart > 0 ? wstart : 0;
float tmp = 0.f;
float coef = coef_h / 2;
if (exclusive && wend - st == 1) {
coef = coef_h;
}
for (int i = 0; i < wend - st; i++) {
tmp += dr0[i] + dr1[i];
}
*(dr_out++) = tmp * coef;
dr0 += S - (st - wstart);
dr1 += S - (st - wstart);
wstart += S;
}
data_out_channel += wout;
}
}
}
TargetFree(TARGET(kARM), zero_ptr);
}
void pooling3x3s1p1_max(const float* din,
float* dout,
int num,
......@@ -2240,6 +2536,9 @@ void pooling3x3s2p0_max(const float* din,
w_unroll_remian = wout - w_unroll_size * 4;
}
int remain = w_unroll_remian - 1;
int right = wout * 2 + 1 - win; // if need right pad
for (int n = 0; n < num; ++n) {
float* data_out_batch = data_out + n * chout * size_channel_out;
const float* data_in_batch = data_in + n * chin * size_channel_in;
......@@ -2266,6 +2565,7 @@ void pooling3x3s2p0_max(const float* din,
}
}
int cnt_num = w_unroll_size;
int cnt_remain = remain;
if (w_unroll_size > 0) {
#ifdef __aarch64__
asm volatile(P3x3S2P0_INIT P3x3S2P0_MAX
......@@ -2289,12 +2589,53 @@ void pooling3x3s2p0_max(const float* din,
"v9",
"v10",
"v11");
dr0 -= 8;
dr1 -= 8;
dr2 -= 8;
int rem = win - (w_unroll_size * 4) * S;
int wstart = 0;
for (int j = 0; j < w_unroll_remian; ++j) {
int wend = std::min(wstart + K, rem);
float tmp = dr0[wstart]; // std::numeric_limits<float>::min();
for (int i = wstart; i < wend; i++) {
tmp = std::max(tmp, dr0[i]);
tmp = std::max(tmp, dr1[i]);
tmp = std::max(tmp, dr2[i]);
}
*(dr_out++) = tmp;
wstart += S;
}
#else
asm volatile(P3x3S2P0_INIT P3x3S2P0_MAX
asm volatile(
P3x3S2P0_INIT P3x3S2P0_MAX
"cmp %[remain], #0 @cmp cnt_num\n"
"sub %[dr0], #32 @sub - 8\n"
"sub %[dr1], #32 @sub - 8\n"
"sub %[dr2], #32 @sub - 8\n"
"ble 4f @ble exit1\n"
"2: @mid loop\n"
"vld1.f32 {d0-d1}, [%[dr0]]! @load \n"
"vld1.f32 {d2-d3}, [%[dr1]]! @load \n"
"vld1.f32 {d4-d5}, [%[dr2]]! @load \n"
"vmov.f32 s3,s2 @mov \n"
"vmov.f32 s7,s6 @mov \n"
"vmov.f32 s11,s10 @mov \n"
"vmax.f32 q0, q0, q1 @max n"
"sub %[dr0], #8 @add w \n"
"sub %[dr1], #8 @add w \n"
"sub %[dr2], #8 @add w \n"
"vmax.f32 q0, q0, q2 @max \n"
"vpmax.f32 d0, d0, d1 @pmax \n"
"vpmax.f32 d0, d0, d0 @pmax \n"
"subs %[remain], #1 @subs \n"
"vst1.f32 d0[0], [%[dr_out]]! @vst \n"
"bne 2b @bne \n"
"4: @exit\n"
: [dr0] "+r"(dr0),
[dr1] "+r"(dr1),
[dr2] "+r"(dr2),
[dr_out] "+r"(dr_out),
[remain] "+r"(cnt_remain),
[cnt_num] "+r"(cnt_num)
:
: "cc",
......@@ -2311,24 +2652,17 @@ void pooling3x3s2p0_max(const float* din,
"q9",
"q10",
"q11");
#endif
dr0 -= 8;
dr1 -= 8;
dr2 -= 8;
}
// deal with right pad
int rem = win - (w_unroll_size * 4) * S;
int wstart = 0;
for (int j = 0; j < w_unroll_remian; ++j) {
int wend = std::min(wstart + K, rem);
if (right) {
int wstart = (w_unroll_size * 4 + remain) * S;
int wend = std::min(wstart + K, win);
float tmp = dr0[wstart]; // std::numeric_limits<float>::min();
for (int i = wstart; i < wend; i++) {
tmp = std::max(tmp, dr0[i]);
tmp = std::max(tmp, dr1[i]);
tmp = std::max(tmp, std::max(dr0[i], dr1[i]));
tmp = std::max(tmp, dr2[i]);
}
*(dr_out++) = tmp;
wstart += S;
}
#endif
}
r0 = r2;
......@@ -2368,6 +2702,9 @@ void pooling3x3s2p0_avg(const float* din,
w_unroll_remian = wout - w_unroll_size * 4;
}
// do overflow process
w_unroll_size -= 1;
w_unroll_remian += 4;
auto zero_ptr =
static_cast<float*>(TargetMalloc(TARGET(kARM), win * sizeof(float)));
memset(zero_ptr, 0, win * sizeof(float));
......
......@@ -76,7 +76,7 @@ void pooling1x1s2p0_max(const float* din,
int pad_bottom,
int pad_right);
void pooling2x2s2_max(const float* din,
void pooling2x2s2p0_max(const float* din,
float* dout,
int num,
int chout,
......@@ -88,7 +88,32 @@ void pooling2x2s2_max(const float* din,
int pad_bottom,
int pad_right);
void pooling2x2s2_avg(const float* din,
void pooling2x2s2p0_avg(const float* din,
float* dout,
int num,
int chout,
int hout,
int wout,
int chin,
int hin,
int win,
bool exclusive,
int pad_bottom,
int pad_right);
void pooling2x2s2p1_max(const float* din,
float* dout,
int num,
int chout,
int hout,
int wout,
int chin,
int hin,
int win,
int pad_bottom,
int pad_right);
void pooling2x2s2p1_avg(const float* din,
float* dout,
int num,
int chout,
......
......@@ -531,7 +531,7 @@ void softmax_inner1_large_axis<float>(const float* din,
}
float32x2_t vhmax = vmax_f32(vget_high_f32(vmax), vget_low_f32(vmax));
float max_data = std::max(vget_lane_f32(vhmax, 0), vget_lane_f32(vhmax, 1));
for (j = 4 * j; j < axis_size; ++j) {
for (j = 4 * nn; j < axis_size; ++j) {
max_data = std::max(max_data, din_max_ptr[0]);
din_max_ptr++;
}
......@@ -557,7 +557,7 @@ void softmax_inner1_large_axis<float>(const float* din,
float32x2_t vhsum = vadd_f32(vget_high_f32(vsum), vget_low_f32(vsum));
float sum_data = vget_lane_f32(vhsum, 0) + vget_lane_f32(vhsum, 1);
for (j = 4 * j; j < axis_size; ++j) {
for (j = 4 * nn; j < axis_size; ++j) {
dout_sum_ptr[0] = expf(din_sum_ptr[0] - max_data);
sum_data += dout_sum_ptr[0];
din_sum_ptr++;
......
......@@ -50,13 +50,14 @@ class PoolingPE : public PE {
PoolingArgs args = {0};
args.mode = param_.type;
auto paddings = *param_.paddings;
args.kernel_reciprocal = fp32_2_fp16(1.0f / (k_width * k_height));
args.image.address = input->data<float16>();
args.image.channels = input->shape().channel();
args.image.height = input->shape().height();
args.image.width = input->shape().width();
args.image.pad_height = param_.paddings[0];
args.image.pad_width = param_.paddings[1];
args.image.pad_height = paddings[0];
args.image.pad_width = paddings[2];
args.image.scale_address = input->scale();
args.output.address = output->mutableData<float16>();
args.output.scale_address = output->scale();
......@@ -69,8 +70,7 @@ class PoolingPE : public PE {
param_.poolingArgs = args;
// use_cpu_ = output->shape().width() == 1 && output->shape().height() == 1
// &&
// (k_width > 7 || k_height > 7);
// && (k_width > 7 || k_height > 7);
use_cpu_ = output->shape().width() == 1 && output->shape().height() == 1 &&
(k_width > 255 || k_height > 255);
// use_cpu_ = param_.type == AVERAGE;
......@@ -86,12 +86,13 @@ class PoolingPE : public PE {
float* image_addr = float_input.mutableData<float>(FP32, input->shape());
float_input.copyFrom(input);
float16* data_out = output->data<float16>();
auto paddings = *param_.paddings;
int image_height = input->shape().height();
int image_width = input->shape().width();
int image_channels = input->shape().channel();
int image_pad_h = param_.paddings[0];
int image_pad_w = param_.paddings[1];
int image_pad_h = paddings[0];
int image_pad_w = paddings[2];
int kernel_height = param_.kernelSize[1];
int kernel_width = param_.kernelSize[0];
int kernel_step_h = param_.strides[0];
......
......@@ -71,6 +71,9 @@ void ConcatCompute::Run() {
auto* axis_tensor_data = axis_tensor->data<int>();
axis = axis_tensor_data[0];
}
if (axis < 0) {
axis += inputs[0]->dims().size();
}
switch (inputs.front()->precision()) {
case PRECISION(kFloat):
......
......@@ -73,7 +73,6 @@ void ConvCompute<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
// VLOG(3) << "invoking dw conv";
} else if (param.groups == 1 && kw == 3 && stride == 1 && ks_equal &&
no_dilation) {
// TODO(MyPandaShaoxiang): winograd conv support any pad
impl_ = new WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>;
// VLOG(3) << "invoking winograd conv";
} else if (param.groups == 1 && kw == 3 && stride == 2 &&
......@@ -122,10 +121,14 @@ void ConvCompute<PRECISION(kInt8), PRECISION(kFloat)>::PrepareForRun() {
no_dilation && flag_dw) {
impl_ = new DepthwiseConv<PRECISION(kInt8), PRECISION(kFloat)>;
// VLOG(3) << "Run DepthwiseConv Int8";
} else if (param.groups == 1 && kw == 3 && (sw == 1 || sw == 2) &&
ic * oc < 4 * hin * win && kps_equal && no_dilation) {
} else if (param.groups == 1 && kw == 3 && sw == 2 && no_dilation &&
pads_equal) {
impl_ = new DirectConv<PRECISION(kInt8), PRECISION(kFloat)>;
// VLOG(3) << "Run DirectConv Int8";
} else if (param.groups == 1 && kw == 3 && sw == 1 && no_dilation &&
pads_equal) {
impl_ = new WinogradConv<PRECISION(kInt8), PRECISION(kFloat)>;
// VLOG(3) << "Run WinogradConv Int8";
} else {
impl_ = new GemmLikeConv<PRECISION(kInt8), PRECISION(kFloat)>;
// VLOG(3) << "Run GemmLikeConvInt8";
......@@ -169,10 +172,14 @@ void ConvCompute<PRECISION(kInt8), PRECISION(kInt8)>::PrepareForRun() {
no_dilation && flag_dw) {
impl_ = new DepthwiseConv<PRECISION(kInt8), PRECISION(kInt8)>;
// VLOG(3) << "Run DepthwiseConv Int8";
} else if (param.groups == 1 && kw == 3 && (sw == 1 || sw == 2) &&
ic * oc < 4 * hin * win && kps_equal && no_dilation) {
} else if (param.groups == 1 && kw == 3 && sw == 2 && no_dilation &&
pads_equal) {
impl_ = new DirectConv<PRECISION(kInt8), PRECISION(kInt8)>;
// VLOG(3) << "Run DirectConv Int8";
} else if (param.groups == 1 && kw == 3 && sw == 1 && no_dilation &&
pads_equal) {
impl_ = new WinogradConv<PRECISION(kInt8), PRECISION(kInt8)>;
// VLOG(3) << "Run WinogradConv Int8";
} else {
impl_ = new GemmLikeConv<PRECISION(kInt8), PRECISION(kInt8)>;
// VLOG(3) << "Run GemmLikeConvInt8";
......
......@@ -13,7 +13,6 @@
// limitations under the License.
#include "lite/kernels/arm/conv_winograd.h"
#include <vector>
#include "lite/backends/arm/math/conv_impl.h"
#include "lite/backends/arm/math/packed_sgemm.h"
......@@ -166,6 +165,189 @@ void WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>::Run() {
}
}
template <PrecisionType OutType>
void WinogradConv<PRECISION(kInt8), OutType>::ReInitWhenNeeded() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<ARMContext>();
int threads = ctx.threads();
auto x_dims = param.x->dims();
auto w_dims = param.filter->dims();
auto o_dims = param.output->dims();
if (last_shape_ == x_dims) {
return;
}
last_shape_ = x_dims;
//! update workspace size
int ic = x_dims[1];
int ih = x_dims[2];
int iw = x_dims[3];
int oc = o_dims[1];
int oh = o_dims[2];
int ow = o_dims[3];
int tile_block = 8;
auto pad = *(param.paddings);
int pad_h0 = pad[0];
int pad_h1 = pad[1];
int pad_w0 = pad[2];
int pad_w1 = pad[3];
int oc_pad = (oc + 7) / 8 * 8;
int ic_pad = (ic + 7) / 8 * 8;
const int new_input_size =
ic_pad * (ih + pad_h0 + pad_h1) * (iw + pad_w0 + pad_w1) +
oc_pad * oh * ow * sizeof(int32_t);
int tmp_input_thread_size_byte =
tile_block * ic_pad * wino_iw * wino_iw * sizeof(int16_t);
int tmp_output_thread_size_byte =
tile_block * oc_pad * wino_iw * wino_iw * sizeof(int32_t);
const int temp_size =
(tmp_input_thread_size_byte + tmp_output_thread_size_byte +
wino_iw * wino_iw * (8 + 8 * sizeof(int32_t))) *
threads;
workspace_size_ = temp_size + new_input_size;
//! update trans weights impl
// choose_small_ = ow * oh / (tile_block * threads) < 36 ? true : false;
// we only support 2x2 now
choose_small_ = true;
float w_fact = 0.25;
if (choose_small_) {
wino_iw = 4;
if (last_function_ == 0) {
return;
}
last_function_ = 0;
} else {
wino_iw = 6;
if (last_function_ == 1) {
return;
}
last_function_ = 1;
}
/// update scale
for (auto& ws : w_scale_) {
ws *= w_fact;
}
weights_.Resize({1, 1, 1, wino_iw * wino_iw * oc_pad * ic_pad});
void* trans_tmp_ptr = malloc(sizeof(int16_t) * wino_iw * wino_iw * oc * ic);
auto weights_data_ = weights_.mutable_data<int16_t>();
if (!choose_small_) {
} else {
lite::arm::math::weight_trans_c8_4x4_int8(
weights_data_,
param.filter->template data<int8_t>(),
ic,
oc,
trans_tmp_ptr);
}
free(trans_tmp_ptr);
}
template <PrecisionType OutType>
void WinogradConv<PRECISION(kInt8), OutType>::PrepareForRun() {
auto& param = this->Param<param_t>();
w_scale_ = param.weight_scale;
if (w_scale_.size() != 1 && w_scale_.size() != param.filter->dims()[0]) {
LOG(FATAL) << "weights scale size must equal to filter size";
return;
}
if (w_scale_.size() == 1) {
for (int i = 0; i < param.filter->dims()[0] - 1; ++i) {
w_scale_.push_back(w_scale_[0]);
}
}
float input_scale = param.input_scale;
for (auto& ws : w_scale_) {
ws *= input_scale;
}
if (param.bias) {
bias_.Resize(param.bias->dims());
auto ptr = bias_.mutable_data<float>();
auto ptr_in = param.bias->template data<float>();
for (int i = 0; i < bias_.numel(); ++i) {
ptr[i] = ptr_in[i];
}
}
if (OutType == PRECISION(kInt8)) {
float output_scale = param.output_scale;
for (auto& ws : w_scale_) {
ws /= output_scale;
}
if (param.bias) {
auto ptr = bias_.mutable_data<float>();
for (int i = 0; i < bias_.numel(); ++i) {
ptr[i] /= output_scale;
}
}
}
ReInitWhenNeeded();
}
template <PrecisionType OutType>
void WinogradConv<PRECISION(kInt8), OutType>::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<ARMContext>();
ctx.ExtendWorkspace(workspace_size_);
const auto* i_data = param.x->template data<int8_t>();
const auto* w_data = weights_.data<int16_t>();
const auto* b_data = param.bias ? bias_.data<float>() : nullptr;
// const float* i_data;
auto x_dims = param.x->dims();
auto w_dims = param.filter->dims();
auto o_dims = param.output->dims();
int iw = x_dims[3]; // nchw
int ih = x_dims[2];
int ic = x_dims[1];
int bs = x_dims[0];
int oh = o_dims[2];
int ow = o_dims[3];
int oc = o_dims[1];
// now always choose small
if (OutType == PRECISION(kInt8)) {
auto* o_data = param.output->template mutable_data<int8_t>();
lite::arm::math::conv_compute_2x2_3x3_int8<int8_t>(i_data,
o_data,
bs,
oc,
oh,
ow,
ic,
ih,
iw,
w_data,
b_data,
w_scale_.data(),
param,
&ctx);
} else {
auto* o_data = param.output->template mutable_data<float>();
lite::arm::math::conv_compute_2x2_3x3_int8<float>(i_data,
o_data,
bs,
oc,
oh,
ow,
ic,
ih,
iw,
w_data,
b_data,
w_scale_.data(),
param,
&ctx);
}
#ifdef LITE_WITH_PROFILE
kernel_func_name_ = "conv_compute_2x2_3x3_int8";
#endif
}
template class WinogradConv<PRECISION(kInt8), PRECISION(kInt8)>;
template class WinogradConv<PRECISION(kInt8), PRECISION(kFloat)>;
} // namespace arm
} // namespace kernels
} // namespace lite
......
......@@ -15,11 +15,12 @@
#pragma once
#include <cmath>
#include <string>
#include <vector>
#include "lite/backends/arm/math/conv_impl.h"
#include "lite/core/context.h"
#include "lite/core/kernel.h"
#include "lite/core/target_wrapper.h"
namespace paddle {
namespace lite {
namespace kernels {
......@@ -44,7 +45,34 @@ class WinogradConv : public KernelLite<TARGET(kARM), Ptype> {
bool choose_small_{false};
int wino_iw{8};
};
template <PrecisionType OutType>
class WinogradConv<PRECISION(kInt8), OutType>
: public KernelLite<TARGET(kARM), PRECISION(kInt8)> {
public:
WinogradConv() = default;
~WinogradConv() {}
virtual void PrepareForRun();
virtual void ReInitWhenNeeded();
virtual void Run();
#ifdef LITE_WITH_PROFILE
virtual void SetProfileRuntimeKernelInfo(
paddle::lite::profile::OpCharacter* ch) {
ch->kernel_func_name = kernel_func_name_;
}
std::string kernel_func_name_{"NotImplForConvWino"};
#endif
protected:
using param_t = operators::ConvParam;
Tensor weights_;
Tensor bias_;
DDim last_shape_;
int workspace_size_{0};
int last_function_{-1};
bool choose_small_{true};
int wino_iw{4};
std::vector<float> w_scale_;
};
} // namespace arm
} // namespace kernels
} // namespace lite
......
......@@ -58,6 +58,7 @@ void PoolCompute::Run() {
bool global_pooling = (paddings[0] == 0) && (ksize[0] == in_dims[2]) &&
(ksize[1] == in_dims[3]) && kps_equal && pads_equal;
global_pooling = param.global_pooling || global_pooling;
if (global_pooling) {
for (size_t i = 0; i < ksize.size(); ++i) {
paddings[2 * i] = 0;
......@@ -107,7 +108,7 @@ void PoolCompute::Run() {
} else if (ksize[0] == 2 && strides[0] == 2 && paddings[0] == 0 &&
kps_equal) {
if (pooling_type == "max") {
lite::arm::math::pooling2x2s2_max(din,
lite::arm::math::pooling2x2s2p0_max(din,
dout,
out_dims[0],
out_dims[1],
......@@ -120,7 +121,7 @@ void PoolCompute::Run() {
paddings[3]);
return;
} else if (pooling_type == "avg") {
lite::arm::math::pooling2x2s2_avg(din,
lite::arm::math::pooling2x2s2p0_avg(din,
dout,
out_dims[0],
out_dims[1],
......@@ -134,8 +135,38 @@ void PoolCompute::Run() {
paddings[3]);
return;
}
} else if (ksize[0] == 3 && strides[0] == 1 && paddings[0] == 1 &&
} else if (ksize[0] == 2 && strides[0] == 2 && paddings[0] == 1 &&
kps_equal) {
if (pooling_type == "max") {
lite::arm::math::pooling2x2s2p1_max(din,
dout,
out_dims[0],
out_dims[1],
out_dims[2],
out_dims[3],
in_dims[1],
in_dims[2],
in_dims[3],
paddings[1],
paddings[3]);
return;
} else if (pooling_type == "avg") {
lite::arm::math::pooling2x2s2p1_avg(din,
dout,
out_dims[0],
out_dims[1],
out_dims[2],
out_dims[3],
in_dims[1],
in_dims[2],
in_dims[3],
exclusive,
paddings[1],
paddings[3]);
return;
}
} else if (ksize[0] == 3 && strides[0] == 1 && paddings[0] == 1 &&
pads_equal && kps_equal) {
if (pooling_type == "max") {
lite::arm::math::pooling3x3s1p1_max(din,
dout,
......@@ -165,7 +196,7 @@ void PoolCompute::Run() {
return;
}
} else if (ksize[0] == 3 && strides[0] == 1 && paddings[0] == 0 &&
kps_equal) {
pads_equal && kps_equal) {
if (pooling_type == "max") {
lite::arm::math::pooling3x3s1p0_max(din,
dout,
......@@ -195,7 +226,7 @@ void PoolCompute::Run() {
return;
}
} else if (ksize[0] == 3 && strides[0] == 2 && paddings[0] == 0 &&
kps_equal) {
pads_equal && kps_equal) {
if (pooling_type == "max") {
lite::arm::math::pooling3x3s2p0_max(din,
dout,
......@@ -225,7 +256,7 @@ void PoolCompute::Run() {
return;
}
} else if (ksize[0] == 3 && strides[0] == 2 && paddings[0] == 1 &&
kps_equal) {
pads_equal && kps_equal) {
if (pooling_type == "max") {
lite::arm::math::pooling3x3s2p1_max(din,
dout,
......
......@@ -34,7 +34,7 @@ void SoftmaxCompute::Run() {
int inner_num = x_dims.Slice(axis + 1, x_rank).production();
int axis_size = x_dims[axis];
if (inner_num == 1) {
if (axis_size >= 4) {
if (axis_size > 4) {
lite::arm::math::softmax_inner1_large_axis(
din, dout, outer_num, axis_size);
} else {
......
......@@ -34,7 +34,7 @@ DEFINE_int32(power_mode,
DEFINE_int32(threads, 1, "threads num");
DEFINE_int32(warmup, 0, "warmup times");
DEFINE_int32(repeats, 1, "repeats times");
DEFINE_bool(basic_test, true, "do all tests");
DEFINE_bool(basic_test, false, "do all tests");
DEFINE_bool(check_result, true, "check the result");
DEFINE_int32(batch, 1, "batch size");
......@@ -59,6 +59,7 @@ DEFINE_bool(flag_bias, true, "with bias");
typedef paddle::lite::DDim DDim;
typedef paddle::lite::Tensor Tensor;
typedef paddle::lite::operators::ConvParam ConvParam;
typedef paddle::lite::operators::ActivationParam ActivationParam;
using paddle::lite::profile::Timer;
DDim compute_out_dim(const DDim& dim_in,
......@@ -165,7 +166,18 @@ void test_conv_int8(const std::vector<DDim>& input_dims,
param_fp32_out.bias->CopyDataFrom(*param_int8_out.bias);
bias_fp32.CopyDataFrom(*param_int8_out.bias);
}
if (flag_relu) {
ActivationParam act_param;
act_param.has_active = true;
act_param.active_type = (paddle::lite_api::ActivationType)
flag_relu; // 1-relu, 2-relu6, 4-leakyrelu
if (flag_relu) {
param_fp32_out.fuse_relu = true;
param_int8_out.fuse_relu = true;
}
param_fp32_out.activation_param = act_param;
param_int8_out.activation_param = act_param;
}
std::vector<float> scale_in{1.f / 127};
std::vector<float> scale_out{weight_dim.count(1, 4) / 127.f};
std::vector<float> scale_w(weight_dim[0], 1.f / 127);
......@@ -580,6 +592,9 @@ TEST(TestConv3x3s1Int8, test_conv_3x3s1) {
dims.push_back(DDim({batch, cin, h, h}));
}
}
if (cin == 1 && cout == 1) {
continue;
}
test_conv_int8(dims,
weights_dim,
1,
......
......@@ -179,6 +179,141 @@ bool test_sgemm_c4(
#endif
return true;
}
bool test_sgemm_c8(
int m, int n, int k, bool has_bias, bool has_relu, int cls, int ths) {
int m_round = (m + 7) / 8 * 8;
int k_round = (k + 7) / 8 * 8;
int size_a = m * k;
int size_b = n * k;
int size_a_c4 = m_round * k_round;
int size_b_c8 = k_round * n;
Tensor ta;
Tensor tb;
Tensor ta_c4;
Tensor tb_c8;
Tensor tc;
Tensor tc_basic;
Tensor tc_backup;
Tensor tbias;
ta.Resize({size_a});
tb.Resize({size_b});
ta_c4.Resize({size_a_c4});
tb_c8.Resize({size_b_c8});
tc.Resize({m_round * n});
tc_basic.Resize({m_round * n});
tbias.Resize({m});
ta.set_precision(PRECISION(kInt16));
tb.set_precision(PRECISION(kInt16));
ta_c4.set_precision(PRECISION(kInt16));
tb_c8.set_precision(PRECISION(kInt16));
tc.set_precision(PRECISION(kInt32));
tc_basic.set_precision(PRECISION(kInt32));
tbias.set_precision(PRECISION(kInt32));
fill_tensor_rand(ta);
fill_tensor_rand(tb);
fill_tensor_rand(tbias);
fill_tensor_rand(tc);
auto da = ta.mutable_data<int16_t>();
auto db = tb.mutable_data<int16_t>();
auto da_c4 = ta_c4.mutable_data<int16_t>();
auto db_c8 = tb_c8.mutable_data<int16_t>();
auto dc_basic = tc_basic.mutable_data<int32_t>();
auto dbias = tbias.mutable_data<int32_t>();
// trans A, B to c4
basic_trans_mat_to_c8(da, da_c4, k, m, k, true);
basic_trans_mat_to_c8(db, db_c8, n, k, n, false);
LOG(INFO) << "sgemm_c8 M: " << m << ", N: " << n << ", K: " << k
<< ", relu: " << (has_relu ? "true" : "false")
<< ", bias: " << (has_bias ? "true" : "false");
if (FLAGS_check_result) {
basic_gemm_c8(false,
false,
m,
n,
k,
1,
da,
k,
db,
n,
0,
dc_basic,
n,
dbias,
false,
false);
}
Timer t0;
LOG(INFO) << "basic test end";
#ifdef LITE_WITH_ARM
//! compute
double ops = 2.0 * m_round * n * k_round;
std::unique_ptr<paddle::lite::KernelContext> ctx1(
new paddle::lite::KernelContext);
auto& ctx = ctx1->As<paddle::lite::ARMContext>();
ctx.SetRunMode(static_cast<paddle::lite_api::PowerMode>(cls), ths);
auto dc = tc.mutable_data<int32_t>();
for (int j = 0; j < FLAGS_warmup; ++j) {
paddle::lite::arm::math::sgemm_prepack_c8_int16_small(
m, n, k, da_c4, db_c8, dc, &ctx);
}
LOG(INFO) << "basic test end";
for (int i = 0; i < FLAGS_repeats; ++i) {
t0.Start();
paddle::lite::arm::math::sgemm_prepack_c8_int16_small(
m, n, k, da_c4, db_c8, dc, &ctx);
t0.Stop();
}
LOG(INFO) << "basic test end";
LOG(INFO) << "M: " << m << ", N: " << n << ", K: " << k
<< ", power_mode: " << cls << ", threads: " << ths
<< ", GOPS: " << ops * 1e-9f
<< " GOPS, avg time: " << t0.LapTimes().Avg()
<< " ms, min time: " << t0.LapTimes().Min()
<< " ms, mean GOPs: " << ops * 1e-6f / t0.LapTimes().Avg()
<< " GOPs, max GOPs: " << ops * 1e-6f / t0.LapTimes().Min()
<< " GOPs";
if (FLAGS_check_result) {
double max_ratio = 0;
double max_diff = 0;
tensor_cmp_host(tc_basic, tc, max_ratio, max_diff);
LOG(INFO) << "compare result, max diff: " << max_diff
<< ", max ratio: " << max_ratio;
if (std::abs(max_ratio) > 1e-4f && std::abs(max_diff) > 5e-5f) {
Tensor tdiff;
tdiff.set_precision(PRECISION(kInt32));
tdiff.Resize(tc.dims());
tensor_diff(tc_basic, tc, tdiff);
LOG(INFO) << "a: ";
print_tensor(ta);
LOG(INFO) << "a_c8: ";
print_tensor(ta_c4);
LOG(INFO) << "b: ";
print_tensor(tb);
LOG(INFO) << "b_c8: ";
print_tensor(tb_c8);
LOG(INFO) << "basic result: ";
print_tensor(tc_basic);
LOG(INFO) << "lite result: ";
print_tensor(tc);
LOG(INFO) << "diff result: ";
print_tensor(tdiff);
return false;
}
}
#endif
return true;
}
TEST(TestSgemmC4, test_func_sgemm_c4_prepacked) {
if (FLAGS_basic_test) {
......@@ -186,11 +321,11 @@ TEST(TestSgemmC4, test_func_sgemm_c4_prepacked) {
paddle::lite::DeviceInfo::Init();
#endif
LOG(INFO) << "run basic sgemm_c4 test";
for (auto& m : {1, 3, 8, 32, 397}) {
for (auto& n : {1, 2, 3, 4, 13, 141, 789}) {
for (auto& k : {1, 3, 8, 59, 234}) {
for (auto& has_bias : {false, true}) {
for (auto& has_relu : {false, true}) {
for (auto& m : {1, 3, 8, 32, 397, 32, 64, 77}) {
for (auto& n : {1, 2, 3, 4, 13, 141, 789, 1}) {
for (auto& k : {1, 3, 8, 59, 234, 19}) {
for (auto& has_bias : {false}) {
for (auto& has_relu : {false}) {
for (auto& th : {1, 2, 4}) {
auto flag = test_sgemm_c4(
m, n, k, has_bias, has_relu, FLAGS_power_mode, th);
......@@ -213,8 +348,41 @@ TEST(TestSgemmC4, test_func_sgemm_c4_prepacked) {
}
}
}
TEST(TestSgemmC8, test_func_sgemm_c8_prepacked) {
if (FLAGS_basic_test) {
#ifdef LITE_WITH_ARM
paddle::lite::DeviceInfo::Init();
#endif
LOG(INFO) << "run basic sgemm_c4 test";
for (auto& m : {1, 3, 8, 32, 397, 32, 64, 77}) {
for (auto& n : {1, 2, 3, 4, 13, 141, 789, 1}) {
for (auto& k : {1, 3, 8, 59, 234, 19}) {
for (auto& has_bias : {false}) {
for (auto& has_relu : {false}) {
for (auto& th : {1}) {
auto flag = test_sgemm_c8(
m, n, k, has_bias, has_relu, FLAGS_power_mode, th);
if (flag) {
LOG(INFO) << "test m = " << m << ", n=" << n << ", k=" << k
<< ", bias: " << (has_bias ? "true" : "false")
<< ", relu: " << (has_relu ? "true" : "false")
<< " passed\n";
} else {
LOG(FATAL) << "test m = " << m << ", n=" << n << ", k=" << k
<< ", bias: " << (has_bias ? "true" : "false")
<< ", relu: " << (has_relu ? "true" : "false")
<< " failed\n";
}
}
}
}
}
}
}
}
}
TEST(TestSgemmC4Custom, test_func_sgemm_c4_prepacked_custom) {
TEST(TestSgemmCnCustom, test_func_sgemm_cn_prepacked_custom) {
#ifdef LITE_WITH_ARM
paddle::lite::DeviceInfo::Init();
#endif
......@@ -230,6 +398,18 @@ TEST(TestSgemmC4Custom, test_func_sgemm_c4_prepacked_custom) {
<< ", k=" << FLAGS_K << ", bias: " << FLAGS_flag_bias
<< ", relu: " << FLAGS_flag_relu << " failed!!";
}
flag = test_sgemm_c8(FLAGS_M,
FLAGS_N,
FLAGS_K,
FLAGS_flag_bias,
FLAGS_flag_relu,
FLAGS_power_mode,
FLAGS_threads);
if (!flag) {
LOG(FATAL) << "test m = " << FLAGS_M << ", n=" << FLAGS_N
<< ", k=" << FLAGS_K << ", bias: " << FLAGS_flag_bias
<< ", relu: " << FLAGS_flag_relu << " failed!!";
}
LOG(INFO) << "test m = " << FLAGS_M << ", n=" << FLAGS_N << ", k=" << FLAGS_K
<< ", bias: " << FLAGS_flag_bias << ", relu: " << FLAGS_flag_relu
<< " passed!!";
......
......@@ -60,6 +60,72 @@ static void basic_trans_mat_to_c4(const type* input,
}
}
}
template <typename type>
static void basic_trans_mat_to_c8(const type* input,
type* output,
const int ldin,
const int M,
const int K,
bool pack_k) {
const int m_round = (M + 7) / 8 * 8;
int k_round = (K + 7) / 8 * 8;
if (!pack_k) {
k_round = K;
}
const int m_loop = m_round / 8;
type zero_buf[K];
memset(zero_buf, 0, K * sizeof(type));
for (int i = 0; i < m_loop; ++i) {
const type* in0 = input + i * 8 * ldin;
const type* in1 = in0 + ldin;
const type* in2 = in1 + ldin;
const type* in3 = in2 + ldin;
const type* in4 = in3 + ldin;
const type* in5 = in4 + ldin;
const type* in6 = in5 + ldin;
const type* in7 = in6 + ldin;
if (8 * (i + 1) - M > 0) {
switch (8 * (i + 1) - M) {
case 7:
in1 = zero_buf;
case 6:
in2 = zero_buf;
case 5:
in3 = zero_buf;
case 4:
in4 = zero_buf;
case 3:
in5 = zero_buf;
case 2:
in6 = zero_buf;
case 1:
in7 = zero_buf;
default:
break;
}
}
for (int j = 0; j < K; ++j) {
*output++ = *in0++;
*output++ = *in1++;
*output++ = *in2++;
*output++ = *in3++;
*output++ = *in4++;
*output++ = *in5++;
*output++ = *in6++;
*output++ = *in7++;
}
for (int j = K; j < k_round; ++j) {
*output++ = static_cast<type>(0);
*output++ = static_cast<type>(0);
*output++ = static_cast<type>(0);
*output++ = static_cast<type>(0);
*output++ = static_cast<type>(0);
*output++ = static_cast<type>(0);
*output++ = static_cast<type>(0);
*output++ = static_cast<type>(0);
}
}
}
template <typename type, typename type2>
static void basic_gemm_c4(bool trans_a,
......@@ -116,6 +182,60 @@ static void basic_gemm_c4(bool trans_a,
free(tmp_c);
}
template <typename type, typename type2>
static void basic_gemm_c8(bool trans_a,
bool trans_b,
int m,
int n,
int k,
type2 alpha,
const type* a,
int lda,
const type* b,
int ldb,
type2 beta,
type2* c,
int ldc,
const type2* bias,
bool flag_bias = false,
bool flag_relu = false) {
type2* tmp_c = reinterpret_cast<type2*>(malloc(m * ldc * sizeof(type2)));
memset(tmp_c, 0, m * ldc * sizeof(type2));
#pragma omp parallel for
for (int i = 0; i < m; ++i) {
auto bias_data = static_cast<type2>(0);
if (flag_bias) {
bias_data = bias[i];
}
for (int j = 0; j < n; ++j) {
auto sum = static_cast<type2>(0);
for (int l = 0; l < k; ++l) {
type av;
type bv;
if (trans_a) {
av = a[l * lda + i];
} else {
av = a[i * lda + l];
}
if (trans_b) {
bv = b[j * ldb + l];
} else {
bv = b[l * ldb + j];
}
sum += av * bv;
}
type2 tmp = alpha * sum + beta * tmp_c[i * ldc + j] + bias_data;
if (flag_relu) {
tmp_c[i * ldc + j] = tmp > (type2)0 ? tmp : (type2)0;
} else {
tmp_c[i * ldc + j] = tmp;
}
}
}
//! trans c to c4
basic_trans_mat_to_c8(tmp_c, c, ldc, m, n, false);
free(tmp_c);
}
template <typename type, typename type2>
static void basic_gemm(bool trans_a,
bool trans_b,
......
......@@ -41,6 +41,10 @@ void fill_tensor_const(Tensor& tensor, float value) { // NOLINT
fill_tensor_host_const_impl(
tensor.mutable_data<int8_t>(), static_cast<signed char>(value), size);
break;
case PRECISION(kInt16):
fill_tensor_host_const_impl(
tensor.mutable_data<int16_t>(), static_cast<int16_t>(value), size);
break;
case PRECISION(kInt32):
fill_tensor_host_const_impl(
tensor.mutable_data<int>(), static_cast<int>(value), size);
......@@ -69,6 +73,12 @@ void fill_tensor_host_rand_impl<signed char>(signed char* dio, int64_t size) {
}
}
template <>
void fill_tensor_host_rand_impl<int16_t>(int16_t* dio, int64_t size) {
for (int64_t i = 0; i < size; ++i) {
dio[i] = (rand() % 256 - 128) * 2; // NOLINT
}
}
template <>
void fill_tensor_host_rand_impl<unsigned char>(unsigned char* dio,
int64_t size) {
for (int64_t i = 0; i < size; ++i) {
......@@ -86,6 +96,9 @@ void fill_tensor_rand(Tensor& tensor) { // NOLINT
case PRECISION(kInt8):
fill_tensor_host_rand_impl(tensor.mutable_data<int8_t>(), size);
break;
case PRECISION(kInt16):
fill_tensor_host_rand_impl(tensor.mutable_data<int16_t>(), size);
break;
case PRECISION(kInt32):
fill_tensor_host_rand_impl(tensor.mutable_data<int>(), size);
break;
......
......@@ -678,15 +678,9 @@ void resize(const uint8_t* src,
} else if (srcFormat == NV12 || srcFormat == NV21) {
nv21_resize(src, dst, srcw, srch, dstw, dsth);
return;
num = 1;
int hout = static_cast<int>(0.5 * dsth);
dsth += hout;
} else if (srcFormat == BGR || srcFormat == RGB) {
bgr_resize(src, dst, srcw, srch, dstw, dsth);
return;
w_in = srcw * 3;
w_out = dstw * 3;
num = 3;
} else if (srcFormat == BGRA || srcFormat == RGBA) {
w_in = srcw * 4;
w_out = dstw * 4;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册