diff --git a/src/operators/kernel/arm/conv_kernel.cpp b/src/operators/kernel/arm/conv_kernel.cpp index 6b9bd5c970590d2405c5f58a5f7016be5949a511..5384faf2b8ae0e0fe6aed1b6c0cd7d4d16978ac9 100644 --- a/src/operators/kernel/arm/conv_kernel.cpp +++ b/src/operators/kernel/arm/conv_kernel.cpp @@ -55,10 +55,10 @@ bool ConvKernel::Init(ConvParam *param) { param->Input()->dims()[2] <= 140 /* refered from ncnn */) { param->ExecMode() = ConvParam::EXEC_WINOGRAD3X3_FLOAT; // transform weight - framework::Tensor *transformed_weight = new framework::Tensor; + framework::Tensor transformed_weight; operators::math::winograd_transform_weight<8, 3>(*param->Filter(), - transformed_weight); - param->Filter() = transformed_weight; + &transformed_weight); + framework::TensorCopy(transformed_weight, param->Filter()); #endif } else { param->ExecMode() = ConvParam::EXEC_GEMM_FLOAT; diff --git a/src/operators/kernel/arm/quantize_kernel.cpp b/src/operators/kernel/arm/quantize_kernel.cpp index e1b16b2732a1828144cbcd27444e2fe92f28c1ea..4c6e6452c236f0085c4d478d5bc0f06f0455d87b 100644 --- a/src/operators/kernel/arm/quantize_kernel.cpp +++ b/src/operators/kernel/arm/quantize_kernel.cpp @@ -20,6 +20,9 @@ limitations under the License. */ #if defined(__ARM_NEON__) || defined(__ARM_NEON) #include +namespace paddle_mobile { +namespace operators { + #ifndef __aarch64__ inline float32_t vmaxvq_f32(float32x4_t r) { float32x2_t v = vmax_f32(vget_high_f32(r), vget_low_f32(r)); @@ -27,9 +30,13 @@ inline float32_t vmaxvq_f32(float32x4_t r) { } #endif -inline int32x4_t vrnd_towards_zero(float32x4_t r) { return vcvtq_s32_f32(r); } +template +inline int32x4_t vround_f32(float32x4_t r) { + return vcvtq_s32_f32(r); +} -inline int32x4_t vrnd_away_zero(float32x4_t r) { +template <> +inline int32x4_t vround_f32(float32x4_t r) { float32x4_t plus = vdupq_n_f32(0.5); float32x4_t minus = vdupq_n_f32(-0.5); float32x4_t zero = vdupq_n_f32(0); @@ -40,31 +47,13 @@ inline int32x4_t vrnd_away_zero(float32x4_t r) { return ret; } -inline int32x4_t vrnd_to_even(float32x4_t r) { -#if 0 - int32x4_t ret; - float value[4]; - vst1q_f32(value, r); - for (int i = 0; i < 4; ++i) { - float v = round(value[i]); - int32_t q = (int32_t)v; - if (abs(abs(v - value[i]) - 0.5) > 0) { - ret[i] = q; - } else { - if (abs(q) % 2 == 0) { - ret[i] = q; - } else { - ret[i] = q + ((q > 0) ? -1 : 1); - } - } - } - return ret; -#else +template <> +inline int32x4_t vround_f32(float32x4_t r) { float32x4_t point5 = vdupq_n_f32(0.5); int32x4_t one = vdupq_n_s32(1); int32x4_t zero = vdupq_n_s32(0); - int32x4_t rnd = vrnd_away_zero(r); + int32x4_t rnd = vround_f32(r); float32x4_t frnd = vcvtq_f32_s32(rnd); frnd = vsubq_f32(frnd, r); frnd = vabsq_f32(frnd); @@ -82,117 +71,39 @@ inline int32x4_t vrnd_to_even(float32x4_t r) { smask = vsubq_s32(smask, one); rnd = vaddq_s32(rnd, smask); return rnd; -#endif } - -namespace paddle_mobile { -namespace operators { - -static float find_abs_max(const Tensor *input) { - float max_abs = 0.f; - const float *x = input->data(); - size_t size = input->numel(); -#if defined(__ARM_NEON__) || defined(__ARM_NEON) - size_t loop = size >> 4; - size_t remain = size & 0xF; - for (size_t i = 0; i < loop; ++i) { - float32x4_t max; - float32x4_t r0 = vld1q_f32(x); - float32x4_t r1 = vld1q_f32(x + 4); - float32x4_t r2 = vld1q_f32(x + 8); - float32x4_t r3 = vld1q_f32(x + 12); - r0 = vabsq_f32(r0); - r1 = vabsq_f32(r1); - r2 = vabsq_f32(r2); - r3 = vabsq_f32(r3); - max[0] = vmaxvq_f32(r0); - max[1] = vmaxvq_f32(r1); - max[2] = vmaxvq_f32(r2); - max[3] = vmaxvq_f32(r3); - max[0] = vmaxvq_f32(max); - if (max[0] > max_abs) { - max_abs = max[0]; - } - x += 16; - } - size = remain; #endif - for (size_t i = 0; i < size; ++i) { - float value = std::abs(x[i]); - if (value > max_abs) { - max_abs = value; - } - } - return max_abs; + +template +inline int8_t Round(const float &x) { + return static_cast(x); } -#ifdef __aarch64__ -static void quantize_round_to_even(const Tensor *input, const float scale, - const std::vector &paddings, - const int8_t padding_val, Tensor *output) { - const float *x = input->data(); - int8_t *y = output->mutable_data(); - size_t size = input->numel(); -#if defined(__ARM_NEON__) || defined(__ARM_NEON) - size_t loop = size >> 4; - size_t remain = size & 0xF; +template <> +inline int8_t Round(const float &x) { + return std::round(x); +} - #pragma omp parallel for - for (size_t i = 0; i < loop; ++i) { - const float *local_x = x + (i << 4); - int8_t *local_y = y + (i << 4); - float32x4_t r0 = vld1q_f32(local_x); - float32x4_t r1 = vld1q_f32(local_x + 4); - float32x4_t r2 = vld1q_f32(local_x + 8); - float32x4_t r3 = vld1q_f32(local_x + 12); - r0 = vmulq_n_f32(r0, scale); - r1 = vmulq_n_f32(r1, scale); - r2 = vmulq_n_f32(r2, scale); - r3 = vmulq_n_f32(r3, scale); - int32x4_t q0 = vrnd_to_even(r0); - int32x4_t q1 = vrnd_to_even(r1); - int32x4_t q2 = vrnd_to_even(r2); - int32x4_t q3 = vrnd_to_even(r3); - int16x4_t d0 = vmovn_s32(q0); - int16x4_t d1 = vmovn_s32(q1); - int16x4_t d2 = vmovn_s32(q2); - int16x4_t d3 = vmovn_s32(q3); - int16x8_t q5 = vcombine_s16(d0, d1); - int16x8_t q6 = vcombine_s16(d2, d3); - int8x8_t d5 = vmovn_s16(q5); - int8x8_t d6 = vmovn_s16(q6); - vst1_s8(local_y, d5); - vst1_s8(local_y + 8, d6); - } - size = remain; - x += (loop << 4); - y += (loop << 4); -#endif - for (size_t i = 0; i < size; ++i) { - float value = x[i] * scale; - float v = round(value); - int32_t q = (int32_t)v; - if (abs(abs(q - value) - 0.5) > 0) { - y[i] = q; - } else { - if (abs(q) % 2 == 0) { - y[i] = q; - } else { - y[i] = q + ((q > 0) ? -1 : 1); - } +template <> +inline int8_t Round(const float &x) { + float v = std::round(x); + int32_t q = static_cast(v); + if (abs(abs(q - v) - 0.5) <= 0) { + if (abs(q) % 2 != 0) { + q = q + ((q > 0) ? -1 : 1); } } + return static_cast(q); } -static void quantize_round_to_zero(const Tensor *input, const float scale, - const std::vector &paddings, - const int8_t padding_val, Tensor *output) { +template +static void Quantize(const Tensor *input, const float scale, Tensor *output) { const float *x = input->data(); int8_t *y = output->mutable_data(); - size_t size = input->numel(); + size_t remain = input->numel(); #if defined(__ARM_NEON__) || defined(__ARM_NEON) - size_t loop = size >> 4; - size_t remain = size & 0xF; + size_t loop = remain >> 4; + remain = remain & 0xF; #pragma omp parallel for for (size_t i = 0; i < loop; ++i) { @@ -206,10 +117,10 @@ static void quantize_round_to_zero(const Tensor *input, const float scale, r1 = vmulq_n_f32(r1, scale); r2 = vmulq_n_f32(r2, scale); r3 = vmulq_n_f32(r3, scale); - int32x4_t q0 = vrnd_towards_zero(r0); - int32x4_t q1 = vrnd_towards_zero(r1); - int32x4_t q2 = vrnd_towards_zero(r2); - int32x4_t q3 = vrnd_towards_zero(r3); + int32x4_t q0 = vround_f32(r0); + int32x4_t q1 = vround_f32(r1); + int32x4_t q2 = vround_f32(r2); + int32x4_t q3 = vround_f32(r3); int16x4_t d0 = vmovn_s32(q0); int16x4_t d1 = vmovn_s32(q1); int16x4_t d2 = vmovn_s32(q2); @@ -221,563 +132,44 @@ static void quantize_round_to_zero(const Tensor *input, const float scale, vst1_s8(local_y, d5); vst1_s8(local_y + 8, d6); } - size = remain; x += (loop << 4); y += (loop << 4); #endif - for (size_t i = 0; i < size; ++i) { - y[i] = static_cast(x[i] * scale); + for (size_t i = 0; i < remain; ++i) { + y[i] = Round(x[i] * scale); } } -static void quantize_round_to_nearest(const Tensor *input, const float scale, - const std::vector &paddings, - const int8_t padding_val, - Tensor *output) { +float find_abs_max(const Tensor *input) { + float max_abs = 0.f; const float *x = input->data(); - int8_t *y = output->mutable_data(); - size_t size = input->numel(); + size_t remain = input->numel(); #if defined(__ARM_NEON__) || defined(__ARM_NEON) - size_t loop = size >> 4; - size_t remain = size & 0xF; + size_t loop = remain >> 4; + remain = remain & 0xF; + float32x4_t __max = {0.f, 0.f, 0.f, 0.f}; - #pragma omp parallel for - for (size_t i = 0; i < loop; ++i) { - const float *local_x = x + (i << 4); - int8_t *local_y = y + (i << 4); - float32x4_t r0 = vld1q_f32(local_x); - float32x4_t r1 = vld1q_f32(local_x + 4); - float32x4_t r2 = vld1q_f32(local_x + 8); - float32x4_t r3 = vld1q_f32(local_x + 12); - r0 = vmulq_n_f32(r0, scale); - r1 = vmulq_n_f32(r1, scale); - r2 = vmulq_n_f32(r2, scale); - r3 = vmulq_n_f32(r3, scale); - int32x4_t q0 = vrnd_away_zero(r0); - int32x4_t q1 = vrnd_away_zero(r1); - int32x4_t q2 = vrnd_away_zero(r2); - int32x4_t q3 = vrnd_away_zero(r3); - int16x4_t d0 = vmovn_s32(q0); - int16x4_t d1 = vmovn_s32(q1); - int16x4_t d2 = vmovn_s32(q2); - int16x4_t d3 = vmovn_s32(q3); - int16x8_t q5 = vcombine_s16(d0, d1); - int16x8_t q6 = vcombine_s16(d2, d3); - int8x8_t d5 = vmovn_s16(q5); - int8x8_t d6 = vmovn_s16(q6); - vst1_s8(local_y, d5); - vst1_s8(local_y + 8, d6); + for (size_t i = 0; i < loop; ++i, x += 16) { + float32x4_t r0 = vld1q_f32(x); + float32x4_t r1 = vld1q_f32(x + 4); + float32x4_t r2 = vld1q_f32(x + 8); + float32x4_t r3 = vld1q_f32(x + 12); + r0 = vabsq_f32(r0); + r1 = vabsq_f32(r1); + r2 = vabsq_f32(r2); + r3 = vabsq_f32(r3); + r0 = vmaxq_f32(r0, r1); + r1 = vmaxq_f32(r2, r3); + r0 = vmaxq_f32(r0, r1); + __max = vmaxq_f32(r0, __max); } - size = remain; - x += (loop << 4); - y += (loop << 4); + max_abs = vmaxvq_f32(__max); #endif - for (size_t i = 0; i < size; ++i) { - y[i] = round(x[i] * scale); - } -} -#else // __aarch64__ - -static void quantize_round_to_even(const Tensor *input, const float scale, - const std::vector &paddings, - const int8_t padding_val, Tensor *output) {} - -static void quantize_round_to_nearest(const Tensor *input, const float scale, - const std::vector &paddings, - const int8_t padding_val, - Tensor *output) {} - -static void quantize_round_to_zero(const Tensor *input, const float scale, - const std::vector &paddings, - const int8_t padding_val, Tensor *output) { - int channels = input->dims()[1]; - int input_h = input->dims()[2]; - int input_w = input->dims()[3]; - int output_h = output->dims()[2]; - int output_w = output->dims()[3]; - int input_spatial_size = input_h * input_w; - int output_spatial_size = output_h * output_w; - const float *x = input->data(); - int8_t *y = output->mutable_data(); - // valid area start - int start = paddings[0] * output_w + paddings[1]; - - for (int batch = 0; batch < input->dims()[0]; ++batch) { - #pragma omp parallel for - for (int c = 0; c < channels - 3; c += 4) { - const float *input0 = x + (batch * channels + c) * input_spatial_size; - const float *input1 = input0 + input_spatial_size; - const float *input2 = input1 + input_spatial_size; - const float *input3 = input2 + input_spatial_size; - size_t offset = (batch * channels + c) * output_spatial_size; - for (int h = 0; h < 2; ++h) { - int8_t *y0 = - y + offset + h * ((input_h + paddings[0]) * output_w - paddings[1]); - int8_t *y1 = y0 + output_spatial_size; - int8_t *y2 = y1 + output_spatial_size; - int8_t *y3 = y2 + output_spatial_size; - int loop = start >> 4; - int remain = start & 0xF; - asm volatile( - "vdup.s8 q0, %[val] \n" - "cmp %[loop], #0 \n" - "ble start_remain_%= \n" - - "store_16w_%=: \n" - "vst1.32 {q0}, [%[y0]]! \n" - "vst1.32 {q0}, [%[y1]]! \n" - "vst1.32 {q0}, [%[y2]]! \n" - "vst1.32 {q0}, [%[y3]]! \n" - "subs %[loop], #1 \n" - "bne store_16w_%= \n" - - "start_remain_%=: \n" - "cmp %[remain], #8 \n" - "blt store_4w_%= \n" - "vst1.32 {d0}, [%[y0]]! \n" - "vst1.32 {d0}, [%[y1]]! \n" - "vst1.32 {d0}, [%[y2]]! \n" - "vst1.32 {d0}, [%[y3]]! \n" - "sub %[remain], #8 \n" - - "store_4w_%=: \n" - "cmp %[remain], #4 \n" - "blt store_2w_%= \n" - "vst1.32 {d0[0]}, [%[y0]]! \n" - "vst1.32 {d0[0]}, [%[y1]]! \n" - "vst1.32 {d0[0]}, [%[y2]]! \n" - "vst1.32 {d0[0]}, [%[y3]]! \n" - "sub %[remain], #4 \n" - - "store_2w_%=: \n" - "cmp %[remain], #4 \n" - "blt store_1w_%= \n" - "vst1.16 {d0[0]}, [%[y0]]! \n" - "vst1.16 {d0[0]}, [%[y1]]! \n" - "vst1.16 {d0[0]}, [%[y2]]! \n" - "vst1.16 {d0[0]}, [%[y3]]! \n" - "sub %[remain], #2 \n" - - "store_1w_%=: \n" - "cmp %[remain], #1 \n" - "blt end_%= \n" - "vst1.8 {d0[0]}, [%[y0]]! \n" - "vst1.8 {d0[0]}, [%[y1]]! \n" - "vst1.8 {d0[0]}, [%[y2]]! \n" - "vst1.8 {d0[0]}, [%[y3]]! \n" - "end_%=: \n" - : [y0] "+r"(y0), [y1] "+r"(y1), [y2] "+r"(y2), [y3] "+r"(y3), - [loop] "+r"(loop), [remain] "+r"(remain) - : [val] "r"(padding_val) - : "cc", "memory", "q0"); - } - // quantize valid area - int8_t *y0 = y + offset + start; - int8_t *y1 = y0 + output_spatial_size; - int8_t *y2 = y1 + output_spatial_size; - int8_t *y3 = y2 + output_spatial_size; - for (int h = 0; h < input_h; ++h) { - const float *x0 = input0 + h * input_w; - const float *x1 = input1 + h * input_w; - const float *x2 = input2 + h * input_w; - const float *x3 = input3 + h * input_w; - int loop = input_w >> 4; - int remain = input_w & 0xF; - int pad_loop = paddings[1] >> 1; // (paddings[1] << 1) >> 2 - int pad_remain = (paddings[1] << 1) & 0x3; - int remain_steps = remain; - asm volatile( - "vdup.f32 q0, %[scale] \n" - "cmp %[loop], #0 \n" - "ble quantize_remain_%= \n" - - "loop_quantize_%=: \n" - "vld1.32 {q1, q2}, [%[x0]]! \n" - "vld1.32 {q3, q4}, [%[x1]]! \n" - "vld1.32 {q5, q6}, [%[x2]]! \n" - "vld1.32 {q7, q8}, [%[x3]]! \n" - "vmul.f32 q1, q1, q0 \n" - "vmul.f32 q2, q2, q0 \n" - "vmul.f32 q3, q3, q0 \n" - "vmul.f32 q4, q4, q0 \n" - "vmul.f32 q5, q5, q0 \n" - "vmul.f32 q6, q6, q0 \n" - "vmul.f32 q7, q7, q0 \n" - "vmul.f32 q8, q8, q0 \n" - "vcvt.s32.f32 q1, q1 \n" - "vcvt.s32.f32 q2, q2 \n" - "vcvt.s32.f32 q3, q3 \n" - "vcvt.s32.f32 q4, q4 \n" - "vcvt.s32.f32 q5, q5 \n" - "vcvt.s32.f32 q6, q6 \n" - "vcvt.s32.f32 q7, q7 \n" - "vcvt.s32.f32 q8, q8 \n" - "vmovn.s32 d2, q1 \n" - "vmovn.s32 d3, q2 \n" - "vmovn.s32 d4, q3 \n" - "vmovn.s32 d5, q4 \n" - "vmovn.s32 d6, q5 \n" - "vmovn.s32 d7, q6 \n" - "vmovn.s32 d8, q7 \n" - "vmovn.s32 d9, q8 \n" - "vmovn.s16 d18, q1 \n" - "vmovn.s16 d20, q2 \n" - "vmovn.s16 d22, q3 \n" - "vmovn.s16 d24, q4 \n" - "vld1.32 {q1, q2}, [%[x0]]! \n" - "vld1.32 {q3, q4}, [%[x1]]! \n" - "vld1.32 {q5, q6}, [%[x2]]! \n" - "vld1.32 {q7, q8}, [%[x3]]! \n" - "vmul.f32 q1, q1, q0 \n" - "vmul.f32 q2, q2, q0 \n" - "vmul.f32 q3, q3, q0 \n" - "vmul.f32 q4, q4, q0 \n" - "vmul.f32 q5, q5, q0 \n" - "vmul.f32 q6, q6, q0 \n" - "vmul.f32 q7, q7, q0 \n" - "vmul.f32 q8, q8, q0 \n" - "vcvt.s32.f32 q1, q1 \n" - "vcvt.s32.f32 q2, q2 \n" - "vcvt.s32.f32 q3, q3 \n" - "vcvt.s32.f32 q4, q4 \n" - "vcvt.s32.f32 q5, q5 \n" - "vcvt.s32.f32 q6, q6 \n" - "vcvt.s32.f32 q7, q7 \n" - "vcvt.s32.f32 q8, q8 \n" - "vmovn.s32 d2, q1 \n" - "vmovn.s32 d3, q2 \n" - "vmovn.s32 d4, q3 \n" - "vmovn.s32 d5, q4 \n" - "vmovn.s32 d6, q5 \n" - "vmovn.s32 d7, q6 \n" - "vmovn.s32 d8, q7 \n" - "vmovn.s32 d9, q8 \n" - "vmovn.s16 d19, q1 \n" - "vmovn.s16 d21, q2 \n" - "vmovn.s16 d23, q3 \n" - "vmovn.s16 d25, q4 \n" - "vst1.32 {q9}, [%[y0]]! \n" - "vst1.32 {q10}, [%[y1]]! \n" - "vst1.32 {q11}, [%[y2]]! \n" - "vst1.32 {q12}, [%[y3]]! \n" - - "subs %[loop], #1 \n" - "bne loop_quantize_%= \n" - - "quantize_remain_%=: \n" - "cmp %[remain], #0 \n" - "ble end_%= \n" - - "vld1.32 {q1, q2}, [%[x0]]! \n" - "vld1.32 {q3, q4}, [%[x1]]! \n" - "vld1.32 {q5, q6}, [%[x2]]! \n" - "vld1.32 {q7, q8}, [%[x3]]! \n" - "vmul.f32 q1, q1, q0 \n" - "vmul.f32 q2, q2, q0 \n" - "vmul.f32 q3, q3, q0 \n" - "vmul.f32 q4, q4, q0 \n" - "vmul.f32 q5, q5, q0 \n" - "vmul.f32 q6, q6, q0 \n" - "vmul.f32 q7, q7, q0 \n" - "vmul.f32 q8, q8, q0 \n" - "vcvt.s32.f32 q1, q1 \n" - "vcvt.s32.f32 q2, q2 \n" - "vcvt.s32.f32 q3, q3 \n" - "vcvt.s32.f32 q4, q4 \n" - "vcvt.s32.f32 q5, q5 \n" - "vcvt.s32.f32 q6, q6 \n" - "vcvt.s32.f32 q7, q7 \n" - "vcvt.s32.f32 q8, q8 \n" - "vmovn.s32 d2, q1 \n" - "vmovn.s32 d3, q2 \n" - "vmovn.s32 d4, q3 \n" - "vmovn.s32 d5, q4 \n" - "vmovn.s32 d6, q5 \n" - "vmovn.s32 d7, q6 \n" - "vmovn.s32 d8, q7 \n" - "vmovn.s32 d9, q8 \n" - "vmovn.s16 d18, q1 \n" - "vmovn.s16 d20, q2 \n" - "vmovn.s16 d22, q3 \n" - "vmovn.s16 d24, q4 \n" - "vld1.32 {q1, q2}, [%[x0]] \n" - "vld1.32 {q3, q4}, [%[x1]] \n" - "vld1.32 {q5, q6}, [%[x2]] \n" - "vld1.32 {q7, q8}, [%[x3]] \n" - "vmul.f32 q1, q1, q0 \n" - "vmul.f32 q2, q2, q0 \n" - "vmul.f32 q3, q3, q0 \n" - "vmul.f32 q4, q4, q0 \n" - "vmul.f32 q5, q5, q0 \n" - "vmul.f32 q6, q6, q0 \n" - "vmul.f32 q7, q7, q0 \n" - "vmul.f32 q8, q8, q0 \n" - "vcvt.s32.f32 q1, q1 \n" - "vcvt.s32.f32 q2, q2 \n" - "vcvt.s32.f32 q3, q3 \n" - "vcvt.s32.f32 q4, q4 \n" - "vcvt.s32.f32 q5, q5 \n" - "vcvt.s32.f32 q6, q6 \n" - "vcvt.s32.f32 q7, q7 \n" - "vcvt.s32.f32 q8, q8 \n" - "vmovn.s32 d2, q1 \n" - "vmovn.s32 d3, q2 \n" - "vmovn.s32 d4, q3 \n" - "vmovn.s32 d5, q4 \n" - "vmovn.s32 d6, q5 \n" - "vmovn.s32 d7, q6 \n" - "vmovn.s32 d8, q7 \n" - "vmovn.s32 d9, q8 \n" - "vmovn.s16 d19, q1 \n" - "vmovn.s16 d21, q2 \n" - "vmovn.s16 d23, q3 \n" - "vmovn.s16 d25, q4 \n" - - "cmp %[remain], #8 \n" - "blt store_4w_%= \n" - "vst1.32 {d18}, [%[y0]]! \n" - "vst1.32 {d20}, [%[y1]]! \n" - "vst1.32 {d22}, [%[y2]]! \n" - "vst1.32 {d24}, [%[y3]]! \n" - "vmov.32 d18, d19 \n" - "vmov.32 d20, d21 \n" - "vmov.32 d22, d23 \n" - "vmov.32 d24, d25 \n" - "sub %[remain], #8 \n" - - "store_4w_%=: \n" - "cmp %[remain], #4 \n" - "blt store_2w_%= \n" - "vst1.32 {d18[0]}, [%[y0]]! \n" - "vst1.32 {d20[0]}, [%[y1]]! \n" - "vst1.32 {d22[0]}, [%[y2]]! \n" - "vst1.32 {d24[0]}, [%[y3]]! \n" - "vext.32 d18, d18, d18, #1 \n" - "vext.32 d20, d20, d20, #1 \n" - "vext.32 d22, d22, d22, #1 \n" - "vext.32 d24, d24, d24, #1 \n" - "sub %[remain], #4 \n" - - "store_2w_%=: \n" - "cmp %[remain], #2 \n" - "blt store_1w_%= \n" - "vst1.16 {d18[0]}, [%[y0]]! \n" - "vst1.16 {d20[0]}, [%[y1]]! \n" - "vst1.16 {d22[0]}, [%[y2]]! \n" - "vst1.16 {d24[0]}, [%[y3]]! \n" - "vext.16 d18, d18, d18, #1 \n" - "vext.16 d20, d20, d20, #1 \n" - "vext.16 d22, d22, d22, #1 \n" - "vext.16 d24, d24, d24, #1 \n" - "sub %[remain], #2 \n" - - "store_1w_%=:" - "cmp %[remain], #1 \n" - "blt end_%= \n" - "vst1.8 {d18[0]}, [%[y0]]! \n" - "vst1.8 {d20[0]}, [%[y1]]! \n" - "vst1.8 {d22[0]}, [%[y2]]! \n" - "vst1.8 {d24[0]}, [%[y3]]! \n" - - "end_%=: \n" - : [x0] "+r"(x0), [x1] "+r"(x1), [x2] "+r"(x2), [x3] "+r"(x3), - [y0] "+r"(y0), [y1] "+r"(y1), [y2] "+r"(y2), [y3] "+r"(y3), - [loop] "+r"(loop), [remain] "+r"(remain) - : [scale] "r"(scale) - : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", - "q8", "q9", "q10", "q11", "q12"); - asm volatile( - "vdup.s8 d0, %[val] \n" - "cmp %[pad_loop], #0 \n" - "ble store_pad_2w_%= \n" - "loop_pad_4w_%=: \n" - "vst1.32 {d0[0]}, [%[y0]]! \n" - "vst1.32 {d0[0]}, [%[y1]]! \n" - "vst1.32 {d0[0]}, [%[y2]]! \n" - "vst1.32 {d0[0]}, [%[y3]]! \n" - "subs %[pad_loop], #1 \n" - "bne loop_pad_4w_%= \n" - - "store_pad_2w_%=: \n" - "cmp %[pad_remain], #2 \n" - "blt store_pad_1w_%= \n" - "vst1.16 {d0[0]}, [%[y0]]! \n" - "vst1.16 {d0[0]}, [%[y1]]! \n" - "vst1.16 {d0[0]}, [%[y2]]! \n" - "vst1.16 {d0[0]}, [%[y3]]! \n" - "sub %[pad_remain], #2 \n" - - "store_pad_1w_%=: \n" - "cmp %[pad_remain], #1 \n" - "blt end_%= \n" - "vst1.8 {d0[0]}, [%[y0]]! \n" - "vst1.8 {d0[0]}, [%[y1]]! \n" - "vst1.8 {d0[0]}, [%[y2]]! \n" - "vst1.8 {d0[0]}, [%[y3]]! \n" - "end_%=: \n" - : [y0] "+r"(y0), [y1] "+r"(y1), [y2] "+r"(y2), [y3] "+r"(y3), - [pad_loop] "+r"(pad_loop), [pad_remain] "+r"(pad_remain) - : [val] "r"(padding_val) - : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", - "q8", "q9", "q10", "q11", "q12"); - } - } - for (int c = (channels & 0xFFFC); c < channels; ++c) { - const float *input0 = x + (batch * channels + c) * input_spatial_size; - size_t offset = (batch * channels + c) * output_spatial_size; - for (int h = 0; h < 2; ++h) { - int8_t *y0 = - y + offset + h * ((input_h + paddings[0]) * output_w - paddings[1]); - int loop = start >> 4; - int remain = start & 0xF; - asm volatile( - "vdup.s8 q0, %[val] \n" - "cmp %[loop], #0 \n" - "ble start_remain_%= \n" - - "store_16w_%=: \n" - "vst1.32 {q0}, [%[y0]]! \n" - "subs %[loop], #1 \n" - "bne store_16w_%= \n" - - "start_remain_%=: \n" - "cmp %[remain], #8 \n" - "blt store_4w_%= \n" - "vst1.32 {d0}, [%[y0]]! \n" - "sub %[remain], #8 \n" - - "store_4w_%=: \n" - "cmp %[remain], #4 \n" - "blt store_2w_%= \n" - "vst1.32 {d0[0]}, [%[y0]]! \n" - "sub %[remain], #4 \n" - - "store_2w_%=: \n" - "cmp %[remain], #4 \n" - "blt store_1w_%= \n" - "vst1.16 {d0[0]}, [%[y0]]! \n" - "sub %[remain], #2 \n" - - "store_1w_%=: \n" - "cmp %[remain], #1 \n" - "blt end_%= \n" - "vst1.8 {d0[0]}, [%[y0]]! \n" - "end_%=: \n" - : [y0] "+r"(y0), [loop] "+r"(loop), [remain] "+r"(remain) - : [val] "r"(padding_val) - : "cc", "memory", "q0"); - } - // quantize valid area - int8_t *y0 = y + offset + start; - for (int h = 0; h < input_h; ++h) { - const float *x0 = input0 + h * input_w; - int loop = input_w >> 4; - int remain = input_w & 0xF; - int pad_loop = paddings[1] >> 1; // (paddings[1] << 1) >> 2 - int pad_remain = (paddings[1] << 1) & 0x3; - asm volatile( - "vdup.f32 q0, %[scale] \n" - "cmp %[loop], #0 \n" - "ble quantize_remain_%= \n" - - "loop_quantize_%=: \n" - "vld1.32 {q1, q2}, [%[x0]]! \n" - "vmul.f32 q1, q1, q0 \n" - "vmul.f32 q2, q2, q0 \n" - "vcvt.s32.f32 q1, q1 \n" - "vcvt.s32.f32 q2, q2 \n" - "vmovn.s32 d2, q1 \n" - "vmovn.s32 d3, q2 \n" - "vmovn.s16 d18, q1 \n" - "vld1.32 {q1, q2}, [%[x0]]! \n" - "vmul.f32 q1, q1, q0 \n" - "vmul.f32 q2, q2, q0 \n" - "vcvt.s32.f32 q1, q1 \n" - "vcvt.s32.f32 q2, q2 \n" - "vmovn.s32 d2, q1 \n" - "vmovn.s32 d3, q2 \n" - "vmovn.s16 d19, q1 \n" - "vst1.32 {q9}, [%[y0]]! \n" - - "subs %[loop], #1 \n" - "bne loop_quantize_%= \n" - - "quantize_remain_%=: \n" - "cmp %[remain], #0 \n" - "ble start_pad_%= \n" - - "vldm %[x0], {d2-d9} \n" - "vmul.f32 q1, q1, q0 \n" - "vmul.f32 q2, q2, q0 \n" - "vcvt.s32.f32 q1, q1 \n" - "vcvt.s32.f32 q2, q2 \n" - "vmovn.s32 d2, q1 \n" - "vmovn.s32 d3, q2 \n" - "vmovn.s16 d18, q1 \n" - "vmul.f32 q3, q3, q0 \n" - "vmul.f32 q4, q4, q0 \n" - "vcvt.s32.f32 q1, q3 \n" - "vcvt.s32.f32 q2, q4 \n" - "vmovn.s32 d2, q1 \n" - "vmovn.s32 d3, q2 \n" - "vmovn.s16 d19, q1 \n" - - "cmp %[remain], #8 \n" - "blt store_4w_%= \n" - "vst1.32 {d18}, [%[y0]]! \n" - "vmov.32 d18, d19 \n" - "sub %[remain], #8 \n" - - "store_4w_%=: \n" - "cmp %[remain], #4 \n" - "blt store_2w_%= \n" - "vst1.32 {d18[0]}, [%[y0]]! \n" - "vext.32 d18, d18, d18, #1 \n" - "sub %[remain], #4 \n" - - "store_2w_%=: \n" - "cmp %[remain], #2 \n" - "blt store_1w_%= \n" - "vst1.16 {d18[0]}, [%[y0]]! \n" - "vext.16 d18, d18, d18, #1 \n" - "sub %[remain], #2 \n" - - "store_1w_%=:" - "cmp %[remain], #1 \n" - "blt start_pad_%= \n" - "vst1.8 {d18[0]}, [%[y0]]! \n" - - "start_pad_%=: \n" - "vdup.s8 d0, %[val] \n" - "cmp %[pad_loop], #0 \n" - "ble pad_remain_%= \n" - "loop_pad_4w_%=: \n" - "vst1.32 {d0[0]}, [%[y0]]! \n" - "subs %[pad_loop], #1 \n" - "bne loop_pad_4w_%= \n" - - "pad_remain_%=: \n" - "cmp %[pad_remain], #2 \n" - "blt store_pad_1w_%= \n" - "vst1.16 {d0[0]}, [%[y0]]! \n" - "sub %[pad_remain], #2 \n" - - "store_pad_1w_%=: \n" - "cmp %[pad_remain], #1 \n" - "blt end_%= \n" - "vst1.8 {d0[0]}, [%[y0]]! \n" - "end_%=: \n" - : [x0] "+r"(x0), [y0] "+r"(y0), [loop] "+r"(loop), - [remain] "+r"(remain), [pad_loop] "+r"(pad_loop), - [pad_remain] "+r"(pad_remain) - : [scale] "r"(scale), [val] "r"(padding_val) - : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q9"); - } - } + for (size_t i = 0; i < remain; ++i) { + max_abs = std::max(max_abs, std::abs(x[i])); } + return max_abs; } -#endif // __aarch64__ -#endif // ARM_NEON template <> bool QuantizeKernel::Init(QuantizeParam *param) { @@ -799,19 +191,15 @@ void QuantizeKernel::Compute(const QuantizeParam ¶m) { // only support int8 currently float scale = 127 / max_abs; param.online_scale_->mutable_data()[0] = max_abs; - const auto &paddings = param.paddings_; - // std::vector paddings = {0, 0}; - // const auto padding_val = param.padding_val_; - int8_t padding_val = 0; switch (param.round_type_) { case ROUND_NEAREST_TO_EVEN: - quantize_round_to_even(input, scale, paddings, padding_val, output); + Quantize(input, scale, output); break; case ROUND_NEAREST_TOWARDS_ZERO: - quantize_round_to_zero(input, scale, paddings, padding_val, output); + Quantize(input, scale, output); break; case ROUND_NEAREST_AWAY_ZERO: - quantize_round_to_nearest(input, scale, paddings, padding_val, output); + Quantize(input, scale, output); break; default: LOG(kLOG_ERROR) << "round type is not supported."; diff --git a/src/operators/kernel/central-arm-func/conv_arm_func.h b/src/operators/kernel/central-arm-func/conv_arm_func.h index ce111ed78f7b81affffc646b49a00e6d15cbb697..00cb4dfb045a8f67e5e0d33fe3faf794a2e0cac1 100644 --- a/src/operators/kernel/central-arm-func/conv_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_arm_func.h @@ -170,31 +170,21 @@ template inline void DepthwiseConv3x3(const ConvParam ¶m) { const Tensor *input = param.Input(); const Tensor *filter = param.Filter(); + const std::vector &paddings = param.Paddings(); + const std::vector &strides = param.Strides(); + const int batch_size = input->dims()[0]; Tensor *output = param.Output(); output->mutable_data(); - const std::vector &paddings = param.Paddings(); - const std::vector &strides = param.Strides(); - const int batch_size = static_cast(input->dims()[0]); - Tensor input_pad; - math::PadFunctor pad; for (int i = 0; i < batch_size; i++) { Tensor in_batch = input->Slice(i, i + 1); Tensor out_batch = output->Slice(i, i + 1); - if (paddings[0] || paddings[1]) { - framework::DDim pad_shape = in_batch.dims(); - pad_shape[2] += 2 * paddings[0]; - pad_shape[3] += 2 * paddings[1]; - input_pad.mutable_data(pad_shape); - pad(in_batch, paddings[0], paddings[0], paddings[1], paddings[1], - &input_pad); - } else { - input_pad = in_batch; - } if (strides[0] == 1) { - math::DepthwiseConv3x3s1(input_pad, *filter, &out_batch); + math::DepthwiseConv3x3s1(in_batch, *filter, paddings, + &out_batch); } else if (strides[0] == 2) { - math::DepthwiseConv3x3s2(input_pad, *filter, &out_batch); + math::DepthwiseConv3x3s2(in_batch, *filter, paddings, + &out_batch); } else { // math::DepthwiseConv3x3(input_pad, *filter, // &out_batch); diff --git a/src/operators/math/depthwise_conv3x3.cpp b/src/operators/math/depthwise_conv3x3.cpp index e74659ab4f0cd86c5a6f742a8313bbfb06dc51d3..a4466a52fac228812e8df205a61bdb594775d327 100644 --- a/src/operators/math/depthwise_conv3x3.cpp +++ b/src/operators/math/depthwise_conv3x3.cpp @@ -1278,7 +1278,10 @@ void DepthwiseConv3x3s2p1v2(const framework::Tensor *input, const float *input_data = input->data(); const float *filter_data = filter->data(); float *output_data = output->data(); - const float *bias_data = bias->data(); + const float *bias_data; + if (if_bias) { + bias_data = bias->data(); + } const int in_h = static_cast(input->dims()[2]); const int in_w = static_cast(input->dims()[3]); diff --git a/src/operators/math/depthwise_conv3x3.h b/src/operators/math/depthwise_conv3x3.h index 34e68e42664a65f9203a30562c2780210c05a42e..35d6c7d3f0cabc4e25fb22388349a7b45f93fc64 100644 --- a/src/operators/math/depthwise_conv3x3.h +++ b/src/operators/math/depthwise_conv3x3.h @@ -70,16 +70,19 @@ void DepthwiseConv3x3s2p0(const framework::Tensor *input, // void DepthwiseConv3x3(const framework::Tensor *input, // const framework::Tensor *filter, // const std::vector &strides, +// const std::vector &paddings, // framework::Tensor *output); template void DepthwiseConv3x3s1(const framework::Tensor &input, const framework::Tensor &filter, + const std::vector &paddings, framework::Tensor *output); template void DepthwiseConv3x3s2(const framework::Tensor &input, const framework::Tensor &filter, + const std::vector &paddings, framework::Tensor *output); } // namespace math diff --git a/src/operators/math/depthwise_conv3x3_int8.cpp b/src/operators/math/depthwise_conv3x3_int8.cpp index ddd8f79f7ce350e048585917f96d82639d4ea951..38081ea6bb8fb092ba7a5cbda8f63bd287ffea2a 100644 --- a/src/operators/math/depthwise_conv3x3_int8.cpp +++ b/src/operators/math/depthwise_conv3x3_int8.cpp @@ -29,6 +29,7 @@ namespace math { template <> void DepthwiseConv3x3s1(const framework::Tensor &input, const framework::Tensor &filter, + const std::vector &paddings, framework::Tensor *output) { const int8_t *input_data = input.data(); const int8_t *filter_data = filter.data(); @@ -751,6 +752,7 @@ void DepthwiseConv3x3s1(const framework::Tensor &input, template <> void DepthwiseConv3x3s2(const framework::Tensor &input, const framework::Tensor &filter, + const std::vector &paddings, framework::Tensor *output) { const int8_t *input_data = input.data(); const int8_t *filter_data = filter.data(); diff --git a/src/operators/op_param.h b/src/operators/op_param.h index 7be35f81f2e7052e32a93531c325d716ed81c2ec..b1c3028fb089894e641bde4d015b13b5dc351db2 100644 --- a/src/operators/op_param.h +++ b/src/operators/op_param.h @@ -405,9 +405,9 @@ class ConvParam : public OpParam { const RType *Input() const { return input_; } - RType *&Filter() const { return filter_; } + RType *Filter() const { return filter_; } - RType *&Output() const { return output_; } + RType *Output() const { return output_; } const vector &Strides() const { return strides_; } @@ -441,8 +441,8 @@ class ConvParam : public OpParam { private: RType *input_; - mutable RType *output_; - mutable RType *filter_; + RType *output_; + RType *filter_; vector strides_; vector paddings_; vector dilations_; diff --git a/test/operators/test_quantize_op.cpp b/test/operators/test_quantize_op.cpp index 9988661bcb898daa5e79b6d22d65d90cfa03c668..50c0e7bd05da7f7a5ee1fd6912be0eff2f6e2958 100644 --- a/test/operators/test_quantize_op.cpp +++ b/test/operators/test_quantize_op.cpp @@ -44,25 +44,19 @@ struct Round { template <> struct Round { int8_t operator()(float x) { - int8_t ret = 0; float v = std::round(x); - int32_t q = (int32_t)v; - if (abs(abs(q - x) - 0.5) > 0) { - ret = q; - } else { - if (abs(q) % 2 == 0) { - ret = q; - } else { - ret = q + ((q > 0) ? -1 : 1); + int32_t q = static_cast(v); + if (abs(abs(q - v) - 0.5) <= 0) { + if (abs(q) % 2 != 0) { + q = q + ((q > 0) ? -1 : 1); } } - return ret; + return static_cast(q); } }; template -static void quantize(const Tensor *input, const float scale, const int pad, - const int8_t pad_val, Tensor *output) { +static void quantize(const Tensor *input, const float scale, Tensor *output) { int batch_size = input->dims()[0]; int channels = input->dims()[1]; int input_h = input->dims()[2]; @@ -77,29 +71,9 @@ static void quantize(const Tensor *input, const float scale, const int pad, for (int nc = 0; nc < batch_size * channels; ++nc) { const float *xh = x + nc * input_spatial; int8_t *yh = y + nc * output_spatial; - // pad top - for (int h = 0; h < pad; ++h, yh += output_w) { - for (int w = 0; w < output_w; ++w) { - yh[w] = pad_val; - } - } for (int h = 0; h < input_h; ++h, yh += output_w, xh += input_w) { - // pad left - for (int w = 0; w < pad; ++w) { - yh[w] = pad_val; - } for (int w = 0; w < input_w; ++w) { - yh[w + pad] = Round()(xh[w] * scale); - } - // pad right - for (int w = 0; w < pad; ++w) { - yh[pad + input_w + w] = pad_val; - } - } - // pad bottom - for (int h = 0; h < pad; ++h, yh += output_w) { - for (int w = 0; w < output_w; ++w) { - yh[w] = pad_val; + yh[w] = Round()(xh[w] * scale); } } } @@ -120,19 +94,14 @@ static float find_abs_max(const Tensor *input) { int TestQuqntizeOp(int argc, char *argv[]) { if (argc < 5) { - std::cout - << "Usage: ./test-quantize-op batch_size channel height width [pad]" - << std::endl; + std::cout << "Usage: ./test-quantize-op batch_size channel height width" + << std::endl; return 1; } - int pad = 0; int batch_size = atoi(argv[1]); int channel = atoi(argv[2]); int height = atoi(argv[3]); int width = atoi(argv[4]); - if (argc == 6) { - pad = atoi(argv[5]); - } std::cout << "batch_size: " << batch_size << ", channel: " << channel << ", height: " << height << ", width: " << width << std::endl; framework::DDim dim = @@ -153,7 +122,6 @@ int TestQuqntizeOp(int argc, char *argv[]) { auto output_scale_var = scope.get()->Var("output_scale"); framework::AttributeMap attrs; - attrs["paddings"].Set>(std::vector({pad, pad})); auto *op = new operators::QuantizeOp("quantize", inputs, outputs, attrs, scope); op->InferShape(); @@ -172,9 +140,9 @@ int TestQuqntizeOp(int argc, char *argv[]) { framework::Tensor output_cmp; output_cmp.Resize(output->dims()); float scale = 127 / output_scale_cmp; - // quantize(input, scale, pad, 0, &output_cmp); - // quantize(input, scale, pad, 0, &output_cmp); - quantize(input, scale, pad, 0, &output_cmp); + // quantize(input, scale, &output_cmp); + // quantize(input, scale, &output_cmp); + quantize(input, scale, &output_cmp); int8_t *output_cmp_data = output_cmp.data(); for (int i = 0; i < output->numel(); ++i) { PADDLE_MOBILE_ENFORCE(output_data[i] == output_cmp_data[i],