diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/deconvolution_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/deconvolution_int8.cc index f39293b313a5b35850536e7005b88fb402e8c6fc..61a10a0144adec8c8bbd14766a2cf3aea27735dd 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/deconvolution_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/deconvolution_int8.cc @@ -27,7 +27,10 @@ using mindspore::lite::RET_OK; using mindspore::schema::PrimitiveType_DeConv2D; namespace mindspore::kernel { -DeConvInt8CPUKernel::~DeConvInt8CPUKernel() { FreeTmpBuffer(); } +DeConvInt8CPUKernel::~DeConvInt8CPUKernel() { + FreeTmpBuffer(); + ConvolutionBaseCPUKernel::FreeQuantParam(); +} void DeConvInt8CPUKernel::FreeTmpBuffer() { if (weight_ptr_ != nullptr) { @@ -46,20 +49,18 @@ void DeConvInt8CPUKernel::FreeTmpBuffer() { free(tmp_output_); tmp_output_ = nullptr; } - ConvolutionBaseCPUKernel::FreeQuantParam(); + if (input_sum_ != nullptr) { + free(input_sum_); + input_sum_ = nullptr; + } + return; } int DeConvInt8CPUKernel::ReSize() { FreeTmpBuffer(); ConvolutionBaseCPUKernel::Init(); - int error_code = ConvolutionBaseCPUKernel::SetQuantParam(); - if (error_code != RET_OK) { - MS_LOG(ERROR) << "deconv int8 SetQuantParam error!"; - return error_code; - } - - error_code = InitParam(); + int error_code = InitParam(); if (error_code != RET_OK) { MS_LOG(ERROR) << "deconv int8 InitParam error!"; return error_code; @@ -79,76 +80,117 @@ int DeConvInt8CPUKernel::ReSize() { return RET_OK; } +int DeConvInt8CPUKernel::Init() { + if (!InferShapeDone()) { + return RET_OK; + } + + CheckSupportOptimize(); + + int error_code = ConvolutionBaseCPUKernel::SetQuantParam(); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "deconv int8 SetQuantParam error!"; + return error_code; + } + return ReSize(); +} + +void DeConvInt8CPUKernel::CheckSupportOptimize() { + matmul_func_ = nullptr; + support_optimize_ = false; + +#ifdef ENABLE_ARM64 + /* todo */ +#endif + + support_optimize_ = true; + matmul_func_ = MatMulOptR4Int8; +} + int DeConvInt8CPUKernel::InitParam() { - fc_param_ = new MatMulParameter(); - fc_param_->row_ = conv_param_->input_h_ * conv_param_->input_w_; - fc_param_->deep_ = conv_param_->input_channel_; - fc_param_->col_ = conv_param_->output_channel_ * conv_param_->kernel_h_ * conv_param_->kernel_w_; - fc_param_->row_8_ = UP_ROUND(fc_param_->row_, C8NUM); - fc_param_->col_8_ = UP_ROUND(conv_param_->output_channel_, C8NUM) * conv_param_->kernel_h_ * conv_param_->kernel_w_; - - size_t oc8 = UP_DIV(conv_param_->output_channel_, C8NUM); - thread_count_ = MSMIN(op_parameter_->thread_num_, oc8); - thread_stride_ = UP_DIV(oc8, thread_count_) * C8NUM; + matmul_param_ = new MatMulParameter(); + matmul_param_->row_ = conv_param_->input_h_ * conv_param_->input_w_; + matmul_param_->deep_ = conv_param_->input_channel_; + matmul_param_->col_ = conv_param_->output_channel_ * conv_param_->kernel_h_ * conv_param_->kernel_w_; + + if (support_optimize_) { + input_trans_func_ = RowMajor2Row16x4MajorInt8; + size_t oc4 = UP_DIV(conv_param_->output_channel_, C4NUM); + thread_count_ = MSMIN(op_parameter_->thread_num_, oc4); + thread_stride_ = UP_DIV(oc4, thread_count_); + } else { + /*todo */ + } return RET_OK; } int DeConvInt8CPUKernel::InitBiasWeight() { + size_t size = UP_ROUND(conv_param_->output_channel_, C4NUM) * sizeof(int32_t); + bias_data_ = malloc(size); + if (bias_data_ == nullptr) { + MS_LOG(ERROR) << "deconv int8 malloc bias_data_ error!"; + return RET_ERROR; + } + memset(bias_data_, 0, size); if (in_tensors_.size() == 3) { - size_t size = UP_ROUND(conv_param_->output_channel_, C8NUM) * sizeof(int32_t); - bias_data_ = malloc(size); - if (bias_data_ == nullptr) { - MS_LOG(ERROR) << "deconv int8 malloc bias_data_ error!"; - return RET_ERROR; - } - memset(bias_data_, 0, size); memcpy(bias_data_, in_tensors_[0]->Data(), conv_param_->output_channel_ * sizeof(int32_t)); - } else { - bias_data_ = nullptr; } - /* weight: ichwoc(nhwc) -> oc8 * h * w * inc * 8 */ - size_t size = conv_param_->kernel_w_ * conv_param_->kernel_h_ * UP_ROUND(conv_param_->output_channel_, C8NUM) * - conv_param_->input_channel_ * sizeof(int8_t); + size = UP_ROUND(conv_param_->output_channel_, C4NUM) * UP_ROUND(conv_param_->input_channel_, C16NUM) * + conv_param_->kernel_w_ * conv_param_->kernel_h_ * sizeof(int8_t); weight_ptr_ = reinterpret_cast(malloc(size)); if (weight_ptr_ == nullptr) { MS_LOG(ERROR) << "deconv int8 malloc weight_ptr_ error!"; return RET_ERROR; } - memset(weight_ptr_, 0, size); - PackNHWCToC8HWN8Int8(in_tensors_[1]->Data(), weight_ptr_, conv_param_->input_channel_, - conv_param_->kernel_h_ * conv_param_->kernel_w_, conv_param_->output_channel_); + memset(weight_ptr_, static_cast(conv_param_->conv_quant_arg_.filter_quant_args_[0].zp_), size); + DeConvWeightTransInt8(reinterpret_cast(in_tensors_[1]->Data()), weight_ptr_, conv_param_->input_channel_, + conv_param_->output_channel_, conv_param_->kernel_h_ * conv_param_->kernel_w_, + support_optimize_); + + size = UP_ROUND(conv_param_->output_channel_, C4NUM) * conv_param_->kernel_h_ * conv_param_->kernel_w_; + weight_sum_ = reinterpret_cast(malloc(size * sizeof(int32_t))); + if (weight_sum_ == nullptr) { + MS_LOG(ERROR) << "deconv int8 malloc weight_sum_ error!"; + return RET_ERROR; + } + memset(weight_sum_, 0, size * sizeof(int32_t)); + DeConvPackWeightSum(weight_ptr_, weight_sum_, conv_param_->conv_quant_arg_.input_quant_args_[0].zp_, + conv_param_->conv_quant_arg_.filter_quant_args_[0].zp_, UP_ROUND(matmul_param_->deep_, C16NUM), + size, support_optimize_); + return RET_OK; } int DeConvInt8CPUKernel::InitData() { - int size = UP_ROUND(conv_param_->input_h_ * conv_param_->input_w_, C8NUM) * conv_param_->input_channel_; + int size = + UP_ROUND(conv_param_->input_h_ * conv_param_->input_w_, C4NUM) * UP_ROUND(conv_param_->input_channel_, C16NUM); input_ptr_ = reinterpret_cast(malloc(size * sizeof(int8_t))); if (input_ptr_ == nullptr) { return RET_MEMORY_FAILED; } - memset(input_ptr_, 0, size * sizeof(int8_t)); + memset(input_ptr_, static_cast(conv_param_->conv_quant_arg_.input_quant_args_[0].zp_), size * sizeof(int8_t)); - size = UP_ROUND(conv_param_->input_h_ * conv_param_->input_w_, C8NUM) * - UP_ROUND(conv_param_->output_channel_, C8NUM) * conv_param_->kernel_w_ * conv_param_->kernel_h_; + size = UP_ROUND(conv_param_->input_h_ * conv_param_->input_w_, C4NUM) * + UP_ROUND(conv_param_->output_channel_, C4NUM) * conv_param_->kernel_w_ * conv_param_->kernel_h_; tmp_buffer_ = reinterpret_cast(malloc(size * sizeof(int32_t))); if (tmp_buffer_ == nullptr) { return RET_MEMORY_FAILED; } - size = UP_ROUND(conv_param_->output_channel_, C8NUM) * conv_param_->output_h_ * conv_param_->output_w_; + size = UP_ROUND(conv_param_->output_channel_, C4NUM) * conv_param_->output_h_ * conv_param_->output_w_; tmp_output_ = reinterpret_cast(malloc(size * sizeof(int32_t))); if (tmp_output_ == nullptr) { return RET_MEMORY_FAILED; } - return RET_OK; -} -int DeConvInt8CPUKernel::Init() { - if (!InferShapeDone()) { - return RET_OK; + size = UP_ROUND(matmul_param_->row_, C4NUM); + input_sum_ = reinterpret_cast(malloc(size * sizeof(int32_t))); + if (input_sum_ == nullptr) { + return RET_MEMORY_FAILED; } - return ReSize(); + + return RET_OK; } int DeConvInt8Run(int task_id, LiteParallelGroupEnv *penv, void *cdata) { @@ -161,46 +203,26 @@ int DeConvInt8Run(int task_id, LiteParallelGroupEnv *penv, void *cdata) { return RET_OK; } -int DeConvInt8PostFuncRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { - auto deconv = reinterpret_cast(cdata); - auto error_code = deconv->DoPostFunc(task_id); - if (error_code != RET_OK) { - MS_LOG(ERROR) << "DeConvInt8PostFuncRun error task_id[" << task_id << "] error_code[" << error_code << "]"; - return RET_ERROR; - } - return RET_OK; -} - int DeConvInt8CPUKernel::DoDeconv(int task_id) { - int cur_oc = MSMIN(thread_stride_, UP_ROUND(conv_param_->output_channel_, C8NUM) - task_id * thread_stride_); + int cur_oc = MSMIN(thread_stride_, UP_DIV(conv_param_->output_channel_, C8NUM) - task_id * thread_stride_); + int cur_oc_res = MSMIN(thread_stride_ * C4NUM, conv_param_->output_channel_ - task_id * thread_stride_ * C4NUM); if (cur_oc <= 0) { return RET_OK; } - int input_plane = conv_param_->input_h_ * conv_param_->input_w_; - int kernel_plane = conv_param_->kernel_w_ * conv_param_->kernel_h_; - - DeConvInt8(input_ptr_, weight_ptr_ + task_id * thread_stride_ * kernel_plane * conv_param_->input_channel_, - tmp_buffer_ + task_id * thread_stride_ * input_plane * kernel_plane, fc_param_->row_8_, - cur_oc * kernel_plane, fc_param_->deep_, conv_param_); - - return RET_OK; -} - -int DeConvInt8CPUKernel::DoPostFunc(int task_id) { int input_plane = conv_param_->input_h_ * conv_param_->input_w_; int kernel_plane = conv_param_->kernel_w_ * conv_param_->kernel_h_; int output_plane = conv_param_->output_h_ * conv_param_->output_w_; - int cur_oc = MSMIN(thread_stride_, conv_param_->output_channel_ - task_id * thread_stride_); - if (cur_oc <= 0) { - return RET_OK; - } + DeConvInt8(input_ptr_, weight_ptr_ + task_id * thread_stride_ * C4NUM * kernel_plane * conv_param_->input_channel_, + tmp_buffer_ + task_id * thread_stride_ * C4NUM * input_plane * kernel_plane, weight_sum_, input_sum_, + UP_ROUND(matmul_param_->row_, C4NUM), cur_oc * C4NUM * kernel_plane, + UP_ROUND(matmul_param_->deep_, C16NUM), conv_param_, matmul_func_); - DeConvPostInt8(tmp_buffer_ + task_id * thread_stride_ * input_plane * kernel_plane, - reinterpret_cast(bias_data_) + task_id * thread_stride_, - tmp_output_ + task_id * thread_stride_ * output_plane, output_ptr_ + task_id * thread_stride_, cur_oc, - conv_param_); + DeConvPostInt8(tmp_buffer_ + task_id * thread_stride_ * C4NUM * input_plane * kernel_plane, + reinterpret_cast(bias_data_) + task_id * thread_stride_ * C4NUM, + tmp_output_ + task_id * thread_stride_ * C4NUM * output_plane, + output_ptr_ + task_id * thread_stride_ * C4NUM, cur_oc_res, conv_param_, support_optimize_); return RET_OK; } @@ -214,20 +236,18 @@ int DeConvInt8CPUKernel::Run() { int8_t *src_out = reinterpret_cast(out_tensors_[0]->Data()); for (int batch_index = 0; batch_index < conv_param_->input_batch_; batch_index++) { - RowMajor2Col8MajorInt8(src_in + batch_index * fc_param_->row_ * conv_param_->input_channel_, input_ptr_, - fc_param_->row_, fc_param_->deep_); - output_ptr_ = src_out + batch_index * fc_param_->col_; + input_trans_func_(src_in + batch_index * matmul_param_->row_ * conv_param_->input_channel_, input_ptr_, + matmul_param_->row_, matmul_param_->deep_); + output_ptr_ = src_out + batch_index * matmul_param_->col_; + + DeConvPackInputSum(input_ptr_, input_sum_, conv_param_->conv_quant_arg_.filter_quant_args_[0].zp_, + UP_ROUND(matmul_param_->row_, C4NUM), UP_ROUND(matmul_param_->deep_, C16NUM), support_optimize_); int error_code = LiteBackendParallelLaunch(DeConvInt8Run, this, thread_count_); if (error_code != RET_OK) { MS_LOG(ERROR) << "deconv int8 run error! error_code[" << error_code << "]"; return RET_ERROR; } - error_code = LiteBackendParallelLaunch(DeConvInt8PostFuncRun, this, thread_count_); - if (error_code != RET_OK) { - MS_LOG(ERROR) << "deconv int8 post run error! error_code[" << error_code << "]"; - return RET_ERROR; - } } return RET_OK; diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/deconvolution_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/deconvolution_int8.h index d5f1cc989cb55718eff4c5b6a7639e5ad2b9397d..30ebc9851fef4f6a901630078d94c6a40c22480c 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/deconvolution_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/deconvolution_int8.h @@ -23,6 +23,7 @@ #include "include/errorcode.h" #include "src/runtime/kernel/arm/nnacl/matmul_parameter.h" #include "src/runtime/kernel/arm/nnacl/int8/deconv.h" +#include "src/runtime/kernel/arm/nnacl/int8/common_func.h" #include "src/runtime/kernel/arm/nnacl/int8/matmul_int8.h" #include "src/runtime/kernel/arm/base/layout_transform.h" #include "src/runtime/kernel/arm/base/convolution_base.h" @@ -43,23 +44,28 @@ class DeConvInt8CPUKernel : public ConvolutionBaseCPUKernel { public: int DoDeconv(int task_id); - int DoPostFunc(int task_id); private: + void FreeTmpBuffer(); int InitData(); int InitParam(); int InitBiasWeight(); + void CheckSupportOptimize(); private: - void FreeTmpBuffer(); - MatMulParameter *fc_param_ = nullptr; - int8_t *weight_ptr_ = nullptr; - int8_t *input_ptr_ = nullptr; /* record c8 input*/ int32_t *tmp_buffer_ = nullptr; /* record matmul result */ int32_t *tmp_output_ = nullptr; /* record post c8 result */ + int32_t *input_sum_ = nullptr; /* record in * w_zp */ + int32_t *weight_sum_ = nullptr; /* record w_v * in_zp - in_zp * w_zp */ + int8_t *input_ptr_ = nullptr; /* packed input */ + int8_t *weight_ptr_ = nullptr; /* packed weight */ int8_t *output_ptr_ = nullptr; - size_t thread_count_; - size_t thread_stride_; + size_t thread_count_ = 1; + size_t thread_stride_ = 0; + MATMUL_OPT_R4_FUNC matmul_func_; + MAT_TRANS_FUNC input_trans_func_; + MatMulParameter *matmul_param_ = nullptr; + bool support_optimize_ = true; }; } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_DECONVOLUTION_INT8_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.cc index 4256f1bfdf50366867f1ac29e4454eee87891d8e..e41f7b56b9063003f1dae9fc46bd230e6c0a5cfd 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.cc @@ -129,8 +129,8 @@ int FullconnectionInt8CPUKernel::Run() { auto &p = quant_params_; RowMajor2Col8MajorInt8(a_ptr, a_c8_ptr_, fc_param_->row_, fc_param_->deep_); LiteBackendParallelLaunch(FcInt8Run, this, thread_count_); - PostFuncInt8(c_r8x8_ptr_, bias_ptr_, output_ptr, fc_param_->col_, fc_param_->row_, fc_param_->row_8_, - p.quant_multiplier, p.left_shift, p.right_shift, p.output.zp_, p.out_act_min, p.out_act_max); + PostFuncInt8C8(c_r8x8_ptr_, bias_ptr_, output_ptr, fc_param_->col_, fc_param_->row_, p.quant_multiplier, p.left_shift, + p.right_shift, p.output.zp_, p.out_act_min, p.out_act_max); return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.h index 361bedbf7353584db49ec922c0aee7c70c5ae8da..c4db0aa640466ff959360372533900cb34825aa9 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.h @@ -21,6 +21,7 @@ #include "include/context.h" #include "src/runtime/kernel/arm/nnacl/quantization/quantize.h" #include "src/runtime/kernel/arm/base/fullconnection_base.h" +#include "src/runtime/kernel/arm/nnacl/int8/common_func.h" using mindspore::lite::Context; diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/common_func.c b/mindspore/lite/src/runtime/kernel/arm/nnacl/common_func.c index 556042b1c086172944ef13357cc79195130149ed..de5b59cddf1e35c9907b24480972caad4ad0ae7e 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/common_func.c +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/common_func.c @@ -228,27 +228,6 @@ void IndirectGemmFp32_Comm(float *output, const float *input, const float *weigh return; } -void PostFuncInt8(const int *in, const int *bias, int8_t *out, int oc, int plane, int plane8, int32_t multiplier, - int32_t left_shift, int32_t right_shift, int32_t zp, int8_t mini, int8_t maxi) { - /* (int32_t)row8x8-major * multiplier + bias => (int8)relu => (int8_t)row-major */ - for (int r = 0; r < plane; r++) { - for (int c = 0; c < oc; c++) { - int c8div = c / 8, c8mod = c % 8; - int src_index = c8div * plane8 * 8 + r * 8 + c8mod; - int dst_index = r * oc + c; - int32_t value = in[src_index]; - if (bias != NULL) { - value = in[src_index] + bias[c]; - } - value = MultiplyByQuantizedMultiplier(value, multiplier, left_shift, right_shift) + zp; - value = MSMIN(maxi, value); - value = MSMAX(mini, value); - out[dst_index] = (int8_t)value; - } - } - return; -} - void SimplePostFuncInt8(const int *in, int8_t *out, int oc, int plane, int plane8, int32_t multiplier, int32_t left_shift, int32_t right_shift, int32_t zp) { /* (int32_t)row8x8-major * multiplier => (int8_t)row-major */ @@ -265,4 +244,3 @@ void SimplePostFuncInt8(const int *in, int8_t *out, int oc, int plane, int plane } } } - diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/common_func.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/common_func.h index e72966bd2a18b63a6dca2b7612eedc96363625c7..1099ac7902188921471aeae7a3632c485174b987 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/common_func.h +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/common_func.h @@ -31,8 +31,6 @@ int8_t MinInt8(int8_t a, int8_t b); int8_t MaxInt8(int8_t a, int8_t b); void ReluFp32(float *data, float *dst, int ele_num); void Relu6Fp32(float *data, float *dst, int ele_num); -void PostFuncInt8(const int *in, const int *bias, int8_t *out, int oc, int plane, int plane8, int32_t multiplier, - int32_t left_shift, int32_t right_shift, int32_t zp, int8_t mini, int8_t maxi); void SimplePostFuncInt8(const int *in, int8_t *out, int oc, int plane, int plane8, int32_t multiplier, int32_t left_shift, int32_t right_shift, int32_t zp); void IndirectGemmFp32_8x8(float *output, const float *input, const float *weight, const float *bias, size_t step, diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/common_func.c b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/common_func.c new file mode 100644 index 0000000000000000000000000000000000000000..500ebdea792ac65fecf5c24575966a9a47b99dc3 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/common_func.c @@ -0,0 +1,57 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/int8/common_func.h" + +void PostConvFuncCommInt8(const int32_t *in, int8_t *out, const int32_t *bias, size_t oc, size_t plane, + size_t out_oc_stride, size_t in_plane_stride, int32_t multiplier, int8_t mini, int8_t maxi, + int32_t left_shift, int32_t right_shift, int32_t zp, int size) { + if (size == 0) { + return; + } + for (int r = 0; r < plane; r++) { + for (int c = 0; c < oc; c++) { + int c8div = c / size, c8mod = c % size; + int src_index = c8div * in_plane_stride + r * size + c8mod; + int dst_index = r * out_oc_stride + c; + int32_t value = in[src_index]; + if (bias != NULL) { + value = in[src_index] + bias[c]; + } + value = MultiplyByQuantizedMultiplier(value, multiplier, left_shift, right_shift) + zp; + value = MSMIN(maxi, value); + value = MSMAX(mini, value); + out[dst_index] = (int8_t)value; + } + } + return; +} + +void PostFuncInt8C8(const int *in, const int *bias, int8_t *out, int oc, int plane, int32_t multiplier, + int32_t left_shift, int32_t right_shift, int32_t zp, int8_t mini, int8_t maxi) { + /* ((int32_t)row8x8-major + bias) * multiplier + output_zp => (int8)relu => (int8_t)row-major */ + PostConvFuncCommInt8(in, out, bias, oc, plane, oc, UP_ROUND(plane, C8NUM) * C8NUM, multiplier, mini, maxi, left_shift, + right_shift, zp, C8NUM); + return; +} + +void PostFuncInt8C4(const int *in, const int *bias, int8_t *out, int oc, int plane, int stride, int32_t multiplier, + int32_t left_shift, int32_t right_shift, int32_t zp, int8_t mini, int8_t maxi) { + /* ((int32_t)row4x4-major + bias) * multiplier + output_zp => (int8)relu => (int8_t)row-major */ + PostConvFuncCommInt8(in, out, bias, oc, plane, stride, UP_ROUND(plane, C4NUM) * C4NUM, multiplier, mini, maxi, + left_shift, right_shift, zp, C4NUM); + return; +} diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/common_func.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/common_func.h index 35c6ada091d21f30fc0b0a862582f3f47a29f48d..f141bff51162595eeb646e84c01b45fe0df4a3ea 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/common_func.h +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/common_func.h @@ -27,6 +27,11 @@ extern "C" { #endif +void PostFuncInt8C8(const int *in, const int *bias, int8_t *out, int oc, int plane, int32_t multiplier, + int32_t left_shift, int32_t right_shift, int32_t zp, int8_t mini, int8_t maxi); +void PostFuncInt8C4(const int *in, const int *bias, int8_t *out, int oc, int plane, int stride, int32_t multiplier, + int32_t left_shift, int32_t right_shift, int32_t zp, int8_t mini, int8_t maxi); + #ifdef ENABLE_ARM void IndirectGemmInt16to32_8x4(int32_t *dst, const int16_t *src, const int16_t *weight, size_t ksize, size_t ic8, size_t oc4, size_t offset); diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/deconv.c b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/deconv.c index e46d9e9047730f0465c869ea58b68a15b3918a52..3ad6cf904db89e0f1dfc58e54614839e46d550af 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/deconv.c +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/deconv.c @@ -16,17 +16,10 @@ #include "nnacl/int8/deconv.h" #include "nnacl/int8/matmul_int8.h" - -int DeConvInt8(const int8_t *input, const int8_t *weight, int32_t *output, size_t row8, size_t col8, size_t deep, - ConvParameter *conv_param) { - MatMulInt8(input, weight, output, row8, col8, deep, conv_param->conv_quant_arg_.input_quant_args_[0].zp_, - conv_param->conv_quant_arg_.filter_quant_args_[0].zp_); - return NNACL_OK; -} - -int DeConvPostInt8(const int32_t *src, const int32_t *bias, int32_t *tmp, int8_t *out, int output_channel, - ConvParameter *conv_param) { - /* row8x8-major(ih*iw x oc*kh*kw) -> row8x8-major(oh*ow x oc) */ +#include "nnacl/int8/common_func.h" +int DeConvPostInt8C8(const int32_t *src, const int32_t *bias, int32_t *tmp, int8_t *out, int output_channel, + ConvParameter *conv_param) { + /* row8x8-major(ih*iw x oc*kh*kw) -> row8-major(oh*ow x oc) */ size_t input_plane = conv_param->input_w_ * conv_param->input_h_; size_t kernel_plane = conv_param->kernel_w_ * conv_param->kernel_h_; size_t output_plane = conv_param->output_w_ * conv_param->output_h_; @@ -63,9 +56,161 @@ int DeConvPostInt8(const int32_t *src, const int32_t *bias, int32_t *tmp, int8_t } /*ih*/ } /*oc8*/ - PostFuncInt8(tmp, bias, out, output_channel, output_plane, UP_ROUND(output_plane, 8), - conv_param->conv_quant_arg_.quant_multiplier_[0], conv_param->conv_quant_arg_.left_shift_[0], - conv_param->conv_quant_arg_.right_shift_[0], conv_param->conv_quant_arg_.output_quant_args_[0].zp_, - conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0]); + PostFuncInt8C8(tmp, bias, out, output_channel, output_plane, conv_param->conv_quant_arg_.quant_multiplier_[0], + conv_param->conv_quant_arg_.left_shift_[0], conv_param->conv_quant_arg_.right_shift_[0], + conv_param->conv_quant_arg_.output_quant_args_[0].zp_, conv_param->conv_quant_arg_.out_act_min_[0], + conv_param->conv_quant_arg_.out_act_max_[0]); return NNACL_OK; } + +int DeConvPostInt8C4(const int32_t *src, const int32_t *bias, int32_t *tmp, int8_t *out, int output_channel, + ConvParameter *conv_param) { + /* row4x4-major(ih*iw x oc*kh*kw) -> row4-major(oh*ow x oc) */ + size_t input_plane = conv_param->input_w_ * conv_param->input_h_; + size_t kernel_plane = conv_param->kernel_w_ * conv_param->kernel_h_; + size_t output_plane = conv_param->output_w_ * conv_param->output_h_; + int oc4 = UP_DIV(output_channel, C4NUM); + int in_plane4 = UP_ROUND(input_plane, C4NUM); + + int src_iw_stride = C4NUM; + int src_ih_stride = conv_param->input_w_ * C4NUM; + int src_kw_stride = input_plane * C4NUM; + int src_kh_stride = input_plane * conv_param->kernel_w_ * C4NUM; + int dst_oh_stride = conv_param->output_w_ * C4NUM; + int dst_ow_stride = C4NUM; + int dst_kh_stride = conv_param->dilation_h_ * conv_param->output_w_ * C4NUM; + int dst_kw_stride = conv_param->dilation_w_ * C4NUM; + + for (int c = 0; c < oc4; c++) { + int32_t *dst_ptr = tmp + c * output_plane * C4NUM; + const int32_t *src_ptr = src + c * in_plane4 * kernel_plane * C4NUM; + memset(dst_ptr, 0, output_plane * C4NUM * sizeof(int32_t)); + + for (int ih = 0; ih < conv_param->input_h_; ih++) { + for (int iw = 0; iw < conv_param->input_w_; iw++) { + int oh = ih * conv_param->stride_h_ - conv_param->pad_h_; + int ow = iw * conv_param->stride_w_ - conv_param->pad_w_; + + int kh_start = MSMAX(0, UP_DIV(-oh, conv_param->dilation_h_)); + int kh_end = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->output_h_ - oh, conv_param->dilation_h_)); + int kw_start = MSMAX(0, UP_DIV(-ow, conv_param->dilation_w_)); + int kw_end = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->output_w_ - ow, conv_param->dilation_w_)); + for (int kh = kh_start; kh < kh_end; kh++) { + for (int kw = kw_start; kw < kw_end; kw++) { + int src_index = ih * src_ih_stride + iw * src_iw_stride + kh * src_kh_stride + kw * src_kw_stride; + int dst_index = oh * dst_oh_stride + ow * dst_ow_stride + kh * dst_kh_stride + kw * dst_kw_stride; + int32_t *tmp_dst = dst_ptr + dst_index; + const int32_t *tmp_src = src_ptr + src_index; +#ifndef ENABLE_ARM64 + for (int i = 0; i < C4NUM; i++) { + tmp_dst[i] += tmp_src[i]; + } +#else + asm volatile( + "mov x0, %[tmp_src] \n" + "mov x1, %[tmp_dst] \n" + + "ld1 {v0.4s}, [x0] \n" + "ld1 {v1.4s}, [x1] \n" + + "add v0.4s, v0.4s, v1.4s \n" + + "st1 {v0.4s}, [x1] \n" + + : + : [ tmp_src ] "r"(tmp_src), [ tmp_dst ] "r"(tmp_dst) + : "x0", "x1", "v0", "v1"); +#endif + } /*kw*/ + } /*kh*/ + } /*iw*/ + } /*ih*/ + } /*oc*/ + + PostFuncInt8C4(tmp, bias, out, output_channel, output_plane, conv_param->output_channel_, + conv_param->conv_quant_arg_.quant_multiplier_[0], conv_param->conv_quant_arg_.left_shift_[0], + conv_param->conv_quant_arg_.right_shift_[0], conv_param->conv_quant_arg_.output_quant_args_[0].zp_, + conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0]); + return NNACL_OK; +} + +void DeConvWeightTransInt8(int8_t *src, int8_t *dst, int input_channel, int output_channel, int plane, + bool support_optimize_) { + if (support_optimize_) { + int ic16 = UP_ROUND(input_channel, C16NUM); + int oc4 = UP_ROUND(output_channel, C4NUM); + for (int ic = 0; ic < input_channel; ic++) { + int ic16div = ic / C16NUM, ic16mod = ic % C16NUM; + for (int oc = 0; oc < output_channel; oc++) { + int oc4div = oc / C4NUM, oc4mod = oc % C4NUM; + for (int hw = 0; hw < plane; hw++) { + int src_index = ic * output_channel * plane + hw * output_channel + oc; + int dst_index = + hw * ic16 * oc4 + oc4div * ic16 * C4NUM + ic16div * C16NUM * C4NUM + oc4mod * C16NUM + ic16mod; + dst[dst_index] = src[src_index]; + } + } + } + } else { + /* todo normal int8 deconv */ + } + return; +} + +void DeConvPackWeightSum(int8_t *weight, int32_t *weight_sum, int32_t input_zp, int32_t filter_zp, int deep16, int col4, + bool suppport_opt) { + if (suppport_opt) { + for (int c = 0; c < col4; c++) { + int c4div = c / C4NUM, c4mod = c % C4NUM; + int32_t value = 0; + for (int r = 0; r < deep16; r++) { + int r16div = r / 16, r16mod = r % 16; + int src_index = c4div * deep16 * C4NUM + r16div * C4NUM * C16NUM + c4mod * C16NUM + r16mod; + value += weight[src_index]; + } + weight_sum[c] = filter_zp * input_zp * deep16 - value * input_zp; + } + } else { + /* todo normal int8 deconv */ + } + return; +} + +void DeConvPackInputSum(const int8_t *src, int32_t *dst, int32_t filter_zp, int row4, int col16, bool suppport_opt) { + if (suppport_opt) { + for (int r = 0; r < row4; r++) { + int32_t tmp_value = 0; + for (int c = 0; c < col16; c++) { + int r4div = r / C4NUM, r4mod = r % C4NUM, c16div = c / C16NUM, c16mod = c % C16NUM; + int src_index = r4div * C4NUM * col16 + c16div * C16NUM * C4NUM + r4mod * C16NUM + c16mod; + tmp_value += src[src_index]; + } + dst[r] = tmp_value * filter_zp; + } + } else { + /* todo normal int8 deconv */ + } + return; +} + +int DeConvInt8(const int8_t *input, const int8_t *weight, int32_t *output, int32_t *weight_sum, int32_t *input_sum, + size_t act_row, size_t act_col, size_t act_deep, ConvParameter *conv_param, + MATMUL_OPT_R4_FUNC matmul_func) { + if (matmul_func != NULL) { + matmul_func(output, input, weight, weight_sum, input_sum, act_row, act_col, act_deep); + } else { + /* todo normal int8 deconv */ + } + return NNACL_OK; +} + +int DeConvPostInt8(const int32_t *src, const int32_t *bias, int32_t *tmp, int8_t *out, int output_channel, + ConvParameter *conv_param, bool support_optimize) { + int error_code = NNACL_OK; + if (support_optimize) { + error_code = DeConvPostInt8C4(src, bias, tmp, out, output_channel, conv_param); + } else { + /* todo normal int8 deconv post */ + } + return error_code; +} diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/deconv.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/deconv.h index 4664a7ef9c05e0649ce84a61066fc1f5085a4a97..c81104660270736b2a324b6163ef63fef32cad85 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/deconv.h +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/deconv.h @@ -22,16 +22,22 @@ #include "nnacl/errorcode.h" #include "nnacl/conv_parameter.h" #include "nnacl/common_func.h" +#include "nnacl/int8/matmul_int8.h" #ifdef __cplusplus extern "C" { #endif +void DeConvPackWeightSum(int8_t *weight, int32_t *weight_sum, int32_t input_zp, int32_t filter_zp, int deep16, int col4, + bool suppport_opt); +void DeConvPackInputSum(const int8_t *src, int32_t *dst, int32_t filter_zp, int row4, int col16, bool suppport_opt); +void DeConvWeightTransInt8(int8_t *src, int8_t *dst, int input_channel, int output_channel, int plane, + bool support_optimize_); -int DeConvInt8(const int8_t *input, const int8_t *weight, int32_t *output, size_t row8, size_t col8, size_t deep, - ConvParameter *conv_param); - +int DeConvInt8(const int8_t *input, const int8_t *weight, int32_t *output, int32_t *weight_sum, int32_t *input_sum, + size_t act_row, size_t act_col, size_t act_deep, ConvParameter *conv_param, + MATMUL_OPT_R4_FUNC matmul_func); int DeConvPostInt8(const int32_t *src, const int32_t *bias, int32_t *tmp, int8_t *out, int output_channel, - ConvParameter *conv_param); + ConvParameter *conv_param, bool support_optimize); #ifdef __cplusplus } #endif diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/matmul_int8.c b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/matmul_int8.c index 30185a750deb74d5a27b2338414f0b139a99cb34..bca1f639fc8a057e502ac0df22ecbd41488c1eb8 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/matmul_int8.c +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/matmul_int8.c @@ -17,7 +17,6 @@ #include "nnacl/int8/matmul_int8.h" #include #include "nnacl/quantization/fixed_point.h" - void RowMajor2Row8MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col) { for (int r = 0; r < row; r++) { int8_t *src = src_ptr + r * col; @@ -29,6 +28,23 @@ void RowMajor2Row8MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col) } } +void RowMajor2Row16x4MajorInt8(void *src_ptr, void *dst_ptr, int row, int col) { + /* Row-major to row16x4-major (block row-major) */ + int col16 = UP_ROUND(col, C16NUM); + for (int r = 0; r < row; r++) { + int r4div = r / C4NUM; + int r4mod = r % C4NUM; + for (int c = 0; c < col; c++) { + int c16div = c / C16NUM; + int c16mod = c % C16NUM; + int src_index = r * col + c; + int dst_index = r4div * C4NUM * col16 + c16div * C16NUM * C4NUM + r4mod * C16NUM + c16mod; + ((int8_t *)dst_ptr)[dst_index] = ((int8_t *)src_ptr)[src_index]; + } + } + return; +} + void RowMajor2Col8MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col) { for (int r = 0; r < row; r++) { int rd8 = r / 8; @@ -57,3 +73,26 @@ void MatMulInt8(const int8_t *a, const int8_t *b, int32_t *c, const int row8, co } } } + +void MatMulOptR4Int8(int32_t *dst, const int8_t *a, const int8_t *b, const int32_t *bias, const int32_t *input_sum, + size_t row_4, size_t col_4, size_t deep_16) { + /* row4x16-major * row16x4-major => row4x4-major */ + for (int r = 0; r < row_4; r++) { + for (int c = 0; c < col_4; c++) { + int r4div = r / C4NUM, r4mod = r % C4NUM; + int c4div = c / C4NUM, c4mod = c % C4NUM; + size_t ci = c4div * row_4 * C4NUM + r * C4NUM + c4mod; + int32_t value = 0; + for (int d = 0; d < deep_16; d++) { + int d16div = d / C16NUM, d16mod = d % C16NUM; + size_t ai = r4div * deep_16 * C4NUM + d16div * C4NUM * C16NUM + r4mod * C16NUM + d16mod; + size_t bi = c4div * deep_16 * C4NUM + d16div * C4NUM * C16NUM + c4mod * C16NUM + d16mod; + value = value + a[ai] * b[bi]; + } + value -= input_sum[r]; + value += bias[c]; + ((int32_t *)dst)[ci] = value; + } + } + return; +} diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/matmul_int8.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/matmul_int8.h index abc23634289dff054d2d0da96edf43e4683f6dbf..860863f6330b076feec0e3251f3322e44c11ab64 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/matmul_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/matmul_int8.h @@ -18,14 +18,19 @@ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_MATMUL_H_ #include "nnacl/op_base.h" +#include "nnacl/matmul_parameter.h" #ifdef __cplusplus extern "C" { #endif void MatMulInt8(const int8_t *a, const int8_t *b, int32_t *c, const int row8, const int col8, const int deep, const int32_t a_zp, const int32_t b_zp); +void MatMulOptR4Int8(int32_t *dst, const int8_t *a, const int8_t *b, const int32_t *bias, const int32_t *input_sum, + size_t row_4, size_t col_4, size_t deep_16); + void RowMajor2Row8MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col); void RowMajor2Col8MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col); +void RowMajor2Row16x4MajorInt8(void *src_ptr, void *dst_ptr, int row, int col); #ifdef __cplusplus } #endif diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/matmul_parameter.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/matmul_parameter.h index b2b24064d354af1d296d8778503402c0fd3d6000..54c6d1c9a6de74ccaf0d8bfcfec5434fdc03546d 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/matmul_parameter.h +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/matmul_parameter.h @@ -19,6 +19,11 @@ #include "nnacl/op_base.h" +typedef void (*MATMUL_OPT_R4_FUNC)(int32_t *dst, const int8_t *a, const int8_t *b, const int32_t *bias, + const int32_t *input_sum, size_t row_4, size_t col_4, size_t deep_16); + +typedef void (*MAT_TRANS_FUNC)(void *dst, void *a, int row, int col); + typedef enum ActType { ActType_No, ActType_Relu, ActType_Relu6 } ActType; typedef struct MatMulParameter { diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/deconv_int8_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/deconv_int8_tests.cc index ff42754f05aa8b77cb9b0bf3e23ae99212e15bea..a4c32b1778c500acc7827862129a258fd10dc1f7 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/deconv_int8_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/deconv_int8_tests.cc @@ -36,15 +36,6 @@ class TestDeconvInt8 : public mindspore::CommonTest { TestDeconvInt8() {} }; -void FloatToInt8(float *fptr, int8_t *iptr, size_t size, int32_t zp, double scale) { - for (int i = 0; i < size; i++) { - int32_t value = round(fptr[i] / scale + zp); - value = MSMIN(value, INT8_MAX); - value = MSMAX(value, INT8_MIN); - iptr[i] = (int8_t)value; - } -} - TEST_F(TestDeconvInt8, PackWeight1) { int8_t in[] = {-8, 11, 99, -80, 8, -12, 37, -45, 31, -69, -66, 26, 112, 124, -109, 85, -24, 28, -46, 100, 72, -36, -82, 64, -110, 37, -72, 65, -124, 91, -43, 99, 3, 100, 19, 51, -14, -81, 67, 90, @@ -164,6 +155,125 @@ TEST_F(TestDeconvInt8, MatMulTest1) { CompareOutputData(out_row_major, co_row_major_10_18, 180, 1); } +TEST_F(TestDeconvInt8, MatMulOptTest1) { + int8_t a_src_ptr[] = {-6, 76, 32, 80, -73, 8, -85, -3, 114, 80, 30, 42, -41, 117, 62, -76, -77, -111, + 88, 105, 68, 105, -74, 13, 51, 94, 31, -52, -92, -4, -35, -71, 101, -93, 46, -65, + 57, -41, -51, 77, 1, 9, 73, -19, -36, 57, 81, -24, 40, 103, 112, 109, -41, -68, + 57, 61, 55, -20, 3, 2, 17, -16, -31, 58, -4, 67, -4, -95, -5, -72, 81, 15, + -7, -16, -47, 112, 114, -26, -98, 53, 15, -49, 26, 19, 19, 8, -57, -35, -79, 118, + 29, 21, 37, -48, 83, 7, 124, 113, -5, 15, -8, 107, -65, -88, 50, -47, -80, -84, + 3, -45, 92, 42, -20, -101, 106, -10, 89, 67, 55, 10}; + int32_t input_zp = 15; + int8_t b_src_ptr[] = { + 92, 27, 22, 52, -112, -20, -57, -2, 89, 32, 93, -66, -25, -54, 94, -97, -119, -98, 101, -99, + 77, -83, 76, 95, 59, 97, 8, 40, -109, -20, 67, -107, 37, -6, -54, -20, -30, 36, -106, -103, + -3, -86, -82, 59, 4, -75, -50, -106, 55, 104, -117, -71, -20, -85, -77, 16, -25, -58, 4, 80, + -75, 94, 32, -68, 2, 40, 56, -103, 11, -98, -70, -69, 0, 57, -6, 82, 66, -112, -61, 33, + -77, -53, 95, -38, 87, -46, -3, 81, -47, 43, 21, 26, -45, -57, 50, -24, -82, -114, 61, 46, + -53, 78, -24, 31, -7, 37, 29, 38, 45, 106, 52, -42, 31, -6, -61, -87, 2, 79, -5, -42, + 43, -106, -104, 7, 91, -63, 58, 97, -15, 74, -96, 15, -23, -3, -47, -97, 100, -54, 26, -46, + 35, 26, 100, -80, 34, -25, 96, -67, -80, -27, 66, 41, 41, -43, -43, -38, -4, -64, 31, 7, + -8, 6, -2, 39, -119, 53, 75, -91, -44, 77, -62, 22, -44, 78, -67, -48, -115, -4, 43, 81, + 40, -20, -5, -89, 60, -62, -4, -48, 66, -64, -69, 62, 17, -89, 1, 87, 81, 32, -29, 51, + 40, 27, 66, 67, 11, -69, 85, -79, -106, 55, 22, -23, 62, 69, -74, 49}; + int32_t filter_zp = -20; + + /* + * ---------------------- pack input ------------------------- */ + int8_t packed_a[12 * 16] = {0}; + memset(packed_a, static_cast(input_zp), 12 * 16); + int8_t correct_packed_a[] = { + -6, 76, 32, 80, -73, 8, -85, -3, 114, 80, 30, 42, 15, 15, 15, 15, -41, 117, 62, -76, -77, -111, + 88, 105, 68, 105, -74, 13, 15, 15, 15, 15, 51, 94, 31, -52, -92, -4, -35, -71, 101, -93, 46, -65, + 15, 15, 15, 15, 57, -41, -51, 77, 1, 9, 73, -19, -36, 57, 81, -24, 15, 15, 15, 15, 40, 103, + 112, 109, -41, -68, 57, 61, 55, -20, 3, 2, 15, 15, 15, 15, 17, -16, -31, 58, -4, 67, -4, -95, + -5, -72, 81, 15, 15, 15, 15, 15, -7, -16, -47, 112, 114, -26, -98, 53, 15, -49, 26, 19, 15, 15, + 15, 15, 19, 8, -57, -35, -79, 118, 29, 21, 37, -48, 83, 7, 15, 15, 15, 15, 124, 113, -5, 15, + -8, 107, -65, -88, 50, -47, -80, -84, 15, 15, 15, 15, 3, -45, 92, 42, -20, -101, 106, -10, 89, 67, + 55, 10, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, + 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, + }; + RowMajor2Row16x4MajorInt8(a_src_ptr, packed_a, 10, 12); + CompareOutputData(packed_a, correct_packed_a, 16 * 12, 0); + + /* + * ---------------------- pack weight ------------------------- */ + int8_t packed_b[16 * 3 * 8] = {0}; + memset(packed_b, static_cast(filter_zp), 16 * 3 * 8); + int8_t correct_packed_b[] = { + 92, 101, -30, -77, 0, 21, 45, 58, 34, -2, 40, -29, -20, -20, -20, -20, 27, -99, 36, 16, 57, + 26, 106, 97, -25, 39, -20, 51, -20, -20, -20, -20, 22, 77, -106, -25, -6, -45, 52, -15, 96, -119, + -5, 40, -20, -20, -20, -20, 52, -83, -103, -58, 82, -57, -42, 74, -67, 53, -89, 27, -20, -20, -20, + -20, -112, 76, -3, 4, 66, 50, 31, -96, -80, 75, 60, 66, -20, -20, -20, -20, -20, 95, -86, 80, + -112, -24, -6, 15, -27, -91, -62, 67, -20, -20, -20, -20, -20, -20, -20, -20, -20, -20, -20, -20, -20, + -20, -20, -20, -20, -20, -20, -20, -20, -20, -20, -20, -20, -20, -20, -20, -20, -20, -20, -20, -20, -20, + -20, -20, -57, 59, -82, -75, -61, -82, -61, -23, 66, -44, -4, 11, -20, -20, -20, -20, -2, 97, 59, + 94, 33, -114, -87, -3, 41, 77, -48, -69, -20, -20, -20, -20, 89, 8, 4, 32, -77, 61, 2, -47, + 41, -62, 66, 85, -20, -20, -20, -20, 32, 40, -75, -68, -53, 46, 79, -97, -43, 22, -64, -79, -20, + -20, -20, -20, 93, -109, -50, 2, 95, -53, -5, 100, -43, -44, -69, -106, -20, -20, -20, -20, -66, -20, + -106, 40, -38, 78, -42, -54, -38, 78, 62, 55, -20, -20, -20, -20, -20, -20, -20, -20, -20, -20, -20, + -20, -20, -20, -20, -20, -20, -20, -20, -20, -20, -20, -20, -20, -20, -20, -20, -20, -20, -20, -20, -20, + -20, -20, -20, -20, -25, 67, 55, 56, 87, -24, 43, 26, -4, -67, 17, 22, -20, -20, -20, -20, -54, + -107, 104, -103, -46, 31, -106, -46, -64, -48, -89, -23, -20, -20, -20, -20, 94, 37, -117, 11, -3, -7, + -104, 35, 31, -115, 1, 62, -20, -20, -20, -20, -97, -6, -71, -98, 81, 37, 7, 26, 7, -4, 87, + 69, -20, -20, -20, -20, -119, -54, -20, -70, -47, 29, 91, 100, -8, 43, 81, -74, -20, -20, -20, -20, + -98, -20, -85, -69, 43, 38, -63, -80, 6, 81, 32, 49, -20, -20, -20, -20, -20, -20, -20, -20, -20, + -20, -20, -20, -20, -20, -20, -20, -20, -20, -20, -20, -20, -20, -20, -20, -20, -20, -20, -20, -20, -20, + -20, -20, -20, -20, -20, -20}; + DeConvWeightTransInt8(b_src_ptr, packed_b, 12, 6, 3, true); + /* kernel : 12x1x3x6 nhwc */ + CompareOutputData(packed_b, correct_packed_b, 16 * 3 * 8, 0); + + /* + * ---------------------- calculate input_sum ------------------------- */ + int32_t input_sum[12] = {0}; + int32_t correct_input_sum[] = {-7100, -4780, 580, -4880, -9460, -1420, -3120, -3260, -1840, -6960, -4800, -4800}; + DeConvPackInputSum(packed_a, input_sum, filter_zp, 12, 16, true); + CompareOutputData(input_sum, correct_input_sum, 12, 0); + + for (int i = 0; i < 12; i++) { + if (input_sum[i] != correct_input_sum[i]) { + printf("%d %d %d\n", i, input_sum[i], correct_input_sum[i]); + } + } + + /* + * ---------------------- calculate weight_sum ------------------------- */ + int32_t weight_sum[3 * 8] = {0}; + int32_t correct_weight_sum[] = {-7395, -8265, -3090, -435, -5655, -1035, 0, 0, 1695, -4770, -6630, 300, + -765, -2835, 0, 0, -7395, 4665, -2475, -4170, -2880, -1110, 0, 0}; + DeConvPackWeightSum(packed_b, weight_sum, input_zp, filter_zp, 16, 24, true); + CompareOutputData(weight_sum, correct_weight_sum, 3 * 8, 0); + + /* + * ---------------------- do matmul ------------------------- */ + int32_t tmp_output[12 * 24] = {0}; + int32_t correct_tmp_output[] = { + -1624, -19061, 1795, -17119, 14706, 417, 7306, 1357, 9653, -44022, 19414, -36187, -2041, 6874, + -5766, 3072, 9842, 2395, 12464, -18826, -12267, -17853, 4617, -19468, -15734, -6112, 2122, 14259, + 11098, -9520, 12407, -15239, 10309, -34271, 9740, -14607, -5027, 12313, -508, -10808, 0, 0, + 0, 0, 0, 0, 0, 0, 1604, 14898, 0, 0, -8212, 9471, 0, 0, + -23430, 6343, 0, 0, 4020, -3740, 0, 0, -9730, 22378, 0, 0, 4702, 4740, + 0, 0, -7541, 5461, 0, 0, -6633, 8356, 0, 0, -16854, 9147, 0, 0, + -4018, -11524, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 17194, 28501, + 13376, -9359, 21454, 22425, -21049, 6603, 23479, -658, 12866, 9739, -12173, -7558, 3862, 10238, + 4110, 31945, 10069, -7376, -1948, -20322, 16439, 3260, 1712, 12743, -8132, -27744, 7633, -33916, + 18755, 11300, 3686, 9222, 10103, 26102, 17, 13135, 785, -6305, 0, 0, 0, 0, + 0, 0, 0, 0, -27325, 14957, 0, 0, -12191, -21866, 0, 0, -21690, -18554, + 0, 0, 8737, 14529, 0, 0, -1774, -19575, 0, 0, -12761, 13286, 0, 0, + 20523, 2488, 0, 0, -12782, 12688, 0, 0, -1194, -10523, 0, 0, -4044, -9671, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -4671, -4173, 8675, -8560, + -1597, -4946, -20214, -6752, -11439, 5138, 11119, -17661, -6690, -17301, -5541, -4356, 22347, -11778, + 2389, -22030, -5176, -242, 8786, -994, 9104, -7208, 24117, 3724, -13648, -1840, 12265, 10347, + -10325, 7184, 19374, -29001, 3979, -6704, -23278, -8124, 0, 0, 0, 0, 0, 0, + 0, 0, -9132, 8560, 0, 0, 19264, -10169, 0, 0, -15133, -13678, 0, 0, + 7894, -51, 0, 0, -4775, -29785, 0, 0, -12597, 4088, 0, 0, -17420, 1815, + 0, 0, 15796, 3101, 0, 0, -37969, -10818, 0, 0, 12714, -7827, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0}; + MatMulOptR4Int8(tmp_output, packed_a, packed_b, weight_sum, input_sum, 12, 24, 16); + CompareOutputData(tmp_output, correct_tmp_output, 12 * 3 * 8, 0); +} + TEST_F(TestDeconvInt8, PostAddTest1) { int32_t in[] = { -4956, -3923, 868, -8880, -4089, -5179, -4526, -4527, -10464, 99, -5826, -2995, -4519, -4519, -10509, -2505, @@ -185,17 +295,17 @@ TEST_F(TestDeconvInt8, PostAddTest1) { int32_t right_shift; QuantizeRoundParameter(multiplier, &quant_multiplier, &left_shift, &right_shift); int32_t zp = 83; - PostFuncInt8(in, bias, out, 10, 5, 8, quant_multiplier, left_shift, right_shift, zp, -128, 127); + PostFuncInt8C8(in, bias, out, 10, 5, quant_multiplier, left_shift, right_shift, zp, -128, 127); CompareOutputData(out, co, 50, 1); int8_t co_relu[] = {0, 11, 99, 0, 8, 0, 0, 0, 112, 124, 0, 85, 0, 28, 0, 0, 0, 37, 0, 65, 0, 91, 0, 0, 0, 0, 67, 90, 4, 0, 0, 0, 47, 0, 114, 125, 0, 100, 0, 0, 37, 0, 31, 0, 0, 26, 0, 0, 0, 100}; - PostFuncInt8(in, bias, out, 10, 5, 8, quant_multiplier, left_shift, right_shift, zp, 0, 127); + PostFuncInt8C8(in, bias, out, 10, 5, quant_multiplier, left_shift, right_shift, zp, 0, 127); CompareOutputData(out, co_relu, 50, 1); int8_t co_relu6[] = {0, 6, 6, 0, 6, 0, 0, 0, 6, 6, 0, 6, 0, 6, 0, 0, 0, 6, 0, 6, 0, 6, 0, 0, 0, 0, 6, 6, 4, 0, 0, 0, 6, 0, 6, 6, 0, 6, 0, 0, 6, 0, 6, 0, 0, 6, 0, 0, 0, 6}; - PostFuncInt8(in, bias, out, 10, 5, 8, quant_multiplier, left_shift, right_shift, zp, 0, 6); + PostFuncInt8C8(in, bias, out, 10, 5, quant_multiplier, left_shift, right_shift, zp, 0, 6); CompareOutputData(out, co_relu6, 50, 1); } @@ -247,7 +357,7 @@ TEST_F(TestDeconvInt8, DeConvInt8Test1) { std::vector outputs_; auto deconv_param = new ConvParameter(); lite::Context *ctx = new lite::Context; - ctx->thread_num_ = 2; + ctx->thread_num_ = 1; int8_t *correct; int total_size = DeConvInt8TestInit1(&inputs_, &outputs_, deconv_param, &correct); mindspore::kernel::DeConvInt8CPUKernel *deconv = new mindspore::kernel::DeConvInt8CPUKernel(