提交 fff7c57b 编写于 作者: B baiyfbupt

refine distillation demo&add a link demo_guide

上级 14d09781
......@@ -14,8 +14,8 @@ import paddle.fluid as fluid
sys.path.append(sys.path[0] + "/../")
import models
import imagenet_reader as reader
from utility import add_arguments, print_arguments
from paddleslim.dist import merge, l2_loss, soft_label_loss, fsp_loss
from utility import add_arguments, print_arguments, _download, _decompress
from single_distiller import merge, l2_loss, soft_label_loss, fsp_loss
logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s')
_logger = logging.getLogger(__name__)
......@@ -38,7 +38,7 @@ add_arg('log_period', int, 20, "Log period in batches.")
add_arg('model', str, "MobileNet", "Set the network to use.")
add_arg('pretrained_model', str, None, "Whether to use pretrained model.")
add_arg('teacher_model', str, "ResNet50", "Set the teacher network to use.")
add_arg('teacher_pretrained_model', str, "../pretrain/ResNet50_pretrained", "Whether to use pretrained model.")
add_arg('teacher_pretrained_model', str, "./ResNet50_pretrained", "Whether to use pretrained model.")
parser.add_argument('--step_epochs', nargs='+', type=int, default=[30, 60, 90], help="piecewise decay step")
# yapf: enable
......@@ -77,11 +77,11 @@ def create_optimizer(args):
def compress(args):
if args.data == "mnist":
import paddle.dataset.mnist as reader
train_reader = reader.train()
val_reader = reader.test()
import paddle.dataset.cifar as reader
train_reader = reader.train10()
val_reader = reader.test10()
class_dim = 10
image_shape = "1,28,28"
image_shape = "3,32,32"
elif args.data == "imagenet":
import imagenet_reader as reader
train_reader = reader.train()
......@@ -132,7 +132,7 @@ def compress(args):
val_reader, batch_size=args.batch_size, drop_last=True)
val_program = student_program.clone(for_test=True)
places = fluid.cuda_places()
places = fluid.cuda_places() if args.use_gpu else fluid.cpu_places()
train_loader.set_sample_list_generator(train_reader, places)
valid_loader.set_sample_list_generator(val_reader, place)
......@@ -140,6 +140,8 @@ def compress(args):
# define teacher program
teacher_program = fluid.Program()
t_startup = fluid.Program()
teacher_scope = fluid.Scope()
with fluid.scope_guard(teacher_scope):
with fluid.program_guard(teacher_program, t_startup):
with fluid.unique_name.guard():
image = fluid.layers.data(
......@@ -151,6 +153,8 @@ def compress(args):
# print(v.name, v.shape)
exe.run(t_startup)
_download('http://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_pretrained.tar', '.')
_decompress('./ResNet50_pretrained.tar')
assert args.teacher_pretrained_model and os.path.exists(
args.teacher_pretrained_model
), "teacher_pretrained_model should be set when teacher_model is not None."
......@@ -158,7 +162,7 @@ def compress(args):
def if_exist(var):
return os.path.exists(
os.path.join(args.teacher_pretrained_model, var.name)
) and var.name != 'conv1_weights' and var.name != 'fc_0.w_0' and var.name != 'fc_0.b_0'
) and var.name != 'fc_0.w_0' and var.name != 'fc_0.b_0'
fluid.io.load_vars(
exe,
......@@ -171,21 +175,12 @@ def compress(args):
teacher_program,
student_program,
data_name_map,
place)
#print("="*50+"teacher_vars"+"="*50)
#for v in teacher_program.list_vars():
# if '_generated_var' not in v.name and 'fetch' not in v.name and 'feed' not in v.name:
# print(v.name, v.shape)
#return
place,
teacher_scope=teacher_scope)
with fluid.program_guard(main, s_startup):
l2_loss_v = l2_loss("teacher_fc_0.tmp_0", "fc_0.tmp_0", main)
fsp_loss_v = fsp_loss("teacher_res2a_branch2a.conv2d.output.1.tmp_0",
"teacher_res3a_branch2a.conv2d.output.1.tmp_0",
"depthwise_conv2d_1.tmp_0", "conv2d_3.tmp_0",
main)
loss = avg_cost + l2_loss_v + fsp_loss_v
loss = avg_cost + l2_loss_v
opt = create_optimizer(args)
opt.minimize(loss)
exe.run(s_startup)
......@@ -196,17 +191,16 @@ def compress(args):
for epoch_id in range(args.num_epochs):
for step_id, data in enumerate(train_loader):
loss_1, loss_2, loss_3, loss_4 = exe.run(
loss_1, loss_2, loss_3 = exe.run(
parallel_main,
feed=data,
fetch_list=[
loss.name, avg_cost.name, l2_loss_v.name, fsp_loss_v.name
loss.name, avg_cost.name, l2_loss_v.name
])
if step_id % args.log_period == 0:
_logger.info(
"train_epoch {} step {} loss {:.6f}, class loss {:.6f}, l2 loss {:.6f}, fsp loss {:.6f}".
format(epoch_id, step_id, loss_1[0], loss_2[0], loss_3[0],
loss_4[0]))
"train_epoch {} step {} loss {:.6f}, class loss {:.6f}, l2 loss {:.6f}".
format(epoch_id, step_id, loss_1[0], loss_2[0], loss_3[0]))
val_acc1s = []
val_acc5s = []
for step_id, data in enumerate(valid_loader):
......
......@@ -20,6 +20,12 @@ import distutils.util
import os
import numpy as np
import six
import requests
import shutil
import tqdm
import hashlib
import tarfile
import zipfile
import logging
import paddle.fluid as fluid
import paddle.compat as cpt
......@@ -30,6 +36,7 @@ logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s')
_logger = logging.getLogger(__name__)
_logger.setLevel(logging.INFO)
DOWNLOAD_RETRY_LIMIT=3
def print_arguments(args):
"""Print argparse's arguments.
......@@ -154,3 +161,122 @@ def load_persistable_nodes(executor, dirname, graph):
else:
_logger.info("Cannot find the var %s!!!" % (node.name()))
fluid.io.load_vars(executor=executor, dirname=dirname, vars=var_list)
def _download(url, path, md5sum=None):
"""
Download from url, save to path.
url (str): download url
path (str): download to given path
"""
if not os.path.exists(path):
os.makedirs(path)
fname = os.path.split(url)[-1]
fullname = os.path.join(path, fname)
retry_cnt = 0
while not (os.path.exists(fullname) and _md5check(fullname, md5sum)):
if retry_cnt < DOWNLOAD_RETRY_LIMIT:
retry_cnt += 1
else:
raise RuntimeError("Download from {} failed. "
"Retry limit reached".format(url))
_logger.info("Downloading {} from {}".format(fname, url))
req = requests.get(url, stream=True)
if req.status_code != 200:
raise RuntimeError("Downloading from {} failed with code "
"{}!".format(url, req.status_code))
# For protecting download interupted, download to
# tmp_fullname firstly, move tmp_fullname to fullname
# after download finished
tmp_fullname = fullname + "_tmp"
total_size = req.headers.get('content-length')
with open(tmp_fullname, 'wb') as f:
if total_size:
for chunk in tqdm.tqdm(
req.iter_content(chunk_size=1024),
total=(int(total_size) + 1023) // 1024,
unit='KB'):
f.write(chunk)
else:
for chunk in req.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)
shutil.move(tmp_fullname, fullname)
return fullname
def _md5check(fullname, md5sum=None):
if md5sum is None:
return True
_logger.info("File {} md5 checking...".format(fullname))
md5 = hashlib.md5()
with open(fullname, 'rb') as f:
for chunk in iter(lambda: f.read(4096), b""):
md5.update(chunk)
calc_md5sum = md5.hexdigest()
if calc_md5sum != md5sum:
_logger.info("File {} md5 check failed, {}(calc) != "
"{}(base)".format(fullname, calc_md5sum, md5sum))
return False
return True
def _decompress(fname):
"""
Decompress for zip and tar file
"""
_logger.info("Decompressing {}...".format(fname))
# For protecting decompressing interupted,
# decompress to fpath_tmp directory firstly, if decompress
# successed, move decompress files to fpath and delete
# fpath_tmp and remove download compress file.
fpath = os.path.split(fname)[0]
fpath_tmp = os.path.join(fpath, 'tmp')
if os.path.isdir(fpath_tmp):
shutil.rmtree(fpath_tmp)
os.makedirs(fpath_tmp)
if fname.find('tar') >= 0:
with tarfile.open(fname) as tf:
tf.extractall(path=fpath_tmp)
elif fname.find('zip') >= 0:
with zipfile.ZipFile(fname) as zf:
zf.extractall(path=fpath_tmp)
else:
raise TypeError("Unsupport compress file type {}".format(fname))
for f in os.listdir(fpath_tmp):
src_dir = os.path.join(fpath_tmp, f)
dst_dir = os.path.join(fpath, f)
_move_and_merge_tree(src_dir, dst_dir)
shutil.rmtree(fpath_tmp)
os.remove(fname)
def _move_and_merge_tree(src, dst):
"""
Move src directory to dst, if dst is already exists,
merge src to dst
"""
if not os.path.exists(dst):
shutil.move(src, dst)
else:
for fp in os.listdir(src):
src_fp = os.path.join(src, fp)
dst_fp = os.path.join(dst, fp)
if os.path.isdir(src_fp):
if os.path.isdir(dst_fp):
_move_and_merge_tree(src_fp, dst_fp)
else:
shutil.move(src_fp, dst_fp)
elif os.path.isfile(src_fp) and \
not os.path.isfile(dst_fp):
shutil.move(src_fp, dst_fp)
## [蒸馏]()
## [蒸馏](../demo/distillation/distillation_demo.py)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册