diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/arithmetic_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/arithmetic_int8.cc index c5ab89807ffcc4b8ba3328709e9433a60fab356e..18865d5650383c34d050d49d8e3a0fca510ed437 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/arithmetic_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/arithmetic_int8.cc @@ -15,7 +15,6 @@ */ #include "src/runtime/kernel/arm/int8/arithmetic_int8.h" -#include "src/runtime/kernel/arm/nnacl/int8/arithmetic_int8.h" #include "src/runtime/kernel/arm/nnacl/arithmetic_common.h" #include "schema/model_generated.h" #include "src/kernel_registry.h" @@ -42,7 +41,7 @@ int ArithmeticsInt8Launch(int thread_id, LiteParallelGroupEnv *penv, void *cdata auto error_code = arithmetic_kernel->DoArithmetic(thread_id); if (error_code != RET_OK) { MS_LOG(ERROR) << "ArithmeticsRun error thread_id[" << thread_id << "] error_code[" << error_code << "]"; - return RET_ERROR; + return error_code; } return RET_OK; } @@ -79,28 +78,43 @@ ArithmeticInt8CPUKernel::~ArithmeticInt8CPUKernel() { int ArithmeticInt8CPUKernel::Init() { switch (op_parameter_->type_) { case PrimitiveType_Equal: - arithmetic_run_ = ElementEqual; + arithmetic_run_ = ElementEqualInt8; break; case PrimitiveType_NotEqual: - arithmetic_run_ = ElementNotEqual; + arithmetic_run_ = ElementNotEqualInt8; break; case PrimitiveType_Less: - arithmetic_run_ = ElementLess; + arithmetic_run_ = ElementLessInt8; break; case PrimitiveType_LessEqual: - arithmetic_run_ = ElementLessEqual; + arithmetic_run_ = ElementLessEqualInt8; break; case PrimitiveType_Greater: - arithmetic_run_ = ElementGreater; + arithmetic_run_ = ElementGreaterInt8; break; case PrimitiveType_GreaterEqual: - arithmetic_run_ = ElementGreaterEqual; + arithmetic_run_ = ElementGreaterEqualInt8; break; default: MS_LOG(ERROR) << "Error Operator type " << op_parameter_->type_; arithmetic_run_ = nullptr; return RET_PARAM_INVALID; } + + auto *input0_tensor = in_tensors_.at(0); + auto in0_quant_args = input0_tensor->GetQuantParams(); + quant_args_.in0_args_.scale_ = in0_quant_args.front().scale; + quant_args_.in0_args_.zp_ = in0_quant_args.front().zeroPoint; + + auto *input1_tensor = in_tensors_.at(1); + auto in1_quant_args = input1_tensor->GetQuantParams(); + quant_args_.in1_args_.scale_ = in1_quant_args.front().scale; + quant_args_.in1_args_.zp_ = in1_quant_args.front().zeroPoint; + + auto *out_tensor = out_tensors_.at(kOutputIndex); + auto out_quant_args = out_tensor->GetQuantParams(); + quant_args_.out_args_.scale_ = out_quant_args.front().scale; + quant_args_.out_args_.zp_ = out_quant_args.front().zeroPoint; if (!InferShapeDone()) { return RET_OK; } @@ -142,16 +156,16 @@ int ArithmeticInt8CPUKernel::DoArithmetic(int thread_id) { } int error_code = arithmetic_run_(tile_data0_ + stride * thread_id, tile_data1_ + stride * thread_id, - output_data + stride * thread_id, count); + output_data + stride * thread_id, count, &quant_args_); if (error_code != RET_OK) { MS_LOG(ERROR) << "Arithmetic run fail! ret: " << error_code; - return RET_ERROR; + return error_code; } } else if (arithmetic_run_ != nullptr) { - int error_code = arithmetic_run_(input0_data, input1_data1, output_data, element_num); + int error_code = arithmetic_run_(input0_data, input1_data1, output_data, element_num, &quant_args_); if (error_code != RET_OK) { MS_LOG(ERROR) << "Arithmetic run fail!ret: " << error_code; - return RET_ERROR; + return error_code; } } else { MS_LOG(ERROR) << "arithmetic_run function is nullptr!"; diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/arithmetic_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/arithmetic_int8.h index bd8053c782f3e76cdbd0e97d071b0cbac2cc71ac..186e0623936b8c4c6949ed037029a0bde008ded7 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/arithmetic_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/arithmetic_int8.h @@ -20,10 +20,12 @@ #include #include "src/lite_kernel.h" #include "schema/model_generated.h" +#include "src/runtime/kernel/arm/nnacl/int8/arithmetic_int8.h" namespace mindspore::kernel { class ArithmeticInt8CPUKernel : public LiteKernel { - typedef int (*ArithmeticRunInt8)(int8_t *input0, int8_t *input1, int8_t *output, int element_size); + typedef int (*ArithmeticRunInt8)(int8_t *input0, int8_t *input1, int8_t *output, int element_size, + ArithmeticQuantArg *quant_arg); public: ArithmeticInt8CPUKernel(OpParameter *parameter, const std::vector &inputs, @@ -39,10 +41,10 @@ class ArithmeticInt8CPUKernel : public LiteKernel { private: void FreeTileData(); - int thread_count_; int8_t *tile_data0_; int8_t *tile_data1_; ArithmeticRunInt8 arithmetic_run_; + ArithmeticQuantArg quant_args_; }; } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_ARITHMETIC_INT8_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/arithmetic.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/arithmetic.cc index a7be697b1d8623b3623158d70e8f2005e1770d8b..ef1ca3cc456584c5267a2bfeb8647cd78d3d2c44 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/arithmetic.cc +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/arithmetic.cc @@ -17,6 +17,8 @@ #include "nnacl/fp32/arithmetic.h" #include +#define ACCURACY_DATA 0.00000001 + int ElementMul(float *input0, float *input1, float *output, int element_size) { int block_mod = element_size % C4NUM; int block_c4 = element_size - block_mod; @@ -549,6 +551,14 @@ int BroadcastMinimum(float *input0, float *input1, float *tile_input0, float *ti return ElementMinimum(tile_input0, tile_input1, output, element_size); } +float FloatNotEqualCheck(float in0, float in1) { + float minus = in0 - in1; + if (minus <= ACCURACY_DATA && minus >= -ACCURACY_DATA) { + return (float)false; + } + return (float)true; +} + int ElementNotEqual(float *input0, float *input1, float *output, int element_size) { int block_mod = element_size % C4NUM; int block_c4 = element_size - block_mod; @@ -563,10 +573,10 @@ int ElementNotEqual(float *input0, float *input1, float *output, int element_siz float32x4_t vout = vbslq_f32(vceqq_f32(vin0, vin1), vfalse, vtrue); vst1q_f32(output, vout); #else - output[0] = (float)(input0[0] != input1[0]); - output[1] = (float)(input0[1] != input1[1]); - output[2] = (float)(input0[2] != input1[2]); - output[3] = (float)(input0[3] != input1[3]); + output[0] = FloatNotEqualCheck(input0[0], input1[0]); + output[1] = FloatNotEqualCheck(input0[1], input1[1]); + output[2] = FloatNotEqualCheck(input0[2], input1[2]); + output[3] = FloatNotEqualCheck(input0[3], input1[3]); #endif input0 += C4NUM; input1 += C4NUM; @@ -584,6 +594,14 @@ int BroadcastNotEqual(float *input0, float *input1, float *tile_input0, float *t return ElementNotEqual(tile_input0, tile_input1, output, element_size); } +float FloatEqualCheck(float in0, float in1) { + float minus = in0 - in1; + if (minus <= ACCURACY_DATA && minus >= -ACCURACY_DATA) { + return (float)true; + } + return (float)false; +} + int ElementEqual(float *input0, float *input1, float *output, int element_size) { int block_mod = element_size % C4NUM; int block_c4 = element_size - block_mod; @@ -598,10 +616,10 @@ int ElementEqual(float *input0, float *input1, float *output, int element_size) float32x4_t vout = vbslq_f32(vceqq_f32(vin0, vin1), vtrue, vfalse); vst1q_f32(output, vout); #else - output[0] = (float)(input0[0] == input1[0]); - output[1] = (float)(input0[1] == input1[1]); - output[2] = (float)(input0[2] == input1[2]); - output[3] = (float)(input0[3] == input1[3]); + output[0] = FloatEqualCheck(input0[0], input1[0]); + output[1] = FloatEqualCheck(input0[1], input1[1]); + output[2] = FloatEqualCheck(input0[2], input1[2]); + output[3] = FloatEqualCheck(input0[3], input1[3]); #endif input0 += C4NUM; input1 += C4NUM; @@ -758,3 +776,5 @@ int BroadcastGreaterEqual(float *input0, float *input1, float *tile_input0, floa TileDimensions(input0, input1, tile_input0, tile_input1, param); return ElementGreaterEqual(tile_input0, tile_input1, output, element_size); } + +#undef ACCURACY_DATA diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/arithmetic_int8.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/arithmetic_int8.cc index 893af2bb1566ae904576738e48b44cc3be120a6a..56ad1369aaf7f185599780b82c1830fae4c26f37 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/arithmetic_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/arithmetic_int8.cc @@ -20,44 +20,102 @@ #endif #include "nnacl/errorcode.h" -int ElementNotEqual(int8_t *input0, int8_t *input1, int8_t *output, int element_size) { +#define ACCURACY_DATA 0.00000001 + +int ElementNotEqualInt8(int8_t *input0, int8_t *input1, int8_t *output, int element_size, + ArithmeticQuantArg *quant_arg) { + float in0_bias = -quant_arg->in0_args_.zp_ * quant_arg->in0_args_.scale_; + float in1_bias = -quant_arg->in1_args_.zp_ * quant_arg->in1_args_.scale_; + float output_inverse_scale = 1.f / quant_arg->out_args_.scale_; + float out_zp = quant_arg->out_args_.zp_; for (int index = 0; index < element_size; ++index) { - output[index] = (int8_t)(input0[index] != input1[index]); + float in0_real = input0[index] * quant_arg->in0_args_.scale_ + in0_bias; + float in1_real = input1[index] * quant_arg->in1_args_.scale_ + in1_bias; + float minus_inputs = in0_real - in1_real; + float out_real = (float)true; + if (minus_inputs >= -ACCURACY_DATA && minus_inputs <= ACCURACY_DATA) { + out_real = (float)false; + } + output[index] = (int8_t)(out_real * output_inverse_scale + out_zp); } return NNACL_OK; } -int ElementEqual(int8_t *input0, int8_t *input1, int8_t *output, int element_size) { +int ElementEqualInt8(int8_t *input0, int8_t *input1, int8_t *output, int element_size, ArithmeticQuantArg *quant_arg) { + float in0_bias = -quant_arg->in0_args_.zp_ * quant_arg->in0_args_.scale_; + float in1_bias = -quant_arg->in1_args_.zp_ * quant_arg->in1_args_.scale_; + float output_inverse_scale = 1.f / quant_arg->out_args_.scale_; + float out_zp = quant_arg->out_args_.zp_; for (int index = 0; index < element_size; ++index) { - output[index] = (int8_t)(input0[index] == input1[index]); + float in0_real = input0[index] * quant_arg->in0_args_.scale_ + in0_bias; + float in1_real = input1[index] * quant_arg->in1_args_.scale_ + in1_bias; + float minus_inputs = in0_real - in1_real; + float out_real = (float)false; + if (minus_inputs >= -ACCURACY_DATA && minus_inputs <= ACCURACY_DATA) { + out_real = (float)true; + } + output[index] = (int8_t)(out_real * output_inverse_scale + out_zp); } return NNACL_OK; } -int ElementLess(int8_t *input0, int8_t *input1, int8_t *output, int element_size) { +int ElementLessInt8(int8_t *input0, int8_t *input1, int8_t *output, int element_size, ArithmeticQuantArg *quant_arg) { + float in0_bias = -quant_arg->in0_args_.zp_ * quant_arg->in0_args_.scale_; + float in1_bias = -quant_arg->in1_args_.zp_ * quant_arg->in1_args_.scale_; + float output_inverse_scale = 1.f / quant_arg->out_args_.scale_; + float out_zp = quant_arg->out_args_.zp_; for (int index = 0; index < element_size; ++index) { - output[index] = (int8_t)(input0[index] < input1[index]); + float in0_real = input0[index] * quant_arg->in0_args_.scale_ + in0_bias; + float in1_real = input1[index] * quant_arg->in1_args_.scale_ + in1_bias; + float out_real = (float)(in0_real < in1_real); + output[index] = (int8_t)(out_real * output_inverse_scale + out_zp); } return NNACL_OK; } -int ElementLessEqual(int8_t *input0, int8_t *input1, int8_t *output, int element_size) { +int ElementLessEqualInt8(int8_t *input0, int8_t *input1, int8_t *output, int element_size, + ArithmeticQuantArg *quant_arg) { + float in0_bias = -quant_arg->in0_args_.zp_ * quant_arg->in0_args_.scale_; + float in1_bias = -quant_arg->in1_args_.zp_ * quant_arg->in1_args_.scale_; + float output_inverse_scale = 1.f / quant_arg->out_args_.scale_; + float out_zp = quant_arg->out_args_.zp_; for (int index = 0; index < element_size; ++index) { - output[index] = (int8_t)(input0[index] <= input1[index]); + float in0_real = input0[index] * quant_arg->in0_args_.scale_ + in0_bias; + float in1_real = input1[index] * quant_arg->in1_args_.scale_ + in1_bias; + float out_real = (float)(in0_real <= in1_real); + output[index] = (int8_t)(out_real * output_inverse_scale + out_zp); } return NNACL_OK; } -int ElementGreater(int8_t *input0, int8_t *input1, int8_t *output, int element_size) { +int ElementGreaterInt8(int8_t *input0, int8_t *input1, int8_t *output, int element_size, + ArithmeticQuantArg *quant_arg) { + float in0_bias = -quant_arg->in0_args_.zp_ * quant_arg->in0_args_.scale_; + float in1_bias = -quant_arg->in1_args_.zp_ * quant_arg->in1_args_.scale_; + float output_inverse_scale = 1.f / quant_arg->out_args_.scale_; + float out_zp = quant_arg->out_args_.zp_; for (int index = 0; index < element_size; ++index) { - output[index] = (int8_t)(input0[index] > input1[index]); + float in0_real = input0[index] * quant_arg->in0_args_.scale_ + in0_bias; + float in1_real = input1[index] * quant_arg->in1_args_.scale_ + in1_bias; + float out_real = (float)(in0_real > in1_real); + output[index] = (int8_t)(out_real * output_inverse_scale + out_zp); } return NNACL_OK; } -int ElementGreaterEqual(int8_t *input0, int8_t *input1, int8_t *output, int element_size) { +int ElementGreaterEqualInt8(int8_t *input0, int8_t *input1, int8_t *output, int element_size, + ArithmeticQuantArg *quant_arg) { + float in0_bias = -quant_arg->in0_args_.zp_ * quant_arg->in0_args_.scale_; + float in1_bias = -quant_arg->in1_args_.zp_ * quant_arg->in1_args_.scale_; + float output_inverse_scale = 1.f / quant_arg->out_args_.scale_; + float out_zp = quant_arg->out_args_.zp_; for (int index = 0; index < element_size; ++index) { - output[index] = (int8_t)(input0[index] >= input1[index]); + float in0_real = input0[index] * quant_arg->in0_args_.scale_ + in0_bias; + float in1_real = input1[index] * quant_arg->in1_args_.scale_ + in1_bias; + float out_real = (float)(in0_real >= in1_real); + output[index] = (int8_t)(out_real * output_inverse_scale + out_zp); } return NNACL_OK; } + +#undef ACCURACY_DATA diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/arithmetic_int8.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/arithmetic_int8.h index 01562a9e903b7751ad2d88d9da96b4437b8d6586..1cd20f1469be90c77237bf86e8179fa94f9f4318 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/arithmetic_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/arithmetic_int8.h @@ -17,16 +17,21 @@ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_ARITHMETIC_INT8_H_ #include "nnacl/op_base.h" +#include "nnacl/quantization/quantize.h" -int ElementNotEqual(int8_t *input0, int8_t *input1, int8_t *output, int element_size); +int ElementNotEqualInt8(int8_t *input0, int8_t *input1, int8_t *output, int element_size, + ArithmeticQuantArg *quant_arg); -int ElementEqual(int8_t *input0, int8_t *input1, int8_t *output, int element_size); +int ElementEqualInt8(int8_t *input0, int8_t *input1, int8_t *output, int element_size, ArithmeticQuantArg *quant_arg); -int ElementLess(int8_t *input0, int8_t *input1, int8_t *output, int element_size); +int ElementLessInt8(int8_t *input0, int8_t *input1, int8_t *output, int element_size, ArithmeticQuantArg *quant_arg); -int ElementLessEqual(int8_t *input0, int8_t *input1, int8_t *output, int element_size); +int ElementLessEqualInt8(int8_t *input0, int8_t *input1, int8_t *output, int element_size, + ArithmeticQuantArg *quant_arg); -int ElementGreater(int8_t *input0, int8_t *input1, int8_t *output, int element_size); +int ElementGreaterInt8(int8_t *input0, int8_t *input1, int8_t *output, int element_size, + ArithmeticQuantArg *quant_arg); -int ElementGreaterEqual(int8_t *input0, int8_t *input1, int8_t *output, int element_size); +int ElementGreaterEqualInt8(int8_t *input0, int8_t *input1, int8_t *output, int element_size, + ArithmeticQuantArg *quant_arg); #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_ARITHMETIC_INT8_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/quantization/quantize.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/quantization/quantize.h index ada5292fc4a8ca6ce1797d2d222dba13c76c6a9a..4286ed897964c88bc0514071e6f8c1727ba6ad02 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/quantization/quantize.h +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/quantization/quantize.h @@ -193,6 +193,12 @@ typedef struct SubQuantArg { int right_shift_out_; } SubQuantArg; +typedef struct ArithmeticQuantArg { + QuantArg in0_args_; + QuantArg in1_args_; + QuantArg out_args_; +} ArithmeticQuantArg; + void QuantizeMultiplier(double double_multiplier, int32_t *quantized_multiplier, int *shift); inline void QuantizeMultiplierSmallerThanOne(double double_multiplier, int32_t *quantized_multiplier,