From 5590daf9fe96f6d09a954c3b50a90fc65bf1d9b8 Mon Sep 17 00:00:00 2001 From: Bai Yifan Date: Wed, 9 Dec 2020 20:51:23 +0800 Subject: [PATCH] Add dygraph qat demo (#537) --- demo/dygraph/quant/README.md | 95 ++++++++ demo/dygraph/quant/mobilenet_v3.py | 357 ++++++++++++++++++++++++++++ demo/dygraph/quant/optimizer.py | 55 +++++ demo/dygraph/quant/train.py | 309 ++++++++++++++++++++++++ paddleslim/dygraph/quant/quanter.py | 2 +- 5 files changed, 817 insertions(+), 1 deletion(-) create mode 100755 demo/dygraph/quant/README.md create mode 100644 demo/dygraph/quant/mobilenet_v3.py create mode 100644 demo/dygraph/quant/optimizer.py create mode 100644 demo/dygraph/quant/train.py diff --git a/demo/dygraph/quant/README.md b/demo/dygraph/quant/README.md new file mode 100755 index 00000000..c8a5337b --- /dev/null +++ b/demo/dygraph/quant/README.md @@ -0,0 +1,95 @@ +# 动态图量化训练 + +本示例介绍如何对动态图模型进行量化训练,示例以常用的MobileNetV1和MobileNetV3模型为例,介绍如何对其进行量化训练。 + + +## 分类模型的量化训练流程 + +### 准备数据 + +在当前目录下创建``data``文件夹,将``ImageNet``数据集解压在``data``文件夹下,解压后``data/ILSVRC2012``文件夹下应包含以下文件: +- ``'train'``文件夹,训练图片 +- ``'train_list.txt'``文件 +- ``'val'``文件夹,验证图片 +- ``'val_list.txt'``文件 + +### 准备需要量化的模型 + +- 对于paddle vision支持的[模型](https://github.com/PaddlePaddle/Paddle/tree/develop/python/paddle/vision/models):`[lenet, mobilenetv1, mobilenetv2, resnet, vgg]`可以直接使用vision内置的模型定义和ImageNet预训练权重 +- 对于paddle vision暂未支持的模型,例如mobilenetv3,需要自行定义好模型结构以及准备相应的预训练权重 + - 本示例使用的是经过蒸馏的mobilenetv3模型,在ImageNet数据集上Top1精度达到78.96: [预训练权重下载](https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV3_large_x1_0_ssld_pretrained.tar) + + +### 配置量化参数 + +``` +quant_config = { + 'weight_preprocess_type': None, + 'activation_preprocess_type': None, + 'weight_quantize_type': 'channel_wise_abs_max', + 'activation_quantize_type': 'moving_average_abs_max', + 'weight_bits': 8, + 'activation_bits': 8, + 'dtype': 'int8', + 'window_size': 10000, + 'moving_rate': 0.9, + 'quantizable_layer_type': ['Conv2D', 'Linear'], +} +``` + +- `'weight_preprocess_type'`:代表对量化模型权重参数预处理的方法,目前支持PACT方法,如需使用可以改为'PACT';默认为None,代表不对权重进行任何预处理。 + +- `'activation_preprocess_type'`:代表对量化模型激活值预处理的方法,目前支持PACT方法,如需使用可以改为'PACT';默认为None,代表不对激活值进行任何预处理。 + +- `weight_quantize_type`:代表模型权重的量化方式,可选的有['abs_max', 'moving_average_abs_max', 'channel_wise_abs_max'],默认为channel_wise_abs_max + +- `activation_quantize_type`:代表模型激活值的量化方式,可选的有['abs_max', 'moving_average_abs_max'],默认为moving_average_abs_max + +- `quantizable_layer_type`:代表量化OP的类型,目前支持Conv2D和Linear + + + +### 插入量化算子,得到量化训练模型 + +```python +quanter = QAT(config=quant_config) +quanter.quantize(net) +``` + +### 量化训练结束,保存量化模型 + +```python +quanter.save_quantized_model(net, 'save_dir', input_spec=[paddle.static.InputSpec(shape=[None, 3, 224, 224], dtype='float32')]) +``` + +### 训练命令 + +- MobileNetV1 + + 我们使用普通的量化训练方法即可,启动命令如下: + + ```bash + # 单卡训练 + python train.py --model='mobilenet_v1' + # 多卡训练,以0到3号卡为例 + python -m paddle.distributed.launch --gpus="0,1,2,3" train.py --model='mobilenet_v1' + ``` +- MobileNetV3 + + 对于MobileNetV3,直接使用普通的量化损失较大,为降低量化损失,可以使用PACT的量化方法,启动命令如下: + + ```bash + # 单卡训练 + python train.py --lr=0.001 --use_pact=True --num_epochs=30 --l2_decay=2e-5 --ls_epsilon=0.1 + # 多卡训练,以0到3号卡为例 + python -m paddle.distributed.launch --gpus="0,1,2,3" train.py --lr=0.001 --use_pact=True --num_epochs=60 --l2_decay=2e-5 --ls_epsilon=0.1 + ``` + + + +### 量化结果 + +| 模型 | FP32模型准确率(Top1/Top5) | 量化方法 | 量化模型准确率(Top1/Top5) | +| ----------- | --------------------------- | ------------ | --------------------------- | +| MobileNetV1 | 70.99/89.65 | 普通在线量化 | 70.63/89.65 | +| MobileNetV3 | 78.96/94.48 | PACT在线量化 | 77.52/93.77 | diff --git a/demo/dygraph/quant/mobilenet_v3.py b/demo/dygraph/quant/mobilenet_v3.py new file mode 100644 index 00000000..e56c8990 --- /dev/null +++ b/demo/dygraph/quant/mobilenet_v3.py @@ -0,0 +1,357 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import paddle +from paddle import ParamAttr +import paddle.nn as nn +import paddle.nn.functional as F +from paddle.nn.functional.activation import hard_sigmoid, hard_swish +from paddle.nn import Conv2D, BatchNorm, Linear, Dropout +from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D +from paddle.regularizer import L2Decay + +import math + +__all__ = [ + "MobileNetV3_small_x0_35", "MobileNetV3_small_x0_5", + "MobileNetV3_small_x0_75", "MobileNetV3_small_x1_0", + "MobileNetV3_small_x1_25", "MobileNetV3_large_x0_35", + "MobileNetV3_large_x0_5", "MobileNetV3_large_x0_75", + "MobileNetV3_large_x1_0", "MobileNetV3_large_x1_25" +] + + +def make_divisible(v, divisor=8, min_value=None): + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +class MobileNetV3(nn.Layer): + def __init__(self, + scale=1.0, + model_name="small", + dropout_prob=0.2, + class_dim=1000): + super(MobileNetV3, self).__init__() + + inplanes = 16 + if model_name == "large": + self.cfg = [ + # k, exp, c, se, nl, s, + [3, 16, 16, False, "relu", 1], + [3, 64, 24, False, "relu", 2], + [3, 72, 24, False, "relu", 1], + [5, 72, 40, True, "relu", 2], + [5, 120, 40, True, "relu", 1], + [5, 120, 40, True, "relu", 1], + [3, 240, 80, False, "hard_swish", 2], + [3, 200, 80, False, "hard_swish", 1], + [3, 184, 80, False, "hard_swish", 1], + [3, 184, 80, False, "hard_swish", 1], + [3, 480, 112, True, "hard_swish", 1], + [3, 672, 112, True, "hard_swish", 1], + [5, 672, 160, True, "hard_swish", 2], + [5, 960, 160, True, "hard_swish", 1], + [5, 960, 160, True, "hard_swish", 1], + ] + self.cls_ch_squeeze = 960 + self.cls_ch_expand = 1280 + elif model_name == "small": + self.cfg = [ + # k, exp, c, se, nl, s, + [3, 16, 16, True, "relu", 2], + [3, 72, 24, False, "relu", 2], + [3, 88, 24, False, "relu", 1], + [5, 96, 40, True, "hard_swish", 2], + [5, 240, 40, True, "hard_swish", 1], + [5, 240, 40, True, "hard_swish", 1], + [5, 120, 48, True, "hard_swish", 1], + [5, 144, 48, True, "hard_swish", 1], + [5, 288, 96, True, "hard_swish", 2], + [5, 576, 96, True, "hard_swish", 1], + [5, 576, 96, True, "hard_swish", 1], + ] + self.cls_ch_squeeze = 576 + self.cls_ch_expand = 1280 + else: + raise NotImplementedError( + "mode[{}_model] is not implemented!".format(model_name)) + + self.conv1 = ConvBNLayer( + in_c=3, + out_c=make_divisible(inplanes * scale), + filter_size=3, + stride=2, + padding=1, + num_groups=1, + if_act=True, + act="hard_swish", + name="conv1") + + self.block_list = [] + i = 0 + inplanes = make_divisible(inplanes * scale) + for (k, exp, c, se, nl, s) in self.cfg: + block = self.add_sublayer( + "conv" + str(i + 2), + ResidualUnit( + in_c=inplanes, + mid_c=make_divisible(scale * exp), + out_c=make_divisible(scale * c), + filter_size=k, + stride=s, + use_se=se, + act=nl, + name="conv" + str(i + 2))) + self.block_list.append(block) + inplanes = make_divisible(scale * c) + i += 1 + + self.last_second_conv = ConvBNLayer( + in_c=inplanes, + out_c=make_divisible(scale * self.cls_ch_squeeze), + filter_size=1, + stride=1, + padding=0, + num_groups=1, + if_act=True, + act="hard_swish", + name="conv_last") + + self.pool = AdaptiveAvgPool2D(1) + + self.last_conv = Conv2D( + in_channels=make_divisible(scale * self.cls_ch_squeeze), + out_channels=self.cls_ch_expand, + kernel_size=1, + stride=1, + padding=0, + weight_attr=ParamAttr(name="last_1x1_conv_weights"), + bias_attr=False) + + self.out = Linear( + self.cls_ch_expand, + class_dim, + weight_attr=ParamAttr("fc_weights"), + bias_attr=ParamAttr(name="fc_offset")) + + def forward(self, inputs, label=None): + x = self.conv1(inputs) + + for block in self.block_list: + x = block(x) + + x = self.last_second_conv(x) + x = self.pool(x) + + x = self.last_conv(x) + x = hard_swish(x) + x = paddle.reshape(x, shape=[x.shape[0], x.shape[1]]) + x = self.out(x) + + return x + + +class ConvBNLayer(nn.Layer): + def __init__(self, + in_c, + out_c, + filter_size, + stride, + padding, + num_groups=1, + if_act=True, + act=None, + use_cudnn=True, + name=""): + super(ConvBNLayer, self).__init__() + self.if_act = if_act + self.act = act + + self.conv = Conv2D( + in_channels=in_c, + out_channels=out_c, + kernel_size=filter_size, + stride=stride, + padding=padding, + groups=num_groups, + weight_attr=ParamAttr(name=name + "_weights"), + bias_attr=False) + self.bn = BatchNorm( + num_channels=out_c, + act=None, + param_attr=ParamAttr( + name=name + "_bn_scale", regularizer=L2Decay(0.0)), + bias_attr=ParamAttr( + name=name + "_bn_offset", regularizer=L2Decay(0.0)), + moving_mean_name=name + "_bn_mean", + moving_variance_name=name + "_bn_variance") + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + if self.if_act: + if self.act == "relu": + x = F.relu(x) + elif self.act == "hard_swish": + x = hard_swish(x) + else: + print("The activation function is selected incorrectly.") + exit() + return x + + +class ResidualUnit(nn.Layer): + def __init__(self, + in_c, + mid_c, + out_c, + filter_size, + stride, + use_se, + act=None, + name=''): + super(ResidualUnit, self).__init__() + self.if_shortcut = stride == 1 and in_c == out_c + self.if_se = use_se + + self.expand_conv = ConvBNLayer( + in_c=in_c, + out_c=mid_c, + filter_size=1, + stride=1, + padding=0, + if_act=True, + act=act, + name=name + "_expand") + self.bottleneck_conv = ConvBNLayer( + in_c=mid_c, + out_c=mid_c, + filter_size=filter_size, + stride=stride, + padding=int((filter_size - 1) // 2), + num_groups=mid_c, + if_act=True, + act=act, + name=name + "_depthwise") + if self.if_se: + self.mid_se = SEModule(mid_c, name=name + "_se") + self.linear_conv = ConvBNLayer( + in_c=mid_c, + out_c=out_c, + filter_size=1, + stride=1, + padding=0, + if_act=False, + act=None, + name=name + "_linear") + + def forward(self, inputs): + x = self.expand_conv(inputs) + x = self.bottleneck_conv(x) + if self.if_se: + x = self.mid_se(x) + x = self.linear_conv(x) + if self.if_shortcut: + x = paddle.add(inputs, x) + return x + + +class SEModule(nn.Layer): + def __init__(self, channel, reduction=4, name=""): + super(SEModule, self).__init__() + self.avg_pool = AdaptiveAvgPool2D(1) + self.conv1 = Conv2D( + in_channels=channel, + out_channels=channel // reduction, + kernel_size=1, + stride=1, + padding=0, + weight_attr=ParamAttr(name=name + "_1_weights"), + bias_attr=ParamAttr(name=name + "_1_offset")) + self.conv2 = Conv2D( + in_channels=channel // reduction, + out_channels=channel, + kernel_size=1, + stride=1, + padding=0, + weight_attr=ParamAttr(name + "_2_weights"), + bias_attr=ParamAttr(name=name + "_2_offset")) + + def forward(self, inputs): + outputs = self.avg_pool(inputs) + outputs = self.conv1(outputs) + outputs = F.relu(outputs) + outputs = self.conv2(outputs) + outputs = hard_sigmoid(outputs) + return paddle.multiply(x=inputs, y=outputs) + + +def MobileNetV3_small_x0_35(**args): + model = MobileNetV3(model_name="small", scale=0.35, **args) + return model + + +def MobileNetV3_small_x0_5(**args): + model = MobileNetV3(model_name="small", scale=0.5, **args) + return model + + +def MobileNetV3_small_x0_75(**args): + model = MobileNetV3(model_name="small", scale=0.75, **args) + return model + + +def MobileNetV3_small_x1_0(**args): + model = MobileNetV3(model_name="small", scale=1.0, **args) + return model + + +def MobileNetV3_small_x1_25(**args): + model = MobileNetV3(model_name="small", scale=1.25, **args) + return model + + +def MobileNetV3_large_x0_35(**args): + model = MobileNetV3(model_name="large", scale=0.35, **args) + return model + + +def MobileNetV3_large_x0_5(**args): + model = MobileNetV3(model_name="large", scale=0.5, **args) + return model + + +def MobileNetV3_large_x0_75(**args): + model = MobileNetV3(model_name="large", scale=0.75, **args) + return model + + +def MobileNetV3_large_x1_0(**args): + model = MobileNetV3(model_name="large", scale=1.0, **args) + return model + + +def MobileNetV3_large_x1_25(**args): + model = MobileNetV3(model_name="large", scale=1.25, **args) + return model diff --git a/demo/dygraph/quant/optimizer.py b/demo/dygraph/quant/optimizer.py new file mode 100644 index 00000000..95c28d6e --- /dev/null +++ b/demo/dygraph/quant/optimizer.py @@ -0,0 +1,55 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import paddle + + +def piecewise_decay(net, device_num, args): + step = int( + math.ceil(float(args.total_images) / (args.batch_size * device_num))) + bd = [step * e for e in args.step_epochs] + lr = [args.lr * (0.1**i) for i in range(len(bd) + 1)] + learning_rate = paddle.optimizer.lr.PiecewiseDecay( + boundaries=bd, values=lr, verbose=False) + optimizer = paddle.optimizer.Momentum( + parameters=net.parameters(), + learning_rate=learning_rate, + momentum=args.momentum_rate, + weight_decay=paddle.regularizer.L2Decay(args.l2_decay)) + return optimizer, learning_rate + + +def cosine_decay(net, device_num, args): + step = int( + math.ceil(float(args.total_images) / (args.batch_size * device_num))) + learning_rate = paddle.optimizer.lr.CosineAnnealingDecay( + learning_rate=args.lr, T_max=step * args.num_epochs, verbose=False) + optimizer = paddle.optimizer.Momentum( + parameters=net.parameters(), + learning_rate=learning_rate, + momentum=args.momentum_rate, + weight_decay=paddle.regularizer.L2Decay(args.l2_decay)) + return optimizer, learning_rate + + +def create_optimizer(net, device_num, args): + if args.lr_strategy == "piecewise_decay": + return piecewise_decay(net, device_num, args) + elif args.lr_strategy == "cosine_decay": + return cosine_decay(net, device_num, args) diff --git a/demo/dygraph/quant/train.py b/demo/dygraph/quant/train.py new file mode 100644 index 00000000..8ab47a8c --- /dev/null +++ b/demo/dygraph/quant/train.py @@ -0,0 +1,309 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import sys +import logging +import paddle +import argparse +import functools +import math +import time +import numpy as np +from paddle.distributed import ParallelEnv +from paddle.static import load_program_state +from paddle.vision.models import mobilenet_v1 +from paddleslim.common import get_logger +from paddleslim.dygraph.quant import QAT + +sys.path.append(os.path.join(os.path.dirname("__file__"))) +from mobilenet_v3 import MobileNetV3_large_x1_0 +from optimizer import create_optimizer +sys.path.append( + os.path.join(os.path.dirname("__file__"), os.path.pardir, os.path.pardir)) +from utility import add_arguments, print_arguments + +_logger = get_logger(__name__, level=logging.INFO) + +parser = argparse.ArgumentParser(description=__doc__) +add_arg = functools.partial(add_arguments, argparser=parser) +# yapf: disable +add_arg('batch_size', int, 256, "Minibatch size.") +add_arg('use_gpu', bool, True, "Whether to use GPU or not.") +add_arg('model', str, "mobilenet_v3", "The target model.") +add_arg('pretrained_model', str, "MobileNetV3_large_x1_0_ssld_pretrained", "Whether to use pretrained model.") +add_arg('lr', float, 0.0001, "The learning rate used to fine-tune pruned model.") +add_arg('lr_strategy', str, "piecewise_decay", "The learning rate decay strategy.") +add_arg('l2_decay', float, 3e-5, "The l2_decay parameter.") +add_arg('ls_epsilon', float, 0.0, "Label smooth epsilon.") +add_arg('use_pact', bool, False, "Whether to use PACT method.") +add_arg('momentum_rate', float, 0.9, "The value of momentum_rate.") +add_arg('num_epochs', int, 1, "The number of total epochs.") +add_arg('total_images', int, 1281167, "The number of total training images.") +add_arg('data', str, "imagenet", "Which data to use. 'mnist' or 'imagenet'") +add_arg('log_period', int, 10, "Log period in batches.") +add_arg('model_save_dir', str, "./", "model save directory.") +parser.add_argument('--step_epochs', nargs='+', type=int, default=[10, 20, 30], help="piecewise decay step") +# yapf: enable + + +def load_dygraph_pretrain(model, path=None, load_static_weights=False): + if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')): + raise ValueError("Model pretrain path {} does not " + "exists.".format(path)) + if load_static_weights: + pre_state_dict = load_program_state(path) + param_state_dict = {} + model_dict = model.state_dict() + for key in model_dict.keys(): + weight_name = model_dict[key].name + if weight_name in pre_state_dict.keys(): + print('Load weight: {}, shape: {}'.format( + weight_name, pre_state_dict[weight_name].shape)) + param_state_dict[key] = pre_state_dict[weight_name] + else: + param_state_dict[key] = model_dict[key] + model.set_dict(param_state_dict) + return + + param_state_dict = paddle.load(path + ".pdparams") + model.set_dict(param_state_dict) + return + + +def compress(args): + if args.data == "mnist": + train_dataset = paddle.vision.datasets.MNIST(mode='train') + val_dataset = paddle.vision.datasets.MNIST(mode='test') + class_dim = 10 + image_shape = "1,28,28" + args.total_images = 60000 + elif args.data == "imagenet": + import imagenet_reader as reader + train_dataset = reader.ImageNetDataset(mode='train') + val_dataset = reader.ImageNetDataset(mode='val') + class_dim = 1000 + image_shape = "3,224,224" + else: + raise ValueError("{} is not supported.".format(args.data)) + + trainer_num = paddle.distributed.get_world_size() + use_data_parallel = trainer_num != 1 + + place = paddle.set_device('gpu' if args.use_gpu else 'cpu') + # model definition + if use_data_parallel: + paddle.distributed.init_parallel_env() + + if args.model == "mobilenet_v1": + pretrain = True if args.data == "imagenet" else False + net = mobilenet_v1(pretrained=pretrained) + elif args.model == "mobilenet_v3": + net = MobileNetV3_large_x1_0() + if args.data == "imagenet": + load_dygraph_pretrain(net, args.pretrained_model, True) + else: + raise ValueError("{} is not supported.".format(args.model)) + _logger.info("Origin model summary:") + paddle.summary(net, (1, 3, 224, 224)) + + ############################################################################################################ + # 1. quantization configs + ############################################################################################################ + quant_config = { + # weight preprocess type, default is None and no preprocessing is performed. + 'weight_preprocess_type': None, + # activation preprocess type, default is None and no preprocessing is performed. + 'activation_preprocess_type': None, + # weight quantize type, default is 'channel_wise_abs_max' + 'weight_quantize_type': 'channel_wise_abs_max', + # activation quantize type, default is 'moving_average_abs_max' + 'activation_quantize_type': 'moving_average_abs_max', + # weight quantize bit num, default is 8 + 'weight_bits': 8, + # activation quantize bit num, default is 8 + 'activation_bits': 8, + # data type after quantization, such as 'uint8', 'int8', etc. default is 'int8' + 'dtype': 'int8', + # window size for 'range_abs_max' quantization. default is 10000 + 'window_size': 10000, + # The decay coefficient of moving average, default is 0.9 + 'moving_rate': 0.9, + # for dygraph quantization, layers of type in quantizable_layer_type will be quantized + 'quantizable_layer_type': ['Conv2D', 'Linear'], + } + + if args.use_pact: + quant_config['activation_preprocess_type'] = 'PACT' + + ############################################################################################################ + # 2. Quantize the model with QAT (quant aware training) + ############################################################################################################ + + quanter = QAT(config=quant_config) + quanter.quantize(net) + + _logger.info("QAT model summary:") + paddle.summary(net, (1, 3, 224, 224)) + + opt, lr = create_optimizer(net, trainer_num, args) + + if use_data_parallel: + net = paddle.DataParallel(net) + + train_batch_sampler = paddle.io.DistributedBatchSampler( + train_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True) + train_loader = paddle.io.DataLoader( + train_dataset, + batch_sampler=train_batch_sampler, + places=place, + return_list=True, + num_workers=4) + + valid_loader = paddle.io.DataLoader( + val_dataset, + places=place, + batch_size=args.batch_size, + shuffle=False, + drop_last=False, + return_list=True, + num_workers=4) + + @paddle.no_grad() + def test(epoch, net): + net.eval() + batch_id = 0 + acc_top1_ns = [] + acc_top5_ns = [] + for data in valid_loader(): + image = data[0] + label = data[1] + start_time = time.time() + + out = net(image) + acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1) + acc_top5 = paddle.metric.accuracy(input=out, label=label, k=5) + + end_time = time.time() + if batch_id % args.log_period == 0: + _logger.info( + "Eval epoch[{}] batch[{}] - acc_top1: {:.6f}; acc_top5: {:.6f}; time: {:.3f}". + format(epoch, batch_id, + np.mean(acc_top1.numpy()), + np.mean(acc_top5.numpy()), end_time - start_time)) + acc_top1_ns.append(np.mean(acc_top1.numpy())) + acc_top5_ns.append(np.mean(acc_top5.numpy())) + batch_id += 1 + + _logger.info( + "Final eval epoch[{}] - acc_top1: {:.6f}; acc_top5: {:.6f}".format( + epoch, + np.mean(np.array(acc_top1_ns)), np.mean(np.array(acc_top5_ns)))) + return np.mean(np.array(acc_top1_ns)) + + def cross_entropy(input, target, ls_epsilon): + if ls_epsilon > 0: + if target.shape[-1] != class_dim: + target = paddle.nn.functional.one_hot(target, class_dim) + target = paddle.nn.functional.label_smooth( + target, epsilon=ls_epsilon) + target = paddle.reshape(target, shape=[-1, class_dim]) + input = -paddle.nn.functional.log_softmax(input, axis=-1) + cost = paddle.sum(target * input, axis=-1) + else: + cost = paddle.nn.functional.cross_entropy(input=input, label=target) + avg_cost = paddle.mean(cost) + return avg_cost + + def train(epoch, net): + + net.train() + batch_id = 0 + for data in train_loader(): + image = data[0] + label = data[1] + start_time = time.time() + + out = net(image) + avg_cost = cross_entropy(out, label, args.ls_epsilon) + + acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1) + acc_top5 = paddle.metric.accuracy(input=out, label=label, k=5) + avg_cost.backward() + opt.step() + opt.clear_grad() + lr.step() + + loss_n = np.mean(avg_cost.numpy()) + acc_top1_n = np.mean(acc_top1.numpy()) + acc_top5_n = np.mean(acc_top5.numpy()) + + end_time = time.time() + if batch_id % args.log_period == 0: + _logger.info( + "epoch[{}]-batch[{}] lr: {:.6f} - loss: {:.6f}; acc_top1: {:.6f}; acc_top5: {:.6f}; time: {:.3f}". + format(epoch, batch_id, + lr.get_lr(), loss_n, acc_top1_n, acc_top5_n, end_time + - start_time)) + batch_id += 1 + + ############################################################################################################ + # train loop + ############################################################################################################ + best_acc1 = 0.0 + best_epoch = 0 + for i in range(args.num_epochs): + train(i, net) + acc1 = test(i, net) + if paddle.distributed.get_rank() == 0: + model_prefix = os.path.join(args.model_save_dir, "epoch_" + str(i)) + paddle.save(net.state_dict(), model_prefix + ".pdparams") + paddle.save(opt.state_dict(), model_prefix + ".pdopt") + + if acc1 > best_acc1: + best_acc1 = acc1 + best_epoch = i + if paddle.distributed.get_rank() == 0: + model_prefix = os.path.join(args.model_save_dir, "best_model") + paddle.save(net.state_dict(), model_prefix + ".pdparams") + paddle.save(opt.state_dict(), model_prefix + ".pdopt") + + # load best model + load_dygraph_pretrain(net, os.path.join(args.model_save_dir, "best_model")) + + ############################################################################################################ + # 3. Save quant aware model + ############################################################################################################ + path = os.path.join(args.model_save_dir, "inference_model", 'qat_model') + quanter.save_quantized_model( + net, + path, + input_spec=[ + paddle.static.InputSpec( + shape=[None, 3, 224, 224], dtype='float32') + ]) + + +def main(): + args = parser.parse_args() + print_arguments(args) + compress(args) + + +if __name__ == '__main__': + main() diff --git a/paddleslim/dygraph/quant/quanter.py b/paddleslim/dygraph/quant/quanter.py index c7dfc3d2..1b800bfa 100644 --- a/paddleslim/dygraph/quant/quanter.py +++ b/paddleslim/dygraph/quant/quanter.py @@ -134,7 +134,7 @@ class PACT(paddle.nn.Layer): alpha_attr = paddle.ParamAttr( name=self.full_name() + ".pact", initializer=paddle.nn.initializer.Constant(value=20), - learning_rate=10.0) + learning_rate=1000.0) self.alpha = self.create_parameter( shape=[1], attr=alpha_attr, dtype='float32') -- GitLab