diff --git a/mindspore/lite/nnacl/int8/div_int8.c b/mindspore/lite/nnacl/int8/div_int8.c index f3b8d86b66723701f623b0825f2d364dd6003df7..1f852cbb39b80af733daba999d92c2bbb308e2d1 100644 --- a/mindspore/lite/nnacl/int8/div_int8.c +++ b/mindspore/lite/nnacl/int8/div_int8.c @@ -29,8 +29,8 @@ int DivInt8(int8_t *input0_data, int8_t *input1_data, int8_t *output_data, int64 } int recip_shift; - const int32_t input1_inv = (input1_val > 0) ? ComputerReciproal(input1_val, 31, &recip_shift) - : -ComputerReciproal(-input1_val, 31, &recip_shift); + const int32_t input1_inv = (input1_val > 0) ? ComputerReciprocal(input1_val, 31, &recip_shift) + : -ComputerReciprocal(-input1_val, 31, &recip_shift); const int leading_bits = CountLeadingSignBits(input0_val); const int32_t raw_data = SaturatingRoundingDoublingHighMul(input0_val * (1 << (unsigned int)leading_bits), input1_inv); diff --git a/mindspore/lite/nnacl/int8/softmax_int8.c b/mindspore/lite/nnacl/int8/softmax_int8.c index 0ffa437d8b7aefce9a53cfdebf249aeed32066c8..7979cf09e64e976d6981a304dbfa9f6652aac396 100644 --- a/mindspore/lite/nnacl/int8/softmax_int8.c +++ b/mindspore/lite/nnacl/int8/softmax_int8.c @@ -58,7 +58,7 @@ int SoftmaxInt8(const int8_t *input_ptr, int8_t *output_ptr, int count, int *exp int axis_offset = outter_offset + i * inner_size; for (int c = 0; c < inner_size; ++c) { int num_bits_over_unit; - int shifted_scale = ComputerReciproal(sum_data[c], 12, &num_bits_over_unit); + int shifted_scale = ComputerReciprocal(sum_data[c], 12, &num_bits_over_unit); int unsat_output = RoundingDivideByPOT( SaturatingRoundingDoublingHighMul(shifted_scale, exp_data[axis_offset + c]), num_bits_over_unit + 31 - 8); diff --git a/mindspore/lite/nnacl/quantization/fixed_point.c b/mindspore/lite/nnacl/quantization/fixed_point.c index c12bac9111f229436cf6f754f2c5ca3b0fa7b471..77339764544fce51ae24c0fc493c1c8691e8dc87 100644 --- a/mindspore/lite/nnacl/quantization/fixed_point.c +++ b/mindspore/lite/nnacl/quantization/fixed_point.c @@ -54,76 +54,34 @@ int MultiplyByQuantizedMultiplier(int32_t value, int32_t multiplier, int32_t lef return RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(value * (1 << left_shift), multiplier), -right_shift); } -int FractionsBits(int kIntegerBits) { - const int totalBits = 8 * sizeof(int32_t) - 1; - return totalBits - kIntegerBits; -} +inline int FractionsBits(int integer_bits) { return 8 * sizeof(int32_t) - 1 - integer_bits; } -int FixedPoint_One(int kIntegerBits, int kFractionsBits) { - return (kIntegerBits == 0 ? INT32_MAX : ((1) << (uint32_t)(kIntegerBits == 0 ? 0 : kFractionsBits))); +inline int FixedPoint_One(int integer_bits, int fractions_bits) { + return (integer_bits == 0 ? INT32_MAX : ((1) << (uint32_t)(integer_bits == 0 ? 0 : fractions_bits))); } -int RoundingHalfSum(int a, int b) { - int64_t a64 = a; - int64_t b64 = b; - int64_t sum = a64 + b64; - int64_t sign = sum > 0 ? 1 : -1; - return (int32_t)((sum + sign) / 2); +int RoundingHalfSum(int32_t a, int32_t b) { + int64_t sum = (int64_t)a + (int64_t)b; + return (int32_t)((sum + (sum > 0 ? 1 : -1)) / 2); } -int32_t BitAnd(int32_t a, int32_t b) { return (uint32_t)a & (uint32_t)b; } +inline int32_t BitAnd(int32_t a, int32_t b) { return (uint32_t)a & (uint32_t)b; } -int32_t BitOr(int32_t a, int32_t b) { return (uint32_t)a | (uint32_t)b; } +inline int32_t BitOr(int32_t a, int32_t b) { return (uint32_t)a | (uint32_t)b; } -int32_t BitXor(int32_t a, int32_t b) { return (uint32_t)a ^ (uint32_t)b; } +inline int32_t BitXor(int32_t a, int32_t b) { return (uint32_t)a ^ (uint32_t)b; } -int32_t BitNot(int32_t a) { return ~(uint32_t)a; } +inline int32_t BitNot(int32_t a) { return ~(uint32_t)a; } -int SelectUsingMask(int mask, int bound, int val) { return BitXor(BitAnd(mask, bound), BitAnd(BitNot(mask), val)); } +inline int BitsSelect(int mask, int bound, int val) { return BitXor(BitAnd(mask, bound), BitAnd(BitNot(mask), val)); } -int32_t MaskNonZero(int32_t a) { - const int32_t zreo = 0; - return a ? BitNot(zreo) : zreo; -} +inline int ConstantPOT(int fractional_bits, int exponent) { return (1 << (uint32_t)(fractional_bits + exponent)); } -static inline int SaturatingRoundingMultiplyByPOT(int32_t x, int Exponent) { - if (Exponent > 0) { - const int min = INT32_MIN; - const int max = INT32_MAX; - const int scalar_int_bits = 8 * sizeof(int32_t); - const int thresold = ((1 << (uint32_t)(scalar_int_bits - 1 - Exponent)) - 1); - const int postive_mask = MaskNonZero(x > thresold); - const int negative_mask = MaskNonZero(x < -thresold); - int result = x * ((int32_t)(1) << (uint32_t)Exponent); - result = SelectUsingMask(postive_mask, max, result); - result = SelectUsingMask(negative_mask, min, result); - return result; - } else if (Exponent < 0) { - return RoundingDivideByPOT(x, -Exponent); - } else { - return x; - } -} +inline int32_t MaskIfNonZero(int32_t a) { return a ? BitNot(0) : 0; } -int32_t Rescale(int x, int kIntegerBitsSrc, int kIntegerBitsDst) { - int kExponent = kIntegerBitsSrc - kIntegerBitsDst; - int result = SaturatingRoundingMultiplyByPOT(x, kExponent); - return result; -} +inline int32_t MaskIfZero(int32_t a) { return MaskIfNonZero(!a); } -int32_t one_over_one_plus_x_for_x_in_0_1(int32_t a) { - int one = FixedPoint_One(0, FractionsBits(0)); - int half_denominator = RoundingHalfSum(a, one); - const int constant_48_over_17 = 1515870810; - const int constant_neg_32_over_17 = -1010580540; - int x = constant_48_over_17 + SaturatingRoundingDoublingHighMul(half_denominator, constant_neg_32_over_17); - for (int i = 0; i < 3; i++) { - int half_denominator_times_x = SaturatingRoundingDoublingHighMul(half_denominator, x); - int one_minus_half_denominator_times_x = FixedPoint_One(2, FractionsBits(2)) - half_denominator_times_x; - x = x + Rescale(SaturatingRoundingDoublingHighMul(x, one_minus_half_denominator_times_x), 2 + 2, 2); - } - return Rescale(x, 2 - 1, 0); -} +inline int32_t MaskIfLessThan(int32_t a, int32_t b) { return MaskIfNonZero((a < b)); } int CountLeadingZeroBits(uint32_t x) { #if defined(__GUNC__) @@ -150,75 +108,97 @@ int CountLeadingSignBits(int32_t x) { #endif } -int32_t ComputerReciproal(int32_t x, int x_digits, int *recip_shift) { +int SaturatingRoundingMultiplyByPOT(int32_t x, int exponent) { + if (exponent > 0) { + const int min = INT32_MIN; + const int max = INT32_MAX; + const int scalar_int_bits = 8 * sizeof(int32_t); + const int thresold = ((1 << (uint32_t)(scalar_int_bits - 1 - exponent)) - 1); + const int postive_mask = x > thresold ? BitNot(0) : 0; + const int negative_mask = x < -thresold ? BitNot(0) : 0; + int result = x * ((int32_t)(1) << (uint32_t)exponent); + result = BitsSelect(postive_mask, max, result); + result = BitsSelect(negative_mask, min, result); + return result; + } else if (exponent < 0) { + return RoundingDivideByPOT(x, -exponent); + } else { + return x; + } +} + +int32_t Rescale(int x, int integer_bits_src, int integer_bits_dst) { + int exponent = integer_bits_src - integer_bits_dst; + return SaturatingRoundingMultiplyByPOT(x, exponent); +} + +int32_t reciprocal_on_interval_between_0_1(int32_t a) { + int one = FixedPoint_One(0, FractionsBits(0)); + int half_sum = RoundingHalfSum(a, one); + const int constant_48_over_17 = 1515870810; + const int constant_neg_32_over_17 = -1010580540; + int x = constant_48_over_17 + SaturatingRoundingDoublingHighMul(half_sum, constant_neg_32_over_17); + for (int i = 0; i < 3; i++) { + int half_sum_times_x = SaturatingRoundingDoublingHighMul(half_sum, x); + int one_minus_half_sum_times_x = FixedPoint_One(2, FractionsBits(2)) - half_sum_times_x; + x = x + Rescale(SaturatingRoundingDoublingHighMul(x, one_minus_half_sum_times_x), 2 + 2, 2); + } + return Rescale(x, 2 - 1, 0); +} + +int32_t ComputerReciprocal(int32_t x, int x_digits, int *recip_shift) { int leading_zreos_plus_one = CountLeadingZeroBits((uint32_t)x); *recip_shift = x_digits - leading_zreos_plus_one; const int32_t shifted_minus_one = (int32_t)(((uint32_t)x << leading_zreos_plus_one) - ((uint32_t)(1) << 31)); - const int32_t shifted_scaled = one_over_one_plus_x_for_x_in_0_1(shifted_minus_one); + const int32_t shifted_scaled = reciprocal_on_interval_between_0_1(shifted_minus_one); return shifted_scaled; } -int ConstantPOT(int fractional_bits, int exponent) { - int offset = fractional_bits + exponent; - return (1 << (uint32_t)offset); -} - -int32_t MaskIfNonZero(int32_t a) { return a ? BitNot(0) : 0; } -int32_t MaskIfZero(int32_t a) { return MaskIfNonZero(!a); } - -int32_t MaskIfLessThan(int32_t a, int32_t b) { return MaskIfNonZero((a < b)); } - -int exp_on_interval_between_negative_one_quarter_and_0_excl(int a) { - const int constant_term = 1895147668; +int exp_on_interval_values(int a) { + const int constant_neg_1_over_8 = 1895147668; const int constant_1_over_3 = 715827883; - // We're evaluating a Taylor expansion around -1/8, so we do the change of - // variable: x = a + 1/8. - // In fixed-point with 0 integer bits, 1/8 is represented by 1 << 28. - int kFractionalBits = FractionsBits(0); - int x = a + ConstantPOT(kFractionalBits, -3); + int fractional_bits = FractionsBits(0); + int x = a + ConstantPOT(fractional_bits, -3); int x2 = SaturatingRoundingDoublingHighMul(x, x); int x3 = SaturatingRoundingDoublingHighMul(x2, x); int x4 = SaturatingRoundingDoublingHighMul(x2, x2); int x4_over_4 = SaturatingRoundingMultiplyByPOT(x4, -2); int x4_over_24_plus_x3_over_6_plus_x2_over_2 = SaturatingRoundingMultiplyByPOT((SaturatingRoundingDoublingHighMul((x4_over_4 + x3), constant_1_over_3) + x2), -1); - return constant_term + - SaturatingRoundingDoublingHighMul(constant_term, (x + x4_over_24_plus_x3_over_6_plus_x2_over_2)); -} - -int exp_on_negative_values(int a, const int tIntegerBits) { - int kIntegerBits = tIntegerBits; - int kFractionalBits = FractionsBits(tIntegerBits); - const int kOneQuarter = ConstantPOT(kFractionalBits, -2); - int mask = kOneQuarter - 1; - int a_mod_quarter_minus_one_quarter = ((unsigned)(a)&mask) - kOneQuarter; - int result = - exp_on_interval_between_negative_one_quarter_and_0_excl(Rescale(a_mod_quarter_minus_one_quarter, tIntegerBits, 0)); - int remainder = a_mod_quarter_minus_one_quarter - a; + return constant_neg_1_over_8 + + SaturatingRoundingDoublingHighMul(constant_neg_1_over_8, (x + x4_over_24_plus_x3_over_6_plus_x2_over_2)); +} -#define GEMMLOWP_EXP_BARREL_SHIFTER(Exponent, FixedPointMultiplier) \ - if (kIntegerBits > Exponent) { \ - const int kMultiplier = FixedPointMultiplier; \ - int kShiftAmount = kIntegerBits > Exponent ? kFractionalBits + Exponent : 0; \ - result = SelectUsingMask(MaskIfNonZero(BitAnd(remainder, (1 << (uint32_t)kShiftAmount))), \ - SaturatingRoundingDoublingHighMul(result, kMultiplier), result); \ - } - GEMMLOWP_EXP_BARREL_SHIFTER(-2, 1672461947); - GEMMLOWP_EXP_BARREL_SHIFTER(-1, 1302514674); - GEMMLOWP_EXP_BARREL_SHIFTER(+0, 790015084); - GEMMLOWP_EXP_BARREL_SHIFTER(+1, 290630308); - GEMMLOWP_EXP_BARREL_SHIFTER(+2, 39332535); - GEMMLOWP_EXP_BARREL_SHIFTER(+3, 720401); - GEMMLOWP_EXP_BARREL_SHIFTER(+4, 242); -#undef GEMMLOWP_EXP_BARREL_SHIFTER - - int clampB = kIntegerBits > 5 ? 36 - kIntegerBits : 0; - if (kIntegerBits > 5) { - const int clamp = -(1 << (uint32_t)clampB); - result = SelectUsingMask(MaskIfLessThan(a, clamp), 0, result); +void exp_barrel_shifter(int exponent, int muliplier, int integer_bits, int fractional_bits, int remainder, + int *result) { + if (integer_bits > exponent) { + int total_shift = integer_bits > exponent ? fractional_bits + exponent : 0; + *result = BitsSelect(MaskIfNonZero(BitAnd(remainder, (1 << (uint32_t)total_shift))), + SaturatingRoundingDoublingHighMul(*result, muliplier), *result); } +} - result = SelectUsingMask(MaskIfZero(a), FixedPoint_One(0, kFractionalBits), result); +int exp_on_negative_values(int a, const int integer_bits) { + int fractional_bits = FractionsBits(integer_bits); + const int one_quarter = ConstantPOT(fractional_bits, -2); + int a_mod_quarter_minus_one_quarter = ((unsigned)(a) & (one_quarter - 1)) - one_quarter; + int result = exp_on_interval_values(Rescale(a_mod_quarter_minus_one_quarter, integer_bits, 0)); + int remainder = a_mod_quarter_minus_one_quarter - a; + + exp_barrel_shifter(-2, 1672461947, integer_bits, fractional_bits, remainder, &result); + exp_barrel_shifter(-1, 1302514674, integer_bits, fractional_bits, remainder, &result); + exp_barrel_shifter(+0, 790015084, integer_bits, fractional_bits, remainder, &result); + exp_barrel_shifter(+1, 290630308, integer_bits, fractional_bits, remainder, &result); + exp_barrel_shifter(+2, 39332535, integer_bits, fractional_bits, remainder, &result); + exp_barrel_shifter(+3, 720401, integer_bits, fractional_bits, remainder, &result); + exp_barrel_shifter(+4, 242, integer_bits, fractional_bits, remainder, &result); + + int clamp_bits = integer_bits > 5 ? 36 - integer_bits : 0; + if (integer_bits > 5) { + const int clamp = -(1 << (uint32_t)clamp_bits); + result = BitsSelect(MaskIfLessThan(a, clamp), 0, result); + } + result = BitsSelect(MaskIfZero(a), FixedPoint_One(0, fractional_bits), result); return result; } diff --git a/mindspore/lite/nnacl/quantization/fixed_point.h b/mindspore/lite/nnacl/quantization/fixed_point.h index e64d76c08f7956ef50d826e7b17a84ce2713b649..8a2fe1602dadbae6df1c9ec306f51f5cbaf42d54 100644 --- a/mindspore/lite/nnacl/quantization/fixed_point.h +++ b/mindspore/lite/nnacl/quantization/fixed_point.h @@ -42,46 +42,14 @@ int RoundingDivideByPOT(int x, int exponent); int MultiplyByQuantizedMultiplier(int32_t value, int32_t multiplier, int32_t left_shift, int32_t right_shift); -int FractionsBits(int kIntegerBits); - -int FixedPoint_One(int kIntegerBits, int kFractionsBits); - -int RoundingHalfSum(int a, int b); - -int32_t BitAnd(int32_t a, int32_t b); - -int32_t BitOr(int32_t a, int32_t b); - -int32_t BitXor(int32_t a, int32_t b); - -int32_t BitNot(int32_t a); - -int SelectUsingMask(int mask, int bound, int val); - -int32_t MaskNonZero(int32_t a); - int32_t Rescale(int x, int kIntegerBitsSrc, int kIntegerBitsDst); -int32_t one_over_one_plus_x_for_x_in_0_1(int32_t a); - -int CountLeadingZeroBits(uint32_t x); - int CountLeadingSignBits(int32_t x); -int32_t ComputerReciproal(int32_t x, int x_digits, int *recip_shift); +int32_t ComputerReciprocal(int32_t x, int x_digits, int *recip_shift); int exp_on_negative_values(int a, const int tIntegerBits); -int ConstantPOT(int fractional_bits, int exponent); - -int32_t MaskIfNonZero(int32_t a); - -int32_t MaskIfZero(int32_t a); - -int32_t MaskIfLessThan(int32_t a, int32_t b); - -int exp_on_interval_between_negative_one_quarter_and_0_excl(int a); - #ifdef __cplusplus } #endif