提交 5c8be15d 编写于 作者: S sunsuodong

fix add kernel output address

上级 3b7df4e5
......@@ -39,11 +39,7 @@ int ReluInt8CPUKernel::Init() {
quant_arg_.output_arg.zp_ = output->GetQuantParams().front().zeroPoint;
const double multiplier = quant_arg_.input_arg.scale_ / quant_arg_.output_arg.scale_;
QuantizeMultiplierSmallerThanOne(multiplier, &quant_arg_.input_multiplier_, &quant_arg_.input_shift_);
int left_shift = -quant_arg_.input_shift_ > 0 ? -quant_arg_.input_shift_ : 0;
quant_arg_.right_shift_ = -quant_arg_.input_shift_ > 0 ? 0 : quant_arg_.input_shift_;
quant_arg_.left_shift_result_ = (1 << left_shift);
QuantizeRoundParameter(multiplier, &quant_arg_.input_multiplier_, &quant_arg_.left_shift_, &quant_arg_.right_shift_);
return RET_OK;
}
......
......@@ -83,7 +83,7 @@ void AddInt8NEON(int8_t *input0_data, int8_t *input1_data, int8_t *output_data,
int16x8_t res_s16 = vcombine_s16(sum_low, sum_high);
int8x8_t res_u8_n0 = vqmovn_s16(res_s16);
vst1_s8(output_data, res_u8_n0);
vst1_s8(output_data + *index, res_u8_n0);
}
}
#endif
......@@ -110,13 +110,8 @@ void AddInt8(int8_t *input0_data, int8_t *input1_data, int8_t *output_data, int6
para->output_multiplier_),
para->right_shift_out_) +
para->output_offset_;
if (raw_output > para->output_activation_max_) {
output_data[index] = para->output_activation_max_;
} else if (raw_output < para->output_activation_min_) {
output_data[index] = para->output_activation_min_;
} else {
output_data[index] = (int8_t)raw_output;
}
output_data[index] = (int8_t)MSMAX(para->output_activation_min_, MSMIN(raw_output, para->output_activation_max_));
}
return;
}
......
......@@ -25,9 +25,8 @@ struct ReluQuantArg {
QuantArg input_arg;
QuantArg output_arg;
int input_multiplier_;
int input_shift_;
int left_shift_;
int right_shift_;
int left_shift_result_;
};
inline void ReluInt8(const int8_t *src, int length, int8_t *dst, ReluQuantArg *arg) {
......@@ -38,7 +37,7 @@ inline void ReluInt8(const int8_t *src, int length, int8_t *dst, ReluQuantArg *a
}
const int32_t input_val = src[i] - arg->input_arg.zp_;
const int32_t scaled_input = SaturatingRoundingDoublingHighMul(input_val, arg->input_multiplier_);
const int32_t shifted_input = RoundingDivideByPOT(scaled_input * arg->left_shift_result_, -arg->right_shift_);
const int32_t shifted_input = RoundingDivideByPOT(scaled_input * (1 << arg->left_shift_), -arg->right_shift_);
const int32_t output = shifted_input + arg->output_arg.zp_;
dst[i] = (int8_t)output;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册