未验证 提交 03625d2f 编写于 作者: Y Yizhuang Zhou 提交者: GitHub

fix(quant): fix code and add quantized weights (#38)

上级 f64a4ccb
...@@ -47,3 +47,5 @@ from official.vision.keypoints.models import ( ...@@ -47,3 +47,5 @@ from official.vision.keypoints.models import (
) )
from official.vision.keypoints.inference import KeypointEvaluator from official.vision.keypoints.inference import KeypointEvaluator
from official.quantization.models import quantized_resnet18
...@@ -48,7 +48,7 @@ def get_config(arch: str): ...@@ -48,7 +48,7 @@ def get_config(arch: str):
class ShufflenetFinetuneConfig(ShufflenetConfig): class ShufflenetFinetuneConfig(ShufflenetConfig):
BATCH_SIZE = 128 // 2 BATCH_SIZE = 128 // 2
LEARNING_RATE = 0.003125 // 2 LEARNING_RATE = 0.003125 / 2
EPOCHS = 30 EPOCHS = 30
......
...@@ -17,6 +17,7 @@ import megengine.functional as F ...@@ -17,6 +17,7 @@ import megengine.functional as F
import megengine.jit as jit import megengine.jit as jit
import megengine.quantization as Q import megengine.quantization as Q
import numpy as np import numpy as np
from megengine.quantization.quantize import quantize, quantize_qat
import models import models
...@@ -45,7 +46,10 @@ def main(): ...@@ -45,7 +46,10 @@ def main():
model = models.__dict__[args.arch]() model = models.__dict__[args.arch]()
if args.mode != "normal": if args.mode != "normal":
Q.quantize_qat(model, Q.ema_fakequant_qconfig) quantize_qat(model, Q.ema_fakequant_qconfig)
if args.mode == "quantized":
quantize(model)
if args.checkpoint: if args.checkpoint:
logger.info("Load pretrained weights from %s", args.checkpoint) logger.info("Load pretrained weights from %s", args.checkpoint)
...@@ -53,9 +57,6 @@ def main(): ...@@ -53,9 +57,6 @@ def main():
ckpt = ckpt["state_dict"] if "state_dict" in ckpt else ckpt ckpt = ckpt["state_dict"] if "state_dict" in ckpt else ckpt
model.load_state_dict(ckpt, strict=False) model.load_state_dict(ckpt, strict=False)
if args.mode == "quantized":
Q.quantize(model)
if args.image is None: if args.image is None:
path = "../assets/cat.jpg" path = "../assets/cat.jpg"
else: else:
......
...@@ -46,6 +46,7 @@ import math ...@@ -46,6 +46,7 @@ import math
import megengine.functional as F import megengine.functional as F
import megengine.hub as hub import megengine.hub as hub
import megengine.module as M import megengine.module as M
from megengine.quantization.quantize import quantize_qat, quantize
class BasicBlock(M.Module): class BasicBlock(M.Module):
...@@ -292,58 +293,23 @@ def resnet18(**kwargs): ...@@ -292,58 +293,23 @@ def resnet18(**kwargs):
r"""ResNet-18 model from r"""ResNet-18 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
""" """
return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) m = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
m.fc.disable_quantize()
return m
def resnet34(**kwargs):
r"""ResNet-34 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
"""
return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
def resnet50(**kwargs): def resnet50(**kwargs):
r"""ResNet-50 model from r"""ResNet-50 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
""" """
return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) m = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
m.fc.disable_quantize()
return m
def resnet101(**kwargs):
r"""ResNet-101 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
"""
return ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
def resnet152(**kwargs):
r"""ResNet-152 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
"""
return ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
def resnext50_32x4d(**kwargs): @hub.pretrained("https://data.megengine.org.cn/models/weights/resnet18.quantized.pkl")
r"""ResNeXt-50 32x4d model from def quantized_resnet18(**kwargs):
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_ model = resnet18(**kwargs)
quantize_qat(model)
Args: quantize(model)
pretrained (bool): If True, returns a model pre-trained on ImageNet return model
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs["groups"] = 32
kwargs["width_per_group"] = 4
return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
def resnext101_32x8d(**kwargs):
r"""ResNeXt-101 32x8d model from
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs["groups"] = 32
kwargs["width_per_group"] = 8
return ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册