提交 29e7b33c 编写于 作者: B baiyfbupt

add brief introduction about distillation demo

上级 81506640
......@@ -13,9 +13,8 @@ import numpy as np
import paddle.fluid as fluid
sys.path.append(sys.path[0] + "/../")
import models
import imagenet_reader as reader
from utility import add_arguments, print_arguments, _download, _decompress
from single_distiller import merge, l2_loss, soft_label_loss, fsp_loss
from paddleslim.dist import merge, l2_loss, soft_label_loss, fsp_loss
logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s')
_logger = logging.getLogger(__name__)
......@@ -33,7 +32,7 @@ add_arg('lr_strategy', str, "piecewise_decay", "The learning rate decay
add_arg('l2_decay', float, 3e-5, "The l2_decay parameter.")
add_arg('momentum_rate', float, 0.9, "The value of momentum_rate.")
add_arg('num_epochs', int, 120, "The number of total epochs.")
add_arg('data', str, "mnist", "Which data to use. 'mnist' or 'imagenet'")
add_arg('data', str, "cifar10", "Which data to use. 'cifar10' or 'imagenet'")
add_arg('log_period', int, 20, "Log period in batches.")
add_arg('model', str, "MobileNet", "Set the network to use.")
add_arg('pretrained_model', str, None, "Whether to use pretrained model.")
......@@ -76,7 +75,7 @@ def create_optimizer(args):
def compress(args):
if args.data == "mnist":
if args.data == "cifar10":
import paddle.dataset.cifar as reader
train_reader = reader.train10()
val_reader = reader.test10()
......@@ -176,8 +175,8 @@ def compress(args):
place)
with fluid.program_guard(main, s_startup):
l2_loss_v = l2_loss("teacher_fc_0.tmp_0", "fc_0.tmp_0", main)
loss = avg_cost + l2_loss_v
l2_loss = l2_loss("teacher_fc_0.tmp_0", "fc_0.tmp_0", main)
loss = avg_cost + l2_loss
opt = create_optimizer(args)
opt.minimize(loss)
exe.run(s_startup)
......@@ -192,7 +191,7 @@ def compress(args):
parallel_main,
feed=data,
fetch_list=[
loss.name, avg_cost.name, l2_loss_v.name
loss.name, avg_cost.name, l2_loss.name
])
if step_id % args.log_period == 0:
_logger.info(
......
## [蒸馏](../demo/distillation/distillation_demo.py)
蒸馏demo默认使用ResNet50作为teacher网络,MobileNet作为student网络,此外还支持将teacher和student换成[models目录](../demo/models)支持的任意模型
demo中对teahcer模型和student模型的一层特征图添加了l2_loss的蒸馏损失函数,使用时也可根据需要选择fsp_loss, soft_label_loss以及自定义的loss函数
训练默认使用的是cifar10数据集,piecewise_decay学习率衰减策略,momentum优化器进行120轮蒸馏训练。使用者也可以简单地用args参数切换为使用ImageNet数据集,cosine_decay学习率衰减策略等其他训练配置
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册