提交 73b4a27f 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4841 Matmul_int8 arm64 neon optimize

Merge pull request !4841 from zhanyuan/dev
......@@ -18,6 +18,7 @@
#include "src/runtime/kernel/arm/nnacl/quantization/fixed_point.h"
#include "src/runtime/runtime_api.h"
#include "src/kernel_registry.h"
#include "src/runtime/kernel/arm/nnacl/optimized_kernel.h"
using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
......@@ -89,14 +90,24 @@ int DeConvInt8CPUKernel::Init() {
void DeConvInt8CPUKernel::CheckSupportOptimize() {
matmul_func_ = nullptr;
support_optimize_ = false;
support_optimize_ = true;
#ifdef ENABLE_ARM64
/* todo */
void *optimize_op_handler = OptimizeModule::GetInstance()->optimized_op_handler_;
if (optimize_op_handler != nullptr) {
dlerror();
*(reinterpret_cast<void **>(&matmul_func_)) = dlsym(optimize_op_handler, "MatMulR4Int8_optimize_handler");
auto dlopen_error = dlerror();
if (dlopen_error != nullptr) {
MS_LOG(ERROR) << "load matmul func failed! " << dlopen_error << ".";
support_optimize_ = false;
matmul_func_ = nullptr;
}
} else {
support_optimize_ = false;
matmul_func_ = nullptr;
}
#endif
support_optimize_ = true;
matmul_func_ = MatMulOptR4Int8;
}
int DeConvInt8CPUKernel::InitParam() {
......@@ -109,15 +120,10 @@ int DeConvInt8CPUKernel::InitParam() {
matmul_param_->deep_ = conv_param_->input_channel_;
matmul_param_->col_ = conv_param_->output_channel_ * conv_param_->kernel_h_ * conv_param_->kernel_w_;
if (support_optimize_) {
input_trans_func_ = RowMajor2Row16x4MajorInt8;
size_t oc4 = UP_DIV(conv_param_->output_channel_, C4NUM);
thread_count_ = MSMIN(op_parameter_->thread_num_, oc4);
thread_stride_ = UP_DIV(oc4, thread_count_);
} else {
/*todo */
}
input_trans_func_ = RowMajor2Row16x4MajorInt8;
size_t oc4 = UP_DIV(conv_param_->output_channel_, C4NUM);
thread_count_ = MSMIN(op_parameter_->thread_num_, oc4);
thread_stride_ = UP_DIV(oc4, thread_count_);
return RET_OK;
}
......
......@@ -47,9 +47,29 @@ int MatmulInt8CPUKernel::ReSize() {
params_->deep_ = params_->a_transpose_ ? x_shape[x_shape.size() - 2] : x_shape[x_shape.size() - 1];
params_->row_8_ = UP_ROUND(params_->row_, 8);
params_->col_8_ = UP_ROUND(params_->col_, 8);
thread_count_ = MSMIN(thread_count_, UP_DIV(params_->col_8_, 8));
thread_stride_ = UP_DIV(UP_DIV(params_->col_8_, 8), thread_count_);
#ifdef ENABLE_ARM64
r4_ = UP_ROUND(params_->row_, 4);
c4_ = UP_ROUND(params_->col_, 4);
d16_ = UP_ROUND(params_->deep_, 16);
a_r4d16_ptr_ = reinterpret_cast<int8_t *>(ctx_->allocator->Malloc(r4_ * d16_ * sizeof(int8_t)));
if (!a_r4d16_ptr_) return RET_MEMORY_FAILED;
memset(a_r4d16_ptr_, 0, r4_ * d16_ * sizeof(int8_t));
b_c4d16_ptr_ = reinterpret_cast<int8_t *>(ctx_->allocator->Malloc(c4_ * d16_ * sizeof(int8_t)));
if (!b_c4d16_ptr_) return RET_MEMORY_FAILED;
memset(b_c4d16_ptr_, 0, c4_ * d16_ * sizeof(int8_t));
c_r4c4_ptr_ = reinterpret_cast<int8_t *>(ctx_->allocator->Malloc(r4_ * c4_ * sizeof(int8_t)));
if (!c_r4c4_ptr_) return RET_MEMORY_FAILED;
memset(c_r4c4_ptr_, 0, r4_ * c4_ * sizeof(int8_t));
a_sums_ = reinterpret_cast<int *>(ctx_->allocator->Malloc(r4_ * sizeof(int)));
if (!a_sums_) return RET_MEMORY_FAILED;
memset(a_sums_, 0, r4_ * sizeof(int));
b_bias_ = reinterpret_cast<int *>(ctx_->allocator->Malloc(c4_ * sizeof(int)));
if (!b_bias_) return RET_MEMORY_FAILED;
memset(b_bias_, 0, c4_ * sizeof(int));
thread_count_ = MSMIN(thread_count_, UP_DIV(c4_, 4));
thread_stride_ = UP_DIV(UP_DIV(c4_, 4), thread_count_);
#else
a_c8_ptr_ = reinterpret_cast<int8_t *>(ctx_->allocator->Malloc(params_->row_8_ * params_->deep_ * sizeof(int8_t)));
if (!a_c8_ptr_) {
return RET_MEMORY_FAILED;
......@@ -65,6 +85,9 @@ int MatmulInt8CPUKernel::ReSize() {
return RET_MEMORY_FAILED;
}
memset(c_r8x8_ptr_, 0, params_->row_8_ * params_->col_8_ * sizeof(int));
thread_count_ = MSMIN(thread_count_, UP_DIV(params_->col_8_, 8));
thread_stride_ = UP_DIV(UP_DIV(params_->col_8_, 8), thread_count_);
#endif
auto input_tensor = in_tensors_[0];
auto params = input_tensor->GetQuantParams();
......@@ -89,14 +112,27 @@ int MatmulInt8CPUKernel::ReSize() {
}
int MatmulInt8CPUKernel::RunImpl(int task_id) {
#ifdef ENABLE_ARM64
int cur_oc = MSMIN(thread_stride_, UP_DIV(c4_, 4) - task_id * thread_stride_);
if (cur_oc <= 0) {
return RET_OK;
}
auto cur_b = b_c4d16_ptr_ + task_id * thread_stride_ * 4 * d16_;
auto cur_c = c_r4c4_ptr_ + task_id * thread_stride_ * 4 * r4_;
auto &p = quant_params_;
MatmulInt8Neon64(a_r4d16_ptr_, cur_b, cur_c, r4_, c4_, d16_, a_sums_, b_bias_, INT_MIN, INT_MAX, p.output.zp_,
p.quant_multiplier, p.left_shift, p.right_shift);
#else
int cur_oc = MSMIN(thread_stride_, UP_DIV(params_->col_8_, 8) - task_id * thread_stride_);
if (cur_oc <= 0) {
return RET_OK;
}
auto cur_b = b_r8_ptr_ + task_id * thread_stride_ * C8NUM * params_->deep_;
auto cur_c = c_r8x8_ptr_ + task_id * thread_stride_ * C8NUM * params_->row_8_;
MatMulInt8(a_c8_ptr_, cur_b, cur_c, params_->row_8_, cur_oc * 8, params_->deep_, quant_params_.input.zp_,
quant_params_.weight.zp_);
#endif
return RET_OK;
}
......@@ -127,6 +163,24 @@ int MatmulInt8CPUKernel::Run() {
auto cur_a_ptr = a_ptr + i * a_stride;
auto cur_b_ptr = b_ptr + i * b_stride;
auto cur_c_ptr = c_ptr + i * c_stride;
#ifdef ENABLE_ARM64
if (params_->a_transpose_) {
RowMajor2Col16x4Major(cur_a_ptr, params_->deep_, params_->row_, a_r4d16_ptr_, d16_);
} else {
RowMajor2Row4x16Major(cur_a_ptr, params_->row_, params_->deep_, a_r4d16_ptr_, d16_);
}
if (params_->b_transpose_) {
RowMajor2Row4x16Major(cur_b_ptr, params_->col_, params_->deep_, b_c4d16_ptr_, d16_);
} else {
RowMajor2Col16x4Major(cur_b_ptr, params_->deep_, params_->col_, b_c4d16_ptr_, d16_);
}
auto &q = quant_params_;
RowMajor2Asums(cur_a_ptr, params_->row_, params_->deep_, q.weight.zp_, a_sums_);
RowMajor2Bbias(cur_b_ptr, params_->deep_, params_->col_, q.input.zp_, q.weight.zp_, NULL, b_bias_);
LiteBackendParallelLaunch(MatmulInt8Run, this, thread_count_);
Row4x4Major2RowMajor(c_r4c4_ptr_, r4_, cur_c_ptr, params_->row_, params_->col_);
#else
if (params_->a_transpose_) {
RowMajor2Row8MajorInt8(cur_a_ptr, a_c8_ptr_, params_->deep_, params_->row_);
} else {
......@@ -141,6 +195,7 @@ int MatmulInt8CPUKernel::Run() {
auto &q = quant_params_;
SimplePostFuncInt8(c_r8x8_ptr_, cur_c_ptr, params_->col_, params_->row_, params_->row_8_, q.quant_multiplier,
q.left_shift, q.right_shift, q.output.zp_);
#endif
}
return RET_OK;
......
......@@ -39,6 +39,28 @@ class MatmulInt8CPUKernel : public MatmulBaseCPUKernel {
private:
void FreeTmpBuffer() {
#ifdef ENABLE_ARM64
if (a_r4d16_ptr_ != nullptr) {
ctx_->allocator->Free(a_r4d16_ptr_);
a_r4d16_ptr_ = nullptr;
}
if (b_c4d16_ptr_ != nullptr) {
ctx_->allocator->Free(b_c4d16_ptr_);
b_c4d16_ptr_ = nullptr;
}
if (c_r4c4_ptr_ != nullptr) {
ctx_->allocator->Free(c_r4c4_ptr_);
c_r4c4_ptr_ = nullptr;
}
if (a_sums_ != nullptr) {
ctx_->allocator->Free(a_sums_);
a_sums_ = nullptr;
}
if (b_bias_ != nullptr) {
ctx_->allocator->Free(b_bias_);
b_bias_ = nullptr;
}
#else
if (a_c8_ptr_ != nullptr) {
ctx_->allocator->Free(a_c8_ptr_);
a_c8_ptr_ = nullptr;
......@@ -51,12 +73,24 @@ class MatmulInt8CPUKernel : public MatmulBaseCPUKernel {
ctx_->allocator->Free(c_r8x8_ptr_);
c_r8x8_ptr_ = nullptr;
}
#endif
}
MatmulQuantArg quant_params_;
#ifdef ENABLE_ARM64
int8_t *a_r4d16_ptr_ = nullptr;
int8_t *b_c4d16_ptr_ = nullptr;
int8_t *c_r4c4_ptr_ = nullptr;
int *a_sums_ = nullptr;
int *b_bias_ = nullptr;
int r4_;
int c4_;
int d16_;
#else
int8_t *a_c8_ptr_ = nullptr;
int8_t *b_r8_ptr_ = nullptr;
int *c_r8x8_ptr_ = nullptr;
};
#endif
}; // namespace mindspore::kernel
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_MATMUL_INT8_H_
#ifdef __aarch64__
.text
.align 5
.global MatmulInt8Neon64
#ifndef __APPLE__
.type MatmulInt8Neon64, %function
#endif
//
// int8 RM 16x4 block
// /-----------------------------------------|
// |v4.b[0] v5.b[0] v6.b[0] v7.b[0] |
// | ... ... ... ... |
// |v4.b[15] v5.b[15] v5.b[15] v7.b[15] |
// \-----------------------------------------/
// int8 LM 4x16 block
// /---------------------\ /-----------------------------------------|
// |v0.b[0] ... v0.b[15] | |v16.4s v17.4s v18.4s v19.4s |
// |v1.b[0] ... v1.b[15] | |v20.4s v21.4s v22.4s v23.4s |
// |v2.b[0] ... v2.b[15] | |v24.4s v25.4s v26.4s v27.4s |
// |v3.b[0] ... v3.b[15] | |v28.4s v29.4s v30.4s v31.4s |
// \---------------------/ \-----------------------------------------/
// int32 accumulators 4x4 block
//void MatmulInt8Neon64(const int8_t *a, const int8_t *b, int8_t *dst, int row4, int col4, int deep16,
// const int *a_sums, const int *bias, int act_min, int act_max, int out_zp,
// int multiplier, int left_shift, int right_shift);
// x0: a(left matrix ptr)
// x1: b(right matrix ptr)
// x2: out ptr
// w3: row4
// w4: col4
// w5: deep16
// x6: a_sums
// x7: bias
// w8: act_min
// w9: act_max
// w10: out_zp
// w11: multiplier
// w12: left_shift
// w13: right_shift
MatmulInt8Neon64:
sub sp, sp, #160
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
stp x19, x20, [sp], #16
stp x21, x22, [sp], #16
ldr w8, [sp]
ldr w9, [sp, #8]
ldr w10, [sp, #16]
ldr w11, [sp, #24]
ldr w12, [sp, #32]
ldr w13, [sp, #40]
mov w15, #0 // b col index
mov w16, #0 // a row index
mov w17, #4 // sizeof(int8)*4
mul w21, w5, w17 // the stride of a/b: sizeof(int8)*4*deep16
L1:
cmp w15, w4
beq End1
mov w16, #0 // reset a row index
mov x17, x0 // reload a ptr
mov x22, x6 // reload a_sums ptr
L2:
cmp w16, w3
beq End2
mov x18, x1 // reload b ptr
mov x19, x7 // reload bias ptr
mov w20, w5 // reload depth
dup v16.4s, wzr
dup v17.4s, wzr
dup v18.4s, wzr
dup v19.4s, wzr
dup v20.4s, wzr
dup v21.4s, wzr
dup v22.4s, wzr
dup v23.4s, wzr
dup v24.4s, wzr
dup v25.4s, wzr
dup v26.4s, wzr
dup v27.4s, wzr
dup v28.4s, wzr
dup v29.4s, wzr
dup v30.4s, wzr
dup v31.4s, wzr
L3:
cmp w20, #0
beq End3
ld1 {v0.16b}, [x17], #16
ld1 {v1.16b}, [x17], #16
ld1 {v2.16b}, [x17], #16
ld1 {v3.16b}, [x17], #16
ld1 {v4.16b}, [x18], #16
ld1 {v5.16b}, [x18], #16
ld1 {v6.16b}, [x18], #16
ld1 {v7.16b}, [x18], #16
smull v8.8h, v4.8b, v0.8b
smull v9.8h, v5.8b, v0.8b
smull v10.8h, v6.8b, v0.8b
smull v11.8h, v7.8b, v0.8b
smull v12.8h, v4.8b, v1.8b
smull v13.8h, v5.8b, v1.8b
smull v14.8h, v6.8b, v1.8b
smull v15.8h, v7.8b, v1.8b
smlal2 v8.8h, v4.16b, v0.16b
smlal2 v9.8h, v5.16b, v0.16b
smlal2 v10.8h, v6.16b, v0.16b
smlal2 v11.8h, v7.16b, v0.16b
smlal2 v12.8h, v4.16b, v1.16b
smlal2 v13.8h, v5.16b, v1.16b
smlal2 v14.8h, v6.16b, v1.16b
smlal2 v15.8h, v7.16b, v1.16b
sadalp v16.4s, v8.8h
sadalp v17.4s, v9.8h
sadalp v18.4s, v10.8h
sadalp v19.4s, v11.8h
sadalp v20.4s, v12.8h
sadalp v21.4s, v13.8h
sadalp v22.4s, v14.8h
sadalp v23.4s, v15.8h
smull v8.8h, v4.8b, v2.8b
smull v9.8h, v5.8b, v2.8b
smull v10.8h, v6.8b, v2.8b
smull v11.8h, v7.8b, v2.8b
smull v12.8h, v4.8b, v3.8b
smull v13.8h, v5.8b, v3.8b
smull v14.8h, v6.8b, v3.8b
smull v15.8h, v7.8b, v3.8b
smlal2 v8.8h, v4.16b, v2.16b
smlal2 v9.8h, v5.16b, v2.16b
smlal2 v10.8h, v6.16b, v2.16b
smlal2 v11.8h, v7.16b, v2.16b
smlal2 v12.8h, v4.16b, v3.16b
smlal2 v13.8h, v5.16b, v3.16b
smlal2 v14.8h, v6.16b, v3.16b
smlal2 v15.8h, v7.16b, v3.16b
sadalp v24.4s, v8.8h
sadalp v25.4s, v9.8h
sadalp v26.4s, v10.8h
sadalp v27.4s, v11.8h
sadalp v28.4s, v12.8h
sadalp v29.4s, v13.8h
sadalp v30.4s, v14.8h
sadalp v31.4s, v15.8h
subs w20, w20, #16 // depth + 16
b L3
End3:
addp v16.4s, v16.4s, v17.4s
addp v18.4s, v18.4s, v19.4s
addp v20.4s, v20.4s, v21.4s
addp v22.4s, v22.4s, v23.4s
addp v24.4s, v24.4s, v25.4s
addp v26.4s, v26.4s, v27.4s
addp v28.4s, v28.4s, v29.4s
addp v30.4s, v30.4s, v31.4s
addp v16.4s, v16.4s, v18.4s
addp v17.4s, v20.4s, v22.4s
addp v18.4s, v24.4s, v26.4s
addp v19.4s, v28.4s, v30.4s
// Add (Bias+Depth*Za*Zb-Za*Bsums)
ld1 {v15.4s}, [x19], #16
add v16.4s, v16.4s, v15.4s
add v17.4s, v17.4s, v15.4s
add v18.4s, v18.4s, v15.4s
add v19.4s, v19.4s, v15.4s
// Subtract (Asums*Zb)
ld1 {v14.4s}, [x22], #16
dup v20.4s, v14.s[0]
dup v21.4s, v14.s[1]
dup v22.4s, v14.s[2]
dup v23.4s, v14.s[3]
sub v16.4s, v16.4s, v20.4s
sub v17.4s, v17.4s, v21.4s
sub v18.4s, v18.4s, v22.4s
sub v19.4s, v19.4s, v23.4s
// Apply left shift
dup v13.4s, w12
sqshl v16.4s, v16.4s, v13.4s
sqshl v17.4s, v17.4s, v13.4s
sqshl v18.4s, v18.4s, v13.4s
sqshl v19.4s, v19.4s, v13.4s
// Apply the fixed-point part of the multiplier.
dup v12.4s, w11
sqrdmulh v16.4s, v16.4s, v12.4s
sqrdmulh v17.4s, v17.4s, v12.4s
sqrdmulh v18.4s, v18.4s, v12.4s
sqrdmulh v19.4s, v19.4s, v12.4s
// Apply right shift
dup v11.4s, w13
and v20.16b, v11.16b, v16.16b
sshr v20.4s, v20.4s, #31
sqadd v16.4s, v16.4s, v20.4s
srshl v16.4s, v16.4s, v11.4s
and v21.16b, v11.16b, v17.16b
sshr v21.4s, v21.4s, #31
sqadd v17.4s, v17.4s, v21.4s
srshl v17.4s, v17.4s, v11.4s
and v22.16b, v11.16b, v18.16b
sshr v22.4s, v22.4s, #31
sqadd v18.4s, v18.4s, v22.4s
srshl v18.4s, v18.4s, v11.4s
and v23.16b, v11.16b, v19.16b
sshr v23.4s, v23.4s, #31
sqadd v19.4s, v19.4s, v23.4s
srshl v19.4s, v19.4s, v11.4s
// Add the destination zero point
dup v10.4s, w10
add v16.4s, v16.4s, v10.4s
add v17.4s, v17.4s, v10.4s
add v18.4s, v18.4s, v10.4s
add v19.4s, v19.4s, v10.4s
// Apply the act_min bound
dup v9.4s, w8
smax v16.4s, v16.4s, v9.4s
smax v17.4s, v17.4s, v9.4s
smax v18.4s, v18.4s, v9.4s
smax v19.4s, v19.4s, v9.4s
// Apply the act_min bound
dup v8.4s, w9
smin v16.4s, v16.4s, v8.4s
smin v17.4s, v17.4s, v8.4s
smin v18.4s, v18.4s, v8.4s
smin v19.4s, v19.4s, v8.4s
// int32 -> int16
sqxtn v13.4h, v16.4s
sqxtn2 v13.8h, v17.4s
sqxtn v14.4h, v18.4s
sqxtn2 v14.8h, v19.4s
// int16 -> int8
sqxtn v15.8b, v13.8h
sqxtn2 v15.16b, v14.8h
st1 {v15.16b}, [x2], #16
add w16, w16, #4 // a row index + 4
b L2
End2:
add w15, w15, #4 // b col index + 4
add x1, x1, x21 // b ptr + stride
add x7, x7, #16 // bias ptr + stride
b L1
End1:
sub sp, sp, #160
ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
ldp x19, x20, [sp], #16
ldp x21, x22, [sp], #16
ret
#endif
#ifdef __aarch64__
.text
.align 5
.global MatMulR4Int8Neon64
#ifndef __APPLE__
.type MatMulR4Int8Neon64, %function
#endif
//
// int8 RM 16x4 block
// /-----------------------------------------|
// |v4.b[0] v5.b[0] v6.b[0] v7.b[0] |
// | ... ... ... ... |
// |v4.b[15] v5.b[15] v5.b[15] v7.b[15] |
// \-----------------------------------------/
// int8 LM 4x16 block
// /---------------------\ /-----------------------------------------|
// |v0.b[0] ... v0.b[15] | |v16.4s v17.4s v18.4s v19.4s |
// |v1.b[0] ... v1.b[15] | |v20.4s v21.4s v22.4s v23.4s |
// |v2.b[0] ... v2.b[15] | |v24.4s v25.4s v26.4s v27.4s |
// |v3.b[0] ... v3.b[15] | |v28.4s v29.4s v30.4s v31.4s |
// \---------------------/ \-----------------------------------------/
// int32 accumulators 4x4 block
//void MatMulR4Int8Neon64(const int8_t *a, const int8_t *b, int32_t *dst, int row4, int col4, int deep16,
// const int *input_sum, const int *bias)
// x0: a(left matrix ptr)
// x1: b(right matrix ptr)
// x2: out ptr
// w3: row4
// w4: col4
// w5: deep16
// x6: a_sums
// x7: bias
MatMulR4Int8Neon64:
sub sp, sp, #128
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
mov w15, #0 // b col index
mov w16, #0 // a row index
mov w17, #4 // sizeof(int8)*4
mul w12, w5, w17 // the stride of a/b: sizeof(int8)*4*deep16
L1:
cmp w15, w4
beq End1
mov w16, #0 // reset a row index
mov x17, x0 // reload a ptr
mov x13, x6 // reload a_sums ptr
L2:
cmp w16, w3
beq End2
mov x18, x1 // reload b ptr
mov x10, x7 // reload bias ptr
mov w11, w5 // reload depth
dup v16.4s, wzr
dup v17.4s, wzr
dup v18.4s, wzr
dup v19.4s, wzr
dup v20.4s, wzr
dup v21.4s, wzr
dup v22.4s, wzr
dup v23.4s, wzr
dup v24.4s, wzr
dup v25.4s, wzr
dup v26.4s, wzr
dup v27.4s, wzr
dup v28.4s, wzr
dup v29.4s, wzr
dup v30.4s, wzr
dup v31.4s, wzr
L3:
cmp w11, #0
beq End3
ld1 {v0.16b}, [x17], #16
ld1 {v1.16b}, [x17], #16
ld1 {v2.16b}, [x17], #16
ld1 {v3.16b}, [x17], #16
ld1 {v4.16b}, [x18], #16
ld1 {v5.16b}, [x18], #16
ld1 {v6.16b}, [x18], #16
ld1 {v7.16b}, [x18], #16
smull v8.8h, v4.8b, v0.8b
smull v9.8h, v5.8b, v0.8b
smull v10.8h, v6.8b, v0.8b
smull v11.8h, v7.8b, v0.8b
smull v12.8h, v4.8b, v1.8b
smull v13.8h, v5.8b, v1.8b
smull v14.8h, v6.8b, v1.8b
smull v15.8h, v7.8b, v1.8b
smlal2 v8.8h, v4.16b, v0.16b
smlal2 v9.8h, v5.16b, v0.16b
smlal2 v10.8h, v6.16b, v0.16b
smlal2 v11.8h, v7.16b, v0.16b
smlal2 v12.8h, v4.16b, v1.16b
smlal2 v13.8h, v5.16b, v1.16b
smlal2 v14.8h, v6.16b, v1.16b
smlal2 v15.8h, v7.16b, v1.16b
sadalp v16.4s, v8.8h
sadalp v17.4s, v9.8h
sadalp v18.4s, v10.8h
sadalp v19.4s, v11.8h
sadalp v20.4s, v12.8h
sadalp v21.4s, v13.8h
sadalp v22.4s, v14.8h
sadalp v23.4s, v15.8h
smull v8.8h, v4.8b, v2.8b
smull v9.8h, v5.8b, v2.8b
smull v10.8h, v6.8b, v2.8b
smull v11.8h, v7.8b, v2.8b
smull v12.8h, v4.8b, v3.8b
smull v13.8h, v5.8b, v3.8b
smull v14.8h, v6.8b, v3.8b
smull v15.8h, v7.8b, v3.8b
smlal2 v8.8h, v4.16b, v2.16b
smlal2 v9.8h, v5.16b, v2.16b
smlal2 v10.8h, v6.16b, v2.16b
smlal2 v11.8h, v7.16b, v2.16b
smlal2 v12.8h, v4.16b, v3.16b
smlal2 v13.8h, v5.16b, v3.16b
smlal2 v14.8h, v6.16b, v3.16b
smlal2 v15.8h, v7.16b, v3.16b
sadalp v24.4s, v8.8h
sadalp v25.4s, v9.8h
sadalp v26.4s, v10.8h
sadalp v27.4s, v11.8h
sadalp v28.4s, v12.8h
sadalp v29.4s, v13.8h
sadalp v30.4s, v14.8h
sadalp v31.4s, v15.8h
subs w11, w11, #16 // depth + 16
b L3
End3:
addp v16.4s, v16.4s, v17.4s
addp v18.4s, v18.4s, v19.4s
addp v20.4s, v20.4s, v21.4s
addp v22.4s, v22.4s, v23.4s
addp v24.4s, v24.4s, v25.4s
addp v26.4s, v26.4s, v27.4s
addp v28.4s, v28.4s, v29.4s
addp v30.4s, v30.4s, v31.4s
addp v16.4s, v16.4s, v18.4s
addp v17.4s, v20.4s, v22.4s
addp v18.4s, v24.4s, v26.4s
addp v19.4s, v28.4s, v30.4s
// Add (Bias+Depth*Za*Zb-Za*Bsums)
ld1 {v15.4s}, [x10], #16
add v16.4s, v16.4s, v15.4s
add v17.4s, v17.4s, v15.4s
add v18.4s, v18.4s, v15.4s
add v19.4s, v19.4s, v15.4s
// Subtract (Asums*Zb)
ld1 {v14.4s}, [x13], #16
dup v20.4s, v14.s[0]
dup v21.4s, v14.s[1]
dup v22.4s, v14.s[2]
dup v23.4s, v14.s[3]
sub v16.4s, v16.4s, v20.4s
sub v17.4s, v17.4s, v21.4s
sub v18.4s, v18.4s, v22.4s
sub v19.4s, v19.4s, v23.4s
st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64
add w16, w16, #4 // a row index + 4
b L2
End2:
add w15, w15, #4 // b col index + 4
add x1, x1, x12 // b ptr + stride
add x7, x7, #16 // bias ptr + stride
b L1
End1:
sub sp, sp, #128
ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
ret
#endif
#ifdef __aarch64__
.text
.align 5
.global MatMulOptR4Int8Neon64
#ifndef __APPLE__
.type MatMulOptR4Int8Neon64, %function
#endif
//
// int8 RM 16x4 block
// /-----------------------------------------|
// |v4.b[0] v5.b[0] v6.b[0] v7.b[0] |
// | ... ... ... ... |
// |v4.b[15] v5.b[15] v5.b[15] v7.b[15] |
// \-----------------------------------------/
// int8 LM 4x16 block
// /---------------------\ /-----------------------------------------|
// |v0.b[0] ... v0.b[15] | |v16.4s v17.4s v18.4s v19.4s |
// |v1.b[0] ... v1.b[15] | |v20.4s v21.4s v22.4s v23.4s |
// |v2.b[0] ... v2.b[15] | |v24.4s v25.4s v26.4s v27.4s |
// |v3.b[0] ... v3.b[15] | |v28.4s v29.4s v30.4s v31.4s |
// \---------------------/ \-----------------------------------------/
// int32 accumulators 4x4 block
//void MatMulOptR4Int8Neon64(const int8_t *a, const int8_t *b, int *dst, int row4, int col4, int deep16,
// const int *input_sum, const int *bias)
// x0: a(left matrix ptr)
// x1: b(right matrix ptr)
// x2: out ptr
// w3: row4
// w4: col4
// w5: deep16
// x6: a_sums
// x7: bias
MatMulOptR4Int8Neon64:
sub sp, sp, #128
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
mov w15, #0 // b col index
mov w16, #0 // a row index
mov w17, #4 // sizeof(int8)*4
mul w12, w5, w17 // the stride of a/b: sizeof(int8)*4*deep16
L1:
cmp w15, w4
beq End1
mov w16, #0 // reset a row index
mov x17, x0 // reload a ptr
mov x13, x6 // reload a_sums ptr
L2:
cmp w16, w3
beq End2
mov x18, x1 // reload b ptr
mov x10, x7 // reload bias ptr
mov w11, w5 // reload depth
dup v16.4s, wzr
dup v17.4s, wzr
dup v18.4s, wzr
dup v19.4s, wzr
dup v20.4s, wzr
dup v21.4s, wzr
dup v22.4s, wzr
dup v23.4s, wzr
dup v24.4s, wzr
dup v25.4s, wzr
dup v26.4s, wzr
dup v27.4s, wzr
dup v28.4s, wzr
dup v29.4s, wzr
dup v30.4s, wzr
dup v31.4s, wzr
L3:
cmp w11, #0
beq End3
ld1 {v0.16b}, [x17], #16
ld1 {v1.16b}, [x17], #16
ld1 {v2.16b}, [x17], #16
ld1 {v3.16b}, [x17], #16
ld1 {v4.16b}, [x18], #16
ld1 {v5.16b}, [x18], #16
ld1 {v6.16b}, [x18], #16
ld1 {v7.16b}, [x18], #16
sdot v16.4s, v4.16b, v0.16b
sdot v17.4s, v5.16b, v0.16b
sdot v18.4s, v6.16b, v0.16b
sdot v19.4s, v7.16b, v0.16b
sdot v20.4s, v4.16b, v1.16b
sdot v21.4s, v5.16b, v1.16b
sdot v22.4s, v6.16b, v1.16b
sdot v23.4s, v7.16b, v1.16b
sdot v24.4s, v4.16b, v2.16b
sdot v25.4s, v5.16b, v2.16b
sdot v26.4s, v6.16b, v2.16b
sdot v27.4s, v7.16b, v2.16b
sdot v28.4s, v4.16b, v3.16b
sdot v29.4s, v5.16b, v3.16b
sdot v30.4s, v6.16b, v3.16b
sdot v31.4s, v7.16b, v3.16b
subs w11, w11, #16 // depth + 16
b L3
End3:
addp v16.4s, v16.4s, v17.4s
addp v18.4s, v18.4s, v19.4s
addp v20.4s, v20.4s, v21.4s
addp v22.4s, v22.4s, v23.4s
addp v24.4s, v24.4s, v25.4s
addp v26.4s, v26.4s, v27.4s
addp v28.4s, v28.4s, v29.4s
addp v30.4s, v30.4s, v31.4s
addp v16.4s, v16.4s, v18.4s
addp v17.4s, v20.4s, v22.4s
addp v18.4s, v24.4s, v26.4s
addp v19.4s, v28.4s, v30.4s
// Add (Bias+Depth*Za*Zb-Za*Bsums)
ld1 {v15.4s}, [x10], #16
add v16.4s, v16.4s, v15.4s
add v17.4s, v17.4s, v15.4s
add v18.4s, v18.4s, v15.4s
add v19.4s, v19.4s, v15.4s
// Subtract (Asums*Zb)
ld1 {v14.4s}, [x13], #16
dup v20.4s, v14.s[0]
dup v21.4s, v14.s[1]
dup v22.4s, v14.s[2]
dup v23.4s, v14.s[3]
sub v16.4s, v16.4s, v20.4s
sub v17.4s, v17.4s, v21.4s
sub v18.4s, v18.4s, v22.4s
sub v19.4s, v19.4s, v23.4s
st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64
add w16, w16, #4 // a row index + 4
b L2
End2:
add w15, w15, #4 // b col index + 4
add x1, x1, x12 // b ptr + stride
add x7, x7, #16 // bias ptr + stride
b L1
End1:
sub sp, sp, #128
ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
ret
#endif
......@@ -269,7 +269,7 @@ void MatMul8x8(const float *a, const float *b, float *dst, const float *bias, Ac
void MatMul(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row, int col,
int stride, bool write_nhwc) {
#ifdef __aarch64__
#ifdef ENABLE_ARM64
MatmulFloatNeon64(a, b, c, bias, (int)act_type, deep, row, col, stride, write_nhwc);
#else
MatMul8x8(a, b, c, bias, act_type, deep, row, col, stride, write_nhwc);
......
......@@ -31,7 +31,7 @@ void MatMul(const float *a, const float *b, float *c, const float *bias, ActType
void RowMajor2Row8Major(float *src_ptr, float *dst_ptr, int row, int col);
void RowMajor2Col8Major(float *src_ptr, float *dst_ptr, size_t row, size_t col);
void Row8x8Major2RowMajor(float *src_ptr, float *dst_ptr, size_t row, size_t col, size_t stride);
#ifdef __aarch64__
#ifdef ENABLE_ARM64
void MatmulFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row,
int col, size_t stride, bool write_nhwc);
#endif
......
......@@ -197,7 +197,7 @@ int DeConvInt8(const int8_t *input, const int8_t *weight, int32_t *output, int32
size_t act_row, size_t act_col, size_t act_deep, ConvParameter *conv_param,
MATMUL_OPT_R4_FUNC matmul_func) {
if (matmul_func != NULL) {
matmul_func(output, input, weight, weight_sum, input_sum, act_row, act_col, act_deep);
matmul_func(input, weight, output, act_row, act_col, act_deep, input_sum, weight_sum);
} else {
/* todo normal int8 deconv */
}
......
......@@ -74,8 +74,8 @@ void MatMulInt8(const int8_t *a, const int8_t *b, int32_t *c, const int row8, co
}
}
void MatMulOptR4Int8(int32_t *dst, const int8_t *a, const int8_t *b, const int32_t *bias, const int32_t *input_sum,
size_t row_4, size_t col_4, size_t deep_16) {
void MatMulOptR4Int8(const int8_t *a, const int8_t *b, int *dst, int row_4, int col_4, int deep_16,
const int *input_sum, const int *bias) {
/* row4x16-major * row16x4-major => row4x4-major */
for (int r = 0; r < row_4; r++) {
for (int c = 0; c < col_4; c++) {
......@@ -96,3 +96,61 @@ void MatMulOptR4Int8(int32_t *dst, const int8_t *a, const int8_t *b, const int32
}
return;
}
#ifdef ENABLE_ARM64
void RowMajor2Row4x16Major(int8_t *src, int row, int col, int8_t *dst, int col_16) {
int stride = sizeof(int8_t) * 16 * 4;
for (int r = 0; r < row; ++r) {
for (int c = 0; c < col; ++c) {
int stride_n = r / 4 * (col_16 / 16) + c / 16;
int src_idx = r * col + c;
dst[stride * stride_n + r % 4 * 16 + c % 16] = src[src_idx];
}
}
}
void RowMajor2Col16x4Major(int8_t *src, int row, int col, int8_t *dst, int row_16) {
int stride = sizeof(int8_t) * 16 * 4;
for (int r = 0; r < row; ++r) {
for (int c = 0; c < col; ++c) {
int stride_n = c / 4 * (row_16 / 16) + r / 16;
int src_idx = r * col + c;
dst[stride * stride_n + c % 4 * 16 + r % 16] = src[src_idx];
}
}
}
void RowMajor2Asums(int8_t *a, int row, int col, int b_zp, int *dst) {
for (int r = 0; r < row; ++r) {
for (int c = 0; c < col; ++c) {
int src_idx = r * col + c;
dst[r] += a[src_idx];
}
dst[r] *= b_zp;
}
}
void RowMajor2Bbias(int8_t *b, int row, int col, int a_zp, int b_zp, int *bias, int *dst) {
for (int c = 0; c < col; ++c) {
for (int r = 0; r < row; ++r) {
int src_idx = r * col + c;
dst[c] += b[src_idx];
}
dst[c] = row * a_zp * b_zp - a_zp * dst[c];
if (bias) {
dst[c] += bias[c];
}
}
}
void Row4x4Major2RowMajor(int8_t *src, int row4, int8_t *dst, int row, int cow) {
int stride = sizeof(int8_t) * 4 * 4;
for (int r = 0; r < row; ++r) {
for (int c = 0; c < cow; ++c) {
int sride_n = c / 4 * (row4 / 4) + r / 4;
int dst_idx = r * cow + c;
dst[dst_idx] = src[stride * sride_n + r % 4 * 4 + c % 4];
}
}
}
#endif
......@@ -23,14 +23,29 @@
#ifdef __cplusplus
extern "C" {
#endif
void MatMulInt8(const int8_t *a, const int8_t *b, int32_t *c, const int row8, const int col8, const int deep,
const int32_t a_zp, const int32_t b_zp);
void MatMulOptR4Int8(int32_t *dst, const int8_t *a, const int8_t *b, const int32_t *bias, const int32_t *input_sum,
size_t row_4, size_t col_4, size_t deep_16);
void MatMulInt8(const int8_t *a, const int8_t *b, int *c, const int row8, const int col8, const int deep,
const int a_zp, const int b_zp);
void MatMulOptR4Int8(const int8_t *a, const int8_t *b, int *dst, int row_4, int col_4, int deep_16,
const int *input_sum, const int *bias);
void RowMajor2Row8MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col);
void RowMajor2Col8MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col);
void RowMajor2Row16x4MajorInt8(void *src_ptr, void *dst_ptr, int row, int col);
#ifdef ENABLE_ARM64
void RowMajor2Row4x16Major(int8_t *src, int row, int col, int8_t *dst, int col_16);
void RowMajor2Col16x4Major(int8_t *src, int row, int col, int8_t *dst, int row_16);
void RowMajor2Asums(int8_t *a, int row, int col, int b_zp, int *dst);
void RowMajor2Bbias(int8_t *b, int row, int col, int a_zp, int b_zp, int *bias, int *dst);
void Row4x4Major2RowMajor(int8_t *src, int row4, int8_t *dst, int row, int cow);
// bias = bias + depth * a_zp * b_zp - a_zp * b_sums
void MatmulInt8Neon64(const int8_t *a, const int8_t *b, int8_t *dst, int row4, int col4, int deep16, const int *a_sums,
const int *bias, int act_min, int act_max, int out_zp, int multiplier, int left_shift,
int right_shift);
void MatMulR4Int8Neon64(const int8_t *a, const int8_t *b, int32_t *dst, int row4, int col4, int deep16,
const int *input_sum, const int *bias);
#endif
#ifdef __cplusplus
}
#endif
......
......@@ -19,8 +19,8 @@
#include "nnacl/op_base.h"
typedef void (*MATMUL_OPT_R4_FUNC)(int32_t *dst, const int8_t *a, const int8_t *b, const int32_t *bias,
const int32_t *input_sum, size_t row_4, size_t col_4, size_t deep_16);
typedef void (*MATMUL_OPT_R4_FUNC)(const int8_t *a, const int8_t *b, int *dst, int row_4, int col_4, int deep_16,
const int *input_sum, const int *bias);
typedef void (*MAT_TRANS_FUNC)(void *dst, void *a, int row, int col);
......
......@@ -23,6 +23,10 @@ extern void IndirectGemmInt8_24x4_dp(int8_t *dst, const int8_t *src, const int8_
size_t ksize, size_t ic4, size_t output_channel, size_t offset,
const int32_t *input_sum, size_t act_min, size_t act_max, size_t out_zp,
size_t out_multiplier, size_t shift_before, size_t shift_after);
extern void MatMulOptR4Int8Neon64(const int8_t *a, const int8_t *b, int *dst, int row4, int col4, int deep16,
const int *input_sum, const int *bias);
#ifdef __cplusplus
}
#endif
......@@ -35,4 +39,9 @@ void IndirectGemmInt8_optimize_handler(int8_t *dst, const int8_t *src, const int
return IndirectGemmInt8_24x4_dp(dst, src, weight, bias, ksize, ic4, output_channel, offset, input_sum, act_min,
act_max, out_zp, out_multiplier, shift_before, shift_after);
}
void MatMulR4Int8_optimize_handler(const int8_t *a, const int8_t *b, int *dst, int row4, int col4, int deep16,
const int *input_sum, const int *bias) {
return MatMulOptR4Int8Neon64(a, b, dst, row4, col4, deep16, input_sum, bias);
}
#endif
......@@ -64,7 +64,7 @@ class OptimizeModule {
optimized_op_handler_ = dlopen(OPTIMIZE_SHARED_LIBRARY_PATH, RTLD_LAZY);
#endif
if (optimized_op_handler_ == nullptr) {
printf("Open optimize shared library failed.\n");
printf("Open optimize shared library failed: %s\n", dlerror());
}
}
......
......@@ -26,9 +26,7 @@ const double dNormalizer = 0x1p54;
const int dNormalizerBias = 54;
const int iMantissaBits = 31;
void QuantizeMultiplierSmallerThanOne(double double_multiplier, int32_t *quantized_multiplier,
int *right_shift) {
void QuantizeMultiplierSmallerThanOne(double double_multiplier, int32_t *quantized_multiplier, int *right_shift) {
if (quantized_multiplier == NULL || right_shift == NULL) {
return;
}
......@@ -55,10 +53,9 @@ uint8_t QuantizeToUint8(float real_value, float scale, int32_t zp) { return roun
int32_t QuantizeToInt8(float real_value, float scale, int32_t zp) { return round(real_value / scale + zp); }
void CalculateActivationRangeQuantized(bool is_relu, bool is_relu6, int32_t zp, float scale, int *mini,
int *maxi) {
int32_t min = CHAR_MIN;
int32_t max = CHAR_MAX;
void CalculateActivationRangeQuantized(bool is_relu, bool is_relu6, int32_t zp, float scale, int *mini, int *maxi) {
int32_t min = INT8_MIN;
int32_t max = INT8_MAX;
int32_t quantized_zero = QuantizeToInt8(0, scale, zp);
int32_t quantized_six = QuantizeToInt8(6, scale, zp);
if (is_relu) {
......@@ -77,8 +74,8 @@ void CalculateActivationRangeQuantized(bool is_relu, bool is_relu6, int32_t zp,
void Quantize(float *input_data, int length, float scale, int zero_point, int8_t *output_data) {
for (int i = 0; i < length; ++i) {
int q = (int)round(input_data[i] / scale + zero_point);
q = q > CHAR_MAX ? CHAR_MAX : q;
q = q < CHAR_MIN ? CHAR_MIN : q;
q = q > SCHAR_MAX ? SCHAR_MAX : q;
q = q < SCHAR_MIN ? SCHAR_MIN : q;
output_data[i] = (int8_t)q;
}
}
......
......@@ -270,7 +270,7 @@ TEST_F(TestDeconvInt8, MatMulOptTest1) {
7894, -51, 0, 0, -4775, -29785, 0, 0, -12597, 4088, 0, 0, -17420, 1815,
0, 0, 15796, 3101, 0, 0, -37969, -10818, 0, 0, 12714, -7827, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0};
MatMulOptR4Int8(tmp_output, packed_a, packed_b, weight_sum, input_sum, 12, 24, 16);
MatMulOptR4Int8(packed_a, packed_b, tmp_output, 12, 24, 16, input_sum, weight_sum);
CompareOutputData(tmp_output, correct_tmp_output, 12 * 3 * 8, 0);
}
......
......@@ -116,7 +116,6 @@ TEST_F(TestMatmulInt8, mmint8) {
Dequantize(reinterpret_cast<int8_t *>(outputs_[0]->Data()), outputs_[0]->ElementsNum(), output_scale, output_zp,
fout);
CompareOutputData(fout, correct, 6, 0.3);
delete matmul_param;
delete mm;
for (auto t : inputs_) delete t;
for (auto t : outputs_) delete t;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册