From 1ee5fd20a6009a986ae60ade528b88fd6ef38f2c Mon Sep 17 00:00:00 2001 From: liyin Date: Tue, 12 Mar 2019 14:43:46 +0800 Subject: [PATCH] Regress gemv; support quantize gather only --- mace/ops/arm/q8/gemv.cc | 494 ++++-------------- mace/ops/arm/q8/gemv.h | 7 +- .../tools/converter_tool/base_converter.py | 2 +- .../tools/converter_tool/transformer.py | 33 +- 4 files changed, 134 insertions(+), 402 deletions(-) diff --git a/mace/ops/arm/q8/gemv.cc b/mace/ops/arm/q8/gemv.cc index f61062f4..d52ed48d 100644 --- a/mace/ops/arm/q8/gemv.cc +++ b/mace/ops/arm/q8/gemv.cc @@ -23,9 +23,7 @@ #if !defined(__aarch64__) -#define vmlal_high_s16(c, a, b) vmlal_s16(c, vget_high_s16(a), vget_high_s16(b)) - -#define vaddvq_s32(v) ((v)[0] + (v)[1] + (v)[2] + (v)[3]) +#define vaddvq_u32(v) ((v)[0] + (v)[1] + (v)[2] + (v)[3]) #endif @@ -47,17 +45,19 @@ MaceStatus Gemv::Compute(const OpContext *context, Tensor *output) { MACE_UNUSED(context); - bool is_output_type_uint8 = - DataTypeToEnum::value == DataType::DT_UINT8; Tensor::MappingGuard lhs_guard(lhs); Tensor::MappingGuard rhs_guard(rhs); Tensor::MappingGuard bias_guard(bias); Tensor::MappingGuard output_guard(output); + const auto *lhs_data = lhs->data(); + const auto *rhs_data = rhs->data(); + OUTPUT_TYPE *output_data = output->mutable_data(); + float output_multiplier_float = 0.0; int32_t output_multiplier = 0; int32_t output_shift = 0; - if (is_output_type_uint8) { + if (is_output_type_uint8_) { MACE_CHECK(output->scale() > 0, "output scale must not be zero"); output_multiplier_float = lhs->scale() * rhs->scale() / output->scale(); GetOutputMultiplierAndShift(lhs->scale(), @@ -66,393 +66,110 @@ MaceStatus Gemv::Compute(const OpContext *context, &output_multiplier, &output_shift); } - const index_t h_block_size = 4; - const index_t h_block_count = RoundUpDiv(lhs_height, h_block_size); -#pragma omp parallel for collapse(2) schedule(runtime) + const int32_t lhs_zero_point = lhs->zero_point(); + const int32_t rhs_zero_point = rhs->zero_point(); + + const index_t w_block_size = 16; + const index_t w_block_count = lhs_width / w_block_size; + const index_t w_block_remain = lhs_width - w_block_size * w_block_count; + for (index_t b = 0; b < batch; ++b) { - for (index_t h_block_idx = 0; h_block_idx < h_block_count; ++h_block_idx) { - // TODO(liyin): it can be put it outside the loop, - // but openmp limits param count - const index_t w_block_size = 16; - const index_t w_block_count = lhs_width / w_block_size; - const index_t w_remain = lhs_width - w_block_size * w_block_count; - - uint8_t lhs_zero_point = static_cast(lhs->zero_point()); - uint8_t rhs_zero_point = static_cast(rhs->zero_point()); - - const uint8_t *lhs_data = lhs->data(); - const uint8_t *rhs_data = rhs->data(); - const int32_t *bias_data = nullptr; - if (bias) { - bias_data = bias->data(); + const uint8_t *rhs_base = + rhs_data + static_cast(rhs_batched) * b * lhs_width; + uint32_t sum_rhs = 0; + for (index_t i = 0; i < lhs_width; ++i) { + sum_rhs += static_cast(rhs_base[i]); + } + +#pragma omp parallel for schedule(runtime) + for (index_t h = 0; h < lhs_height; ++h) { + const uint8_t *lhs_ptr = lhs_data + + static_cast(lhs_batched) * b * lhs_height * lhs_width + + h * lhs_width; + const uint8_t *rhs_ptr = rhs_base; + OUTPUT_TYPE *output_ptr = output_data + b * lhs_height + h; + + uint32_t dot = 0; + uint32_t sum_lhs = 0; + uint32x4_t vo0_high_u32 = vdupq_n_u32(0); + uint32x4_t vo0_low_u32 = vdupq_n_u32(0); + uint32x4_t vo1_high_u32 = vdupq_n_u32(0); + uint32x4_t vo1_low_u32 = vdupq_n_u32(0); + uint32x4_t sum_lhs_low_u32 = vdupq_n_u32(0); + uint32x4_t sum_lhs_high_u32 = vdupq_n_u32(0); + + for (index_t w_block_idx = 0; w_block_idx < w_block_count; + ++w_block_idx) { + uint8x8_t vl0_u8 = vld1_u8(lhs_ptr); + uint8x8_t vl1_u8 = vld1_u8(lhs_ptr + 8); + + uint8x8_t vr0_u8 = vld1_u8(rhs_ptr); + uint8x8_t vr1_u8 = vld1_u8(rhs_ptr + 8); + + uint16x8_t vl0_u16 = vmovl_u8(vl0_u8); + uint16x8_t vl1_u16 = vmovl_u8(vl1_u8); + + uint16x8_t vr0_u16 = vmovl_u8(vr0_u8); + uint16x8_t vr1_u16 = vmovl_u8(vr1_u8); + + vo0_high_u32 = vmlal_u16(vo0_high_u32, + vget_high_u16(vl0_u16), + vget_high_u16(vr0_u16)); + vo0_low_u32 = vmlal_u16(vo0_low_u32, + vget_low_u16(vl0_u16), + vget_low_u16(vr0_u16)); + vo1_high_u32 = vmlal_u16(vo1_high_u32, + vget_high_u16(vl1_u16), + vget_high_u16(vr1_u16)); + vo1_low_u32 = vmlal_u16(vo1_low_u32, + vget_low_u16(vl1_u16), + vget_low_u16(vr1_u16)); + + // It can be precuculated if lhs is const, but for this case + // computation is not bottleneck + sum_lhs_high_u32 += vaddl_u16(vget_high_u16(vl0_u16), + vget_high_u16(vl1_u16)); + sum_lhs_low_u32 += vaddl_u16(vget_low_u16(vl0_u16), + vget_low_u16(vl1_u16)); + + lhs_ptr += 16; + rhs_ptr += 16; } - OUTPUT_TYPE *output_data = output->mutable_data(); - int32x4_t voutput_multiplier = vdupq_n_s32(output_multiplier); - int32x4_t voutput_shift_left = vdupq_n_s32(-output_shift); + vo0_low_u32 = vaddq_u32(vo0_high_u32, vo0_low_u32); + vo1_low_u32 = vaddq_u32(vo1_high_u32, vo1_low_u32); + vo0_low_u32 = vaddq_u32(vo0_low_u32, vo1_low_u32); + dot += vaddvq_u32(vo0_low_u32); - uint8x8_t - vlhs_zero_point = vdup_n_u8(lhs_zero_point); - uint8x8_t - vrhs_zero_point = vdup_n_u8(rhs_zero_point); + sum_lhs_low_u32 = vaddq_u32(sum_lhs_high_u32, sum_lhs_low_u32); + sum_lhs = vaddvq_u32(sum_lhs_low_u32); - const uint8_t - *lhs_ptr = lhs_data - + static_cast(lhs_batched) * b * lhs_height * lhs_width - + lhs_width * h_block_idx * h_block_size; - const uint8_t *rhs_ptr = - rhs_data + static_cast(rhs_batched) * b * lhs_width; - OUTPUT_TYPE - *ret_ptr = output_data + b * lhs_height + h_block_idx * h_block_size; - - const index_t h_block_len = - std::min(h_block_size, lhs_height - h_block_idx * h_block_size); - const index_t h_offset = h_block_idx * h_block_size; - - if (h_block_len == 4) { - int32x4_t vo0 = vdupq_n_s32(0); - int32x4_t vo1 = vdupq_n_s32(0); - int32x4_t vo2 = vdupq_n_s32(0); - int32x4_t vo3 = vdupq_n_s32(0); - - index_t r_w_block_count = w_block_count; - // just make compiler happy - MACE_UNUSED(r_w_block_count); - - // Register layout: (4x16) x (16x1) - // - // +----+ - // |d16 | - // | . | - // | . | - // | . | - // Rhs +----+ - // |d17 | - // | . | - // | . | - // | . | - // +----+ - // |d18 | - // | . | - // | . | - // | . | - // +----+ - // |d19 | - // | . | - // | . | - // | . | - // +----+ - // - // | | - // - // Lhs | | - // - // +--------+--------+--------+--------+ - - - - +----+ - // | d0 ... | d1 ... | d2 ... | d3 ... | |vo0 | - // | d4 ... | d5 ... | d6 ... | d7 ... | |vo1 | - // | d8 ... | d9 ... | d10... | d11... | |vo2 | - // | d12... | d13... | d14... | d15... | |vo3 | - // +--------+--------+--------+--------+ - - - - +----+ - // - // Accumulator - // - -#if not defined(__aarch64__) - asm volatile( - "cmp %[r_w_block_count], #0\n" - "beq 0f\n" - - "mov r0, %[rhs_ptr]\n" - "mov r1, %[lhs_ptr]\n" - "add r2, r1, %[lhs_width]\n" - "add r3, r2, %[lhs_width]\n" - "add r4, r3, %[lhs_width]\n" - - "vdup.u8 d20, %[rhs_zero_point]\n" - "vdup.u8 d21, %[lhs_zero_point]\n" - - // prelogue - "vld1.8 d16, [r0]!\n" - "vld1.8 d18, [r0]!\n" - - "vld1.8 d0, [r1]!\n" - "vld1.8 d2, [r1]!\n" - "vld1.8 d4, [r2]!\n" - "vld1.8 d6, [r2]!\n" - "vld1.8 d8, [r3]!\n" - "vld1.8 d10, [r3]!\n" - "vld1.8 d12, [r4]!\n" - "vld1.8 d14, [r4]!\n" - - "subs %[r_w_block_count], #1\n" - "beq 1f\n" - - "2: \n" - "vsubl.u8 q8, d16, d20\n" - "vsubl.u8 q9, d18, d20\n" - - "vsubl.u8 q0, d0, d21\n" - "vsubl.u8 q1, d2, d21\n" - "vsubl.u8 q2, d4, d21\n" - "vsubl.u8 q3, d6, d21\n" - "vsubl.u8 q4, d8, d21\n" - "vsubl.u8 q5, d10, d21\n" - "vsubl.u8 q6, d12, d21\n" - "vsubl.u8 q7, d14, d21\n" - - "vmlal.s16 %q[vo0], d0, d16\n" - "vmlal.s16 %q[vo1], d4, d16\n" - "vmlal.s16 %q[vo2], d8, d16\n" - "vmlal.s16 %q[vo3], d12, d16\n" - - "vld1.8 d0, [r1]!\n" - "vld1.8 d4, [r2]!\n" - "vld1.8 d8, [r3]!\n" - "vld1.8 d12, [r4]!\n" - "vld1.8 d16, [r0]!\n" - - "vmlal.s16 %q[vo0], d2, d18\n" - "vmlal.s16 %q[vo1], d6, d18\n" - "vmlal.s16 %q[vo2], d10, d18\n" - "vmlal.s16 %q[vo3], d14, d18\n" - - "vld1.8 d2, [r1]!\n" - "vld1.8 d6, [r2]!\n" - "vld1.8 d10, [r3]!\n" - "vld1.8 d14, [r4]!\n" - "vld1.8 d18, [r0]!\n" - - "vmlal.s16 %q[vo0], d1, d17\n" - "vmlal.s16 %q[vo1], d5, d17\n" - "vmlal.s16 %q[vo2], d9, d17\n" - "vmlal.s16 %q[vo3], d13, d17\n" - - "subs %[r_w_block_count], #1\n" - "vmlal.s16 %q[vo0], d3, d19\n" - "vmlal.s16 %q[vo1], d7, d19\n" - "vmlal.s16 %q[vo2], d11, d19\n" - "vmlal.s16 %q[vo3], d15, d19\n" - - "bne 2b\n" - - // prologue - "1:\n" - "vsubl.u8 q8, d16, d20\n" - "vsubl.u8 q9, d18, d20\n" - - "vsubl.u8 q0, d0, d21\n" - "vsubl.u8 q1, d2, d21\n" - "vsubl.u8 q2, d4, d21\n" - "vsubl.u8 q3, d6, d21\n" - "vsubl.u8 q4, d8, d21\n" - "vsubl.u8 q5, d10, d21\n" - "vsubl.u8 q6, d12, d21\n" - "vsubl.u8 q7, d14, d21\n" - - "vmlal.s16 %q[vo0], d0, d16\n" - "vmlal.s16 %q[vo1], d4, d16\n" - "vmlal.s16 %q[vo2], d8, d16\n" - "vmlal.s16 %q[vo3], d12, d16\n" - - "vmlal.s16 %q[vo0], d1, d17\n" - "vmlal.s16 %q[vo1], d5, d17\n" - "vmlal.s16 %q[vo2], d9, d17\n" - "vmlal.s16 %q[vo3], d13, d17\n" - - "vmlal.s16 %q[vo0], d2, d18\n" - "vmlal.s16 %q[vo1], d6, d18\n" - "vmlal.s16 %q[vo2], d10, d18\n" - "vmlal.s16 %q[vo3], d14, d18\n" - - "vmlal.s16 %q[vo0], d3, d19\n" - "vmlal.s16 %q[vo1], d7, d19\n" - "vmlal.s16 %q[vo2], d11, d19\n" - "vmlal.s16 %q[vo3], d15, d19\n" - - "0:\n" - : // outputs - [vo0] "+w"(vo0), - [vo1] "+w"(vo1), - [vo2] "+w"(vo2), - [vo3] "+w"(vo3), - [r_w_block_count] "+r"(r_w_block_count) - : // inputs - [lhs_ptr] "r"(lhs_ptr), [rhs_ptr] "r"(rhs_ptr), - [lhs_width] "r"(lhs_width), - [lhs_zero_point] "r"(lhs_zero_point), - [rhs_zero_point] "r"(rhs_zero_point) - : // clobbers - "cc", "memory", "r0", "r1", "r2", "r3", "r4", - "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10", - "d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", "d20", - "d21"); - - lhs_ptr += w_block_count * w_block_size; - rhs_ptr += w_block_count * w_block_size; -#else - for (index_t w_block_index = 0; w_block_index < w_block_count; - ++w_block_index) { - uint8x8_t vr0 = vld1_u8(rhs_ptr); - int16x8_t - vxr0 = vreinterpretq_s16_u16(vsubl_u8(vr0, vrhs_zero_point)); - uint8x8_t vr0n = vld1_u8(rhs_ptr + 8); - int16x8_t - vxr0n = vreinterpretq_s16_u16(vsubl_u8(vr0n, vrhs_zero_point)); - - uint8x8_t vl0 = vld1_u8(lhs_ptr); - int16x8_t - vxl0 = vreinterpretq_s16_u16(vsubl_u8(vl0, vlhs_zero_point)); - uint8x8_t vl0n = vld1_u8(lhs_ptr + 8); - int16x8_t - vxl0n = vreinterpretq_s16_u16(vsubl_u8(vl0n, vlhs_zero_point)); - - vo0 = vmlal_s16(vo0, vget_low_s16(vxl0), vget_low_s16(vxr0)); - vo0 = vmlal_high_s16(vo0, vxl0, vxr0); - vo0 = vmlal_s16(vo0, vget_low_s16(vxl0n), vget_low_s16(vxr0n)); - vo0 = vmlal_high_s16(vo0, vxl0n, vxr0n); - - const uint8_t *lhs_ptr1 = lhs_ptr + lhs_width; - - uint8x8_t vl1 = vld1_u8(lhs_ptr1); - int16x8_t - vxl1 = vreinterpretq_s16_u16(vsubl_u8(vl1, vlhs_zero_point)); - uint8x8_t vl1n = vld1_u8(lhs_ptr1 + 8); - int16x8_t - vxl1n = vreinterpretq_s16_u16(vsubl_u8(vl1n, vlhs_zero_point)); - - vo1 = vmlal_s16(vo1, vget_low_s16(vxl1), vget_low_s16(vxr0)); - vo1 = vmlal_high_s16(vo1, vxl1, vxr0); - vo1 = vmlal_s16(vo1, vget_low_s16(vxl1n), vget_low_s16(vxr0n)); - vo1 = vmlal_high_s16(vo1, vxl1n, vxr0n); - - const uint8_t *lhs_ptr2 = lhs_ptr1 + lhs_width; - - uint8x8_t vl2 = vld1_u8(lhs_ptr2); - int16x8_t - vxl2 = vreinterpretq_s16_u16(vsubl_u8(vl2, vlhs_zero_point)); - uint8x8_t vl2n = vld1_u8(lhs_ptr2 + 8); - int16x8_t - vxl2n = vreinterpretq_s16_u16(vsubl_u8(vl2n, vlhs_zero_point)); - - vo2 = vmlal_s16(vo2, vget_low_s16(vxl2), vget_low_s16(vxr0)); - vo2 = vmlal_high_s16(vo2, vxl2, vxr0); - vo2 = vmlal_s16(vo2, vget_low_s16(vxl2n), vget_low_s16(vxr0n)); - vo2 = vmlal_high_s16(vo2, vxl2n, vxr0n); - - const uint8_t *lhs_ptr3 = lhs_ptr2 + lhs_width; - - uint8x8_t vl3 = vld1_u8(lhs_ptr3); - int16x8_t - vxl3 = vreinterpretq_s16_u16(vsubl_u8(vl3, vlhs_zero_point)); - uint8x8_t vl3n = vld1_u8(lhs_ptr3 + 8); - int16x8_t - vxl3n = vreinterpretq_s16_u16(vsubl_u8(vl3n, vlhs_zero_point)); - - vo3 = vmlal_s16(vo3, vget_low_s16(vxl3), vget_low_s16(vxr0)); - vo3 = vmlal_high_s16(vo3, vxl3, vxr0); - vo3 = vmlal_s16(vo3, vget_low_s16(vxl3n), vget_low_s16(vxr0n)); - vo3 = vmlal_high_s16(vo3, vxl3n, vxr0n); - - lhs_ptr += 16; - rhs_ptr += 16; - } -#endif // __aarch64__ - int32x4_t vo = {vaddvq_s32(vo0), - vaddvq_s32(vo1), - vaddvq_s32(vo2), - vaddvq_s32(vo3)}; - - for (index_t w = 0; w < w_remain; ++w) { - vo[0] += - (lhs_ptr[0] - lhs_zero_point) * (rhs_ptr[0] - rhs_zero_point); - vo[1] += (lhs_ptr[lhs_width] - lhs_zero_point) - * (rhs_ptr[0] - rhs_zero_point); - vo[2] += (lhs_ptr[lhs_width * 2] - lhs_zero_point) - * (rhs_ptr[0] - rhs_zero_point); - vo[3] += (lhs_ptr[lhs_width * 3] - lhs_zero_point) - * (rhs_ptr[0] - rhs_zero_point); - ++lhs_ptr; - ++rhs_ptr; - } - - if (bias) { - int32x4_t vbias = vdupq_n_s32(0); - vbias = vld1q_s32(bias_data + h_offset); - vo = vaddq_s32(vo, vbias); - } - - if (is_output_type_uint8) { - int32x4_t vo_mul = vqrdmulhq_s32(vo, voutput_multiplier); - int32x4_t - fixup = vshrq_n_s32(vandq_s32(vo_mul, voutput_shift_left), 31); - int32x4_t fixed_up_x = vqaddq_s32(vo_mul, fixup); - int32x4_t - vo_rescale_int32 = vrshlq_s32(fixed_up_x, voutput_shift_left); - - int16x4_t vo_rescale_int16 = vqmovn_s32(vo_rescale_int32); - uint8x8_t vo_rescale_uint8 = - vqmovun_s16(vcombine_s16(vo_rescale_int16, vo_rescale_int16)); - - ret_ptr[0] = vo_rescale_uint8[0]; - ret_ptr[1] = vo_rescale_uint8[1]; - ret_ptr[2] = vo_rescale_uint8[2]; - ret_ptr[3] = vo_rescale_uint8[3]; - } else { - ret_ptr[0] = vo[0]; - ret_ptr[1] = vo[1]; - ret_ptr[2] = vo[2]; - ret_ptr[3] = vo[3]; - } - } else { // h_block_len < 4 - // TODO(liyin): handle here case by case (1,2,3) to accelerate - const uint8_t *tmp_lhs_ptr = lhs_ptr; - const uint8_t *tmp_rhs_ptr = rhs_ptr; - for (index_t h = 0; h < h_block_len; ++h) { - lhs_ptr = tmp_lhs_ptr + h * lhs_width; - rhs_ptr = tmp_rhs_ptr; - int32x4_t vo0 = vdupq_n_s32(0); - for (index_t w = 0; w < w_block_count; ++w) { - uint8x8_t vr0 = vld1_u8(rhs_ptr); - int16x8_t - vxr0 = vreinterpretq_s16_u16(vsubl_u8(vr0, vrhs_zero_point)); - uint8x8_t vr0n = vld1_u8(rhs_ptr + 8); - int16x8_t - vxr0n = vreinterpretq_s16_u16(vsubl_u8(vr0n, vrhs_zero_point)); - - uint8x8_t vl0 = vld1_u8(lhs_ptr); - int16x8_t - vxl0 = vreinterpretq_s16_u16(vsubl_u8(vl0, vlhs_zero_point)); - uint8x8_t vl0n = vld1_u8(lhs_ptr + 8); - int16x8_t - vxl0n = vreinterpretq_s16_u16(vsubl_u8(vl0n, vlhs_zero_point)); - - vo0 = vmlal_s16(vo0, vget_low_s16(vxl0), vget_low_s16(vxr0)); - vo0 = vmlal_high_s16(vo0, vxl0, vxr0); - vo0 = vmlal_s16(vo0, vget_low_s16(vxl0n), vget_low_s16(vxr0n)); - vo0 = vmlal_high_s16(vo0, vxl0n, vxr0n); - - lhs_ptr += 16; - rhs_ptr += 16; - } // w - int32_t s0 = vaddvq_s32(vo0) + (bias ? bias_data[h_offset + h] : 0); - for (index_t w = 0; w < w_remain; ++w) { - s0 += (lhs_ptr[0] - lhs_zero_point) * (rhs_ptr[0] - rhs_zero_point); - ++lhs_ptr; - ++rhs_ptr; - } // w - - if (is_output_type_uint8) { - ret_ptr[h] = - Saturate(std::roundf(s0 * output_multiplier_float)); - } else { - ret_ptr[h] = s0; - } - } // h - } // if - } // h_block_idx + for (index_t w = 0; w < w_block_remain; ++w) { + dot += (*lhs_ptr) * (*rhs_ptr); + sum_lhs += (*lhs_ptr); + ++lhs_ptr; + ++rhs_ptr; + } + + const auto zero_point_dot = + static_cast(lhs_zero_point * rhs_zero_point * lhs_width); + int32_t ret = dot - sum_lhs * rhs_zero_point - sum_rhs * lhs_zero_point + + zero_point_dot; + if (bias) { + ret += bias->data()[h]; + } + + if (is_output_type_uint8_) { + *output_ptr = + Saturate(std::roundf(ret * output_multiplier_float)); + } else { + *output_ptr = ret; + } + } // h } // b + return MaceStatus::MACE_SUCCESS; } @@ -466,7 +183,6 @@ class Gemv; } // namespace ops } // namespace mace -#if defined(vmlal_high_s16) -#undef vmlal_high_s16 -#undef vaddvq_s32 -#endif +#ifdef vaddvq_u32 +#undef vaddvq_u32 +#endif // vaddvq_u32 diff --git a/mace/ops/arm/q8/gemv.h b/mace/ops/arm/q8/gemv.h index adcb9590..1734a956 100644 --- a/mace/ops/arm/q8/gemv.h +++ b/mace/ops/arm/q8/gemv.h @@ -30,7 +30,9 @@ namespace q8 { template class Gemv { public: - Gemv() {} + Gemv() : is_output_type_uint8_( + DataTypeToEnum::value == DataType::DT_UINT8) { + } ~Gemv() {} // Always row-major after transpose MaceStatus Compute( @@ -44,6 +46,9 @@ class Gemv { const bool lhs_batched, const bool rhs_batched, Tensor *output); + + private: + bool is_output_type_uint8_; }; } // namespace q8 diff --git a/mace/python/tools/converter_tool/base_converter.py b/mace/python/tools/converter_tool/base_converter.py index 03b1e7c3..000d7910 100644 --- a/mace/python/tools/converter_tool/base_converter.py +++ b/mace/python/tools/converter_tool/base_converter.py @@ -280,7 +280,7 @@ class TransformerRule(Enum): FOLD_FC_RESHAPE = 37 TRANSFORM_CHANNEL_SHUFFLE = 38 UPDATE_DATA_FORMAT = 39 - QUANTIZE_MATMUL_ONLY = 40 + QUANTIZE_SPECIFIC_OPS_ONLY = 40 class ConverterInterface(object): diff --git a/mace/python/tools/converter_tool/transformer.py b/mace/python/tools/converter_tool/transformer.py index 4a170646..fe9be9ee 100644 --- a/mace/python/tools/converter_tool/transformer.py +++ b/mace/python/tools/converter_tool/transformer.py @@ -103,8 +103,8 @@ class Transformer(base_converter.ConverterInterface): self.transform_caffe_reshape_and_flatten, TransformerRule.TRANSFORM_CHANNEL_SHUFFLE: self.transform_channel_shuffle, - TransformerRule.QUANTIZE_MATMUL_ONLY: - self.quantize_matmul_only, + TransformerRule.QUANTIZE_SPECIFIC_OPS_ONLY: + self.quantize_specific_ops_only, } self._option = option @@ -1118,7 +1118,7 @@ class Transformer(base_converter.ConverterInterface): rhs = op.input[1] if rhs in self._consts and len(self._consts[rhs].dims) == 2: arg = ConverterUtil.get_arg(op, MaceKeyword.mace_transpose_b_str) # noqa - six.print_('transpose matmul weight') + six.print_("Transpose matmul weight %s" % rhs) if arg is None: arg = op.arg.add() arg.name = MaceKeyword.mace_transpose_b_str @@ -1927,35 +1927,46 @@ class Transformer(base_converter.ConverterInterface): return True - def quantize_matmul_only(self): + def quantize_specific_ops_only(self): """ This transform rule is only used internally, we are not gonna make things too complex for users """ - to_quantize_ops = [MaceOp.MatMul.name] + to_quantize_ops_output_type = { + MaceOp.MatMul.name: mace_pb2.DT_INT32, + MaceOp.Gather.name: mace_pb2.DT_UINT8, + } + for op in self._model.op: - if (op.type not in to_quantize_ops or len(op.output) > 1 + if (op.type not in to_quantize_ops_output_type + or len(op.output) > 1 or ConverterUtil.get_arg(op, MaceKeyword.mace_op_data_type_str).i != mace_pb2.DT_FLOAT): # noqa # only support single output continue quantized_inputs_names = [] - should_quantize = True + should_quantize = False for idx, input_tensor in enumerate(op.input): if self.get_tensor_data_type(input_tensor) \ - != mace_pb2.DT_FLOAT: - should_quantize = False + == mace_pb2.DT_FLOAT: + should_quantize = True break if not should_quantize: continue + else: + print("Quantize op %s (%s)" % (op.name, op.type)) non_zero = self._option.device == DeviceType.CPU.value for idx, input_tensor in enumerate(op.input): quantized_inputs_names.append(input_tensor) + if self.get_tensor_data_type(input_tensor) \ + != mace_pb2.DT_FLOAT: + continue + if input_tensor in self._consts: const_tensor = self._consts[input_tensor] quantized_tensor = quantize_util.quantize( @@ -2005,7 +2016,7 @@ class Transformer(base_converter.ConverterInterface): orginal_output_name = op.output[0] op.output[0] = orginal_output_name + "_quant" - op.output_type.extend([mace_pb2.DT_INT32]) + op.output_type.extend([to_quantize_ops_output_type[op.type]]) data_type_arg = ConverterUtil.get_arg(op, MaceKeyword.mace_op_data_type_str) # noqa if data_type_arg is None: @@ -2022,7 +2033,7 @@ class Transformer(base_converter.ConverterInterface): dequantize_op.output_type.extend([mace_pb2.DT_FLOAT]) data_type_arg = dequantize_op.arg.add() data_type_arg.name = MaceKeyword.mace_op_data_type_str - data_type_arg.i = mace_pb2.DT_INT32 + data_type_arg.i = to_quantize_ops_output_type[op.type] quantize_flag_arg = ConverterUtil.get_arg(self._model, MaceKeyword.mace_quantize_flag_arg_str) # noqa -- GitLab