From b1e8fd8a980cbe9fa941f822b336aaaa201b77b7 Mon Sep 17 00:00:00 2001 From: Yizhuang Zhou <62599194+zhouyizhuang-megvii@users.noreply.github.com> Date: Fri, 29 May 2020 14:50:28 +0800 Subject: [PATCH] feat(quantization): Model Quantization (#17) * add quantization codebase for resnet, shufflenetv1, and mobilenetv2 Co-authored-by: Kevin.W Co-authored-by: wangshupeng Co-authored-by: LiZhiyuan <848796515@qq.com> Co-authored-by: wangjianfeng --- official/quantization/README.md | 85 +++++ official/quantization/config.py | 71 ++++ official/quantization/finetune.py | 317 +++++++++++++++++ official/quantization/inference.py | 110 ++++++ official/quantization/models/__init__.py | 11 + official/quantization/models/mobilenet_v2.py | 187 ++++++++++ official/quantization/models/resnet.py | 349 +++++++++++++++++++ official/quantization/models/shufflenet.py | 221 ++++++++++++ official/quantization/test.py | 194 +++++++++++ official/quantization/train.py | 309 ++++++++++++++++ 10 files changed, 1854 insertions(+) create mode 100644 official/quantization/README.md create mode 100644 official/quantization/config.py create mode 100644 official/quantization/finetune.py create mode 100644 official/quantization/inference.py create mode 100644 official/quantization/models/__init__.py create mode 100644 official/quantization/models/mobilenet_v2.py create mode 100644 official/quantization/models/resnet.py create mode 100644 official/quantization/models/shufflenet.py create mode 100644 official/quantization/test.py create mode 100644 official/quantization/train.py diff --git a/official/quantization/README.md b/official/quantization/README.md new file mode 100644 index 0000000..66582ad --- /dev/null +++ b/official/quantization/README.md @@ -0,0 +1,85 @@ +模型量化 Model Quantization +--- + +本目录包含了采用MegEngine实现的量化训练和部署的代码,包括常用的ResNet、ShuffleNet和MobileNet,其量化模型的ImageNet Top 1 准确率如下: + +| Model | top1 acc (float32) | FPS* (float32) | top1 acc (int8) | FPS* (int8) | +| --- | --- | --- | --- | --- | +| ResNet18 | 69.824 | 10.5 | 69.754 | 16.3 | +| ShufflenetV1 (1.5x) | 71.954 | 17.3 | | 25.3 | +| MobilenetV2 | 72.820 | 13.1 | | 17.4 | + +**: FPS is measured on Intel(R) Xeon(R) Gold 6130 CPU @ 2.10GHz, single 224x224 image* + +量化模型使用时,统一读取0-255的uint8图片,减去128的均值,转化为int8,输入网络。 + +## Quantization Aware Training (QAT) + +```python +import megengine.quantization as Q + +model = ... + +# Quantization Aware Training +Q.quantize_qat(model, qconfig=Q.ema_fakequant_qconfig) + +for _ in range(...): + train(model) +``` + +## Deploying Quantized Model + +```python +import megengine.quantization as Q +import megengine.jit as jit + +model = ... + +Q.quantize_qat(model, qconfig=Q.ema_fakequant_qconfig) + +# real quant +Q.quantize(model) + +@jit.trace(symbolic=True): +def inference_func(x): + return model(x) + +inference_func.dump(...) +``` + +# HOWTO use this codebase + +## Step 1. Train a fp32 model + +``` +python3 train.py -a resnet18 -d /path/to/imagenet --mode normal +``` + +## Step 2. Finetune fp32 model with quantization aware training(QAT) + +``` +python3 finetune.py -a resnet18 -d /path/to/imagenet --checkpoint /path/to/resnet18.normal/checkpoint.pkl --mode qat +``` + +## Step 3. Test QAT model on ImageNet Testset + +``` +python3 test.py -a resnet18 -d /path/to/imagenet --checkpoint /path/to/resnet18.qat/checkpoint.pkl --mode qat +``` + +or testing in quantized mode, which uses only cpu for inference and takes longer time + +``` +python3 test.py -a resnet18 -d /path/to/imagenet --checkpoint /path/to/resnet18.qat/checkpoint.pkl --mode quantized -n 1 +``` + +## Step 4. Inference and dump + +``` +python3 inference.py -a resnet18 --checkpoint /path/to/resnet18.qat/checkpoint.pkl --mode quantized --dump +``` + +will feed a cat image to the network and output the classification probabilities with quantized network. + +Also, set `--dump` will dump the quantized network to `resnet18.quantized.megengine` binary file. + diff --git a/official/quantization/config.py b/official/quantization/config.py new file mode 100644 index 0000000..967b8a6 --- /dev/null +++ b/official/quantization/config.py @@ -0,0 +1,71 @@ +# -*- coding: utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +""" +Configurations to train/finetune quantized classification models +""" +import megengine.data.transform as T + + +class ShufflenetConfig: + BATCH_SIZE = 128 + LEARNING_RATE = 0.0625 + MOMENTUM = 0.9 + WEIGHT_DECAY = lambda self, n, p: \ + 4e-5 if n.find("weight") >= 0 and len(p.shape) > 1 else 0 + EPOCHS = 240 + + SCHEDULER = "Linear" + COLOR_JITTOR = T.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4) + + +class ResnetConfig: + BATCH_SIZE = 32 + LEARNING_RATE = 0.0125 + MOMENTUM = 0.9 + WEIGHT_DECAY = 1e-4 + EPOCHS = 90 + + SCHEDULER = "Multistep" + SCHEDULER_STEPS = [30, 60, 80] + SCHEDULER_GAMMA = 0.1 + COLOR_JITTOR = T.PseudoTransform() # disable colorjittor + + +def get_config(arch: str): + if "resne" in arch: # both resnet and resnext + return ResnetConfig() + elif "shufflenet" in arch or "mobilenet" in arch: + return ShufflenetConfig() + else: + raise ValueError("config for {} not exists".format(arch)) + + +class ShufflenetFinetuneConfig(ShufflenetConfig): + BATCH_SIZE = 128 // 2 + LEARNING_RATE = 0.03125 + EPOCHS = 120 + + +class ResnetFinetuneConfig(ResnetConfig): + BATCH_SIZE = 32 + LEARNING_RATE = 0.000125 + EPOCHS = 12 + + SCHEDULER = "Multistep" + SCHEDULER_STEPS = [6,] + SCHEDULER_GAMMA = 0.1 + + +def get_finetune_config(arch: str): + if "resne" in arch: # both resnet and resnext + return ResnetFinetuneConfig() + elif "shufflenet" in arch or "mobilenet" in arch: + return ShufflenetFinetuneConfig() + else: + raise ValueError("config for {} not exists".format(arch)) diff --git a/official/quantization/finetune.py b/official/quantization/finetune.py new file mode 100644 index 0000000..2397554 --- /dev/null +++ b/official/quantization/finetune.py @@ -0,0 +1,317 @@ +# -*- coding: utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +"""Finetune a pretrained fp32 with int8 quantization aware training(QAT)""" +import argparse +import collections +import multiprocessing as mp +import numbers +import os +import bisect +import time + +import megengine as mge +import megengine.data as data +import megengine.data.transform as T +import megengine.distributed as dist +import megengine.functional as F +import megengine.jit as jit +import megengine.optimizer as optim +import megengine.quantization as Q + +import config +import models + +logger = mge.get_logger(__name__) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("-a", "--arch", default="resnet18", type=str) + parser.add_argument("-d", "--data", default=None, type=str) + parser.add_argument("-s", "--save", default="/data/models", type=str) + parser.add_argument("-c", "--checkpoint", default=None, type=str, + help="pretrained model to finetune") + + parser.add_argument("-m", "--mode", default="qat", type=str, + choices=["normal", "qat", "quantized"], + help="Quantization Mode\n" + "normal: no quantization, using float32\n" + "qat: quantization aware training, simulate int8\n" + "quantized: convert mode to int8 quantized, inference only") + + parser.add_argument("-n", "--ngpus", default=None, type=int) + parser.add_argument("-w", "--workers", default=4, type=int) + parser.add_argument("--report-freq", default=50, type=int) + args = parser.parse_args() + + world_size = mge.get_device_count("gpu") if args.ngpus is None else args.ngpus + + if world_size > 1: + # start distributed training, dispatch sub-processes + mp.set_start_method("spawn") + processes = [] + for rank in range(world_size): + p = mp.Process(target=worker, args=(rank, world_size, args)) + p.start() + processes.append(p) + + for p in processes: + p.join() + else: + worker(0, 1, args) + + +def get_parameters(model, cfg): + if isinstance(cfg.WEIGHT_DECAY, numbers.Number): + return {"params": model.parameters(requires_grad=True), + "weight_decay": cfg.WEIGHT_DECAY} + + groups = collections.defaultdict(list) # weight_decay -> List[param] + for pname, p in model.named_parameters(requires_grad=True): + wd = cfg.WEIGHT_DECAY(pname, p) + groups[wd].append(p) + groups = [ + {"params": params, "weight_decay": wd} + for wd, params in groups.items() + ] # List[{param, weight_decay}] + return groups + + +def worker(rank, world_size, args): + # pylint: disable=too-many-statements + + if world_size > 1: + # Initialize distributed process group + logger.info("init distributed process group {} / {}".format(rank, world_size)) + dist.init_process_group( + master_ip="localhost", + master_port=23456, + world_size=world_size, + rank=rank, + dev=rank, + ) + + save_dir = os.path.join(args.save, args.arch + "." + args.mode) + if not os.path.exists(save_dir): + os.makedirs(save_dir, exist_ok=True) + mge.set_log_file(os.path.join(save_dir, "log.txt")) + + model = models.__dict__[args.arch]() + cfg = config.get_finetune_config(args.arch) + + cfg.LEARNING_RATE *= world_size # scale learning rate in distributed training + total_batch_size = cfg.BATCH_SIZE * world_size + steps_per_epoch = 1280000 // total_batch_size + total_steps = steps_per_epoch * cfg.EPOCHS + + if args.mode != "normal": + Q.quantize_qat(model, Q.ema_fakequant_qconfig) + + if args.checkpoint: + logger.info("Load pretrained weights from %s", args.checkpoint) + ckpt = mge.load(args.checkpoint) + ckpt = ckpt["state_dict"] if "state_dict" in ckpt else ckpt + model.load_state_dict(ckpt, strict=False) + + if args.mode == "quantized": + raise ValueError("mode = quantized only used during inference") + Q.quantize(model) + + optimizer = optim.SGD( + get_parameters(model, cfg), + lr=cfg.LEARNING_RATE, + momentum=cfg.MOMENTUM, + ) + + # Define train and valid graph + @jit.trace(symbolic=True) + def train_func(image, label): + model.train() + logits = model(image) + loss = F.cross_entropy_with_softmax(logits, label, label_smooth=0.1) + acc1, acc5 = F.accuracy(logits, label, (1, 5)) + optimizer.backward(loss) # compute gradients + if dist.is_distributed(): # all_reduce_mean + loss = dist.all_reduce_sum(loss, "train_loss") / dist.get_world_size() + acc1 = dist.all_reduce_sum(acc1, "train_acc1") / dist.get_world_size() + acc5 = dist.all_reduce_sum(acc5, "train_acc5") / dist.get_world_size() + return loss, acc1, acc5 + + @jit.trace(symbolic=True) + def valid_func(image, label): + model.eval() + logits = model(image) + loss = F.cross_entropy_with_softmax(logits, label, label_smooth=0.1) + acc1, acc5 = F.accuracy(logits, label, (1, 5)) + if dist.is_distributed(): # all_reduce_mean + loss = dist.all_reduce_sum(loss, "valid_loss") / dist.get_world_size() + acc1 = dist.all_reduce_sum(acc1, "valid_acc1") / dist.get_world_size() + acc5 = dist.all_reduce_sum(acc5, "valid_acc5") / dist.get_world_size() + return loss, acc1, acc5 + + # Build train and valid datasets + logger.info("preparing dataset..") + train_dataset = data.dataset.ImageNet(args.data, train=True) + train_sampler = data.Infinite(data.RandomSampler( + train_dataset, batch_size=cfg.BATCH_SIZE, drop_last=True + )) + train_queue = data.DataLoader( + train_dataset, + sampler=train_sampler, + transform=T.Compose( + [ + T.RandomResizedCrop(224), + T.RandomHorizontalFlip(), + cfg.COLOR_JITTOR, + T.Normalize(mean=128), + T.ToMode("CHW"), + ] + ), + num_workers=args.workers, + ) + train_queue = iter(train_queue) + valid_dataset = data.dataset.ImageNet(args.data, train=False) + valid_sampler = data.SequentialSampler( + valid_dataset, batch_size=100, drop_last=False + ) + valid_queue = data.DataLoader( + valid_dataset, + sampler=valid_sampler, + transform=T.Compose( + [ + T.Resize(256), + T.CenterCrop(224), + T.Normalize(mean=128), + T.ToMode("CHW"), + ] + ), + num_workers=args.workers, + ) + + def adjust_learning_rate(step, epoch): + learning_rate = cfg.LEARNING_RATE + if cfg.SCHEDULER == "Linear": + learning_rate *= 1 - float(step) / total_steps + elif cfg.SCHEDULER == "Multistep": + learning_rate *= cfg.SCHEDULER_GAMMA ** bisect.bisect_right(cfg.SCHEDULER_STEPS, epoch) + else: + raise ValueError(cfg.SCHEDULER) + for param_group in optimizer.param_groups: + param_group["lr"] = learning_rate + return learning_rate + + # Start training + objs = AverageMeter("Loss") + top1 = AverageMeter("Acc@1") + top5 = AverageMeter("Acc@5") + total_time = AverageMeter("Time") + + t = time.time() + for step in range(0, total_steps): + # Linear learning rate decay + epoch = step // steps_per_epoch + learning_rate = adjust_learning_rate(step, epoch) + + image, label = next(train_queue) + image = image.astype("float32") + label = label.astype("int32") + + n = image.shape[0] + + optimizer.zero_grad() + loss, acc1, acc5 = train_func(image, label) + optimizer.step() + + top1.update(100 * acc1.numpy()[0], n) + top5.update(100 * acc5.numpy()[0], n) + objs.update(loss.numpy()[0], n) + total_time.update(time.time() - t) + t = time.time() + if step % args.report_freq == 0 and rank == 0: + logger.info( + "TRAIN e%d %06d %f %s %s %s %s", + epoch, step, learning_rate, + objs, top1, top5, total_time + ) + objs.reset() + top1.reset() + top5.reset() + total_time.reset() + if step % 10000 == 0 and rank == 0: + logger.info("SAVING %06d", step) + mge.save( + {"step": step, "state_dict": model.state_dict()}, + os.path.join(save_dir, "checkpoint.pkl"), + ) + if step % 10000 == 0 and step != 0: + _, valid_acc, valid_acc5 = infer(valid_func, valid_queue, args) + logger.info("TEST %06d %f, %f", step, valid_acc, valid_acc5) + + mge.save( + {"step": step, "state_dict": model.state_dict()}, + os.path.join(save_dir, "checkpoint-final.pkl") + ) + _, valid_acc, valid_acc5 = infer(valid_func, valid_queue, args) + logger.info("TEST %06d %f, %f", step, valid_acc, valid_acc5) + + +def infer(model, data_queue, args): + objs = AverageMeter("Loss") + top1 = AverageMeter("Acc@1") + top5 = AverageMeter("Acc@5") + total_time = AverageMeter("Time") + + t = time.time() + for step, (image, label) in enumerate(data_queue): + n = image.shape[0] + image = image.astype("float32") # convert np.uint8 to float32 + label = label.astype("int32") + + loss, acc1, acc5 = model(image, label) + + objs.update(loss.numpy()[0], n) + top1.update(100 * acc1.numpy()[0], n) + top5.update(100 * acc5.numpy()[0], n) + total_time.update(time.time() - t) + t = time.time() + + if step % args.report_freq == 0 and dist.get_rank() == 0: + logger.info("Step %d, %s %s %s %s", + step, objs, top1, top5, total_time) + + return objs.avg, top1.avg, top5.avg + + +class AverageMeter: + """Computes and stores the average and current value""" + + def __init__(self, name, fmt=":.3f"): + self.name = name + self.fmt = fmt + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def __str__(self): + fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})" + return fmtstr.format(**self.__dict__) + + +if __name__ == "__main__": + main() diff --git a/official/quantization/inference.py b/official/quantization/inference.py new file mode 100644 index 0000000..ce0703e --- /dev/null +++ b/official/quantization/inference.py @@ -0,0 +1,110 @@ +# -*- coding: utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +"""Finetune a pretrained fp32 with int8 quantization aware training(QAT)""" +import argparse +import json + +import cv2 +import megengine as mge +import megengine.data.transform as T +import megengine.functional as F +import megengine.jit as jit +import megengine.quantization as Q +import numpy as np + +import models + +logger = mge.get_logger(__name__) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("-a", "--arch", default="resnet18", type=str) + parser.add_argument("-c", "--checkpoint", default=None, type=str) + parser.add_argument("-i", "--image", default=None, type=str) + + parser.add_argument("-m", "--mode", default="quantized", type=str, + choices=["normal", "qat", "quantized"], + help="Quantization Mode\n" + "normal: no quantization, using float32\n" + "qat: quantization aware training, simulate int8\n" + "quantized: convert mode to int8 quantized, inference only") + parser.add_argument("--dump", action="store_true", + help="Dump quantized model") + args = parser.parse_args() + + if args.mode == "quantized": + mge.set_default_device("cpux") + + model = models.__dict__[args.arch]() + + if args.mode != "normal": + Q.quantize_qat(model, Q.ema_fakequant_qconfig) + + if args.checkpoint: + logger.info("Load pretrained weights from %s", args.checkpoint) + ckpt = mge.load(args.checkpoint) + ckpt = ckpt["state_dict"] if "state_dict" in ckpt else ckpt + model.load_state_dict(ckpt, strict=False) + + if args.mode == "quantized": + Q.quantize(model) + + if args.image is None: + path = "../assets/cat.jpg" + else: + path = args.image + image = cv2.imread(path, cv2.IMREAD_COLOR) + + transform = T.Compose( + [ + T.Resize(256), + T.CenterCrop(224), + T.Normalize(mean=128), + T.ToMode("CHW"), + ] + ) + + @jit.trace(symbolic=True) + def infer_func(processed_img): + model.eval() + logits = model(processed_img) + probs = F.softmax(logits) + return probs + + processed_img = transform.apply(image)[np.newaxis, :] + + if args.mode == "normal": + processed_img = processed_img.astype("float32") + elif args.mode == "quantized": + processed_img = processed_img.astype("int8") + + probs = infer_func(processed_img) + + top_probs, classes = F.top_k(probs, k=5, descending=True) + + if args.dump: + output_file = ".".join([args.arch, args.mode, "megengine"]) + logger.info("Dump to {}".format(output_file)) + infer_func.dump(output_file, arg_names=["data"]) + mge.save(model.state_dict(), output_file.replace("megengine", "pkl")) + + with open("../assets/imagenet_class_info.json") as fp: + imagenet_class_index = json.load(fp) + + for rank, (prob, classid) in enumerate( + zip(top_probs.numpy().reshape(-1), classes.numpy().reshape(-1)) + ): + print( + "{}: class = {:20s} with probability = {:4.1f} %".format( + rank, imagenet_class_index[str(classid)][1], 100 * prob + ) + ) +if __name__ == "__main__": + main() diff --git a/official/quantization/models/__init__.py b/official/quantization/models/__init__.py new file mode 100644 index 0000000..6f07d91 --- /dev/null +++ b/official/quantization/models/__init__.py @@ -0,0 +1,11 @@ +# -*- coding: utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +from .resnet import * +from .shufflenet import * +from .mobilenet_v2 import * diff --git a/official/quantization/models/mobilenet_v2.py b/official/quantization/models/mobilenet_v2.py new file mode 100644 index 0000000..76f6100 --- /dev/null +++ b/official/quantization/models/mobilenet_v2.py @@ -0,0 +1,187 @@ +# BSD 3-Clause License + +# Copyright (c) Soumith Chintala 2016, +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# ------------------------------------------------------------------------------ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# +# This file has been modified by Megvii ("Megvii Modifications"). +# All Megvii Modifications are Copyright (C) 2014-2019 Megvii Inc. All rights reserved. +# ------------------------------------------------------------------------------ +import megengine.functional as F +import megengine.module as M + +__all__ = ['MobileNetV2', 'mobilenet_v2'] + + +def _make_divisible(v, divisor, min_value=None): + """ + This function is taken from the original tf repo. + It ensures that all layers have a channel number that is divisible by 8 + It can be seen here: + https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py + :param v: + :param divisor: + :param min_value: + :return: + """ + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +class InvertedResidual(M.Module): + def __init__(self, inp, oup, stride, expand_ratio): + super(InvertedResidual, self).__init__() + self.stride = stride + assert stride in [1, 2] + + hidden_dim = int(round(inp * expand_ratio)) + self.use_res_connect = self.stride == 1 and inp == oup + + layers = [] + if expand_ratio != 1: + # pw + layers.append(M.ConvBnRelu2d(inp, hidden_dim, kernel_size=1, bias=False)) + layers.extend([ + # dw + M.ConvBnRelu2d(hidden_dim, hidden_dim, kernel_size=3, padding=1, + stride=stride, groups=hidden_dim, bias=False), + # pw-linear + M.ConvBn2d(hidden_dim, oup, kernel_size=1, bias=False) + ]) + self.conv = M.Sequential(*layers) + self.add = M.Elemwise("ADD") + + def forward(self, x): + if self.use_res_connect: + return self.add(x, self.conv(x)) + else: + return self.conv(x) + + +class MobileNetV2(M.Module): + def __init__(self, num_classes=1000, width_mult=1.0, inverted_residual_setting=None, round_nearest=8): + """ + MobileNet V2 main class + + Args: + num_classes (int): Number of classes + width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount + inverted_residual_setting: Network structure + round_nearest (int): Round the number of channels in each layer to be a multiple of this number + Set to 1 to turn off rounding + """ + super(MobileNetV2, self).__init__() + block = InvertedResidual + input_channel = 32 + last_channel = 1280 + + if inverted_residual_setting is None: + inverted_residual_setting = [ + # t, c, n, s + [1, 16, 1, 1], + [6, 24, 2, 2], + [6, 32, 3, 2], + [6, 64, 4, 2], + [6, 96, 3, 1], + [6, 160, 3, 2], + [6, 320, 1, 1], + ] + + # only check the first element, assuming user knows t,c,n,s are required + if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4: + raise ValueError("inverted_residual_setting should be non-empty " + "or a 4-element list, got {}".format(inverted_residual_setting)) + + # building first layer + input_channel = _make_divisible(input_channel * width_mult, round_nearest) + self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) + features = [M.ConvBnRelu2d(3, input_channel, kernel_size=3, padding=1, stride=2, bias=False)] + # building inverted residual blocks + for t, c, n, s in inverted_residual_setting: + output_channel = _make_divisible(c * width_mult, round_nearest) + for i in range(n): + stride = s if i == 0 else 1 + features.append(block(input_channel, output_channel, stride, expand_ratio=t)) + input_channel = output_channel + # building last several layers + features.append(M.ConvBnRelu2d(input_channel, self.last_channel, kernel_size=1, bias=False)) + # make it M.Sequential + self.features = M.Sequential(*features) + + # building classifier + self.classifier = M.Sequential( + M.Dropout(0.2), + M.Linear(self.last_channel, num_classes), + ) + + self.quant = M.QuantStub() + self.dequant = M.DequantStub() + + # weight initialization + for m in self.modules(): + if isinstance(m, M.Conv2d): + M.init.msra_normal_(m.weight, mode='fan_out') + if m.bias is not None: + M.init.zeros_(m.bias) + elif isinstance(m, M.BatchNorm2d): + M.init.ones_(m.weight) + M.init.zeros_(m.bias) + elif isinstance(m, M.Linear): + M.init.normal_(m.weight, 0, 0.01) + M.init.zeros_(m.bias) + + def forward(self, x): + x = self.quant(x) + x = self.features(x) + x = F.avg_pool2d(x, 7) + x = F.flatten(x, 1) + x = self.dequant(x) + x = self.classifier(x) + return x + + +def mobilenet_v2(**kwargs): + """ + Constructs a MobileNetV2 architecture from + `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" `_. + """ + model = MobileNetV2(**kwargs) + return model diff --git a/official/quantization/models/resnet.py b/official/quantization/models/resnet.py new file mode 100644 index 0000000..c29e0f3 --- /dev/null +++ b/official/quantization/models/resnet.py @@ -0,0 +1,349 @@ +# BSD 3-Clause License + +# Copyright (c) Soumith Chintala 2016, +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# ------------------------------------------------------------------------------ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# +# This file has been modified by Megvii ("Megvii Modifications"). +# All Megvii Modifications are Copyright (C) 2014-2019 Megvii Inc. All rights reserved. +# ------------------------------------------------------------------------------ +"""ResNet optimized for quantization, idential after modification.""" +import math + +import megengine.functional as F +import megengine.hub as hub +import megengine.module as M + + +class BasicBlock(M.Module): + expansion = 1 + + def __init__( + self, + in_channels, + channels, + stride=1, + groups=1, + base_width=64, + dilation=1, + norm=M.BatchNorm2d, + ): + assert norm is M.BatchNorm2d, 'Quant mode only support BatchNorm2d currently.' + super(BasicBlock, self).__init__() + if groups != 1 or base_width != 64: + raise ValueError("BasicBlock only supports groups=1 and base_width=64") + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + self.conv_bn_relu1 = M.ConvBnRelu2d( + in_channels, channels, 3, stride, padding=dilation, bias=False + ) + self.conv_bn2 = M.ConvBn2d( + channels, channels, 3, 1, padding=1, bias=False + ) + self.downsample = ( + M.Identity() + if in_channels == channels and stride == 1 + else M.ConvBn2d(in_channels, channels, 1, stride, bias=False) + ) + self.add = M.Elemwise("ADD") + + def forward(self, x): + identity = x + x = self.conv_bn_relu1(x) + x = self.conv_bn2(x) + identity = self.downsample(identity) + x = self.add(x, identity) + x = F.relu(x) + return x + + +class Bottleneck(M.Module): + expansion = 4 + + def __init__( + self, + in_channels, + channels, + stride=1, + groups=1, + base_width=64, + dilation=1, + norm=M.BatchNorm2d, + ): + super(Bottleneck, self).__init__() + width = int(channels * (base_width / 64.0)) * groups + self.conv_bn_relu1 = M.ConvBnRelu2d(in_channels, width, 1, 1, bias=False) + self.conv_bn_relu2 = M.ConvBnRelu2d( + width, + width, + 3, + stride, + padding=dilation, + groups=groups, + dilation=dilation, + bias=False, + ) + self.conv_bn3 = M.ConvBn2d( + width, channels * self.expansion, 1, 1, bias=False + ) + self.downsample = ( + M.Identity() + if in_channels == channels * self.expansion and stride == 1 + else M.ConvBn2d(in_channels, channels * self.expansion, 1, stride, bias=False) + ) + self.add = M.Elemwise("ADD") + + def forward(self, x): + identity = x + x = self.conv_bn_relu1(x) + x = self.conv_bn_relu2(x) + x = self.conv_bn3(x) + identity = self.downsample(identity) + x = self.add(x, identity) + x = F.relu(x) + return x + + +class ResNet(M.Module): + def __init__( + self, + block, + layers, + num_classes=1000, + zero_init_residual=False, + groups=1, + width_per_group=64, + replace_stride_with_dilation=None, + norm=M.BatchNorm2d, + ): + super(ResNet, self).__init__() + self.in_channels = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + # each element in the tuple indicates if we should replace + # the 2x2 stride with a dilated convolution instead + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError( + "replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation) + ) + self.groups = groups + self.base_width = width_per_group + self.quant = M.QuantStub() + self.dequant = M.DequantStub() + self.conv_bn_relu1 = M.ConvBnRelu2d( + 3, self.in_channels, kernel_size=7, stride=2, padding=3, bias=False + ) + self.maxpool = M.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0], norm=norm) + self.layer2 = self._make_layer( + block, + 128, + layers[1], + stride=2, + dilate=replace_stride_with_dilation[0], + norm=norm, + ) + self.layer3 = self._make_layer( + block, + 256, + layers[2], + stride=2, + dilate=replace_stride_with_dilation[1], + norm=norm, + ) + self.layer4 = self._make_layer( + block, + 512, + layers[3], + stride=2, + dilate=replace_stride_with_dilation[2], + norm=norm, + ) + self.fc = M.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, M.Conv2d): + M.init.msra_normal_(m.weight, mode="fan_out", nonlinearity="relu") + if m.bias is not None: + fan_in, _ = M.init.calculate_fan_in_and_fan_out(m.weight) + bound = 1 / math.sqrt(fan_in) + M.init.uniform_(m.bias, -bound, bound) + elif isinstance(m, M.BatchNorm2d): + M.init.ones_(m.weight) + M.init.zeros_(m.bias) + elif isinstance(m, M.Linear): + M.init.msra_uniform_(m.weight, a=math.sqrt(5)) + if m.bias is not None: + fan_in, _ = M.init.calculate_fan_in_and_fan_out(m.weight) + bound = 1 / math.sqrt(fan_in) + M.init.uniform_(m.bias, -bound, bound) + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + M.init.zeros_(m.bn3.weight) + elif isinstance(m, BasicBlock): + M.init.zeros_(m.bn2.weight) + + def _make_layer( + self, block, channels, blocks, stride=1, dilate=False, norm=M.BatchNorm2d + ): + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + + layers = [] + layers.append( + block( + self.in_channels, + channels, + stride, + groups=self.groups, + base_width=self.base_width, + dilation=previous_dilation, + norm=norm, + ) + ) + self.in_channels = channels * block.expansion + for _ in range(1, blocks): + layers.append( + block( + self.in_channels, + channels, + groups=self.groups, + base_width=self.base_width, + dilation=self.dilation, + norm=norm, + ) + ) + + return M.Sequential(*layers) + + def extract_features(self, x): + outputs = {} + x = self.conv_bn_relu1(x) + x = self.maxpool(x) + outputs["stem"] = x + + x = self.layer1(x) + outputs["res2"] = x + x = self.layer2(x) + outputs["res3"] = x + x = self.layer3(x) + outputs["res4"] = x + x = self.layer4(x) + outputs["res5"] = x + + return outputs + + def forward(self, x): + # FIXME whenever finding elegant solution + x = self.quant(x) + x = self.extract_features(x)["res5"] + + x = F.avg_pool2d(x, 7) + x = F.flatten(x, 1) + x = self.dequant(x) + x = self.fc(x) + + return x + + +def resnet18(**kwargs): + r"""ResNet-18 model from + `"Deep Residual Learning for Image Recognition" `_ + """ + return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) + + +def resnet34(**kwargs): + r"""ResNet-34 model from + `"Deep Residual Learning for Image Recognition" `_ + """ + return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) + + +def resnet50(**kwargs): + r"""ResNet-50 model from + `"Deep Residual Learning for Image Recognition" `_ + """ + return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) + + +def resnet101(**kwargs): + r"""ResNet-101 model from + `"Deep Residual Learning for Image Recognition" `_ + """ + return ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) + + +def resnet152(**kwargs): + r"""ResNet-152 model from + `"Deep Residual Learning for Image Recognition" `_ + """ + return ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) + + +def resnext50_32x4d(**kwargs): + r"""ResNeXt-50 32x4d model from + `"Aggregated Residual Transformation for Deep Neural Networks" `_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs["groups"] = 32 + kwargs["width_per_group"] = 4 + return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) + + +def resnext101_32x8d(**kwargs): + r"""ResNeXt-101 32x8d model from + `"Aggregated Residual Transformation for Deep Neural Networks" `_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs["groups"] = 32 + kwargs["width_per_group"] = 8 + return ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) diff --git a/official/quantization/models/shufflenet.py b/official/quantization/models/shufflenet.py new file mode 100644 index 0000000..ed82642 --- /dev/null +++ b/official/quantization/models/shufflenet.py @@ -0,0 +1,221 @@ +# -*- coding: utf-8 -*- +# MIT License +# +# Copyright (c) 2019 Megvii Technology +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# +# ------------------------------------------------------------------------------ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# +# This file has been modified by Megvii ("Megvii Modifications"). +# All Megvii Modifications are Copyright (C) 2014-2019 Megvii Inc. All rights reserved. +# ------------------------------------------------------------------------------ +import megengine.functional as F +import megengine.hub as hub +import megengine.module as M +from megengine.module import ( + BatchNorm2d, + Conv2d, + ConvBn2d, + ConvBnRelu2d, + AvgPool2d, + MaxPool2d, + DequantStub, + Linear, + Module, + QuantStub, + Sequential, + MaxPool2d, + Sequential, + Elemwise, +) +from megengine.quantization import * + + +class ShuffleV1Block(Module): + def __init__(self, inp, oup, *, group, first_group, mid_channels, ksize, stride): + super(ShuffleV1Block, self).__init__() + self.stride = stride + assert stride in [1, 2] + + self.mid_channels = mid_channels + self.ksize = ksize + pad = ksize // 2 + self.pad = pad + self.inp = inp + self.group = group + + branch_main_1 = [ + # pw + ConvBnRelu2d(inp, mid_channels, 1, 1, 0, groups=1 if first_group else group, bias=False), + # dw + ConvBn2d(mid_channels, mid_channels, ksize, stride, pad, groups=mid_channels, bias=False) + ] + branch_main_2 = [ + # pw-linear + ConvBn2d(mid_channels, oup, 1, 1, 0, groups=group, bias=False) + ] + self.branch_main_1 = Sequential(*branch_main_1) + self.branch_main_2 = Sequential(*branch_main_2) + self.add = Elemwise('FUSE_ADD_RELU') + + if stride == 2: + self.branch_proj = ConvBn2d(inp, oup, 1, 2, 0, bias=False) + + def forward(self, old_x): + x = old_x + x_proj = old_x + x = self.branch_main_1(x) + if self.group > 1: + x = self.channel_shuffle(x) + x = self.branch_main_2(x) + if self.stride == 1: + return self.add(x, x_proj) + elif self.stride == 2: + return self.add(self.branch_proj(x_proj), x) + + def channel_shuffle(self, x): + batchsize, num_channels, height, width = x.shape + # assert num_channels.numpy() % self.group == 0 + group_channels = num_channels // self.group + + x = x.reshape(batchsize, group_channels, self.group, height, width) + x = x.dimshuffle(0, 2, 1, 3, 4) + x = x.reshape(batchsize, num_channels, height, width) + return x + + +class ShuffleNetV1(Module): + def __init__(self, num_classes=1000, model_size='2.0x', group=None): + super(ShuffleNetV1, self).__init__() + print('model size is ', model_size) + + assert group is not None + + self.stage_repeats = [4, 8, 4] + self.model_size = model_size + if group == 3: + if model_size == '0.5x': + self.stage_out_channels = [-1, 12, 120, 240, 480] + elif model_size == '1.0x': + self.stage_out_channels = [-1, 24, 240, 480, 960] + elif model_size == '1.5x': + self.stage_out_channels = [-1, 24, 360, 720, 1440] + elif model_size == '2.0x': + self.stage_out_channels = [-1, 48, 480, 960, 1920] + else: + raise NotImplementedError + elif group == 8: + if model_size == '0.5x': + self.stage_out_channels = [-1, 16, 192, 384, 768] + elif model_size == '1.0x': + self.stage_out_channels = [-1, 24, 384, 768, 1536] + elif model_size == '1.5x': + self.stage_out_channels = [-1, 24, 576, 1152, 2304] + elif model_size == '2.0x': + self.stage_out_channels = [-1, 48, 768, 1536, 3072] + else: + raise NotImplementedError + + # building first layer + input_channel = self.stage_out_channels[1] + self.first_conv = Sequential( + ConvBnRelu2d(3, input_channel, 3, 2, 1, bias=False) + ) + self.maxpool = MaxPool2d(kernel_size=3, stride=2, padding=1) + + self.features = [] + for idxstage in range(len(self.stage_repeats)): + numrepeat = self.stage_repeats[idxstage] + output_channel = self.stage_out_channels[idxstage + 2] + + for i in range(numrepeat): + stride = 2 if i == 0 else 1 + first_group = idxstage == 0 and i == 0 + self.features.append(ShuffleV1Block(input_channel, output_channel, + group=group, first_group=first_group, + mid_channels=output_channel // 4, ksize=3, stride=stride)) + input_channel = output_channel + + self.features = Sequential(*self.features) + self.quant = QuantStub() + self.dequant = DequantStub() + self.classifier = Sequential(Linear(self.stage_out_channels[-1], num_classes, bias=False)) + self._initialize_weights() + + def forward(self, x): + x = self.quant(x) + x = self.first_conv(x) + x = self.maxpool(x) + + x = self.features(x) + + x = F.avg_pool2d(x, 7) + x = F.flatten(x, 1) + x = self.dequant(x) + x = self.classifier(x) + return x + + def _initialize_weights(self): + for name, m in self.named_modules(): + if isinstance(m, M.Conv2d): + if "first" in name: + M.init.normal_(m.weight, 0, 0.01) + else: + M.init.normal_(m.weight, 0, 1.0 / m.weight.shape[1]) + if m.bias is not None: + M.init.fill_(m.bias, 0) + elif isinstance(m, M.BatchNorm2d): + M.init.fill_(m.weight, 1) + if m.bias is not None: + M.init.fill_(m.bias, 0.0001) + M.init.fill_(m.running_mean, 0) + elif isinstance(m, M.BatchNorm1d): + M.init.fill_(m.weight, 1) + if m.bias is not None: + M.init.fill_(m.bias, 0.0001) + M.init.fill_(m.running_mean, 0) + elif isinstance(m, M.Linear): + M.init.normal_(m.weight, 0, 0.01) + if m.bias is not None: + M.init.fill_(m.bias, 0) + + +def shufflenet_v1_x0_5_g3(num_classes=1000): + net = ShuffleNetV1(num_classes=num_classes, model_size="0.5x", group=3) + return net + +def shufflenet_v1_x1_0_g3(num_classes=1000): + net = ShuffleNetV1(num_classes=num_classes, model_size="1.0x", group=3) + return net + +def shufflenet_v1_x1_5_g3(num_classes=1000): + net = ShuffleNetV1(num_classes=num_classes, model_size="1.5x", group=3) + return net + +def shufflenet_v1_x2_0_g3(num_classes=1000): + net = ShuffleNetV1(num_classes=num_classes, model_size="2.0x", group=3) + return net diff --git a/official/quantization/test.py b/official/quantization/test.py new file mode 100644 index 0000000..071da66 --- /dev/null +++ b/official/quantization/test.py @@ -0,0 +1,194 @@ +# -*- coding: utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +"""Test int8 quantizated model on ImageNet. + +Note: + * QAT simulate int8 with fp32, gpu only. + * Quantized use real int8, cpu only, a bit slow. + * Results may be slightly different between qat and quantized mode. +""" +import argparse +import multiprocessing as mp +import time + +import megengine as mge +import megengine.data as data +import megengine.data.transform as T +import megengine.distributed as dist +import megengine.functional as F +import megengine.jit as jit +import megengine.quantization as Q + +import models + +logger = mge.get_logger(__name__) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("-a", "--arch", default="resnet18", type=str) + parser.add_argument("-d", "--data", default=None, type=str) + parser.add_argument("-s", "--save", default="/data/models", type=str) + parser.add_argument("-c", "--checkpoint", default=None, type=str, + help="pretrained model to finetune") + + parser.add_argument("-m", "--mode", default="qat", type=str, + choices=["normal", "qat", "quantized"], + help="Quantization Mode\n" + "normal: no quantization, using float32\n" + "qat: quantization aware training, simulate int8\n" + "quantized: convert mode to int8 quantized, inference only") + + parser.add_argument("-n", "--ngpus", default=None, type=int) + parser.add_argument("-w", "--workers", default=4, type=int) + parser.add_argument("--report-freq", default=50, type=int) + args = parser.parse_args() + + world_size = mge.get_device_count("gpu") if args.ngpus is None else args.ngpus + + if args.mode == "quantized": + world_size = 1 + args.report_freq = 1 # test is slow on cpu + mge.set_default_device("cpux") + logger.warning("quantized mode use cpu only") + + if world_size > 1: + # start distributed training, dispatch sub-processes + mp.set_start_method("spawn") + processes = [] + for rank in range(world_size): + p = mp.Process(target=worker, args=(rank, world_size, args)) + p.start() + processes.append(p) + + for p in processes: + p.join() + else: + worker(0, 1, args) + + +def worker(rank, world_size, args): + # pylint: disable=too-many-statements + + if world_size > 1: + # Initialize distributed process group + logger.info("init distributed process group {} / {}".format(rank, world_size)) + dist.init_process_group( + master_ip="localhost", + master_port=23456, + world_size=world_size, + rank=rank, + dev=rank, + ) + + model = models.__dict__[args.arch]() + + if args.mode != "normal": + Q.quantize_qat(model, Q.ema_fakequant_qconfig) + + if args.checkpoint: + logger.info("Load pretrained weights from %s", args.checkpoint) + ckpt = mge.load(args.checkpoint) + ckpt = ckpt["state_dict"] if "state_dict" in ckpt else ckpt + model.load_state_dict(ckpt, strict=False) + + if args.mode == "quantized": + Q.quantize(model) + + # Define valid graph + @jit.trace(symbolic=True) + def valid_func(image, label): + model.eval() + logits = model(image) + loss = F.cross_entropy_with_softmax(logits, label, label_smooth=0.1) + acc1, acc5 = F.accuracy(logits, label, (1, 5)) + if dist.is_distributed(): # all_reduce_mean + loss = dist.all_reduce_sum(loss, "valid_loss") / dist.get_world_size() + acc1 = dist.all_reduce_sum(acc1, "valid_acc1") / dist.get_world_size() + acc5 = dist.all_reduce_sum(acc5, "valid_acc5") / dist.get_world_size() + return loss, acc1, acc5 + + # Build valid datasets + logger.info("preparing dataset..") + valid_dataset = data.dataset.ImageNet(args.data, train=False) + valid_sampler = data.SequentialSampler( + valid_dataset, batch_size=100, drop_last=False + ) + valid_queue = data.DataLoader( + valid_dataset, + sampler=valid_sampler, + transform=T.Compose( + [ + T.Resize(256), + T.CenterCrop(224), + T.Normalize(mean=128), + T.ToMode("CHW"), + ] + ), + num_workers=args.workers, + ) + + _, valid_acc, valid_acc5 = infer(valid_func, valid_queue, args) + logger.info("TEST %f, %f", valid_acc, valid_acc5) + + +def infer(model, data_queue, args): + objs = AverageMeter("Loss") + top1 = AverageMeter("Acc@1") + top5 = AverageMeter("Acc@5") + total_time = AverageMeter("Time") + + t = time.time() + for step, (image, label) in enumerate(data_queue): + n = image.shape[0] + image = image.astype("float32") # convert np.uint8 to float32 + label = label.astype("int32") + + loss, acc1, acc5 = model(image, label) + + objs.update(loss.numpy()[0], n) + top1.update(100 * acc1.numpy()[0], n) + top5.update(100 * acc5.numpy()[0], n) + total_time.update(time.time() - t) + t = time.time() + + if step % args.report_freq == 0 and dist.get_rank() == 0: + logger.info("Step %d, %s %s %s %s", + step, objs, top1, top5, total_time) + + return objs.avg, top1.avg, top5.avg + + +class AverageMeter: + """Computes and stores the average and current value""" + + def __init__(self, name, fmt=":.3f"): + self.name = name + self.fmt = fmt + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def __str__(self): + fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})" + return fmtstr.format(**self.__dict__) + + +if __name__ == "__main__": + main() diff --git a/official/quantization/train.py b/official/quantization/train.py new file mode 100644 index 0000000..3a694f7 --- /dev/null +++ b/official/quantization/train.py @@ -0,0 +1,309 @@ +# -*- coding: utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +"""Train a model in fp32""" +import argparse +import collections +import multiprocessing as mp +import numbers +import os +import bisect +import time + +import megengine as mge +import megengine.data as data +import megengine.data.transform as T +import megengine.distributed as dist +import megengine.functional as F +import megengine.jit as jit +import megengine.optimizer as optim +import megengine.quantization as Q + +import config +import models + +logger = mge.get_logger(__name__) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("-a", "--arch", default="resnet18", type=str) + parser.add_argument("-d", "--data", default=None, type=str) + parser.add_argument("-s", "--save", default="/data/models", type=str) + + parser.add_argument("-m", "--mode", default="normal", type=str, + choices=["normal", "qat", "quantized"], + help="Quantization Mode\n" + "normal: no quantization, using float32\n" + "qat: quantization aware training, simulate int8\n" + "quantized: convert mode to int8 quantized, inference only") + + parser.add_argument("-n", "--ngpus", default=None, type=int) + parser.add_argument("-w", "--workers", default=4, type=int) + parser.add_argument("--report-freq", default=50, type=int) + args = parser.parse_args() + + world_size = mge.get_device_count("gpu") if args.ngpus is None else args.ngpus + + if world_size > 1: + # start distributed training, dispatch sub-processes + mp.set_start_method("spawn") + processes = [] + for rank in range(world_size): + p = mp.Process(target=worker, args=(rank, world_size, args)) + p.start() + processes.append(p) + + for p in processes: + p.join() + else: + worker(0, 1, args) + + +def get_parameters(model, cfg): + if isinstance(cfg.WEIGHT_DECAY, numbers.Number): + return {"params": model.parameters(requires_grad=True), + "weight_decay": cfg.WEIGHT_DECAY} + + groups = collections.defaultdict(list) # weight_decay -> List[param] + for pname, p in model.named_parameters(requires_grad=True): + wd = cfg.WEIGHT_DECAY(pname, p) + groups[wd].append(p) + groups = [ + {"params": params, "weight_decay": wd} + for wd, params in groups.items() + ] # List[{param, weight_decay}] + return groups + + +def worker(rank, world_size, args): + # pylint: disable=too-many-statements + + if world_size > 1: + # Initialize distributed process group + logger.info("init distributed process group {} / {}".format(rank, world_size)) + dist.init_process_group( + master_ip="localhost", + master_port=23456, + world_size=world_size, + rank=rank, + dev=rank, + ) + + save_dir = os.path.join(args.save, args.arch + "." + args.mode) + if not os.path.exists(save_dir): + os.makedirs(save_dir, exist_ok=True) + mge.set_log_file(os.path.join(save_dir, "log.txt")) + + model = models.__dict__[args.arch]() + cfg = config.get_config(args.arch) + + cfg.LEARNING_RATE *= world_size # scale learning rate in distributed training + total_batch_size = cfg.BATCH_SIZE * world_size + steps_per_epoch = 1280000 // total_batch_size + total_steps = steps_per_epoch * cfg.EPOCHS + + if args.mode != "normal": + Q.quantize_qat(model, Q.ema_fakequant_qconfig) + + if args.mode == "quantized": + raise ValueError("mode = quantized only used during inference") + Q.quantize(model) + + optimizer = optim.SGD( + get_parameters(model, cfg), + lr=cfg.LEARNING_RATE, + momentum=cfg.MOMENTUM, + ) + + # Define train and valid graph + @jit.trace(symbolic=True) + def train_func(image, label): + model.train() + logits = model(image) + loss = F.cross_entropy_with_softmax(logits, label, label_smooth=0.1) + acc1, acc5 = F.accuracy(logits, label, (1, 5)) + optimizer.backward(loss) # compute gradients + if dist.is_distributed(): # all_reduce_mean + loss = dist.all_reduce_sum(loss, "train_loss") / dist.get_world_size() + acc1 = dist.all_reduce_sum(acc1, "train_acc1") / dist.get_world_size() + acc5 = dist.all_reduce_sum(acc5, "train_acc5") / dist.get_world_size() + return loss, acc1, acc5 + + @jit.trace(symbolic=True) + def valid_func(image, label): + model.eval() + logits = model(image) + loss = F.cross_entropy_with_softmax(logits, label, label_smooth=0.1) + acc1, acc5 = F.accuracy(logits, label, (1, 5)) + if dist.is_distributed(): # all_reduce_mean + loss = dist.all_reduce_sum(loss, "valid_loss") / dist.get_world_size() + acc1 = dist.all_reduce_sum(acc1, "valid_acc1") / dist.get_world_size() + acc5 = dist.all_reduce_sum(acc5, "valid_acc5") / dist.get_world_size() + return loss, acc1, acc5 + + # Build train and valid datasets + logger.info("preparing dataset..") + train_dataset = data.dataset.ImageNet(args.data, train=True) + train_sampler = data.Infinite(data.RandomSampler( + train_dataset, batch_size=cfg.BATCH_SIZE, drop_last=True + )) + train_queue = data.DataLoader( + train_dataset, + sampler=train_sampler, + transform=T.Compose( + [ + T.RandomResizedCrop(224), + T.RandomHorizontalFlip(), + cfg.COLOR_JITTOR, + T.Normalize(mean=128), + T.ToMode("CHW"), + ] + ), + num_workers=args.workers, + ) + train_queue = iter(train_queue) + valid_dataset = data.dataset.ImageNet(args.data, train=False) + valid_sampler = data.SequentialSampler( + valid_dataset, batch_size=100, drop_last=False + ) + valid_queue = data.DataLoader( + valid_dataset, + sampler=valid_sampler, + transform=T.Compose( + [ + T.Resize(256), + T.CenterCrop(224), + T.Normalize(mean=128), + T.ToMode("CHW"), + ] + ), + num_workers=args.workers, + ) + + def adjust_learning_rate(step, epoch): + learning_rate = cfg.LEARNING_RATE + if cfg.SCHEDULER == "Linear": + learning_rate *= 1 - float(step) / total_steps + elif cfg.SCHEDULER == "Multistep": + learning_rate *= cfg.SCHEDULER_GAMMA ** bisect.bisect_right(cfg.SCHEDULER_STEPS, epoch) + else: + raise ValueError(cfg.SCHEDULER) + for param_group in optimizer.param_groups: + param_group["lr"] = learning_rate + return learning_rate + + # Start training + objs = AverageMeter("Loss") + top1 = AverageMeter("Acc@1") + top5 = AverageMeter("Acc@5") + total_time = AverageMeter("Time") + + t = time.time() + for step in range(0, total_steps): + # Linear learning rate decay + epoch = step // steps_per_epoch + learning_rate = adjust_learning_rate(step, epoch) + + image, label = next(train_queue) + image = image.astype("float32") + label = label.astype("int32") + + n = image.shape[0] + + optimizer.zero_grad() + loss, acc1, acc5 = train_func(image, label) + optimizer.step() + + top1.update(100 * acc1.numpy()[0], n) + top5.update(100 * acc5.numpy()[0], n) + objs.update(loss.numpy()[0], n) + total_time.update(time.time() - t) + t = time.time() + if step % args.report_freq == 0 and rank == 0: + logger.info( + "TRAIN e%d %06d %f %s %s %s %s", + epoch, step, learning_rate, + objs, top1, top5, total_time + ) + objs.reset() + top1.reset() + top5.reset() + total_time.reset() + if step % 10000 == 0 and rank == 0: + logger.info("SAVING %06d", step) + mge.save( + {"step": step, "state_dict": model.state_dict()}, + os.path.join(save_dir, "checkpoint.pkl"), + ) + if step % 10000 == 0 and step != 0: + _, valid_acc, valid_acc5 = infer(valid_func, valid_queue, args) + logger.info("TEST %06d %f, %f", step, valid_acc, valid_acc5) + + mge.save( + {"step": step, "state_dict": model.state_dict()}, + os.path.join(save_dir, "checkpoint-final.pkl") + ) + _, valid_acc, valid_acc5 = infer(valid_func, valid_queue, args) + logger.info("TEST %06d %f, %f", step, valid_acc, valid_acc5) + + +def infer(model, data_queue, args): + objs = AverageMeter("Loss") + top1 = AverageMeter("Acc@1") + top5 = AverageMeter("Acc@5") + total_time = AverageMeter("Time") + + t = time.time() + for step, (image, label) in enumerate(data_queue): + n = image.shape[0] + image = image.astype("float32") # convert np.uint8 to float32 + label = label.astype("int32") + + loss, acc1, acc5 = model(image, label) + + objs.update(loss.numpy()[0], n) + top1.update(100 * acc1.numpy()[0], n) + top5.update(100 * acc5.numpy()[0], n) + total_time.update(time.time() - t) + t = time.time() + + if step % args.report_freq == 0 and dist.get_rank() == 0: + logger.info("Step %d, %s %s %s %s", + step, objs, top1, top5, total_time) + + return objs.avg, top1.avg, top5.avg + + +class AverageMeter: + """Computes and stores the average and current value""" + + def __init__(self, name, fmt=":.3f"): + self.name = name + self.fmt = fmt + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def __str__(self): + fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})" + return fmtstr.format(**self.__dict__) + + +if __name__ == "__main__": + main() -- GitLab