提交 cd8b664f 编写于 作者: Z zhanyuan

Optimize the post process of arm64 matmul int8

上级 25bbf5c6
......@@ -24,7 +24,7 @@
//void MatmulInt8Neon64(const int8_t *a, const int8_t *b, int8_t *dst, int row4, int col4, int deep16,
// const int *a_sums, const int *bias, int act_min, int act_max, int out_zp,
// int multiplier, int left_shift, int right_shift);
// int multiplier, int left_shift, int right_shift, int row, int col, int stride);
// x0: a(left matrix ptr)
// x1: b(right matrix ptr)
......@@ -40,13 +40,18 @@
// w11: multiplier
// w12: left_shift
// w13: right_shift
// w14: row
// w15: col
// w24: stride
MatmulInt8Neon64:
sub sp, sp, #160
sub sp, sp, #192
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
stp x19, x20, [sp], #16
stp x21, x22, [sp], #16
stp x23, x24, [sp], #16
stp x25, x26, [sp], #16
ldr w8, [sp]
ldr w9, [sp, #8]
......@@ -54,25 +59,28 @@ MatmulInt8Neon64:
ldr w11, [sp, #24]
ldr w12, [sp, #32]
ldr w13, [sp, #40]
ldr w14, [sp, #48]
ldr w15, [sp, #56]
ldr w24, [sp, #64]
mov w15, #0 // b col index
mov w16, #0 // a row index
mov w17, #4 // sizeof(int8)*4
mul w21, w5, w17 // the stride of a/b: sizeof(int8)*4*deep16
mov w17, #1
mov x25, x2
L1:
cmp w15, w4
cmp w4, #0 // if at the end of col4
beq End1
mov w16, #0 // reset a row index
mov w16, w3 // reset a row4 counter
mov w23, w14 // reset a row counter
mov x17, x0 // reload a ptr
mov x22, x6 // reload a_sums ptr
L2:
cmp w16, w3
cmp w16, #0
beq End2
mov x18, x1 // reload b ptr
mov x19, x7 // reload bias ptr
mov x19, x7 // reload bias ptr
mov w20, w5 // reload depth
dup v16.4s, wzr
dup v17.4s, wzr
......@@ -256,21 +264,128 @@ End3:
sqxtn v15.8b, v13.8h
sqxtn2 v15.16b, v14.8h
st1 {v15.16b}, [x2], #16
add w16, w16, #4 // a row index + 4
cmp w23, #4
blt Write // if rows < 4
cmp w15, #4
blt Write // if cols < 4
st1 {v15.s}[0], [x2], x24
st1 {v15.s}[1], [x2], x24
st1 {v15.s}[2], [x2], x24
st1 {v15.s}[3], [x2], x24
b Endwrite
Write:
cmp w15, #4
beq WriteCol4
cmp w15, #3
beq WriteCol3
cmp w15, #2
beq WriteCol2
cmp w15, #1
beq WriteCol1
WriteCol4:
st1 {v15.s}[0], [x2], x24
cmp w23, #1
beq Endwrite
st1 {v15.s}[1], [x2], x24
cmp w23, #2
beq Endwrite
st1 {v15.s}[2], [x2], x24
cmp w23, #3
beq Endwrite
st1 {v15.s}[3], [x2], x24
b Endwrite
WriteCol3:
mov x26, x2
st1 {v15.b}[0], [x26], #1
st1 {v15.b}[1], [x26], #1
st1 {v15.b}[2], [x26], #1
add x2, x2, x24
cmp w23, #1
beq Endwrite
mov x26, x2
st1 {v15.b}[4], [x26], #1
st1 {v15.b}[5], [x26], #1
st1 {v15.b}[6], [x26], #1
add x2, x2, x24
cmp w23, #2
beq Endwrite
mov x26, x2
st1 {v15.b}[8], [x26], #1
st1 {v15.b}[9], [x26], #1
st1 {v15.b}[10], [x26], #1
add x2, x2, x24
cmp w23, #3
beq Endwrite
mov x26, x2
st1 {v15.b}[12], [x26], #1
st1 {v15.b}[13], [x26], #1
st1 {v15.b}[14], [x26], #1
add x2, x2, x24
b Endwrite
WriteCol2:
mov x26, x2
st1 {v15.b}[0], [x26], #1
st1 {v15.b}[1], [x26], #1
add x2, x2, x24
cmp w23, #1
beq Endwrite
mov x26, x2
st1 {v15.b}[4], [x26], #1
st1 {v15.b}[5], [x26], #1
add x2, x2, x24
cmp w23, #2
beq Endwrite
mov x26, x2
st1 {v15.b}[8], [x26], #1
st1 {v15.b}[9], [x26], #1
add x2, x2, x24
cmp w23, #3
beq Endwrite
mov x26, x2
st1 {v15.b}[12], [x26], #1
st1 {v15.b}[13], [x26], #1
add x2, x2, x24
b Endwrite
WriteCol1:
st1 {v15.b}[0], [x2], x24
cmp w23, #1
beq Endwrite
st1 {v15.b}[4], [x2], x24
cmp w23, #2
beq Endwrite
st1 {v15.b}[8], [x2], x24
cmp w23, #3
beq Endwrite
st1 {v15.b}[12], [x2], x24
b Endwrite
Endwrite:
sub w16, w16, #4 // a row4 counter - 4
sub w23, w23, #4 // a row counter - 4
b L2
End2:
add w15, w15, #4 // b col index + 4
sub w4, w4, #4 // b col4 counter - 4
sub w15, w15, #4 // b col counter - 4
add x1, x1, x21 // b ptr + stride
add x7, x7, #16 // bias ptr + stride
add x25, x25, #4 // output + stride(4 * sizeof(int8))
mov x2, x25
b L1
End1:
sub sp, sp, #160
sub sp, sp, #192
ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
ldp x19, x20, [sp], #16
ldp x21, x22, [sp], #16
ldp x23, x24, [sp], #16
ldp x25, x26, [sp], #16
ret
#endif
......@@ -228,19 +228,3 @@ void IndirectGemmFp32_Comm(float *output, const float *input, const float *weigh
return;
}
void SimplePostFuncInt8(const int *in, int8_t *out, int oc, int plane, int plane8, int32_t multiplier,
int32_t left_shift, int32_t right_shift, int32_t zp) {
/* (int32_t)row8x8-major * multiplier => (int8_t)row-major */
for (int r = 0; r < plane; r++) {
for (int c = 0; c < oc; c++) {
int c8div = c / 8, c8mod = c % 8;
int src_index = c8div * plane8 * 8 + r * 8 + c8mod;
int dst_index = r * oc + c;
int32_t value = in[src_index];
value = MultiplyByQuantizedMultiplier(value, multiplier, left_shift, right_shift) + zp;
value = MSMIN(CHAR_MAX, value);
value = MSMAX(CHAR_MIN, value);
out[dst_index] = (int8_t)value;
}
}
}
......@@ -117,25 +117,6 @@ void RowMajor2Col8MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col)
}
}
void MatMulInt8(const int8_t *a, const int8_t *b, int32_t *c, const int row8, const int col8, const int deep,
const int32_t a_zp, const int32_t b_zp) {
/* col8-major * row8-major => row8x8-major */
for (int row = 0; row < row8; row++) {
for (int col = 0; col < col8; col++) {
int r8div = row / 8, r8mod = row % 8;
int c8div = col / 8, c8mod = col % 8;
size_t ci = c8div * row8 * 8 + row * 8 + c8mod;
int32_t value = 0;
for (int d = 0; d < deep; d++) {
size_t ai = r8div * deep * 8 + d * 8 + r8mod;
size_t bi = c8div * deep * 8 + d * 8 + c8mod;
value = value + ((int32_t)a[ai] - a_zp) * ((int32_t)b[bi] - b_zp);
}
c[ci] = value;
}
}
}
void MatMulInt8_16x4(const int8_t *a, const int8_t *b, int *dst, int row_4, int col_4, int deep_16,
const int *input_sum, const int *bias) {
/* row4x16-major * row16x4-major => row4x4-major */
......@@ -191,6 +172,36 @@ void MatMulInt8_16x4_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row
return;
}
/* row4x16-major * col16x4-major => row4x4-major */
void MatmulInt8(const int8_t *a, const int8_t *b, int8_t *dst, const int *a_sums, const int *bias, int act_min,
int act_max, int out_zp, int multiplier, int left_shift, int right_shift, int row, int col, int deep16,
int stride) {
int8_t *output = dst;
for (int r = 0; r < row; r++) {
for (int c = 0; c < col; c++) {
int r4div = r / C4NUM;
int r4mod = r % C4NUM;
int c4div = c / C4NUM;
int c4mod = c % C4NUM;
int value = 0;
for (int d = 0; d < deep16; d++) {
int d16div = d / C16NUM;
int d16mod = d % C16NUM;
size_t ai = r4div * deep16 * C4NUM + d16div * C4NUM * C16NUM + r4mod * C16NUM + d16mod;
size_t bi = c4div * deep16 * C4NUM + d16div * C4NUM * C16NUM + c4mod * C16NUM + d16mod;
value += a[ai] * b[bi];
}
value -= a_sums[r];
value += bias[c];
value = MultiplyByQuantizedMultiplier(value, multiplier, left_shift, right_shift) + out_zp;
value = MSMIN(INT8_MAX, value);
value = MSMAX(INT8_MIN, value);
output[c] = (int8_t)value;
}
output += stride;
}
}
void RowMajor2Row4x16Major(int8_t *src, int row, int col, int8_t *dst, int col_16) {
int stride = sizeof(int8_t) * 16 * 4;
for (int r = 0; r < row; ++r) {
......@@ -213,23 +224,28 @@ void RowMajor2Col16x4Major(int8_t *src, int row, int col, int8_t *dst, int row_1
}
}
void RowMajor2Asums(int8_t *a, int row, int col, int b_zp, int *dst) {
// dst: weight_zp * input_row_sums
void CalcInputSums(int8_t *input, int row, int col, int weight_zp, int *dst) {
for (int r = 0; r < row; ++r) {
int sum = 0;
for (int c = 0; c < col; ++c) {
int src_idx = r * col + c;
dst[r] += a[src_idx];
sum += input[src_idx];
}
dst[r] *= b_zp;
sum *= weight_zp;
dst[r] = sum;
}
}
void RowMajor2Bbias(int8_t *b, int row, int col, int a_zp, int b_zp, int *bias, int *dst) {
// dst: bias + depth*input_zp*weight_zp - input_zp*weight_col_sums
void CalcWeightBiasSums(int8_t *weight, int row, int col, int input_zp, int weight_zp, int *bias, int *dst) {
for (int c = 0; c < col; ++c) {
int sum = 0;
for (int r = 0; r < row; ++r) {
int src_idx = r * col + c;
dst[c] += b[src_idx];
sum += weight[src_idx];
}
dst[c] = row * a_zp * b_zp - a_zp * dst[c];
dst[c] = row * input_zp * weight_zp - input_zp * sum;
if (bias) {
dst[c] += bias[c];
}
......
......@@ -24,8 +24,6 @@
#ifdef __cplusplus
extern "C" {
#endif
void MatMulInt8(const int8_t *a, const int8_t *b, int *c, const int row8, const int col8, const int deep,
const int a_zp, const int b_zp);
void MatMulInt8_16x4(const int8_t *a, const int8_t *b, int *dst, int row_4, int col_4, int deep_16,
const int *input_sum, const int *bias);
void MatMulInt8_16x4_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_16,
......@@ -39,15 +37,16 @@ void RowMajor2Row16x4MajorInt8(void *src_ptr, void *dst_ptr, int row, int col);
void RowMajor2Row4x16Major(int8_t *src, int row, int col, int8_t *dst, int col_16);
void RowMajor2Col16x4Major(int8_t *src, int row, int col, int8_t *dst, int row_16);
void RowMajor2Asums(int8_t *a, int row, int col, int b_zp, int *dst);
void RowMajor2Bbias(int8_t *b, int row, int col, int a_zp, int b_zp, int *bias, int *dst);
void Row4x4Major2RowMajor(int8_t *src, int row4, int8_t *dst, int row, int cow);
void CalcInputSums(int8_t *a, int row, int col, int b_zp, int *dst);
void CalcWeightBiasSums(int8_t *b, int row, int col, int a_zp, int b_zp, int *bias, int *dst);
void MatmulInt8(const int8_t *a, const int8_t *b, int8_t *dst, const int *a_sums, const int *bias, int act_min,
int act_max, int out_zp, int multiplier, int left_shift, int right_shift, int row, int col, int deep16,
int stride);
#ifdef ENABLE_ARM64
// bias = bias + depth * a_zp * b_zp - a_zp * b_sums
void MatmulInt8Neon64(const int8_t *a, const int8_t *b, int8_t *dst, int row4, int col4, int deep16, const int *a_sums,
const int *bias, int act_min, int act_max, int out_zp, int multiplier, int left_shift,
int right_shift);
int right_shift, int row, int col, int stride);
void MatMulR4Int8Neon64(const int8_t *a, const int8_t *b, int32_t *dst, int row4, int col4, int deep16,
const int *input_sum, const int *bias);
......
......@@ -39,36 +39,32 @@ int FullconnectionInt8CPUKernel::ReSize() {
fc_param_->row_8_ = UP_ROUND(fc_param_->row_, 8);
fc_param_->col_8_ = UP_ROUND(fc_param_->col_, 8);
thread_count_ = MSMIN(thread_count_, UP_DIV(fc_param_->col_8_, 8));
thread_stride_ = UP_DIV(UP_DIV(fc_param_->col_8_, 8), thread_count_);
a_c8_ptr_ =
reinterpret_cast<int8_t *>(ctx_->allocator->Malloc(fc_param_->row_8_ * fc_param_->deep_ * sizeof(int8_t)));
if (!a_c8_ptr_) {
return RET_MEMORY_FAILED;
}
memset(a_c8_ptr_, 0, fc_param_->row_8_ * fc_param_->deep_ * sizeof(int8_t));
b_r8_ptr_ =
reinterpret_cast<int8_t *>(ctx_->allocator->Malloc(fc_param_->col_8_ * fc_param_->deep_ * sizeof(int8_t)));
if (!b_r8_ptr_) {
return RET_MEMORY_FAILED;
}
memset(b_r8_ptr_, 0, fc_param_->col_8_ * fc_param_->deep_ * sizeof(int8_t));
r4_ = UP_ROUND(fc_param_->row_, 4);
c4_ = UP_ROUND(fc_param_->col_, 4);
d16_ = UP_ROUND(fc_param_->deep_, 16);
thread_count_ = MSMIN(thread_count_, UP_DIV(c4_, 4));
thread_stride_ = UP_DIV(UP_DIV(c4_, 4), thread_count_);
a_r4x16_ptr_ = reinterpret_cast<int8_t *>(ctx_->allocator->Malloc(r4_ * d16_ * sizeof(int8_t)));
if (!a_r4x16_ptr_) return RET_MEMORY_FAILED;
memset(a_r4x16_ptr_, 0, r4_ * d16_ * sizeof(int8_t));
b_c16x4_ptr_ = reinterpret_cast<int8_t *>(ctx_->allocator->Malloc(c4_ * d16_ * sizeof(int8_t)));
if (!b_c16x4_ptr_) return RET_MEMORY_FAILED;
memset(b_c16x4_ptr_, 0, c4_ * d16_ * sizeof(int8_t));
input_sums_ = reinterpret_cast<int *>(ctx_->allocator->Malloc(r4_ * sizeof(int)));
if (!input_sums_) return RET_MEMORY_FAILED;
memset(input_sums_, 0, r4_ * sizeof(int));
weight_bias_sums_ = reinterpret_cast<int *>(ctx_->allocator->Malloc(c4_ * sizeof(int)));
if (!weight_bias_sums_) return RET_MEMORY_FAILED;
memset(weight_bias_sums_, 0, c4_ * sizeof(int));
auto weight_data = reinterpret_cast<int8_t *>(in_tensors_[1]->Data());
RowMajor2Col8MajorInt8(weight_data, b_r8_ptr_, fc_param_->col_, fc_param_->deep_);
c_r8x8_ptr_ = reinterpret_cast<int *>(ctx_->allocator->Malloc(fc_param_->row_8_ * fc_param_->col_8_ * sizeof(int)));
if (!c_r8x8_ptr_) {
return RET_MEMORY_FAILED;
}
memset(c_r8x8_ptr_, 0, fc_param_->row_8_ * fc_param_->col_8_ * sizeof(int));
auto bias_len = fc_param_->col_8_ * sizeof(int);
bias_ptr_ = reinterpret_cast<int *>(ctx_->allocator->Malloc(bias_len));
if (!bias_ptr_) {
return RET_MEMORY_FAILED;
}
memset(bias_ptr_, 0, bias_len);
RowMajor2Row4x16Major(weight_data, fc_param_->col_, fc_param_->deep_, b_c16x4_ptr_, d16_);
if (in_tensors_.size() == 3) {
auto bias_len = fc_param_->col_8_ * sizeof(int);
bias_ptr_ = reinterpret_cast<int *>(ctx_->allocator->Malloc(bias_len));
if (!bias_ptr_) return RET_MEMORY_FAILED;
memcpy(bias_ptr_, in_tensors_[2]->Data(), bias_len);
} else {
bias_ptr_ = NULL;
}
auto input_tensor = in_tensors_[0];
......@@ -93,18 +89,32 @@ int FullconnectionInt8CPUKernel::ReSize() {
CalculateActivationRangeQuantized(fc_param_->act_type_ == ActType_Relu, fc_param_->act_type_ == ActType_Relu6,
quant_params_.output.zp_, quant_params_.output.scale_, &quant_params_.out_act_min,
&quant_params_.out_act_max);
CalcWeightBiasSums(weight_data, fc_param_->deep_, fc_param_->col_, quant_params_.input.zp_, quant_params_.weight.zp_,
bias_ptr_, weight_bias_sums_);
return RET_OK;
}
int FullconnectionInt8CPUKernel::RunImpl(int task_id) {
int cur_oc = MSMIN(thread_stride_, UP_DIV(fc_param_->col_8_, 8) - task_id * thread_stride_);
int cur_oc = MSMIN(thread_stride_, UP_DIV(c4_, 4) - task_id * thread_stride_);
if (cur_oc <= 0) {
return RET_OK;
}
auto &p = quant_params_;
auto cur_b = b_r8_ptr_ + task_id * thread_stride_ * C8NUM * fc_param_->deep_;
auto cur_c = c_r8x8_ptr_ + task_id * thread_stride_ * C8NUM * fc_param_->row_8_;
MatMulInt8(a_c8_ptr_, cur_b, cur_c, fc_param_->row_8_, cur_oc * 8, fc_param_->deep_, p.input.zp_, p.weight.zp_);
int cur_oc_res = MSMIN(thread_stride_ * C4NUM, fc_param_->col_ - task_id * thread_stride_ * C4NUM);
auto &q = quant_params_;
auto &p = fc_param_;
auto cur_b = b_c16x4_ptr_ + task_id * thread_stride_ * C4NUM * d16_;
auto cur_bias = weight_bias_sums_ + task_id * thread_stride_ * C4NUM;
auto output_ptr = reinterpret_cast<int8_t *>(out_tensors_[0]->Data());
auto cur_c = output_ptr + task_id * thread_stride_ * C4NUM;
#ifdef ENABLE_ARM64
MatmulInt8Neon64(a_r4x16_ptr_, cur_b, cur_c, r4_, cur_oc * C4NUM, d16_, input_sums_, cur_bias, q.out_act_min,
q.out_act_max, q.output.zp_, q.quant_multiplier, q.left_shift, q.right_shift, p->row_, cur_oc_res,
p->col_ * sizeof(int8_t));
#else
MatmulInt8(a_r4x16_ptr_, cur_b, cur_c, input_sums_, cur_bias, q.out_act_min, q.out_act_max, q.output.zp_,
q.quant_multiplier, q.left_shift, q.right_shift, p->row_, cur_oc_res, d16_, p->col_);
#endif
return RET_OK;
}
......@@ -124,13 +134,10 @@ int FullconnectionInt8CPUKernel::Run() {
MS_LOG(ERROR) << "Prepare failed.";
return RET_ERROR;
}
auto a_ptr = reinterpret_cast<int8_t *>(in_tensors_[0]->Data());
auto output_ptr = reinterpret_cast<int8_t *>(out_tensors_[0]->Data());
auto &p = quant_params_;
RowMajor2Col8MajorInt8(a_ptr, a_c8_ptr_, fc_param_->row_, fc_param_->deep_);
auto input_ptr = reinterpret_cast<int8_t *>(in_tensors_[0]->Data());
RowMajor2Row4x16Major(input_ptr, fc_param_->row_, fc_param_->deep_, a_r4x16_ptr_, d16_);
CalcInputSums(input_ptr, fc_param_->row_, fc_param_->deep_, quant_params_.weight.zp_, input_sums_);
LiteBackendParallelLaunch(FcInt8Run, this, thread_count_);
PostFuncInt8C8(c_r8x8_ptr_, bias_ptr_, output_ptr, fc_param_->col_, fc_param_->row_, p.quant_multiplier, p.left_shift,
p.right_shift, p.output.zp_, p.out_act_min, p.out_act_max);
return RET_OK;
}
......
......@@ -41,28 +41,36 @@ class FullconnectionInt8CPUKernel : public FullconnectionBaseCPUKernel {
private:
void FreeTmpBuffer() {
if (a_c8_ptr_ != nullptr) {
ctx_->allocator->Free(a_c8_ptr_);
a_c8_ptr_ = nullptr;
if (a_r4x16_ptr_ != nullptr) {
ctx_->allocator->Free(a_r4x16_ptr_);
a_r4x16_ptr_ = nullptr;
}
if (b_r8_ptr_ != nullptr) {
ctx_->allocator->Free(b_r8_ptr_);
b_r8_ptr_ = nullptr;
if (b_c16x4_ptr_ != nullptr) {
ctx_->allocator->Free(b_c16x4_ptr_);
b_c16x4_ptr_ = nullptr;
}
if (c_r8x8_ptr_ != nullptr) {
ctx_->allocator->Free(c_r8x8_ptr_);
c_r8x8_ptr_ = nullptr;
if (input_sums_ != nullptr) {
ctx_->allocator->Free(input_sums_);
input_sums_ = nullptr;
}
if (weight_bias_sums_ != nullptr) {
ctx_->allocator->Free(weight_bias_sums_);
weight_bias_sums_ = nullptr;
}
if (bias_ptr_ != nullptr) {
ctx_->allocator->Free(bias_ptr_);
bias_ptr_ = nullptr;
ctx_->allocator->Free(weight_bias_sums_);
weight_bias_sums_ = nullptr;
}
}
MatmulQuantArg quant_params_;
int8_t *a_c8_ptr_ = nullptr;
int8_t *b_r8_ptr_ = nullptr;
int *c_r8x8_ptr_ = nullptr;
int8_t *a_r4x16_ptr_ = nullptr;
int8_t *b_c16x4_ptr_ = nullptr;
int *input_sums_ = nullptr;
int *weight_bias_sums_ = nullptr;
int *bias_ptr_ = nullptr;
int r4_;
int c4_;
int d16_;
};
} // namespace mindspore::kernel
......
......@@ -48,46 +48,23 @@ int MatmulInt8CPUKernel::ReSize() {
params_->row_8_ = UP_ROUND(params_->row_, 8);
params_->col_8_ = UP_ROUND(params_->col_, 8);
#ifdef ENABLE_ARM64
r4_ = UP_ROUND(params_->row_, 4);
c4_ = UP_ROUND(params_->col_, 4);
d16_ = UP_ROUND(params_->deep_, 16);
a_r4d16_ptr_ = reinterpret_cast<int8_t *>(ctx_->allocator->Malloc(r4_ * d16_ * sizeof(int8_t)));
if (!a_r4d16_ptr_) return RET_MEMORY_FAILED;
memset(a_r4d16_ptr_, 0, r4_ * d16_ * sizeof(int8_t));
b_c4d16_ptr_ = reinterpret_cast<int8_t *>(ctx_->allocator->Malloc(c4_ * d16_ * sizeof(int8_t)));
if (!b_c4d16_ptr_) return RET_MEMORY_FAILED;
memset(b_c4d16_ptr_, 0, c4_ * d16_ * sizeof(int8_t));
c_r4c4_ptr_ = reinterpret_cast<int8_t *>(ctx_->allocator->Malloc(r4_ * c4_ * sizeof(int8_t)));
if (!c_r4c4_ptr_) return RET_MEMORY_FAILED;
memset(c_r4c4_ptr_, 0, r4_ * c4_ * sizeof(int8_t));
a_sums_ = reinterpret_cast<int *>(ctx_->allocator->Malloc(r4_ * sizeof(int)));
if (!a_sums_) return RET_MEMORY_FAILED;
memset(a_sums_, 0, r4_ * sizeof(int));
b_bias_ = reinterpret_cast<int *>(ctx_->allocator->Malloc(c4_ * sizeof(int)));
if (!b_bias_) return RET_MEMORY_FAILED;
memset(b_bias_, 0, c4_ * sizeof(int));
a_r4x16_ptr_ = reinterpret_cast<int8_t *>(ctx_->allocator->Malloc(r4_ * d16_ * sizeof(int8_t)));
if (!a_r4x16_ptr_) return RET_MEMORY_FAILED;
memset(a_r4x16_ptr_, 0, r4_ * d16_ * sizeof(int8_t));
b_c16x4_ptr_ = reinterpret_cast<int8_t *>(ctx_->allocator->Malloc(c4_ * d16_ * sizeof(int8_t)));
if (!b_c16x4_ptr_) return RET_MEMORY_FAILED;
memset(b_c16x4_ptr_, 0, c4_ * d16_ * sizeof(int8_t));
input_sums_ = reinterpret_cast<int *>(ctx_->allocator->Malloc(r4_ * sizeof(int)));
if (!input_sums_) return RET_MEMORY_FAILED;
memset(input_sums_, 0, r4_ * sizeof(int));
weight_bias_sums_ = reinterpret_cast<int *>(ctx_->allocator->Malloc(c4_ * sizeof(int)));
if (!weight_bias_sums_) return RET_MEMORY_FAILED;
memset(weight_bias_sums_, 0, c4_ * sizeof(int));
thread_count_ = MSMIN(thread_count_, UP_DIV(c4_, 4));
thread_stride_ = UP_DIV(UP_DIV(c4_, 4), thread_count_);
#else
a_c8_ptr_ = reinterpret_cast<int8_t *>(ctx_->allocator->Malloc(params_->row_8_ * params_->deep_ * sizeof(int8_t)));
if (!a_c8_ptr_) {
return RET_MEMORY_FAILED;
}
memset(a_c8_ptr_, 0, params_->row_8_ * params_->deep_ * sizeof(int8_t));
b_r8_ptr_ = reinterpret_cast<int8_t *>(ctx_->allocator->Malloc(params_->col_8_ * params_->deep_ * sizeof(int8_t)));
if (!b_r8_ptr_) {
return RET_MEMORY_FAILED;
}
memset(b_r8_ptr_, 0, params_->col_8_ * params_->deep_ * sizeof(int8_t));
c_r8x8_ptr_ = reinterpret_cast<int *>(ctx_->allocator->Malloc(params_->row_8_ * params_->col_8_ * sizeof(int)));
if (!c_r8x8_ptr_) {
return RET_MEMORY_FAILED;
}
memset(c_r8x8_ptr_, 0, params_->row_8_ * params_->col_8_ * sizeof(int));
thread_count_ = MSMIN(thread_count_, UP_DIV(params_->col_8_, 8));
thread_stride_ = UP_DIV(UP_DIV(params_->col_8_, 8), thread_count_);
#endif
auto input_tensor = in_tensors_[0];
auto params = input_tensor->GetQuantParams();
......@@ -112,27 +89,25 @@ int MatmulInt8CPUKernel::ReSize() {
}
int MatmulInt8CPUKernel::RunImpl(int task_id) {
#ifdef ENABLE_ARM64
int cur_oc = MSMIN(thread_stride_, UP_DIV(c4_, 4) - task_id * thread_stride_);
if (cur_oc <= 0) {
return RET_OK;
}
auto cur_b = b_c4d16_ptr_ + task_id * thread_stride_ * 4 * d16_;
auto cur_c = c_r4c4_ptr_ + task_id * thread_stride_ * 4 * r4_;
int cur_oc_res = MSMIN(thread_stride_ * C4NUM, params_->col_ - task_id * thread_stride_ * C4NUM);
auto cur_b = b_c16x4_ptr_ + task_id * thread_stride_ * 4 * d16_;
auto cur_bias = weight_bias_sums_ + task_id * thread_stride_ * 4;
auto cur_c = c_ptr_ + task_id * thread_stride_ * 4;
auto &p = quant_params_;
MatmulInt8Neon64(a_r4d16_ptr_, cur_b, cur_c, r4_, c4_, d16_, a_sums_, b_bias_, INT_MIN, INT_MAX, p.output.zp_,
p.quant_multiplier, p.left_shift, p.right_shift);
#ifdef ENABLE_ARM64
MatmulInt8Neon64(a_r4x16_ptr_, cur_b, cur_c, r4_, cur_oc * C4NUM, d16_, input_sums_, cur_bias, INT8_MIN, INT8_MAX,
p.output.zp_, p.quant_multiplier, p.left_shift, p.right_shift, params_->row_, cur_oc_res,
params_->col_ * sizeof(int8_t));
#else
int cur_oc = MSMIN(thread_stride_, UP_DIV(params_->col_8_, 8) - task_id * thread_stride_);
if (cur_oc <= 0) {
return RET_OK;
}
auto cur_b = b_r8_ptr_ + task_id * thread_stride_ * C8NUM * params_->deep_;
auto cur_c = c_r8x8_ptr_ + task_id * thread_stride_ * C8NUM * params_->row_8_;
MatMulInt8(a_c8_ptr_, cur_b, cur_c, params_->row_8_, cur_oc * 8, params_->deep_, quant_params_.input.zp_,
quant_params_.weight.zp_);
MatmulInt8(a_r4x16_ptr_, cur_b, cur_c, input_sums_, cur_bias, INT8_MIN, INT8_MAX, p.output.zp_, p.quant_multiplier,
p.left_shift, p.right_shift, params_->row_, cur_oc_res, d16_, params_->col_);
#endif
return RET_OK;
}
......@@ -162,43 +137,27 @@ int MatmulInt8CPUKernel::Run() {
for (int i = 0; i < params_->batch; ++i) {
auto cur_a_ptr = a_ptr + i * a_stride;
auto cur_b_ptr = b_ptr + i * b_stride;
auto cur_c_ptr = c_ptr + i * c_stride;
#ifdef ENABLE_ARM64
if (params_->a_transpose_) {
RowMajor2Col16x4Major(cur_a_ptr, params_->deep_, params_->row_, a_r4d16_ptr_, d16_);
RowMajor2Col16x4Major(cur_a_ptr, params_->deep_, params_->row_, a_r4x16_ptr_, d16_);
} else {
RowMajor2Row4x16Major(cur_a_ptr, params_->row_, params_->deep_, a_r4d16_ptr_, d16_);
RowMajor2Row4x16Major(cur_a_ptr, params_->row_, params_->deep_, a_r4x16_ptr_, d16_);
}
if (params_->b_transpose_) {
RowMajor2Row4x16Major(cur_b_ptr, params_->col_, params_->deep_, b_c4d16_ptr_, d16_);
RowMajor2Row4x16Major(cur_b_ptr, params_->col_, params_->deep_, b_c16x4_ptr_, d16_);
} else {
RowMajor2Col16x4Major(cur_b_ptr, params_->deep_, params_->col_, b_c4d16_ptr_, d16_);
RowMajor2Col16x4Major(cur_b_ptr, params_->deep_, params_->col_, b_c16x4_ptr_, d16_);
}
c_ptr_ = c_ptr + i * c_stride;
auto &q = quant_params_;
RowMajor2Asums(cur_a_ptr, params_->row_, params_->deep_, q.weight.zp_, a_sums_);
RowMajor2Bbias(cur_b_ptr, params_->deep_, params_->col_, q.input.zp_, q.weight.zp_, NULL, b_bias_);
LiteBackendParallelLaunch(MatmulInt8Run, this, thread_count_);
Row4x4Major2RowMajor(c_r4c4_ptr_, r4_, cur_c_ptr, params_->row_, params_->col_);
#else
if (params_->a_transpose_) {
RowMajor2Row8MajorInt8(cur_a_ptr, a_c8_ptr_, params_->deep_, params_->row_);
} else {
RowMajor2Col8MajorInt8(cur_a_ptr, a_c8_ptr_, params_->row_, params_->deep_);
}
if (params_->b_transpose_) {
RowMajor2Col8MajorInt8(cur_b_ptr, b_r8_ptr_, params_->col_, params_->deep_);
} else {
RowMajor2Row8MajorInt8(cur_b_ptr, b_r8_ptr_, params_->deep_, params_->col_);
CalcInputSums(cur_a_ptr, params_->row_, params_->deep_, q.weight.zp_, input_sums_);
CalcWeightBiasSums(cur_b_ptr, params_->deep_, params_->col_, q.input.zp_, q.weight.zp_, NULL, weight_bias_sums_);
ret = LiteBackendParallelLaunch(MatmulInt8Run, this, thread_count_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "MatmulInt8Run error: [" << ret << "]";
return ret;
}
LiteBackendParallelLaunch(MatmulInt8Run, this, thread_count_);
auto &q = quant_params_;
SimplePostFuncInt8(c_r8x8_ptr_, cur_c_ptr, params_->col_, params_->row_, params_->row_8_, q.quant_multiplier,
q.left_shift, q.right_shift, q.output.zp_);
#endif
}
return RET_OK;
}
} // namespace mindspore::kernel
......@@ -39,57 +39,32 @@ class MatmulInt8CPUKernel : public MatmulBaseCPUKernel {
private:
void FreeTmpBuffer() {
#ifdef ENABLE_ARM64
if (a_r4d16_ptr_ != nullptr) {
ctx_->allocator->Free(a_r4d16_ptr_);
a_r4d16_ptr_ = nullptr;
if (a_r4x16_ptr_ != nullptr) {
ctx_->allocator->Free(a_r4x16_ptr_);
a_r4x16_ptr_ = nullptr;
}
if (b_c4d16_ptr_ != nullptr) {
ctx_->allocator->Free(b_c4d16_ptr_);
b_c4d16_ptr_ = nullptr;
if (b_c16x4_ptr_ != nullptr) {
ctx_->allocator->Free(b_c16x4_ptr_);
b_c16x4_ptr_ = nullptr;
}
if (c_r4c4_ptr_ != nullptr) {
ctx_->allocator->Free(c_r4c4_ptr_);
c_r4c4_ptr_ = nullptr;
if (input_sums_ != nullptr) {
ctx_->allocator->Free(input_sums_);
input_sums_ = nullptr;
}
if (a_sums_ != nullptr) {
ctx_->allocator->Free(a_sums_);
a_sums_ = nullptr;
if (weight_bias_sums_ != nullptr) {
ctx_->allocator->Free(weight_bias_sums_);
weight_bias_sums_ = nullptr;
}
if (b_bias_ != nullptr) {
ctx_->allocator->Free(b_bias_);
b_bias_ = nullptr;
}
#else
if (a_c8_ptr_ != nullptr) {
ctx_->allocator->Free(a_c8_ptr_);
a_c8_ptr_ = nullptr;
}
if (b_r8_ptr_ != nullptr) {
ctx_->allocator->Free(b_r8_ptr_);
b_r8_ptr_ = nullptr;
}
if (c_r8x8_ptr_ != nullptr) {
ctx_->allocator->Free(c_r8x8_ptr_);
c_r8x8_ptr_ = nullptr;
}
#endif
}
MatmulQuantArg quant_params_;
#ifdef ENABLE_ARM64
int8_t *a_r4d16_ptr_ = nullptr;
int8_t *b_c4d16_ptr_ = nullptr;
int8_t *c_r4c4_ptr_ = nullptr;
int *a_sums_ = nullptr;
int *b_bias_ = nullptr;
int8_t *a_r4x16_ptr_ = nullptr;
int8_t *b_c16x4_ptr_ = nullptr;
int8_t *c_ptr_ = nullptr;
int *input_sums_ = nullptr;
int *weight_bias_sums_ = nullptr;
int r4_;
int c4_;
int d16_;
#else
int8_t *a_c8_ptr_ = nullptr;
int8_t *b_r8_ptr_ = nullptr;
int *c_r8x8_ptr_ = nullptr;
#endif
}; // namespace mindspore::kernel
} // namespace mindspore::kernel
......
......@@ -134,54 +134,6 @@ TEST_F(TestDeconvInt8, PackInputTest1) {
CompareOutputData(dst, co, 8 * 32, 1);
}
TEST_F(TestDeconvInt8, MatMulTest1) {
int8_t a_row_major_10_12[] = {
-6, 76, 32, 80, -73, 8, -85, -3, 114, 80, 30, 42, -41, 117, 62, -76, -77, -111, 88, 105,
68, 105, -74, 13, 51, 94, 31, -52, -92, -4, -35, -71, 101, -93, 46, -65, 57, -41, -51, 77,
1, 9, 73, -19, -36, 57, 81, -24, 40, 103, 112, 109, -41, -68, 57, 61, 55, -20, 3, 2,
17, -16, -31, 58, -4, 67, -4, -95, -5, -72, 81, 15, -7, -16, -47, 112, 114, -26, -98, 53,
15, -49, 26, 19, 19, 8, -57, -35, -79, 118, 29, 21, 37, -48, 83, 7, 124, 113, -5, 15,
-8, 107, -65, -88, 50, -47, -80, -84, 3, -45, 92, 42, -20, -101, 106, -10, 89, 67, 55, 10};
int32_t zp_a = 15;
int8_t a_col8_major[16 * 12] = {0};
int8_t b_col_major_12_18[] = {
92, 27, 22, 52, -112, -20, -57, -2, 89, 32, 93, -66, -25, -54, 94, -97, -119, -98, 101, -99,
77, -83, 76, 95, 59, 97, 8, 40, -109, -20, 67, -107, 37, -6, -54, -20, -30, 36, -106, -103,
-3, -86, -82, 59, 4, -75, -50, -106, 55, 104, -117, -71, -20, -85, -77, 16, -25, -58, 4, 80,
-75, 94, 32, -68, 2, 40, 56, -103, 11, -98, -70, -69, 0, 57, -6, 82, 66, -112, -61, 33,
-77, -53, 95, -38, 87, -46, -3, 81, -47, 43, 21, 26, -45, -57, 50, -24, -82, -114, 61, 46,
-53, 78, -24, 31, -7, 37, 29, 38, 45, 106, 52, -42, 31, -6, -61, -87, 2, 79, -5, -42,
43, -106, -104, 7, 91, -63, 58, 97, -15, 74, -96, 15, -23, -3, -47, -97, 100, -54, 26, -46,
35, 26, 100, -80, 34, -25, 96, -67, -80, -27, 66, 41, 41, -43, -43, -38, -4, -64, 31, 7,
-8, 6, -2, 39, -119, 53, 75, -91, -44, 77, -62, 22, -44, 78, -67, -48, -115, -4, 43, 81,
40, -20, -5, -89, 60, -62, -4, -48, 66, -64, -69, 62, 17, -89, 1, 87, 81, 32, -29, 51,
40, 27, 66, 67, 11, -69, 85, -79, -106, 55, 22, -23, 62, 69, -74, 49};
int32_t zp_b = -20;
int8_t b_row8_major[12 * 24] = {0};
int32_t co_row_major_10_18[] = {
32005, 3597, 16595, -3458, 6627, -6663, 818, -3910, 10228, 15079, -19205, -10203, -3178, -10046,
10374, -6199, 5330, 12163, 1819, 20533, 17382, 18283, 9778, 9185, -12623, -26234, -11987, 7904,
8144, -1603, 27611, -10190, -20053, 4999, -28389, 21852, 24680, 25858, 23506, 17944, 11768, 24378,
-6102, -4675, -23460, 10434, -47579, 1986, 12018, -19418, -7248, 4938, -32613, -941, 8171, -4788,
3325, -11310, -8351, -14786, 6909, 16401, 2017, -6456, 11242, 7393, -9119, 17312, 2646, -14402,
7201, -9949, 23986, 17607, 27461, -1547, 2783, 7558, 19487, 11158, -2686, 6328, -8225, -11668,
21858, -2079, -8671, -639, -1544, 1235, 1156, 6582, 2829, -10311, -2692, 5154, 1527, 10870,
106, -8189, -24174, -1846, -15399, -3598, 14874, -5591, -619, -13667, -6053, -31103, -24499, 13008,
9143, -17982, 28437, 2176, -2114, -11631, 10779, -1032, -24690, -3112, 2125, 432, 20270, -33859,
8907, 10063, 1603, 3761, 4805, 4904, -15594, 10786, 4287, -13591, -18777, -1679, 2109, -2243,
12051, -8504, -6558, 4209, 13606, -25803, 27922, 12092, 7140, 27142, -12267, 2339, -26224, 23674,
-26579, -11398, -1823, -18976, 3641, 4415, -24878, -2045, 15937, 41465, 12601, -14513, -17619, -5728,
334, -424, 8147, -1369, 5984, 11000, 19016, 4456, -25920, 4506, 5930, 15458};
int32_t c_row8x8_major[16 * 24] = {0};
int32_t out_row_major[180] = {0};
RowMajor2Col8MajorInt8(a_row_major_10_12, a_col8_major, 10, 12);
RowMajor2Col8MajorInt8(b_col_major_12_18, b_row8_major, 18, 12);
MatMulInt8(a_col8_major, b_row8_major, c_row8x8_major, 16, 24, 12, zp_a, zp_b);
Row8x8Major2RowMajor(reinterpret_cast<float *>(c_row8x8_major), reinterpret_cast<float *>(out_row_major), 10, 18, 18);
CompareOutputData(out_row_major, co_row_major_10_18, 180, 1);
}
TEST_F(TestDeconvInt8, InputSumTest1) {
int8_t packed_a[] = {
-6, 76, 32, 80, -73, 8, -85, -3, 114, 80, 30, 42, 15, 15, 15, 15, -41, 117, 62, -76, -77, -111,
......
......@@ -29,99 +29,128 @@ class TestFcInt8 : public mindspore::CommonTest {
TestFcInt8() {}
};
int FcInt8TestInit(std::vector<lite::tensor::Tensor *> *inputs_, std::vector<lite::tensor::Tensor *> *outputs_,
MatMulParameter *matmal_param, float **correct, double *scale, int *zeropoint) {
float input_max = 20;
float input_min = -20;
float weight_max = 1;
float weight_min = -1;
float output_max = 20;
float output_min = -20;
struct TensorInfo {
float *data;
int *data_int;
float min;
float max;
int len;
std::vector<int> *shape;
};
double input_scale =
(input_max - input_min) / (std::numeric_limits<int8_t>::max() - std::numeric_limits<int8_t>::min());
int input_zp = std::numeric_limits<int8_t>::max() - input_max / input_scale;
double weight_scale =
(weight_max - weight_min) / (std::numeric_limits<int8_t>::max() - std::numeric_limits<int8_t>::min());
int weight_zp = std::numeric_limits<int8_t>::max() - weight_max / weight_scale;
double output_scale =
(output_max - output_min) / (std::numeric_limits<int8_t>::max() - std::numeric_limits<int8_t>::min());
int output_zp = std::numeric_limits<int8_t>::max() - output_max / output_scale;
*scale = output_scale;
*zeropoint = output_zp;
extern void QuantProcess(float *input, int len, float min, float max, float *scale, int *zero_point, int8_t *output);
extern lite::tensor::Tensor *MakeQuantTensor(int8_t *data, int len, std::vector<int> *shape, float scale, int zp);
Tensor *in_t = new Tensor(kNumberTypeInt8, {2, 2, 2, 2}, schema::Format_NHWC, static_cast<schema::NodeType>(1));
in_t->MallocData();
float in[] = {-3.2366564, -4.7733846, -7.8329225, 16.146885, 5.060793, -6.1471, -1.7680453, -6.5721383,
17.87506, -5.1192183, 10.742863, 1.4536934, 19.693445, 19.45783, 5.063163, 0.5234792};
Quantize(in, in_t->ElementsNum(), input_scale, input_zp, reinterpret_cast<int8_t *>(in_t->Data()));
auto in_quant_arg = new mindspore::lite::tensor::QuantArg();
in_quant_arg->zeroPoint = input_zp;
in_quant_arg->scale = input_scale;
in_t->AddQuantParam(*in_quant_arg);
inputs_->push_back(in_t);
lite::tensor::Tensor *MakeIntTensor(int *data, int len, std::vector<int> *shape) {
auto tensor =
new lite::tensor::Tensor(kNumberTypeInt32, *shape, schema::Format_NHWC, static_cast<schema::NodeType>(1));
tensor->MallocData();
auto tensor_ptr = reinterpret_cast<int *>(tensor->Data());
memcpy(tensor_ptr, data, len * sizeof(int));
return tensor;
}
Tensor *weight_t = new Tensor(kNumberTypeInt8, {3, 8}, schema::Format_NHWC, static_cast<schema::NodeType>(1));
weight_t->MallocData();
float weight[] = {-0.24438887, 0.06738146, -0.8169129, 0.21510671, -0.012470592, -0.053063435,
0.6050155, 0.8656233, 0.12911413, -0.028635843, -0.034080597, -0.10622552,
-0.012254699, -0.01312836, 0.25241964, -0.4706142, 0.2451482, -0.9558459,
0.4481974, 0.33251503, -0.011705584, -0.1720293, -0.39410214, -0.73637343};
Quantize(weight, weight_t->ElementsNum(), weight_scale, weight_zp, reinterpret_cast<int8_t *>(weight_t->Data()));
auto weight_quant_arg = new mindspore::lite::tensor::QuantArg();
weight_quant_arg->zeroPoint = weight_zp;
weight_quant_arg->scale = weight_scale;
weight_t->AddQuantParam(*weight_quant_arg);
inputs_->push_back(weight_t);
void FcInt8TestInit(std::vector<lite::tensor::Tensor *> *inputs, std::vector<lite::tensor::Tensor *> *outputs,
TensorInfo *in, TensorInfo *weight, TensorInfo *bias, TensorInfo *out) {
float in_scale, weight_scale, out_scale;
int in_zp, weight_zp, out_zp;
int8_t *in_data = new int8_t[in->len];
int8_t *weight_data = new int8_t[weight->len];
QuantProcess(in->data, in->len, in->min, in->max, &in_scale, &in_zp, in_data);
auto in_tensor = MakeQuantTensor(in_data, in->len, in->shape, in_scale, in_zp);
inputs->push_back(in_tensor);
QuantProcess(weight->data, weight->len, weight->min, weight->max, &weight_scale, &weight_zp, weight_data);
auto weight_tensor = MakeQuantTensor(weight_data, weight->len, weight->shape, weight_scale, weight_zp);
inputs->push_back(weight_tensor);
auto bias_tensor = MakeIntTensor(bias->data_int, bias->len, bias->shape);
inputs->push_back(bias_tensor);
QuantProcess(out->data, out->len, out->min, out->max, &out_scale, &out_zp, nullptr);
auto out_tensor = MakeQuantTensor(nullptr, out->len, out->shape, out_scale, out_zp);
outputs->push_back(out_tensor);
delete[] in_data;
delete[] weight_data;
}
Tensor *bias_t = new Tensor(kNumberTypeInt32, {3}, schema::Format_NHWC, static_cast<schema::NodeType>(1));
bias_t->MallocData();
memset(bias_t->Data(), 0, sizeof(int) * bias_t->ElementsNum());
inputs_->push_back(bias_t);
TEST_F(TestFcInt8, fctest1) {
float in[] = {4.259103407444801, 5.992151035772917, -9.495343223733581, 3.0509999931426215, -16.635707833991095,
-14.72005749234452, 2.8290916795754093, -15.827977973039049, -16.98208477063347, 2.8801101778935347,
-0.5905297521382735, 18.042746010536085, 3.913511213700396, 11.571264917136105, 19.084257392926148,
8.571560238377568, 17.58868010598305, 12.433311533838427, 4.548078598583526, 15.609650071521138,
6.663372887795717, 17.581323475674594, 1.453277207446778, -6.119351424589654, -16.87310296820285,
11.906066592064796, -13.290100998834653, 19.627129875430548, 16.034262583959162, 10.255738135902781,
12.134650347811792, -5.5882066903433305, 15.554050723026322, 15.288481461776783, 17.651080309797287,
-9.258779162183215, 4.218532791445092, -6.205309122668545, 1.2220458021156908, 1.6800736573947326};
TensorInfo in_params;
in_params.data = in;
in_params.len = 40;
std::vector<int> in_shape{5, 2, 2, 2};
in_params.shape = &in_shape;
in_params.min = -20;
in_params.max = 20;
Tensor *out_t = new Tensor(kNumberTypeInt8, {2, 3}, schema::Format_NHWC, static_cast<schema::NodeType>(1));
out_t->MallocData();
auto output_quant_arg = new mindspore::lite::tensor::QuantArg();
output_quant_arg->zeroPoint = output_zp;
output_quant_arg->scale = output_scale;
out_t->AddQuantParam(*output_quant_arg);
outputs_->push_back(out_t);
float weight[] = {
-0.586269014312498, 0.10845796767603733, 0.8455159907124523, 0.20261291069007226, 0.7564258582027543,
0.4505005038790615, -0.607259232240795, -0.6962171798923924, 0.7967573009922135, -0.46069496925353715,
-0.2967638879316592, -0.7025557337565955, -0.5313515272071268, 0.07584168670764102, -0.6860034691410029,
0.9218806800279316, -0.07408538201953907, -0.7933652717840096, 0.6636691558029275, -0.30198695606477477,
0.790225747868754, -0.9478140254555916, 0.4537316306461665, 0.1776848732022871, -0.7492316745474277,
-0.5825825240770948, 0.5680842804542614, -0.9255552309192772, 0.20866577718844725, 0.9570928647172854,
0.18172570688854406, -0.26442830241827253, -0.24765169216720873, -0.19512285277145702, 0.1120696020054861,
0.7558578199370625, -0.15032457481135109, -0.08485585411928809, 0.6343014796699504, 0.026380085222785787,
-0.40516674259120444, -0.7407588590646037, -0.28521396461492454, 0.2555841827858194, 0.023640857478332444,
-0.6540694390119834, 0.7439705499824205, -0.7579774562590929};
TensorInfo weight_params;
weight_params.data = weight;
weight_params.len = 48;
std::vector<int> weight_shape{6, 8};
weight_params.shape = &weight_shape;
weight_params.min = -1;
weight_params.max = 1;
*correct = reinterpret_cast<float *>(malloc(out_t->ElementsNum() * sizeof(float)));
float nchw_co[] = {3.84586822, 0.93586633, 12.16212629, -10.93835061, 2.46887183, 8.61480108};
memcpy(*correct, nchw_co, out_t->ElementsNum() * sizeof(float));
int bias[6] = {0};
TensorInfo bias_params;
bias_params.data_int = bias;
bias_params.len = 6;
std::vector<int> bias_shape{6};
bias_params.shape = &bias_shape;
matmal_param->b_transpose_ = true;
matmal_param->a_transpose_ = false;
matmal_param->has_bias_ = true;
matmal_param->act_type_ = ActType_No;
return out_t->ElementsNum();
}
float correct[] = {-19.170732, -7.5019627, -13.015462, -27.760283, 4.1447954, 20.660276, 4.0412164, -33.750015,
-4.560128, 7.1035166, 27.976341, 9.75216, 14.383608, -12.87587, -24.688887, -12.185722,
3.7933283, -19.266382, 17.193876, -49.99205, -15.480089, -3.1659412, 19.470417, 13.758459,
4.0713396, 4.614437, 11.296907, -7.244551, -11.143417, -21.233654};
TensorInfo out_params;
out_params.data = correct;
out_params.len = 30;
std::vector<int> out_shape{5, 6};
out_params.shape = &out_shape;
out_params.min = -50;
out_params.max = 50;
TEST_F(TestFcInt8, fcint8) {
std::vector<lite::tensor::Tensor *> inputs_;
std::vector<lite::tensor::Tensor *> outputs_;
auto matmul_param = new MatMulParameter();
float *correct;
double output_scale;
int output_zp;
int total_size = FcInt8TestInit(&inputs_, &outputs_, matmul_param, &correct, &output_scale, &output_zp);
lite::Context *ctx = new lite::Context;
auto fc_param = new MatMulParameter();
fc_param->a_transpose_ = false;
fc_param->b_transpose_ = true;
fc_param->has_bias_ = true;
fc_param->act_type_ = ActType_No;
std::vector<lite::tensor::Tensor *> inputs;
std::vector<lite::tensor::Tensor *> outputs;
FcInt8TestInit(&inputs, &outputs, &in_params, &weight_params, &bias_params, &out_params);
auto ctx = new lite::Context;
ctx->thread_num_ = 2;
kernel::FullconnectionInt8CPUKernel *fc = new kernel::FullconnectionInt8CPUKernel(
reinterpret_cast<OpParameter *>(matmul_param), inputs_, outputs_, ctx, nullptr);
kernel::FullconnectionInt8CPUKernel *fc =
new kernel::FullconnectionInt8CPUKernel(reinterpret_cast<OpParameter *>(fc_param), inputs, outputs, ctx, nullptr);
fc->Init();
fc->Run();
float fout[6] = {0};
Dequantize(reinterpret_cast<int8_t *>(outputs_[0]->Data()), outputs_[0]->ElementsNum(), output_scale, output_zp,
fout);
CompareOutputData(fout, correct, 6, 0.2);
delete matmul_param;
float out_scale;
int out_zp;
QuantProcess(correct, out_params.len, out_params.min, out_params.max, &out_scale, &out_zp, nullptr);
float *out = new float[out_params.len];
Dequantize(reinterpret_cast<int8_t *>(outputs[0]->Data()), outputs[0]->ElementsNum(), out_scale, out_zp, out);
CompareOutputData(out, correct, 6, 0.3);
delete fc;
for (auto t : inputs_) delete t;
for (auto t : outputs_) delete t;
free(correct);
for (auto t : inputs) delete t;
for (auto t : outputs) delete t;
delete[] out;
}
} // namespace mindspore
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册
新手
引导
客服 返回
顶部