提交 203a4d2e 编写于 作者: Z zhanyuan

Fix the bug of matmul_int8's pre-process

上级 b346f0b3
......@@ -225,12 +225,15 @@ void RowMajor2Col16x4Major(int8_t *src, int row, int col, int8_t *dst, int row_1
}
// dst: weight_zp * input_row_sums
void CalcInputSums(int8_t *input, int row, int col, int weight_zp, int *dst) {
void CalcInputSums(int8_t *input, int row, int col, int weight_zp, int *dst, DataOrder order) {
for (int r = 0; r < row; ++r) {
int sum = 0;
for (int c = 0; c < col; ++c) {
int src_idx = r * col + c;
sum += input[src_idx];
if (order == RowMajor) {
sum += input[r * col + c];
} else {
sum += input[c * row + r];
}
}
sum *= weight_zp;
dst[r] = sum;
......@@ -238,12 +241,16 @@ void CalcInputSums(int8_t *input, int row, int col, int weight_zp, int *dst) {
}
// dst: bias + depth*input_zp*weight_zp - input_zp*weight_col_sums
void CalcWeightBiasSums(int8_t *weight, int row, int col, int input_zp, int weight_zp, int *bias, int *dst) {
void CalcWeightBiasSums(int8_t *weight, int row, int col, int input_zp, int weight_zp, int *bias, int *dst,
DataOrder order) {
for (int c = 0; c < col; ++c) {
int sum = 0;
for (int r = 0; r < row; ++r) {
int src_idx = r * col + c;
sum += weight[src_idx];
if (order == RowMajor) {
sum += weight[r * col + c];
} else {
sum += weight[c * row + r];
}
}
dst[c] = row * input_zp * weight_zp - input_zp * sum;
if (bias) {
......
......@@ -37,8 +37,9 @@ void RowMajor2Row16x4MajorInt8(void *src_ptr, void *dst_ptr, int row, int col);
void RowMajor2Row4x16Major(int8_t *src, int row, int col, int8_t *dst, int col_16);
void RowMajor2Col16x4Major(int8_t *src, int row, int col, int8_t *dst, int row_16);
void CalcInputSums(int8_t *a, int row, int col, int b_zp, int *dst);
void CalcWeightBiasSums(int8_t *b, int row, int col, int a_zp, int b_zp, int *bias, int *dst);
void CalcInputSums(int8_t *input, int row, int col, int weight_zp, int *dst, DataOrder order);
void CalcWeightBiasSums(int8_t *weight, int row, int col, int input_zp, int weight_zp, int *bias, int *dst,
DataOrder order);
void MatmulInt8(const int8_t *a, const int8_t *b, int8_t *dst, const int *a_sums, const int *bias, int act_min,
int act_max, int out_zp, int multiplier, int left_shift, int right_shift, int row, int col, int deep16,
int stride);
......
......@@ -55,6 +55,11 @@ typedef enum LiteDataType {
kDataTypeInt8,
} LiteDataType;
typedef enum DataOrder {
RowMajor,
ColMajor,
} DataOrder;
typedef struct OpParameter {
char name_[100];
int type_;
......
......@@ -90,7 +90,7 @@ int FullconnectionInt8CPUKernel::ReSize() {
quant_params_.output.zp_, quant_params_.output.scale_, &quant_params_.out_act_min,
&quant_params_.out_act_max);
CalcWeightBiasSums(weight_data, fc_param_->deep_, fc_param_->col_, quant_params_.input.zp_, quant_params_.weight.zp_,
bias_ptr_, weight_bias_sums_);
bias_ptr_, weight_bias_sums_, ColMajor);
return RET_OK;
}
......@@ -136,7 +136,7 @@ int FullconnectionInt8CPUKernel::Run() {
}
auto input_ptr = reinterpret_cast<int8_t *>(in_tensors_[0]->Data());
RowMajor2Row4x16Major(input_ptr, fc_param_->row_, fc_param_->deep_, a_r4x16_ptr_, d16_);
CalcInputSums(input_ptr, fc_param_->row_, fc_param_->deep_, quant_params_.weight.zp_, input_sums_);
CalcInputSums(input_ptr, fc_param_->row_, fc_param_->deep_, quant_params_.weight.zp_, input_sums_, RowMajor);
ParallelLaunch(THREAD_POOL_DEFAULT, FcInt8Run, this, thread_count_);
return RET_OK;
}
......
......@@ -140,18 +140,21 @@ int MatmulInt8CPUKernel::Run() {
if (params_->a_transpose_) {
RowMajor2Col16x4Major(cur_a_ptr, params_->deep_, params_->row_, a_r4x16_ptr_, d16_);
CalcInputSums(cur_a_ptr, params_->row_, params_->deep_, quant_params_.weight.zp_, input_sums_, ColMajor);
} else {
RowMajor2Row4x16Major(cur_a_ptr, params_->row_, params_->deep_, a_r4x16_ptr_, d16_);
CalcInputSums(cur_a_ptr, params_->row_, params_->deep_, quant_params_.weight.zp_, input_sums_, RowMajor);
}
if (params_->b_transpose_) {
RowMajor2Row4x16Major(cur_b_ptr, params_->col_, params_->deep_, b_c16x4_ptr_, d16_);
CalcWeightBiasSums(cur_b_ptr, params_->deep_, params_->col_, quant_params_.input.zp_, quant_params_.weight.zp_,
NULL, weight_bias_sums_, ColMajor);
} else {
RowMajor2Col16x4Major(cur_b_ptr, params_->deep_, params_->col_, b_c16x4_ptr_, d16_);
CalcWeightBiasSums(cur_b_ptr, params_->deep_, params_->col_, quant_params_.input.zp_, quant_params_.weight.zp_,
NULL, weight_bias_sums_, RowMajor);
}
c_ptr_ = c_ptr + i * c_stride;
auto &q = quant_params_;
CalcInputSums(cur_a_ptr, params_->row_, params_->deep_, q.weight.zp_, input_sums_);
CalcWeightBiasSums(cur_b_ptr, params_->deep_, params_->col_, q.input.zp_, q.weight.zp_, NULL, weight_bias_sums_);
ret = ParallelLaunch(THREAD_POOL_DEFAULT, MatmulInt8Run, this, thread_count_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "MatmulInt8Run error: [" << ret << "]";
......
......@@ -226,7 +226,8 @@ STATUS AwareQuantizer::DoQuantize() {
}
STATUS status;
if (GetCNodeTType(*node) == schema::PrimitiveType_Conv2D ||
GetCNodeTType(*node) == schema::PrimitiveType_DepthwiseConv2D) {
GetCNodeTType(*node) == schema::PrimitiveType_DepthwiseConv2D ||
GetCNodeTType(*node) == schema::PrimitiveType_FullConnection) {
auto inputIndexes = node->inputIndex;
if (inputIndexes.size() < 2) {
MS_LOG(ERROR) << node->name.c_str()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册