提交 6fefe939 编写于 作者: L ling

deconv int8

上级 d541e261
...@@ -27,7 +27,10 @@ using mindspore::lite::RET_OK; ...@@ -27,7 +27,10 @@ using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_DeConv2D; using mindspore::schema::PrimitiveType_DeConv2D;
namespace mindspore::kernel { namespace mindspore::kernel {
DeConvInt8CPUKernel::~DeConvInt8CPUKernel() { FreeTmpBuffer(); } DeConvInt8CPUKernel::~DeConvInt8CPUKernel() {
FreeTmpBuffer();
ConvolutionBaseCPUKernel::FreeQuantParam();
}
void DeConvInt8CPUKernel::FreeTmpBuffer() { void DeConvInt8CPUKernel::FreeTmpBuffer() {
if (weight_ptr_ != nullptr) { if (weight_ptr_ != nullptr) {
...@@ -46,20 +49,18 @@ void DeConvInt8CPUKernel::FreeTmpBuffer() { ...@@ -46,20 +49,18 @@ void DeConvInt8CPUKernel::FreeTmpBuffer() {
free(tmp_output_); free(tmp_output_);
tmp_output_ = nullptr; tmp_output_ = nullptr;
} }
ConvolutionBaseCPUKernel::FreeQuantParam(); if (input_sum_ != nullptr) {
free(input_sum_);
input_sum_ = nullptr;
}
return;
} }
int DeConvInt8CPUKernel::ReSize() { int DeConvInt8CPUKernel::ReSize() {
FreeTmpBuffer(); FreeTmpBuffer();
ConvolutionBaseCPUKernel::Init(); ConvolutionBaseCPUKernel::Init();
int error_code = ConvolutionBaseCPUKernel::SetQuantParam(); int error_code = InitParam();
if (error_code != RET_OK) {
MS_LOG(ERROR) << "deconv int8 SetQuantParam error!";
return error_code;
}
error_code = InitParam();
if (error_code != RET_OK) { if (error_code != RET_OK) {
MS_LOG(ERROR) << "deconv int8 InitParam error!"; MS_LOG(ERROR) << "deconv int8 InitParam error!";
return error_code; return error_code;
...@@ -79,76 +80,117 @@ int DeConvInt8CPUKernel::ReSize() { ...@@ -79,76 +80,117 @@ int DeConvInt8CPUKernel::ReSize() {
return RET_OK; return RET_OK;
} }
int DeConvInt8CPUKernel::Init() {
if (!InferShapeDone()) {
return RET_OK;
}
CheckSupportOptimize();
int error_code = ConvolutionBaseCPUKernel::SetQuantParam();
if (error_code != RET_OK) {
MS_LOG(ERROR) << "deconv int8 SetQuantParam error!";
return error_code;
}
return ReSize();
}
void DeConvInt8CPUKernel::CheckSupportOptimize() {
matmul_func_ = nullptr;
support_optimize_ = false;
#ifdef ENABLE_ARM64
/* todo */
#endif
support_optimize_ = true;
matmul_func_ = MatMulOptR4Int8;
}
int DeConvInt8CPUKernel::InitParam() { int DeConvInt8CPUKernel::InitParam() {
fc_param_ = new MatMulParameter(); matmul_param_ = new MatMulParameter();
fc_param_->row_ = conv_param_->input_h_ * conv_param_->input_w_; matmul_param_->row_ = conv_param_->input_h_ * conv_param_->input_w_;
fc_param_->deep_ = conv_param_->input_channel_; matmul_param_->deep_ = conv_param_->input_channel_;
fc_param_->col_ = conv_param_->output_channel_ * conv_param_->kernel_h_ * conv_param_->kernel_w_; matmul_param_->col_ = conv_param_->output_channel_ * conv_param_->kernel_h_ * conv_param_->kernel_w_;
fc_param_->row_8_ = UP_ROUND(fc_param_->row_, C8NUM);
fc_param_->col_8_ = UP_ROUND(conv_param_->output_channel_, C8NUM) * conv_param_->kernel_h_ * conv_param_->kernel_w_; if (support_optimize_) {
input_trans_func_ = RowMajor2Row16x4MajorInt8;
size_t oc8 = UP_DIV(conv_param_->output_channel_, C8NUM); size_t oc4 = UP_DIV(conv_param_->output_channel_, C4NUM);
thread_count_ = MSMIN(op_parameter_->thread_num_, oc8); thread_count_ = MSMIN(op_parameter_->thread_num_, oc4);
thread_stride_ = UP_DIV(oc8, thread_count_) * C8NUM; thread_stride_ = UP_DIV(oc4, thread_count_);
} else {
/*todo */
}
return RET_OK; return RET_OK;
} }
int DeConvInt8CPUKernel::InitBiasWeight() { int DeConvInt8CPUKernel::InitBiasWeight() {
if (in_tensors_.size() == 3) { size_t size = UP_ROUND(conv_param_->output_channel_, C4NUM) * sizeof(int32_t);
size_t size = UP_ROUND(conv_param_->output_channel_, C8NUM) * sizeof(int32_t);
bias_data_ = malloc(size); bias_data_ = malloc(size);
if (bias_data_ == nullptr) { if (bias_data_ == nullptr) {
MS_LOG(ERROR) << "deconv int8 malloc bias_data_ error!"; MS_LOG(ERROR) << "deconv int8 malloc bias_data_ error!";
return RET_ERROR; return RET_ERROR;
} }
memset(bias_data_, 0, size); memset(bias_data_, 0, size);
if (in_tensors_.size() == 3) {
memcpy(bias_data_, in_tensors_[0]->Data(), conv_param_->output_channel_ * sizeof(int32_t)); memcpy(bias_data_, in_tensors_[0]->Data(), conv_param_->output_channel_ * sizeof(int32_t));
} else {
bias_data_ = nullptr;
} }
/* weight: ichwoc(nhwc) -> oc8 * h * w * inc * 8 */ size = UP_ROUND(conv_param_->output_channel_, C4NUM) * UP_ROUND(conv_param_->input_channel_, C16NUM) *
size_t size = conv_param_->kernel_w_ * conv_param_->kernel_h_ * UP_ROUND(conv_param_->output_channel_, C8NUM) * conv_param_->kernel_w_ * conv_param_->kernel_h_ * sizeof(int8_t);
conv_param_->input_channel_ * sizeof(int8_t);
weight_ptr_ = reinterpret_cast<int8_t *>(malloc(size)); weight_ptr_ = reinterpret_cast<int8_t *>(malloc(size));
if (weight_ptr_ == nullptr) { if (weight_ptr_ == nullptr) {
MS_LOG(ERROR) << "deconv int8 malloc weight_ptr_ error!"; MS_LOG(ERROR) << "deconv int8 malloc weight_ptr_ error!";
return RET_ERROR; return RET_ERROR;
} }
memset(weight_ptr_, 0, size); memset(weight_ptr_, static_cast<int8_t>(conv_param_->conv_quant_arg_.filter_quant_args_[0].zp_), size);
PackNHWCToC8HWN8Int8(in_tensors_[1]->Data(), weight_ptr_, conv_param_->input_channel_, DeConvWeightTransInt8(reinterpret_cast<int8_t *>(in_tensors_[1]->Data()), weight_ptr_, conv_param_->input_channel_,
conv_param_->kernel_h_ * conv_param_->kernel_w_, conv_param_->output_channel_); conv_param_->output_channel_, conv_param_->kernel_h_ * conv_param_->kernel_w_,
support_optimize_);
size = UP_ROUND(conv_param_->output_channel_, C4NUM) * conv_param_->kernel_h_ * conv_param_->kernel_w_;
weight_sum_ = reinterpret_cast<int32_t *>(malloc(size * sizeof(int32_t)));
if (weight_sum_ == nullptr) {
MS_LOG(ERROR) << "deconv int8 malloc weight_sum_ error!";
return RET_ERROR;
}
memset(weight_sum_, 0, size * sizeof(int32_t));
DeConvPackWeightSum(weight_ptr_, weight_sum_, conv_param_->conv_quant_arg_.input_quant_args_[0].zp_,
conv_param_->conv_quant_arg_.filter_quant_args_[0].zp_, UP_ROUND(matmul_param_->deep_, C16NUM),
size, support_optimize_);
return RET_OK; return RET_OK;
} }
int DeConvInt8CPUKernel::InitData() { int DeConvInt8CPUKernel::InitData() {
int size = UP_ROUND(conv_param_->input_h_ * conv_param_->input_w_, C8NUM) * conv_param_->input_channel_; int size =
UP_ROUND(conv_param_->input_h_ * conv_param_->input_w_, C4NUM) * UP_ROUND(conv_param_->input_channel_, C16NUM);
input_ptr_ = reinterpret_cast<int8_t *>(malloc(size * sizeof(int8_t))); input_ptr_ = reinterpret_cast<int8_t *>(malloc(size * sizeof(int8_t)));
if (input_ptr_ == nullptr) { if (input_ptr_ == nullptr) {
return RET_MEMORY_FAILED; return RET_MEMORY_FAILED;
} }
memset(input_ptr_, 0, size * sizeof(int8_t)); memset(input_ptr_, static_cast<int8_t>(conv_param_->conv_quant_arg_.input_quant_args_[0].zp_), size * sizeof(int8_t));
size = UP_ROUND(conv_param_->input_h_ * conv_param_->input_w_, C8NUM) * size = UP_ROUND(conv_param_->input_h_ * conv_param_->input_w_, C4NUM) *
UP_ROUND(conv_param_->output_channel_, C8NUM) * conv_param_->kernel_w_ * conv_param_->kernel_h_; UP_ROUND(conv_param_->output_channel_, C4NUM) * conv_param_->kernel_w_ * conv_param_->kernel_h_;
tmp_buffer_ = reinterpret_cast<int32_t *>(malloc(size * sizeof(int32_t))); tmp_buffer_ = reinterpret_cast<int32_t *>(malloc(size * sizeof(int32_t)));
if (tmp_buffer_ == nullptr) { if (tmp_buffer_ == nullptr) {
return RET_MEMORY_FAILED; return RET_MEMORY_FAILED;
} }
size = UP_ROUND(conv_param_->output_channel_, C8NUM) * conv_param_->output_h_ * conv_param_->output_w_; size = UP_ROUND(conv_param_->output_channel_, C4NUM) * conv_param_->output_h_ * conv_param_->output_w_;
tmp_output_ = reinterpret_cast<int32_t *>(malloc(size * sizeof(int32_t))); tmp_output_ = reinterpret_cast<int32_t *>(malloc(size * sizeof(int32_t)));
if (tmp_output_ == nullptr) { if (tmp_output_ == nullptr) {
return RET_MEMORY_FAILED; return RET_MEMORY_FAILED;
} }
return RET_OK;
}
int DeConvInt8CPUKernel::Init() { size = UP_ROUND(matmul_param_->row_, C4NUM);
if (!InferShapeDone()) { input_sum_ = reinterpret_cast<int32_t *>(malloc(size * sizeof(int32_t)));
return RET_OK; if (input_sum_ == nullptr) {
return RET_MEMORY_FAILED;
} }
return ReSize();
return RET_OK;
} }
int DeConvInt8Run(int task_id, LiteParallelGroupEnv *penv, void *cdata) { int DeConvInt8Run(int task_id, LiteParallelGroupEnv *penv, void *cdata) {
...@@ -161,46 +203,26 @@ int DeConvInt8Run(int task_id, LiteParallelGroupEnv *penv, void *cdata) { ...@@ -161,46 +203,26 @@ int DeConvInt8Run(int task_id, LiteParallelGroupEnv *penv, void *cdata) {
return RET_OK; return RET_OK;
} }
int DeConvInt8PostFuncRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) {
auto deconv = reinterpret_cast<DeConvInt8CPUKernel *>(cdata);
auto error_code = deconv->DoPostFunc(task_id);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "DeConvInt8PostFuncRun error task_id[" << task_id << "] error_code[" << error_code << "]";
return RET_ERROR;
}
return RET_OK;
}
int DeConvInt8CPUKernel::DoDeconv(int task_id) { int DeConvInt8CPUKernel::DoDeconv(int task_id) {
int cur_oc = MSMIN(thread_stride_, UP_ROUND(conv_param_->output_channel_, C8NUM) - task_id * thread_stride_); int cur_oc = MSMIN(thread_stride_, UP_DIV(conv_param_->output_channel_, C8NUM) - task_id * thread_stride_);
int cur_oc_res = MSMIN(thread_stride_ * C4NUM, conv_param_->output_channel_ - task_id * thread_stride_ * C4NUM);
if (cur_oc <= 0) { if (cur_oc <= 0) {
return RET_OK; return RET_OK;
} }
int input_plane = conv_param_->input_h_ * conv_param_->input_w_;
int kernel_plane = conv_param_->kernel_w_ * conv_param_->kernel_h_;
DeConvInt8(input_ptr_, weight_ptr_ + task_id * thread_stride_ * kernel_plane * conv_param_->input_channel_,
tmp_buffer_ + task_id * thread_stride_ * input_plane * kernel_plane, fc_param_->row_8_,
cur_oc * kernel_plane, fc_param_->deep_, conv_param_);
return RET_OK;
}
int DeConvInt8CPUKernel::DoPostFunc(int task_id) {
int input_plane = conv_param_->input_h_ * conv_param_->input_w_; int input_plane = conv_param_->input_h_ * conv_param_->input_w_;
int kernel_plane = conv_param_->kernel_w_ * conv_param_->kernel_h_; int kernel_plane = conv_param_->kernel_w_ * conv_param_->kernel_h_;
int output_plane = conv_param_->output_h_ * conv_param_->output_w_; int output_plane = conv_param_->output_h_ * conv_param_->output_w_;
int cur_oc = MSMIN(thread_stride_, conv_param_->output_channel_ - task_id * thread_stride_); DeConvInt8(input_ptr_, weight_ptr_ + task_id * thread_stride_ * C4NUM * kernel_plane * conv_param_->input_channel_,
if (cur_oc <= 0) { tmp_buffer_ + task_id * thread_stride_ * C4NUM * input_plane * kernel_plane, weight_sum_, input_sum_,
return RET_OK; UP_ROUND(matmul_param_->row_, C4NUM), cur_oc * C4NUM * kernel_plane,
} UP_ROUND(matmul_param_->deep_, C16NUM), conv_param_, matmul_func_);
DeConvPostInt8(tmp_buffer_ + task_id * thread_stride_ * input_plane * kernel_plane, DeConvPostInt8(tmp_buffer_ + task_id * thread_stride_ * C4NUM * input_plane * kernel_plane,
reinterpret_cast<int32_t *>(bias_data_) + task_id * thread_stride_, reinterpret_cast<int32_t *>(bias_data_) + task_id * thread_stride_ * C4NUM,
tmp_output_ + task_id * thread_stride_ * output_plane, output_ptr_ + task_id * thread_stride_, cur_oc, tmp_output_ + task_id * thread_stride_ * C4NUM * output_plane,
conv_param_); output_ptr_ + task_id * thread_stride_ * C4NUM, cur_oc_res, conv_param_, support_optimize_);
return RET_OK; return RET_OK;
} }
...@@ -214,20 +236,18 @@ int DeConvInt8CPUKernel::Run() { ...@@ -214,20 +236,18 @@ int DeConvInt8CPUKernel::Run() {
int8_t *src_out = reinterpret_cast<int8_t *>(out_tensors_[0]->Data()); int8_t *src_out = reinterpret_cast<int8_t *>(out_tensors_[0]->Data());
for (int batch_index = 0; batch_index < conv_param_->input_batch_; batch_index++) { for (int batch_index = 0; batch_index < conv_param_->input_batch_; batch_index++) {
RowMajor2Col8MajorInt8(src_in + batch_index * fc_param_->row_ * conv_param_->input_channel_, input_ptr_, input_trans_func_(src_in + batch_index * matmul_param_->row_ * conv_param_->input_channel_, input_ptr_,
fc_param_->row_, fc_param_->deep_); matmul_param_->row_, matmul_param_->deep_);
output_ptr_ = src_out + batch_index * fc_param_->col_; output_ptr_ = src_out + batch_index * matmul_param_->col_;
DeConvPackInputSum(input_ptr_, input_sum_, conv_param_->conv_quant_arg_.filter_quant_args_[0].zp_,
UP_ROUND(matmul_param_->row_, C4NUM), UP_ROUND(matmul_param_->deep_, C16NUM), support_optimize_);
int error_code = LiteBackendParallelLaunch(DeConvInt8Run, this, thread_count_); int error_code = LiteBackendParallelLaunch(DeConvInt8Run, this, thread_count_);
if (error_code != RET_OK) { if (error_code != RET_OK) {
MS_LOG(ERROR) << "deconv int8 run error! error_code[" << error_code << "]"; MS_LOG(ERROR) << "deconv int8 run error! error_code[" << error_code << "]";
return RET_ERROR; return RET_ERROR;
} }
error_code = LiteBackendParallelLaunch(DeConvInt8PostFuncRun, this, thread_count_);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "deconv int8 post run error! error_code[" << error_code << "]";
return RET_ERROR;
}
} }
return RET_OK; return RET_OK;
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include "include/errorcode.h" #include "include/errorcode.h"
#include "src/runtime/kernel/arm/nnacl/matmul_parameter.h" #include "src/runtime/kernel/arm/nnacl/matmul_parameter.h"
#include "src/runtime/kernel/arm/nnacl/int8/deconv.h" #include "src/runtime/kernel/arm/nnacl/int8/deconv.h"
#include "src/runtime/kernel/arm/nnacl/int8/common_func.h"
#include "src/runtime/kernel/arm/nnacl/int8/matmul_int8.h" #include "src/runtime/kernel/arm/nnacl/int8/matmul_int8.h"
#include "src/runtime/kernel/arm/base/layout_transform.h" #include "src/runtime/kernel/arm/base/layout_transform.h"
#include "src/runtime/kernel/arm/base/convolution_base.h" #include "src/runtime/kernel/arm/base/convolution_base.h"
...@@ -43,23 +44,28 @@ class DeConvInt8CPUKernel : public ConvolutionBaseCPUKernel { ...@@ -43,23 +44,28 @@ class DeConvInt8CPUKernel : public ConvolutionBaseCPUKernel {
public: public:
int DoDeconv(int task_id); int DoDeconv(int task_id);
int DoPostFunc(int task_id);
private: private:
void FreeTmpBuffer();
int InitData(); int InitData();
int InitParam(); int InitParam();
int InitBiasWeight(); int InitBiasWeight();
void CheckSupportOptimize();
private: private:
void FreeTmpBuffer();
MatMulParameter *fc_param_ = nullptr;
int8_t *weight_ptr_ = nullptr;
int8_t *input_ptr_ = nullptr; /* record c8 input*/
int32_t *tmp_buffer_ = nullptr; /* record matmul result */ int32_t *tmp_buffer_ = nullptr; /* record matmul result */
int32_t *tmp_output_ = nullptr; /* record post c8 result */ int32_t *tmp_output_ = nullptr; /* record post c8 result */
int32_t *input_sum_ = nullptr; /* record in * w_zp */
int32_t *weight_sum_ = nullptr; /* record w_v * in_zp - in_zp * w_zp */
int8_t *input_ptr_ = nullptr; /* packed input */
int8_t *weight_ptr_ = nullptr; /* packed weight */
int8_t *output_ptr_ = nullptr; int8_t *output_ptr_ = nullptr;
size_t thread_count_; size_t thread_count_ = 1;
size_t thread_stride_; size_t thread_stride_ = 0;
MATMUL_OPT_R4_FUNC matmul_func_;
MAT_TRANS_FUNC input_trans_func_;
MatMulParameter *matmul_param_ = nullptr;
bool support_optimize_ = true;
}; };
} // namespace mindspore::kernel } // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_DECONVOLUTION_INT8_H_ #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_DECONVOLUTION_INT8_H_
...@@ -129,8 +129,8 @@ int FullconnectionInt8CPUKernel::Run() { ...@@ -129,8 +129,8 @@ int FullconnectionInt8CPUKernel::Run() {
auto &p = quant_params_; auto &p = quant_params_;
RowMajor2Col8MajorInt8(a_ptr, a_c8_ptr_, fc_param_->row_, fc_param_->deep_); RowMajor2Col8MajorInt8(a_ptr, a_c8_ptr_, fc_param_->row_, fc_param_->deep_);
LiteBackendParallelLaunch(FcInt8Run, this, thread_count_); LiteBackendParallelLaunch(FcInt8Run, this, thread_count_);
PostFuncInt8(c_r8x8_ptr_, bias_ptr_, output_ptr, fc_param_->col_, fc_param_->row_, fc_param_->row_8_, PostFuncInt8C8(c_r8x8_ptr_, bias_ptr_, output_ptr, fc_param_->col_, fc_param_->row_, p.quant_multiplier, p.left_shift,
p.quant_multiplier, p.left_shift, p.right_shift, p.output.zp_, p.out_act_min, p.out_act_max); p.right_shift, p.output.zp_, p.out_act_min, p.out_act_max);
return RET_OK; return RET_OK;
} }
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include "include/context.h" #include "include/context.h"
#include "src/runtime/kernel/arm/nnacl/quantization/quantize.h" #include "src/runtime/kernel/arm/nnacl/quantization/quantize.h"
#include "src/runtime/kernel/arm/base/fullconnection_base.h" #include "src/runtime/kernel/arm/base/fullconnection_base.h"
#include "src/runtime/kernel/arm/nnacl/int8/common_func.h"
using mindspore::lite::Context; using mindspore::lite::Context;
......
...@@ -228,27 +228,6 @@ void IndirectGemmFp32_Comm(float *output, const float *input, const float *weigh ...@@ -228,27 +228,6 @@ void IndirectGemmFp32_Comm(float *output, const float *input, const float *weigh
return; return;
} }
void PostFuncInt8(const int *in, const int *bias, int8_t *out, int oc, int plane, int plane8, int32_t multiplier,
int32_t left_shift, int32_t right_shift, int32_t zp, int8_t mini, int8_t maxi) {
/* (int32_t)row8x8-major * multiplier + bias => (int8)relu => (int8_t)row-major */
for (int r = 0; r < plane; r++) {
for (int c = 0; c < oc; c++) {
int c8div = c / 8, c8mod = c % 8;
int src_index = c8div * plane8 * 8 + r * 8 + c8mod;
int dst_index = r * oc + c;
int32_t value = in[src_index];
if (bias != NULL) {
value = in[src_index] + bias[c];
}
value = MultiplyByQuantizedMultiplier(value, multiplier, left_shift, right_shift) + zp;
value = MSMIN(maxi, value);
value = MSMAX(mini, value);
out[dst_index] = (int8_t)value;
}
}
return;
}
void SimplePostFuncInt8(const int *in, int8_t *out, int oc, int plane, int plane8, int32_t multiplier, void SimplePostFuncInt8(const int *in, int8_t *out, int oc, int plane, int plane8, int32_t multiplier,
int32_t left_shift, int32_t right_shift, int32_t zp) { int32_t left_shift, int32_t right_shift, int32_t zp) {
/* (int32_t)row8x8-major * multiplier => (int8_t)row-major */ /* (int32_t)row8x8-major * multiplier => (int8_t)row-major */
...@@ -265,4 +244,3 @@ void SimplePostFuncInt8(const int *in, int8_t *out, int oc, int plane, int plane ...@@ -265,4 +244,3 @@ void SimplePostFuncInt8(const int *in, int8_t *out, int oc, int plane, int plane
} }
} }
} }
...@@ -31,8 +31,6 @@ int8_t MinInt8(int8_t a, int8_t b); ...@@ -31,8 +31,6 @@ int8_t MinInt8(int8_t a, int8_t b);
int8_t MaxInt8(int8_t a, int8_t b); int8_t MaxInt8(int8_t a, int8_t b);
void ReluFp32(float *data, float *dst, int ele_num); void ReluFp32(float *data, float *dst, int ele_num);
void Relu6Fp32(float *data, float *dst, int ele_num); void Relu6Fp32(float *data, float *dst, int ele_num);
void PostFuncInt8(const int *in, const int *bias, int8_t *out, int oc, int plane, int plane8, int32_t multiplier,
int32_t left_shift, int32_t right_shift, int32_t zp, int8_t mini, int8_t maxi);
void SimplePostFuncInt8(const int *in, int8_t *out, int oc, int plane, int plane8, int32_t multiplier, void SimplePostFuncInt8(const int *in, int8_t *out, int oc, int plane, int plane8, int32_t multiplier,
int32_t left_shift, int32_t right_shift, int32_t zp); int32_t left_shift, int32_t right_shift, int32_t zp);
void IndirectGemmFp32_8x8(float *output, const float *input, const float *weight, const float *bias, size_t step, void IndirectGemmFp32_8x8(float *output, const float *input, const float *weight, const float *bias, size_t step,
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "nnacl/int8/common_func.h"
void PostConvFuncCommInt8(const int32_t *in, int8_t *out, const int32_t *bias, size_t oc, size_t plane,
size_t out_oc_stride, size_t in_plane_stride, int32_t multiplier, int8_t mini, int8_t maxi,
int32_t left_shift, int32_t right_shift, int32_t zp, int size) {
if (size == 0) {
return;
}
for (int r = 0; r < plane; r++) {
for (int c = 0; c < oc; c++) {
int c8div = c / size, c8mod = c % size;
int src_index = c8div * in_plane_stride + r * size + c8mod;
int dst_index = r * out_oc_stride + c;
int32_t value = in[src_index];
if (bias != NULL) {
value = in[src_index] + bias[c];
}
value = MultiplyByQuantizedMultiplier(value, multiplier, left_shift, right_shift) + zp;
value = MSMIN(maxi, value);
value = MSMAX(mini, value);
out[dst_index] = (int8_t)value;
}
}
return;
}
void PostFuncInt8C8(const int *in, const int *bias, int8_t *out, int oc, int plane, int32_t multiplier,
int32_t left_shift, int32_t right_shift, int32_t zp, int8_t mini, int8_t maxi) {
/* ((int32_t)row8x8-major + bias) * multiplier + output_zp => (int8)relu => (int8_t)row-major */
PostConvFuncCommInt8(in, out, bias, oc, plane, oc, UP_ROUND(plane, C8NUM) * C8NUM, multiplier, mini, maxi, left_shift,
right_shift, zp, C8NUM);
return;
}
void PostFuncInt8C4(const int *in, const int *bias, int8_t *out, int oc, int plane, int stride, int32_t multiplier,
int32_t left_shift, int32_t right_shift, int32_t zp, int8_t mini, int8_t maxi) {
/* ((int32_t)row4x4-major + bias) * multiplier + output_zp => (int8)relu => (int8_t)row-major */
PostConvFuncCommInt8(in, out, bias, oc, plane, stride, UP_ROUND(plane, C4NUM) * C4NUM, multiplier, mini, maxi,
left_shift, right_shift, zp, C4NUM);
return;
}
...@@ -27,6 +27,11 @@ ...@@ -27,6 +27,11 @@
extern "C" { extern "C" {
#endif #endif
void PostFuncInt8C8(const int *in, const int *bias, int8_t *out, int oc, int plane, int32_t multiplier,
int32_t left_shift, int32_t right_shift, int32_t zp, int8_t mini, int8_t maxi);
void PostFuncInt8C4(const int *in, const int *bias, int8_t *out, int oc, int plane, int stride, int32_t multiplier,
int32_t left_shift, int32_t right_shift, int32_t zp, int8_t mini, int8_t maxi);
#ifdef ENABLE_ARM #ifdef ENABLE_ARM
void IndirectGemmInt16to32_8x4(int32_t *dst, const int16_t *src, const int16_t *weight, size_t ksize, size_t ic8, void IndirectGemmInt16to32_8x4(int32_t *dst, const int16_t *src, const int16_t *weight, size_t ksize, size_t ic8,
size_t oc4, size_t offset); size_t oc4, size_t offset);
......
...@@ -16,17 +16,10 @@ ...@@ -16,17 +16,10 @@
#include "nnacl/int8/deconv.h" #include "nnacl/int8/deconv.h"
#include "nnacl/int8/matmul_int8.h" #include "nnacl/int8/matmul_int8.h"
#include "nnacl/int8/common_func.h"
int DeConvInt8(const int8_t *input, const int8_t *weight, int32_t *output, size_t row8, size_t col8, size_t deep, int DeConvPostInt8C8(const int32_t *src, const int32_t *bias, int32_t *tmp, int8_t *out, int output_channel,
ConvParameter *conv_param) {
MatMulInt8(input, weight, output, row8, col8, deep, conv_param->conv_quant_arg_.input_quant_args_[0].zp_,
conv_param->conv_quant_arg_.filter_quant_args_[0].zp_);
return NNACL_OK;
}
int DeConvPostInt8(const int32_t *src, const int32_t *bias, int32_t *tmp, int8_t *out, int output_channel,
ConvParameter *conv_param) { ConvParameter *conv_param) {
/* row8x8-major(ih*iw x oc*kh*kw) -> row8x8-major(oh*ow x oc) */ /* row8x8-major(ih*iw x oc*kh*kw) -> row8-major(oh*ow x oc) */
size_t input_plane = conv_param->input_w_ * conv_param->input_h_; size_t input_plane = conv_param->input_w_ * conv_param->input_h_;
size_t kernel_plane = conv_param->kernel_w_ * conv_param->kernel_h_; size_t kernel_plane = conv_param->kernel_w_ * conv_param->kernel_h_;
size_t output_plane = conv_param->output_w_ * conv_param->output_h_; size_t output_plane = conv_param->output_w_ * conv_param->output_h_;
...@@ -63,9 +56,161 @@ int DeConvPostInt8(const int32_t *src, const int32_t *bias, int32_t *tmp, int8_t ...@@ -63,9 +56,161 @@ int DeConvPostInt8(const int32_t *src, const int32_t *bias, int32_t *tmp, int8_t
} /*ih*/ } /*ih*/
} /*oc8*/ } /*oc8*/
PostFuncInt8(tmp, bias, out, output_channel, output_plane, UP_ROUND(output_plane, 8), PostFuncInt8C8(tmp, bias, out, output_channel, output_plane, conv_param->conv_quant_arg_.quant_multiplier_[0],
conv_param->conv_quant_arg_.left_shift_[0], conv_param->conv_quant_arg_.right_shift_[0],
conv_param->conv_quant_arg_.output_quant_args_[0].zp_, conv_param->conv_quant_arg_.out_act_min_[0],
conv_param->conv_quant_arg_.out_act_max_[0]);
return NNACL_OK;
}
int DeConvPostInt8C4(const int32_t *src, const int32_t *bias, int32_t *tmp, int8_t *out, int output_channel,
ConvParameter *conv_param) {
/* row4x4-major(ih*iw x oc*kh*kw) -> row4-major(oh*ow x oc) */
size_t input_plane = conv_param->input_w_ * conv_param->input_h_;
size_t kernel_plane = conv_param->kernel_w_ * conv_param->kernel_h_;
size_t output_plane = conv_param->output_w_ * conv_param->output_h_;
int oc4 = UP_DIV(output_channel, C4NUM);
int in_plane4 = UP_ROUND(input_plane, C4NUM);
int src_iw_stride = C4NUM;
int src_ih_stride = conv_param->input_w_ * C4NUM;
int src_kw_stride = input_plane * C4NUM;
int src_kh_stride = input_plane * conv_param->kernel_w_ * C4NUM;
int dst_oh_stride = conv_param->output_w_ * C4NUM;
int dst_ow_stride = C4NUM;
int dst_kh_stride = conv_param->dilation_h_ * conv_param->output_w_ * C4NUM;
int dst_kw_stride = conv_param->dilation_w_ * C4NUM;
for (int c = 0; c < oc4; c++) {
int32_t *dst_ptr = tmp + c * output_plane * C4NUM;
const int32_t *src_ptr = src + c * in_plane4 * kernel_plane * C4NUM;
memset(dst_ptr, 0, output_plane * C4NUM * sizeof(int32_t));
for (int ih = 0; ih < conv_param->input_h_; ih++) {
for (int iw = 0; iw < conv_param->input_w_; iw++) {
int oh = ih * conv_param->stride_h_ - conv_param->pad_h_;
int ow = iw * conv_param->stride_w_ - conv_param->pad_w_;
int kh_start = MSMAX(0, UP_DIV(-oh, conv_param->dilation_h_));
int kh_end = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->output_h_ - oh, conv_param->dilation_h_));
int kw_start = MSMAX(0, UP_DIV(-ow, conv_param->dilation_w_));
int kw_end = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->output_w_ - ow, conv_param->dilation_w_));
for (int kh = kh_start; kh < kh_end; kh++) {
for (int kw = kw_start; kw < kw_end; kw++) {
int src_index = ih * src_ih_stride + iw * src_iw_stride + kh * src_kh_stride + kw * src_kw_stride;
int dst_index = oh * dst_oh_stride + ow * dst_ow_stride + kh * dst_kh_stride + kw * dst_kw_stride;
int32_t *tmp_dst = dst_ptr + dst_index;
const int32_t *tmp_src = src_ptr + src_index;
#ifndef ENABLE_ARM64
for (int i = 0; i < C4NUM; i++) {
tmp_dst[i] += tmp_src[i];
}
#else
asm volatile(
"mov x0, %[tmp_src] \n"
"mov x1, %[tmp_dst] \n"
"ld1 {v0.4s}, [x0] \n"
"ld1 {v1.4s}, [x1] \n"
"add v0.4s, v0.4s, v1.4s \n"
"st1 {v0.4s}, [x1] \n"
:
: [ tmp_src ] "r"(tmp_src), [ tmp_dst ] "r"(tmp_dst)
: "x0", "x1", "v0", "v1");
#endif
} /*kw*/
} /*kh*/
} /*iw*/
} /*ih*/
} /*oc*/
PostFuncInt8C4(tmp, bias, out, output_channel, output_plane, conv_param->output_channel_,
conv_param->conv_quant_arg_.quant_multiplier_[0], conv_param->conv_quant_arg_.left_shift_[0], conv_param->conv_quant_arg_.quant_multiplier_[0], conv_param->conv_quant_arg_.left_shift_[0],
conv_param->conv_quant_arg_.right_shift_[0], conv_param->conv_quant_arg_.output_quant_args_[0].zp_, conv_param->conv_quant_arg_.right_shift_[0], conv_param->conv_quant_arg_.output_quant_args_[0].zp_,
conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0]); conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0]);
return NNACL_OK; return NNACL_OK;
} }
void DeConvWeightTransInt8(int8_t *src, int8_t *dst, int input_channel, int output_channel, int plane,
bool support_optimize_) {
if (support_optimize_) {
int ic16 = UP_ROUND(input_channel, C16NUM);
int oc4 = UP_ROUND(output_channel, C4NUM);
for (int ic = 0; ic < input_channel; ic++) {
int ic16div = ic / C16NUM, ic16mod = ic % C16NUM;
for (int oc = 0; oc < output_channel; oc++) {
int oc4div = oc / C4NUM, oc4mod = oc % C4NUM;
for (int hw = 0; hw < plane; hw++) {
int src_index = ic * output_channel * plane + hw * output_channel + oc;
int dst_index =
hw * ic16 * oc4 + oc4div * ic16 * C4NUM + ic16div * C16NUM * C4NUM + oc4mod * C16NUM + ic16mod;
dst[dst_index] = src[src_index];
}
}
}
} else {
/* todo normal int8 deconv */
}
return;
}
void DeConvPackWeightSum(int8_t *weight, int32_t *weight_sum, int32_t input_zp, int32_t filter_zp, int deep16, int col4,
bool suppport_opt) {
if (suppport_opt) {
for (int c = 0; c < col4; c++) {
int c4div = c / C4NUM, c4mod = c % C4NUM;
int32_t value = 0;
for (int r = 0; r < deep16; r++) {
int r16div = r / 16, r16mod = r % 16;
int src_index = c4div * deep16 * C4NUM + r16div * C4NUM * C16NUM + c4mod * C16NUM + r16mod;
value += weight[src_index];
}
weight_sum[c] = filter_zp * input_zp * deep16 - value * input_zp;
}
} else {
/* todo normal int8 deconv */
}
return;
}
void DeConvPackInputSum(const int8_t *src, int32_t *dst, int32_t filter_zp, int row4, int col16, bool suppport_opt) {
if (suppport_opt) {
for (int r = 0; r < row4; r++) {
int32_t tmp_value = 0;
for (int c = 0; c < col16; c++) {
int r4div = r / C4NUM, r4mod = r % C4NUM, c16div = c / C16NUM, c16mod = c % C16NUM;
int src_index = r4div * C4NUM * col16 + c16div * C16NUM * C4NUM + r4mod * C16NUM + c16mod;
tmp_value += src[src_index];
}
dst[r] = tmp_value * filter_zp;
}
} else {
/* todo normal int8 deconv */
}
return;
}
int DeConvInt8(const int8_t *input, const int8_t *weight, int32_t *output, int32_t *weight_sum, int32_t *input_sum,
size_t act_row, size_t act_col, size_t act_deep, ConvParameter *conv_param,
MATMUL_OPT_R4_FUNC matmul_func) {
if (matmul_func != NULL) {
matmul_func(output, input, weight, weight_sum, input_sum, act_row, act_col, act_deep);
} else {
/* todo normal int8 deconv */
}
return NNACL_OK;
}
int DeConvPostInt8(const int32_t *src, const int32_t *bias, int32_t *tmp, int8_t *out, int output_channel,
ConvParameter *conv_param, bool support_optimize) {
int error_code = NNACL_OK;
if (support_optimize) {
error_code = DeConvPostInt8C4(src, bias, tmp, out, output_channel, conv_param);
} else {
/* todo normal int8 deconv post */
}
return error_code;
}
...@@ -22,16 +22,22 @@ ...@@ -22,16 +22,22 @@
#include "nnacl/errorcode.h" #include "nnacl/errorcode.h"
#include "nnacl/conv_parameter.h" #include "nnacl/conv_parameter.h"
#include "nnacl/common_func.h" #include "nnacl/common_func.h"
#include "nnacl/int8/matmul_int8.h"
#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {
#endif #endif
void DeConvPackWeightSum(int8_t *weight, int32_t *weight_sum, int32_t input_zp, int32_t filter_zp, int deep16, int col4,
bool suppport_opt);
void DeConvPackInputSum(const int8_t *src, int32_t *dst, int32_t filter_zp, int row4, int col16, bool suppport_opt);
void DeConvWeightTransInt8(int8_t *src, int8_t *dst, int input_channel, int output_channel, int plane,
bool support_optimize_);
int DeConvInt8(const int8_t *input, const int8_t *weight, int32_t *output, size_t row8, size_t col8, size_t deep, int DeConvInt8(const int8_t *input, const int8_t *weight, int32_t *output, int32_t *weight_sum, int32_t *input_sum,
ConvParameter *conv_param); size_t act_row, size_t act_col, size_t act_deep, ConvParameter *conv_param,
MATMUL_OPT_R4_FUNC matmul_func);
int DeConvPostInt8(const int32_t *src, const int32_t *bias, int32_t *tmp, int8_t *out, int output_channel, int DeConvPostInt8(const int32_t *src, const int32_t *bias, int32_t *tmp, int8_t *out, int output_channel,
ConvParameter *conv_param); ConvParameter *conv_param, bool support_optimize);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif
......
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
#include "nnacl/int8/matmul_int8.h" #include "nnacl/int8/matmul_int8.h"
#include <limits.h> #include <limits.h>
#include "nnacl/quantization/fixed_point.h" #include "nnacl/quantization/fixed_point.h"
void RowMajor2Row8MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col) { void RowMajor2Row8MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col) {
for (int r = 0; r < row; r++) { for (int r = 0; r < row; r++) {
int8_t *src = src_ptr + r * col; int8_t *src = src_ptr + r * col;
...@@ -29,6 +28,23 @@ void RowMajor2Row8MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col) ...@@ -29,6 +28,23 @@ void RowMajor2Row8MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col)
} }
} }
void RowMajor2Row16x4MajorInt8(void *src_ptr, void *dst_ptr, int row, int col) {
/* Row-major to row16x4-major (block row-major) */
int col16 = UP_ROUND(col, C16NUM);
for (int r = 0; r < row; r++) {
int r4div = r / C4NUM;
int r4mod = r % C4NUM;
for (int c = 0; c < col; c++) {
int c16div = c / C16NUM;
int c16mod = c % C16NUM;
int src_index = r * col + c;
int dst_index = r4div * C4NUM * col16 + c16div * C16NUM * C4NUM + r4mod * C16NUM + c16mod;
((int8_t *)dst_ptr)[dst_index] = ((int8_t *)src_ptr)[src_index];
}
}
return;
}
void RowMajor2Col8MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col) { void RowMajor2Col8MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col) {
for (int r = 0; r < row; r++) { for (int r = 0; r < row; r++) {
int rd8 = r / 8; int rd8 = r / 8;
...@@ -57,3 +73,26 @@ void MatMulInt8(const int8_t *a, const int8_t *b, int32_t *c, const int row8, co ...@@ -57,3 +73,26 @@ void MatMulInt8(const int8_t *a, const int8_t *b, int32_t *c, const int row8, co
} }
} }
} }
void MatMulOptR4Int8(int32_t *dst, const int8_t *a, const int8_t *b, const int32_t *bias, const int32_t *input_sum,
size_t row_4, size_t col_4, size_t deep_16) {
/* row4x16-major * row16x4-major => row4x4-major */
for (int r = 0; r < row_4; r++) {
for (int c = 0; c < col_4; c++) {
int r4div = r / C4NUM, r4mod = r % C4NUM;
int c4div = c / C4NUM, c4mod = c % C4NUM;
size_t ci = c4div * row_4 * C4NUM + r * C4NUM + c4mod;
int32_t value = 0;
for (int d = 0; d < deep_16; d++) {
int d16div = d / C16NUM, d16mod = d % C16NUM;
size_t ai = r4div * deep_16 * C4NUM + d16div * C4NUM * C16NUM + r4mod * C16NUM + d16mod;
size_t bi = c4div * deep_16 * C4NUM + d16div * C4NUM * C16NUM + c4mod * C16NUM + d16mod;
value = value + a[ai] * b[bi];
}
value -= input_sum[r];
value += bias[c];
((int32_t *)dst)[ci] = value;
}
}
return;
}
...@@ -18,14 +18,19 @@ ...@@ -18,14 +18,19 @@
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_MATMUL_H_ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_MATMUL_H_
#include "nnacl/op_base.h" #include "nnacl/op_base.h"
#include "nnacl/matmul_parameter.h"
#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {
#endif #endif
void MatMulInt8(const int8_t *a, const int8_t *b, int32_t *c, const int row8, const int col8, const int deep, void MatMulInt8(const int8_t *a, const int8_t *b, int32_t *c, const int row8, const int col8, const int deep,
const int32_t a_zp, const int32_t b_zp); const int32_t a_zp, const int32_t b_zp);
void MatMulOptR4Int8(int32_t *dst, const int8_t *a, const int8_t *b, const int32_t *bias, const int32_t *input_sum,
size_t row_4, size_t col_4, size_t deep_16);
void RowMajor2Row8MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col); void RowMajor2Row8MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col);
void RowMajor2Col8MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col); void RowMajor2Col8MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col);
void RowMajor2Row16x4MajorInt8(void *src_ptr, void *dst_ptr, int row, int col);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif
......
...@@ -19,6 +19,11 @@ ...@@ -19,6 +19,11 @@
#include "nnacl/op_base.h" #include "nnacl/op_base.h"
typedef void (*MATMUL_OPT_R4_FUNC)(int32_t *dst, const int8_t *a, const int8_t *b, const int32_t *bias,
const int32_t *input_sum, size_t row_4, size_t col_4, size_t deep_16);
typedef void (*MAT_TRANS_FUNC)(void *dst, void *a, int row, int col);
typedef enum ActType { ActType_No, ActType_Relu, ActType_Relu6 } ActType; typedef enum ActType { ActType_No, ActType_Relu, ActType_Relu6 } ActType;
typedef struct MatMulParameter { typedef struct MatMulParameter {
......
...@@ -36,15 +36,6 @@ class TestDeconvInt8 : public mindspore::CommonTest { ...@@ -36,15 +36,6 @@ class TestDeconvInt8 : public mindspore::CommonTest {
TestDeconvInt8() {} TestDeconvInt8() {}
}; };
void FloatToInt8(float *fptr, int8_t *iptr, size_t size, int32_t zp, double scale) {
for (int i = 0; i < size; i++) {
int32_t value = round(fptr[i] / scale + zp);
value = MSMIN(value, INT8_MAX);
value = MSMAX(value, INT8_MIN);
iptr[i] = (int8_t)value;
}
}
TEST_F(TestDeconvInt8, PackWeight1) { TEST_F(TestDeconvInt8, PackWeight1) {
int8_t in[] = {-8, 11, 99, -80, 8, -12, 37, -45, 31, -69, -66, 26, 112, 124, -109, 85, -24, 28, -46, 100, int8_t in[] = {-8, 11, 99, -80, 8, -12, 37, -45, 31, -69, -66, 26, 112, 124, -109, 85, -24, 28, -46, 100,
72, -36, -82, 64, -110, 37, -72, 65, -124, 91, -43, 99, 3, 100, 19, 51, -14, -81, 67, 90, 72, -36, -82, 64, -110, 37, -72, 65, -124, 91, -43, 99, 3, 100, 19, 51, -14, -81, 67, 90,
...@@ -164,6 +155,125 @@ TEST_F(TestDeconvInt8, MatMulTest1) { ...@@ -164,6 +155,125 @@ TEST_F(TestDeconvInt8, MatMulTest1) {
CompareOutputData(out_row_major, co_row_major_10_18, 180, 1); CompareOutputData(out_row_major, co_row_major_10_18, 180, 1);
} }
TEST_F(TestDeconvInt8, MatMulOptTest1) {
int8_t a_src_ptr[] = {-6, 76, 32, 80, -73, 8, -85, -3, 114, 80, 30, 42, -41, 117, 62, -76, -77, -111,
88, 105, 68, 105, -74, 13, 51, 94, 31, -52, -92, -4, -35, -71, 101, -93, 46, -65,
57, -41, -51, 77, 1, 9, 73, -19, -36, 57, 81, -24, 40, 103, 112, 109, -41, -68,
57, 61, 55, -20, 3, 2, 17, -16, -31, 58, -4, 67, -4, -95, -5, -72, 81, 15,
-7, -16, -47, 112, 114, -26, -98, 53, 15, -49, 26, 19, 19, 8, -57, -35, -79, 118,
29, 21, 37, -48, 83, 7, 124, 113, -5, 15, -8, 107, -65, -88, 50, -47, -80, -84,
3, -45, 92, 42, -20, -101, 106, -10, 89, 67, 55, 10};
int32_t input_zp = 15;
int8_t b_src_ptr[] = {
92, 27, 22, 52, -112, -20, -57, -2, 89, 32, 93, -66, -25, -54, 94, -97, -119, -98, 101, -99,
77, -83, 76, 95, 59, 97, 8, 40, -109, -20, 67, -107, 37, -6, -54, -20, -30, 36, -106, -103,
-3, -86, -82, 59, 4, -75, -50, -106, 55, 104, -117, -71, -20, -85, -77, 16, -25, -58, 4, 80,
-75, 94, 32, -68, 2, 40, 56, -103, 11, -98, -70, -69, 0, 57, -6, 82, 66, -112, -61, 33,
-77, -53, 95, -38, 87, -46, -3, 81, -47, 43, 21, 26, -45, -57, 50, -24, -82, -114, 61, 46,
-53, 78, -24, 31, -7, 37, 29, 38, 45, 106, 52, -42, 31, -6, -61, -87, 2, 79, -5, -42,
43, -106, -104, 7, 91, -63, 58, 97, -15, 74, -96, 15, -23, -3, -47, -97, 100, -54, 26, -46,
35, 26, 100, -80, 34, -25, 96, -67, -80, -27, 66, 41, 41, -43, -43, -38, -4, -64, 31, 7,
-8, 6, -2, 39, -119, 53, 75, -91, -44, 77, -62, 22, -44, 78, -67, -48, -115, -4, 43, 81,
40, -20, -5, -89, 60, -62, -4, -48, 66, -64, -69, 62, 17, -89, 1, 87, 81, 32, -29, 51,
40, 27, 66, 67, 11, -69, 85, -79, -106, 55, 22, -23, 62, 69, -74, 49};
int32_t filter_zp = -20;
/*
* ---------------------- pack input ------------------------- */
int8_t packed_a[12 * 16] = {0};
memset(packed_a, static_cast<int8_t>(input_zp), 12 * 16);
int8_t correct_packed_a[] = {
-6, 76, 32, 80, -73, 8, -85, -3, 114, 80, 30, 42, 15, 15, 15, 15, -41, 117, 62, -76, -77, -111,
88, 105, 68, 105, -74, 13, 15, 15, 15, 15, 51, 94, 31, -52, -92, -4, -35, -71, 101, -93, 46, -65,
15, 15, 15, 15, 57, -41, -51, 77, 1, 9, 73, -19, -36, 57, 81, -24, 15, 15, 15, 15, 40, 103,
112, 109, -41, -68, 57, 61, 55, -20, 3, 2, 15, 15, 15, 15, 17, -16, -31, 58, -4, 67, -4, -95,
-5, -72, 81, 15, 15, 15, 15, 15, -7, -16, -47, 112, 114, -26, -98, 53, 15, -49, 26, 19, 15, 15,
15, 15, 19, 8, -57, -35, -79, 118, 29, 21, 37, -48, 83, 7, 15, 15, 15, 15, 124, 113, -5, 15,
-8, 107, -65, -88, 50, -47, -80, -84, 15, 15, 15, 15, 3, -45, 92, 42, -20, -101, 106, -10, 89, 67,
55, 10, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15,
15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15,
};
RowMajor2Row16x4MajorInt8(a_src_ptr, packed_a, 10, 12);
CompareOutputData(packed_a, correct_packed_a, 16 * 12, 0);
/*
* ---------------------- pack weight ------------------------- */
int8_t packed_b[16 * 3 * 8] = {0};
memset(packed_b, static_cast<int8_t>(filter_zp), 16 * 3 * 8);
int8_t correct_packed_b[] = {
92, 101, -30, -77, 0, 21, 45, 58, 34, -2, 40, -29, -20, -20, -20, -20, 27, -99, 36, 16, 57,
26, 106, 97, -25, 39, -20, 51, -20, -20, -20, -20, 22, 77, -106, -25, -6, -45, 52, -15, 96, -119,
-5, 40, -20, -20, -20, -20, 52, -83, -103, -58, 82, -57, -42, 74, -67, 53, -89, 27, -20, -20, -20,
-20, -112, 76, -3, 4, 66, 50, 31, -96, -80, 75, 60, 66, -20, -20, -20, -20, -20, 95, -86, 80,
-112, -24, -6, 15, -27, -91, -62, 67, -20, -20, -20, -20, -20, -20, -20, -20, -20, -20, -20, -20, -20,
-20, -20, -20, -20, -20, -20, -20, -20, -20, -20, -20, -20, -20, -20, -20, -20, -20, -20, -20, -20, -20,
-20, -20, -57, 59, -82, -75, -61, -82, -61, -23, 66, -44, -4, 11, -20, -20, -20, -20, -2, 97, 59,
94, 33, -114, -87, -3, 41, 77, -48, -69, -20, -20, -20, -20, 89, 8, 4, 32, -77, 61, 2, -47,
41, -62, 66, 85, -20, -20, -20, -20, 32, 40, -75, -68, -53, 46, 79, -97, -43, 22, -64, -79, -20,
-20, -20, -20, 93, -109, -50, 2, 95, -53, -5, 100, -43, -44, -69, -106, -20, -20, -20, -20, -66, -20,
-106, 40, -38, 78, -42, -54, -38, 78, 62, 55, -20, -20, -20, -20, -20, -20, -20, -20, -20, -20, -20,
-20, -20, -20, -20, -20, -20, -20, -20, -20, -20, -20, -20, -20, -20, -20, -20, -20, -20, -20, -20, -20,
-20, -20, -20, -20, -25, 67, 55, 56, 87, -24, 43, 26, -4, -67, 17, 22, -20, -20, -20, -20, -54,
-107, 104, -103, -46, 31, -106, -46, -64, -48, -89, -23, -20, -20, -20, -20, 94, 37, -117, 11, -3, -7,
-104, 35, 31, -115, 1, 62, -20, -20, -20, -20, -97, -6, -71, -98, 81, 37, 7, 26, 7, -4, 87,
69, -20, -20, -20, -20, -119, -54, -20, -70, -47, 29, 91, 100, -8, 43, 81, -74, -20, -20, -20, -20,
-98, -20, -85, -69, 43, 38, -63, -80, 6, 81, 32, 49, -20, -20, -20, -20, -20, -20, -20, -20, -20,
-20, -20, -20, -20, -20, -20, -20, -20, -20, -20, -20, -20, -20, -20, -20, -20, -20, -20, -20, -20, -20,
-20, -20, -20, -20, -20, -20};
DeConvWeightTransInt8(b_src_ptr, packed_b, 12, 6, 3, true);
/* kernel : 12x1x3x6 nhwc */
CompareOutputData(packed_b, correct_packed_b, 16 * 3 * 8, 0);
/*
* ---------------------- calculate input_sum ------------------------- */
int32_t input_sum[12] = {0};
int32_t correct_input_sum[] = {-7100, -4780, 580, -4880, -9460, -1420, -3120, -3260, -1840, -6960, -4800, -4800};
DeConvPackInputSum(packed_a, input_sum, filter_zp, 12, 16, true);
CompareOutputData(input_sum, correct_input_sum, 12, 0);
for (int i = 0; i < 12; i++) {
if (input_sum[i] != correct_input_sum[i]) {
printf("%d %d %d\n", i, input_sum[i], correct_input_sum[i]);
}
}
/*
* ---------------------- calculate weight_sum ------------------------- */
int32_t weight_sum[3 * 8] = {0};
int32_t correct_weight_sum[] = {-7395, -8265, -3090, -435, -5655, -1035, 0, 0, 1695, -4770, -6630, 300,
-765, -2835, 0, 0, -7395, 4665, -2475, -4170, -2880, -1110, 0, 0};
DeConvPackWeightSum(packed_b, weight_sum, input_zp, filter_zp, 16, 24, true);
CompareOutputData(weight_sum, correct_weight_sum, 3 * 8, 0);
/*
* ---------------------- do matmul ------------------------- */
int32_t tmp_output[12 * 24] = {0};
int32_t correct_tmp_output[] = {
-1624, -19061, 1795, -17119, 14706, 417, 7306, 1357, 9653, -44022, 19414, -36187, -2041, 6874,
-5766, 3072, 9842, 2395, 12464, -18826, -12267, -17853, 4617, -19468, -15734, -6112, 2122, 14259,
11098, -9520, 12407, -15239, 10309, -34271, 9740, -14607, -5027, 12313, -508, -10808, 0, 0,
0, 0, 0, 0, 0, 0, 1604, 14898, 0, 0, -8212, 9471, 0, 0,
-23430, 6343, 0, 0, 4020, -3740, 0, 0, -9730, 22378, 0, 0, 4702, 4740,
0, 0, -7541, 5461, 0, 0, -6633, 8356, 0, 0, -16854, 9147, 0, 0,
-4018, -11524, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 17194, 28501,
13376, -9359, 21454, 22425, -21049, 6603, 23479, -658, 12866, 9739, -12173, -7558, 3862, 10238,
4110, 31945, 10069, -7376, -1948, -20322, 16439, 3260, 1712, 12743, -8132, -27744, 7633, -33916,
18755, 11300, 3686, 9222, 10103, 26102, 17, 13135, 785, -6305, 0, 0, 0, 0,
0, 0, 0, 0, -27325, 14957, 0, 0, -12191, -21866, 0, 0, -21690, -18554,
0, 0, 8737, 14529, 0, 0, -1774, -19575, 0, 0, -12761, 13286, 0, 0,
20523, 2488, 0, 0, -12782, 12688, 0, 0, -1194, -10523, 0, 0, -4044, -9671,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -4671, -4173, 8675, -8560,
-1597, -4946, -20214, -6752, -11439, 5138, 11119, -17661, -6690, -17301, -5541, -4356, 22347, -11778,
2389, -22030, -5176, -242, 8786, -994, 9104, -7208, 24117, 3724, -13648, -1840, 12265, 10347,
-10325, 7184, 19374, -29001, 3979, -6704, -23278, -8124, 0, 0, 0, 0, 0, 0,
0, 0, -9132, 8560, 0, 0, 19264, -10169, 0, 0, -15133, -13678, 0, 0,
7894, -51, 0, 0, -4775, -29785, 0, 0, -12597, 4088, 0, 0, -17420, 1815,
0, 0, 15796, 3101, 0, 0, -37969, -10818, 0, 0, 12714, -7827, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0};
MatMulOptR4Int8(tmp_output, packed_a, packed_b, weight_sum, input_sum, 12, 24, 16);
CompareOutputData(tmp_output, correct_tmp_output, 12 * 3 * 8, 0);
}
TEST_F(TestDeconvInt8, PostAddTest1) { TEST_F(TestDeconvInt8, PostAddTest1) {
int32_t in[] = { int32_t in[] = {
-4956, -3923, 868, -8880, -4089, -5179, -4526, -4527, -10464, 99, -5826, -2995, -4519, -4519, -10509, -2505, -4956, -3923, 868, -8880, -4089, -5179, -4526, -4527, -10464, 99, -5826, -2995, -4519, -4519, -10509, -2505,
...@@ -185,17 +295,17 @@ TEST_F(TestDeconvInt8, PostAddTest1) { ...@@ -185,17 +295,17 @@ TEST_F(TestDeconvInt8, PostAddTest1) {
int32_t right_shift; int32_t right_shift;
QuantizeRoundParameter(multiplier, &quant_multiplier, &left_shift, &right_shift); QuantizeRoundParameter(multiplier, &quant_multiplier, &left_shift, &right_shift);
int32_t zp = 83; int32_t zp = 83;
PostFuncInt8(in, bias, out, 10, 5, 8, quant_multiplier, left_shift, right_shift, zp, -128, 127); PostFuncInt8C8(in, bias, out, 10, 5, quant_multiplier, left_shift, right_shift, zp, -128, 127);
CompareOutputData(out, co, 50, 1); CompareOutputData(out, co, 50, 1);
int8_t co_relu[] = {0, 11, 99, 0, 8, 0, 0, 0, 112, 124, 0, 85, 0, 28, 0, 0, 0, 37, 0, 65, 0, 91, 0, 0, 0, int8_t co_relu[] = {0, 11, 99, 0, 8, 0, 0, 0, 112, 124, 0, 85, 0, 28, 0, 0, 0, 37, 0, 65, 0, 91, 0, 0, 0,
0, 67, 90, 4, 0, 0, 0, 47, 0, 114, 125, 0, 100, 0, 0, 37, 0, 31, 0, 0, 26, 0, 0, 0, 100}; 0, 67, 90, 4, 0, 0, 0, 47, 0, 114, 125, 0, 100, 0, 0, 37, 0, 31, 0, 0, 26, 0, 0, 0, 100};
PostFuncInt8(in, bias, out, 10, 5, 8, quant_multiplier, left_shift, right_shift, zp, 0, 127); PostFuncInt8C8(in, bias, out, 10, 5, quant_multiplier, left_shift, right_shift, zp, 0, 127);
CompareOutputData(out, co_relu, 50, 1); CompareOutputData(out, co_relu, 50, 1);
int8_t co_relu6[] = {0, 6, 6, 0, 6, 0, 0, 0, 6, 6, 0, 6, 0, 6, 0, 0, 0, 6, 0, 6, 0, 6, 0, 0, 0, int8_t co_relu6[] = {0, 6, 6, 0, 6, 0, 0, 0, 6, 6, 0, 6, 0, 6, 0, 0, 0, 6, 0, 6, 0, 6, 0, 0, 0,
0, 6, 6, 4, 0, 0, 0, 6, 0, 6, 6, 0, 6, 0, 0, 6, 0, 6, 0, 0, 6, 0, 0, 0, 6}; 0, 6, 6, 4, 0, 0, 0, 6, 0, 6, 6, 0, 6, 0, 0, 6, 0, 6, 0, 0, 6, 0, 0, 0, 6};
PostFuncInt8(in, bias, out, 10, 5, 8, quant_multiplier, left_shift, right_shift, zp, 0, 6); PostFuncInt8C8(in, bias, out, 10, 5, quant_multiplier, left_shift, right_shift, zp, 0, 6);
CompareOutputData(out, co_relu6, 50, 1); CompareOutputData(out, co_relu6, 50, 1);
} }
...@@ -247,7 +357,7 @@ TEST_F(TestDeconvInt8, DeConvInt8Test1) { ...@@ -247,7 +357,7 @@ TEST_F(TestDeconvInt8, DeConvInt8Test1) {
std::vector<lite::tensor::Tensor *> outputs_; std::vector<lite::tensor::Tensor *> outputs_;
auto deconv_param = new ConvParameter(); auto deconv_param = new ConvParameter();
lite::Context *ctx = new lite::Context; lite::Context *ctx = new lite::Context;
ctx->thread_num_ = 2; ctx->thread_num_ = 1;
int8_t *correct; int8_t *correct;
int total_size = DeConvInt8TestInit1(&inputs_, &outputs_, deconv_param, &correct); int total_size = DeConvInt8TestInit1(&inputs_, &outputs_, deconv_param, &correct);
mindspore::kernel::DeConvInt8CPUKernel *deconv = new mindspore::kernel::DeConvInt8CPUKernel( mindspore::kernel::DeConvInt8CPUKernel *deconv = new mindspore::kernel::DeConvInt8CPUKernel(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册