未验证 提交 5590daf9 编写于 作者: B Bai Yifan 提交者: GitHub

Add dygraph qat demo (#537)

上级 1e12c326
# 动态图量化训练
本示例介绍如何对动态图模型进行量化训练,示例以常用的MobileNetV1和MobileNetV3模型为例,介绍如何对其进行量化训练。
## 分类模型的量化训练流程
### 准备数据
在当前目录下创建``data``文件夹,将``ImageNet``数据集解压在``data``文件夹下,解压后``data/ILSVRC2012``文件夹下应包含以下文件:
- ``'train'``文件夹,训练图片
- ``'train_list.txt'``文件
- ``'val'``文件夹,验证图片
- ``'val_list.txt'``文件
### 准备需要量化的模型
- 对于paddle vision支持的[模型](https://github.com/PaddlePaddle/Paddle/tree/develop/python/paddle/vision/models)`[lenet, mobilenetv1, mobilenetv2, resnet, vgg]`可以直接使用vision内置的模型定义和ImageNet预训练权重
- 对于paddle vision暂未支持的模型,例如mobilenetv3,需要自行定义好模型结构以及准备相应的预训练权重
- 本示例使用的是经过蒸馏的mobilenetv3模型,在ImageNet数据集上Top1精度达到78.96: [预训练权重下载](https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV3_large_x1_0_ssld_pretrained.tar)
### 配置量化参数
```
quant_config = {
'weight_preprocess_type': None,
'activation_preprocess_type': None,
'weight_quantize_type': 'channel_wise_abs_max',
'activation_quantize_type': 'moving_average_abs_max',
'weight_bits': 8,
'activation_bits': 8,
'dtype': 'int8',
'window_size': 10000,
'moving_rate': 0.9,
'quantizable_layer_type': ['Conv2D', 'Linear'],
}
```
- `'weight_preprocess_type'`:代表对量化模型权重参数预处理的方法,目前支持PACT方法,如需使用可以改为'PACT';默认为None,代表不对权重进行任何预处理。
- `'activation_preprocess_type'`:代表对量化模型激活值预处理的方法,目前支持PACT方法,如需使用可以改为'PACT';默认为None,代表不对激活值进行任何预处理。
- `weight_quantize_type`:代表模型权重的量化方式,可选的有['abs_max', 'moving_average_abs_max', 'channel_wise_abs_max'],默认为channel_wise_abs_max
- `activation_quantize_type`:代表模型激活值的量化方式,可选的有['abs_max', 'moving_average_abs_max'],默认为moving_average_abs_max
- `quantizable_layer_type`:代表量化OP的类型,目前支持Conv2D和Linear
### 插入量化算子,得到量化训练模型
```python
quanter = QAT(config=quant_config)
quanter.quantize(net)
```
### 量化训练结束,保存量化模型
```python
quanter.save_quantized_model(net, 'save_dir', input_spec=[paddle.static.InputSpec(shape=[None, 3, 224, 224], dtype='float32')])
```
### 训练命令
- MobileNetV1
我们使用普通的量化训练方法即可,启动命令如下:
```bash
# 单卡训练
python train.py --model='mobilenet_v1'
# 多卡训练,以0到3号卡为例
python -m paddle.distributed.launch --gpus="0,1,2,3" train.py --model='mobilenet_v1'
```
- MobileNetV3
对于MobileNetV3,直接使用普通的量化损失较大,为降低量化损失,可以使用PACT的量化方法,启动命令如下:
```bash
# 单卡训练
python train.py --lr=0.001 --use_pact=True --num_epochs=30 --l2_decay=2e-5 --ls_epsilon=0.1
# 多卡训练,以0到3号卡为例
python -m paddle.distributed.launch --gpus="0,1,2,3" train.py --lr=0.001 --use_pact=True --num_epochs=60 --l2_decay=2e-5 --ls_epsilon=0.1
```
### 量化结果
| 模型 | FP32模型准确率(Top1/Top5) | 量化方法 | 量化模型准确率(Top1/Top5) |
| ----------- | --------------------------- | ------------ | --------------------------- |
| MobileNetV1 | 70.99/89.65 | 普通在线量化 | 70.63/89.65 |
| MobileNetV3 | 78.96/94.48 | PACT在线量化 | 77.52/93.77 |
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import paddle
from paddle import ParamAttr
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn.functional.activation import hard_sigmoid, hard_swish
from paddle.nn import Conv2D, BatchNorm, Linear, Dropout
from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
from paddle.regularizer import L2Decay
import math
__all__ = [
"MobileNetV3_small_x0_35", "MobileNetV3_small_x0_5",
"MobileNetV3_small_x0_75", "MobileNetV3_small_x1_0",
"MobileNetV3_small_x1_25", "MobileNetV3_large_x0_35",
"MobileNetV3_large_x0_5", "MobileNetV3_large_x0_75",
"MobileNetV3_large_x1_0", "MobileNetV3_large_x1_25"
]
def make_divisible(v, divisor=8, min_value=None):
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
if new_v < 0.9 * v:
new_v += divisor
return new_v
class MobileNetV3(nn.Layer):
def __init__(self,
scale=1.0,
model_name="small",
dropout_prob=0.2,
class_dim=1000):
super(MobileNetV3, self).__init__()
inplanes = 16
if model_name == "large":
self.cfg = [
# k, exp, c, se, nl, s,
[3, 16, 16, False, "relu", 1],
[3, 64, 24, False, "relu", 2],
[3, 72, 24, False, "relu", 1],
[5, 72, 40, True, "relu", 2],
[5, 120, 40, True, "relu", 1],
[5, 120, 40, True, "relu", 1],
[3, 240, 80, False, "hard_swish", 2],
[3, 200, 80, False, "hard_swish", 1],
[3, 184, 80, False, "hard_swish", 1],
[3, 184, 80, False, "hard_swish", 1],
[3, 480, 112, True, "hard_swish", 1],
[3, 672, 112, True, "hard_swish", 1],
[5, 672, 160, True, "hard_swish", 2],
[5, 960, 160, True, "hard_swish", 1],
[5, 960, 160, True, "hard_swish", 1],
]
self.cls_ch_squeeze = 960
self.cls_ch_expand = 1280
elif model_name == "small":
self.cfg = [
# k, exp, c, se, nl, s,
[3, 16, 16, True, "relu", 2],
[3, 72, 24, False, "relu", 2],
[3, 88, 24, False, "relu", 1],
[5, 96, 40, True, "hard_swish", 2],
[5, 240, 40, True, "hard_swish", 1],
[5, 240, 40, True, "hard_swish", 1],
[5, 120, 48, True, "hard_swish", 1],
[5, 144, 48, True, "hard_swish", 1],
[5, 288, 96, True, "hard_swish", 2],
[5, 576, 96, True, "hard_swish", 1],
[5, 576, 96, True, "hard_swish", 1],
]
self.cls_ch_squeeze = 576
self.cls_ch_expand = 1280
else:
raise NotImplementedError(
"mode[{}_model] is not implemented!".format(model_name))
self.conv1 = ConvBNLayer(
in_c=3,
out_c=make_divisible(inplanes * scale),
filter_size=3,
stride=2,
padding=1,
num_groups=1,
if_act=True,
act="hard_swish",
name="conv1")
self.block_list = []
i = 0
inplanes = make_divisible(inplanes * scale)
for (k, exp, c, se, nl, s) in self.cfg:
block = self.add_sublayer(
"conv" + str(i + 2),
ResidualUnit(
in_c=inplanes,
mid_c=make_divisible(scale * exp),
out_c=make_divisible(scale * c),
filter_size=k,
stride=s,
use_se=se,
act=nl,
name="conv" + str(i + 2)))
self.block_list.append(block)
inplanes = make_divisible(scale * c)
i += 1
self.last_second_conv = ConvBNLayer(
in_c=inplanes,
out_c=make_divisible(scale * self.cls_ch_squeeze),
filter_size=1,
stride=1,
padding=0,
num_groups=1,
if_act=True,
act="hard_swish",
name="conv_last")
self.pool = AdaptiveAvgPool2D(1)
self.last_conv = Conv2D(
in_channels=make_divisible(scale * self.cls_ch_squeeze),
out_channels=self.cls_ch_expand,
kernel_size=1,
stride=1,
padding=0,
weight_attr=ParamAttr(name="last_1x1_conv_weights"),
bias_attr=False)
self.out = Linear(
self.cls_ch_expand,
class_dim,
weight_attr=ParamAttr("fc_weights"),
bias_attr=ParamAttr(name="fc_offset"))
def forward(self, inputs, label=None):
x = self.conv1(inputs)
for block in self.block_list:
x = block(x)
x = self.last_second_conv(x)
x = self.pool(x)
x = self.last_conv(x)
x = hard_swish(x)
x = paddle.reshape(x, shape=[x.shape[0], x.shape[1]])
x = self.out(x)
return x
class ConvBNLayer(nn.Layer):
def __init__(self,
in_c,
out_c,
filter_size,
stride,
padding,
num_groups=1,
if_act=True,
act=None,
use_cudnn=True,
name=""):
super(ConvBNLayer, self).__init__()
self.if_act = if_act
self.act = act
self.conv = Conv2D(
in_channels=in_c,
out_channels=out_c,
kernel_size=filter_size,
stride=stride,
padding=padding,
groups=num_groups,
weight_attr=ParamAttr(name=name + "_weights"),
bias_attr=False)
self.bn = BatchNorm(
num_channels=out_c,
act=None,
param_attr=ParamAttr(
name=name + "_bn_scale", regularizer=L2Decay(0.0)),
bias_attr=ParamAttr(
name=name + "_bn_offset", regularizer=L2Decay(0.0)),
moving_mean_name=name + "_bn_mean",
moving_variance_name=name + "_bn_variance")
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
if self.if_act:
if self.act == "relu":
x = F.relu(x)
elif self.act == "hard_swish":
x = hard_swish(x)
else:
print("The activation function is selected incorrectly.")
exit()
return x
class ResidualUnit(nn.Layer):
def __init__(self,
in_c,
mid_c,
out_c,
filter_size,
stride,
use_se,
act=None,
name=''):
super(ResidualUnit, self).__init__()
self.if_shortcut = stride == 1 and in_c == out_c
self.if_se = use_se
self.expand_conv = ConvBNLayer(
in_c=in_c,
out_c=mid_c,
filter_size=1,
stride=1,
padding=0,
if_act=True,
act=act,
name=name + "_expand")
self.bottleneck_conv = ConvBNLayer(
in_c=mid_c,
out_c=mid_c,
filter_size=filter_size,
stride=stride,
padding=int((filter_size - 1) // 2),
num_groups=mid_c,
if_act=True,
act=act,
name=name + "_depthwise")
if self.if_se:
self.mid_se = SEModule(mid_c, name=name + "_se")
self.linear_conv = ConvBNLayer(
in_c=mid_c,
out_c=out_c,
filter_size=1,
stride=1,
padding=0,
if_act=False,
act=None,
name=name + "_linear")
def forward(self, inputs):
x = self.expand_conv(inputs)
x = self.bottleneck_conv(x)
if self.if_se:
x = self.mid_se(x)
x = self.linear_conv(x)
if self.if_shortcut:
x = paddle.add(inputs, x)
return x
class SEModule(nn.Layer):
def __init__(self, channel, reduction=4, name=""):
super(SEModule, self).__init__()
self.avg_pool = AdaptiveAvgPool2D(1)
self.conv1 = Conv2D(
in_channels=channel,
out_channels=channel // reduction,
kernel_size=1,
stride=1,
padding=0,
weight_attr=ParamAttr(name=name + "_1_weights"),
bias_attr=ParamAttr(name=name + "_1_offset"))
self.conv2 = Conv2D(
in_channels=channel // reduction,
out_channels=channel,
kernel_size=1,
stride=1,
padding=0,
weight_attr=ParamAttr(name + "_2_weights"),
bias_attr=ParamAttr(name=name + "_2_offset"))
def forward(self, inputs):
outputs = self.avg_pool(inputs)
outputs = self.conv1(outputs)
outputs = F.relu(outputs)
outputs = self.conv2(outputs)
outputs = hard_sigmoid(outputs)
return paddle.multiply(x=inputs, y=outputs)
def MobileNetV3_small_x0_35(**args):
model = MobileNetV3(model_name="small", scale=0.35, **args)
return model
def MobileNetV3_small_x0_5(**args):
model = MobileNetV3(model_name="small", scale=0.5, **args)
return model
def MobileNetV3_small_x0_75(**args):
model = MobileNetV3(model_name="small", scale=0.75, **args)
return model
def MobileNetV3_small_x1_0(**args):
model = MobileNetV3(model_name="small", scale=1.0, **args)
return model
def MobileNetV3_small_x1_25(**args):
model = MobileNetV3(model_name="small", scale=1.25, **args)
return model
def MobileNetV3_large_x0_35(**args):
model = MobileNetV3(model_name="large", scale=0.35, **args)
return model
def MobileNetV3_large_x0_5(**args):
model = MobileNetV3(model_name="large", scale=0.5, **args)
return model
def MobileNetV3_large_x0_75(**args):
model = MobileNetV3(model_name="large", scale=0.75, **args)
return model
def MobileNetV3_large_x1_0(**args):
model = MobileNetV3(model_name="large", scale=1.0, **args)
return model
def MobileNetV3_large_x1_25(**args):
model = MobileNetV3(model_name="large", scale=1.25, **args)
return model
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import paddle
def piecewise_decay(net, device_num, args):
step = int(
math.ceil(float(args.total_images) / (args.batch_size * device_num)))
bd = [step * e for e in args.step_epochs]
lr = [args.lr * (0.1**i) for i in range(len(bd) + 1)]
learning_rate = paddle.optimizer.lr.PiecewiseDecay(
boundaries=bd, values=lr, verbose=False)
optimizer = paddle.optimizer.Momentum(
parameters=net.parameters(),
learning_rate=learning_rate,
momentum=args.momentum_rate,
weight_decay=paddle.regularizer.L2Decay(args.l2_decay))
return optimizer, learning_rate
def cosine_decay(net, device_num, args):
step = int(
math.ceil(float(args.total_images) / (args.batch_size * device_num)))
learning_rate = paddle.optimizer.lr.CosineAnnealingDecay(
learning_rate=args.lr, T_max=step * args.num_epochs, verbose=False)
optimizer = paddle.optimizer.Momentum(
parameters=net.parameters(),
learning_rate=learning_rate,
momentum=args.momentum_rate,
weight_decay=paddle.regularizer.L2Decay(args.l2_decay))
return optimizer, learning_rate
def create_optimizer(net, device_num, args):
if args.lr_strategy == "piecewise_decay":
return piecewise_decay(net, device_num, args)
elif args.lr_strategy == "cosine_decay":
return cosine_decay(net, device_num, args)
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
import logging
import paddle
import argparse
import functools
import math
import time
import numpy as np
from paddle.distributed import ParallelEnv
from paddle.static import load_program_state
from paddle.vision.models import mobilenet_v1
from paddleslim.common import get_logger
from paddleslim.dygraph.quant import QAT
sys.path.append(os.path.join(os.path.dirname("__file__")))
from mobilenet_v3 import MobileNetV3_large_x1_0
from optimizer import create_optimizer
sys.path.append(
os.path.join(os.path.dirname("__file__"), os.path.pardir, os.path.pardir))
from utility import add_arguments, print_arguments
_logger = get_logger(__name__, level=logging.INFO)
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size', int, 256, "Minibatch size.")
add_arg('use_gpu', bool, True, "Whether to use GPU or not.")
add_arg('model', str, "mobilenet_v3", "The target model.")
add_arg('pretrained_model', str, "MobileNetV3_large_x1_0_ssld_pretrained", "Whether to use pretrained model.")
add_arg('lr', float, 0.0001, "The learning rate used to fine-tune pruned model.")
add_arg('lr_strategy', str, "piecewise_decay", "The learning rate decay strategy.")
add_arg('l2_decay', float, 3e-5, "The l2_decay parameter.")
add_arg('ls_epsilon', float, 0.0, "Label smooth epsilon.")
add_arg('use_pact', bool, False, "Whether to use PACT method.")
add_arg('momentum_rate', float, 0.9, "The value of momentum_rate.")
add_arg('num_epochs', int, 1, "The number of total epochs.")
add_arg('total_images', int, 1281167, "The number of total training images.")
add_arg('data', str, "imagenet", "Which data to use. 'mnist' or 'imagenet'")
add_arg('log_period', int, 10, "Log period in batches.")
add_arg('model_save_dir', str, "./", "model save directory.")
parser.add_argument('--step_epochs', nargs='+', type=int, default=[10, 20, 30], help="piecewise decay step")
# yapf: enable
def load_dygraph_pretrain(model, path=None, load_static_weights=False):
if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')):
raise ValueError("Model pretrain path {} does not "
"exists.".format(path))
if load_static_weights:
pre_state_dict = load_program_state(path)
param_state_dict = {}
model_dict = model.state_dict()
for key in model_dict.keys():
weight_name = model_dict[key].name
if weight_name in pre_state_dict.keys():
print('Load weight: {}, shape: {}'.format(
weight_name, pre_state_dict[weight_name].shape))
param_state_dict[key] = pre_state_dict[weight_name]
else:
param_state_dict[key] = model_dict[key]
model.set_dict(param_state_dict)
return
param_state_dict = paddle.load(path + ".pdparams")
model.set_dict(param_state_dict)
return
def compress(args):
if args.data == "mnist":
train_dataset = paddle.vision.datasets.MNIST(mode='train')
val_dataset = paddle.vision.datasets.MNIST(mode='test')
class_dim = 10
image_shape = "1,28,28"
args.total_images = 60000
elif args.data == "imagenet":
import imagenet_reader as reader
train_dataset = reader.ImageNetDataset(mode='train')
val_dataset = reader.ImageNetDataset(mode='val')
class_dim = 1000
image_shape = "3,224,224"
else:
raise ValueError("{} is not supported.".format(args.data))
trainer_num = paddle.distributed.get_world_size()
use_data_parallel = trainer_num != 1
place = paddle.set_device('gpu' if args.use_gpu else 'cpu')
# model definition
if use_data_parallel:
paddle.distributed.init_parallel_env()
if args.model == "mobilenet_v1":
pretrain = True if args.data == "imagenet" else False
net = mobilenet_v1(pretrained=pretrained)
elif args.model == "mobilenet_v3":
net = MobileNetV3_large_x1_0()
if args.data == "imagenet":
load_dygraph_pretrain(net, args.pretrained_model, True)
else:
raise ValueError("{} is not supported.".format(args.model))
_logger.info("Origin model summary:")
paddle.summary(net, (1, 3, 224, 224))
############################################################################################################
# 1. quantization configs
############################################################################################################
quant_config = {
# weight preprocess type, default is None and no preprocessing is performed.
'weight_preprocess_type': None,
# activation preprocess type, default is None and no preprocessing is performed.
'activation_preprocess_type': None,
# weight quantize type, default is 'channel_wise_abs_max'
'weight_quantize_type': 'channel_wise_abs_max',
# activation quantize type, default is 'moving_average_abs_max'
'activation_quantize_type': 'moving_average_abs_max',
# weight quantize bit num, default is 8
'weight_bits': 8,
# activation quantize bit num, default is 8
'activation_bits': 8,
# data type after quantization, such as 'uint8', 'int8', etc. default is 'int8'
'dtype': 'int8',
# window size for 'range_abs_max' quantization. default is 10000
'window_size': 10000,
# The decay coefficient of moving average, default is 0.9
'moving_rate': 0.9,
# for dygraph quantization, layers of type in quantizable_layer_type will be quantized
'quantizable_layer_type': ['Conv2D', 'Linear'],
}
if args.use_pact:
quant_config['activation_preprocess_type'] = 'PACT'
############################################################################################################
# 2. Quantize the model with QAT (quant aware training)
############################################################################################################
quanter = QAT(config=quant_config)
quanter.quantize(net)
_logger.info("QAT model summary:")
paddle.summary(net, (1, 3, 224, 224))
opt, lr = create_optimizer(net, trainer_num, args)
if use_data_parallel:
net = paddle.DataParallel(net)
train_batch_sampler = paddle.io.DistributedBatchSampler(
train_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True)
train_loader = paddle.io.DataLoader(
train_dataset,
batch_sampler=train_batch_sampler,
places=place,
return_list=True,
num_workers=4)
valid_loader = paddle.io.DataLoader(
val_dataset,
places=place,
batch_size=args.batch_size,
shuffle=False,
drop_last=False,
return_list=True,
num_workers=4)
@paddle.no_grad()
def test(epoch, net):
net.eval()
batch_id = 0
acc_top1_ns = []
acc_top5_ns = []
for data in valid_loader():
image = data[0]
label = data[1]
start_time = time.time()
out = net(image)
acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1)
acc_top5 = paddle.metric.accuracy(input=out, label=label, k=5)
end_time = time.time()
if batch_id % args.log_period == 0:
_logger.info(
"Eval epoch[{}] batch[{}] - acc_top1: {:.6f}; acc_top5: {:.6f}; time: {:.3f}".
format(epoch, batch_id,
np.mean(acc_top1.numpy()),
np.mean(acc_top5.numpy()), end_time - start_time))
acc_top1_ns.append(np.mean(acc_top1.numpy()))
acc_top5_ns.append(np.mean(acc_top5.numpy()))
batch_id += 1
_logger.info(
"Final eval epoch[{}] - acc_top1: {:.6f}; acc_top5: {:.6f}".format(
epoch,
np.mean(np.array(acc_top1_ns)), np.mean(np.array(acc_top5_ns))))
return np.mean(np.array(acc_top1_ns))
def cross_entropy(input, target, ls_epsilon):
if ls_epsilon > 0:
if target.shape[-1] != class_dim:
target = paddle.nn.functional.one_hot(target, class_dim)
target = paddle.nn.functional.label_smooth(
target, epsilon=ls_epsilon)
target = paddle.reshape(target, shape=[-1, class_dim])
input = -paddle.nn.functional.log_softmax(input, axis=-1)
cost = paddle.sum(target * input, axis=-1)
else:
cost = paddle.nn.functional.cross_entropy(input=input, label=target)
avg_cost = paddle.mean(cost)
return avg_cost
def train(epoch, net):
net.train()
batch_id = 0
for data in train_loader():
image = data[0]
label = data[1]
start_time = time.time()
out = net(image)
avg_cost = cross_entropy(out, label, args.ls_epsilon)
acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1)
acc_top5 = paddle.metric.accuracy(input=out, label=label, k=5)
avg_cost.backward()
opt.step()
opt.clear_grad()
lr.step()
loss_n = np.mean(avg_cost.numpy())
acc_top1_n = np.mean(acc_top1.numpy())
acc_top5_n = np.mean(acc_top5.numpy())
end_time = time.time()
if batch_id % args.log_period == 0:
_logger.info(
"epoch[{}]-batch[{}] lr: {:.6f} - loss: {:.6f}; acc_top1: {:.6f}; acc_top5: {:.6f}; time: {:.3f}".
format(epoch, batch_id,
lr.get_lr(), loss_n, acc_top1_n, acc_top5_n, end_time
- start_time))
batch_id += 1
############################################################################################################
# train loop
############################################################################################################
best_acc1 = 0.0
best_epoch = 0
for i in range(args.num_epochs):
train(i, net)
acc1 = test(i, net)
if paddle.distributed.get_rank() == 0:
model_prefix = os.path.join(args.model_save_dir, "epoch_" + str(i))
paddle.save(net.state_dict(), model_prefix + ".pdparams")
paddle.save(opt.state_dict(), model_prefix + ".pdopt")
if acc1 > best_acc1:
best_acc1 = acc1
best_epoch = i
if paddle.distributed.get_rank() == 0:
model_prefix = os.path.join(args.model_save_dir, "best_model")
paddle.save(net.state_dict(), model_prefix + ".pdparams")
paddle.save(opt.state_dict(), model_prefix + ".pdopt")
# load best model
load_dygraph_pretrain(net, os.path.join(args.model_save_dir, "best_model"))
############################################################################################################
# 3. Save quant aware model
############################################################################################################
path = os.path.join(args.model_save_dir, "inference_model", 'qat_model')
quanter.save_quantized_model(
net,
path,
input_spec=[
paddle.static.InputSpec(
shape=[None, 3, 224, 224], dtype='float32')
])
def main():
args = parser.parse_args()
print_arguments(args)
compress(args)
if __name__ == '__main__':
main()
......@@ -134,7 +134,7 @@ class PACT(paddle.nn.Layer):
alpha_attr = paddle.ParamAttr(
name=self.full_name() + ".pact",
initializer=paddle.nn.initializer.Constant(value=20),
learning_rate=10.0)
learning_rate=1000.0)
self.alpha = self.create_parameter(
shape=[1], attr=alpha_attr, dtype='float32')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册