未验证 提交 53570d38 编写于 作者: H HappyAngel 提交者: GitHub

[cherry-pick] fix xiaodu crash and profiler (#3906)

* [arm]add 2x2s2p1 pooling  (#3705)

* fix pooling bug and speed

* add 2x2s2p1 pooling. test=develop

* fix conflict, test=develop

* fix conflict in wino

* [arm] add 3x3s1 Winograd int8 (#3767)

* fix: winograd support unsame pad
test=develop

* feat: add winograd int8 kernel
test=develop

* fix: style fix
test=develo

* fix winograd_int8 ut sgement default. test=develop

* close basic_test, test=develop
Co-authored-by: NMyPandaShaoxiang <txg4794@163.com>

* fix xiaodu crash in gemm prepacked

* in huwen phone, 3x3s2p0 avg pooling will rand crash, other phone does not have this feature

* [arm] update con int8 kernel choose (#3834)

* fix conv int8 kernel choose and sooftmax compute bug

* change axis_size = 4 kernel choose, test=develop

* fix format. test=develop

* fix format.test=develop

* fix build test=develop

* fix buiild error test=develop

* fix wino_int8 computte erroor. test=develop

* Update the link to debug, test=develop, test=document_fix (#3870) (#3871)
Co-authored-by: NMyPandaShaoxiang <txg4794@163.com>
Co-authored-by: Ncc <52520497+juncaipeng@users.noreply.github.com>
上级 1166948a
...@@ -49,4 +49,4 @@ $ ./opt \ ...@@ -49,4 +49,4 @@ $ ./opt \
## 五. 测试工具 ## 五. 测试工具
为了使您更好的了解并使用Lite框架,我们向有进一步使用需求的用户开放了 [Debug工具](debug#debug)[Profile工具](debug#profiler)。Lite Model Debug Tool可以用来查找Lite框架与PaddlePaddle框架在执行预测时模型中的对应变量值是否有差异,进一步快速定位问题Op,方便复现与排查问题。Profile Monitor Tool可以帮助您了解每个Op的执行时间消耗,其会自动统计Op执行的次数,最长、最短、平均执行时间等等信息,为性能调优做一个基础参考。您可以通过 [相关专题](debug) 了解更多内容。 为了使您更好的了解并使用Lite框架,我们向有进一步使用需求的用户开放了 [Debug工具](debug)[Profile工具](debug)。Lite Model Debug Tool可以用来查找Lite框架与PaddlePaddle框架在执行预测时模型中的对应变量值是否有差异,进一步快速定位问题Op,方便复现与排查问题。Profile Monitor Tool可以帮助您了解每个Op的执行时间消耗,其会自动统计Op执行的次数,最长、最短、平均执行时间等等信息,为性能调优做一个基础参考。您可以通过 [相关专题](debug) 了解更多内容。
...@@ -83,6 +83,7 @@ if (NOT HAS_ARM_MATH_LIB_DIR) ...@@ -83,6 +83,7 @@ if (NOT HAS_ARM_MATH_LIB_DIR)
conv5x5s2_depthwise_int8.cc conv5x5s2_depthwise_int8.cc
conv5x5s2_depthwise_fp32.cc conv5x5s2_depthwise_fp32.cc
conv3x3_winograd_fp32_c4.cc conv3x3_winograd_fp32_c4.cc
conv3x3_winograd_int8.cc
conv_winograd_3x3.cc conv_winograd_3x3.cc
conv_impl.cc conv_impl.cc
softmax.cc softmax.cc
......
...@@ -1245,7 +1245,7 @@ void weight_trans_c4_8x8( ...@@ -1245,7 +1245,7 @@ void weight_trans_c4_8x8(
for (int i = 0; i < ch_out * ch_in * 64; ++i) { for (int i = 0; i < ch_out * ch_in * 64; ++i) {
int new_c = i % 64; int new_c = i % 64;
int new_oc = i / ch_in / 64 / 4; int new_oc = i / ch_in / 64 / 4;
int new_ic = i / 64 % (ch_in * 4) % ch_in; int new_ic = i / 64 % ch_in;
int new_inner = i / ch_in / 64 % 4; int new_inner = i / ch_in / 64 % 4;
int dest_ind = int dest_ind =
new_c * c_stride + new_oc * ic_pad * 4 + new_ic * 4 + new_inner; new_c * c_stride + new_oc * ic_pad * 4 + new_ic * 4 + new_inner;
...@@ -1302,7 +1302,7 @@ void weight_trans_c4_4x4( ...@@ -1302,7 +1302,7 @@ void weight_trans_c4_4x4(
for (int i = 0; i < ch_out * ch_in * 16; ++i) { for (int i = 0; i < ch_out * ch_in * 16; ++i) {
int new_c = i % 16; int new_c = i % 16;
int new_oc = i / ch_in / 16 / 4; int new_oc = i / ch_in / 16 / 4;
int new_ic = i / 16 % (ch_in * 4) % ch_in; int new_ic = i / 16 % ch_in;
int new_inner = i / ch_in / 16 % 4; int new_inner = i / ch_in / 16 % 4;
int dest_ind = int dest_ind =
new_c * c_stride + new_oc * ic_pad * 4 + new_ic * 4 + new_inner; new_c * c_stride + new_oc * ic_pad * 4 + new_ic * 4 + new_inner;
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/backends/arm/math/conv_block_utils.h"
#include "lite/backends/arm/math/conv_impl.h"
#include "lite/backends/arm/math/packed_sgemm_c4.h"
#ifdef ARM_WITH_OMP
#include <omp.h>
#endif
#include <arm_neon.h>
namespace paddle {
namespace lite {
namespace arm {
namespace math {
void input_trans_c8_4x4_int8(const int8_t* src,
int src_stride,
int src_h_stride,
int16_t* dest,
int dest_stride,
int dest_h_stride);
void output_trans_c8_post_2x4_int8(const int32_t* src,
int src_stride,
int src_h_stride,
int32_t* dest,
int dest_stride,
int dest_h_stride);
void weight_trans_c8_4x4_int8(
int16_t* dest, const int8_t* src, int ic, int oc, void* workspace);
// F(2,3)
template <typename Dtype>
void conv_compute_2x2_3x3_int8(const int8_t* input,
Dtype* output,
int num,
int chout,
int hout,
int wout,
int chin,
int hin,
int win,
const int16_t* weight,
const float* bias,
const float* scale,
const operators::ConvParam& param,
ARMContext* ctx) {
auto act_param = param.activation_param;
const int pad_h0 = (*param.paddings)[0];
const int pad_h1 = (*param.paddings)[1];
const int pad_w0 = (*param.paddings)[2];
const int pad_w1 = (*param.paddings)[3];
int8_t* tmp_work_space =
ctx->workspace_data<int8_t>() + ctx->llc_size() / sizeof(int8_t);
int in_n_stride = chin * hin * win;
int out_n_stride = chout * hout * wout;
int ic_stride = win * hin;
int oc_stride = wout * hout;
int ic_8 = (chin + 7) / 8;
int oc_8 = (chout + 7) / 8;
int tile_w = (wout + 1) / 2;
int tile_h = (hout + 1) / 2;
int size_tile = tile_h * tile_w;
int w_pad = win + pad_w0 + pad_w1;
int h_pad = hin + pad_h0 + pad_h1;
const int zero_len = (w_pad + 3) / 4 * 4;
Dtype zero_ptr[zero_len]; // NOLINT
memset(zero_ptr, 0, zero_len * sizeof(Dtype));
int8_t* input_c8 = tmp_work_space;
int new_h_stride = w_pad * 8;
int new_c_stride = new_h_stride * h_pad;
int ic_8_stride = w_pad * h_pad * 8;
int oc_8_stride = wout * hout * 8;
int tile_block = 8;
int block_count = (size_tile + tile_block - 1) / tile_block;
int threads = ctx->threads();
int16_t* g_tmp_data =
(int16_t*)(tmp_work_space + ic_8 * ic_8_stride + // NOLINT
oc_8 * oc_8_stride * sizeof(int32_t));
int tmp_input_thread_stride = tile_block * ic_8 * 128;
int tmp_output_thread_stride = tile_block * oc_8 * 128;
int tmp_data_thread_stride_size = tmp_input_thread_stride * sizeof(int16_t) +
tmp_output_thread_stride * sizeof(int32_t);
memset(g_tmp_data, 0, tmp_data_thread_stride_size);
int8_t* g_trans_remain_tmp_data =
(int8_t*)(g_tmp_data + // NOLINT
threads * (tmp_input_thread_stride +
tmp_output_thread_stride * sizeof(int32_t) /
sizeof(int16_t)));
int32_t* g_trans_tmp_data =
(int32_t*)(g_trans_remain_tmp_data + threads * 128); // NOLINT
auto act_type = act_param.active_type;
int flag_act = 0; // relu: 1, relu6: 2, leakey: 3
float alpha[4] = {0.f, 0.f, 0.f, 0.f};
if (act_param.has_active) {
if (act_type == lite_api::ActivationType::kRelu) {
flag_act = 1;
} else if (act_type == lite_api::ActivationType::kRelu6) {
flag_act = 2;
float local_alpha = act_param.Relu_clipped_coef;
alpha[0] = local_alpha;
alpha[1] = local_alpha;
alpha[2] = local_alpha;
alpha[3] = local_alpha;
} else if (act_type == lite_api::ActivationType::kLeakyRelu) {
flag_act = 3;
float local_alpha = act_param.Leaky_relu_alpha;
alpha[0] = local_alpha;
alpha[1] = local_alpha;
alpha[2] = local_alpha;
alpha[3] = local_alpha;
}
}
// begin compute
for (int ni = 0; ni < num; ++ni) {
// trans input to c8
for (int i = 0; i < ic_8; ++i) {
prepack_input_nxwc8_int8_dw(input + ni * in_n_stride,
input_c8 + i * new_c_stride,
i * 8,
-pad_h0,
hin + pad_h1,
-pad_w0,
win + pad_w1,
chin,
win,
hin);
}
int32_t* output_c8 = (int32_t*)(input_c8 + ic_8 * ic_8_stride); // NOLINT
Dtype* output_ptr = output + ni * out_n_stride;
const int16_t* weight_ptr = weight;
#pragma omp parallel for num_threads(threads)
for (int tbi = 0; tbi < block_count; ++tbi) {
#ifdef ARM_WITH_OMP
int16_t* tmp_data =
g_tmp_data +
omp_get_thread_num() * tmp_data_thread_stride_size / sizeof(int16_t);
int32_t* trans_tmp_data = g_trans_tmp_data + omp_get_thread_num() * 32;
int8_t* trans_remain_tmp_data =
g_trans_remain_tmp_data + omp_get_thread_num() * 128;
#else
int16_t* tmp_data = g_tmp_data;
int32_t* trans_tmp_data = g_trans_tmp_data;
int8_t* trans_remain_tmp_data = g_trans_remain_tmp_data;
#endif
int tile_index = tbi * tile_block;
int tile_remain = size_tile - tile_index;
int tile_count = tile_remain > tile_block ? tile_block : tile_remain;
// input trans
int c_gi_stride = tile_count * oc_8 * 8;
int b_gi_stride = tile_count * ic_8 * 8;
//*
for (int ti = 0; ti < tile_count; ++ti) {
int index = tile_index + ti;
int tw_index = index % tile_w;
int th_index = index / tile_w;
int src_x = tw_index + tw_index;
int src_y = th_index + th_index;
int ex = src_x + 4 > w_pad ? w_pad - src_x : 4;
int ey = src_y + 4 > h_pad ? h_pad - src_y : 4;
int16_t* dst_ptr = tmp_data + ti * 8;
const int8_t* src_ptr = input_c8 + (src_y * w_pad + src_x) * 8;
if (ex == 4 && ey == 4) {
// trans input
for (int ci = 0; ci < ic_8; ++ci) {
const int8_t* src_ci = src_ptr + ci * ic_8_stride;
int16_t* dst_ci = dst_ptr + ci * tile_count * 8;
input_trans_c8_4x4_int8(
src_ci, 8, w_pad * 8, dst_ci, b_gi_stride, b_gi_stride * 4);
}
} else {
// trans remain input
int x_size = ex;
for (int ci = 0; ci < ic_8; ++ci) {
const int8_t* src_ci = src_ptr + ci * ic_8_stride;
// pad
memset(trans_remain_tmp_data, 0, 128 * sizeof(int8_t));
if (x_size > 0) {
for (int yi = 0; yi < ey; ++yi) {
int8_t* dst_yi = trans_remain_tmp_data + yi * 32;
const int8_t* src_yi = src_ci + w_pad * yi * 8;
memcpy(dst_yi, src_yi, x_size * sizeof(int8_t) * 8);
}
}
// trans
int16_t* dst_ci = dst_ptr + ci * tile_count * 8;
input_trans_c8_4x4_int8(trans_remain_tmp_data,
8,
32,
dst_ci,
b_gi_stride,
b_gi_stride * 4);
} // for ci_4
}
}
//*/
// input trans end
// *begin compute dot
// *
//*
int32_t* dst_temp_data =
(int32_t*)(tmp_data + tmp_input_thread_stride); // NOLINT
int16_t* b_ptr = tmp_data;
int w_gi_stride = ic_8 * oc_8 * 64;
for (int gi = 0; gi < 16; ++gi) {
int32_t* origin_C = dst_temp_data + gi * c_gi_stride;
int16_t* origin_B = b_ptr + gi * b_gi_stride;
const int16_t* origin_A = weight + gi * w_gi_stride;
sgemm_prepack_c8_int16_small(
oc_8 * 8, tile_count, ic_8 * 8, origin_A, origin_B, origin_C, ctx);
}
//*/
//*
// output trans
for (int ti = 0; ti < tile_count; ++ti) {
int index = tile_index + ti;
int tw_index = index % tile_w;
int th_index = index / tile_w;
int dst_x = tw_index * 2;
int dst_y = th_index * 2;
int ex = dst_x + 2 > wout ? wout - dst_x : 2;
int ey = dst_y + 2 > hout ? hout - dst_y : 2;
int32_t* src_ptr = dst_temp_data + ti * 8;
int32_t* trans_remain_tmp_i32_data =
(int32_t*)(trans_remain_tmp_data); // NOLINT
int32_t* dst_ptr = output_c8 + (dst_y * wout + dst_x) * 8;
if (ex == 2 && ey == 2) {
// trans output
for (int ci = 0; ci < oc_8; ++ci) {
int cur_ind = ci * 8;
int32_t* src_ci = src_ptr + ci * tile_count * 8;
int32_t* dst_ci = dst_ptr + ci * oc_8_stride;
output_trans_c8_post_2x4_int8(
src_ci, c_gi_stride, c_gi_stride * 4, dst_ci, 8, wout * 8);
}
} else {
for (int ci = 0; ci < oc_8; ++ci) {
int cur_ind = ci * 8;
// trans output
int32_t* src_ci = src_ptr + ci * tile_count * 8;
output_trans_c8_post_2x4_int8(src_ci,
c_gi_stride,
c_gi_stride * 4,
trans_remain_tmp_i32_data,
8,
16);
// copy to dest
int32_t* dst_ci = dst_ptr + ci * oc_8_stride;
for (int i = 0; i < ey; ++i) {
memcpy(dst_ci + i * wout * 8,
trans_remain_tmp_i32_data + i * 16,
ex * sizeof(int32_t) * 8);
}
}
}
}
//*/
} // for block_count
const float* bias_local_ptr = bias;
for (int ci = 0; ci < oc_8; ++ci) {
float bias_local[8] = {bias_local_ptr[0],
bias_local_ptr[1],
bias_local_ptr[2],
bias_local_ptr[3],
bias_local_ptr[4],
bias_local_ptr[5],
bias_local_ptr[6],
bias_local_ptr[7]};
write_int32_nchwc8_to_nchw(output_c8 + ci * oc_8_stride,
output_ptr,
ci * 8,
ci * 8 + 8,
0,
hout,
0,
wout,
chout,
hout,
wout,
flag_act > 0,
bias_local,
param.bias,
zero_ptr,
scale + ci * 8);
bias_local_ptr += 8;
}
} // for num
} // conv compute
template void conv_compute_2x2_3x3_int8<int8_t>(
const int8_t* input,
int8_t* output,
int num,
int chout,
int hout,
int wout,
int chin,
int hin,
int win,
const int16_t* weight,
const float* bias,
const float* scale,
const operators::ConvParam& param,
ARMContext* ctx);
template void conv_compute_2x2_3x3_int8<float>(
const int8_t* input,
float* output,
int num,
int chout,
int hout,
int wout,
int chin,
int hin,
int win,
const int16_t* weight,
const float* bias,
const float* scale,
const operators::ConvParam& param,
ARMContext* ctx);
// BT=[1, 0, -1, 0,
// 0, 1, 1, 0,
// 0, -1, 1, 0,
// 0, 1, 0, -1]
void input_trans_c8_4x4_int8(const int8_t* src,
int src_stride,
int src_h_stride,
int16_t* dest,
int dest_stride,
int dest_h_stride) {
int8x8_t src00 = vld1_s8(src);
int8x8_t src01 = vld1_s8(src + src_stride);
int8x8_t src02 = vld1_s8(src + src_stride + src_stride);
int8x8_t src03 = vld1_s8(src + src_stride + src_stride + src_stride);
src += src_h_stride;
int8x8_t src10 = vld1_s8(src);
int8x8_t src11 = vld1_s8(src + src_stride);
int8x8_t src12 = vld1_s8(src + src_stride + src_stride);
int8x8_t src13 = vld1_s8(src + src_stride + src_stride + src_stride);
src += src_h_stride;
int8x8_t src20 = vld1_s8(src);
int8x8_t src21 = vld1_s8(src + src_stride);
int8x8_t src22 = vld1_s8(src + src_stride + src_stride);
int8x8_t src23 = vld1_s8(src + src_stride + src_stride + src_stride);
src += src_h_stride;
int8x8_t src30 = vld1_s8(src);
int8x8_t src31 = vld1_s8(src + src_stride);
int8x8_t src32 = vld1_s8(src + src_stride + src_stride);
int8x8_t src33 = vld1_s8(src + src_stride + src_stride + src_stride);
int16x8_t dst00 = vsubl_s8(src00, src02);
int16x8_t dst10 = vaddl_s8(src01, src02);
int16x8_t dst20 = vsubl_s8(src02, src01);
int16x8_t dst30 = vsubl_s8(src01, src03);
int16x8_t dst01 = vsubl_s8(src10, src12);
int16x8_t dst11 = vaddl_s8(src11, src12);
int16x8_t dst21 = vsubl_s8(src12, src11);
int16x8_t dst31 = vsubl_s8(src11, src13);
int16x8_t dst02 = vsubl_s8(src20, src22);
int16x8_t dst12 = vaddl_s8(src21, src22);
int16x8_t dst22 = vsubl_s8(src22, src21);
int16x8_t dst32 = vsubl_s8(src21, src23);
int16x8_t dst03 = vsubl_s8(src30, src32);
int16x8_t dst13 = vaddl_s8(src31, src32);
int16x8_t dst23 = vsubl_s8(src32, src31);
int16x8_t dst33 = vsubl_s8(src31, src33);
int16x8_t dest00 = vsubq_s16(dst00, dst02);
int16x8_t dest10 = vaddq_s16(dst01, dst02);
int16x8_t dest20 = vsubq_s16(dst02, dst01);
int16x8_t dest30 = vsubq_s16(dst01, dst03);
int16x8_t dest01 = vsubq_s16(dst10, dst12);
int16x8_t dest11 = vaddq_s16(dst11, dst12);
int16x8_t dest21 = vsubq_s16(dst12, dst11);
int16x8_t dest31 = vsubq_s16(dst11, dst13);
int16x8_t dest02 = vsubq_s16(dst20, dst22);
int16x8_t dest12 = vaddq_s16(dst21, dst22);
int16x8_t dest22 = vsubq_s16(dst22, dst21);
int16x8_t dest32 = vsubq_s16(dst21, dst23);
int16x8_t dest03 = vsubq_s16(dst30, dst32);
int16x8_t dest13 = vaddq_s16(dst31, dst32);
int16x8_t dest23 = vsubq_s16(dst32, dst31);
int16x8_t dest33 = vsubq_s16(dst31, dst33);
vst1q_s16(dest, dest00);
vst1q_s16(dest + dest_stride, dest10);
vst1q_s16(dest + dest_stride + dest_stride, dest20);
vst1q_s16(dest + dest_stride + dest_stride + dest_stride, dest30);
dest += dest_h_stride;
vst1q_s16(dest, dest01);
vst1q_s16(dest + dest_stride, dest11);
vst1q_s16(dest + dest_stride + dest_stride, dest21);
vst1q_s16(dest + dest_stride + dest_stride + dest_stride, dest31);
dest += dest_h_stride;
vst1q_s16(dest, dest02);
vst1q_s16(dest + dest_stride, dest12);
vst1q_s16(dest + dest_stride + dest_stride, dest22);
vst1q_s16(dest + dest_stride + dest_stride + dest_stride, dest32);
dest += dest_h_stride;
vst1q_s16(dest, dest03);
vst1q_s16(dest + dest_stride, dest13);
vst1q_s16(dest + dest_stride + dest_stride, dest23);
vst1q_s16(dest + dest_stride + dest_stride + dest_stride, dest33);
}
// AT=[1, 1, 1, 0,
// 0, 1, -1, -1]
void output_trans_c8_post_2x4_int8(const int32_t* src,
int src_stride,
int src_h_stride,
int32_t* dest,
int dest_stride,
int dest_h_stride) {
int32x4_t src400 = vld1q_s32(src);
int32x4_t src800 = vld1q_s32(src + 4);
src += src_stride;
int32x4_t src401 = vld1q_s32(src);
int32x4_t src801 = vld1q_s32(src + 4);
src += src_stride;
int32x4_t src402 = vld1q_s32(src);
int32x4_t src802 = vld1q_s32(src + 4);
src += src_stride;
int32x4_t src403 = vld1q_s32(src);
int32x4_t src803 = vld1q_s32(src + 4);
src += src_h_stride - 3 * src_stride;
int32x4_t src410 = vld1q_s32(src);
int32x4_t src810 = vld1q_s32(src + 4);
src += src_stride;
int32x4_t src411 = vld1q_s32(src);
int32x4_t src811 = vld1q_s32(src + 4);
src += src_stride;
int32x4_t src412 = vld1q_s32(src);
int32x4_t src812 = vld1q_s32(src + 4);
src += src_stride;
int32x4_t src413 = vld1q_s32(src);
int32x4_t src813 = vld1q_s32(src + 4);
src += src_h_stride - 3 * src_stride;
int32x4_t src420 = vld1q_s32(src);
int32x4_t src820 = vld1q_s32(src + 4);
src += src_stride;
int32x4_t src421 = vld1q_s32(src);
int32x4_t src821 = vld1q_s32(src + 4);
src += src_stride;
int32x4_t src422 = vld1q_s32(src);
int32x4_t src822 = vld1q_s32(src + 4);
src += src_stride;
int32x4_t src423 = vld1q_s32(src);
int32x4_t src823 = vld1q_s32(src + 4);
src += src_h_stride - 3 * src_stride;
int32x4_t src430 = vld1q_s32(src);
int32x4_t src830 = vld1q_s32(src + 4);
src += src_stride;
int32x4_t src431 = vld1q_s32(src);
int32x4_t src831 = vld1q_s32(src + 4);
src += src_stride;
int32x4_t src432 = vld1q_s32(src);
int32x4_t src832 = vld1q_s32(src + 4);
src += src_stride;
int32x4_t src433 = vld1q_s32(src);
int32x4_t src833 = vld1q_s32(src + 4);
int32x4_t dst400 = vaddq_s32(vaddq_s32(src400, src401), src402);
int32x4_t dst410 = vsubq_s32(vsubq_s32(src401, src402), src403);
int32x4_t dst401 = vaddq_s32(vaddq_s32(src410, src411), src412);
int32x4_t dst411 = vsubq_s32(vsubq_s32(src411, src412), src413);
int32x4_t dst402 = vaddq_s32(vaddq_s32(src420, src421), src422);
int32x4_t dst412 = vsubq_s32(vsubq_s32(src421, src422), src423);
int32x4_t dst403 = vaddq_s32(vaddq_s32(src430, src431), src432);
int32x4_t dst413 = vsubq_s32(vsubq_s32(src431, src432), src433);
int32x4_t dst800 = vaddq_s32(vaddq_s32(src800, src801), src802);
int32x4_t dst810 = vsubq_s32(vsubq_s32(src801, src802), src803);
int32x4_t dst801 = vaddq_s32(vaddq_s32(src810, src811), src812);
int32x4_t dst811 = vsubq_s32(vsubq_s32(src811, src812), src813);
int32x4_t dst802 = vaddq_s32(vaddq_s32(src820, src821), src822);
int32x4_t dst812 = vsubq_s32(vsubq_s32(src821, src822), src823);
int32x4_t dst803 = vaddq_s32(vaddq_s32(src830, src831), src832);
int32x4_t dst813 = vsubq_s32(vsubq_s32(src831, src832), src833);
int32x4_t dest400 = vaddq_s32(vaddq_s32(dst400, dst401), dst402);
int32x4_t dest410 = vsubq_s32(vsubq_s32(dst401, dst402), dst403);
int32x4_t dest401 = vaddq_s32(vaddq_s32(dst410, dst411), dst412);
int32x4_t dest411 = vsubq_s32(vsubq_s32(dst411, dst412), dst413);
int32x4_t dest800 = vaddq_s32(vaddq_s32(dst800, dst801), dst802);
int32x4_t dest810 = vsubq_s32(vsubq_s32(dst801, dst802), dst803);
int32x4_t dest801 = vaddq_s32(vaddq_s32(dst810, dst811), dst812);
int32x4_t dest811 = vsubq_s32(vsubq_s32(dst811, dst812), dst813);
vst1q_s32(dest, dest400);
vst1q_s32(dest + 4, dest800);
dest += dest_stride;
vst1q_s32(dest, dest410);
vst1q_s32(dest + 4, dest810);
dest += dest_h_stride - dest_stride;
vst1q_s32(dest, dest401);
vst1q_s32(dest + 4, dest801);
dest += dest_stride;
vst1q_s32(dest, dest411);
vst1q_s32(dest + 4, dest811);
}
void weight_trans_c8_4x4_int8(
int16_t* dest, const int8_t* din, int ch_in, int ch_out, void* workspace) {
const int16_t coeff[4][3] = {{2, 0, 0}, {1, 1, 1}, {1, -1, 1}, {0, 0, 2}};
int16_t* ptr_out = static_cast<int16_t*>(workspace);
for (int i = 0; i < ch_out; i++) {
for (int j = 0; j < ch_in; j++) {
const int8_t* kernel0 =
static_cast<const int8_t*>(din) + (i * ch_in + j) * 9;
int16_t* ptr_channel = ptr_out + (i * ch_in + j) * 16;
//! transform kernel, transposed
const int8_t* k0 = kernel0;
const int8_t* k1 = kernel0 + 3;
const int8_t* k2 = kernel0 + 6;
//! h
int16_t tmp[4][3];
for (int i = 0; i < 4; i++) {
tmp[i][0] =
k0[0] * coeff[i][0] + k0[1] * coeff[i][1] + k0[2] * coeff[i][2];
tmp[i][1] =
k1[0] * coeff[i][0] + k1[1] * coeff[i][1] + k1[2] * coeff[i][2];
tmp[i][2] =
k2[0] * coeff[i][0] + k2[1] * coeff[i][1] + k2[2] * coeff[i][2];
}
//! v
for (int j = 0; j < 4; j++) {
int16_t* tmpp = &tmp[j][0];
for (int i = 0; i < 4; i++) {
ptr_channel[j * 4 + i] = tmpp[0] * coeff[i][0] +
tmpp[1] * coeff[i][1] +
tmpp[2] * coeff[i][2];
}
}
}
}
int oc_pad = (ch_out + 7) / 8 * 8;
int ic_pad = (ch_in + 7) / 8 * 8;
int c_stride = ic_pad * oc_pad;
for (int i = 0; i < ch_out * ch_in * 16; ++i) {
int new_c = i % 16;
int new_oc = i / ch_in / 16 / 8;
int new_ic = i / 16 % ch_in;
int new_inner = i / ch_in / 16 % 8;
int dest_ind =
new_c * c_stride + new_oc * ic_pad * 8 + new_ic * 8 + new_inner;
dest[dest_ind] = ptr_out[i];
}
}
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
...@@ -3878,6 +3878,7 @@ inline void write_int32_nchwc8_to_nchw(const int* din, ...@@ -3878,6 +3878,7 @@ inline void write_int32_nchwc8_to_nchw(const int* din,
int w_stride = we - ws; int w_stride = we - ws;
int valid_w = (we > width ? width : we) - ws; int valid_w = (we > width ? width : we) - ws;
int cnt = valid_w / 4; int cnt = valid_w / 4;
int remain = valid_w & 3;
float32x4_t w_scale0 = vld1q_f32(scale); float32x4_t w_scale0 = vld1q_f32(scale);
float32x4_t w_scale1 = vld1q_f32(scale + 4); float32x4_t w_scale1 = vld1q_f32(scale + 4);
...@@ -3933,10 +3934,10 @@ inline void write_int32_nchwc8_to_nchw(const int* din, ...@@ -3933,10 +3934,10 @@ inline void write_int32_nchwc8_to_nchw(const int* din,
w_bias1, w_bias1,
flag_relu); flag_relu);
} }
if (we > width) { if (remain > 0) {
int offset = 32 * cnt; int offset = 32 * cnt;
din_hei_ptr = ptr_din + offset; din_hei_ptr = ptr_din + offset;
for (int j = ws + cnt * 4; j < width; ++j) { for (int j = 0; j < remain; ++j) {
if (flag_bias) { if (flag_bias) {
*(doutc0_ptr++) = *(doutc0_ptr++) =
cvt_kernel<Dtype>(din_hei_ptr[0], scale[0], bias[0], flag_relu); cvt_kernel<Dtype>(din_hei_ptr[0], scale[0], bias[0], flag_relu);
......
...@@ -359,6 +359,35 @@ void conv_compute_2x2_3x3_small(const float* input, ...@@ -359,6 +359,35 @@ void conv_compute_2x2_3x3_small(const float* input,
const float* bias, const float* bias,
const operators::ConvParam& param, const operators::ConvParam& param,
ARMContext* ctx); ARMContext* ctx);
void input_trans_c8_4x4_int8(const int8_t* src,
int src_stride,
int src_h_stride,
int16_t* dest,
int dest_stride,
int dest_h_stride);
void output_trans_c8_post_2x4_int8(const int32_t* src,
int src_stride,
int src_h_stride,
int32_t* dest,
int dest_stride,
int dest_h_stride);
void weight_trans_c8_4x4_int8(
int16_t* dest, const int8_t* src, int ic, int oc, void* workspace);
template <typename Dtype>
void conv_compute_2x2_3x3_int8(const int8_t* input,
Dtype* output,
int num,
int chout,
int hout,
int wout,
int chin,
int hin,
int win,
const int16_t* weight,
const float* bias,
const float* scale,
const operators::ConvParam& param,
ARMContext* ctx);
template <typename Dtype> template <typename Dtype>
void im2col(const Dtype* data_im, void im2col(const Dtype* data_im,
......
...@@ -1922,19 +1922,45 @@ void gemm_prepack_oth_int8(const int8_t* A_packed, ...@@ -1922,19 +1922,45 @@ void gemm_prepack_oth_int8(const int8_t* A_packed,
Dtype* tmp1 = nullptr; Dtype* tmp1 = nullptr;
Dtype* tmp2 = nullptr; Dtype* tmp2 = nullptr;
Dtype* tmp3 = nullptr; Dtype* tmp3 = nullptr;
float32_t scale_local[4]; float32_t scale_local[4] = {0, 0, 0, 0};
float32_t bias_local[4] = {0, 0, 0, 0}; float32_t bias_local[4] = {0, 0, 0, 0};
if (is_bias) { if (is_bias) {
bias_local[0] = bias[y]; if (y + 4 <= M) {
bias_local[1] = bias[y + 1]; bias_local[0] = bias[y];
bias_local[2] = bias[y + 2]; bias_local[1] = bias[y + 1];
bias_local[3] = bias[y + 3]; bias_local[2] = bias[y + 2];
bias_local[3] = bias[y + 3];
} else {
switch (M - y) {
case 3:
bias_local[2] = bias[y + 2];
case 2:
bias_local[1] = bias[y + 1];
case 1:
bias_local[0] = bias[y + 0];
default:
break;
}
}
} }
if (scale) { if (scale) {
scale_local[0] = scale[y]; if (y + 4 <= M) {
scale_local[1] = scale[y + 1]; scale_local[0] = scale[y];
scale_local[2] = scale[y + 2]; scale_local[1] = scale[y + 1];
scale_local[3] = scale[y + 3]; scale_local[2] = scale[y + 2];
scale_local[3] = scale[y + 3];
} else {
switch (M - y) {
case 3:
scale_local[2] = scale[y + 2];
case 2:
scale_local[1] = scale[y + 1];
case 1:
scale_local[0] = scale[y + 0];
default:
break;
}
}
} }
if (y + MBLOCK_INT8_OTH > M) { if (y + MBLOCK_INT8_OTH > M) {
switch (y + MBLOCK_INT8_OTH - M) { switch (y + MBLOCK_INT8_OTH - M) {
......
...@@ -1679,6 +1679,912 @@ void sgemm_prepack_c4_small(int M, ...@@ -1679,6 +1679,912 @@ void sgemm_prepack_c4_small(int M,
} }
} }
void sgemm_prepack_c8_int16_small(int M,
int N,
int K,
const int16_t* A_packed,
const int16_t* B,
int32_t* C,
ARMContext* ctx) {
const int m_round = (M + 7) / 8 * 8;
const int k_round = (K + 7) / 8 * 8;
const int mloop = m_round >> 3;
const int lda = 8 * k_round;
const int ldb_byte = 8 * N * sizeof(int16_t);
const int kcnt = k_round >> 3;
#ifdef __aarch64__
float32x4_t vzero = vdupq_n_f32(0.f);
#endif
for (int m = 0; m < mloop; ++m) {
const int16_t* b = B;
int n = N;
#ifdef __aarch64__
for (; n > 7; n -= 8) {
int cnt = kcnt;
const int16_t* a_ptr = A_packed;
const int16_t* b_ptr = b;
// clang-format off
asm volatile(
"ld1 {v0.8h, v1.8h}, [%[a]], #32 \n" //load a0, a1
"ld1 {v4.8h, v5.8h}, [%[b]], #32 \n" //load b0, b1
"ld1 {v6.8h, v7.8h}, [%[b]], #32 \n" //load b2, b3
"smull v20.4s, v0.4h, v4.h[0] \n"
"smull v21.4s, v0.4h, v5.h[0] \n"
"smull v22.4s, v0.4h, v6.h[0] \n"
"smull v23.4s, v0.4h, v7.h[0] \n"
"ld1 {v8.8h, v9.8h}, [%[b]], #32 \n" //load b0, b1
"ld1 {v10.8h, v11.8h}, [%[b]], #32 \n" //load b2, b3
"smull2 v24.4s, v0.8h, v4.h[0] \n"
"smull2 v25.4s, v0.8h, v5.h[0] \n"
"smull2 v26.4s, v0.8h, v6.h[0] \n"
"smull2 v27.4s, v0.8h, v7.h[0] \n"
"ld1 {v2.8h, v3.8h}, [%[a]], #32 \n" //load a2, a3
"smlal v20.4s, v1.4h, v4.h[1] \n"
"smlal v21.4s, v1.4h, v5.h[1] \n"
"smlal v22.4s, v1.4h, v6.h[1] \n"
"smlal v23.4s, v1.4h, v7.h[1] \n"
"smlal2 v24.4s, v1.8h, v4.h[1] \n"
"smlal2 v25.4s, v1.8h, v5.h[1] \n"
"smlal2 v26.4s, v1.8h, v6.h[1] \n"
"smlal2 v27.4s, v1.8h, v7.h[1] \n"
"smull v12.4s, v0.4h, v8.h[0] \n"
"smull v13.4s, v0.4h, v9.h[0] \n"
"smull v14.4s, v0.4h, v10.h[0] \n"
"smull v15.4s, v0.4h, v11.h[0] \n"
"smull2 v16.4s, v0.8h, v8.h[0] \n"
"smull2 v17.4s, v0.8h, v9.h[0] \n"
"smull2 v18.4s, v0.8h, v10.h[0] \n"
"smull2 v19.4s, v0.8h, v11.h[0] \n"
"smlal v12.4s, v1.4h, v8.h[1] \n"
"smlal v13.4s, v1.4h, v9.h[1] \n"
"smlal v14.4s, v1.4h, v10.h[1] \n"
"smlal v15.4s, v1.4h, v11.h[1] \n"
"smlal2 v16.4s, v1.8h, v8.h[1] \n"
"smlal2 v17.4s, v1.8h, v9.h[1] \n"
"smlal2 v18.4s, v1.8h, v10.h[1] \n"
"smlal2 v19.4s, v1.8h, v11.h[1] \n"
"smlal v20.4s, v2.4h, v4.h[2] \n"
"smlal v21.4s, v2.4h, v5.h[2] \n"
"smlal v22.4s, v2.4h, v6.h[2] \n"
"smlal v23.4s, v2.4h, v7.h[2] \n"
"ld1 {v0.8h, v1.8h}, [%[a]], #32 \n" //load a0, a1
"smlal2 v24.4s, v2.8h, v4.h[2] \n"
"smlal2 v25.4s, v2.8h, v5.h[2] \n"
"smlal2 v26.4s, v2.8h, v6.h[2] \n"
"smlal2 v27.4s, v2.8h, v7.h[2] \n"
"smlal v12.4s, v2.4h, v8.h[2] \n"
"smlal v13.4s, v2.4h, v9.h[2] \n"
"smlal v14.4s, v2.4h, v10.h[2] \n"
"smlal v15.4s, v2.4h, v11.h[2] \n"
"smlal2 v16.4s, v2.8h, v8.h[2] \n"
"smlal2 v17.4s, v2.8h, v9.h[2] \n"
"smlal2 v18.4s, v2.8h, v10.h[2] \n"
"smlal2 v19.4s, v2.8h, v11.h[2] \n"
"smlal v20.4s, v3.4h, v4.h[3] \n"
"smlal v21.4s, v3.4h, v5.h[3] \n"
"smlal v22.4s, v3.4h, v6.h[3] \n"
"smlal v23.4s, v3.4h, v7.h[3] \n"
"smlal2 v24.4s, v3.8h, v4.h[3] \n"
"smlal2 v25.4s, v3.8h, v5.h[3] \n"
"smlal2 v26.4s, v3.8h, v6.h[3] \n"
"smlal2 v27.4s, v3.8h, v7.h[3] \n"
"smlal v12.4s, v3.4h, v8.h[3] \n"
"smlal v13.4s, v3.4h, v9.h[3] \n"
"smlal v14.4s, v3.4h, v10.h[3] \n"
"smlal v15.4s, v3.4h, v11.h[3] \n"
"smlal2 v16.4s, v3.8h, v8.h[3] \n"
"smlal2 v17.4s, v3.8h, v9.h[3] \n"
"smlal2 v18.4s, v3.8h, v10.h[3] \n"
"smlal2 v19.4s, v3.8h, v11.h[3] \n"
"smlal v20.4s, v0.4h, v4.h[4] \n"
"smlal v21.4s, v0.4h, v5.h[4] \n"
"smlal v22.4s, v0.4h, v6.h[4] \n"
"smlal v23.4s, v0.4h, v7.h[4] \n"
"smlal2 v24.4s, v0.8h, v4.h[4] \n"
"smlal2 v25.4s, v0.8h, v5.h[4] \n"
"smlal2 v26.4s, v0.8h, v6.h[4] \n"
"smlal2 v27.4s, v0.8h, v7.h[4] \n"
"ld1 {v2.8h, v3.8h}, [%[a]], #32 \n" //load a2, a3
"smlal v20.4s, v1.4h, v4.h[5] \n"
"smlal v21.4s, v1.4h, v5.h[5] \n"
"smlal v22.4s, v1.4h, v6.h[5] \n"
"smlal v23.4s, v1.4h, v7.h[5] \n"
"smlal2 v24.4s, v1.8h, v4.h[5] \n"
"smlal2 v25.4s, v1.8h, v5.h[5] \n"
"smlal2 v26.4s, v1.8h, v6.h[5] \n"
"smlal2 v27.4s, v1.8h, v7.h[5] \n"
"smlal v12.4s, v0.4h, v8.h[4] \n"
"smlal v13.4s, v0.4h, v9.h[4] \n"
"smlal v14.4s, v0.4h, v10.h[4] \n"
"smlal v15.4s, v0.4h, v11.h[4] \n"
"smlal2 v16.4s, v0.8h, v8.h[4] \n"
"smlal2 v17.4s, v0.8h, v9.h[4] \n"
"smlal2 v18.4s, v0.8h, v10.h[4] \n"
"smlal2 v19.4s, v0.8h, v11.h[4] \n"
"smlal v12.4s, v1.4h, v8.h[5] \n"
"smlal v13.4s, v1.4h, v9.h[5] \n"
"smlal v14.4s, v1.4h, v10.h[5] \n"
"smlal v15.4s, v1.4h, v11.h[5] \n"
"smlal2 v16.4s, v1.8h, v8.h[5] \n"
"smlal2 v17.4s, v1.8h, v9.h[5] \n"
"smlal2 v18.4s, v1.8h, v10.h[5] \n"
"smlal2 v19.4s, v1.8h, v11.h[5] \n"
"smlal v20.4s, v2.4h, v4.h[6] \n"
"smlal v21.4s, v2.4h, v5.h[6] \n"
"smlal v22.4s, v2.4h, v6.h[6] \n"
"smlal v23.4s, v2.4h, v7.h[6] \n"
"ld1 {v0.8h, v1.8h}, [%[a]], #32 \n" //load a0, a1
"smlal2 v24.4s, v2.8h, v4.h[6] \n"
"smlal2 v25.4s, v2.8h, v5.h[6] \n"
"smlal2 v26.4s, v2.8h, v6.h[6] \n"
"smlal2 v27.4s, v2.8h, v7.h[6] \n"
"sub %[b], %[b], #128 \n"
"add %[b], %[b], %[ldb] \n"
"smlal v20.4s, v3.4h, v4.h[7] \n"
"smlal v21.4s, v3.4h, v5.h[7] \n"
"smlal v22.4s, v3.4h, v6.h[7] \n"
"smlal v23.4s, v3.4h, v7.h[7] \n"
"smlal2 v24.4s, v3.8h, v4.h[7] \n"
"smlal2 v25.4s, v3.8h, v5.h[7] \n"
"smlal2 v26.4s, v3.8h, v6.h[7] \n"
"smlal2 v27.4s, v3.8h, v7.h[7] \n"
"ld1 {v4.8h, v5.8h}, [%[b]], #32 \n" //load b0, b1
"ld1 {v6.8h, v7.8h}, [%[b]], #32 \n" //load b2, b3
"smlal v12.4s, v2.4h, v8.h[6] \n"
"smlal v13.4s, v2.4h, v9.h[6] \n"
"smlal v14.4s, v2.4h, v10.h[6] \n"
"smlal v15.4s, v2.4h, v11.h[6] \n"
"smlal2 v16.4s, v2.8h, v8.h[6] \n"
"smlal2 v17.4s, v2.8h, v9.h[6] \n"
"smlal2 v18.4s, v2.8h, v10.h[6] \n"
"smlal2 v19.4s, v2.8h, v11.h[6] \n"
"subs %w[cnt], %w[cnt], #1 \n"
"smlal v12.4s, v3.4h, v8.h[7] \n"
"smlal v13.4s, v3.4h, v9.h[7] \n"
"smlal v14.4s, v3.4h, v10.h[7] \n"
"smlal v15.4s, v3.4h, v11.h[7] \n"
"smlal2 v16.4s, v3.8h, v8.h[7] \n"
"smlal2 v17.4s, v3.8h, v9.h[7] \n"
"smlal2 v18.4s, v3.8h, v10.h[7] \n"
"smlal2 v19.4s, v3.8h, v11.h[7] \n"
"beq 2f \n"
"1:\n"
"smlal v20.4s, v0.4h, v4.h[0] \n"
"smlal v21.4s, v0.4h, v5.h[0] \n"
"smlal v22.4s, v0.4h, v6.h[0] \n"
"smlal v23.4s, v0.4h, v7.h[0] \n"
"ld1 {v8.8h, v9.8h}, [%[b]], #32 \n" //load b0, b1
"ld1 {v10.8h, v11.8h}, [%[b]], #32 \n" //load b2, b3
"smlal2 v24.4s, v0.8h, v4.h[0] \n"
"smlal2 v25.4s, v0.8h, v5.h[0] \n"
"smlal2 v26.4s, v0.8h, v6.h[0] \n"
"smlal2 v27.4s, v0.8h, v7.h[0] \n"
"ld1 {v2.8h, v3.8h}, [%[a]], #32 \n" //load a2, a3
"smlal v20.4s, v1.4h, v4.h[1] \n"
"smlal v21.4s, v1.4h, v5.h[1] \n"
"smlal v22.4s, v1.4h, v6.h[1] \n"
"smlal v23.4s, v1.4h, v7.h[1] \n"
"smlal2 v24.4s, v1.8h, v4.h[1] \n"
"smlal2 v25.4s, v1.8h, v5.h[1] \n"
"smlal2 v26.4s, v1.8h, v6.h[1] \n"
"smlal2 v27.4s, v1.8h, v7.h[1] \n"
"smlal v12.4s, v0.4h, v8.h[0] \n"
"smlal v13.4s, v0.4h, v9.h[0] \n"
"smlal v14.4s, v0.4h, v10.h[0] \n"
"smlal v15.4s, v0.4h, v11.h[0] \n"
"smlal2 v16.4s, v0.8h, v8.h[0] \n"
"smlal2 v17.4s, v0.8h, v9.h[0] \n"
"smlal2 v18.4s, v0.8h, v10.h[0] \n"
"smlal2 v19.4s, v0.8h, v11.h[0] \n"
"smlal v12.4s, v1.4h, v8.h[1] \n"
"smlal v13.4s, v1.4h, v9.h[1] \n"
"smlal v14.4s, v1.4h, v10.h[1] \n"
"smlal v15.4s, v1.4h, v11.h[1] \n"
"smlal2 v16.4s, v1.8h, v8.h[1] \n"
"smlal2 v17.4s, v1.8h, v9.h[1] \n"
"smlal2 v18.4s, v1.8h, v10.h[1] \n"
"smlal2 v19.4s, v1.8h, v11.h[1] \n"
"smlal v20.4s, v2.4h, v4.h[2] \n"
"smlal v21.4s, v2.4h, v5.h[2] \n"
"smlal v22.4s, v2.4h, v6.h[2] \n"
"smlal v23.4s, v2.4h, v7.h[2] \n"
"ld1 {v0.8h, v1.8h}, [%[a]], #32 \n" //load a0, a1
"smlal2 v24.4s, v2.8h, v4.h[2] \n"
"smlal2 v25.4s, v2.8h, v5.h[2] \n"
"smlal2 v26.4s, v2.8h, v6.h[2] \n"
"smlal2 v27.4s, v2.8h, v7.h[2] \n"
"smlal v12.4s, v2.4h, v8.h[2] \n"
"smlal v13.4s, v2.4h, v9.h[2] \n"
"smlal v14.4s, v2.4h, v10.h[2] \n"
"smlal v15.4s, v2.4h, v11.h[2] \n"
"smlal2 v16.4s, v2.8h, v8.h[2] \n"
"smlal2 v17.4s, v2.8h, v9.h[2] \n"
"smlal2 v18.4s, v2.8h, v10.h[2] \n"
"smlal2 v19.4s, v2.8h, v11.h[2] \n"
"smlal v20.4s, v3.4h, v4.h[3] \n"
"smlal v21.4s, v3.4h, v5.h[3] \n"
"smlal v22.4s, v3.4h, v6.h[3] \n"
"smlal v23.4s, v3.4h, v7.h[3] \n"
"smlal2 v24.4s, v3.8h, v4.h[3] \n"
"smlal2 v25.4s, v3.8h, v5.h[3] \n"
"smlal2 v26.4s, v3.8h, v6.h[3] \n"
"smlal2 v27.4s, v3.8h, v7.h[3] \n"
"smlal v12.4s, v3.4h, v8.h[3] \n"
"smlal v13.4s, v3.4h, v9.h[3] \n"
"smlal v14.4s, v3.4h, v10.h[3] \n"
"smlal v15.4s, v3.4h, v11.h[3] \n"
"smlal2 v16.4s, v3.8h, v8.h[3] \n"
"smlal2 v17.4s, v3.8h, v9.h[3] \n"
"smlal2 v18.4s, v3.8h, v10.h[3] \n"
"smlal2 v19.4s, v3.8h, v11.h[3] \n"
"smlal v20.4s, v0.4h, v4.h[4] \n"
"smlal v21.4s, v0.4h, v5.h[4] \n"
"smlal v22.4s, v0.4h, v6.h[4] \n"
"smlal v23.4s, v0.4h, v7.h[4] \n"
"smlal2 v24.4s, v0.8h, v4.h[4] \n"
"smlal2 v25.4s, v0.8h, v5.h[4] \n"
"smlal2 v26.4s, v0.8h, v6.h[4] \n"
"smlal2 v27.4s, v0.8h, v7.h[4] \n"
"ld1 {v2.8h, v3.8h}, [%[a]], #32 \n" //load a2, a3
"smlal v20.4s, v1.4h, v4.h[5] \n"
"smlal v21.4s, v1.4h, v5.h[5] \n"
"smlal v22.4s, v1.4h, v6.h[5] \n"
"smlal v23.4s, v1.4h, v7.h[5] \n"
"smlal2 v24.4s, v1.8h, v4.h[5] \n"
"smlal2 v25.4s, v1.8h, v5.h[5] \n"
"smlal2 v26.4s, v1.8h, v6.h[5] \n"
"smlal2 v27.4s, v1.8h, v7.h[5] \n"
"smlal v12.4s, v0.4h, v8.h[4] \n"
"smlal v13.4s, v0.4h, v9.h[4] \n"
"smlal v14.4s, v0.4h, v10.h[4] \n"
"smlal v15.4s, v0.4h, v11.h[4] \n"
"smlal2 v16.4s, v0.8h, v8.h[4] \n"
"smlal2 v17.4s, v0.8h, v9.h[4] \n"
"smlal2 v18.4s, v0.8h, v10.h[4] \n"
"smlal2 v19.4s, v0.8h, v11.h[4] \n"
"smlal v12.4s, v1.4h, v8.h[5] \n"
"smlal v13.4s, v1.4h, v9.h[5] \n"
"smlal v14.4s, v1.4h, v10.h[5] \n"
"smlal v15.4s, v1.4h, v11.h[5] \n"
"smlal2 v16.4s, v1.8h, v8.h[5] \n"
"smlal2 v17.4s, v1.8h, v9.h[5] \n"
"smlal2 v18.4s, v1.8h, v10.h[5] \n"
"smlal2 v19.4s, v1.8h, v11.h[5] \n"
"smlal v20.4s, v2.4h, v4.h[6] \n"
"smlal v21.4s, v2.4h, v5.h[6] \n"
"smlal v22.4s, v2.4h, v6.h[6] \n"
"smlal v23.4s, v2.4h, v7.h[6] \n"
"ld1 {v0.8h, v1.8h}, [%[a]], #32 \n" //load a0, a1
"smlal2 v24.4s, v2.8h, v4.h[6] \n"
"smlal2 v25.4s, v2.8h, v5.h[6] \n"
"smlal2 v26.4s, v2.8h, v6.h[6] \n"
"smlal2 v27.4s, v2.8h, v7.h[6] \n"
"sub %[b], %[b], #128 \n"
"add %[b], %[b], %[ldb] \n"
"smlal v20.4s, v3.4h, v4.h[7] \n"
"smlal v21.4s, v3.4h, v5.h[7] \n"
"smlal v22.4s, v3.4h, v6.h[7] \n"
"smlal v23.4s, v3.4h, v7.h[7] \n"
"smlal2 v24.4s, v3.8h, v4.h[7] \n"
"smlal2 v25.4s, v3.8h, v5.h[7] \n"
"smlal2 v26.4s, v3.8h, v6.h[7] \n"
"smlal2 v27.4s, v3.8h, v7.h[7] \n"
"ld1 {v4.8h, v5.8h}, [%[b]], #32 \n" //load b0, b1
"ld1 {v6.8h, v7.8h}, [%[b]], #32 \n" //load b2, b3
"smlal v12.4s, v2.4h, v8.h[6] \n"
"smlal v13.4s, v2.4h, v9.h[6] \n"
"smlal v14.4s, v2.4h, v10.h[6] \n"
"smlal v15.4s, v2.4h, v11.h[6] \n"
"smlal2 v16.4s, v2.8h, v8.h[6] \n"
"smlal2 v17.4s, v2.8h, v9.h[6] \n"
"smlal2 v18.4s, v2.8h, v10.h[6] \n"
"smlal2 v19.4s, v2.8h, v11.h[6] \n"
"subs %w[cnt], %w[cnt], #1 \n"
"smlal v12.4s, v3.4h, v8.h[7] \n"
"smlal v13.4s, v3.4h, v9.h[7] \n"
"smlal v14.4s, v3.4h, v10.h[7] \n"
"smlal v15.4s, v3.4h, v11.h[7] \n"
"smlal2 v16.4s, v3.8h, v8.h[7] \n"
"smlal2 v17.4s, v3.8h, v9.h[7] \n"
"smlal2 v18.4s, v3.8h, v10.h[7] \n"
"smlal2 v19.4s, v3.8h, v11.h[7] \n"
"bne 1b \n"
"2: \n"
"stp q20, q24, [%[c]], #32 \n"
"stp q21, q25, [%[c]], #32 \n"
"stp q22, q26, [%[c]], #32 \n"
"stp q23, q27, [%[c]], #32 \n"
"stp q12, q16, [%[c]], #32 \n"
"stp q13, q17, [%[c]], #32 \n"
"stp q14, q18, [%[c]], #32 \n"
"stp q15, q19, [%[c]], #32 \n"
: [a] "+r" (a_ptr),
[b] "+r" (b_ptr),
[c] "+r" (C),
[cnt] "+r" (cnt)
: [ldb] "r" (ldb_byte)
: "v0", "v1", "v2", "v3", "v4","v5", "v6", "v7", "v8", "v9",
"v10", "v11", "13", "14", "15", "16", "17", "18", "19","v20",
"v21", "v22", "v23", "v24", "v25", "v26", "v27", "cc", "memory"
);
// clang format on
b += 64;
}
for (; n > 3; n -= 4) {
int cnt = kcnt;
const int16_t* a_ptr = A_packed;
const int16_t* b_ptr = b;
// clang-format off
asm volatile(
"ld1 {v0.8h, v1.8h}, [%[a]], #32 \n"
"ld1 {v4.8h, v5.8h}, [%[b]], #32 \n"
"smull v8.4s, v0.4h, v4.h[0] \n"
"smull v9.4s, v0.4h, v5.h[0] \n"
"ld1 {v6.8h, v7.8h}, [%[b]], #32 \n"
"smull2 v10.4s, v0.8h, v4.h[0] \n"
"smull2 v11.4s, v0.8h, v5.h[0] \n"
"smlal v8.4s, v1.4h, v4.h[1] \n"
"smlal v9.4s, v1.4h, v5.h[1] \n"
"smlal2 v10.4s, v1.8h, v4.h[1] \n"
"smlal2 v11.4s, v1.8h, v5.h[1] \n"
"ld1 {v2.8h, v3.8h}, [%[a]], #32 \n"
"smull v12.4s, v0.4h, v6.h[0] \n"
"smull v13.4s, v0.4h, v7.h[0] \n"
"smull2 v14.4s, v0.8h, v6.h[0] \n"
"smull2 v15.4s, v0.8h, v7.h[0] \n"
"smlal v12.4s, v1.4h, v6.h[1] \n"
"smlal v13.4s, v1.4h, v7.h[1] \n"
"smlal2 v14.4s, v1.8h, v6.h[1] \n"
"smlal2 v15.4s, v1.8h, v7.h[1] \n"
"smlal v8.4s, v2.4h, v4.h[2] \n"
"smlal v9.4s, v2.4h, v5.h[2] \n"
"ld1 {v0.8h, v1.8h}, [%[a]], #32 \n"
"smlal2 v10.4s, v2.8h, v4.h[2] \n"
"smlal2 v11.4s, v2.8h, v5.h[2] \n"
"smlal v8.4s, v3.4h, v4.h[3] \n"
"smlal v9.4s, v3.4h, v5.h[3] \n"
"smlal2 v10.4s, v3.8h, v4.h[3] \n"
"smlal2 v11.4s, v3.8h, v5.h[3] \n"
"smlal v12.4s, v2.4h, v6.h[2] \n"
"smlal v13.4s, v2.4h, v7.h[2] \n"
"smlal2 v14.4s, v2.8h, v6.h[2] \n"
"smlal2 v15.4s, v2.8h, v7.h[2] \n"
"smlal v12.4s, v3.4h, v6.h[3] \n"
"smlal v13.4s, v3.4h, v7.h[3] \n"
"smlal2 v14.4s, v3.8h, v6.h[3] \n"
"smlal2 v15.4s, v3.8h, v7.h[3] \n"
"smlal v8.4s, v0.4h, v4.h[4] \n"
"smlal v9.4s, v0.4h, v5.h[4] \n"
"smlal2 v10.4s, v0.8h, v4.h[4] \n"
"smlal2 v11.4s, v0.8h, v5.h[4] \n"
"smlal v8.4s, v1.4h, v4.h[5] \n"
"smlal v9.4s, v1.4h, v5.h[5] \n"
"smlal2 v10.4s, v1.8h, v4.h[5] \n"
"smlal2 v11.4s, v1.8h, v5.h[5] \n"
"ld1 {v2.8h, v3.8h}, [%[a]], #32 \n"
"smlal v12.4s, v0.4h, v6.h[4] \n"
"smlal v13.4s, v0.4h, v7.h[4] \n"
"smlal2 v14.4s, v0.8h, v6.h[4] \n"
"smlal2 v15.4s, v0.8h, v7.h[4] \n"
"smlal v12.4s, v1.4h, v6.h[5] \n"
"smlal v13.4s, v1.4h, v7.h[5] \n"
"smlal2 v14.4s, v1.8h, v6.h[5] \n"
"smlal2 v15.4s, v1.8h, v7.h[5] \n"
"smlal v8.4s, v2.4h, v4.h[6] \n"
"smlal v9.4s, v2.4h, v5.h[6] \n"
"ld1 {v0.8h, v1.8h}, [%[a]], #32 \n"
"smlal2 v10.4s, v2.8h, v4.h[6] \n"
"smlal2 v11.4s, v2.8h, v5.h[6] \n"
"smlal v8.4s, v3.4h, v4.h[7] \n"
"smlal v9.4s, v3.4h, v5.h[7] \n"
"smlal2 v10.4s, v3.8h, v4.h[7] \n"
"smlal2 v11.4s, v3.8h, v5.h[7] \n"
"sub %[b], %[b], #64 \n"
"add %[b], %[b], %[ldb] \n"
"smlal v12.4s, v2.4h, v6.h[6] \n"
"smlal v13.4s, v2.4h, v7.h[6] \n"
"subs %w[cnt], %w[cnt], #1 \n"
"ld1 {v4.8h, v5.8h}, [%[b]], #32 \n"
"smlal2 v14.4s, v2.8h, v6.h[6] \n"
"smlal2 v15.4s, v2.8h, v7.h[6] \n"
"smlal v12.4s, v3.4h, v6.h[7] \n"
"smlal v13.4s, v3.4h, v7.h[7] \n"
"smlal2 v14.4s, v3.8h, v6.h[7] \n"
"smlal2 v15.4s, v3.8h, v7.h[7] \n"
"beq 2f \n"
"1: \n"
"smlal v8.4s, v0.4h, v4.h[0] \n"
"smlal v9.4s, v0.4h, v5.h[0] \n"
"ld1 {v6.8h, v7.8h}, [%[b]], #32 \n"
"smlal2 v10.4s, v0.8h, v4.h[0] \n"
"smlal2 v11.4s, v0.8h, v5.h[0] \n"
"smlal v8.4s, v1.4h, v4.h[1] \n"
"smlal v9.4s, v1.4h, v5.h[1] \n"
"smlal2 v10.4s, v1.8h, v4.h[1] \n"
"smlal2 v11.4s, v1.8h, v5.h[1] \n"
"ld1 {v2.8h, v3.8h}, [%[a]], #32 \n"
"smlal v12.4s, v0.4h, v6.h[0] \n"
"smlal v13.4s, v0.4h, v7.h[0] \n"
"smlal2 v14.4s, v0.8h, v6.h[0] \n"
"smlal2 v15.4s, v0.8h, v7.h[0] \n"
"smlal v12.4s, v1.4h, v6.h[1] \n"
"smlal v13.4s, v1.4h, v7.h[1] \n"
"smlal2 v14.4s, v1.8h, v6.h[1] \n"
"smlal2 v15.4s, v1.8h, v7.h[1] \n"
"smlal v8.4s, v2.4h, v4.h[2] \n"
"smlal v9.4s, v2.4h, v5.h[2] \n"
"ld1 {v0.8h, v1.8h}, [%[a]], #32 \n"
"smlal2 v10.4s, v2.8h, v4.h[2] \n"
"smlal2 v11.4s, v2.8h, v5.h[2] \n"
"smlal v8.4s, v3.4h, v4.h[3] \n"
"smlal v9.4s, v3.4h, v5.h[3] \n"
"smlal2 v10.4s, v3.8h, v4.h[3] \n"
"smlal2 v11.4s, v3.8h, v5.h[3] \n"
"smlal v12.4s, v2.4h, v6.h[2] \n"
"smlal v13.4s, v2.4h, v7.h[2] \n"
"smlal2 v14.4s, v2.8h, v6.h[2] \n"
"smlal2 v15.4s, v2.8h, v7.h[2] \n"
"smlal v12.4s, v3.4h, v6.h[3] \n"
"smlal v13.4s, v3.4h, v7.h[3] \n"
"smlal2 v14.4s, v3.8h, v6.h[3] \n"
"smlal2 v15.4s, v3.8h, v7.h[3] \n"
"smlal v8.4s, v0.4h, v4.h[4] \n"
"smlal v9.4s, v0.4h, v5.h[4] \n"
"ld1 {v2.8h, v3.8h}, [%[a]], #32 \n"
"smlal2 v10.4s, v0.8h, v4.h[4] \n"
"smlal2 v11.4s, v0.8h, v5.h[4] \n"
"smlal v8.4s, v1.4h, v4.h[5] \n"
"smlal v9.4s, v1.4h, v5.h[5] \n"
"smlal2 v10.4s, v1.8h, v4.h[5] \n"
"smlal2 v11.4s, v1.8h, v5.h[5] \n"
"smlal v12.4s, v0.4h, v6.h[4] \n"
"smlal v13.4s, v0.4h, v7.h[4] \n"
"smlal2 v14.4s, v0.8h, v6.h[4] \n"
"smlal2 v15.4s, v0.8h, v7.h[4] \n"
"smlal v12.4s, v1.4h, v6.h[5] \n"
"smlal v13.4s, v1.4h, v7.h[5] \n"
"smlal2 v14.4s, v1.8h, v6.h[5] \n"
"smlal2 v15.4s, v1.8h, v7.h[5] \n"
"smlal v8.4s, v2.4h, v4.h[6] \n"
"smlal v9.4s, v2.4h, v5.h[6] \n"
"ld1 {v0.8h, v1.8h}, [%[a]], #32 \n"
"smlal2 v10.4s, v2.8h, v4.h[6] \n"
"smlal2 v11.4s, v2.8h, v5.h[6] \n"
"smlal v8.4s, v3.4h, v4.h[7] \n"
"smlal v9.4s, v3.4h, v5.h[7] \n"
"smlal2 v10.4s, v3.8h, v4.h[7] \n"
"smlal2 v11.4s, v3.8h, v5.h[7] \n"
"sub %[b], %[b], #64 \n"
"add %[b], %[b], %[ldb] \n"
"smlal v12.4s, v2.4h, v6.h[6] \n"
"smlal v13.4s, v2.4h, v7.h[6] \n"
"subs %w[cnt], %w[cnt], #1 \n"
"ld1 {v4.8h, v5.8h}, [%[b]], #32 \n"
"smlal2 v14.4s, v2.8h, v6.h[6] \n"
"smlal2 v15.4s, v2.8h, v7.h[6] \n"
"smlal v12.4s, v3.4h, v6.h[7] \n"
"smlal v13.4s, v3.4h, v7.h[7] \n"
"smlal2 v14.4s, v3.8h, v6.h[7] \n"
"smlal2 v15.4s, v3.8h, v7.h[7] \n"
"bne 1b \n"
"2: \n"
"stp q8, q10, [%[c]], #32 \n"
"stp q9, q11, [%[c]], #32 \n"
"stp q12, q14, [%[c]], #32 \n"
"stp q13, q15, [%[c]], #32 \n"
: [a] "+r" (a_ptr),
[b] "+r" (b_ptr),
[c] "+r" (C),
[cnt] "+r" (cnt)
: [ldb] "r" (ldb_byte)
: "v0", "v1", "v2", "v3", "v4","v5", "v6", "v7", "v8", "v9",
"v10", "v11","v12", "v13", "v14", "v15", "cc", "memory"
);
// clang-format on
b += 32;
}
for (; n > 0; --n) {
int cnt = kcnt;
const int16_t* a_ptr = A_packed;
const int16_t* b_ptr = b;
// clang-format off
asm volatile(
"ld1 {v0.8h, v1.8h}, [%[a]], #32 \n"
"ld1 {v4.8h}, [%[b]], #16 \n"
"ld1 {v2.8h, v3.8h}, [%[a]], #32 \n"
"smull v5.4s, v0.4h, v4.h[0] \n"
"smull2 v6.4s, v0.8h, v4.h[0] \n"
"ld1 {v10.8h, v11.8h}, [%[a]], #32 \n"
"smlal v5.4s, v1.4h, v4.h[1] \n"
"smlal2 v6.4s, v1.8h, v4.h[1] \n"
"ld1 {v12.8h, v13.8h}, [%[a]], #32 \n"
"smlal v5.4s, v2.4h, v4.h[2] \n"
"smlal2 v6.4s, v2.8h, v4.h[2] \n"
"smlal v5.4s, v3.4h, v4.h[3] \n"
"smlal2 v6.4s, v3.8h, v4.h[3] \n"
"sub %[b], %[b], #16 \n"
"add %[b], %[b], %[ldb] \n"
"smlal v5.4s, v10.4h, v4.h[4] \n"
"smlal2 v6.4s, v10.8h, v4.h[4] \n"
"smlal v5.4s, v11.4h, v4.h[5] \n"
"smlal2 v6.4s, v11.8h, v4.h[5] \n"
"subs %w[cnt], %w[cnt], #1 \n"
"ld1 {v0.8h, v1.8h}, [%[a]], #32 \n"
"smlal v5.4s, v12.4h, v4.h[6] \n"
"smlal2 v6.4s, v12.8h, v4.h[6] \n"
"smlal v5.4s, v13.4h, v4.h[7] \n"
"smlal2 v6.4s, v13.8h, v4.h[7] \n"
"beq 2f \n"
"1: \n"
"ld1 {v4.8h}, [%[b]], #16 \n"
"ld1 {v2.8h, v3.8h}, [%[a]], #32 \n"
"smlal v5.4s, v0.4h, v4.h[0] \n"
"smlal2 v6.4s, v0.8h, v4.h[0] \n"
"ld1 {v10.8h, v11.8h}, [%[a]], #32 \n"
"smlal v5.4s, v1.4h, v4.h[1] \n"
"smlal2 v6.4s, v1.8h, v4.h[1] \n"
"ld1 {v12.8h, v13.8h}, [%[a]], #32 \n"
"smlal v5.4s, v2.4h, v4.h[2] \n"
"smlal2 v6.4s, v2.8h, v4.h[2] \n"
"smlal v5.4s, v3.4h, v4.h[3] \n"
"smlal2 v6.4s, v3.8h, v4.h[3] \n"
"sub %[b], %[b], #16 \n"
"add %[b], %[b], %[ldb] \n"
"smlal v5.4s, v10.4h, v4.h[4] \n"
"smlal2 v6.4s, v10.8h, v4.h[4] \n"
"smlal v5.4s, v11.4h, v4.h[5] \n"
"smlal2 v6.4s, v11.8h, v4.h[5] \n"
"subs %w[cnt], %w[cnt], #1 \n"
"ld1 {v0.8h, v1.8h}, [%[a]], #32 \n"
"smlal v5.4s, v12.4h, v4.h[6] \n"
"smlal2 v6.4s, v12.8h, v4.h[6] \n"
"smlal v5.4s, v13.4h, v4.h[7] \n"
"smlal2 v6.4s, v13.8h, v4.h[7] \n"
"bne 1b \n"
"2: \n"
"st1 {v5.4s, v6.4s}, [%[c]], #32 \n"
: [a] "+r" (a_ptr),
[b] "+r" (b_ptr),
[c] "+r" (C),
[cnt] "+r" (cnt)
: [ldb] "r" (ldb_byte)
: "v0", "v1", "v2", "v3", "v4","v5", "v6", "cc", "memory"
);
// clang-format on
b += 8;
}
#else
for (; n > 3; n -= 4) {
int cnt = kcnt;
const int16_t* a_ptr = A_packed;
const int16_t* b_ptr = b;
// clang-format off
asm volatile (
"vld1.16 {d0-d3}, [%[b]]! \n"
"vld1.16 {d8-d11}, [%[a]]! \n"
"vld1.16 {d4-d7}, [%[b]]! \n"
"vmull.s16 q8, d8, d0[0] \n"
"vmull.s16 q9, d8, d2[0] \n"
"vld1.16 {d12-d15}, [%[a]]! \n"
"vmull.s16 q10, d9, d0[0] \n"
"vmull.s16 q11, d9, d2[0] \n"
"vmlal.s16 q8, d10, d0[1] \n"
"vmlal.s16 q9, d10, d2[1] \n"
"vmlal.s16 q10, d11, d0[1] \n"
"vmlal.s16 q11, d11, d2[1] \n"
"vmull.s16 q12, d8, d4[0] \n"
"vmull.s16 q13, d8, d6[0] \n"
"vmull.s16 q14, d9, d4[0] \n"
"vmull.s16 q15, d9, d6[0] \n"
"vmlal.s16 q12, d10, d4[1] \n"
"vmlal.s16 q13, d10, d6[1] \n"
"vmlal.s16 q14, d11, d4[1] \n"
"vmlal.s16 q15, d11, d6[1] \n"
"vmlal.s16 q8, d12, d0[2] \n"
"vmlal.s16 q9, d12, d2[2] \n"
"vld1.16 {d8-d11}, [%[a]]! \n"
"vmlal.s16 q10, d13, d0[2] \n"
"vmlal.s16 q11, d13, d2[2] \n"
"vmlal.s16 q8, d14, d0[3] \n"
"vmlal.s16 q9, d14, d2[3] \n"
"vmlal.s16 q10, d15, d0[3] \n"
"vmlal.s16 q11, d15, d2[3] \n"
"vmlal.s16 q12, d12, d4[2] \n"
"vmlal.s16 q13, d12, d6[2] \n"
"vmlal.s16 q14, d13, d4[2] \n"
"vmlal.s16 q15, d13, d6[2] \n"
"vmlal.s16 q12, d14, d4[3] \n"
"vmlal.s16 q13, d14, d6[3] \n"
"vmlal.s16 q14, d15, d4[3] \n"
"vmlal.s16 q15, d15, d6[3] \n"
"sub %[b], %[b], #64 \n"
"add %[b], %[b], %[ldb] \n"
"vld1.16 {d12-d15}, [%[a]]! \n"
"vmlal.s16 q8, d8, d1[0] \n"
"vmlal.s16 q9, d8, d3[0] \n"
"vmlal.s16 q10, d9, d1[0] \n"
"vmlal.s16 q11, d9, d3[0] \n"
"vmlal.s16 q8, d10, d1[1] \n"
"vmlal.s16 q9, d10, d3[1] \n"
"vmlal.s16 q10, d11, d1[1] \n"
"vmlal.s16 q11, d11, d3[1] \n"
"vmlal.s16 q8, d12, d1[2] \n"
"vmlal.s16 q9, d12, d3[2] \n"
"vmlal.s16 q10, d13, d1[2] \n"
"vmlal.s16 q11, d13, d3[2] \n"
"vmlal.s16 q8, d14, d1[3] \n"
"vmlal.s16 q9, d14, d3[3] \n"
"vmlal.s16 q10, d15, d1[3] \n"
"vmlal.s16 q11, d15, d3[3] \n"
"vld1.16 {d0-d3}, [%[b]]! \n"
"vmlal.s16 q12, d8, d5[0] \n"
"vmlal.s16 q13, d8, d7[0] \n"
"vmlal.s16 q14, d9, d5[0] \n"
"vmlal.s16 q15, d9, d7[0] \n"
"vmlal.s16 q12, d10, d5[1] \n"
"vmlal.s16 q13, d10, d7[1] \n"
"subs %[cnt], %[cnt], #1 \n"
"vmlal.s16 q14, d11, d5[1] \n"
"vmlal.s16 q15, d11, d7[1] \n"
"vld1.16 {d8-d11}, [%[a]]! \n"
"vmlal.s16 q12, d12, d5[2] \n"
"vmlal.s16 q13, d12, d7[2] \n"
"vmlal.s16 q14, d13, d5[2] \n"
"vmlal.s16 q15, d13, d7[2] \n"
"vmlal.s16 q12, d14, d5[3] \n"
"vmlal.s16 q13, d14, d7[3] \n"
"vmlal.s16 q14, d15, d5[3] \n"
"vmlal.s16 q15, d15, d7[3] \n"
"beq 2f \n"
"1: \n"
"vld1.16 {d4-d7}, [%[b]]! \n"
"vmlal.s16 q8, d8, d0[0] \n"
"vmlal.s16 q9, d8, d2[0] \n"
"vld1.16 {d12-d15}, [%[a]]! \n"
"vmlal.s16 q10, d9, d0[0] \n"
"vmlal.s16 q11, d9, d2[0] \n"
"vmlal.s16 q8, d10, d0[1] \n"
"vmlal.s16 q9, d10, d2[1] \n"
"vmlal.s16 q10, d11, d0[1] \n"
"vmlal.s16 q11, d11, d2[1] \n"
"vmlal.s16 q12, d8, d4[0] \n"
"vmlal.s16 q13, d8, d6[0] \n"
"vmlal.s16 q14, d9, d4[0] \n"
"vmlal.s16 q15, d9, d6[0] \n"
"vmlal.s16 q12, d10, d4[1] \n"
"vmlal.s16 q13, d10, d6[1] \n"
"vmlal.s16 q14, d11, d4[1] \n"
"vmlal.s16 q15, d11, d6[1] \n"
"vmlal.s16 q8, d12, d0[2] \n"
"vmlal.s16 q9, d12, d2[2] \n"
"vld1.16 {d8-d11}, [%[a]]! \n"
"vmlal.s16 q10, d13, d0[2] \n"
"vmlal.s16 q11, d13, d2[2] \n"
"vmlal.s16 q8, d14, d0[3] \n"
"vmlal.s16 q9, d14, d2[3] \n"
"vmlal.s16 q10, d15, d0[3] \n"
"vmlal.s16 q11, d15, d2[3] \n"
"vmlal.s16 q12, d12, d4[2] \n"
"vmlal.s16 q13, d12, d6[2] \n"
"vmlal.s16 q14, d13, d4[2] \n"
"vmlal.s16 q15, d13, d6[2] \n"
"vmlal.s16 q12, d14, d4[3] \n"
"vmlal.s16 q13, d14, d6[3] \n"
"vmlal.s16 q14, d15, d4[3] \n"
"vmlal.s16 q15, d15, d6[3] \n"
"sub %[b], %[b], #64 \n"
"add %[b], %[b], %[ldb] \n"
"vld1.16 {d12-d15}, [%[a]]! \n"
"vmlal.s16 q8, d8, d1[0] \n"
"vmlal.s16 q9, d8, d3[0] \n"
"vmlal.s16 q10, d9, d1[0] \n"
"vmlal.s16 q11, d9, d3[0] \n"
"vmlal.s16 q8, d10, d1[1] \n"
"vmlal.s16 q9, d10, d3[1] \n"
"vmlal.s16 q10, d11, d1[1] \n"
"vmlal.s16 q11, d11, d3[1] \n"
"vmlal.s16 q8, d12, d1[2] \n"
"vmlal.s16 q9, d12, d3[2] \n"
"vmlal.s16 q10, d13, d1[2] \n"
"vmlal.s16 q11, d13, d3[2] \n"
"vmlal.s16 q8, d14, d1[3] \n"
"vmlal.s16 q9, d14, d3[3] \n"
"vmlal.s16 q10, d15, d1[3] \n"
"vmlal.s16 q11, d15, d3[3] \n"
"vld1.16 {d0-d3}, [%[b]]! \n"
"vmlal.s16 q12, d8, d5[0] \n"
"vmlal.s16 q13, d8, d7[0] \n"
"vmlal.s16 q14, d9, d5[0] \n"
"vmlal.s16 q15, d9, d7[0] \n"
"vmlal.s16 q12, d10, d5[1] \n"
"vmlal.s16 q13, d10, d7[1] \n"
"subs %[cnt], %[cnt], #1 \n"
"vmlal.s16 q14, d11, d5[1] \n"
"vmlal.s16 q15, d11, d7[1] \n"
"vld1.16 {d8-d11}, [%[a]]! \n"
"vmlal.s16 q12, d12, d5[2] \n"
"vmlal.s16 q13, d12, d7[2] \n"
"vmlal.s16 q14, d13, d5[2] \n"
"vmlal.s16 q15, d13, d7[2] \n"
"vmlal.s16 q12, d14, d5[3] \n"
"vmlal.s16 q13, d14, d7[3] \n"
"vmlal.s16 q14, d15, d5[3] \n"
"vmlal.s16 q15, d15, d7[3] \n"
"bne 1b \n"
"2: \n"
"vst1.32 {d16-d17}, [%[c]]! \n"
"vst1.32 {d20-d21}, [%[c]]! \n"
"vst1.32 {d18-d19}, [%[c]]! \n"
"vst1.32 {d22-d23}, [%[c]]! \n"
"vst1.32 {d24-d25}, [%[c]]! \n"
"vst1.32 {d28-d29}, [%[c]]! \n"
"vst1.32 {d26-d27}, [%[c]]! \n"
"vst1.32 {d30-d31}, [%[c]]! \n"
: [a] "+r" (a_ptr),
[b] "+r" (b_ptr),
[c] "+r" (C),
[cnt] "+r" (cnt)
: [ldb] "r" (ldb_byte)
: "q0", "q1", "q2", "q3", "q4","q5", "q6", "q7", "q8",
"q9", "q10", "q11", "q12", "q13", "q14", "q15", "cc", "memory"
);
// clang format on
b += 32;
}
for (; n > 0; --n) {
int cnt = kcnt;
const int16_t* a_ptr = A_packed;
const int16_t* b_ptr = b;
// clang format off
asm volatile (
"vld1.16 {d0-d1}, [%[b]]! \n"
"vld1.16 {d4-d7}, [%[a]]! \n"
"vld1.16 {d8-d11}, [%[a]]! \n"
"vmull.s16 q8, d4, d0[0] \n"
"vmull.s16 q9, d5, d0[0] \n"
"sub %[b], %[b], #16 \n"
"vmlal.s16 q8, d6, d0[1] \n"
"vmlal.s16 q9, d7, d0[1] \n"
"add %[b], %[b], %[ldb] \n"
"subs %[cnt], %[cnt], #1 \n"
"vld1.16 {d4-d7}, [%[a]]! \n"
"vmlal.s16 q8, d8, d0[2] \n"
"vmlal.s16 q9, d9, d0[2] \n"
"vmlal.s16 q8, d10, d0[3] \n"
"vmlal.s16 q9, d11, d0[3] \n"
"vld1.16 {d8-d11}, [%[a]]! \n"
"vmlal.s16 q8, d4, d1[0] \n"
"vmlal.s16 q9, d5, d1[0] \n"
"vmlal.s16 q8, d6, d1[1] \n"
"vmlal.s16 q9, d7, d1[1] \n"
"vld1.16 {d4-d7}, [%[a]]! \n"
"vmlal.s16 q8, d8, d1[2] \n"
"vmlal.s16 q9, d9, d1[2] \n"
"vmlal.s16 q8, d10, d1[3] \n"
"vmlal.s16 q9, d11, d1[3] \n"
"beq 2f \n"
"1:\n"
"vld1.16 {d0-d1}, [%[b]]! \n"
"vld1.16 {d8-d11}, [%[a]]! \n"
"vmlal.s16 q8, d4, d0[0] \n"
"vmlal.s16 q9, d5, d0[0] \n"
"sub %[b], %[b], #16 \n"
"vmlal.s16 q8, d6, d0[1] \n"
"vmlal.s16 q9, d7, d0[1] \n"
"add %[b], %[b], %[ldb] \n"
"subs %[cnt], %[cnt], #1 \n"
"vld1.16 {d4-d7}, [%[a]]! \n"
"vmlal.s16 q8, d8, d0[2] \n"
"vmlal.s16 q9, d9, d0[2] \n"
"vmlal.s16 q8, d10, d0[3] \n"
"vmlal.s16 q9, d11, d0[3] \n"
"vld1.16 {d8-d11}, [%[a]]! \n"
"vmlal.s16 q8, d4, d1[0] \n"
"vmlal.s16 q9, d5, d1[0] \n"
"vmlal.s16 q8, d6, d1[1] \n"
"vmlal.s16 q9, d7, d1[1] \n"
"vld1.16 {d4-d7}, [%[a]]! \n"
"vmlal.s16 q8, d8, d1[2] \n"
"vmlal.s16 q9, d9, d1[2] \n"
"vmlal.s16 q8, d10, d1[3] \n"
"vmlal.s16 q9, d11, d1[3] \n"
"bne 1b \n"
"2: \n"
"vst1.32 {d16-d19}, [%[c]]! \n"
: [a] "+r" (a_ptr),
[b] "+r" (b_ptr),
[c] "+r" (C),
[cnt] "+r" (cnt)
: [ldb] "r" (ldb_byte)
: "q0", "q1", "q2", "q3", "q4","q5", "q6", "q7", "q8",
"q9", "cc", "memory"
);
// clang-format on
b += 8;
}
#endif
A_packed += lda;
}
}
void sgemm_prepack_c4(int M, void sgemm_prepack_c4(int M,
int N, int N,
int K, int K,
......
...@@ -54,6 +54,13 @@ void sgemm_prepack_c4_small(int M, ...@@ -54,6 +54,13 @@ void sgemm_prepack_c4_small(int M,
const float* B, const float* B,
float* C, float* C,
ARMContext* ctx); ARMContext* ctx);
void sgemm_prepack_c8_int16_small(int M,
int N,
int K,
const int16_t* A_packed,
const int16_t* B,
int32_t* C,
ARMContext* ctx);
} // namespace math } // namespace math
} // namespace arm } // namespace arm
} // namespace lite } // namespace lite
......
...@@ -206,6 +206,20 @@ void pooling_basic(const float* din, ...@@ -206,6 +206,20 @@ void pooling_basic(const float* din,
"ld2 {v0.4s, v1.4s}, [%[dr0]], #32\n" /* load q0-q1, dr0, 0-7*/ \ "ld2 {v0.4s, v1.4s}, [%[dr0]], #32\n" /* load q0-q1, dr0, 0-7*/ \
"ld2 {v2.4s, v3.4s}, [%[dr1]], #32\n" /* load q2-q3, dr1, 0-7*/ "ld2 {v2.4s, v3.4s}, [%[dr1]], #32\n" /* load q2-q3, dr1, 0-7*/
#define P2x2S2P1_MAX \
"ext v6.16b, %[vzero].16b, v1.16b, #12\n" /* 1357-0135 */ \
"ext v8.16b, %[vzero].16b, v3.16b, #12\n" /* 1357-0135 */ \
"sub %[dr0], %[dr0], #4\n" /* sub */ \
"sub %[dr1], %[dr1], #4\n" /* sub */ \
"fmax v4.4s, v0.4s, v6.4s\n" /* max */ \
"fmax v5.4s, v2.4s, v8.4s\n" /* max */ \
"ld2 {v0.4s, v1.4s}, [%[dr0]], #32\n" /* load q0-q1, dr0, 0-7*/ \
"ld2 {v2.4s, v3.4s}, [%[dr1]], #32\n" /* load q2-q3, dr1, 0-7*/ \
"fmax v6.4s, v4.4s, v5.4s\n" /* max reduce */ \
"subs %w[cnt_num], %w[cnt_num], #1\n" /* subs cnt_num, #1*/ \
"st1 {v6.4s}, [%[dr_out]], #16\n" /* store 4 out, dr_out */ \
"ble 2f\n" /* bne s3_max_loop_mid */
#define P2x2S2P0_MAX \ #define P2x2S2P0_MAX \
"1: \n" \ "1: \n" \
"fmax v4.4s, v0.4s, v1.4s\n" /* max */ \ "fmax v4.4s, v0.4s, v1.4s\n" /* max */ \
...@@ -217,6 +231,21 @@ void pooling_basic(const float* din, ...@@ -217,6 +231,21 @@ void pooling_basic(const float* din,
"st1 {v6.4s}, [%[dr_out]], #16\n" /* store 4 out, dr_out */ \ "st1 {v6.4s}, [%[dr_out]], #16\n" /* store 4 out, dr_out */ \
"bne 1b\n" /* bne s3_max_loop_mid */ "bne 1b\n" /* bne s3_max_loop_mid */
#define P2x2S2P1_AVG \
"ext v6.16b, %[vzero].16b, v1.16b, #12\n" /* 1357-0135 */ \
"ext v8.16b, %[vzero].16b, v3.16b, #12\n" /* 1357-0135 */ \
"sub %[dr0], %[dr0], #4\n" /* sub */ \
"sub %[dr1], %[dr1], #4\n" /* sub */ \
"fadd v4.4s, v0.4s, v6.4s\n" /* add 0, 2, 4, 6 and 1, 3, 5, 7 */ \
"fadd v5.4s, v2.4s, v8.4s\n" /* add 0, 2, 4, 6 and 1, 3, 5, 7 */ \
"ld2 {v0.4s, v1.4s}, [%[dr0]], #32\n" /* load q0-q1, dr0, 0-7*/ \
"ld2 {v2.4s, v3.4s}, [%[dr1]], #32\n" /* load q2-q3, dr1, 0-7*/ \
"fadd v6.4s, v4.4s, v5.4s\n" /* add reduce */ \
"subs %w[cnt_num], %w[cnt_num], #1\n" /* subs cnt_num, #1*/ \
"fmul v4.4s, v6.4s, %[vcoef_left].4s\n" /* mul coef */ \
"st1 {v4.4s}, [%[dr_out]], #16\n" /* store 4 out, dr_out */ \
"ble 2f\n" /* bne s3_max_loop_mid */
#define P2x2S2P0_AVG \ #define P2x2S2P0_AVG \
"1: \n" /* load bias to q2, q3*/ \ "1: \n" /* load bias to q2, q3*/ \
"fadd v4.4s, v0.4s, v1.4s\n" /* add 0, 2, 4, 6 and 1, 3, 5, 7 */ \ "fadd v4.4s, v0.4s, v1.4s\n" /* add 0, 2, 4, 6 and 1, 3, 5, 7 */ \
...@@ -228,6 +257,7 @@ void pooling_basic(const float* din, ...@@ -228,6 +257,7 @@ void pooling_basic(const float* din,
"fmul v4.4s, v6.4s, %[vcoef].4s\n" /* mul coef */ \ "fmul v4.4s, v6.4s, %[vcoef].4s\n" /* mul coef */ \
"st1 {v4.4s}, [%[dr_out]], #16\n" /* store 4 out, dr_out */ \ "st1 {v4.4s}, [%[dr_out]], #16\n" /* store 4 out, dr_out */ \
"bne 1b\n" /* bne s3_max_loop_mid */ "bne 1b\n" /* bne s3_max_loop_mid */
#define P3x3S1_INIT \ #define P3x3S1_INIT \
"ldr q0, [%[dr0]], #16\n" /* load q0, dr0, 0-3*/ \ "ldr q0, [%[dr0]], #16\n" /* load q0, dr0, 0-3*/ \
"ldr q1, [%[dr1]], #16\n" /* load q1, dr1, 0-3*/ \ "ldr q1, [%[dr1]], #16\n" /* load q1, dr1, 0-3*/ \
...@@ -518,16 +548,45 @@ void pooling_basic(const float* din, ...@@ -518,16 +548,45 @@ void pooling_basic(const float* din,
"vld2.f32 {d0-d3}, [%[dr0]]! @ load \n" \ "vld2.f32 {d0-d3}, [%[dr0]]! @ load \n" \
"vld2.f32 {d4-d7}, [%[dr1]]! @ load \n" "vld2.f32 {d4-d7}, [%[dr1]]! @ load \n"
#define P2x2S2P1_MAX \
"vext.32 q4, %q[vzero], q1, #3 @ 1357-0135\n" \
"vext.32 q5, %q[vzero], q3, #3 @ 1357-0135\n" \
"sub %[dr0], #4 @sub \n" \
"sub %[dr1], #4 @sub \n" \
"vmax.f32 q8, q0, q4 @ max \n" \
"vmax.f32 q9, q2, q5 @ max \n" \
"vld2.f32 {d0-d3}, [%[dr0]]! @ load \n" \
"vld2.f32 {d4-d7}, [%[dr1]]! @ load \n" \
"vmax.f32 q5, q9, q8 @ max reduce\n" \
"subs %[cnt_num], #1 @ subs cnt_num \n" \
"vst1.f32 {d10-d11}, [%[dr_out]]! @ store 4 out \n" \
"ble 2f @ bne \n"
#define P2x2S2P0_MAX \ #define P2x2S2P0_MAX \
"1: @ main loop\n" \ "1: @ main loop\n" \
"vmax.f32 q4, q0, q1 @ max \n" \ "vmax.f32 q4, q0, q1 @ max \n" \
"vmax.f32 q5, q2, q3 @ max \n" \ "vmax.f32 q5, q2, q3 @ max \n" \
"vld2.f32 {d0-d3}, [%[dr0]]! @ load \n" \ "vld2.f32 {d0-d3}, [%[dr0]]! @ load \n" \
"vld2.f32 {d4-d7}, [%[dr1]]! @ load \n" \ "vld2.f32 {d4-d7}, [%[dr1]]! @ load \n" \
"vmax.f32 q6, q4, q5 @ max reduce\n" \ "vmax.f32 q8, q4, q5 @ max reduce\n" \
"subs %[cnt_num], #1 @ subs cnt_num \n" \ "subs %[cnt_num], #1 @ subs cnt_num \n" \
"vst1.f32 {d12-d13}, [%[dr_out]]! @ store 4 out \n" \ "vst1.f32 {d16-d17}, [%[dr_out]]! @ store 4 out \n" \
"bne 1b @ bne " "bne 1b @ bne \n"
#define P2x2S2P1_AVG \
"vext.32 q4, %q[vzero], q1, #3 @ 1357-0135\n" \
"vext.32 q5, %q[vzero], q3, #3 @ 1357-0135\n" \
"sub %[dr0], #4 @sub \n" \
"sub %[dr1], #4 @sub \n" \
"vadd.f32 q9, q0, q4 @ max \n" \
"vadd.f32 q8, q2, q5 @ max \n" \
"vld2.f32 {d0-d3}, [%[dr0]]! @ load \n" \
"vld2.f32 {d4-d7}, [%[dr1]]! @ load \n" \
"vadd.f32 q5, q9, q8 @ max reduce\n" \
"subs %[cnt_num], #1 @ subs cnt_num \n" \
"vmul.f32 q4, q5, %q[vcoef_left] @ mul coef \n" \
"vst1.f32 {d8-d9}, [%[dr_out]]! @ store 4 out \n" \
"ble 2f @ bne\n"
#define P2x2S2P0_AVG \ #define P2x2S2P0_AVG \
"1: @ main loop\n" \ "1: @ main loop\n" \
...@@ -535,9 +594,9 @@ void pooling_basic(const float* din, ...@@ -535,9 +594,9 @@ void pooling_basic(const float* din,
"vadd.f32 q5, q2, q3 @ add 0, 2, 4, 6 \n" \ "vadd.f32 q5, q2, q3 @ add 0, 2, 4, 6 \n" \
"vld2.f32 {d0-d3}, [%[dr0]]! @ load d0-d3 \n" \ "vld2.f32 {d0-d3}, [%[dr0]]! @ load d0-d3 \n" \
"vld2.f32 {d4-d7}, [%[dr1]]! @ load d4-d7 \n" \ "vld2.f32 {d4-d7}, [%[dr1]]! @ load d4-d7 \n" \
"vadd.f32 q6, q4, q5 @ add reduce \n" \ "vadd.f32 q8, q4, q5 @ add reduce \n" \
"subs %[cnt_num], #1 @ subs \n" \ "subs %[cnt_num], #1 @ subs \n" \
"vmul.f32 q4, q6, %q[vcoef] @ mul coef \n" \ "vmul.f32 q4, q8, %q[vcoef] @ mul coef \n" \
"vst1.f32 {d8-d9}, [%[dr_out]]! @ store 4 out \n" \ "vst1.f32 {d8-d9}, [%[dr_out]]! @ store 4 out \n" \
"bne 1b @ bne \n" "bne 1b @ bne \n"
...@@ -1037,17 +1096,17 @@ void pooling1x1s2p0_max(const float* din, ...@@ -1037,17 +1096,17 @@ void pooling1x1s2p0_max(const float* din,
TargetFree(TARGET(kARM), write_ptr); TargetFree(TARGET(kARM), write_ptr);
} }
void pooling2x2s2_max(const float* din, void pooling2x2s2p0_max(const float* din,
float* dout, float* dout,
int num, int num,
int chout, int chout,
int hout, int hout,
int wout, int wout,
int chin, int chin,
int hin, int hin,
int win, int win,
int pad_bottom, int pad_bottom,
int pad_right) { int pad_right) {
int size_channel_out = wout * hout; int size_channel_out = wout * hout;
int size_channel_in = win * hin; int size_channel_in = win * hin;
auto data_out = static_cast<float*>(dout); auto data_out = static_cast<float*>(dout);
...@@ -1095,7 +1154,7 @@ void pooling2x2s2_max(const float* din, ...@@ -1095,7 +1154,7 @@ void pooling2x2s2_max(const float* din,
[dr_out] "+r"(dr_out), [dr_out] "+r"(dr_out),
[cnt_num] "+r"(cnt_num) [cnt_num] "+r"(cnt_num)
: :
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6"); : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q8");
#endif #endif
dr0 -= 8; dr0 -= 8;
dr1 -= 8; dr1 -= 8;
...@@ -1121,18 +1180,18 @@ void pooling2x2s2_max(const float* din, ...@@ -1121,18 +1180,18 @@ void pooling2x2s2_max(const float* din,
} }
} }
void pooling2x2s2_avg(const float* din, void pooling2x2s2p0_avg(const float* din,
float* dout, float* dout,
int num, int num,
int chout, int chout,
int hout, int hout,
int wout, int wout,
int chin, int chin,
int hin, int hin,
int win, int win,
bool exclusive, bool exclusive,
int pad_bottom, int pad_bottom,
int pad_right) { int pad_right) {
int size_channel_out = wout * hout; int size_channel_out = wout * hout;
int size_channel_in = win * hin; int size_channel_in = win * hin;
auto data_out = static_cast<float*>(dout); auto data_out = static_cast<float*>(dout);
...@@ -1158,12 +1217,14 @@ void pooling2x2s2_avg(const float* din, ...@@ -1158,12 +1217,14 @@ void pooling2x2s2_avg(const float* din,
const float* data_in_channel = data_in_batch + c * size_channel_in; const float* data_in_channel = data_in_batch + c * size_channel_in;
const float* r0 = data_in_channel; const float* r0 = data_in_channel;
const float* r1 = r0 + win; const float* r1 = r0 + win;
vcoef = vdupq_n_f32(0.25f);
for (int h = 0; h < hout; h++) { for (int h = 0; h < hout; h++) {
float* dr_out = data_out_channel; float* dr_out = data_out_channel;
auto dr0 = r0; auto dr0 = r0;
auto dr1 = r1; auto dr1 = r1;
if (h * S + K - P > hin) { if (h * S + K - P > hin) {
dr1 = zero_ptr; dr1 = zero_ptr;
vcoef = vdupq_n_f32(0.5f);
} }
int cnt_num = w_unroll_size; int cnt_num = w_unroll_size;
if (w_unroll_size > 0) { if (w_unroll_size > 0) {
...@@ -1184,7 +1245,7 @@ void pooling2x2s2_avg(const float* din, ...@@ -1184,7 +1245,7 @@ void pooling2x2s2_avg(const float* din,
[dr_out] "+r"(dr_out), [dr_out] "+r"(dr_out),
[cnt_num] "+r"(cnt_num) [cnt_num] "+r"(cnt_num)
: [vcoef] "w"(vcoef) : [vcoef] "w"(vcoef)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6"); : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q8");
#endif #endif
dr0 -= 8; dr0 -= 8;
dr1 -= 8; dr1 -= 8;
...@@ -1194,8 +1255,14 @@ void pooling2x2s2_avg(const float* din, ...@@ -1194,8 +1255,14 @@ void pooling2x2s2_avg(const float* din,
int wstart = 0; int wstart = 0;
for (int j = 0; j < w_unroll_remian; ++j) { for (int j = 0; j < w_unroll_remian; ++j) {
int wend = std::min(wstart + K, rem); int wend = std::min(wstart + K, rem);
float coef = 0.5f / (wend - wstart); float coef = 0.25f;
float tmp = 0.f; float tmp = 0.f;
if (wend - wstart == 1 && pad_right == 0) {
coef *= 2;
}
if (h * S + K - P > hin && pad_bottom == 0) {
coef *= 2;
}
for (int i = wstart; i < wend; i++) { for (int i = wstart; i < wend; i++) {
tmp += dr0[i] + dr1[i]; tmp += dr0[i] + dr1[i];
} }
...@@ -1212,6 +1279,235 @@ void pooling2x2s2_avg(const float* din, ...@@ -1212,6 +1279,235 @@ void pooling2x2s2_avg(const float* din,
TargetFree(TARGET(kARM), zero_ptr); TargetFree(TARGET(kARM), zero_ptr);
} }
void pooling2x2s2p1_max(const float* din,
float* dout,
int num,
int chout,
int hout,
int wout,
int chin,
int hin,
int win,
int pad_bottom,
int pad_right) {
int size_channel_out = wout * hout;
int size_channel_in = win * hin;
auto data_out = static_cast<float*>(dout);
auto data_in = static_cast<const float*>(din);
const int K = 2;
const int P = 1;
const int S = 2;
int w_unroll_size = wout / 4;
int w_unroll_remian = wout - w_unroll_size * 4;
float32x4_t vzero = vdupq_n_f32(std::numeric_limits<float>::lowest());
if (w_unroll_remian == 0) {
w_unroll_size -= 1;
w_unroll_remian = wout - w_unroll_size * 4;
}
for (int n = 0; n < num; ++n) {
float* data_out_batch = data_out + n * chout * size_channel_out;
const float* data_in_batch = data_in + n * chin * size_channel_in;
#pragma omp parallel for
for (int c = 0; c < chout; c++) {
float* data_out_channel = data_out_batch + c * size_channel_out;
const float* data_in_channel = data_in_batch + c * size_channel_in;
const float* r0 = data_in_channel;
const float* r1 = r0 + win;
for (int h = 0; h < hout; h++) {
float* dr_out = data_out_channel;
auto dr0 = r0;
auto dr1 = r1;
if (h == 0) {
dr0 = r0;
dr1 = r0;
r0 = r1;
r1 = r0 + win;
} else {
r0 = r1 + win;
r1 = r0 + win;
}
if (h * S + K - P > hin) {
dr1 = dr0;
if (h * S + K - P > hin + 1) {
memset(dr_out, 0, wout * sizeof(float));
continue;
}
}
int cnt_num = w_unroll_size;
if (w_unroll_size > 0) {
#ifdef __aarch64__
asm volatile(
P2x2S2_INIT P2x2S2P1_MAX P2x2S2P0_MAX "2: \n" /* end */
: [dr0] "+r"(dr0),
[dr1] "+r"(dr1),
[dr_out] "+r"(dr_out),
[cnt_num] "+r"(cnt_num)
: [vzero] "w"(vzero)
: "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v8");
#else
asm volatile(
P2x2S2_INIT P2x2S2P1_MAX P2x2S2P0_MAX "2: \n" /* end */
: [dr0] "+r"(dr0),
[dr1] "+r"(dr1),
[dr_out] "+r"(dr_out),
[cnt_num] "+r"(cnt_num)
: [vzero] "w"(vzero)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q8", "q9");
#endif
dr0 -= 8;
dr1 -= 8;
}
// deal with right pad
int wstart = w_unroll_size * 4 * S - P;
for (int j = 0; j < w_unroll_remian; ++j) {
int wend = std::min(wstart + K, win);
int st = wstart > 0 ? wstart : 0;
float tmp = wend == st ? 0.f : dr0[0];
for (int i = 0; i < wend - st; i++) {
tmp = std::max(tmp, dr0[i]);
tmp = std::max(tmp, dr1[i]);
}
*(dr_out++) = tmp;
dr0 += S - (st - wstart);
dr1 += S - (st - wstart);
wstart += S;
}
data_out_channel += wout;
}
}
}
}
void pooling2x2s2p1_avg(const float* din,
float* dout,
int num,
int chout,
int hout,
int wout,
int chin,
int hin,
int win,
bool exclusive,
int pad_bottom,
int pad_right) {
int size_channel_out = wout * hout;
int size_channel_in = win * hin;
auto data_out = static_cast<float*>(dout);
auto data_in = static_cast<const float*>(din);
const int K = 2;
const int P = 1;
const int S = 2;
int w_unroll_size = wout / 4;
int w_unroll_remian = wout - w_unroll_size * 4;
auto zero_ptr =
static_cast<float*>(TargetMalloc(TARGET(kARM), win * sizeof(float)));
float32x4_t vzero = vdupq_n_f32(0.f);
memset(zero_ptr, 0, win * sizeof(float));
if (w_unroll_remian == 0) {
w_unroll_size -= 1;
w_unroll_remian = wout - w_unroll_size * 4;
}
for (int n = 0; n < num; ++n) {
float* data_out_batch = data_out + n * chout * size_channel_out;
const float* data_in_batch = data_in + n * chin * size_channel_in;
#pragma omp parallel for
for (int c = 0; c < chout; c++) {
float* data_out_channel = data_out_batch + c * size_channel_out;
const float* data_in_channel = data_in_batch + c * size_channel_in;
const float* r0 = data_in_channel;
const float* r1 = r0 + win;
for (int h = 0; h < hout; h++) {
float* dr_out = data_out_channel;
auto dr0 = r0;
auto dr1 = r1;
float coef_h = 0.5f;
if (h == 0) {
dr0 = zero_ptr;
dr1 = r0;
r0 = r1;
r1 = r0 + win;
if (exclusive) {
coef_h = 1.f;
}
} else {
r0 = r1 + win;
r1 = r0 + win;
}
if (h * S + K - P > hin) {
dr1 = zero_ptr;
if (exclusive) {
coef_h = 1.f;
}
if (h * S + K - P > hin + 1) {
memset(dr_out, 0, wout * sizeof(float));
continue;
}
}
float coef_left_most = exclusive ? coef_h : coef_h / 2;
float32x4_t vcoef = vdupq_n_f32(coef_h / 2);
float coef_left[4] = {
coef_left_most, coef_h / 2, coef_h / 2, coef_h / 2};
float32x4_t vcoef_left = vld1q_f32(coef_left);
int cnt_num = w_unroll_size;
if (w_unroll_size > 0) {
#ifdef __aarch64__
asm volatile(
P2x2S2_INIT P2x2S2P1_AVG P2x2S2P0_AVG "2: \n"
: [dr0] "+r"(dr0),
[dr1] "+r"(dr1),
[dr_out] "+r"(dr_out),
[cnt_num] "+r"(cnt_num)
: [vcoef] "w"(vcoef),
[vzero] "w"(vzero),
[vcoef_left] "w"(vcoef_left)
: "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v8");
#else
asm volatile(
P2x2S2_INIT P2x2S2P1_AVG P2x2S2P0_AVG "2: \n"
: [dr0] "+r"(dr0),
[dr1] "+r"(dr1),
[dr_out] "+r"(dr_out),
[cnt_num] "+r"(cnt_num)
: [vcoef] "w"(vcoef),
[vzero] "w"(vzero),
[vcoef_left] "w"(vcoef_left)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q8", "q9");
#endif
dr0 -= 8;
dr1 -= 8;
}
// deal with right pad
int wstart = w_unroll_size * 4 * S - P;
for (int j = 0; j < w_unroll_remian; ++j) {
int wend = std::min(wstart + K, win);
int st = wstart > 0 ? wstart : 0;
float tmp = 0.f;
float coef = coef_h / 2;
if (exclusive && wend - st == 1) {
coef = coef_h;
}
for (int i = 0; i < wend - st; i++) {
tmp += dr0[i] + dr1[i];
}
*(dr_out++) = tmp * coef;
dr0 += S - (st - wstart);
dr1 += S - (st - wstart);
wstart += S;
}
data_out_channel += wout;
}
}
}
TargetFree(TARGET(kARM), zero_ptr);
}
void pooling3x3s1p1_max(const float* din, void pooling3x3s1p1_max(const float* din,
float* dout, float* dout,
int num, int num,
...@@ -2240,6 +2536,9 @@ void pooling3x3s2p0_max(const float* din, ...@@ -2240,6 +2536,9 @@ void pooling3x3s2p0_max(const float* din,
w_unroll_remian = wout - w_unroll_size * 4; w_unroll_remian = wout - w_unroll_size * 4;
} }
int remain = w_unroll_remian - 1;
int right = wout * 2 + 1 - win; // if need right pad
for (int n = 0; n < num; ++n) { for (int n = 0; n < num; ++n) {
float* data_out_batch = data_out + n * chout * size_channel_out; float* data_out_batch = data_out + n * chout * size_channel_out;
const float* data_in_batch = data_in + n * chin * size_channel_in; const float* data_in_batch = data_in + n * chin * size_channel_in;
...@@ -2266,6 +2565,7 @@ void pooling3x3s2p0_max(const float* din, ...@@ -2266,6 +2565,7 @@ void pooling3x3s2p0_max(const float* din,
} }
} }
int cnt_num = w_unroll_size; int cnt_num = w_unroll_size;
int cnt_remain = remain;
if (w_unroll_size > 0) { if (w_unroll_size > 0) {
#ifdef __aarch64__ #ifdef __aarch64__
asm volatile(P3x3S2P0_INIT P3x3S2P0_MAX asm volatile(P3x3S2P0_INIT P3x3S2P0_MAX
...@@ -2289,46 +2589,80 @@ void pooling3x3s2p0_max(const float* din, ...@@ -2289,46 +2589,80 @@ void pooling3x3s2p0_max(const float* din,
"v9", "v9",
"v10", "v10",
"v11"); "v11");
#else
asm volatile(P3x3S2P0_INIT P3x3S2P0_MAX
: [dr0] "+r"(dr0),
[dr1] "+r"(dr1),
[dr2] "+r"(dr2),
[dr_out] "+r"(dr_out),
[cnt_num] "+r"(cnt_num)
:
: "cc",
"memory",
"q0",
"q1",
"q2",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11");
#endif
dr0 -= 8; dr0 -= 8;
dr1 -= 8; dr1 -= 8;
dr2 -= 8; dr2 -= 8;
} int rem = win - (w_unroll_size * 4) * S;
// deal with right pad int wstart = 0;
int rem = win - (w_unroll_size * 4) * S; for (int j = 0; j < w_unroll_remian; ++j) {
int wstart = 0; int wend = std::min(wstart + K, rem);
for (int j = 0; j < w_unroll_remian; ++j) { float tmp = dr0[wstart]; // std::numeric_limits<float>::min();
int wend = std::min(wstart + K, rem); for (int i = wstart; i < wend; i++) {
float tmp = dr0[wstart]; // std::numeric_limits<float>::min(); tmp = std::max(tmp, dr0[i]);
for (int i = wstart; i < wend; i++) { tmp = std::max(tmp, dr1[i]);
tmp = std::max(tmp, dr0[i]); tmp = std::max(tmp, dr2[i]);
tmp = std::max(tmp, dr1[i]); }
tmp = std::max(tmp, dr2[i]); *(dr_out++) = tmp;
wstart += S;
} }
*(dr_out++) = tmp; #else
wstart += S; asm volatile(
P3x3S2P0_INIT P3x3S2P0_MAX
"cmp %[remain], #0 @cmp cnt_num\n"
"sub %[dr0], #32 @sub - 8\n"
"sub %[dr1], #32 @sub - 8\n"
"sub %[dr2], #32 @sub - 8\n"
"ble 4f @ble exit1\n"
"2: @mid loop\n"
"vld1.f32 {d0-d1}, [%[dr0]]! @load \n"
"vld1.f32 {d2-d3}, [%[dr1]]! @load \n"
"vld1.f32 {d4-d5}, [%[dr2]]! @load \n"
"vmov.f32 s3,s2 @mov \n"
"vmov.f32 s7,s6 @mov \n"
"vmov.f32 s11,s10 @mov \n"
"vmax.f32 q0, q0, q1 @max n"
"sub %[dr0], #8 @add w \n"
"sub %[dr1], #8 @add w \n"
"sub %[dr2], #8 @add w \n"
"vmax.f32 q0, q0, q2 @max \n"
"vpmax.f32 d0, d0, d1 @pmax \n"
"vpmax.f32 d0, d0, d0 @pmax \n"
"subs %[remain], #1 @subs \n"
"vst1.f32 d0[0], [%[dr_out]]! @vst \n"
"bne 2b @bne \n"
"4: @exit\n"
: [dr0] "+r"(dr0),
[dr1] "+r"(dr1),
[dr2] "+r"(dr2),
[dr_out] "+r"(dr_out),
[remain] "+r"(cnt_remain),
[cnt_num] "+r"(cnt_num)
:
: "cc",
"memory",
"q0",
"q1",
"q2",
"q3",
"q4",
"q5",
"q6",
"q7",
"q8",
"q9",
"q10",
"q11");
if (right) {
int wstart = (w_unroll_size * 4 + remain) * S;
int wend = std::min(wstart + K, win);
float tmp = dr0[wstart]; // std::numeric_limits<float>::min();
for (int i = wstart; i < wend; i++) {
tmp = std::max(tmp, std::max(dr0[i], dr1[i]));
tmp = std::max(tmp, dr2[i]);
}
*(dr_out++) = tmp;
}
#endif
} }
r0 = r2; r0 = r2;
...@@ -2368,6 +2702,9 @@ void pooling3x3s2p0_avg(const float* din, ...@@ -2368,6 +2702,9 @@ void pooling3x3s2p0_avg(const float* din,
w_unroll_remian = wout - w_unroll_size * 4; w_unroll_remian = wout - w_unroll_size * 4;
} }
// do overflow process
w_unroll_size -= 1;
w_unroll_remian += 4;
auto zero_ptr = auto zero_ptr =
static_cast<float*>(TargetMalloc(TARGET(kARM), win * sizeof(float))); static_cast<float*>(TargetMalloc(TARGET(kARM), win * sizeof(float)));
memset(zero_ptr, 0, win * sizeof(float)); memset(zero_ptr, 0, win * sizeof(float));
......
...@@ -76,30 +76,55 @@ void pooling1x1s2p0_max(const float* din, ...@@ -76,30 +76,55 @@ void pooling1x1s2p0_max(const float* din,
int pad_bottom, int pad_bottom,
int pad_right); int pad_right);
void pooling2x2s2_max(const float* din, void pooling2x2s2p0_max(const float* din,
float* dout, float* dout,
int num, int num,
int chout, int chout,
int hout, int hout,
int wout, int wout,
int chin, int chin,
int hin, int hin,
int win, int win,
int pad_bottom, int pad_bottom,
int pad_right); int pad_right);
void pooling2x2s2_avg(const float* din, void pooling2x2s2p0_avg(const float* din,
float* dout, float* dout,
int num, int num,
int chout, int chout,
int hout, int hout,
int wout, int wout,
int chin, int chin,
int hin, int hin,
int win, int win,
bool exclusive, bool exclusive,
int pad_bottom, int pad_bottom,
int pad_right); int pad_right);
void pooling2x2s2p1_max(const float* din,
float* dout,
int num,
int chout,
int hout,
int wout,
int chin,
int hin,
int win,
int pad_bottom,
int pad_right);
void pooling2x2s2p1_avg(const float* din,
float* dout,
int num,
int chout,
int hout,
int wout,
int chin,
int hin,
int win,
bool exclusive,
int pad_bottom,
int pad_right);
void pooling3x3s1p1_max(const float* din, void pooling3x3s1p1_max(const float* din,
float* dout, float* dout,
......
...@@ -531,7 +531,7 @@ void softmax_inner1_large_axis<float>(const float* din, ...@@ -531,7 +531,7 @@ void softmax_inner1_large_axis<float>(const float* din,
} }
float32x2_t vhmax = vmax_f32(vget_high_f32(vmax), vget_low_f32(vmax)); float32x2_t vhmax = vmax_f32(vget_high_f32(vmax), vget_low_f32(vmax));
float max_data = std::max(vget_lane_f32(vhmax, 0), vget_lane_f32(vhmax, 1)); float max_data = std::max(vget_lane_f32(vhmax, 0), vget_lane_f32(vhmax, 1));
for (j = 4 * j; j < axis_size; ++j) { for (j = 4 * nn; j < axis_size; ++j) {
max_data = std::max(max_data, din_max_ptr[0]); max_data = std::max(max_data, din_max_ptr[0]);
din_max_ptr++; din_max_ptr++;
} }
...@@ -557,7 +557,7 @@ void softmax_inner1_large_axis<float>(const float* din, ...@@ -557,7 +557,7 @@ void softmax_inner1_large_axis<float>(const float* din,
float32x2_t vhsum = vadd_f32(vget_high_f32(vsum), vget_low_f32(vsum)); float32x2_t vhsum = vadd_f32(vget_high_f32(vsum), vget_low_f32(vsum));
float sum_data = vget_lane_f32(vhsum, 0) + vget_lane_f32(vhsum, 1); float sum_data = vget_lane_f32(vhsum, 0) + vget_lane_f32(vhsum, 1);
for (j = 4 * j; j < axis_size; ++j) { for (j = 4 * nn; j < axis_size; ++j) {
dout_sum_ptr[0] = expf(din_sum_ptr[0] - max_data); dout_sum_ptr[0] = expf(din_sum_ptr[0] - max_data);
sum_data += dout_sum_ptr[0]; sum_data += dout_sum_ptr[0];
din_sum_ptr++; din_sum_ptr++;
......
...@@ -50,13 +50,14 @@ class PoolingPE : public PE { ...@@ -50,13 +50,14 @@ class PoolingPE : public PE {
PoolingArgs args = {0}; PoolingArgs args = {0};
args.mode = param_.type; args.mode = param_.type;
auto paddings = *param_.paddings;
args.kernel_reciprocal = fp32_2_fp16(1.0f / (k_width * k_height)); args.kernel_reciprocal = fp32_2_fp16(1.0f / (k_width * k_height));
args.image.address = input->data<float16>(); args.image.address = input->data<float16>();
args.image.channels = input->shape().channel(); args.image.channels = input->shape().channel();
args.image.height = input->shape().height(); args.image.height = input->shape().height();
args.image.width = input->shape().width(); args.image.width = input->shape().width();
args.image.pad_height = param_.paddings[0]; args.image.pad_height = paddings[0];
args.image.pad_width = param_.paddings[1]; args.image.pad_width = paddings[2];
args.image.scale_address = input->scale(); args.image.scale_address = input->scale();
args.output.address = output->mutableData<float16>(); args.output.address = output->mutableData<float16>();
args.output.scale_address = output->scale(); args.output.scale_address = output->scale();
...@@ -69,8 +70,7 @@ class PoolingPE : public PE { ...@@ -69,8 +70,7 @@ class PoolingPE : public PE {
param_.poolingArgs = args; param_.poolingArgs = args;
// use_cpu_ = output->shape().width() == 1 && output->shape().height() == 1 // use_cpu_ = output->shape().width() == 1 && output->shape().height() == 1
// && // && (k_width > 7 || k_height > 7);
// (k_width > 7 || k_height > 7);
use_cpu_ = output->shape().width() == 1 && output->shape().height() == 1 && use_cpu_ = output->shape().width() == 1 && output->shape().height() == 1 &&
(k_width > 255 || k_height > 255); (k_width > 255 || k_height > 255);
// use_cpu_ = param_.type == AVERAGE; // use_cpu_ = param_.type == AVERAGE;
...@@ -86,12 +86,13 @@ class PoolingPE : public PE { ...@@ -86,12 +86,13 @@ class PoolingPE : public PE {
float* image_addr = float_input.mutableData<float>(FP32, input->shape()); float* image_addr = float_input.mutableData<float>(FP32, input->shape());
float_input.copyFrom(input); float_input.copyFrom(input);
float16* data_out = output->data<float16>(); float16* data_out = output->data<float16>();
auto paddings = *param_.paddings;
int image_height = input->shape().height(); int image_height = input->shape().height();
int image_width = input->shape().width(); int image_width = input->shape().width();
int image_channels = input->shape().channel(); int image_channels = input->shape().channel();
int image_pad_h = param_.paddings[0]; int image_pad_h = paddings[0];
int image_pad_w = param_.paddings[1]; int image_pad_w = paddings[2];
int kernel_height = param_.kernelSize[1]; int kernel_height = param_.kernelSize[1];
int kernel_width = param_.kernelSize[0]; int kernel_width = param_.kernelSize[0];
int kernel_step_h = param_.strides[0]; int kernel_step_h = param_.strides[0];
......
...@@ -71,6 +71,9 @@ void ConcatCompute::Run() { ...@@ -71,6 +71,9 @@ void ConcatCompute::Run() {
auto* axis_tensor_data = axis_tensor->data<int>(); auto* axis_tensor_data = axis_tensor->data<int>();
axis = axis_tensor_data[0]; axis = axis_tensor_data[0];
} }
if (axis < 0) {
axis += inputs[0]->dims().size();
}
switch (inputs.front()->precision()) { switch (inputs.front()->precision()) {
case PRECISION(kFloat): case PRECISION(kFloat):
......
...@@ -73,7 +73,6 @@ void ConvCompute<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() { ...@@ -73,7 +73,6 @@ void ConvCompute<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
// VLOG(3) << "invoking dw conv"; // VLOG(3) << "invoking dw conv";
} else if (param.groups == 1 && kw == 3 && stride == 1 && ks_equal && } else if (param.groups == 1 && kw == 3 && stride == 1 && ks_equal &&
no_dilation) { no_dilation) {
// TODO(MyPandaShaoxiang): winograd conv support any pad
impl_ = new WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>; impl_ = new WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>;
// VLOG(3) << "invoking winograd conv"; // VLOG(3) << "invoking winograd conv";
} else if (param.groups == 1 && kw == 3 && stride == 2 && } else if (param.groups == 1 && kw == 3 && stride == 2 &&
...@@ -122,10 +121,14 @@ void ConvCompute<PRECISION(kInt8), PRECISION(kFloat)>::PrepareForRun() { ...@@ -122,10 +121,14 @@ void ConvCompute<PRECISION(kInt8), PRECISION(kFloat)>::PrepareForRun() {
no_dilation && flag_dw) { no_dilation && flag_dw) {
impl_ = new DepthwiseConv<PRECISION(kInt8), PRECISION(kFloat)>; impl_ = new DepthwiseConv<PRECISION(kInt8), PRECISION(kFloat)>;
// VLOG(3) << "Run DepthwiseConv Int8"; // VLOG(3) << "Run DepthwiseConv Int8";
} else if (param.groups == 1 && kw == 3 && (sw == 1 || sw == 2) && } else if (param.groups == 1 && kw == 3 && sw == 2 && no_dilation &&
ic * oc < 4 * hin * win && kps_equal && no_dilation) { pads_equal) {
impl_ = new DirectConv<PRECISION(kInt8), PRECISION(kFloat)>; impl_ = new DirectConv<PRECISION(kInt8), PRECISION(kFloat)>;
// VLOG(3) << "Run DirectConv Int8"; // VLOG(3) << "Run DirectConv Int8";
} else if (param.groups == 1 && kw == 3 && sw == 1 && no_dilation &&
pads_equal) {
impl_ = new WinogradConv<PRECISION(kInt8), PRECISION(kFloat)>;
// VLOG(3) << "Run WinogradConv Int8";
} else { } else {
impl_ = new GemmLikeConv<PRECISION(kInt8), PRECISION(kFloat)>; impl_ = new GemmLikeConv<PRECISION(kInt8), PRECISION(kFloat)>;
// VLOG(3) << "Run GemmLikeConvInt8"; // VLOG(3) << "Run GemmLikeConvInt8";
...@@ -169,10 +172,14 @@ void ConvCompute<PRECISION(kInt8), PRECISION(kInt8)>::PrepareForRun() { ...@@ -169,10 +172,14 @@ void ConvCompute<PRECISION(kInt8), PRECISION(kInt8)>::PrepareForRun() {
no_dilation && flag_dw) { no_dilation && flag_dw) {
impl_ = new DepthwiseConv<PRECISION(kInt8), PRECISION(kInt8)>; impl_ = new DepthwiseConv<PRECISION(kInt8), PRECISION(kInt8)>;
// VLOG(3) << "Run DepthwiseConv Int8"; // VLOG(3) << "Run DepthwiseConv Int8";
} else if (param.groups == 1 && kw == 3 && (sw == 1 || sw == 2) && } else if (param.groups == 1 && kw == 3 && sw == 2 && no_dilation &&
ic * oc < 4 * hin * win && kps_equal && no_dilation) { pads_equal) {
impl_ = new DirectConv<PRECISION(kInt8), PRECISION(kInt8)>; impl_ = new DirectConv<PRECISION(kInt8), PRECISION(kInt8)>;
// VLOG(3) << "Run DirectConv Int8"; // VLOG(3) << "Run DirectConv Int8";
} else if (param.groups == 1 && kw == 3 && sw == 1 && no_dilation &&
pads_equal) {
impl_ = new WinogradConv<PRECISION(kInt8), PRECISION(kInt8)>;
// VLOG(3) << "Run WinogradConv Int8";
} else { } else {
impl_ = new GemmLikeConv<PRECISION(kInt8), PRECISION(kInt8)>; impl_ = new GemmLikeConv<PRECISION(kInt8), PRECISION(kInt8)>;
// VLOG(3) << "Run GemmLikeConvInt8"; // VLOG(3) << "Run GemmLikeConvInt8";
......
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
// limitations under the License. // limitations under the License.
#include "lite/kernels/arm/conv_winograd.h" #include "lite/kernels/arm/conv_winograd.h"
#include <vector>
#include "lite/backends/arm/math/conv_impl.h" #include "lite/backends/arm/math/conv_impl.h"
#include "lite/backends/arm/math/packed_sgemm.h" #include "lite/backends/arm/math/packed_sgemm.h"
...@@ -166,6 +165,189 @@ void WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>::Run() { ...@@ -166,6 +165,189 @@ void WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>::Run() {
} }
} }
template <PrecisionType OutType>
void WinogradConv<PRECISION(kInt8), OutType>::ReInitWhenNeeded() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<ARMContext>();
int threads = ctx.threads();
auto x_dims = param.x->dims();
auto w_dims = param.filter->dims();
auto o_dims = param.output->dims();
if (last_shape_ == x_dims) {
return;
}
last_shape_ = x_dims;
//! update workspace size
int ic = x_dims[1];
int ih = x_dims[2];
int iw = x_dims[3];
int oc = o_dims[1];
int oh = o_dims[2];
int ow = o_dims[3];
int tile_block = 8;
auto pad = *(param.paddings);
int pad_h0 = pad[0];
int pad_h1 = pad[1];
int pad_w0 = pad[2];
int pad_w1 = pad[3];
int oc_pad = (oc + 7) / 8 * 8;
int ic_pad = (ic + 7) / 8 * 8;
const int new_input_size =
ic_pad * (ih + pad_h0 + pad_h1) * (iw + pad_w0 + pad_w1) +
oc_pad * oh * ow * sizeof(int32_t);
int tmp_input_thread_size_byte =
tile_block * ic_pad * wino_iw * wino_iw * sizeof(int16_t);
int tmp_output_thread_size_byte =
tile_block * oc_pad * wino_iw * wino_iw * sizeof(int32_t);
const int temp_size =
(tmp_input_thread_size_byte + tmp_output_thread_size_byte +
wino_iw * wino_iw * (8 + 8 * sizeof(int32_t))) *
threads;
workspace_size_ = temp_size + new_input_size;
//! update trans weights impl
// choose_small_ = ow * oh / (tile_block * threads) < 36 ? true : false;
// we only support 2x2 now
choose_small_ = true;
float w_fact = 0.25;
if (choose_small_) {
wino_iw = 4;
if (last_function_ == 0) {
return;
}
last_function_ = 0;
} else {
wino_iw = 6;
if (last_function_ == 1) {
return;
}
last_function_ = 1;
}
/// update scale
for (auto& ws : w_scale_) {
ws *= w_fact;
}
weights_.Resize({1, 1, 1, wino_iw * wino_iw * oc_pad * ic_pad});
void* trans_tmp_ptr = malloc(sizeof(int16_t) * wino_iw * wino_iw * oc * ic);
auto weights_data_ = weights_.mutable_data<int16_t>();
if (!choose_small_) {
} else {
lite::arm::math::weight_trans_c8_4x4_int8(
weights_data_,
param.filter->template data<int8_t>(),
ic,
oc,
trans_tmp_ptr);
}
free(trans_tmp_ptr);
}
template <PrecisionType OutType>
void WinogradConv<PRECISION(kInt8), OutType>::PrepareForRun() {
auto& param = this->Param<param_t>();
w_scale_ = param.weight_scale;
if (w_scale_.size() != 1 && w_scale_.size() != param.filter->dims()[0]) {
LOG(FATAL) << "weights scale size must equal to filter size";
return;
}
if (w_scale_.size() == 1) {
for (int i = 0; i < param.filter->dims()[0] - 1; ++i) {
w_scale_.push_back(w_scale_[0]);
}
}
float input_scale = param.input_scale;
for (auto& ws : w_scale_) {
ws *= input_scale;
}
if (param.bias) {
bias_.Resize(param.bias->dims());
auto ptr = bias_.mutable_data<float>();
auto ptr_in = param.bias->template data<float>();
for (int i = 0; i < bias_.numel(); ++i) {
ptr[i] = ptr_in[i];
}
}
if (OutType == PRECISION(kInt8)) {
float output_scale = param.output_scale;
for (auto& ws : w_scale_) {
ws /= output_scale;
}
if (param.bias) {
auto ptr = bias_.mutable_data<float>();
for (int i = 0; i < bias_.numel(); ++i) {
ptr[i] /= output_scale;
}
}
}
ReInitWhenNeeded();
}
template <PrecisionType OutType>
void WinogradConv<PRECISION(kInt8), OutType>::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<ARMContext>();
ctx.ExtendWorkspace(workspace_size_);
const auto* i_data = param.x->template data<int8_t>();
const auto* w_data = weights_.data<int16_t>();
const auto* b_data = param.bias ? bias_.data<float>() : nullptr;
// const float* i_data;
auto x_dims = param.x->dims();
auto w_dims = param.filter->dims();
auto o_dims = param.output->dims();
int iw = x_dims[3]; // nchw
int ih = x_dims[2];
int ic = x_dims[1];
int bs = x_dims[0];
int oh = o_dims[2];
int ow = o_dims[3];
int oc = o_dims[1];
// now always choose small
if (OutType == PRECISION(kInt8)) {
auto* o_data = param.output->template mutable_data<int8_t>();
lite::arm::math::conv_compute_2x2_3x3_int8<int8_t>(i_data,
o_data,
bs,
oc,
oh,
ow,
ic,
ih,
iw,
w_data,
b_data,
w_scale_.data(),
param,
&ctx);
} else {
auto* o_data = param.output->template mutable_data<float>();
lite::arm::math::conv_compute_2x2_3x3_int8<float>(i_data,
o_data,
bs,
oc,
oh,
ow,
ic,
ih,
iw,
w_data,
b_data,
w_scale_.data(),
param,
&ctx);
}
#ifdef LITE_WITH_PROFILE
kernel_func_name_ = "conv_compute_2x2_3x3_int8";
#endif
}
template class WinogradConv<PRECISION(kInt8), PRECISION(kInt8)>;
template class WinogradConv<PRECISION(kInt8), PRECISION(kFloat)>;
} // namespace arm } // namespace arm
} // namespace kernels } // namespace kernels
} // namespace lite } // namespace lite
......
...@@ -15,11 +15,12 @@ ...@@ -15,11 +15,12 @@
#pragma once #pragma once
#include <cmath> #include <cmath>
#include <string>
#include <vector>
#include "lite/backends/arm/math/conv_impl.h" #include "lite/backends/arm/math/conv_impl.h"
#include "lite/core/context.h" #include "lite/core/context.h"
#include "lite/core/kernel.h" #include "lite/core/kernel.h"
#include "lite/core/target_wrapper.h" #include "lite/core/target_wrapper.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
namespace kernels { namespace kernels {
...@@ -44,7 +45,34 @@ class WinogradConv : public KernelLite<TARGET(kARM), Ptype> { ...@@ -44,7 +45,34 @@ class WinogradConv : public KernelLite<TARGET(kARM), Ptype> {
bool choose_small_{false}; bool choose_small_{false};
int wino_iw{8}; int wino_iw{8};
}; };
template <PrecisionType OutType>
class WinogradConv<PRECISION(kInt8), OutType>
: public KernelLite<TARGET(kARM), PRECISION(kInt8)> {
public:
WinogradConv() = default;
~WinogradConv() {}
virtual void PrepareForRun();
virtual void ReInitWhenNeeded();
virtual void Run();
#ifdef LITE_WITH_PROFILE
virtual void SetProfileRuntimeKernelInfo(
paddle::lite::profile::OpCharacter* ch) {
ch->kernel_func_name = kernel_func_name_;
}
std::string kernel_func_name_{"NotImplForConvWino"};
#endif
protected:
using param_t = operators::ConvParam;
Tensor weights_;
Tensor bias_;
DDim last_shape_;
int workspace_size_{0};
int last_function_{-1};
bool choose_small_{true};
int wino_iw{4};
std::vector<float> w_scale_;
};
} // namespace arm } // namespace arm
} // namespace kernels } // namespace kernels
} // namespace lite } // namespace lite
......
...@@ -58,6 +58,7 @@ void PoolCompute::Run() { ...@@ -58,6 +58,7 @@ void PoolCompute::Run() {
bool global_pooling = (paddings[0] == 0) && (ksize[0] == in_dims[2]) && bool global_pooling = (paddings[0] == 0) && (ksize[0] == in_dims[2]) &&
(ksize[1] == in_dims[3]) && kps_equal && pads_equal; (ksize[1] == in_dims[3]) && kps_equal && pads_equal;
global_pooling = param.global_pooling || global_pooling; global_pooling = param.global_pooling || global_pooling;
if (global_pooling) { if (global_pooling) {
for (size_t i = 0; i < ksize.size(); ++i) { for (size_t i = 0; i < ksize.size(); ++i) {
paddings[2 * i] = 0; paddings[2 * i] = 0;
...@@ -107,35 +108,65 @@ void PoolCompute::Run() { ...@@ -107,35 +108,65 @@ void PoolCompute::Run() {
} else if (ksize[0] == 2 && strides[0] == 2 && paddings[0] == 0 && } else if (ksize[0] == 2 && strides[0] == 2 && paddings[0] == 0 &&
kps_equal) { kps_equal) {
if (pooling_type == "max") { if (pooling_type == "max") {
lite::arm::math::pooling2x2s2_max(din, lite::arm::math::pooling2x2s2p0_max(din,
dout, dout,
out_dims[0], out_dims[0],
out_dims[1], out_dims[1],
out_dims[2], out_dims[2],
out_dims[3], out_dims[3],
in_dims[1], in_dims[1],
in_dims[2], in_dims[2],
in_dims[3], in_dims[3],
paddings[1], paddings[1],
paddings[3]); paddings[3]);
return; return;
} else if (pooling_type == "avg") { } else if (pooling_type == "avg") {
lite::arm::math::pooling2x2s2_avg(din, lite::arm::math::pooling2x2s2p0_avg(din,
dout, dout,
out_dims[0], out_dims[0],
out_dims[1], out_dims[1],
out_dims[2], out_dims[2],
out_dims[3], out_dims[3],
in_dims[1], in_dims[1],
in_dims[2], in_dims[2],
in_dims[3], in_dims[3],
exclusive, exclusive,
paddings[1], paddings[1],
paddings[3]); paddings[3]);
return; return;
} }
} else if (ksize[0] == 3 && strides[0] == 1 && paddings[0] == 1 && } else if (ksize[0] == 2 && strides[0] == 2 && paddings[0] == 1 &&
kps_equal) { kps_equal) {
if (pooling_type == "max") {
lite::arm::math::pooling2x2s2p1_max(din,
dout,
out_dims[0],
out_dims[1],
out_dims[2],
out_dims[3],
in_dims[1],
in_dims[2],
in_dims[3],
paddings[1],
paddings[3]);
return;
} else if (pooling_type == "avg") {
lite::arm::math::pooling2x2s2p1_avg(din,
dout,
out_dims[0],
out_dims[1],
out_dims[2],
out_dims[3],
in_dims[1],
in_dims[2],
in_dims[3],
exclusive,
paddings[1],
paddings[3]);
return;
}
} else if (ksize[0] == 3 && strides[0] == 1 && paddings[0] == 1 &&
pads_equal && kps_equal) {
if (pooling_type == "max") { if (pooling_type == "max") {
lite::arm::math::pooling3x3s1p1_max(din, lite::arm::math::pooling3x3s1p1_max(din,
dout, dout,
...@@ -165,7 +196,7 @@ void PoolCompute::Run() { ...@@ -165,7 +196,7 @@ void PoolCompute::Run() {
return; return;
} }
} else if (ksize[0] == 3 && strides[0] == 1 && paddings[0] == 0 && } else if (ksize[0] == 3 && strides[0] == 1 && paddings[0] == 0 &&
kps_equal) { pads_equal && kps_equal) {
if (pooling_type == "max") { if (pooling_type == "max") {
lite::arm::math::pooling3x3s1p0_max(din, lite::arm::math::pooling3x3s1p0_max(din,
dout, dout,
...@@ -195,7 +226,7 @@ void PoolCompute::Run() { ...@@ -195,7 +226,7 @@ void PoolCompute::Run() {
return; return;
} }
} else if (ksize[0] == 3 && strides[0] == 2 && paddings[0] == 0 && } else if (ksize[0] == 3 && strides[0] == 2 && paddings[0] == 0 &&
kps_equal) { pads_equal && kps_equal) {
if (pooling_type == "max") { if (pooling_type == "max") {
lite::arm::math::pooling3x3s2p0_max(din, lite::arm::math::pooling3x3s2p0_max(din,
dout, dout,
...@@ -225,7 +256,7 @@ void PoolCompute::Run() { ...@@ -225,7 +256,7 @@ void PoolCompute::Run() {
return; return;
} }
} else if (ksize[0] == 3 && strides[0] == 2 && paddings[0] == 1 && } else if (ksize[0] == 3 && strides[0] == 2 && paddings[0] == 1 &&
kps_equal) { pads_equal && kps_equal) {
if (pooling_type == "max") { if (pooling_type == "max") {
lite::arm::math::pooling3x3s2p1_max(din, lite::arm::math::pooling3x3s2p1_max(din,
dout, dout,
......
...@@ -34,7 +34,7 @@ void SoftmaxCompute::Run() { ...@@ -34,7 +34,7 @@ void SoftmaxCompute::Run() {
int inner_num = x_dims.Slice(axis + 1, x_rank).production(); int inner_num = x_dims.Slice(axis + 1, x_rank).production();
int axis_size = x_dims[axis]; int axis_size = x_dims[axis];
if (inner_num == 1) { if (inner_num == 1) {
if (axis_size >= 4) { if (axis_size > 4) {
lite::arm::math::softmax_inner1_large_axis( lite::arm::math::softmax_inner1_large_axis(
din, dout, outer_num, axis_size); din, dout, outer_num, axis_size);
} else { } else {
......
...@@ -34,7 +34,7 @@ DEFINE_int32(power_mode, ...@@ -34,7 +34,7 @@ DEFINE_int32(power_mode,
DEFINE_int32(threads, 1, "threads num"); DEFINE_int32(threads, 1, "threads num");
DEFINE_int32(warmup, 0, "warmup times"); DEFINE_int32(warmup, 0, "warmup times");
DEFINE_int32(repeats, 1, "repeats times"); DEFINE_int32(repeats, 1, "repeats times");
DEFINE_bool(basic_test, true, "do all tests"); DEFINE_bool(basic_test, false, "do all tests");
DEFINE_bool(check_result, true, "check the result"); DEFINE_bool(check_result, true, "check the result");
DEFINE_int32(batch, 1, "batch size"); DEFINE_int32(batch, 1, "batch size");
...@@ -59,6 +59,7 @@ DEFINE_bool(flag_bias, true, "with bias"); ...@@ -59,6 +59,7 @@ DEFINE_bool(flag_bias, true, "with bias");
typedef paddle::lite::DDim DDim; typedef paddle::lite::DDim DDim;
typedef paddle::lite::Tensor Tensor; typedef paddle::lite::Tensor Tensor;
typedef paddle::lite::operators::ConvParam ConvParam; typedef paddle::lite::operators::ConvParam ConvParam;
typedef paddle::lite::operators::ActivationParam ActivationParam;
using paddle::lite::profile::Timer; using paddle::lite::profile::Timer;
DDim compute_out_dim(const DDim& dim_in, DDim compute_out_dim(const DDim& dim_in,
...@@ -165,7 +166,18 @@ void test_conv_int8(const std::vector<DDim>& input_dims, ...@@ -165,7 +166,18 @@ void test_conv_int8(const std::vector<DDim>& input_dims,
param_fp32_out.bias->CopyDataFrom(*param_int8_out.bias); param_fp32_out.bias->CopyDataFrom(*param_int8_out.bias);
bias_fp32.CopyDataFrom(*param_int8_out.bias); bias_fp32.CopyDataFrom(*param_int8_out.bias);
} }
if (flag_relu) {
ActivationParam act_param;
act_param.has_active = true;
act_param.active_type = (paddle::lite_api::ActivationType)
flag_relu; // 1-relu, 2-relu6, 4-leakyrelu
if (flag_relu) {
param_fp32_out.fuse_relu = true;
param_int8_out.fuse_relu = true;
}
param_fp32_out.activation_param = act_param;
param_int8_out.activation_param = act_param;
}
std::vector<float> scale_in{1.f / 127}; std::vector<float> scale_in{1.f / 127};
std::vector<float> scale_out{weight_dim.count(1, 4) / 127.f}; std::vector<float> scale_out{weight_dim.count(1, 4) / 127.f};
std::vector<float> scale_w(weight_dim[0], 1.f / 127); std::vector<float> scale_w(weight_dim[0], 1.f / 127);
...@@ -580,6 +592,9 @@ TEST(TestConv3x3s1Int8, test_conv_3x3s1) { ...@@ -580,6 +592,9 @@ TEST(TestConv3x3s1Int8, test_conv_3x3s1) {
dims.push_back(DDim({batch, cin, h, h})); dims.push_back(DDim({batch, cin, h, h}));
} }
} }
if (cin == 1 && cout == 1) {
continue;
}
test_conv_int8(dims, test_conv_int8(dims,
weights_dim, weights_dim,
1, 1,
......
...@@ -179,6 +179,141 @@ bool test_sgemm_c4( ...@@ -179,6 +179,141 @@ bool test_sgemm_c4(
#endif #endif
return true; return true;
} }
bool test_sgemm_c8(
int m, int n, int k, bool has_bias, bool has_relu, int cls, int ths) {
int m_round = (m + 7) / 8 * 8;
int k_round = (k + 7) / 8 * 8;
int size_a = m * k;
int size_b = n * k;
int size_a_c4 = m_round * k_round;
int size_b_c8 = k_round * n;
Tensor ta;
Tensor tb;
Tensor ta_c4;
Tensor tb_c8;
Tensor tc;
Tensor tc_basic;
Tensor tc_backup;
Tensor tbias;
ta.Resize({size_a});
tb.Resize({size_b});
ta_c4.Resize({size_a_c4});
tb_c8.Resize({size_b_c8});
tc.Resize({m_round * n});
tc_basic.Resize({m_round * n});
tbias.Resize({m});
ta.set_precision(PRECISION(kInt16));
tb.set_precision(PRECISION(kInt16));
ta_c4.set_precision(PRECISION(kInt16));
tb_c8.set_precision(PRECISION(kInt16));
tc.set_precision(PRECISION(kInt32));
tc_basic.set_precision(PRECISION(kInt32));
tbias.set_precision(PRECISION(kInt32));
fill_tensor_rand(ta);
fill_tensor_rand(tb);
fill_tensor_rand(tbias);
fill_tensor_rand(tc);
auto da = ta.mutable_data<int16_t>();
auto db = tb.mutable_data<int16_t>();
auto da_c4 = ta_c4.mutable_data<int16_t>();
auto db_c8 = tb_c8.mutable_data<int16_t>();
auto dc_basic = tc_basic.mutable_data<int32_t>();
auto dbias = tbias.mutable_data<int32_t>();
// trans A, B to c4
basic_trans_mat_to_c8(da, da_c4, k, m, k, true);
basic_trans_mat_to_c8(db, db_c8, n, k, n, false);
LOG(INFO) << "sgemm_c8 M: " << m << ", N: " << n << ", K: " << k
<< ", relu: " << (has_relu ? "true" : "false")
<< ", bias: " << (has_bias ? "true" : "false");
if (FLAGS_check_result) {
basic_gemm_c8(false,
false,
m,
n,
k,
1,
da,
k,
db,
n,
0,
dc_basic,
n,
dbias,
false,
false);
}
Timer t0;
LOG(INFO) << "basic test end";
#ifdef LITE_WITH_ARM
//! compute
double ops = 2.0 * m_round * n * k_round;
std::unique_ptr<paddle::lite::KernelContext> ctx1(
new paddle::lite::KernelContext);
auto& ctx = ctx1->As<paddle::lite::ARMContext>();
ctx.SetRunMode(static_cast<paddle::lite_api::PowerMode>(cls), ths);
auto dc = tc.mutable_data<int32_t>();
for (int j = 0; j < FLAGS_warmup; ++j) {
paddle::lite::arm::math::sgemm_prepack_c8_int16_small(
m, n, k, da_c4, db_c8, dc, &ctx);
}
LOG(INFO) << "basic test end";
for (int i = 0; i < FLAGS_repeats; ++i) {
t0.Start();
paddle::lite::arm::math::sgemm_prepack_c8_int16_small(
m, n, k, da_c4, db_c8, dc, &ctx);
t0.Stop();
}
LOG(INFO) << "basic test end";
LOG(INFO) << "M: " << m << ", N: " << n << ", K: " << k
<< ", power_mode: " << cls << ", threads: " << ths
<< ", GOPS: " << ops * 1e-9f
<< " GOPS, avg time: " << t0.LapTimes().Avg()
<< " ms, min time: " << t0.LapTimes().Min()
<< " ms, mean GOPs: " << ops * 1e-6f / t0.LapTimes().Avg()
<< " GOPs, max GOPs: " << ops * 1e-6f / t0.LapTimes().Min()
<< " GOPs";
if (FLAGS_check_result) {
double max_ratio = 0;
double max_diff = 0;
tensor_cmp_host(tc_basic, tc, max_ratio, max_diff);
LOG(INFO) << "compare result, max diff: " << max_diff
<< ", max ratio: " << max_ratio;
if (std::abs(max_ratio) > 1e-4f && std::abs(max_diff) > 5e-5f) {
Tensor tdiff;
tdiff.set_precision(PRECISION(kInt32));
tdiff.Resize(tc.dims());
tensor_diff(tc_basic, tc, tdiff);
LOG(INFO) << "a: ";
print_tensor(ta);
LOG(INFO) << "a_c8: ";
print_tensor(ta_c4);
LOG(INFO) << "b: ";
print_tensor(tb);
LOG(INFO) << "b_c8: ";
print_tensor(tb_c8);
LOG(INFO) << "basic result: ";
print_tensor(tc_basic);
LOG(INFO) << "lite result: ";
print_tensor(tc);
LOG(INFO) << "diff result: ";
print_tensor(tdiff);
return false;
}
}
#endif
return true;
}
TEST(TestSgemmC4, test_func_sgemm_c4_prepacked) { TEST(TestSgemmC4, test_func_sgemm_c4_prepacked) {
if (FLAGS_basic_test) { if (FLAGS_basic_test) {
...@@ -186,11 +321,11 @@ TEST(TestSgemmC4, test_func_sgemm_c4_prepacked) { ...@@ -186,11 +321,11 @@ TEST(TestSgemmC4, test_func_sgemm_c4_prepacked) {
paddle::lite::DeviceInfo::Init(); paddle::lite::DeviceInfo::Init();
#endif #endif
LOG(INFO) << "run basic sgemm_c4 test"; LOG(INFO) << "run basic sgemm_c4 test";
for (auto& m : {1, 3, 8, 32, 397}) { for (auto& m : {1, 3, 8, 32, 397, 32, 64, 77}) {
for (auto& n : {1, 2, 3, 4, 13, 141, 789}) { for (auto& n : {1, 2, 3, 4, 13, 141, 789, 1}) {
for (auto& k : {1, 3, 8, 59, 234}) { for (auto& k : {1, 3, 8, 59, 234, 19}) {
for (auto& has_bias : {false, true}) { for (auto& has_bias : {false}) {
for (auto& has_relu : {false, true}) { for (auto& has_relu : {false}) {
for (auto& th : {1, 2, 4}) { for (auto& th : {1, 2, 4}) {
auto flag = test_sgemm_c4( auto flag = test_sgemm_c4(
m, n, k, has_bias, has_relu, FLAGS_power_mode, th); m, n, k, has_bias, has_relu, FLAGS_power_mode, th);
...@@ -213,8 +348,41 @@ TEST(TestSgemmC4, test_func_sgemm_c4_prepacked) { ...@@ -213,8 +348,41 @@ TEST(TestSgemmC4, test_func_sgemm_c4_prepacked) {
} }
} }
} }
TEST(TestSgemmC8, test_func_sgemm_c8_prepacked) {
if (FLAGS_basic_test) {
#ifdef LITE_WITH_ARM
paddle::lite::DeviceInfo::Init();
#endif
LOG(INFO) << "run basic sgemm_c4 test";
for (auto& m : {1, 3, 8, 32, 397, 32, 64, 77}) {
for (auto& n : {1, 2, 3, 4, 13, 141, 789, 1}) {
for (auto& k : {1, 3, 8, 59, 234, 19}) {
for (auto& has_bias : {false}) {
for (auto& has_relu : {false}) {
for (auto& th : {1}) {
auto flag = test_sgemm_c8(
m, n, k, has_bias, has_relu, FLAGS_power_mode, th);
if (flag) {
LOG(INFO) << "test m = " << m << ", n=" << n << ", k=" << k
<< ", bias: " << (has_bias ? "true" : "false")
<< ", relu: " << (has_relu ? "true" : "false")
<< " passed\n";
} else {
LOG(FATAL) << "test m = " << m << ", n=" << n << ", k=" << k
<< ", bias: " << (has_bias ? "true" : "false")
<< ", relu: " << (has_relu ? "true" : "false")
<< " failed\n";
}
}
}
}
}
}
}
}
}
TEST(TestSgemmC4Custom, test_func_sgemm_c4_prepacked_custom) { TEST(TestSgemmCnCustom, test_func_sgemm_cn_prepacked_custom) {
#ifdef LITE_WITH_ARM #ifdef LITE_WITH_ARM
paddle::lite::DeviceInfo::Init(); paddle::lite::DeviceInfo::Init();
#endif #endif
...@@ -230,6 +398,18 @@ TEST(TestSgemmC4Custom, test_func_sgemm_c4_prepacked_custom) { ...@@ -230,6 +398,18 @@ TEST(TestSgemmC4Custom, test_func_sgemm_c4_prepacked_custom) {
<< ", k=" << FLAGS_K << ", bias: " << FLAGS_flag_bias << ", k=" << FLAGS_K << ", bias: " << FLAGS_flag_bias
<< ", relu: " << FLAGS_flag_relu << " failed!!"; << ", relu: " << FLAGS_flag_relu << " failed!!";
} }
flag = test_sgemm_c8(FLAGS_M,
FLAGS_N,
FLAGS_K,
FLAGS_flag_bias,
FLAGS_flag_relu,
FLAGS_power_mode,
FLAGS_threads);
if (!flag) {
LOG(FATAL) << "test m = " << FLAGS_M << ", n=" << FLAGS_N
<< ", k=" << FLAGS_K << ", bias: " << FLAGS_flag_bias
<< ", relu: " << FLAGS_flag_relu << " failed!!";
}
LOG(INFO) << "test m = " << FLAGS_M << ", n=" << FLAGS_N << ", k=" << FLAGS_K LOG(INFO) << "test m = " << FLAGS_M << ", n=" << FLAGS_N << ", k=" << FLAGS_K
<< ", bias: " << FLAGS_flag_bias << ", relu: " << FLAGS_flag_relu << ", bias: " << FLAGS_flag_bias << ", relu: " << FLAGS_flag_relu
<< " passed!!"; << " passed!!";
......
...@@ -60,6 +60,72 @@ static void basic_trans_mat_to_c4(const type* input, ...@@ -60,6 +60,72 @@ static void basic_trans_mat_to_c4(const type* input,
} }
} }
} }
template <typename type>
static void basic_trans_mat_to_c8(const type* input,
type* output,
const int ldin,
const int M,
const int K,
bool pack_k) {
const int m_round = (M + 7) / 8 * 8;
int k_round = (K + 7) / 8 * 8;
if (!pack_k) {
k_round = K;
}
const int m_loop = m_round / 8;
type zero_buf[K];
memset(zero_buf, 0, K * sizeof(type));
for (int i = 0; i < m_loop; ++i) {
const type* in0 = input + i * 8 * ldin;
const type* in1 = in0 + ldin;
const type* in2 = in1 + ldin;
const type* in3 = in2 + ldin;
const type* in4 = in3 + ldin;
const type* in5 = in4 + ldin;
const type* in6 = in5 + ldin;
const type* in7 = in6 + ldin;
if (8 * (i + 1) - M > 0) {
switch (8 * (i + 1) - M) {
case 7:
in1 = zero_buf;
case 6:
in2 = zero_buf;
case 5:
in3 = zero_buf;
case 4:
in4 = zero_buf;
case 3:
in5 = zero_buf;
case 2:
in6 = zero_buf;
case 1:
in7 = zero_buf;
default:
break;
}
}
for (int j = 0; j < K; ++j) {
*output++ = *in0++;
*output++ = *in1++;
*output++ = *in2++;
*output++ = *in3++;
*output++ = *in4++;
*output++ = *in5++;
*output++ = *in6++;
*output++ = *in7++;
}
for (int j = K; j < k_round; ++j) {
*output++ = static_cast<type>(0);
*output++ = static_cast<type>(0);
*output++ = static_cast<type>(0);
*output++ = static_cast<type>(0);
*output++ = static_cast<type>(0);
*output++ = static_cast<type>(0);
*output++ = static_cast<type>(0);
*output++ = static_cast<type>(0);
}
}
}
template <typename type, typename type2> template <typename type, typename type2>
static void basic_gemm_c4(bool trans_a, static void basic_gemm_c4(bool trans_a,
...@@ -116,6 +182,60 @@ static void basic_gemm_c4(bool trans_a, ...@@ -116,6 +182,60 @@ static void basic_gemm_c4(bool trans_a,
free(tmp_c); free(tmp_c);
} }
template <typename type, typename type2>
static void basic_gemm_c8(bool trans_a,
bool trans_b,
int m,
int n,
int k,
type2 alpha,
const type* a,
int lda,
const type* b,
int ldb,
type2 beta,
type2* c,
int ldc,
const type2* bias,
bool flag_bias = false,
bool flag_relu = false) {
type2* tmp_c = reinterpret_cast<type2*>(malloc(m * ldc * sizeof(type2)));
memset(tmp_c, 0, m * ldc * sizeof(type2));
#pragma omp parallel for
for (int i = 0; i < m; ++i) {
auto bias_data = static_cast<type2>(0);
if (flag_bias) {
bias_data = bias[i];
}
for (int j = 0; j < n; ++j) {
auto sum = static_cast<type2>(0);
for (int l = 0; l < k; ++l) {
type av;
type bv;
if (trans_a) {
av = a[l * lda + i];
} else {
av = a[i * lda + l];
}
if (trans_b) {
bv = b[j * ldb + l];
} else {
bv = b[l * ldb + j];
}
sum += av * bv;
}
type2 tmp = alpha * sum + beta * tmp_c[i * ldc + j] + bias_data;
if (flag_relu) {
tmp_c[i * ldc + j] = tmp > (type2)0 ? tmp : (type2)0;
} else {
tmp_c[i * ldc + j] = tmp;
}
}
}
//! trans c to c4
basic_trans_mat_to_c8(tmp_c, c, ldc, m, n, false);
free(tmp_c);
}
template <typename type, typename type2> template <typename type, typename type2>
static void basic_gemm(bool trans_a, static void basic_gemm(bool trans_a,
bool trans_b, bool trans_b,
......
...@@ -41,6 +41,10 @@ void fill_tensor_const(Tensor& tensor, float value) { // NOLINT ...@@ -41,6 +41,10 @@ void fill_tensor_const(Tensor& tensor, float value) { // NOLINT
fill_tensor_host_const_impl( fill_tensor_host_const_impl(
tensor.mutable_data<int8_t>(), static_cast<signed char>(value), size); tensor.mutable_data<int8_t>(), static_cast<signed char>(value), size);
break; break;
case PRECISION(kInt16):
fill_tensor_host_const_impl(
tensor.mutable_data<int16_t>(), static_cast<int16_t>(value), size);
break;
case PRECISION(kInt32): case PRECISION(kInt32):
fill_tensor_host_const_impl( fill_tensor_host_const_impl(
tensor.mutable_data<int>(), static_cast<int>(value), size); tensor.mutable_data<int>(), static_cast<int>(value), size);
...@@ -69,6 +73,12 @@ void fill_tensor_host_rand_impl<signed char>(signed char* dio, int64_t size) { ...@@ -69,6 +73,12 @@ void fill_tensor_host_rand_impl<signed char>(signed char* dio, int64_t size) {
} }
} }
template <> template <>
void fill_tensor_host_rand_impl<int16_t>(int16_t* dio, int64_t size) {
for (int64_t i = 0; i < size; ++i) {
dio[i] = (rand() % 256 - 128) * 2; // NOLINT
}
}
template <>
void fill_tensor_host_rand_impl<unsigned char>(unsigned char* dio, void fill_tensor_host_rand_impl<unsigned char>(unsigned char* dio,
int64_t size) { int64_t size) {
for (int64_t i = 0; i < size; ++i) { for (int64_t i = 0; i < size; ++i) {
...@@ -86,6 +96,9 @@ void fill_tensor_rand(Tensor& tensor) { // NOLINT ...@@ -86,6 +96,9 @@ void fill_tensor_rand(Tensor& tensor) { // NOLINT
case PRECISION(kInt8): case PRECISION(kInt8):
fill_tensor_host_rand_impl(tensor.mutable_data<int8_t>(), size); fill_tensor_host_rand_impl(tensor.mutable_data<int8_t>(), size);
break; break;
case PRECISION(kInt16):
fill_tensor_host_rand_impl(tensor.mutable_data<int16_t>(), size);
break;
case PRECISION(kInt32): case PRECISION(kInt32):
fill_tensor_host_rand_impl(tensor.mutable_data<int>(), size); fill_tensor_host_rand_impl(tensor.mutable_data<int>(), size);
break; break;
......
...@@ -678,15 +678,9 @@ void resize(const uint8_t* src, ...@@ -678,15 +678,9 @@ void resize(const uint8_t* src,
} else if (srcFormat == NV12 || srcFormat == NV21) { } else if (srcFormat == NV12 || srcFormat == NV21) {
nv21_resize(src, dst, srcw, srch, dstw, dsth); nv21_resize(src, dst, srcw, srch, dstw, dsth);
return; return;
num = 1;
int hout = static_cast<int>(0.5 * dsth);
dsth += hout;
} else if (srcFormat == BGR || srcFormat == RGB) { } else if (srcFormat == BGR || srcFormat == RGB) {
bgr_resize(src, dst, srcw, srch, dstw, dsth); bgr_resize(src, dst, srcw, srch, dstw, dsth);
return; return;
w_in = srcw * 3;
w_out = dstw * 3;
num = 3;
} else if (srcFormat == BGRA || srcFormat == RGBA) { } else if (srcFormat == BGRA || srcFormat == RGBA) {
w_in = srcw * 4; w_in = srcw * 4;
w_out = dstw * 4; w_out = dstw * 4;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册