提交 56806d22 编写于 作者: C chenfei

standardlization of moblienetv2 and resnet50 quant network

上级 c7f461ac
...@@ -43,7 +43,6 @@ run_ascend() ...@@ -43,7 +43,6 @@ run_ascend()
--training_script=${BASEPATH}/../train.py \ --training_script=${BASEPATH}/../train.py \
--dataset_path=$5 \ --dataset_path=$5 \
--pre_trained=$6 \ --pre_trained=$6 \
--quantization_aware=True \
--device_target=$1 &> train.log & # dataset train folder --device_target=$1 &> train.log & # dataset train folder
} }
...@@ -75,8 +74,7 @@ run_gpu() ...@@ -75,8 +74,7 @@ run_gpu()
python ${BASEPATH}/../train.py \ python ${BASEPATH}/../train.py \
--dataset_path=$4 \ --dataset_path=$4 \
--device_target=$1 \ --device_target=$1 \
--pre_trained=$5 \ --pre_trained=$5 &> ../train.log & # dataset train folder
--quantization_aware=True &> ../train.log & # dataset train folder
} }
if [ $# -gt 6 ] || [ $# -lt 5 ] if [ $# -gt 6 ] || [ $# -lt 5 ]
......
...@@ -16,34 +16,12 @@ ...@@ -16,34 +16,12 @@
network config setting, will be used in train.py and eval.py network config setting, will be used in train.py and eval.py
""" """
from easydict import EasyDict as ed from easydict import EasyDict as ed
config_ascend = ed({
"num_classes": 1000,
"image_height": 224,
"image_width": 224,
"batch_size": 256,
"data_load_mode": "mindrecord",
"epoch_size": 200,
"start_epoch": 0,
"warmup_epochs": 4,
"lr": 0.4,
"momentum": 0.9,
"weight_decay": 4e-5,
"label_smooth": 0.1,
"loss_scale": 1024,
"save_checkpoint": True,
"save_checkpoint_epochs": 1,
"keep_checkpoint_max": 300,
"save_checkpoint_path": "./checkpoint",
"quantization_aware": False,
})
config_ascend_quant = ed({ config_ascend_quant = ed({
"num_classes": 1000, "num_classes": 1000,
"image_height": 224, "image_height": 224,
"image_width": 224, "image_width": 224,
"batch_size": 192, "batch_size": 192,
"data_load_mode": "mindrecord", "data_load_mode": "mindata",
"epoch_size": 60, "epoch_size": 60,
"start_epoch": 200, "start_epoch": 200,
"warmup_epochs": 1, "warmup_epochs": 1,
...@@ -59,24 +37,6 @@ config_ascend_quant = ed({ ...@@ -59,24 +37,6 @@ config_ascend_quant = ed({
"quantization_aware": True, "quantization_aware": True,
}) })
config_gpu = ed({
"num_classes": 1000,
"image_height": 224,
"image_width": 224,
"batch_size": 150,
"epoch_size": 200,
"warmup_epochs": 4,
"lr": 0.8,
"momentum": 0.9,
"weight_decay": 4e-5,
"label_smooth": 0.1,
"loss_scale": 1024,
"save_checkpoint": True,
"save_checkpoint_epochs": 1,
"keep_checkpoint_max": 300,
"save_checkpoint_path": "./checkpoint",
})
config_gpu_quant = ed({ config_gpu_quant = ed({
"num_classes": 1000, "num_classes": 1000,
"image_height": 224, "image_height": 224,
......
...@@ -26,6 +26,38 @@ from mindspore.ops import functional as F ...@@ -26,6 +26,38 @@ from mindspore.ops import functional as F
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
def _load_param_into_net(model, params_dict):
"""
load fp32 model parameters to quantization model.
Args:
model: quantization model
params_dict: f32 param
Returns:
None
"""
iterable_dict = {
'weight': iter([item for item in params_dict.items() if item[0].endswith('weight')]),
'bias': iter([item for item in params_dict.items() if item[0].endswith('bias')]),
'gamma': iter([item for item in params_dict.items() if item[0].endswith('gamma')]),
'beta': iter([item for item in params_dict.items() if item[0].endswith('beta')]),
'moving_mean': iter([item for item in params_dict.items() if item[0].endswith('moving_mean')]),
'moving_variance': iter(
[item for item in params_dict.items() if item[0].endswith('moving_variance')]),
'minq': iter([item for item in params_dict.items() if item[0].endswith('minq')]),
'maxq': iter([item for item in params_dict.items() if item[0].endswith('maxq')])
}
for name, param in model.parameters_and_names():
key_name = name.split(".")[-1]
if key_name not in iterable_dict.keys():
raise ValueError(f"Can't find match parameter in ckpt,param name = {name}")
value_param = next(iterable_dict[key_name], None)
if value_param is not None:
param.set_parameter_data(value_param[1].data)
print(f'init model param {name} with checkpoint param {value_param[0]}')
class Monitor(Callback): class Monitor(Callback):
""" """
Monitor loss and time. Monitor loss and time.
......
...@@ -25,7 +25,7 @@ from mindspore import nn ...@@ -25,7 +25,7 @@ from mindspore import nn
from mindspore.train.model import Model, ParallelMode from mindspore.train.model import Model, ParallelMode
from mindspore.train.loss_scale_manager import FixedLossScaleManager from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint
from mindspore.communication.management import init, get_group_size, get_rank from mindspore.communication.management import init, get_group_size, get_rank
from mindspore.train.quant import quant from mindspore.train.quant import quant
import mindspore.dataset.engine as de import mindspore.dataset.engine as de
...@@ -33,8 +33,9 @@ import mindspore.dataset.engine as de ...@@ -33,8 +33,9 @@ import mindspore.dataset.engine as de
from src.dataset import create_dataset from src.dataset import create_dataset
from src.lr_generator import get_lr from src.lr_generator import get_lr
from src.utils import Monitor, CrossEntropyWithLabelSmooth from src.utils import Monitor, CrossEntropyWithLabelSmooth
from src.config import config_ascend_quant, config_ascend, config_gpu_quant, config_gpu from src.config import config_ascend_quant, config_gpu_quant
from src.mobilenetV2 import mobilenetV2 from src.mobilenetV2 import mobilenetV2
from src.utils import _load_param_into_net
random.seed(1) random.seed(1)
np.random.seed(1) np.random.seed(1)
...@@ -44,7 +45,6 @@ parser = argparse.ArgumentParser(description='Image classification') ...@@ -44,7 +45,6 @@ parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path') parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
parser.add_argument('--pre_trained', type=str, default=None, help='Pertained checkpoint path') parser.add_argument('--pre_trained', type=str, default=None, help='Pertained checkpoint path')
parser.add_argument('--device_target', type=str, default=None, help='Run device target') parser.add_argument('--device_target', type=str, default=None, help='Run device target')
parser.add_argument('--quantization_aware', type=bool, default=False, help='Use quantization aware training')
args_opt = parser.parse_args() args_opt = parser.parse_args()
if args_opt.device_target == "Ascend": if args_opt.device_target == "Ascend":
...@@ -69,7 +69,7 @@ else: ...@@ -69,7 +69,7 @@ else:
def train_on_ascend(): def train_on_ascend():
config = config_ascend_quant if args_opt.quantization_aware else config_ascend config = config_ascend_quant
print("training args: {}".format(args_opt)) print("training args: {}".format(args_opt))
print("training configure: {}".format(config)) print("training configure: {}".format(config))
print("parallel args: rank_id {}, device_id {}, rank_size {}".format(rank_id, device_id, rank_size)) print("parallel args: rank_id {}, device_id {}, rank_size {}".format(rank_id, device_id, rank_size))
...@@ -101,10 +101,8 @@ def train_on_ascend(): ...@@ -101,10 +101,8 @@ def train_on_ascend():
# load pre trained ckpt # load pre trained ckpt
if args_opt.pre_trained: if args_opt.pre_trained:
param_dict = load_checkpoint(args_opt.pre_trained) param_dict = load_checkpoint(args_opt.pre_trained)
load_param_into_net(network, param_dict) _load_param_into_net(network, param_dict)
# convert fusion network to quantization aware network # convert fusion network to quantization aware network
if config.quantization_aware:
network = quant.convert_quant_network(network, network = quant.convert_quant_network(network,
bn_fold=True, bn_fold=True,
per_channel=[True, False], per_channel=[True, False],
...@@ -141,7 +139,7 @@ def train_on_ascend(): ...@@ -141,7 +139,7 @@ def train_on_ascend():
def train_on_gpu(): def train_on_gpu():
config = config_gpu_quant if args_opt.quantization_aware else config_gpu config = config_gpu_quant
print("training args: {}".format(args_opt)) print("training args: {}".format(args_opt))
print("training configure: {}".format(config)) print("training configure: {}".format(config))
...@@ -165,14 +163,15 @@ def train_on_gpu(): ...@@ -165,14 +163,15 @@ def train_on_gpu():
# resume # resume
if args_opt.pre_trained: if args_opt.pre_trained:
param_dict = load_checkpoint(args_opt.pre_trained) param_dict = load_checkpoint(args_opt.pre_trained)
load_param_into_net(network, param_dict) _load_param_into_net(network, param_dict)
# convert fusion network to quantization aware network # convert fusion network to quantization aware network
if config.quantization_aware:
network = quant.convert_quant_network(network, network = quant.convert_quant_network(network,
bn_fold=True, bn_fold=True,
per_channel=[True, False], per_channel=[True, False],
symmetric=[True, True]) symmetric=[True, True],
freeze_bn=1000000,
quant_delay=step_size * 2)
# get learning rate # get learning rate
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False) loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
......
...@@ -16,33 +16,6 @@ ...@@ -16,33 +16,6 @@
network config setting, will be used in train.py and eval.py network config setting, will be used in train.py and eval.py
""" """
from easydict import EasyDict as ed from easydict import EasyDict as ed
quant_set = ed({
"quantization_aware": True,
})
config_noquant = ed({
"class_num": 1001,
"batch_size": 32,
"loss_scale": 1024,
"momentum": 0.9,
"weight_decay": 1e-4,
"epoch_size": 90,
"pretrained_epoch_size": 1,
"buffer_size": 1000,
"image_height": 224,
"image_width": 224,
"data_load_mode": "mindrecord",
"save_checkpoint": True,
"save_checkpoint_epochs": 1,
"keep_checkpoint_max": 50,
"save_checkpoint_path": "./",
"warmup_epochs": 0,
"lr_decay_mode": "cosine",
"use_label_smooth": True,
"label_smooth_factor": 0.1,
"lr_init": 0,
"lr_max": 0.1,
})
config_quant = ed({ config_quant = ed({
"class_num": 1001, "class_num": 1001,
"batch_size": 32, "batch_size": 32,
...@@ -54,7 +27,7 @@ config_quant = ed({ ...@@ -54,7 +27,7 @@ config_quant = ed({
"buffer_size": 1000, "buffer_size": 1000,
"image_height": 224, "image_height": 224,
"image_width": 224, "image_width": 224,
"data_load_mode": "mindrecord", "data_load_mode": "mindata",
"save_checkpoint": True, "save_checkpoint": True,
"save_checkpoint_epochs": 1, "save_checkpoint_epochs": 1,
"keep_checkpoint_max": 50, "keep_checkpoint_max": 50,
......
...@@ -33,7 +33,7 @@ import mindspore.common.initializer as weight_init ...@@ -33,7 +33,7 @@ import mindspore.common.initializer as weight_init
from models.resnet_quant import resnet50_quant from models.resnet_quant import resnet50_quant
from src.dataset import create_dataset from src.dataset import create_dataset
from src.lr_generator import get_lr from src.lr_generator import get_lr
from src.config import quant_set, config_quant, config_noquant from src.config import config_quant
from src.crossentropy import CrossEntropy from src.crossentropy import CrossEntropy
from src.utils import _load_param_into_net from src.utils import _load_param_into_net
...@@ -44,7 +44,7 @@ parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path ...@@ -44,7 +44,7 @@ parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path
parser.add_argument('--device_target', type=str, default='Ascend', help='Device target') parser.add_argument('--device_target', type=str, default='Ascend', help='Device target')
parser.add_argument('--pre_trained', type=str, default=None, help='Pertained checkpoint path') parser.add_argument('--pre_trained', type=str, default=None, help='Pertained checkpoint path')
args_opt = parser.parse_args() args_opt = parser.parse_args()
config = config_quant if quant_set.quantization_aware else config_noquant config = config_quant
if args_opt.device_target == "Ascend": if args_opt.device_target == "Ascend":
device_id = int(os.getenv('DEVICE_ID')) device_id = int(os.getenv('DEVICE_ID'))
...@@ -110,7 +110,6 @@ if __name__ == '__main__': ...@@ -110,7 +110,6 @@ if __name__ == '__main__':
target=args_opt.device_target) target=args_opt.device_target)
step_size = dataset.get_dataset_size() step_size = dataset.get_dataset_size()
if quant_set.quantization_aware:
# convert fusion network to quantization aware network # convert fusion network to quantization aware network
net = quant.convert_quant_network(net, bn_fold=True, per_channel=[True, False], symmetric=[True, False]) net = quant.convert_quant_network(net, bn_fold=True, per_channel=[True, False], symmetric=[True, False])
...@@ -131,11 +130,7 @@ if __name__ == '__main__': ...@@ -131,11 +130,7 @@ if __name__ == '__main__':
config.weight_decay, config.loss_scale) config.weight_decay, config.loss_scale)
# define model # define model
if quant_set.quantization_aware:
model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'}) model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'})
else:
model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'},
amp_level="O2")
print("============== Starting Training ==============") print("============== Starting Training ==============")
time_callback = TimeMonitor(data_size=step_size) time_callback = TimeMonitor(data_size=step_size)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册