You need to sign in or sign up before continuing.
在量化训练中,训练正常,评估输出结果不正确,保存的模型也不正确
Created by: yeyupiaoling
- PaddlePaddle 1.8.4
def main(args):
config = get_config(args.config, overrides=args.override, show=True)
# 如果需要量化训练,就必须开启评估
if not config.validate and args.use_quant:
logger.error("=====>Train quant model must use validate!")
sys.exit(1)
if config.epochs < 6 and args.use_quant:
logger.error("=====>Train quant model epochs must greater than 6!")
sys.exit(1)
# 设置是否使用 GPU
use_gpu = config.get("use_gpu", True)
places = fluid.cuda_places() if use_gpu else fluid.cpu_places()
startup_prog = fluid.Program()
train_prog = fluid.Program()
best_top1_acc = 0.0
# 获取训练数据和模型输出
if not config.get('use_ema'):
train_dataloader, train_fetchs, out = program.build(config,
train_prog,
startup_prog,
is_train=True,
is_distributed=False)
else:
train_dataloader, train_fetchs, ema, out = program.build(config,
train_prog,
startup_prog,
is_train=True,
is_distributed=False)
# 获取评估数据和模型输出
if config.validate:
valid_prog = fluid.Program()
valid_dataloader, valid_fetchs, _ = program.build(config,
valid_prog,
startup_prog,
is_train=False,
is_distributed=False)
# 克隆评估程序,可以去掉与评估无关的计算
valid_prog = valid_prog.clone(for_test=True)
# 创建执行器
exe = fluid.Executor(places[0])
exe.run(startup_prog)
# 加载模型,可以是预训练模型,也可以是检查点
init_model(config, train_prog, exe)
train_reader = Reader(config, 'train')()
train_dataloader.set_sample_list_generator(train_reader, places)
compiled_train_prog = program.compile(config, train_prog, train_fetchs['loss'][0].name)
if config.validate:
valid_reader = Reader(config, 'valid')()
valid_dataloader.set_sample_list_generator(valid_reader, places)
compiled_valid_prog = program.compile(config, valid_prog, share_prog=compiled_train_prog)
vdl_writer = LogWriter(args.vdl_dir)
for epoch_id in range(config.epochs - 5):
# 训练一轮
program.run(train_dataloader, exe, compiled_train_prog, train_fetchs, epoch_id, 'train', config, vdl_writer)
# 执行一次评估
if config.validate and epoch_id % config.valid_interval == 0:
if config.get('use_ema'):
logger.info(logger.coloring("EMA validate start..."))
with ema.apply(exe):
_ = program.run(valid_dataloader, exe,
compiled_valid_prog, valid_fetchs,
epoch_id, 'valid', config)
logger.info(logger.coloring("EMA validate over!"))
top1_acc = program.run(valid_dataloader, exe, compiled_valid_prog, valid_fetchs, epoch_id, 'valid', config)
if top1_acc > best_top1_acc:
best_top1_acc = top1_acc
message = "The best top1 acc {:.5f}, in epoch: {:d}".format(best_top1_acc, epoch_id)
logger.info("{:s}".format(logger.coloring(message, "RED")))
if epoch_id % config.save_interval == 0:
model_path = os.path.join(config.model_save_dir, config.ARCHITECTURE["name"])
save_model(train_prog, model_path, "best_model")
# 保存模型
if epoch_id % config.save_interval == 0:
model_path = os.path.join(config.model_save_dir, config.ARCHITECTURE["name"])
if epoch_id >= 3 and os.path.exists(os.path.join(model_path, str(epoch_id - 3))):
shutil.rmtree(os.path.join(model_path, str(epoch_id - 3)), ignore_errors=True)
save_model(train_prog, model_path, epoch_id)
# 量化训练
if args.use_quant and config.validate:
# 执行量化训练
quant_program = slim.quant.quant_aware(train_prog, exe.place, for_test=False)
fetch_list = [f[0] for f in train_fetchs.values()]
metric_list = [f[1] for f in train_fetchs.values()]
for i in range(5):
for idx, batch in enumerate(train_dataloader()):
metrics = exe.run(program=quant_program, feed=batch, fetch_list=fetch_list)
for i, m in enumerate(metrics):
metric_list[i].update(np.mean(m), len(batch[0]))
fetchs_str = ''.join([str(m.value) + ' ' for m in metric_list])
if idx % 10 == 0:
logger.info("quant train : " + fetchs_str)
# 评估量化的结果
val_quant_program = slim.quant.quant_aware(valid_prog, exe.place, for_test=True)
fetch_list = [f[0] for f in valid_fetchs.values()]
metric_list = [f[1] for f in valid_fetchs.values()]
for idx, batch in enumerate(valid_dataloader()):
metrics = exe.run(program=val_quant_program, feed=batch, fetch_list=fetch_list)
for i, m in enumerate(metrics):
metric_list[i].update(np.mean(m), len(batch[0]))
fetchs_str = ''.join([str(m.value) + ' ' for m in metric_list])
if idx % 10 == 0:
logger.info("quant valid: " + fetchs_str)
# 保存量化训练模型
float_prog, int8_prog = slim.quant.convert(val_quant_program, exe.place, save_int8=True)
fluid.io.save_inference_model(dirname=args.output_path,
feeded_var_names=['feed_image'],
target_vars=out,
executor=exe,
main_program=float_prog,
model_filename='__model__',
params_filename='__params__')
输出结果如下,可以看出来量化评估的准确率非常低,相当于没有训练
2020-09-17 02:50:44,995-INFO: epoch:394 train step:190 loss: 1.2416 lr: 0.000028 elapse: 0.275s
2020-09-17 02:50:47,746-INFO: epoch:394 train step:200 loss: 1.0109 lr: 0.000028 elapse: 0.274s
2020-09-17 02:50:50,503-INFO: epoch:394 train step:210 loss: 0.9340 lr: 0.000028 elapse: 0.277s
2020-09-17 02:50:53,257-INFO: epoch:394 train step:220 loss: 0.9204 lr: 0.000028 elapse: 0.275s
2020-09-17 02:50:56,010-INFO: epoch:394 train step:230 loss: 2.4909 lr: 0.000028 elapse: 0.274s
2020-09-17 02:50:58,765-INFO: epoch:394 train step:240 loss: 0.8378 lr: 0.000028 elapse: 0.275s
2020-09-17 02:51:01,520-INFO: epoch:394 train step:250 loss: 0.8325 lr: 0.000028 elapse: 0.275s
2020-09-17 02:51:04,273-INFO: epoch:394 train step:260 loss: 0.8397 lr: 0.000028 elapse: 0.274s
2020-09-17 02:51:07,029-INFO: epoch:394 train step:270 loss: 0.8546 lr: 0.000028 elapse: 0.275s
2020-09-17 02:51:07,579-INFO: END epoch:394 train loss_avg: 1.2111 elapse_sum: 75.699s
2020-09-17 02:51:07,936-INFO: epoch:394 valid step:0 loss: 1.4642 top1: 0.7031 top5: 0.9062 elapse: 0.357s
2020-09-17 02:51:09,953-INFO: epoch:394 valid step:10 loss: 1.4104 top1: 0.7344 top5: 0.9062 elapse: 0.082s
2020-09-17 02:51:10,806-INFO: epoch:394 valid step:20 loss: 1.3292 top1: 0.7344 top5: 0.9531 elapse: 0.095s
2020-09-17 02:51:12,113-INFO: END epoch:394 valid loss_avg: 1.4435 top1_avg: 0.7120 top5_avg: 0.9125 elapse_sum: 4.530s
2020-09-17 02:51:13,980-INFO: Already save model in ./output/ResNet50_vd/394
2020-09-17 02:51:13,980-INFO: quant_aware config {'weight_quantize_type': 'channel_wise_abs_max', 'activation_quantize_type': 'moving_average_abs_max', 'weight_bits': 8, 'activation_bits': 8, 'not_quant_pattern': ['skip_quant'], 'quantize_op_types': ['conv2d', 'depthwise_conv2d', 'mul'], 'dtype': 'int8', 'window_size': 10000, 'moving_rate': 0.9, 'for_tensorrt': False, 'is_full_quantize': False}
2020-09-17 02:51:17,966-INFO: quant train : loss: 0.8532 lr: 0.000019
2020-09-17 02:51:21,137-INFO: quant train : loss: 1.0666 lr: 0.000019
2020-09-17 02:51:24,206-INFO: quant train : loss: 0.9276 lr: 0.000019
2020-09-17 02:57:51,909-INFO: quant train : loss: 2.4494 lr: 0.000001
.......................................
2020-09-17 02:58:16,551-INFO: quant train : loss: 1.6833 lr: 0.000001
2020-09-17 02:58:19,624-INFO: quant train : loss: 2.1616 lr: 0.000001
2020-09-17 02:58:22,701-INFO: quant train : loss: 0.8692 lr: 0.000001
2020-09-17 02:58:23,315-INFO: quant_aware config {'weight_quantize_type': 'channel_wise_abs_max', 'activation_quantize_type': 'moving_average_abs_max', 'weight_bits': 8, 'activation_bits': 8, 'not_quant_pattern': ['skip_quant'], 'quantize_op_types': ['conv2d', 'depthwise_conv2d', 'mul'], 'dtype': 'int8', 'window_size': 10000, 'moving_rate': 0.9, 'for_tensorrt': False, 'is_full_quantize': False}
2020-09-17 02:58:24,427-INFO: quant valid: loss: 3.0968 top1: 0.0625 top5: 0.3125
2020-09-17 02:58:25,943-INFO: quant valid: loss: 3.0719 top1: 0.0625 top5: 0.3750
2020-09-17 02:58:27,448-INFO: quant valid: loss: 3.1283 top1: 0.0625 top5: 0.3750