From 475b6b9f8cdecd8c91c22a2b1d816123b1476a86 Mon Sep 17 00:00:00 2001 From: Yizhuang Zhou <62599194+zhouyizhuang-megvii@users.noreply.github.com> Date: Fri, 5 Jun 2020 15:48:04 +0800 Subject: [PATCH] feat(quantization): add pretrained weights and update accuracy of quantized models (#30) --- official/quantization/README.md | 17 +++++++++++++++-- official/quantization/config.py | 4 ++-- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/official/quantization/README.md b/official/quantization/README.md index 66582ad..3b49339 100644 --- a/official/quantization/README.md +++ b/official/quantization/README.md @@ -6,13 +6,26 @@ | Model | top1 acc (float32) | FPS* (float32) | top1 acc (int8) | FPS* (int8) | | --- | --- | --- | --- | --- | | ResNet18 | 69.824 | 10.5 | 69.754 | 16.3 | -| ShufflenetV1 (1.5x) | 71.954 | 17.3 | | 25.3 | -| MobilenetV2 | 72.820 | 13.1 | | 17.4 | +| ShufflenetV1 (1.5x) | 71.954 | 17.3 | 70.656 | 25.3 | +| MobilenetV2 | 72.820 | 13.1 | 71.378 | 17.4 | **: FPS is measured on Intel(R) Xeon(R) Gold 6130 CPU @ 2.10GHz, single 224x224 image* +*We finetune mobile models with QAT for 30 epochs, training longer may yield better accuracy* + 量化模型使用时,统一读取0-255的uint8图片,减去128的均值,转化为int8,输入网络。 + +#### (Optional) Download Pretrained Models +``` +wget https://data.megengine.org.cn/models/weights/mobilenet_v2_normal_72820.pkl +wget https://data.megengine.org.cn/models/weights/mobilenet_v2_qat_71378.pkl +wget https://data.megengine.org.cn/models/weights/resnet18_normal_69824.pkl +wget https://data.megengine.org.cn/models/weights/resnet18_qat_69754.pkl +wget https://data.megengine.org.cn/models/weights/shufflenet_v1_x1_5_g3_normal_71954.pkl +wget https://data.megengine.org.cn/models/weights/shufflenet_v1_x1_5_g3_qat_70656.pkl +``` + ## Quantization Aware Training (QAT) ```python diff --git a/official/quantization/config.py b/official/quantization/config.py index 967b8a6..ef3ea23 100644 --- a/official/quantization/config.py +++ b/official/quantization/config.py @@ -48,8 +48,8 @@ def get_config(arch: str): class ShufflenetFinetuneConfig(ShufflenetConfig): BATCH_SIZE = 128 // 2 - LEARNING_RATE = 0.03125 - EPOCHS = 120 + LEARNING_RATE = 0.003125 // 2 + EPOCHS = 30 class ResnetFinetuneConfig(ResnetConfig): -- GitLab