提交 98cb4faa 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!5398 Modify fixed point functions

Merge pull request !5398 from wangminggui/dev
...@@ -29,8 +29,8 @@ int DivInt8(int8_t *input0_data, int8_t *input1_data, int8_t *output_data, int64 ...@@ -29,8 +29,8 @@ int DivInt8(int8_t *input0_data, int8_t *input1_data, int8_t *output_data, int64
} }
int recip_shift; int recip_shift;
const int32_t input1_inv = (input1_val > 0) ? ComputerReciproal(input1_val, 31, &recip_shift) const int32_t input1_inv = (input1_val > 0) ? ComputerReciprocal(input1_val, 31, &recip_shift)
: -ComputerReciproal(-input1_val, 31, &recip_shift); : -ComputerReciprocal(-input1_val, 31, &recip_shift);
const int leading_bits = CountLeadingSignBits(input0_val); const int leading_bits = CountLeadingSignBits(input0_val);
const int32_t raw_data = const int32_t raw_data =
SaturatingRoundingDoublingHighMul(input0_val * (1 << (unsigned int)leading_bits), input1_inv); SaturatingRoundingDoublingHighMul(input0_val * (1 << (unsigned int)leading_bits), input1_inv);
......
...@@ -58,7 +58,7 @@ int SoftmaxInt8(const int8_t *input_ptr, int8_t *output_ptr, int count, int *exp ...@@ -58,7 +58,7 @@ int SoftmaxInt8(const int8_t *input_ptr, int8_t *output_ptr, int count, int *exp
int axis_offset = outter_offset + i * inner_size; int axis_offset = outter_offset + i * inner_size;
for (int c = 0; c < inner_size; ++c) { for (int c = 0; c < inner_size; ++c) {
int num_bits_over_unit; int num_bits_over_unit;
int shifted_scale = ComputerReciproal(sum_data[c], 12, &num_bits_over_unit); int shifted_scale = ComputerReciprocal(sum_data[c], 12, &num_bits_over_unit);
int unsat_output = RoundingDivideByPOT( int unsat_output = RoundingDivideByPOT(
SaturatingRoundingDoublingHighMul(shifted_scale, exp_data[axis_offset + c]), num_bits_over_unit + 31 - 8); SaturatingRoundingDoublingHighMul(shifted_scale, exp_data[axis_offset + c]), num_bits_over_unit + 31 - 8);
......
...@@ -54,76 +54,34 @@ int MultiplyByQuantizedMultiplier(int32_t value, int32_t multiplier, int32_t lef ...@@ -54,76 +54,34 @@ int MultiplyByQuantizedMultiplier(int32_t value, int32_t multiplier, int32_t lef
return RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(value * (1 << left_shift), multiplier), -right_shift); return RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(value * (1 << left_shift), multiplier), -right_shift);
} }
int FractionsBits(int kIntegerBits) { inline int FractionsBits(int integer_bits) { return 8 * sizeof(int32_t) - 1 - integer_bits; }
const int totalBits = 8 * sizeof(int32_t) - 1;
return totalBits - kIntegerBits;
}
int FixedPoint_One(int kIntegerBits, int kFractionsBits) { inline int FixedPoint_One(int integer_bits, int fractions_bits) {
return (kIntegerBits == 0 ? INT32_MAX : ((1) << (uint32_t)(kIntegerBits == 0 ? 0 : kFractionsBits))); return (integer_bits == 0 ? INT32_MAX : ((1) << (uint32_t)(integer_bits == 0 ? 0 : fractions_bits)));
} }
int RoundingHalfSum(int a, int b) { int RoundingHalfSum(int32_t a, int32_t b) {
int64_t a64 = a; int64_t sum = (int64_t)a + (int64_t)b;
int64_t b64 = b; return (int32_t)((sum + (sum > 0 ? 1 : -1)) / 2);
int64_t sum = a64 + b64;
int64_t sign = sum > 0 ? 1 : -1;
return (int32_t)((sum + sign) / 2);
} }
int32_t BitAnd(int32_t a, int32_t b) { return (uint32_t)a & (uint32_t)b; } inline int32_t BitAnd(int32_t a, int32_t b) { return (uint32_t)a & (uint32_t)b; }
int32_t BitOr(int32_t a, int32_t b) { return (uint32_t)a | (uint32_t)b; } inline int32_t BitOr(int32_t a, int32_t b) { return (uint32_t)a | (uint32_t)b; }
int32_t BitXor(int32_t a, int32_t b) { return (uint32_t)a ^ (uint32_t)b; } inline int32_t BitXor(int32_t a, int32_t b) { return (uint32_t)a ^ (uint32_t)b; }
int32_t BitNot(int32_t a) { return ~(uint32_t)a; } inline int32_t BitNot(int32_t a) { return ~(uint32_t)a; }
int SelectUsingMask(int mask, int bound, int val) { return BitXor(BitAnd(mask, bound), BitAnd(BitNot(mask), val)); } inline int BitsSelect(int mask, int bound, int val) { return BitXor(BitAnd(mask, bound), BitAnd(BitNot(mask), val)); }
int32_t MaskNonZero(int32_t a) { inline int ConstantPOT(int fractional_bits, int exponent) { return (1 << (uint32_t)(fractional_bits + exponent)); }
const int32_t zreo = 0;
return a ? BitNot(zreo) : zreo;
}
static inline int SaturatingRoundingMultiplyByPOT(int32_t x, int Exponent) { inline int32_t MaskIfNonZero(int32_t a) { return a ? BitNot(0) : 0; }
if (Exponent > 0) {
const int min = INT32_MIN;
const int max = INT32_MAX;
const int scalar_int_bits = 8 * sizeof(int32_t);
const int thresold = ((1 << (uint32_t)(scalar_int_bits - 1 - Exponent)) - 1);
const int postive_mask = MaskNonZero(x > thresold);
const int negative_mask = MaskNonZero(x < -thresold);
int result = x * ((int32_t)(1) << (uint32_t)Exponent);
result = SelectUsingMask(postive_mask, max, result);
result = SelectUsingMask(negative_mask, min, result);
return result;
} else if (Exponent < 0) {
return RoundingDivideByPOT(x, -Exponent);
} else {
return x;
}
}
int32_t Rescale(int x, int kIntegerBitsSrc, int kIntegerBitsDst) { inline int32_t MaskIfZero(int32_t a) { return MaskIfNonZero(!a); }
int kExponent = kIntegerBitsSrc - kIntegerBitsDst;
int result = SaturatingRoundingMultiplyByPOT(x, kExponent);
return result;
}
int32_t one_over_one_plus_x_for_x_in_0_1(int32_t a) { inline int32_t MaskIfLessThan(int32_t a, int32_t b) { return MaskIfNonZero((a < b)); }
int one = FixedPoint_One(0, FractionsBits(0));
int half_denominator = RoundingHalfSum(a, one);
const int constant_48_over_17 = 1515870810;
const int constant_neg_32_over_17 = -1010580540;
int x = constant_48_over_17 + SaturatingRoundingDoublingHighMul(half_denominator, constant_neg_32_over_17);
for (int i = 0; i < 3; i++) {
int half_denominator_times_x = SaturatingRoundingDoublingHighMul(half_denominator, x);
int one_minus_half_denominator_times_x = FixedPoint_One(2, FractionsBits(2)) - half_denominator_times_x;
x = x + Rescale(SaturatingRoundingDoublingHighMul(x, one_minus_half_denominator_times_x), 2 + 2, 2);
}
return Rescale(x, 2 - 1, 0);
}
int CountLeadingZeroBits(uint32_t x) { int CountLeadingZeroBits(uint32_t x) {
#if defined(__GUNC__) #if defined(__GUNC__)
...@@ -150,75 +108,97 @@ int CountLeadingSignBits(int32_t x) { ...@@ -150,75 +108,97 @@ int CountLeadingSignBits(int32_t x) {
#endif #endif
} }
int32_t ComputerReciproal(int32_t x, int x_digits, int *recip_shift) { int SaturatingRoundingMultiplyByPOT(int32_t x, int exponent) {
if (exponent > 0) {
const int min = INT32_MIN;
const int max = INT32_MAX;
const int scalar_int_bits = 8 * sizeof(int32_t);
const int thresold = ((1 << (uint32_t)(scalar_int_bits - 1 - exponent)) - 1);
const int postive_mask = x > thresold ? BitNot(0) : 0;
const int negative_mask = x < -thresold ? BitNot(0) : 0;
int result = x * ((int32_t)(1) << (uint32_t)exponent);
result = BitsSelect(postive_mask, max, result);
result = BitsSelect(negative_mask, min, result);
return result;
} else if (exponent < 0) {
return RoundingDivideByPOT(x, -exponent);
} else {
return x;
}
}
int32_t Rescale(int x, int integer_bits_src, int integer_bits_dst) {
int exponent = integer_bits_src - integer_bits_dst;
return SaturatingRoundingMultiplyByPOT(x, exponent);
}
int32_t reciprocal_on_interval_between_0_1(int32_t a) {
int one = FixedPoint_One(0, FractionsBits(0));
int half_sum = RoundingHalfSum(a, one);
const int constant_48_over_17 = 1515870810;
const int constant_neg_32_over_17 = -1010580540;
int x = constant_48_over_17 + SaturatingRoundingDoublingHighMul(half_sum, constant_neg_32_over_17);
for (int i = 0; i < 3; i++) {
int half_sum_times_x = SaturatingRoundingDoublingHighMul(half_sum, x);
int one_minus_half_sum_times_x = FixedPoint_One(2, FractionsBits(2)) - half_sum_times_x;
x = x + Rescale(SaturatingRoundingDoublingHighMul(x, one_minus_half_sum_times_x), 2 + 2, 2);
}
return Rescale(x, 2 - 1, 0);
}
int32_t ComputerReciprocal(int32_t x, int x_digits, int *recip_shift) {
int leading_zreos_plus_one = CountLeadingZeroBits((uint32_t)x); int leading_zreos_plus_one = CountLeadingZeroBits((uint32_t)x);
*recip_shift = x_digits - leading_zreos_plus_one; *recip_shift = x_digits - leading_zreos_plus_one;
const int32_t shifted_minus_one = (int32_t)(((uint32_t)x << leading_zreos_plus_one) - ((uint32_t)(1) << 31)); const int32_t shifted_minus_one = (int32_t)(((uint32_t)x << leading_zreos_plus_one) - ((uint32_t)(1) << 31));
const int32_t shifted_scaled = one_over_one_plus_x_for_x_in_0_1(shifted_minus_one); const int32_t shifted_scaled = reciprocal_on_interval_between_0_1(shifted_minus_one);
return shifted_scaled; return shifted_scaled;
} }
int ConstantPOT(int fractional_bits, int exponent) {
int offset = fractional_bits + exponent;
return (1 << (uint32_t)offset);
}
int32_t MaskIfNonZero(int32_t a) { return a ? BitNot(0) : 0; }
int32_t MaskIfZero(int32_t a) { return MaskIfNonZero(!a); } int exp_on_interval_values(int a) {
const int constant_neg_1_over_8 = 1895147668;
int32_t MaskIfLessThan(int32_t a, int32_t b) { return MaskIfNonZero((a < b)); }
int exp_on_interval_between_negative_one_quarter_and_0_excl(int a) {
const int constant_term = 1895147668;
const int constant_1_over_3 = 715827883; const int constant_1_over_3 = 715827883;
// We're evaluating a Taylor expansion around -1/8, so we do the change of int fractional_bits = FractionsBits(0);
// variable: x = a + 1/8. int x = a + ConstantPOT(fractional_bits, -3);
// In fixed-point with 0 integer bits, 1/8 is represented by 1 << 28.
int kFractionalBits = FractionsBits(0);
int x = a + ConstantPOT(kFractionalBits, -3);
int x2 = SaturatingRoundingDoublingHighMul(x, x); int x2 = SaturatingRoundingDoublingHighMul(x, x);
int x3 = SaturatingRoundingDoublingHighMul(x2, x); int x3 = SaturatingRoundingDoublingHighMul(x2, x);
int x4 = SaturatingRoundingDoublingHighMul(x2, x2); int x4 = SaturatingRoundingDoublingHighMul(x2, x2);
int x4_over_4 = SaturatingRoundingMultiplyByPOT(x4, -2); int x4_over_4 = SaturatingRoundingMultiplyByPOT(x4, -2);
int x4_over_24_plus_x3_over_6_plus_x2_over_2 = int x4_over_24_plus_x3_over_6_plus_x2_over_2 =
SaturatingRoundingMultiplyByPOT((SaturatingRoundingDoublingHighMul((x4_over_4 + x3), constant_1_over_3) + x2), -1); SaturatingRoundingMultiplyByPOT((SaturatingRoundingDoublingHighMul((x4_over_4 + x3), constant_1_over_3) + x2), -1);
return constant_term + return constant_neg_1_over_8 +
SaturatingRoundingDoublingHighMul(constant_term, (x + x4_over_24_plus_x3_over_6_plus_x2_over_2)); SaturatingRoundingDoublingHighMul(constant_neg_1_over_8, (x + x4_over_24_plus_x3_over_6_plus_x2_over_2));
} }
int exp_on_negative_values(int a, const int tIntegerBits) {
int kIntegerBits = tIntegerBits;
int kFractionalBits = FractionsBits(tIntegerBits);
const int kOneQuarter = ConstantPOT(kFractionalBits, -2);
int mask = kOneQuarter - 1;
int a_mod_quarter_minus_one_quarter = ((unsigned)(a)&mask) - kOneQuarter;
int result =
exp_on_interval_between_negative_one_quarter_and_0_excl(Rescale(a_mod_quarter_minus_one_quarter, tIntegerBits, 0));
int remainder = a_mod_quarter_minus_one_quarter - a;
#define GEMMLOWP_EXP_BARREL_SHIFTER(Exponent, FixedPointMultiplier) \ void exp_barrel_shifter(int exponent, int muliplier, int integer_bits, int fractional_bits, int remainder,
if (kIntegerBits > Exponent) { \ int *result) {
const int kMultiplier = FixedPointMultiplier; \ if (integer_bits > exponent) {
int kShiftAmount = kIntegerBits > Exponent ? kFractionalBits + Exponent : 0; \ int total_shift = integer_bits > exponent ? fractional_bits + exponent : 0;
result = SelectUsingMask(MaskIfNonZero(BitAnd(remainder, (1 << (uint32_t)kShiftAmount))), \ *result = BitsSelect(MaskIfNonZero(BitAnd(remainder, (1 << (uint32_t)total_shift))),
SaturatingRoundingDoublingHighMul(result, kMultiplier), result); \ SaturatingRoundingDoublingHighMul(*result, muliplier), *result);
}
GEMMLOWP_EXP_BARREL_SHIFTER(-2, 1672461947);
GEMMLOWP_EXP_BARREL_SHIFTER(-1, 1302514674);
GEMMLOWP_EXP_BARREL_SHIFTER(+0, 790015084);
GEMMLOWP_EXP_BARREL_SHIFTER(+1, 290630308);
GEMMLOWP_EXP_BARREL_SHIFTER(+2, 39332535);
GEMMLOWP_EXP_BARREL_SHIFTER(+3, 720401);
GEMMLOWP_EXP_BARREL_SHIFTER(+4, 242);
#undef GEMMLOWP_EXP_BARREL_SHIFTER
int clampB = kIntegerBits > 5 ? 36 - kIntegerBits : 0;
if (kIntegerBits > 5) {
const int clamp = -(1 << (uint32_t)clampB);
result = SelectUsingMask(MaskIfLessThan(a, clamp), 0, result);
} }
}
result = SelectUsingMask(MaskIfZero(a), FixedPoint_One(0, kFractionalBits), result); int exp_on_negative_values(int a, const int integer_bits) {
int fractional_bits = FractionsBits(integer_bits);
const int one_quarter = ConstantPOT(fractional_bits, -2);
int a_mod_quarter_minus_one_quarter = ((unsigned)(a) & (one_quarter - 1)) - one_quarter;
int result = exp_on_interval_values(Rescale(a_mod_quarter_minus_one_quarter, integer_bits, 0));
int remainder = a_mod_quarter_minus_one_quarter - a;
exp_barrel_shifter(-2, 1672461947, integer_bits, fractional_bits, remainder, &result);
exp_barrel_shifter(-1, 1302514674, integer_bits, fractional_bits, remainder, &result);
exp_barrel_shifter(+0, 790015084, integer_bits, fractional_bits, remainder, &result);
exp_barrel_shifter(+1, 290630308, integer_bits, fractional_bits, remainder, &result);
exp_barrel_shifter(+2, 39332535, integer_bits, fractional_bits, remainder, &result);
exp_barrel_shifter(+3, 720401, integer_bits, fractional_bits, remainder, &result);
exp_barrel_shifter(+4, 242, integer_bits, fractional_bits, remainder, &result);
int clamp_bits = integer_bits > 5 ? 36 - integer_bits : 0;
if (integer_bits > 5) {
const int clamp = -(1 << (uint32_t)clamp_bits);
result = BitsSelect(MaskIfLessThan(a, clamp), 0, result);
}
result = BitsSelect(MaskIfZero(a), FixedPoint_One(0, fractional_bits), result);
return result; return result;
} }
......
...@@ -42,46 +42,14 @@ int RoundingDivideByPOT(int x, int exponent); ...@@ -42,46 +42,14 @@ int RoundingDivideByPOT(int x, int exponent);
int MultiplyByQuantizedMultiplier(int32_t value, int32_t multiplier, int32_t left_shift, int32_t right_shift); int MultiplyByQuantizedMultiplier(int32_t value, int32_t multiplier, int32_t left_shift, int32_t right_shift);
int FractionsBits(int kIntegerBits);
int FixedPoint_One(int kIntegerBits, int kFractionsBits);
int RoundingHalfSum(int a, int b);
int32_t BitAnd(int32_t a, int32_t b);
int32_t BitOr(int32_t a, int32_t b);
int32_t BitXor(int32_t a, int32_t b);
int32_t BitNot(int32_t a);
int SelectUsingMask(int mask, int bound, int val);
int32_t MaskNonZero(int32_t a);
int32_t Rescale(int x, int kIntegerBitsSrc, int kIntegerBitsDst); int32_t Rescale(int x, int kIntegerBitsSrc, int kIntegerBitsDst);
int32_t one_over_one_plus_x_for_x_in_0_1(int32_t a);
int CountLeadingZeroBits(uint32_t x);
int CountLeadingSignBits(int32_t x); int CountLeadingSignBits(int32_t x);
int32_t ComputerReciproal(int32_t x, int x_digits, int *recip_shift); int32_t ComputerReciprocal(int32_t x, int x_digits, int *recip_shift);
int exp_on_negative_values(int a, const int tIntegerBits); int exp_on_negative_values(int a, const int tIntegerBits);
int ConstantPOT(int fractional_bits, int exponent);
int32_t MaskIfNonZero(int32_t a);
int32_t MaskIfZero(int32_t a);
int32_t MaskIfLessThan(int32_t a, int32_t b);
int exp_on_interval_between_negative_one_quarter_and_0_excl(int a);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册