提交 a86b78b9 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4085 fix bug of winograd

Merge pull request !4085 from fuzhiye/tmp
......@@ -130,8 +130,9 @@ int Convolution3x3FP16CPUKernel::InitTmpBuffer() {
memset(tmp_dst_buffer_, 0, tmp_dst_buffer_size);
/*=============================tmp_out_============================*/
size_t tmp_out_size = oC8 * C8NUM * conv_param_->output_batch_ * conv_param_->output_h_ * conv_param_->output_w_ *
tile_num * sizeof(float16_t);
int new_out_plane = UP_DIV(conv_param_->output_h_, C4NUM) * UP_DIV(conv_param_->output_w_, C4NUM) * C4NUM * C4NUM;
size_t tmp_out_size =
oC8 * C8NUM * conv_param_->output_batch_ * new_out_plane * sizeof(float16_t);
tmp_out_ = reinterpret_cast<float16_t *>(malloc(tmp_out_size));
if (tmp_out_ == nullptr) {
MS_LOG(ERROR) << "malloc tmp_out_ failed.";
......@@ -278,7 +279,7 @@ int Convolution3x3FP16CPUKernel::Run() {
auto out_tensor = outputs_.at(kOutputIndex);
auto output_addr = reinterpret_cast<float *>(out_tensor->Data());
for (int j = 0; j < out_tensor->ElementsNum(); ++j) {
output_addr[j] = (reinterpret_cast<float *>(fp16_out_))[j];
output_addr[j] = static_cast<float >(fp16_out_[j]);
}
return RET_OK;
}
......
......@@ -77,7 +77,6 @@ void IndirectGemmFp16_16x8_tmp(float16_t *output, float16_t *input, float16_t *w
int oc8_block = j / 8;
int oc8_res = j % 8;
int weight_oc_offset = oc8_block * 36 * ic4 * C4NUM * 8 + oc8_res;
// todo nc4hw4 -> nhwc
int out_oc_offset = output_tile_offset + oc8_block * 36 * C8NUM + oc8_res;
for (int n = 0; n < step; n++) {
......@@ -169,6 +168,7 @@ void Conv3x3Fp16(float16_t *input_data, float16_t *transed_weight, const float16
int thread_count = conv_param->thread_num_;
int tile_num = 16;
int output_unit = 4;
int k_plane = 36;
int ic4 = UP_DIV(conv_param->input_channel_, C4NUM);
int oc8 = UP_DIV(conv_param->output_channel_, C8NUM);
......@@ -181,6 +181,9 @@ void Conv3x3Fp16(float16_t *input_data, float16_t *transed_weight, const float16
int out_h_block = UP_DIV(conv_param->output_h_, C4NUM);
int output_count = out_w_block * out_h_block;
int output_tile_count = UP_DIV(output_count, tile_num);
int tile_buffer_offset = tile_num * k_plane * ic4 * C4NUM;
int block_unit_buffer_offset = k_plane * C4NUM;
int tmp_dst_buffer_offset = tile_num * k_plane * oc8 * C8NUM;
int input_batch = conv_param->input_batch_;
for (int batch = 0; batch < input_batch; batch++) {
......@@ -188,14 +191,16 @@ void Conv3x3Fp16(float16_t *input_data, float16_t *transed_weight, const float16
int start_index = thread_id * tile_num;
int real_cal_num = (output_count - start_index) < tile_num ? (output_count - start_index) : tile_num;
Conv3x3Fp16InputTransform(input_data, tile_buffer, block_unit_buffer, start_index, real_cal_num, out_w_block,
conv_param);
Conv3x3Fp16InputTransform(input_data, tile_buffer + task_id * tile_buffer_offset,
block_unit_buffer + task_id * block_unit_buffer_offset, start_index, real_cal_num,
out_w_block, conv_param);
IndirectGemmFp16_16x8(tmp_dst_buffer, tile_buffer, transed_weight, NULL, 36, ic4, oc8 * C8NUM,
IndirectGemmFp16_16x8(tmp_dst_buffer + task_id * tmp_dst_buffer_offset,
tile_buffer + task_id * tile_buffer_offset, transed_weight, NULL, 36, ic4, oc8 * C8NUM,
oc8 * C8NUM * 36 * sizeof(float16_t), 1, 1, 0, 0);
Conv3x3Fp16OutputTransform(tmp_dst_buffer, tmp_out, bias_data, start_index, real_cal_num, out_w_block,
conv_param);
Conv3x3Fp16OutputTransform(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, tmp_out, bias_data, start_index,
real_cal_num, out_w_block, conv_param);
}
}
......
......@@ -207,7 +207,7 @@ void Conv3x3Fp16InputTransform(const float16_t *input_data, float16_t *trans_inp
int real_y_start = origin_y > 0 ? 0 : -origin_y;
int real_y_end = (origin_y + 6) < input_height ? 6 : (input_height - origin_y);
int src_plane_offset = input_channel * (origin_y * input_width + origin_x);
int src_plane_offset = ic4 * C4NUM * (origin_y * input_width + origin_x);
int dst_plane_offset = cal_id * C4NUM;
for (int ic = 0; ic < ic4; ic++) {
// clear tmp buffer
......@@ -216,10 +216,10 @@ void Conv3x3Fp16InputTransform(const float16_t *input_data, float16_t *trans_inp
// get real input block with padding
int src_ic4_offset = src_plane_offset + ic * C4NUM;
for (int interval = real_y_start; interval < real_y_end; interval++) {
int src_y_offset = src_ic4_offset + interval * input_width * input_channel + real_x_start * input_channel;
int src_y_offset = src_ic4_offset + (interval * input_width + real_x_start) * ic4 * C4NUM;
int dst_y_offset = interval * 6 * C4NUM + real_x_start * C4NUM;
for (int j = 0; j < (real_x_end - real_x_start); j++) {
int src_x_offset = src_y_offset + j * input_channel;
int src_x_offset = src_y_offset + j * ic4 * C4NUM;
int dst_x_offset = dst_y_offset + j * C4NUM;
float16_t *src_addr = (float16_t *)(input_data) + src_x_offset;
float16_t *dst_addr = tmp_data + dst_x_offset;
......@@ -511,7 +511,7 @@ void Conv3x3Fp16OutputTransform(const float16_t *gemm_out, float16_t *out_data,
int output_w = conv_param->output_w_;
int output_h = conv_param->output_h_;
int oc8 = UP_DIV(output_channel, C8NUM);
// todo outputw --> out_w_block * out_unit
for (int i = 0; i < real_cal_num; i++) {
int out_w_index = (start_index + i) % out_w_block;
int out_h_index = (start_index + i) / out_w_block;
......
......@@ -203,19 +203,20 @@ void ConvInt8(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight, c
// clear tmp buffer before compute
memset(gemm_input, (int8_t)input_zp, unit_size * tile_n);
int out_offset = thread_id * tile_n * out_channel + out_batch_offset;
// todo
size_t tmp_dst_size = thread_count * tile_n * conv_param->output_channel_ * sizeof(int32_t);
memset(tmp_dst, 0, tmp_dst_size);
size_t tmp_dst_size = tile_n * conv_param->output_channel_ * sizeof(int32_t);
int tmp_dst_offset = task_id * tile_n * conv_param->output_channel_;
memset(tmp_dst + tmp_dst_offset, 0, tmp_dst_size);
Im2ColPackUnitInt8(input_data + in_batch_offset, gemm_input, real_cal_num, start_index, input_sum, conv_param);
if (real_cal_num == tile_n) {
int8_t *gemm_output = output_data + out_offset;
IndirectGemmInt8(gemm_output, tmp_dst, gemm_input, packed_weight, bias_data, ic4, kernel_plane, out_channel,
input_sum, conv_param);
IndirectGemmInt8(gemm_output, tmp_dst + tmp_dst_offset, gemm_input, packed_weight, bias_data, ic4, kernel_plane,
out_channel, input_sum, conv_param);
} else {
// res part
IndirectGemmInt8(tmp_out, tmp_dst, gemm_input, packed_weight, bias_data, ic4, kernel_plane, out_channel,
input_sum, conv_param);
IndirectGemmInt8(tmp_out, tmp_dst + tmp_dst_offset, gemm_input, packed_weight, bias_data, ic4, kernel_plane,
out_channel, input_sum, conv_param);
memcpy(output_data + out_offset, tmp_out, real_cal_num * out_channel);
}
}
......@@ -257,19 +258,20 @@ void ConvInt8Opt(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight
// clear tmp buffer before compute
memset(gemm_input, (int8_t)input_zp, unit_size * tile_n);
int out_offset = thread_id * tile_n * out_channel + out_batch_offset;
// todo
size_t tmp_dst_size = thread_count * tile_n * conv_param->output_channel_ * sizeof(int32_t);
memset(tmp_dst, 0, tmp_dst_size);
size_t tmp_dst_size = tile_n * conv_param->output_channel_ * sizeof(int32_t);
int tmp_dst_offset = task_id * tile_n * conv_param->output_channel_;
memset(tmp_dst + tmp_dst_offset, 0, tmp_dst_size);
Im2ColPackUnitInt8Opt(input_data + in_batch_offset, gemm_input, real_cal_num, start_index, input_sum, conv_param);
if (real_cal_num == tile_n) {
int8_t *gemm_output = output_data + out_offset;
IndirectGemmInt8Opt(gemm_output, tmp_dst, gemm_input, packed_weight, bias_data, ic4, kernel_plane, out_channel,
input_sum, conv_param, gemm_func);
IndirectGemmInt8Opt(gemm_output, tmp_dst + tmp_dst_offset, gemm_input, packed_weight, bias_data, ic4,
kernel_plane, out_channel, input_sum, conv_param, gemm_func);
} else {
// res part
IndirectGemmInt8Opt(tmp_out, tmp_dst, gemm_input, packed_weight, bias_data, ic4, kernel_plane, out_channel,
input_sum, conv_param, gemm_func);
IndirectGemmInt8Opt(tmp_out, tmp_dst + tmp_dst_offset, gemm_input, packed_weight, bias_data, ic4, kernel_plane,
out_channel, input_sum, conv_param, gemm_func);
memcpy(output_data + out_offset, tmp_out, real_cal_num * out_channel);
}
}
......@@ -290,6 +292,10 @@ void Conv3x3Int8(int16_t *input_data, int16_t *transed_weight, const int32_t *bi
int out_h_block = UP_DIV(conv_param->output_h_, OUPUT_UNIT);
int output_count = out_w_block * out_h_block;
int output_tile_count = UP_DIV(output_count, TILE_NUM);
int oc4 = UP_DIV(output_channel, C4NUM);
int tile_buffer_offset = TILE_NUM * 16 * ic8 * C8NUM;
int block_unit_buffer_offset = 16 * C8NUM;
int tmp_dst_buffer_offset = TILE_NUM * 16 * oc4 * C4NUM;
int input_batch = conv_param->input_batch_;
for (int batch = 0; batch < input_batch; batch++) {
......@@ -297,13 +303,15 @@ void Conv3x3Int8(int16_t *input_data, int16_t *transed_weight, const int32_t *bi
int start_index = thread_id * TILE_NUM;
int real_cal_num = (output_count - start_index) < TILE_NUM ? (output_count - start_index) : TILE_NUM;
Conv3x3Uint8InputTransform(input_data, tile_buffer, block_unit_buffer, start_index, real_cal_num, out_w_block,
conv_param);
Conv3x3Uint8InputTransform(input_data, tile_buffer + task_id * tile_buffer_offset,
block_unit_buffer + task_id * block_unit_buffer_offset, start_index, real_cal_num,
out_w_block, conv_param);
Conv3x3Uint8Gemm(tmp_dst_buffer, tile_buffer, transed_weight, output_channel, ic8, real_cal_num);
Conv3x3Uint8Gemm(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, tile_buffer + task_id * tile_buffer_offset,
transed_weight, output_channel, ic8, real_cal_num);
Conv3x3Uint8OutputTransform(tmp_dst_buffer, tmp_out, bias_data, start_index, real_cal_num, out_w_block,
conv_param);
Conv3x3Uint8OutputTransform(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, tmp_out, bias_data, start_index,
real_cal_num, out_w_block, conv_param);
}
}
......
......@@ -136,7 +136,7 @@ inline uint8_t QuantizeToUint8(float real_value, float scale, int32_t zp) { retu
inline int32_t QuantizeToInt8(float real_value, float scale, int32_t zp) { return round(real_value / scale + zp); }
inline void CalculateActivationRangeQuantized(bool is_relu, bool is_relu6, int32_t zp, int32_t scale, int *mini,
inline void CalculateActivationRangeQuantized(bool is_relu, bool is_relu6, int32_t zp, float scale, int *mini,
int *maxi) {
int32_t min = std::numeric_limits<int8_t>::min();
int32_t max = std::numeric_limits<int8_t>::max();
......
......@@ -584,7 +584,7 @@ void InputTransform8x8Unit(const float *src_data, float *dst_data, int src_step,
vaddq_f32(vaddq_f32(vmulq_n_f32(vaddq_f32(t01, t02), -0.3), vmulq_n_f32(vaddq_f32(t03, t04), 1.33333333333)),
vmulq_n_f32(vaddq_f32(t05, t06), -0.533333333333));
float32x4_t m04 =
vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(t01, t02), 0.3), vmulq_n_f32(vsubq_f32(t03, t04), 1.33333333333)),
vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(t01, t02), 0.3), vmulq_n_f32(vsubq_f32(t04, t03), 1.33333333333)),
vmulq_n_f32(vsubq_f32(t05, t06), 0.533333333333));
float32x4_t m05 =
vaddq_f32(vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t01, 0.03333333), vmulq_n_f32(t02, 0.0222222)),
......@@ -618,7 +618,7 @@ void InputTransform8x8Unit(const float *src_data, float *dst_data, int src_step,
vaddq_f32(vaddq_f32(vmulq_n_f32(vaddq_f32(t11, t12), -0.3), vmulq_n_f32(vaddq_f32(t13, t14), 1.33333333333)),
vmulq_n_f32(vaddq_f32(t15, t16), -0.533333333333));
float32x4_t m14 =
vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(t11, t12), 0.3), vmulq_n_f32(vsubq_f32(t13, t14), 1.33333333333)),
vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(t11, t12), 0.3), vmulq_n_f32(vsubq_f32(t14, t13), 1.33333333333)),
vmulq_n_f32(vsubq_f32(t15, t16), 0.533333333333));
float32x4_t m15 =
vaddq_f32(vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t11, 0.03333333), vmulq_n_f32(t12, 0.0222222)),
......@@ -652,7 +652,7 @@ void InputTransform8x8Unit(const float *src_data, float *dst_data, int src_step,
vaddq_f32(vaddq_f32(vmulq_n_f32(vaddq_f32(t21, t22), -0.3), vmulq_n_f32(vaddq_f32(t23, t24), 1.33333333333)),
vmulq_n_f32(vaddq_f32(t25, t26), -0.533333333333));
float32x4_t m24 =
vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(t21, t22), 0.3), vmulq_n_f32(vsubq_f32(t23, t24), 1.33333333333)),
vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(t21, t22), 0.3), vmulq_n_f32(vsubq_f32(t24, t23), 1.33333333333)),
vmulq_n_f32(vsubq_f32(t25, t26), 0.533333333333));
float32x4_t m25 =
vaddq_f32(vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t21, 0.03333333), vmulq_n_f32(t22, 0.0222222)),
......@@ -686,7 +686,7 @@ void InputTransform8x8Unit(const float *src_data, float *dst_data, int src_step,
vaddq_f32(vaddq_f32(vmulq_n_f32(vaddq_f32(t31, t32), -0.3), vmulq_n_f32(vaddq_f32(t33, t34), 1.33333333333)),
vmulq_n_f32(vaddq_f32(t35, t36), -0.533333333333));
float32x4_t m34 =
vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(t31, t32), 0.3), vmulq_n_f32(vsubq_f32(t33, t34), 1.33333333333)),
vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(t31, t32), 0.3), vmulq_n_f32(vsubq_f32(t34, t33), 1.33333333333)),
vmulq_n_f32(vsubq_f32(t35, t36), 0.533333333333));
float32x4_t m35 =
vaddq_f32(vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t31, 0.03333333), vmulq_n_f32(t32, 0.0222222)),
......@@ -720,7 +720,7 @@ void InputTransform8x8Unit(const float *src_data, float *dst_data, int src_step,
vaddq_f32(vaddq_f32(vmulq_n_f32(vaddq_f32(t41, t42), -0.3), vmulq_n_f32(vaddq_f32(t43, t44), 1.33333333333)),
vmulq_n_f32(vaddq_f32(t45, t46), -0.533333333333));
float32x4_t m44 =
vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(t41, t42), 0.3), vmulq_n_f32(vsubq_f32(t43, t44), 1.33333333333)),
vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(t41, t42), 0.3), vmulq_n_f32(vsubq_f32(t44, t43), 1.33333333333)),
vmulq_n_f32(vsubq_f32(t45, t46), 0.533333333333));
float32x4_t m45 =
vaddq_f32(vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t41, 0.03333333), vmulq_n_f32(t42, 0.0222222)),
......@@ -754,7 +754,7 @@ void InputTransform8x8Unit(const float *src_data, float *dst_data, int src_step,
vaddq_f32(vaddq_f32(vmulq_n_f32(vaddq_f32(t51, t52), -0.3), vmulq_n_f32(vaddq_f32(t53, t54), 1.33333333333)),
vmulq_n_f32(vaddq_f32(t55, t56), -0.533333333333));
float32x4_t m54 =
vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(t51, t52), 0.3), vmulq_n_f32(vsubq_f32(t53, t54), 1.33333333333)),
vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(t51, t52), 0.3), vmulq_n_f32(vsubq_f32(t54, t53), 1.33333333333)),
vmulq_n_f32(vsubq_f32(t55, t56), 0.533333333333));
float32x4_t m55 =
vaddq_f32(vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t51, 0.03333333), vmulq_n_f32(t52, 0.0222222)),
......@@ -788,7 +788,7 @@ void InputTransform8x8Unit(const float *src_data, float *dst_data, int src_step,
vaddq_f32(vaddq_f32(vmulq_n_f32(vaddq_f32(t61, t62), -0.3), vmulq_n_f32(vaddq_f32(t63, t64), 1.33333333333)),
vmulq_n_f32(vaddq_f32(t65, t66), -0.533333333333));
float32x4_t m64 =
vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(t61, t62), 0.3), vmulq_n_f32(vsubq_f32(t63, t64), 1.33333333333)),
vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(t61, t62), 0.3), vmulq_n_f32(vsubq_f32(t64, t63), 1.33333333333)),
vmulq_n_f32(vsubq_f32(t65, t66), 0.533333333333));
float32x4_t m65 =
vaddq_f32(vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t61, 0.03333333), vmulq_n_f32(t62, 0.0222222)),
......@@ -822,7 +822,7 @@ void InputTransform8x8Unit(const float *src_data, float *dst_data, int src_step,
vaddq_f32(vaddq_f32(vmulq_n_f32(vaddq_f32(t71, t72), -0.3), vmulq_n_f32(vaddq_f32(t73, t74), 1.33333333333)),
vmulq_n_f32(vaddq_f32(t75, t76), -0.533333333333));
float32x4_t m74 =
vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(t71, t72), 0.3), vmulq_n_f32(vsubq_f32(t73, t74), 1.33333333333)),
vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(t71, t72), 0.3), vmulq_n_f32(vsubq_f32(t74, t73), 1.33333333333)),
vmulq_n_f32(vsubq_f32(t75, t76), 0.533333333333));
float32x4_t m75 =
vaddq_f32(vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t71, 0.03333333), vmulq_n_f32(t72, 0.0222222)),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册