提交 29c04de8 编写于 作者: 王明贵

Modify the quantization operation of softmax op

上级 1ca715c7
......@@ -37,17 +37,24 @@ int SoftmaxInt8CPUKernel::Init() {
auto in_quant_args = input_tensor->GetQuantParams();
quant_params_.in_quant_args_.scale_ = in_quant_args.front().scale;
quant_params_.in_quant_args_.zp_ = in_quant_args.front().zeroPoint;
quant_params_.in_quant_args_.zp_ = -in_quant_args.front().zeroPoint;
auto *out_tensor = out_tensors_.at(kOutputIndex);
MS_ASSERT(out_tensor);
auto out_quant_args = out_tensor->GetQuantParams();
quant_params_.out_quant_arg_.scale_ = out_quant_args.front().scale;
quant_params_.out_quant_arg_.zp_ = out_quant_args.front().zeroPoint;
quant_params_.out_quant_arg_.zp_ = -out_quant_args.front().zeroPoint;
quant_params_.output_activation_min_ = std::numeric_limits<int8_t>::min();
quant_params_.output_activation_max_ = std::numeric_limits<int8_t>::max();
const double input_real_multiplier =
MSMIN(quant_params_.in_quant_args_.scale_ * (1 << (unsigned int)(31 - 5)), (1ll << 31) - 1.0);
int right_shift = 0;
QuantizeMultiplierSmallerThanOne(input_real_multiplier, &quant_params_.output_multiplier_, &right_shift);
quant_params_.shift_left_ = right_shift < 0 ? -right_shift : 0;
quant_params_.shift_right_ = right_shift > 0 ? right_shift : 0;
if (!InferShapeDone()) {
return RET_OK;
}
......@@ -72,12 +79,12 @@ int SoftmaxInt8CPUKernel::ReSize() {
return ret;
}
FreeTmpBuffer();
exp_data_ = reinterpret_cast<float *>(malloc(softmax_param_->element_size_ * sizeof(float)));
exp_data_ = reinterpret_cast<int *>(malloc(softmax_param_->element_size_ * sizeof(int)));
int inner_size = 1;
for (int i = softmax_param_->axis_ + 1; i < softmax_param_->n_dim_; i++) {
inner_size *= softmax_param_->input_shape_[i];
}
sum_data_ = reinterpret_cast<float *>(malloc(inner_size * sizeof(float)));
sum_data_ = reinterpret_cast<int *>(malloc(inner_size * sizeof(int)));
return RET_OK;
}
......@@ -125,12 +132,7 @@ int SoftmaxInt8CPUKernel::Run() {
MS_LOG(ERROR) << "Prepare fail!ret: " << ret;
return RET_ERROR;
}
auto input_ptr = reinterpret_cast<int8_t *>(in_tensors_.at(0)->Data());
int ele_size = softmax_param_->element_size_;
for (int i = 0; i < ele_size; i++) {
float input_scaled = ((input_ptr[i] - quant_params_.in_quant_args_.zp_) * quant_params_.in_quant_args_.scale_);
exp_data_[i] = exp(input_scaled);
}
int error_code = LiteBackendParallelLaunch(SoftmaxRun, this, thread_count_);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "Softmax function error error_code[" << error_code << "]";
......
......@@ -37,8 +37,8 @@ class SoftmaxInt8CPUKernel : public SoftmaxBaseCPUKernel {
private:
void FreeTmpBuffer();
float *sum_data_ = nullptr;
float *exp_data_ = nullptr;
int *sum_data_ = nullptr;
int *exp_data_ = nullptr;
SoftmaxQuantArg quant_params_;
};
} // namespace mindspore::kernel
......
......@@ -16,17 +16,17 @@
#include "nnacl/int8/softmax_int8.h"
#include <math.h>
#include "nnacl/quantization/fixed_point.h"
#include "nnacl/quantization/quantize.h"
#include "nnacl/errorcode.h"
int SoftmaxInt8(const int8_t *input_ptr, int8_t *output_ptr, int count, float *exp_data, float *sum_data,
int SoftmaxInt8(const int8_t *input_ptr, int8_t *output_ptr, int count, int *exp_data, int *sum_data,
SoftmaxQuantArg quant_param, SoftmaxParameter *parameter) {
int32_t axis = parameter->axis_;
int n_dim = parameter->n_dim_;
int *input_shape = parameter->input_shape_;
int axis_shape_size = input_shape[axis];
double output_scale = quant_param.out_quant_arg_.scale_;
int32_t output_zp = quant_param.out_quant_arg_.zp_;
int inner_size = 1;
for (int i = axis + 1; i < n_dim; i++) {
inner_size *= input_shape[i];
......@@ -34,22 +34,37 @@ int SoftmaxInt8(const int8_t *input_ptr, int8_t *output_ptr, int count, float *e
for (int o = 0; o < count; o++) {
int outter_offset = o * axis_shape_size * inner_size;
for (int i = 0; i < inner_size; i++) {
float sum = 0;
for (int j = 0; j < axis_shape_size; j++) {
int axis_offset = outter_offset + i + j * inner_size;
sum += exp_data[axis_offset];
for (int c = 0; c < inner_size; c++) {
int8_t max_row = quant_param.output_activation_min_;
for (int i = 0; i < axis_shape_size; ++i) {
int axis_offset = outter_offset + c + i * inner_size;
max_row = MSMAX(max_row, input_ptr[axis_offset]);
}
sum_data[i] = sum;
int32_t exp_sum = 0;
for (int i = 0; i < axis_shape_size; ++i) {
int axis_offset = outter_offset + c + i * inner_size;
const int32_t input_val = input_ptr[axis_offset] - max_row;
const int32_t input_scaled = SaturatingRoundingDoublingHighMul(
input_val * (1 << (unsigned int)quant_param.shift_left_), quant_param.output_multiplier_);
int exp_val = exp_on_negative_values(input_scaled, 5);
exp_data[axis_offset] = exp_val;
exp_sum = exp_sum + Rescale(exp_val, 0, 12);
}
sum_data[c] = exp_sum;
}
for (int j = 0; j < axis_shape_size; j++) {
int axis_offset = outter_offset + j * inner_size;
for (int i = 0; i < inner_size; i++) {
int inner_offset = axis_offset + i;
float real_output = exp_data[inner_offset] / sum_data[i];
int32_t output_scaled = round(real_output / output_scale) + output_zp;
output_ptr[inner_offset] =
MSMAX(quant_param.output_activation_min_, MSMIN(quant_param.output_activation_max_, output_scaled));
for (int i = 0; i < axis_shape_size; ++i) {
int axis_offset = outter_offset + i * inner_size;
for (int c = 0; c < inner_size; ++c) {
int num_bits_over_unit;
int shifted_scale = ComputerReciproal(sum_data[c], 12, &num_bits_over_unit);
int unsat_output = RoundingDivideByPOT(
SaturatingRoundingDoublingHighMul(shifted_scale, exp_data[axis_offset + c]), num_bits_over_unit + 31 - 8);
int raw_output = unsat_output + quant_param.output_activation_min_;
output_ptr[axis_offset + c] =
(int8_t)MSMAX(quant_param.output_activation_min_, MSMIN(raw_output, quant_param.output_activation_max_));
}
}
}
......
......@@ -24,7 +24,7 @@
#ifdef __cplusplus
extern "C" {
#endif
int SoftmaxInt8(const int8_t *input_ptr, int8_t *output_ptr, int count, float *exp_data, float *sum_data,
int SoftmaxInt8(const int8_t *input_ptr, int8_t *output_ptr, int count, int *exp_data, int *sum_data,
SoftmaxQuantArg quant_param, SoftmaxParameter *parameter);
#ifdef __cplusplus
}
......
......@@ -86,24 +86,22 @@ int32_t MaskNonZero(int32_t a) {
return a ? BitNot(zreo) : zreo;
}
int SaturatingRoundingMultiplyByPOT(int32_t x, int Exponent) {
int ExponentSign = (Exponent > 0 ? 1 : Exponent < 0 ? -1 : 0);
if (ExponentSign == 0) {
return x;
} else if (ExponentSign == 1) {
static inline int SaturatingRoundingMultiplyByPOT(int32_t x, int Exponent) {
if (Exponent > 0) {
const int min = INT32_MIN;
const int max = INT32_MAX;
const int thresold = ((1 << (uint32_t)(31 - Exponent)) - 1);
const int scalar_int_bits = 8 * sizeof(int32_t);
const int thresold = ((1 << (uint32_t)(scalar_int_bits - 1 - Exponent)) - 1);
const int postive_mask = MaskNonZero(x > thresold);
const int negative_mask = MaskNonZero(x < -thresold);
int result = x << Exponent;
int result = x * ((int32_t)(1) << (uint32_t)Exponent);
result = SelectUsingMask(postive_mask, max, result);
result = SelectUsingMask(negative_mask, min, result);
return result;
} else if (ExponentSign == -1) {
} else if (Exponent < 0) {
return RoundingDivideByPOT(x, -Exponent);
} else {
return 0;
return x;
}
}
......@@ -113,7 +111,7 @@ int32_t Rescale(int x, int kIntegerBitsSrc, int kIntegerBitsDst) {
return result;
}
static int32_t one_over_one_plus_x_for_x_in_0_1(int32_t a) {
int32_t one_over_one_plus_x_for_x_in_0_1(int32_t a) {
int one = FixedPoint_One(0, FractionsBits(0));
int half_denominator = RoundingHalfSum(a, one);
const int constant_48_over_17 = 1515870810;
......@@ -159,6 +157,71 @@ int32_t ComputerReciproal(int32_t x, int x_digits, int *recip_shift) {
const int32_t shifted_scaled = one_over_one_plus_x_for_x_in_0_1(shifted_minus_one);
return shifted_scaled;
}
int ConstantPOT(int fractional_bits, int exponent) {
int offset = fractional_bits + exponent;
return (1 << (uint32_t)offset);
}
int32_t MaskIfNonZero(int32_t a) { return a ? BitNot(0) : 0; }
int32_t MaskIfZero(int32_t a) { return MaskIfNonZero(!a); }
int32_t MaskIfLessThan(int32_t a, int32_t b) { return MaskIfNonZero((a < b)); }
int exp_on_interval_between_negative_one_quarter_and_0_excl(int a) {
const int constant_term = 1895147668;
const int constant_1_over_3 = 715827883;
// We're evaluating a Taylor expansion around -1/8, so we do the change of
// variable: x = a + 1/8.
// In fixed-point with 0 integer bits, 1/8 is represented by 1 << 28.
int kFractionalBits = FractionsBits(0);
int x = a + ConstantPOT(kFractionalBits, -3);
int x2 = SaturatingRoundingDoublingHighMul(x, x);
int x3 = SaturatingRoundingDoublingHighMul(x2, x);
int x4 = SaturatingRoundingDoublingHighMul(x2, x2);
int x4_over_4 = SaturatingRoundingMultiplyByPOT(x4, -2);
int x4_over_24_plus_x3_over_6_plus_x2_over_2 =
SaturatingRoundingMultiplyByPOT((SaturatingRoundingDoublingHighMul((x4_over_4 + x3), constant_1_over_3) + x2), -1);
return constant_term +
SaturatingRoundingDoublingHighMul(constant_term, (x + x4_over_24_plus_x3_over_6_plus_x2_over_2));
}
int exp_on_negative_values(int a, const int tIntegerBits) {
int kIntegerBits = tIntegerBits;
int kFractionalBits = FractionsBits(tIntegerBits);
const int kOneQuarter = ConstantPOT(kFractionalBits, -2);
int mask = kOneQuarter - 1;
int a_mod_quarter_minus_one_quarter = ((unsigned)(a)&mask) - kOneQuarter;
int result =
exp_on_interval_between_negative_one_quarter_and_0_excl(Rescale(a_mod_quarter_minus_one_quarter, tIntegerBits, 0));
int remainder = a_mod_quarter_minus_one_quarter - a;
#define GEMMLOWP_EXP_BARREL_SHIFTER(Exponent, FixedPointMultiplier) \
if (kIntegerBits > Exponent) { \
const int kMultiplier = FixedPointMultiplier; \
int kShiftAmount = kIntegerBits > Exponent ? kFractionalBits + Exponent : 0; \
result = SelectUsingMask(MaskIfNonZero(BitAnd(remainder, (1 << (uint32_t)kShiftAmount))), \
SaturatingRoundingDoublingHighMul(result, kMultiplier), result); \
}
GEMMLOWP_EXP_BARREL_SHIFTER(-2, 1672461947);
GEMMLOWP_EXP_BARREL_SHIFTER(-1, 1302514674);
GEMMLOWP_EXP_BARREL_SHIFTER(+0, 790015084);
GEMMLOWP_EXP_BARREL_SHIFTER(+1, 290630308);
GEMMLOWP_EXP_BARREL_SHIFTER(+2, 39332535);
GEMMLOWP_EXP_BARREL_SHIFTER(+3, 720401);
GEMMLOWP_EXP_BARREL_SHIFTER(+4, 242);
#undef GEMMLOWP_EXP_BARREL_SHIFTER
int clampB = kIntegerBits > 5 ? 36 - kIntegerBits : 0;
if (kIntegerBits > 5) {
const int clamp = -(1 << (uint32_t)clampB);
result = SelectUsingMask(MaskIfLessThan(a, clamp), 0, result);
}
result = SelectUsingMask(MaskIfZero(a), FixedPoint_One(0, kFractionalBits), result);
return result;
}
#ifdef ENABLE_NEON
int32x4_t RoundingDivideByPOTInt32x4(int32x4_t x, int exponent) {
const int32x4_t shift_vec = vdupq_n_s32(-exponent);
......
......@@ -60,11 +60,9 @@ int SelectUsingMask(int mask, int bound, int val);
int32_t MaskNonZero(int32_t a);
int SaturatingRoundingMultiplyByPOT(int32_t x, int Exponent);
int32_t Rescale(int x, int kIntegerBitsSrc, int kIntegerBitsDst);
static int32_t one_over_one_plus_x_for_x_in_0_1(int32_t a);
int32_t one_over_one_plus_x_for_x_in_0_1(int32_t a);
int CountLeadingZeroBits(uint32_t x);
......@@ -72,6 +70,18 @@ int CountLeadingSignBits(int32_t x);
int32_t ComputerReciproal(int32_t x, int x_digits, int *recip_shift);
int exp_on_negative_values(int a, const int tIntegerBits);
int ConstantPOT(int fractional_bits, int exponent);
int32_t MaskIfNonZero(int32_t a);
int32_t MaskIfZero(int32_t a);
int32_t MaskIfLessThan(int32_t a, int32_t b);
int exp_on_interval_between_negative_one_quarter_and_0_excl(int a);
#ifdef __cplusplus
}
#endif
......
......@@ -80,9 +80,8 @@ TEST_F(TestSoftmaxInt8, SoftmaxInt8) {
auto output_tensor_shape = output0_tensor.shape();
kernel->Run();
std::vector<int8_t> except_result = {-126, -126, -124, -124, -123, -124, -116, -116, 121, 121, 111, 111,
-127, -127, -127, -127, -59, -59, -61, -59, 57, 57, 59, 57};
std::vector<int8_t> except_result = {-126, -126, -124, -124, -123, -124, -116, -116, 122, 122, 112, 112,
-127, -127, -127, -127, -59, -59, -61, -59, 58, 58, 59, 58};
CompareOutputData(output.data(), except_result.data(), input.size(), 0.000001);
input0_tensor.SetData(nullptr);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册