提交 1ee5fd20 编写于 作者: L liyin

Regress gemv; support quantize gather only

上级 202ea3a6
......@@ -23,9 +23,7 @@
#if !defined(__aarch64__)
#define vmlal_high_s16(c, a, b) vmlal_s16(c, vget_high_s16(a), vget_high_s16(b))
#define vaddvq_s32(v) ((v)[0] + (v)[1] + (v)[2] + (v)[3])
#define vaddvq_u32(v) ((v)[0] + (v)[1] + (v)[2] + (v)[3])
#endif
......@@ -47,17 +45,19 @@ MaceStatus Gemv<OUTPUT_TYPE>::Compute(const OpContext *context,
Tensor *output) {
MACE_UNUSED(context);
bool is_output_type_uint8 =
DataTypeToEnum<OUTPUT_TYPE>::value == DataType::DT_UINT8;
Tensor::MappingGuard lhs_guard(lhs);
Tensor::MappingGuard rhs_guard(rhs);
Tensor::MappingGuard bias_guard(bias);
Tensor::MappingGuard output_guard(output);
const auto *lhs_data = lhs->data<uint8_t>();
const auto *rhs_data = rhs->data<uint8_t>();
OUTPUT_TYPE *output_data = output->mutable_data<OUTPUT_TYPE>();
float output_multiplier_float = 0.0;
int32_t output_multiplier = 0;
int32_t output_shift = 0;
if (is_output_type_uint8) {
if (is_output_type_uint8_) {
MACE_CHECK(output->scale() > 0, "output scale must not be zero");
output_multiplier_float = lhs->scale() * rhs->scale() / output->scale();
GetOutputMultiplierAndShift(lhs->scale(),
......@@ -66,393 +66,110 @@ MaceStatus Gemv<OUTPUT_TYPE>::Compute(const OpContext *context,
&output_multiplier,
&output_shift);
}
const index_t h_block_size = 4;
const index_t h_block_count = RoundUpDiv(lhs_height, h_block_size);
#pragma omp parallel for collapse(2) schedule(runtime)
const int32_t lhs_zero_point = lhs->zero_point();
const int32_t rhs_zero_point = rhs->zero_point();
const index_t w_block_size = 16;
const index_t w_block_count = lhs_width / w_block_size;
const index_t w_block_remain = lhs_width - w_block_size * w_block_count;
for (index_t b = 0; b < batch; ++b) {
for (index_t h_block_idx = 0; h_block_idx < h_block_count; ++h_block_idx) {
// TODO(liyin): it can be put it outside the loop,
// but openmp limits param count
const index_t w_block_size = 16;
const index_t w_block_count = lhs_width / w_block_size;
const index_t w_remain = lhs_width - w_block_size * w_block_count;
uint8_t lhs_zero_point = static_cast<uint8_t>(lhs->zero_point());
uint8_t rhs_zero_point = static_cast<uint8_t>(rhs->zero_point());
const uint8_t *lhs_data = lhs->data<uint8_t>();
const uint8_t *rhs_data = rhs->data<uint8_t>();
const int32_t *bias_data = nullptr;
if (bias) {
bias_data = bias->data<int32_t>();
const uint8_t *rhs_base =
rhs_data + static_cast<index_t>(rhs_batched) * b * lhs_width;
uint32_t sum_rhs = 0;
for (index_t i = 0; i < lhs_width; ++i) {
sum_rhs += static_cast<uint32_t>(rhs_base[i]);
}
#pragma omp parallel for schedule(runtime)
for (index_t h = 0; h < lhs_height; ++h) {
const uint8_t *lhs_ptr = lhs_data
+ static_cast<index_t>(lhs_batched) * b * lhs_height * lhs_width
+ h * lhs_width;
const uint8_t *rhs_ptr = rhs_base;
OUTPUT_TYPE *output_ptr = output_data + b * lhs_height + h;
uint32_t dot = 0;
uint32_t sum_lhs = 0;
uint32x4_t vo0_high_u32 = vdupq_n_u32(0);
uint32x4_t vo0_low_u32 = vdupq_n_u32(0);
uint32x4_t vo1_high_u32 = vdupq_n_u32(0);
uint32x4_t vo1_low_u32 = vdupq_n_u32(0);
uint32x4_t sum_lhs_low_u32 = vdupq_n_u32(0);
uint32x4_t sum_lhs_high_u32 = vdupq_n_u32(0);
for (index_t w_block_idx = 0; w_block_idx < w_block_count;
++w_block_idx) {
uint8x8_t vl0_u8 = vld1_u8(lhs_ptr);
uint8x8_t vl1_u8 = vld1_u8(lhs_ptr + 8);
uint8x8_t vr0_u8 = vld1_u8(rhs_ptr);
uint8x8_t vr1_u8 = vld1_u8(rhs_ptr + 8);
uint16x8_t vl0_u16 = vmovl_u8(vl0_u8);
uint16x8_t vl1_u16 = vmovl_u8(vl1_u8);
uint16x8_t vr0_u16 = vmovl_u8(vr0_u8);
uint16x8_t vr1_u16 = vmovl_u8(vr1_u8);
vo0_high_u32 = vmlal_u16(vo0_high_u32,
vget_high_u16(vl0_u16),
vget_high_u16(vr0_u16));
vo0_low_u32 = vmlal_u16(vo0_low_u32,
vget_low_u16(vl0_u16),
vget_low_u16(vr0_u16));
vo1_high_u32 = vmlal_u16(vo1_high_u32,
vget_high_u16(vl1_u16),
vget_high_u16(vr1_u16));
vo1_low_u32 = vmlal_u16(vo1_low_u32,
vget_low_u16(vl1_u16),
vget_low_u16(vr1_u16));
// It can be precuculated if lhs is const, but for this case
// computation is not bottleneck
sum_lhs_high_u32 += vaddl_u16(vget_high_u16(vl0_u16),
vget_high_u16(vl1_u16));
sum_lhs_low_u32 += vaddl_u16(vget_low_u16(vl0_u16),
vget_low_u16(vl1_u16));
lhs_ptr += 16;
rhs_ptr += 16;
}
OUTPUT_TYPE *output_data = output->mutable_data<OUTPUT_TYPE>();
int32x4_t voutput_multiplier = vdupq_n_s32(output_multiplier);
int32x4_t voutput_shift_left = vdupq_n_s32(-output_shift);
vo0_low_u32 = vaddq_u32(vo0_high_u32, vo0_low_u32);
vo1_low_u32 = vaddq_u32(vo1_high_u32, vo1_low_u32);
vo0_low_u32 = vaddq_u32(vo0_low_u32, vo1_low_u32);
dot += vaddvq_u32(vo0_low_u32);
uint8x8_t
vlhs_zero_point = vdup_n_u8(lhs_zero_point);
uint8x8_t
vrhs_zero_point = vdup_n_u8(rhs_zero_point);
sum_lhs_low_u32 = vaddq_u32(sum_lhs_high_u32, sum_lhs_low_u32);
sum_lhs = vaddvq_u32(sum_lhs_low_u32);
const uint8_t
*lhs_ptr = lhs_data
+ static_cast<index_t>(lhs_batched) * b * lhs_height * lhs_width
+ lhs_width * h_block_idx * h_block_size;
const uint8_t *rhs_ptr =
rhs_data + static_cast<index_t>(rhs_batched) * b * lhs_width;
OUTPUT_TYPE
*ret_ptr = output_data + b * lhs_height + h_block_idx * h_block_size;
const index_t h_block_len =
std::min(h_block_size, lhs_height - h_block_idx * h_block_size);
const index_t h_offset = h_block_idx * h_block_size;
if (h_block_len == 4) {
int32x4_t vo0 = vdupq_n_s32(0);
int32x4_t vo1 = vdupq_n_s32(0);
int32x4_t vo2 = vdupq_n_s32(0);
int32x4_t vo3 = vdupq_n_s32(0);
index_t r_w_block_count = w_block_count;
// just make compiler happy
MACE_UNUSED(r_w_block_count);
// Register layout: (4x16) x (16x1)
//
// +----+
// |d16 |
// | . |
// | . |
// | . |
// Rhs +----+
// |d17 |
// | . |
// | . |
// | . |
// +----+
// |d18 |
// | . |
// | . |
// | . |
// +----+
// |d19 |
// | . |
// | . |
// | . |
// +----+
//
// | |
//
// Lhs | |
//
// +--------+--------+--------+--------+ - - - - +----+
// | d0 ... | d1 ... | d2 ... | d3 ... | |vo0 |
// | d4 ... | d5 ... | d6 ... | d7 ... | |vo1 |
// | d8 ... | d9 ... | d10... | d11... | |vo2 |
// | d12... | d13... | d14... | d15... | |vo3 |
// +--------+--------+--------+--------+ - - - - +----+
//
// Accumulator
//
#if not defined(__aarch64__)
asm volatile(
"cmp %[r_w_block_count], #0\n"
"beq 0f\n"
"mov r0, %[rhs_ptr]\n"
"mov r1, %[lhs_ptr]\n"
"add r2, r1, %[lhs_width]\n"
"add r3, r2, %[lhs_width]\n"
"add r4, r3, %[lhs_width]\n"
"vdup.u8 d20, %[rhs_zero_point]\n"
"vdup.u8 d21, %[lhs_zero_point]\n"
// prelogue
"vld1.8 d16, [r0]!\n"
"vld1.8 d18, [r0]!\n"
"vld1.8 d0, [r1]!\n"
"vld1.8 d2, [r1]!\n"
"vld1.8 d4, [r2]!\n"
"vld1.8 d6, [r2]!\n"
"vld1.8 d8, [r3]!\n"
"vld1.8 d10, [r3]!\n"
"vld1.8 d12, [r4]!\n"
"vld1.8 d14, [r4]!\n"
"subs %[r_w_block_count], #1\n"
"beq 1f\n"
"2: \n"
"vsubl.u8 q8, d16, d20\n"
"vsubl.u8 q9, d18, d20\n"
"vsubl.u8 q0, d0, d21\n"
"vsubl.u8 q1, d2, d21\n"
"vsubl.u8 q2, d4, d21\n"
"vsubl.u8 q3, d6, d21\n"
"vsubl.u8 q4, d8, d21\n"
"vsubl.u8 q5, d10, d21\n"
"vsubl.u8 q6, d12, d21\n"
"vsubl.u8 q7, d14, d21\n"
"vmlal.s16 %q[vo0], d0, d16\n"
"vmlal.s16 %q[vo1], d4, d16\n"
"vmlal.s16 %q[vo2], d8, d16\n"
"vmlal.s16 %q[vo3], d12, d16\n"
"vld1.8 d0, [r1]!\n"
"vld1.8 d4, [r2]!\n"
"vld1.8 d8, [r3]!\n"
"vld1.8 d12, [r4]!\n"
"vld1.8 d16, [r0]!\n"
"vmlal.s16 %q[vo0], d2, d18\n"
"vmlal.s16 %q[vo1], d6, d18\n"
"vmlal.s16 %q[vo2], d10, d18\n"
"vmlal.s16 %q[vo3], d14, d18\n"
"vld1.8 d2, [r1]!\n"
"vld1.8 d6, [r2]!\n"
"vld1.8 d10, [r3]!\n"
"vld1.8 d14, [r4]!\n"
"vld1.8 d18, [r0]!\n"
"vmlal.s16 %q[vo0], d1, d17\n"
"vmlal.s16 %q[vo1], d5, d17\n"
"vmlal.s16 %q[vo2], d9, d17\n"
"vmlal.s16 %q[vo3], d13, d17\n"
"subs %[r_w_block_count], #1\n"
"vmlal.s16 %q[vo0], d3, d19\n"
"vmlal.s16 %q[vo1], d7, d19\n"
"vmlal.s16 %q[vo2], d11, d19\n"
"vmlal.s16 %q[vo3], d15, d19\n"
"bne 2b\n"
// prologue
"1:\n"
"vsubl.u8 q8, d16, d20\n"
"vsubl.u8 q9, d18, d20\n"
"vsubl.u8 q0, d0, d21\n"
"vsubl.u8 q1, d2, d21\n"
"vsubl.u8 q2, d4, d21\n"
"vsubl.u8 q3, d6, d21\n"
"vsubl.u8 q4, d8, d21\n"
"vsubl.u8 q5, d10, d21\n"
"vsubl.u8 q6, d12, d21\n"
"vsubl.u8 q7, d14, d21\n"
"vmlal.s16 %q[vo0], d0, d16\n"
"vmlal.s16 %q[vo1], d4, d16\n"
"vmlal.s16 %q[vo2], d8, d16\n"
"vmlal.s16 %q[vo3], d12, d16\n"
"vmlal.s16 %q[vo0], d1, d17\n"
"vmlal.s16 %q[vo1], d5, d17\n"
"vmlal.s16 %q[vo2], d9, d17\n"
"vmlal.s16 %q[vo3], d13, d17\n"
"vmlal.s16 %q[vo0], d2, d18\n"
"vmlal.s16 %q[vo1], d6, d18\n"
"vmlal.s16 %q[vo2], d10, d18\n"
"vmlal.s16 %q[vo3], d14, d18\n"
"vmlal.s16 %q[vo0], d3, d19\n"
"vmlal.s16 %q[vo1], d7, d19\n"
"vmlal.s16 %q[vo2], d11, d19\n"
"vmlal.s16 %q[vo3], d15, d19\n"
"0:\n"
: // outputs
[vo0] "+w"(vo0),
[vo1] "+w"(vo1),
[vo2] "+w"(vo2),
[vo3] "+w"(vo3),
[r_w_block_count] "+r"(r_w_block_count)
: // inputs
[lhs_ptr] "r"(lhs_ptr), [rhs_ptr] "r"(rhs_ptr),
[lhs_width] "r"(lhs_width),
[lhs_zero_point] "r"(lhs_zero_point),
[rhs_zero_point] "r"(rhs_zero_point)
: // clobbers
"cc", "memory", "r0", "r1", "r2", "r3", "r4",
"d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
"d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", "d20",
"d21");
lhs_ptr += w_block_count * w_block_size;
rhs_ptr += w_block_count * w_block_size;
#else
for (index_t w_block_index = 0; w_block_index < w_block_count;
++w_block_index) {
uint8x8_t vr0 = vld1_u8(rhs_ptr);
int16x8_t
vxr0 = vreinterpretq_s16_u16(vsubl_u8(vr0, vrhs_zero_point));
uint8x8_t vr0n = vld1_u8(rhs_ptr + 8);
int16x8_t
vxr0n = vreinterpretq_s16_u16(vsubl_u8(vr0n, vrhs_zero_point));
uint8x8_t vl0 = vld1_u8(lhs_ptr);
int16x8_t
vxl0 = vreinterpretq_s16_u16(vsubl_u8(vl0, vlhs_zero_point));
uint8x8_t vl0n = vld1_u8(lhs_ptr + 8);
int16x8_t
vxl0n = vreinterpretq_s16_u16(vsubl_u8(vl0n, vlhs_zero_point));
vo0 = vmlal_s16(vo0, vget_low_s16(vxl0), vget_low_s16(vxr0));
vo0 = vmlal_high_s16(vo0, vxl0, vxr0);
vo0 = vmlal_s16(vo0, vget_low_s16(vxl0n), vget_low_s16(vxr0n));
vo0 = vmlal_high_s16(vo0, vxl0n, vxr0n);
const uint8_t *lhs_ptr1 = lhs_ptr + lhs_width;
uint8x8_t vl1 = vld1_u8(lhs_ptr1);
int16x8_t
vxl1 = vreinterpretq_s16_u16(vsubl_u8(vl1, vlhs_zero_point));
uint8x8_t vl1n = vld1_u8(lhs_ptr1 + 8);
int16x8_t
vxl1n = vreinterpretq_s16_u16(vsubl_u8(vl1n, vlhs_zero_point));
vo1 = vmlal_s16(vo1, vget_low_s16(vxl1), vget_low_s16(vxr0));
vo1 = vmlal_high_s16(vo1, vxl1, vxr0);
vo1 = vmlal_s16(vo1, vget_low_s16(vxl1n), vget_low_s16(vxr0n));
vo1 = vmlal_high_s16(vo1, vxl1n, vxr0n);
const uint8_t *lhs_ptr2 = lhs_ptr1 + lhs_width;
uint8x8_t vl2 = vld1_u8(lhs_ptr2);
int16x8_t
vxl2 = vreinterpretq_s16_u16(vsubl_u8(vl2, vlhs_zero_point));
uint8x8_t vl2n = vld1_u8(lhs_ptr2 + 8);
int16x8_t
vxl2n = vreinterpretq_s16_u16(vsubl_u8(vl2n, vlhs_zero_point));
vo2 = vmlal_s16(vo2, vget_low_s16(vxl2), vget_low_s16(vxr0));
vo2 = vmlal_high_s16(vo2, vxl2, vxr0);
vo2 = vmlal_s16(vo2, vget_low_s16(vxl2n), vget_low_s16(vxr0n));
vo2 = vmlal_high_s16(vo2, vxl2n, vxr0n);
const uint8_t *lhs_ptr3 = lhs_ptr2 + lhs_width;
uint8x8_t vl3 = vld1_u8(lhs_ptr3);
int16x8_t
vxl3 = vreinterpretq_s16_u16(vsubl_u8(vl3, vlhs_zero_point));
uint8x8_t vl3n = vld1_u8(lhs_ptr3 + 8);
int16x8_t
vxl3n = vreinterpretq_s16_u16(vsubl_u8(vl3n, vlhs_zero_point));
vo3 = vmlal_s16(vo3, vget_low_s16(vxl3), vget_low_s16(vxr0));
vo3 = vmlal_high_s16(vo3, vxl3, vxr0);
vo3 = vmlal_s16(vo3, vget_low_s16(vxl3n), vget_low_s16(vxr0n));
vo3 = vmlal_high_s16(vo3, vxl3n, vxr0n);
lhs_ptr += 16;
rhs_ptr += 16;
}
#endif // __aarch64__
int32x4_t vo = {vaddvq_s32(vo0),
vaddvq_s32(vo1),
vaddvq_s32(vo2),
vaddvq_s32(vo3)};
for (index_t w = 0; w < w_remain; ++w) {
vo[0] +=
(lhs_ptr[0] - lhs_zero_point) * (rhs_ptr[0] - rhs_zero_point);
vo[1] += (lhs_ptr[lhs_width] - lhs_zero_point)
* (rhs_ptr[0] - rhs_zero_point);
vo[2] += (lhs_ptr[lhs_width * 2] - lhs_zero_point)
* (rhs_ptr[0] - rhs_zero_point);
vo[3] += (lhs_ptr[lhs_width * 3] - lhs_zero_point)
* (rhs_ptr[0] - rhs_zero_point);
++lhs_ptr;
++rhs_ptr;
}
if (bias) {
int32x4_t vbias = vdupq_n_s32(0);
vbias = vld1q_s32(bias_data + h_offset);
vo = vaddq_s32(vo, vbias);
}
if (is_output_type_uint8) {
int32x4_t vo_mul = vqrdmulhq_s32(vo, voutput_multiplier);
int32x4_t
fixup = vshrq_n_s32(vandq_s32(vo_mul, voutput_shift_left), 31);
int32x4_t fixed_up_x = vqaddq_s32(vo_mul, fixup);
int32x4_t
vo_rescale_int32 = vrshlq_s32(fixed_up_x, voutput_shift_left);
int16x4_t vo_rescale_int16 = vqmovn_s32(vo_rescale_int32);
uint8x8_t vo_rescale_uint8 =
vqmovun_s16(vcombine_s16(vo_rescale_int16, vo_rescale_int16));
ret_ptr[0] = vo_rescale_uint8[0];
ret_ptr[1] = vo_rescale_uint8[1];
ret_ptr[2] = vo_rescale_uint8[2];
ret_ptr[3] = vo_rescale_uint8[3];
} else {
ret_ptr[0] = vo[0];
ret_ptr[1] = vo[1];
ret_ptr[2] = vo[2];
ret_ptr[3] = vo[3];
}
} else { // h_block_len < 4
// TODO(liyin): handle here case by case (1,2,3) to accelerate
const uint8_t *tmp_lhs_ptr = lhs_ptr;
const uint8_t *tmp_rhs_ptr = rhs_ptr;
for (index_t h = 0; h < h_block_len; ++h) {
lhs_ptr = tmp_lhs_ptr + h * lhs_width;
rhs_ptr = tmp_rhs_ptr;
int32x4_t vo0 = vdupq_n_s32(0);
for (index_t w = 0; w < w_block_count; ++w) {
uint8x8_t vr0 = vld1_u8(rhs_ptr);
int16x8_t
vxr0 = vreinterpretq_s16_u16(vsubl_u8(vr0, vrhs_zero_point));
uint8x8_t vr0n = vld1_u8(rhs_ptr + 8);
int16x8_t
vxr0n = vreinterpretq_s16_u16(vsubl_u8(vr0n, vrhs_zero_point));
uint8x8_t vl0 = vld1_u8(lhs_ptr);
int16x8_t
vxl0 = vreinterpretq_s16_u16(vsubl_u8(vl0, vlhs_zero_point));
uint8x8_t vl0n = vld1_u8(lhs_ptr + 8);
int16x8_t
vxl0n = vreinterpretq_s16_u16(vsubl_u8(vl0n, vlhs_zero_point));
vo0 = vmlal_s16(vo0, vget_low_s16(vxl0), vget_low_s16(vxr0));
vo0 = vmlal_high_s16(vo0, vxl0, vxr0);
vo0 = vmlal_s16(vo0, vget_low_s16(vxl0n), vget_low_s16(vxr0n));
vo0 = vmlal_high_s16(vo0, vxl0n, vxr0n);
lhs_ptr += 16;
rhs_ptr += 16;
} // w
int32_t s0 = vaddvq_s32(vo0) + (bias ? bias_data[h_offset + h] : 0);
for (index_t w = 0; w < w_remain; ++w) {
s0 += (lhs_ptr[0] - lhs_zero_point) * (rhs_ptr[0] - rhs_zero_point);
++lhs_ptr;
++rhs_ptr;
} // w
if (is_output_type_uint8) {
ret_ptr[h] =
Saturate<uint8_t>(std::roundf(s0 * output_multiplier_float));
} else {
ret_ptr[h] = s0;
}
} // h
} // if
} // h_block_idx
for (index_t w = 0; w < w_block_remain; ++w) {
dot += (*lhs_ptr) * (*rhs_ptr);
sum_lhs += (*lhs_ptr);
++lhs_ptr;
++rhs_ptr;
}
const auto zero_point_dot =
static_cast<int32_t>(lhs_zero_point * rhs_zero_point * lhs_width);
int32_t ret = dot - sum_lhs * rhs_zero_point - sum_rhs * lhs_zero_point
+ zero_point_dot;
if (bias) {
ret += bias->data<int32_t>()[h];
}
if (is_output_type_uint8_) {
*output_ptr =
Saturate<uint8_t>(std::roundf(ret * output_multiplier_float));
} else {
*output_ptr = ret;
}
} // h
} // b
return MaceStatus::MACE_SUCCESS;
}
......@@ -466,7 +183,6 @@ class Gemv<int32_t>;
} // namespace ops
} // namespace mace
#if defined(vmlal_high_s16)
#undef vmlal_high_s16
#undef vaddvq_s32
#endif
#ifdef vaddvq_u32
#undef vaddvq_u32
#endif // vaddvq_u32
......@@ -30,7 +30,9 @@ namespace q8 {
template<typename OUTPUT_TYPE>
class Gemv {
public:
Gemv() {}
Gemv() : is_output_type_uint8_(
DataTypeToEnum<OUTPUT_TYPE>::value == DataType::DT_UINT8) {
}
~Gemv() {}
// Always row-major after transpose
MaceStatus Compute(
......@@ -44,6 +46,9 @@ class Gemv {
const bool lhs_batched,
const bool rhs_batched,
Tensor *output);
private:
bool is_output_type_uint8_;
};
} // namespace q8
......
......@@ -280,7 +280,7 @@ class TransformerRule(Enum):
FOLD_FC_RESHAPE = 37
TRANSFORM_CHANNEL_SHUFFLE = 38
UPDATE_DATA_FORMAT = 39
QUANTIZE_MATMUL_ONLY = 40
QUANTIZE_SPECIFIC_OPS_ONLY = 40
class ConverterInterface(object):
......
......@@ -103,8 +103,8 @@ class Transformer(base_converter.ConverterInterface):
self.transform_caffe_reshape_and_flatten,
TransformerRule.TRANSFORM_CHANNEL_SHUFFLE:
self.transform_channel_shuffle,
TransformerRule.QUANTIZE_MATMUL_ONLY:
self.quantize_matmul_only,
TransformerRule.QUANTIZE_SPECIFIC_OPS_ONLY:
self.quantize_specific_ops_only,
}
self._option = option
......@@ -1118,7 +1118,7 @@ class Transformer(base_converter.ConverterInterface):
rhs = op.input[1]
if rhs in self._consts and len(self._consts[rhs].dims) == 2:
arg = ConverterUtil.get_arg(op, MaceKeyword.mace_transpose_b_str) # noqa
six.print_('transpose matmul weight')
six.print_("Transpose matmul weight %s" % rhs)
if arg is None:
arg = op.arg.add()
arg.name = MaceKeyword.mace_transpose_b_str
......@@ -1927,35 +1927,46 @@ class Transformer(base_converter.ConverterInterface):
return True
def quantize_matmul_only(self):
def quantize_specific_ops_only(self):
"""
This transform rule is only used internally, we are not gonna make
things too complex for users
"""
to_quantize_ops = [MaceOp.MatMul.name]
to_quantize_ops_output_type = {
MaceOp.MatMul.name: mace_pb2.DT_INT32,
MaceOp.Gather.name: mace_pb2.DT_UINT8,
}
for op in self._model.op:
if (op.type not in to_quantize_ops or len(op.output) > 1
if (op.type not in to_quantize_ops_output_type
or len(op.output) > 1
or ConverterUtil.get_arg(op,
MaceKeyword.mace_op_data_type_str).i != mace_pb2.DT_FLOAT): # noqa
# only support single output
continue
quantized_inputs_names = []
should_quantize = True
should_quantize = False
for idx, input_tensor in enumerate(op.input):
if self.get_tensor_data_type(input_tensor) \
!= mace_pb2.DT_FLOAT:
should_quantize = False
== mace_pb2.DT_FLOAT:
should_quantize = True
break
if not should_quantize:
continue
else:
print("Quantize op %s (%s)" % (op.name, op.type))
non_zero = self._option.device == DeviceType.CPU.value
for idx, input_tensor in enumerate(op.input):
quantized_inputs_names.append(input_tensor)
if self.get_tensor_data_type(input_tensor) \
!= mace_pb2.DT_FLOAT:
continue
if input_tensor in self._consts:
const_tensor = self._consts[input_tensor]
quantized_tensor = quantize_util.quantize(
......@@ -2005,7 +2016,7 @@ class Transformer(base_converter.ConverterInterface):
orginal_output_name = op.output[0]
op.output[0] = orginal_output_name + "_quant"
op.output_type.extend([mace_pb2.DT_INT32])
op.output_type.extend([to_quantize_ops_output_type[op.type]])
data_type_arg = ConverterUtil.get_arg(op,
MaceKeyword.mace_op_data_type_str) # noqa
if data_type_arg is None:
......@@ -2022,7 +2033,7 @@ class Transformer(base_converter.ConverterInterface):
dequantize_op.output_type.extend([mace_pb2.DT_FLOAT])
data_type_arg = dequantize_op.arg.add()
data_type_arg.name = MaceKeyword.mace_op_data_type_str
data_type_arg.i = mace_pb2.DT_INT32
data_type_arg.i = to_quantize_ops_output_type[op.type]
quantize_flag_arg = ConverterUtil.get_arg(self._model,
MaceKeyword.mace_quantize_flag_arg_str) # noqa
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册