提交 c6efac34 编写于 作者: L liyin

Optimize quantize and dequantize ops

上级 f0c7717e
...@@ -10,11 +10,12 @@ licenses(["notice"]) # Apache 2.0 ...@@ -10,11 +10,12 @@ licenses(["notice"]) # Apache 2.0
load( load(
"//mace:mace.bzl", "//mace:mace.bzl",
"if_android", "if_android",
"if_android_armv7",
"if_hexagon_enabled", "if_hexagon_enabled",
"if_not_hexagon_enabled",
"if_openmp_enabled",
"if_neon_enabled", "if_neon_enabled",
"if_not_hexagon_enabled",
"if_opencl_enabled", "if_opencl_enabled",
"if_openmp_enabled",
"if_quantize_enabled", "if_quantize_enabled",
) )
...@@ -58,6 +59,9 @@ cc_library( ...@@ -58,6 +59,9 @@ cc_library(
"-DMACE_ENABLE_HEXAGON", "-DMACE_ENABLE_HEXAGON",
]) + if_neon_enabled([ ]) + if_neon_enabled([
"-DMACE_ENABLE_NEON", "-DMACE_ENABLE_NEON",
]) + if_android_armv7([
"-mfpu=neon",
"-mfloat-abi=softfp",
]), ]),
linkopts = ["-ldl"], linkopts = ["-ldl"],
deps = [ deps = [
......
...@@ -40,19 +40,33 @@ struct CPUFreq { ...@@ -40,19 +40,33 @@ struct CPUFreq {
float freq; float freq;
}; };
enum SchedulePolicy {
SCHED_STATIC,
SCHED_GUIDED,
};
namespace { namespace {
MaceStatus SetOpenMPThreadsAndAffinityCPUs(int omp_num_threads, MaceStatus SetOpenMPThreadsAndAffinityCPUs(int omp_num_threads,
const std::vector<size_t> &cpu_ids) { const std::vector<size_t> &cpu_ids,
SchedulePolicy schedule_policy) {
MaceOpenMPThreadCount = omp_num_threads; MaceOpenMPThreadCount = omp_num_threads;
#ifdef MACE_ENABLE_OPENMP #ifdef MACE_ENABLE_OPENMP
VLOG(1) << "Set OpenMP threads number: " << omp_num_threads VLOG(1) << "Set OpenMP threads number: " << omp_num_threads
<< ", CPU core IDs: " << MakeString(cpu_ids); << ", CPU core IDs: " << MakeString(cpu_ids);
omp_set_schedule(omp_sched_guided, 1); if (schedule_policy == SCHED_GUIDED) {
omp_set_schedule(omp_sched_guided, 1);
} else if (schedule_policy == SCHED_STATIC) {
omp_set_schedule(omp_sched_static, 0);
} else {
LOG(WARNING) << "Unknown schedule policy: " << schedule_policy;
}
omp_set_num_threads(omp_num_threads); omp_set_num_threads(omp_num_threads);
#else #else
MACE_UNUSED(omp_num_threads); MACE_UNUSED(omp_num_threads);
MACE_UNUSED(schedule_policy);
LOG(WARNING) << "Set OpenMP threads number failed: OpenMP not enabled."; LOG(WARNING) << "Set OpenMP threads number failed: OpenMP not enabled.";
#endif #endif
...@@ -148,6 +162,7 @@ MaceStatus CPURuntime::SetOpenMPThreadsAndAffinityPolicy( ...@@ -148,6 +162,7 @@ MaceStatus CPURuntime::SetOpenMPThreadsAndAffinityPolicy(
} else { } else {
cores_to_use = num_threads_hint; cores_to_use = num_threads_hint;
} }
MACE_CHECK(cores_to_use > 0, "number of cores to use should > 0");
VLOG(2) << "Use " << num_threads_hint << " threads"; VLOG(2) << "Use " << num_threads_hint << " threads";
std::vector<size_t> cpu_ids(cores_to_use); std::vector<size_t> cpu_ids(cores_to_use);
...@@ -156,6 +171,10 @@ MaceStatus CPURuntime::SetOpenMPThreadsAndAffinityPolicy( ...@@ -156,6 +171,10 @@ MaceStatus CPURuntime::SetOpenMPThreadsAndAffinityPolicy(
<< cpu_freq[i].freq; << cpu_freq[i].freq;
cpu_ids[i] = cpu_freq[i].core_id; cpu_ids[i] = cpu_freq[i].core_id;
} }
SchedulePolicy sched_policy = SCHED_GUIDED;
if (std::abs(cpu_freq[0].freq - cpu_freq[cores_to_use - 1].freq) < 1e-6) {
sched_policy = SCHED_STATIC;
}
#ifdef MACE_ENABLE_QUANTIZE #ifdef MACE_ENABLE_QUANTIZE
if (gemm_context) { if (gemm_context) {
...@@ -164,7 +183,9 @@ MaceStatus CPURuntime::SetOpenMPThreadsAndAffinityPolicy( ...@@ -164,7 +183,9 @@ MaceStatus CPURuntime::SetOpenMPThreadsAndAffinityPolicy(
} }
#endif // MACE_ENABLE_QUANTIZE #endif // MACE_ENABLE_QUANTIZE
return SetOpenMPThreadsAndAffinityCPUs(num_threads_hint, cpu_ids); return SetOpenMPThreadsAndAffinityCPUs(num_threads_hint,
cpu_ids,
sched_policy);
} }
} // namespace mace } // namespace mace
......
...@@ -1942,13 +1942,21 @@ class Transformer(base_converter.ConverterInterface): ...@@ -1942,13 +1942,21 @@ class Transformer(base_converter.ConverterInterface):
continue continue
quantized_inputs_names = [] quantized_inputs_names = []
should_quantize = False should_quantize = False
has_const = False
for idx, input_tensor in enumerate(op.input):
if input_tensor in self._consts:
has_const = True
break
if not has_const:
continue
for idx, input_tensor in enumerate(op.input): for idx, input_tensor in enumerate(op.input):
if self.get_tensor_data_type(input_tensor) \ if self.get_tensor_data_type(input_tensor) \
== mace_pb2.DT_FLOAT: == mace_pb2.DT_FLOAT:
should_quantize = True should_quantize = True
break break
if not should_quantize: if not should_quantize:
continue continue
else: else:
......
...@@ -7,6 +7,14 @@ package( ...@@ -7,6 +7,14 @@ package(
licenses(["notice"]) # Apache 2.0 licenses(["notice"]) # Apache 2.0
load(
"//mace:mace.bzl",
"if_android",
"if_android_armv7",
"if_neon_enabled",
"if_openmp_enabled",
)
cc_library( cc_library(
name = "utils_hdrs", name = "utils_hdrs",
hdrs = glob([ hdrs = glob([
...@@ -37,7 +45,17 @@ cc_library( ...@@ -37,7 +45,17 @@ cc_library(
"-Werror", "-Werror",
"-Wextra", "-Wextra",
"-Wno-missing-field-initializers", "-Wno-missing-field-initializers",
], ] + if_openmp_enabled([
"-fopenmp",
]) + if_neon_enabled([
"-DMACE_ENABLE_NEON",
]) + if_android_armv7([
"-mfpu=neon",
"-mfloat-abi=softfp",
]),
linkopts = if_android([
"-llog",
]),
deps = [ deps = [
":utils_hdrs", ":utils_hdrs",
], ],
......
...@@ -19,6 +19,10 @@ ...@@ -19,6 +19,10 @@
#include <cmath> #include <cmath>
#include <limits> #include <limits>
#if defined(MACE_ENABLE_NEON)
#include <arm_neon.h>
#endif // MACE_ENABLE_NEON
#include "mace/utils/logging.h" #include "mace/utils/logging.h"
namespace mace { namespace mace {
...@@ -156,6 +160,106 @@ inline void Dequantize(const T *input, ...@@ -156,6 +160,106 @@ inline void Dequantize(const T *input,
} }
} }
#if defined(MACE_ENABLE_NEON)
template<>
inline void QuantizeWithScaleAndZeropoint<uint8_t>(const float *input,
const index_t size,
float scale,
int32_t zero_point,
uint8_t *output) {
const float32x4_t vround = vdupq_n_f32(0.5);
const float32x4_t
vzero = vaddq_f32(vround, vcvtq_f32_s32(vdupq_n_s32(zero_point)));
const float recip_scale = 1.f / scale;
const float32x4_t vrecip_scale = vdupq_n_f32(recip_scale);
const index_t block_count = size / 16;
#pragma omp parallel for schedule(runtime)
for (index_t i = 0; i < block_count; ++i) {
float32x4_t vi0 = vld1q_f32(input + i * 16);
float32x4_t vi1 = vld1q_f32(input + i * 16 + 4);
float32x4_t vi2 = vld1q_f32(input + i * 16 + 8);
float32x4_t vi3 = vld1q_f32(input + i * 16 + 12);
int32x4_t vo0_s32 = vcvtq_s32_f32(vmlaq_f32(vzero, vi0, vrecip_scale));
int32x4_t vo1_s32 = vcvtq_s32_f32(vmlaq_f32(vzero, vi1, vrecip_scale));
int32x4_t vo2_s32 = vcvtq_s32_f32(vmlaq_f32(vzero, vi2, vrecip_scale));
int32x4_t vo3_s32 = vcvtq_s32_f32(vmlaq_f32(vzero, vi3, vrecip_scale));
uint8x8_t vo0_u8 =
vqmovun_s16(vcombine_s16(vqmovn_s32(vo0_s32), vqmovn_s32(vo1_s32)));
uint8x8_t vo1_u8 =
vqmovun_s16(vcombine_s16(vqmovn_s32(vo2_s32), vqmovn_s32(vo3_s32)));
uint8x16_t vo = vcombine_u8(vo0_u8, vo1_u8);
vst1q_u8(output + i * 16, vo);
}
#pragma omp parallel for schedule(runtime)
for (index_t i = block_count * 16; i < size; ++i) {
output[i] = Saturate<uint8_t>(roundf(zero_point + recip_scale * input[i]));
}
}
template<>
inline void Dequantize<int32_t>(const int32_t *input,
const index_t size,
const float scale,
const int32_t zero_point,
float *output) {
const index_t block_count = size / 4;
const int32x4_t vzero = vdupq_n_s32(zero_point);
const float32x4_t vscale = vdupq_n_f32(scale);
#pragma omp parallel for schedule(runtime)
for (index_t i = 0; i < block_count; ++i) {
int32x4_t vi = vld1q_s32(input + i * 4);
float32x4_t vo = vmulq_f32(vscale, vcvtq_f32_s32(vsubq_s32(vi, vzero)));
vst1q_f32(output + i * 4, vo);
}
for (index_t i = block_count * 4; i < size; ++i) {
output[i] = scale * (input[i] - zero_point);
}
}
template<>
inline void Dequantize<uint8_t>(const uint8_t *input,
const index_t size,
const float scale,
const int32_t zero_point,
float *output) {
const index_t block_count = size / 16;
const int32x4_t vzero = vdupq_n_s32(zero_point);
const float32x4_t vscale = vdupq_n_f32(scale);
#pragma omp parallel for schedule(runtime)
for (index_t i = 0; i < block_count; ++i) {
uint8x16_t vi = vld1q_u8(input + i * 16);
float32x4x4_t vo = {
vmulq_f32(vscale,
vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(
vget_low_u16(vmovl_u8(vget_low_u8(vi))))), vzero))),
vmulq_f32(vscale,
vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(
vget_high_u16(vmovl_u8(vget_low_u8(vi))))), vzero))),
vmulq_f32(vscale,
vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(
vget_low_u16(vmovl_u8(vget_high_u8(vi))))), vzero))),
vmulq_f32(vscale,
vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(
vget_high_u16(vmovl_u8(vget_high_u8(vi))))), vzero))),
};
vst1q_f32(output + i * 16, vo.val[0]);
vst1q_f32(output + i * 16 + 4, vo.val[1]);
vst1q_f32(output + i * 16 + 8, vo.val[2]);
vst1q_f32(output + i * 16 + 12, vo.val[3]);
}
for (index_t i = block_count * 16; i < size; ++i) {
output[i] = scale * (input[i] - zero_point);
}
}
#endif // MACE_ENABLE_NEON
template<typename T> template<typename T>
inline void DeQuantize(const Tensor &input, inline void DeQuantize(const Tensor &input,
const float min_in, const float min_in,
...@@ -175,8 +279,8 @@ inline void DeQuantize(const Tensor &input, ...@@ -175,8 +279,8 @@ inline void DeQuantize(const Tensor &input,
} }
inline void QuantizeMultiplier(double multiplier, inline void QuantizeMultiplier(double multiplier,
int32_t* output_multiplier, int32_t *output_multiplier,
int32_t* shift) { int32_t *shift) {
const double q = std::frexp(multiplier, shift); const double q = std::frexp(multiplier, shift);
auto qint = static_cast<int64_t>(roundl(q * (1ll << 31))); auto qint = static_cast<int64_t>(roundl(q * (1ll << 31)));
if (qint == (1ll << 31)) { if (qint == (1ll << 31)) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册