提交 fff7c57b 编写于 作者: B baiyfbupt

refine distillation demo&add a link demo_guide

上级 14d09781
...@@ -14,8 +14,8 @@ import paddle.fluid as fluid ...@@ -14,8 +14,8 @@ import paddle.fluid as fluid
sys.path.append(sys.path[0] + "/../") sys.path.append(sys.path[0] + "/../")
import models import models
import imagenet_reader as reader import imagenet_reader as reader
from utility import add_arguments, print_arguments from utility import add_arguments, print_arguments, _download, _decompress
from paddleslim.dist import merge, l2_loss, soft_label_loss, fsp_loss from single_distiller import merge, l2_loss, soft_label_loss, fsp_loss
logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s') logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s')
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
...@@ -38,7 +38,7 @@ add_arg('log_period', int, 20, "Log period in batches.") ...@@ -38,7 +38,7 @@ add_arg('log_period', int, 20, "Log period in batches.")
add_arg('model', str, "MobileNet", "Set the network to use.") add_arg('model', str, "MobileNet", "Set the network to use.")
add_arg('pretrained_model', str, None, "Whether to use pretrained model.") add_arg('pretrained_model', str, None, "Whether to use pretrained model.")
add_arg('teacher_model', str, "ResNet50", "Set the teacher network to use.") add_arg('teacher_model', str, "ResNet50", "Set the teacher network to use.")
add_arg('teacher_pretrained_model', str, "../pretrain/ResNet50_pretrained", "Whether to use pretrained model.") add_arg('teacher_pretrained_model', str, "./ResNet50_pretrained", "Whether to use pretrained model.")
parser.add_argument('--step_epochs', nargs='+', type=int, default=[30, 60, 90], help="piecewise decay step") parser.add_argument('--step_epochs', nargs='+', type=int, default=[30, 60, 90], help="piecewise decay step")
# yapf: enable # yapf: enable
...@@ -77,11 +77,11 @@ def create_optimizer(args): ...@@ -77,11 +77,11 @@ def create_optimizer(args):
def compress(args): def compress(args):
if args.data == "mnist": if args.data == "mnist":
import paddle.dataset.mnist as reader import paddle.dataset.cifar as reader
train_reader = reader.train() train_reader = reader.train10()
val_reader = reader.test() val_reader = reader.test10()
class_dim = 10 class_dim = 10
image_shape = "1,28,28" image_shape = "3,32,32"
elif args.data == "imagenet": elif args.data == "imagenet":
import imagenet_reader as reader import imagenet_reader as reader
train_reader = reader.train() train_reader = reader.train()
...@@ -132,7 +132,7 @@ def compress(args): ...@@ -132,7 +132,7 @@ def compress(args):
val_reader, batch_size=args.batch_size, drop_last=True) val_reader, batch_size=args.batch_size, drop_last=True)
val_program = student_program.clone(for_test=True) val_program = student_program.clone(for_test=True)
places = fluid.cuda_places() places = fluid.cuda_places() if args.use_gpu else fluid.cpu_places()
train_loader.set_sample_list_generator(train_reader, places) train_loader.set_sample_list_generator(train_reader, places)
valid_loader.set_sample_list_generator(val_reader, place) valid_loader.set_sample_list_generator(val_reader, place)
...@@ -140,52 +140,47 @@ def compress(args): ...@@ -140,52 +140,47 @@ def compress(args):
# define teacher program # define teacher program
teacher_program = fluid.Program() teacher_program = fluid.Program()
t_startup = fluid.Program() t_startup = fluid.Program()
with fluid.program_guard(teacher_program, t_startup): teacher_scope = fluid.Scope()
with fluid.unique_name.guard(): with fluid.scope_guard(teacher_scope):
image = fluid.layers.data( with fluid.program_guard(teacher_program, t_startup):
name='image', shape=image_shape, dtype='float32') with fluid.unique_name.guard():
predict = teacher_model.net(image, class_dim=class_dim) image = fluid.layers.data(
name='image', shape=image_shape, dtype='float32')
#print("="*50+"teacher_model_params"+"="*50) predict = teacher_model.net(image, class_dim=class_dim)
#for v in teacher_program.list_vars():
# print(v.name, v.shape) #print("="*50+"teacher_model_params"+"="*50)
#for v in teacher_program.list_vars():
exe.run(t_startup) # print(v.name, v.shape)
assert args.teacher_pretrained_model and os.path.exists(
args.teacher_pretrained_model exe.run(t_startup)
), "teacher_pretrained_model should be set when teacher_model is not None." _download('http://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_pretrained.tar', '.')
_decompress('./ResNet50_pretrained.tar')
def if_exist(var): assert args.teacher_pretrained_model and os.path.exists(
return os.path.exists( args.teacher_pretrained_model
os.path.join(args.teacher_pretrained_model, var.name) ), "teacher_pretrained_model should be set when teacher_model is not None."
) and var.name != 'conv1_weights' and var.name != 'fc_0.w_0' and var.name != 'fc_0.b_0'
def if_exist(var):
fluid.io.load_vars( return os.path.exists(
exe, os.path.join(args.teacher_pretrained_model, var.name)
args.teacher_pretrained_model, ) and var.name != 'fc_0.w_0' and var.name != 'fc_0.b_0'
main_program=teacher_program,
predicate=if_exist) fluid.io.load_vars(
exe,
args.teacher_pretrained_model,
main_program=teacher_program,
predicate=if_exist)
data_name_map = {'image': 'image'} data_name_map = {'image': 'image'}
main = merge( main = merge(
teacher_program, teacher_program,
student_program, student_program,
data_name_map, data_name_map,
place) place,
teacher_scope=teacher_scope)
#print("="*50+"teacher_vars"+"="*50)
#for v in teacher_program.list_vars():
# if '_generated_var' not in v.name and 'fetch' not in v.name and 'feed' not in v.name:
# print(v.name, v.shape)
#return
with fluid.program_guard(main, s_startup): with fluid.program_guard(main, s_startup):
l2_loss_v = l2_loss("teacher_fc_0.tmp_0", "fc_0.tmp_0", main) l2_loss_v = l2_loss("teacher_fc_0.tmp_0", "fc_0.tmp_0", main)
fsp_loss_v = fsp_loss("teacher_res2a_branch2a.conv2d.output.1.tmp_0", loss = avg_cost + l2_loss_v
"teacher_res3a_branch2a.conv2d.output.1.tmp_0",
"depthwise_conv2d_1.tmp_0", "conv2d_3.tmp_0",
main)
loss = avg_cost + l2_loss_v + fsp_loss_v
opt = create_optimizer(args) opt = create_optimizer(args)
opt.minimize(loss) opt.minimize(loss)
exe.run(s_startup) exe.run(s_startup)
...@@ -196,17 +191,16 @@ def compress(args): ...@@ -196,17 +191,16 @@ def compress(args):
for epoch_id in range(args.num_epochs): for epoch_id in range(args.num_epochs):
for step_id, data in enumerate(train_loader): for step_id, data in enumerate(train_loader):
loss_1, loss_2, loss_3, loss_4 = exe.run( loss_1, loss_2, loss_3 = exe.run(
parallel_main, parallel_main,
feed=data, feed=data,
fetch_list=[ fetch_list=[
loss.name, avg_cost.name, l2_loss_v.name, fsp_loss_v.name loss.name, avg_cost.name, l2_loss_v.name
]) ])
if step_id % args.log_period == 0: if step_id % args.log_period == 0:
_logger.info( _logger.info(
"train_epoch {} step {} loss {:.6f}, class loss {:.6f}, l2 loss {:.6f}, fsp loss {:.6f}". "train_epoch {} step {} loss {:.6f}, class loss {:.6f}, l2 loss {:.6f}".
format(epoch_id, step_id, loss_1[0], loss_2[0], loss_3[0], format(epoch_id, step_id, loss_1[0], loss_2[0], loss_3[0]))
loss_4[0]))
val_acc1s = [] val_acc1s = []
val_acc5s = [] val_acc5s = []
for step_id, data in enumerate(valid_loader): for step_id, data in enumerate(valid_loader):
......
...@@ -20,6 +20,12 @@ import distutils.util ...@@ -20,6 +20,12 @@ import distutils.util
import os import os
import numpy as np import numpy as np
import six import six
import requests
import shutil
import tqdm
import hashlib
import tarfile
import zipfile
import logging import logging
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.compat as cpt import paddle.compat as cpt
...@@ -30,6 +36,7 @@ logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s') ...@@ -30,6 +36,7 @@ logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s')
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
_logger.setLevel(logging.INFO) _logger.setLevel(logging.INFO)
DOWNLOAD_RETRY_LIMIT=3
def print_arguments(args): def print_arguments(args):
"""Print argparse's arguments. """Print argparse's arguments.
...@@ -154,3 +161,122 @@ def load_persistable_nodes(executor, dirname, graph): ...@@ -154,3 +161,122 @@ def load_persistable_nodes(executor, dirname, graph):
else: else:
_logger.info("Cannot find the var %s!!!" % (node.name())) _logger.info("Cannot find the var %s!!!" % (node.name()))
fluid.io.load_vars(executor=executor, dirname=dirname, vars=var_list) fluid.io.load_vars(executor=executor, dirname=dirname, vars=var_list)
def _download(url, path, md5sum=None):
"""
Download from url, save to path.
url (str): download url
path (str): download to given path
"""
if not os.path.exists(path):
os.makedirs(path)
fname = os.path.split(url)[-1]
fullname = os.path.join(path, fname)
retry_cnt = 0
while not (os.path.exists(fullname) and _md5check(fullname, md5sum)):
if retry_cnt < DOWNLOAD_RETRY_LIMIT:
retry_cnt += 1
else:
raise RuntimeError("Download from {} failed. "
"Retry limit reached".format(url))
_logger.info("Downloading {} from {}".format(fname, url))
req = requests.get(url, stream=True)
if req.status_code != 200:
raise RuntimeError("Downloading from {} failed with code "
"{}!".format(url, req.status_code))
# For protecting download interupted, download to
# tmp_fullname firstly, move tmp_fullname to fullname
# after download finished
tmp_fullname = fullname + "_tmp"
total_size = req.headers.get('content-length')
with open(tmp_fullname, 'wb') as f:
if total_size:
for chunk in tqdm.tqdm(
req.iter_content(chunk_size=1024),
total=(int(total_size) + 1023) // 1024,
unit='KB'):
f.write(chunk)
else:
for chunk in req.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)
shutil.move(tmp_fullname, fullname)
return fullname
def _md5check(fullname, md5sum=None):
if md5sum is None:
return True
_logger.info("File {} md5 checking...".format(fullname))
md5 = hashlib.md5()
with open(fullname, 'rb') as f:
for chunk in iter(lambda: f.read(4096), b""):
md5.update(chunk)
calc_md5sum = md5.hexdigest()
if calc_md5sum != md5sum:
_logger.info("File {} md5 check failed, {}(calc) != "
"{}(base)".format(fullname, calc_md5sum, md5sum))
return False
return True
def _decompress(fname):
"""
Decompress for zip and tar file
"""
_logger.info("Decompressing {}...".format(fname))
# For protecting decompressing interupted,
# decompress to fpath_tmp directory firstly, if decompress
# successed, move decompress files to fpath and delete
# fpath_tmp and remove download compress file.
fpath = os.path.split(fname)[0]
fpath_tmp = os.path.join(fpath, 'tmp')
if os.path.isdir(fpath_tmp):
shutil.rmtree(fpath_tmp)
os.makedirs(fpath_tmp)
if fname.find('tar') >= 0:
with tarfile.open(fname) as tf:
tf.extractall(path=fpath_tmp)
elif fname.find('zip') >= 0:
with zipfile.ZipFile(fname) as zf:
zf.extractall(path=fpath_tmp)
else:
raise TypeError("Unsupport compress file type {}".format(fname))
for f in os.listdir(fpath_tmp):
src_dir = os.path.join(fpath_tmp, f)
dst_dir = os.path.join(fpath, f)
_move_and_merge_tree(src_dir, dst_dir)
shutil.rmtree(fpath_tmp)
os.remove(fname)
def _move_and_merge_tree(src, dst):
"""
Move src directory to dst, if dst is already exists,
merge src to dst
"""
if not os.path.exists(dst):
shutil.move(src, dst)
else:
for fp in os.listdir(src):
src_fp = os.path.join(src, fp)
dst_fp = os.path.join(dst, fp)
if os.path.isdir(src_fp):
if os.path.isdir(dst_fp):
_move_and_merge_tree(src_fp, dst_fp)
else:
shutil.move(src_fp, dst_fp)
elif os.path.isfile(src_fp) and \
not os.path.isfile(dst_fp):
shutil.move(src_fp, dst_fp)
## [蒸馏]() ## [蒸馏](../demo/distillation/distillation_demo.py)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册