diff --git a/demo/distillation/distillation_demo.py b/demo/distillation/distillation_demo.py index 79142026359b525e3cf3754d9dd52808081b3a57..3f47553e541ff86ae0a6f4d86c046a1dee66a03f 100644 --- a/demo/distillation/distillation_demo.py +++ b/demo/distillation/distillation_demo.py @@ -13,9 +13,8 @@ import numpy as np import paddle.fluid as fluid sys.path.append(sys.path[0] + "/../") import models -import imagenet_reader as reader from utility import add_arguments, print_arguments, _download, _decompress -from single_distiller import merge, l2_loss, soft_label_loss, fsp_loss +from paddleslim.dist import merge, l2_loss, soft_label_loss, fsp_loss logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s') _logger = logging.getLogger(__name__) @@ -33,7 +32,7 @@ add_arg('lr_strategy', str, "piecewise_decay", "The learning rate decay add_arg('l2_decay', float, 3e-5, "The l2_decay parameter.") add_arg('momentum_rate', float, 0.9, "The value of momentum_rate.") add_arg('num_epochs', int, 120, "The number of total epochs.") -add_arg('data', str, "mnist", "Which data to use. 'mnist' or 'imagenet'") +add_arg('data', str, "cifar10", "Which data to use. 'cifar10' or 'imagenet'") add_arg('log_period', int, 20, "Log period in batches.") add_arg('model', str, "MobileNet", "Set the network to use.") add_arg('pretrained_model', str, None, "Whether to use pretrained model.") @@ -76,7 +75,7 @@ def create_optimizer(args): def compress(args): - if args.data == "mnist": + if args.data == "cifar10": import paddle.dataset.cifar as reader train_reader = reader.train10() val_reader = reader.test10() @@ -146,9 +145,9 @@ def compress(args): name='image', shape=image_shape, dtype='float32') predict = teacher_model.net(image, class_dim=class_dim) - #print("="*50+"teacher_model_params"+"="*50) - #for v in teacher_program.list_vars(): - # print(v.name, v.shape) + #print("="*50+"teacher_model_params"+"="*50) + #for v in teacher_program.list_vars(): + # print(v.name, v.shape) exe.run(t_startup) _download('http://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_pretrained.tar', '.') @@ -176,8 +175,8 @@ def compress(args): place) with fluid.program_guard(main, s_startup): - l2_loss_v = l2_loss("teacher_fc_0.tmp_0", "fc_0.tmp_0", main) - loss = avg_cost + l2_loss_v + l2_loss = l2_loss("teacher_fc_0.tmp_0", "fc_0.tmp_0", main) + loss = avg_cost + l2_loss opt = create_optimizer(args) opt.minimize(loss) exe.run(s_startup) @@ -192,7 +191,7 @@ def compress(args): parallel_main, feed=data, fetch_list=[ - loss.name, avg_cost.name, l2_loss_v.name + loss.name, avg_cost.name, l2_loss.name ]) if step_id % args.log_period == 0: _logger.info( diff --git a/doc/demo_guide.md b/doc/demo_guide.md index 9722dcc8de94a265491dced85e11e305c00073bc..ca7514c51d006c23d36764424c457f95fa6ca912 100644 --- a/doc/demo_guide.md +++ b/doc/demo_guide.md @@ -1,3 +1,9 @@ ## [蒸馏](../demo/distillation/distillation_demo.py) +蒸馏demo默认使用ResNet50作为teacher网络,MobileNet作为student网络,此外还支持将teacher和student换成[models目录](../demo/models)支持的任意模型 + +demo中对teahcer模型和student模型的一层特征图添加了l2_loss的蒸馏损失函数,使用时也可根据需要选择fsp_loss, soft_label_loss以及自定义的loss函数 + +训练默认使用的是cifar10数据集,piecewise_decay学习率衰减策略,momentum优化器进行120轮蒸馏训练。使用者也可以简单地用args参数切换为使用ImageNet数据集,cosine_decay学习率衰减策略等其他训练配置 +