diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_3x3_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_3x3_fp16.cc index 576c5f37525109cd469497e7d3d3c62f087a148f..6deb79a83cbc1ffa3fdbb4d36d5a49b012a950de 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_3x3_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_3x3_fp16.cc @@ -130,8 +130,9 @@ int Convolution3x3FP16CPUKernel::InitTmpBuffer() { memset(tmp_dst_buffer_, 0, tmp_dst_buffer_size); /*=============================tmp_out_============================*/ - size_t tmp_out_size = oC8 * C8NUM * conv_param_->output_batch_ * conv_param_->output_h_ * conv_param_->output_w_ * - tile_num * sizeof(float16_t); + int new_out_plane = UP_DIV(conv_param_->output_h_, C4NUM) * UP_DIV(conv_param_->output_w_, C4NUM) * C4NUM * C4NUM; + size_t tmp_out_size = + oC8 * C8NUM * conv_param_->output_batch_ * new_out_plane * sizeof(float16_t); tmp_out_ = reinterpret_cast(malloc(tmp_out_size)); if (tmp_out_ == nullptr) { MS_LOG(ERROR) << "malloc tmp_out_ failed."; @@ -278,7 +279,7 @@ int Convolution3x3FP16CPUKernel::Run() { auto out_tensor = outputs_.at(kOutputIndex); auto output_addr = reinterpret_cast(out_tensor->Data()); for (int j = 0; j < out_tensor->ElementsNum(); ++j) { - output_addr[j] = (reinterpret_cast(fp16_out_))[j]; + output_addr[j] = static_cast(fp16_out_[j]); } return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp16/conv_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/fp16/conv_fp16.cc index 0d3da09fe3bcaa42c3478909e55f419e268f81c8..6f4291433e6e1f163c8c794e5e0deae26712a1a3 100644 --- a/mindspore/lite/src/runtime/kernel/arm/opclib/fp16/conv_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp16/conv_fp16.cc @@ -77,7 +77,6 @@ void IndirectGemmFp16_16x8_tmp(float16_t *output, float16_t *input, float16_t *w int oc8_block = j / 8; int oc8_res = j % 8; int weight_oc_offset = oc8_block * 36 * ic4 * C4NUM * 8 + oc8_res; - // todo nc4hw4 -> nhwc int out_oc_offset = output_tile_offset + oc8_block * 36 * C8NUM + oc8_res; for (int n = 0; n < step; n++) { @@ -169,6 +168,7 @@ void Conv3x3Fp16(float16_t *input_data, float16_t *transed_weight, const float16 int thread_count = conv_param->thread_num_; int tile_num = 16; int output_unit = 4; + int k_plane = 36; int ic4 = UP_DIV(conv_param->input_channel_, C4NUM); int oc8 = UP_DIV(conv_param->output_channel_, C8NUM); @@ -181,6 +181,9 @@ void Conv3x3Fp16(float16_t *input_data, float16_t *transed_weight, const float16 int out_h_block = UP_DIV(conv_param->output_h_, C4NUM); int output_count = out_w_block * out_h_block; int output_tile_count = UP_DIV(output_count, tile_num); + int tile_buffer_offset = tile_num * k_plane * ic4 * C4NUM; + int block_unit_buffer_offset = k_plane * C4NUM; + int tmp_dst_buffer_offset = tile_num * k_plane * oc8 * C8NUM; int input_batch = conv_param->input_batch_; for (int batch = 0; batch < input_batch; batch++) { @@ -188,14 +191,16 @@ void Conv3x3Fp16(float16_t *input_data, float16_t *transed_weight, const float16 int start_index = thread_id * tile_num; int real_cal_num = (output_count - start_index) < tile_num ? (output_count - start_index) : tile_num; - Conv3x3Fp16InputTransform(input_data, tile_buffer, block_unit_buffer, start_index, real_cal_num, out_w_block, - conv_param); + Conv3x3Fp16InputTransform(input_data, tile_buffer + task_id * tile_buffer_offset, + block_unit_buffer + task_id * block_unit_buffer_offset, start_index, real_cal_num, + out_w_block, conv_param); - IndirectGemmFp16_16x8(tmp_dst_buffer, tile_buffer, transed_weight, NULL, 36, ic4, oc8 * C8NUM, + IndirectGemmFp16_16x8(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, + tile_buffer + task_id * tile_buffer_offset, transed_weight, NULL, 36, ic4, oc8 * C8NUM, oc8 * C8NUM * 36 * sizeof(float16_t), 1, 1, 0, 0); - Conv3x3Fp16OutputTransform(tmp_dst_buffer, tmp_out, bias_data, start_index, real_cal_num, out_w_block, - conv_param); + Conv3x3Fp16OutputTransform(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, tmp_out, bias_data, start_index, + real_cal_num, out_w_block, conv_param); } } diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp16/winograd_transform_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/fp16/winograd_transform_fp16.cc index 4e586db1d9b31471789925215dd243c727de2f80..78ec1631032ad540fa2b6cadb6d843d68c09945c 100644 --- a/mindspore/lite/src/runtime/kernel/arm/opclib/fp16/winograd_transform_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp16/winograd_transform_fp16.cc @@ -207,7 +207,7 @@ void Conv3x3Fp16InputTransform(const float16_t *input_data, float16_t *trans_inp int real_y_start = origin_y > 0 ? 0 : -origin_y; int real_y_end = (origin_y + 6) < input_height ? 6 : (input_height - origin_y); - int src_plane_offset = input_channel * (origin_y * input_width + origin_x); + int src_plane_offset = ic4 * C4NUM * (origin_y * input_width + origin_x); int dst_plane_offset = cal_id * C4NUM; for (int ic = 0; ic < ic4; ic++) { // clear tmp buffer @@ -216,10 +216,10 @@ void Conv3x3Fp16InputTransform(const float16_t *input_data, float16_t *trans_inp // get real input block with padding int src_ic4_offset = src_plane_offset + ic * C4NUM; for (int interval = real_y_start; interval < real_y_end; interval++) { - int src_y_offset = src_ic4_offset + interval * input_width * input_channel + real_x_start * input_channel; + int src_y_offset = src_ic4_offset + (interval * input_width + real_x_start) * ic4 * C4NUM; int dst_y_offset = interval * 6 * C4NUM + real_x_start * C4NUM; for (int j = 0; j < (real_x_end - real_x_start); j++) { - int src_x_offset = src_y_offset + j * input_channel; + int src_x_offset = src_y_offset + j * ic4 * C4NUM; int dst_x_offset = dst_y_offset + j * C4NUM; float16_t *src_addr = (float16_t *)(input_data) + src_x_offset; float16_t *dst_addr = tmp_data + dst_x_offset; @@ -511,7 +511,7 @@ void Conv3x3Fp16OutputTransform(const float16_t *gemm_out, float16_t *out_data, int output_w = conv_param->output_w_; int output_h = conv_param->output_h_; int oc8 = UP_DIV(output_channel, C8NUM); - +// todo outputw --> out_w_block * out_unit for (int i = 0; i < real_cal_num; i++) { int out_w_index = (start_index + i) % out_w_block; int out_h_index = (start_index + i) / out_w_block; diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/int8/conv_int8.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/int8/conv_int8.cc index 81c1f8f30de53cbb2598f654f12c0a72d9f3e866..1d32f887975c86e2fe8964a5fe6df2a9ea0bc2fd 100644 --- a/mindspore/lite/src/runtime/kernel/arm/opclib/int8/conv_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/int8/conv_int8.cc @@ -203,19 +203,20 @@ void ConvInt8(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight, c // clear tmp buffer before compute memset(gemm_input, (int8_t)input_zp, unit_size * tile_n); int out_offset = thread_id * tile_n * out_channel + out_batch_offset; - // todo - size_t tmp_dst_size = thread_count * tile_n * conv_param->output_channel_ * sizeof(int32_t); - memset(tmp_dst, 0, tmp_dst_size); + + size_t tmp_dst_size = tile_n * conv_param->output_channel_ * sizeof(int32_t); + int tmp_dst_offset = task_id * tile_n * conv_param->output_channel_; + memset(tmp_dst + tmp_dst_offset, 0, tmp_dst_size); Im2ColPackUnitInt8(input_data + in_batch_offset, gemm_input, real_cal_num, start_index, input_sum, conv_param); if (real_cal_num == tile_n) { int8_t *gemm_output = output_data + out_offset; - IndirectGemmInt8(gemm_output, tmp_dst, gemm_input, packed_weight, bias_data, ic4, kernel_plane, out_channel, - input_sum, conv_param); + IndirectGemmInt8(gemm_output, tmp_dst + tmp_dst_offset, gemm_input, packed_weight, bias_data, ic4, kernel_plane, + out_channel, input_sum, conv_param); } else { // res part - IndirectGemmInt8(tmp_out, tmp_dst, gemm_input, packed_weight, bias_data, ic4, kernel_plane, out_channel, - input_sum, conv_param); + IndirectGemmInt8(tmp_out, tmp_dst + tmp_dst_offset, gemm_input, packed_weight, bias_data, ic4, kernel_plane, + out_channel, input_sum, conv_param); memcpy(output_data + out_offset, tmp_out, real_cal_num * out_channel); } } @@ -257,19 +258,20 @@ void ConvInt8Opt(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight // clear tmp buffer before compute memset(gemm_input, (int8_t)input_zp, unit_size * tile_n); int out_offset = thread_id * tile_n * out_channel + out_batch_offset; - // todo - size_t tmp_dst_size = thread_count * tile_n * conv_param->output_channel_ * sizeof(int32_t); - memset(tmp_dst, 0, tmp_dst_size); + + size_t tmp_dst_size = tile_n * conv_param->output_channel_ * sizeof(int32_t); + int tmp_dst_offset = task_id * tile_n * conv_param->output_channel_; + memset(tmp_dst + tmp_dst_offset, 0, tmp_dst_size); Im2ColPackUnitInt8Opt(input_data + in_batch_offset, gemm_input, real_cal_num, start_index, input_sum, conv_param); if (real_cal_num == tile_n) { int8_t *gemm_output = output_data + out_offset; - IndirectGemmInt8Opt(gemm_output, tmp_dst, gemm_input, packed_weight, bias_data, ic4, kernel_plane, out_channel, - input_sum, conv_param, gemm_func); + IndirectGemmInt8Opt(gemm_output, tmp_dst + tmp_dst_offset, gemm_input, packed_weight, bias_data, ic4, + kernel_plane, out_channel, input_sum, conv_param, gemm_func); } else { // res part - IndirectGemmInt8Opt(tmp_out, tmp_dst, gemm_input, packed_weight, bias_data, ic4, kernel_plane, out_channel, - input_sum, conv_param, gemm_func); + IndirectGemmInt8Opt(tmp_out, tmp_dst + tmp_dst_offset, gemm_input, packed_weight, bias_data, ic4, kernel_plane, + out_channel, input_sum, conv_param, gemm_func); memcpy(output_data + out_offset, tmp_out, real_cal_num * out_channel); } } @@ -290,6 +292,10 @@ void Conv3x3Int8(int16_t *input_data, int16_t *transed_weight, const int32_t *bi int out_h_block = UP_DIV(conv_param->output_h_, OUPUT_UNIT); int output_count = out_w_block * out_h_block; int output_tile_count = UP_DIV(output_count, TILE_NUM); + int oc4 = UP_DIV(output_channel, C4NUM); + int tile_buffer_offset = TILE_NUM * 16 * ic8 * C8NUM; + int block_unit_buffer_offset = 16 * C8NUM; + int tmp_dst_buffer_offset = TILE_NUM * 16 * oc4 * C4NUM; int input_batch = conv_param->input_batch_; for (int batch = 0; batch < input_batch; batch++) { @@ -297,13 +303,15 @@ void Conv3x3Int8(int16_t *input_data, int16_t *transed_weight, const int32_t *bi int start_index = thread_id * TILE_NUM; int real_cal_num = (output_count - start_index) < TILE_NUM ? (output_count - start_index) : TILE_NUM; - Conv3x3Uint8InputTransform(input_data, tile_buffer, block_unit_buffer, start_index, real_cal_num, out_w_block, - conv_param); + Conv3x3Uint8InputTransform(input_data, tile_buffer + task_id * tile_buffer_offset, + block_unit_buffer + task_id * block_unit_buffer_offset, start_index, real_cal_num, + out_w_block, conv_param); - Conv3x3Uint8Gemm(tmp_dst_buffer, tile_buffer, transed_weight, output_channel, ic8, real_cal_num); + Conv3x3Uint8Gemm(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, tile_buffer + task_id * tile_buffer_offset, + transed_weight, output_channel, ic8, real_cal_num); - Conv3x3Uint8OutputTransform(tmp_dst_buffer, tmp_out, bias_data, start_index, real_cal_num, out_w_block, - conv_param); + Conv3x3Uint8OutputTransform(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, tmp_out, bias_data, start_index, + real_cal_num, out_w_block, conv_param); } } diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/quantization/quantize.h b/mindspore/lite/src/runtime/kernel/arm/opclib/quantization/quantize.h index 8c6065cf57838b5d8cd2105ab522ce6416a88827..c4fe6984b192e6a89182ffced44b489f70a4c8cc 100644 --- a/mindspore/lite/src/runtime/kernel/arm/opclib/quantization/quantize.h +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/quantization/quantize.h @@ -136,7 +136,7 @@ inline uint8_t QuantizeToUint8(float real_value, float scale, int32_t zp) { retu inline int32_t QuantizeToInt8(float real_value, float scale, int32_t zp) { return round(real_value / scale + zp); } -inline void CalculateActivationRangeQuantized(bool is_relu, bool is_relu6, int32_t zp, int32_t scale, int *mini, +inline void CalculateActivationRangeQuantized(bool is_relu, bool is_relu6, int32_t zp, float scale, int *mini, int *maxi) { int32_t min = std::numeric_limits::min(); int32_t max = std::numeric_limits::max(); diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/winograd_utils.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/winograd_utils.cc index 7081c664af23916d65c4c168b3086100774bcc3d..44707c1f66b5095361367e8cfa1d33d7dc353ebe 100644 --- a/mindspore/lite/src/runtime/kernel/arm/opclib/winograd_utils.cc +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/winograd_utils.cc @@ -584,7 +584,7 @@ void InputTransform8x8Unit(const float *src_data, float *dst_data, int src_step, vaddq_f32(vaddq_f32(vmulq_n_f32(vaddq_f32(t01, t02), -0.3), vmulq_n_f32(vaddq_f32(t03, t04), 1.33333333333)), vmulq_n_f32(vaddq_f32(t05, t06), -0.533333333333)); float32x4_t m04 = - vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(t01, t02), 0.3), vmulq_n_f32(vsubq_f32(t03, t04), 1.33333333333)), + vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(t01, t02), 0.3), vmulq_n_f32(vsubq_f32(t04, t03), 1.33333333333)), vmulq_n_f32(vsubq_f32(t05, t06), 0.533333333333)); float32x4_t m05 = vaddq_f32(vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t01, 0.03333333), vmulq_n_f32(t02, 0.0222222)), @@ -618,7 +618,7 @@ void InputTransform8x8Unit(const float *src_data, float *dst_data, int src_step, vaddq_f32(vaddq_f32(vmulq_n_f32(vaddq_f32(t11, t12), -0.3), vmulq_n_f32(vaddq_f32(t13, t14), 1.33333333333)), vmulq_n_f32(vaddq_f32(t15, t16), -0.533333333333)); float32x4_t m14 = - vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(t11, t12), 0.3), vmulq_n_f32(vsubq_f32(t13, t14), 1.33333333333)), + vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(t11, t12), 0.3), vmulq_n_f32(vsubq_f32(t14, t13), 1.33333333333)), vmulq_n_f32(vsubq_f32(t15, t16), 0.533333333333)); float32x4_t m15 = vaddq_f32(vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t11, 0.03333333), vmulq_n_f32(t12, 0.0222222)), @@ -652,7 +652,7 @@ void InputTransform8x8Unit(const float *src_data, float *dst_data, int src_step, vaddq_f32(vaddq_f32(vmulq_n_f32(vaddq_f32(t21, t22), -0.3), vmulq_n_f32(vaddq_f32(t23, t24), 1.33333333333)), vmulq_n_f32(vaddq_f32(t25, t26), -0.533333333333)); float32x4_t m24 = - vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(t21, t22), 0.3), vmulq_n_f32(vsubq_f32(t23, t24), 1.33333333333)), + vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(t21, t22), 0.3), vmulq_n_f32(vsubq_f32(t24, t23), 1.33333333333)), vmulq_n_f32(vsubq_f32(t25, t26), 0.533333333333)); float32x4_t m25 = vaddq_f32(vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t21, 0.03333333), vmulq_n_f32(t22, 0.0222222)), @@ -686,7 +686,7 @@ void InputTransform8x8Unit(const float *src_data, float *dst_data, int src_step, vaddq_f32(vaddq_f32(vmulq_n_f32(vaddq_f32(t31, t32), -0.3), vmulq_n_f32(vaddq_f32(t33, t34), 1.33333333333)), vmulq_n_f32(vaddq_f32(t35, t36), -0.533333333333)); float32x4_t m34 = - vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(t31, t32), 0.3), vmulq_n_f32(vsubq_f32(t33, t34), 1.33333333333)), + vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(t31, t32), 0.3), vmulq_n_f32(vsubq_f32(t34, t33), 1.33333333333)), vmulq_n_f32(vsubq_f32(t35, t36), 0.533333333333)); float32x4_t m35 = vaddq_f32(vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t31, 0.03333333), vmulq_n_f32(t32, 0.0222222)), @@ -720,7 +720,7 @@ void InputTransform8x8Unit(const float *src_data, float *dst_data, int src_step, vaddq_f32(vaddq_f32(vmulq_n_f32(vaddq_f32(t41, t42), -0.3), vmulq_n_f32(vaddq_f32(t43, t44), 1.33333333333)), vmulq_n_f32(vaddq_f32(t45, t46), -0.533333333333)); float32x4_t m44 = - vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(t41, t42), 0.3), vmulq_n_f32(vsubq_f32(t43, t44), 1.33333333333)), + vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(t41, t42), 0.3), vmulq_n_f32(vsubq_f32(t44, t43), 1.33333333333)), vmulq_n_f32(vsubq_f32(t45, t46), 0.533333333333)); float32x4_t m45 = vaddq_f32(vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t41, 0.03333333), vmulq_n_f32(t42, 0.0222222)), @@ -754,7 +754,7 @@ void InputTransform8x8Unit(const float *src_data, float *dst_data, int src_step, vaddq_f32(vaddq_f32(vmulq_n_f32(vaddq_f32(t51, t52), -0.3), vmulq_n_f32(vaddq_f32(t53, t54), 1.33333333333)), vmulq_n_f32(vaddq_f32(t55, t56), -0.533333333333)); float32x4_t m54 = - vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(t51, t52), 0.3), vmulq_n_f32(vsubq_f32(t53, t54), 1.33333333333)), + vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(t51, t52), 0.3), vmulq_n_f32(vsubq_f32(t54, t53), 1.33333333333)), vmulq_n_f32(vsubq_f32(t55, t56), 0.533333333333)); float32x4_t m55 = vaddq_f32(vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t51, 0.03333333), vmulq_n_f32(t52, 0.0222222)), @@ -788,7 +788,7 @@ void InputTransform8x8Unit(const float *src_data, float *dst_data, int src_step, vaddq_f32(vaddq_f32(vmulq_n_f32(vaddq_f32(t61, t62), -0.3), vmulq_n_f32(vaddq_f32(t63, t64), 1.33333333333)), vmulq_n_f32(vaddq_f32(t65, t66), -0.533333333333)); float32x4_t m64 = - vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(t61, t62), 0.3), vmulq_n_f32(vsubq_f32(t63, t64), 1.33333333333)), + vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(t61, t62), 0.3), vmulq_n_f32(vsubq_f32(t64, t63), 1.33333333333)), vmulq_n_f32(vsubq_f32(t65, t66), 0.533333333333)); float32x4_t m65 = vaddq_f32(vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t61, 0.03333333), vmulq_n_f32(t62, 0.0222222)), @@ -822,7 +822,7 @@ void InputTransform8x8Unit(const float *src_data, float *dst_data, int src_step, vaddq_f32(vaddq_f32(vmulq_n_f32(vaddq_f32(t71, t72), -0.3), vmulq_n_f32(vaddq_f32(t73, t74), 1.33333333333)), vmulq_n_f32(vaddq_f32(t75, t76), -0.533333333333)); float32x4_t m74 = - vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(t71, t72), 0.3), vmulq_n_f32(vsubq_f32(t73, t74), 1.33333333333)), + vaddq_f32(vaddq_f32(vmulq_n_f32(vsubq_f32(t71, t72), 0.3), vmulq_n_f32(vsubq_f32(t74, t73), 1.33333333333)), vmulq_n_f32(vsubq_f32(t75, t76), 0.533333333333)); float32x4_t m75 = vaddq_f32(vaddq_f32(vsubq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t71, 0.03333333), vmulq_n_f32(t72, 0.0222222)),