diff --git a/demo/distillation/train.py b/demo/distillation/distillation_demo.py similarity index 83% rename from demo/distillation/train.py rename to demo/distillation/distillation_demo.py index e17678a70c73cfd1af746118ae7d0685317e96aa..3f47553e541ff86ae0a6f4d86c046a1dee66a03f 100644 --- a/demo/distillation/train.py +++ b/demo/distillation/distillation_demo.py @@ -13,8 +13,7 @@ import numpy as np import paddle.fluid as fluid sys.path.append(sys.path[0] + "/../") import models -import imagenet_reader as reader -from utility import add_arguments, print_arguments +from utility import add_arguments, print_arguments, _download, _decompress from paddleslim.dist import merge, l2_loss, soft_label_loss, fsp_loss logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s') @@ -33,12 +32,12 @@ add_arg('lr_strategy', str, "piecewise_decay", "The learning rate decay add_arg('l2_decay', float, 3e-5, "The l2_decay parameter.") add_arg('momentum_rate', float, 0.9, "The value of momentum_rate.") add_arg('num_epochs', int, 120, "The number of total epochs.") -add_arg('data', str, "mnist", "Which data to use. 'mnist' or 'imagenet'") +add_arg('data', str, "cifar10", "Which data to use. 'cifar10' or 'imagenet'") add_arg('log_period', int, 20, "Log period in batches.") add_arg('model', str, "MobileNet", "Set the network to use.") add_arg('pretrained_model', str, None, "Whether to use pretrained model.") add_arg('teacher_model', str, "ResNet50", "Set the teacher network to use.") -add_arg('teacher_pretrained_model', str, "../pretrain/ResNet50_pretrained", "Whether to use pretrained model.") +add_arg('teacher_pretrained_model', str, "./ResNet50_pretrained", "Whether to use pretrained model.") parser.add_argument('--step_epochs', nargs='+', type=int, default=[30, 60, 90], help="piecewise decay step") # yapf: enable @@ -76,12 +75,12 @@ def create_optimizer(args): def compress(args): - if args.data == "mnist": - import paddle.dataset.mnist as reader - train_reader = reader.train() - val_reader = reader.test() + if args.data == "cifar10": + import paddle.dataset.cifar as reader + train_reader = reader.train10() + val_reader = reader.test10() class_dim = 10 - image_shape = "1,28,28" + image_shape = "3,32,32" elif args.data == "imagenet": import imagenet_reader as reader train_reader = reader.train() @@ -132,7 +131,7 @@ def compress(args): val_reader, batch_size=args.batch_size, drop_last=True) val_program = student_program.clone(for_test=True) - places = fluid.cuda_places() + places = fluid.cuda_places() if args.use_gpu else fluid.cpu_places() train_loader.set_sample_list_generator(train_reader, places) valid_loader.set_sample_list_generator(val_reader, place) @@ -146,11 +145,13 @@ def compress(args): name='image', shape=image_shape, dtype='float32') predict = teacher_model.net(image, class_dim=class_dim) - #print("="*50+"teacher_model_params"+"="*50) - #for v in teacher_program.list_vars(): - # print(v.name, v.shape) + #print("="*50+"teacher_model_params"+"="*50) + #for v in teacher_program.list_vars(): + # print(v.name, v.shape) exe.run(t_startup) + _download('http://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_pretrained.tar', '.') + _decompress('./ResNet50_pretrained.tar') assert args.teacher_pretrained_model and os.path.exists( args.teacher_pretrained_model ), "teacher_pretrained_model should be set when teacher_model is not None." @@ -158,7 +159,7 @@ def compress(args): def if_exist(var): return os.path.exists( os.path.join(args.teacher_pretrained_model, var.name) - ) and var.name != 'conv1_weights' and var.name != 'fc_0.w_0' and var.name != 'fc_0.b_0' + ) and var.name != 'fc_0.w_0' and var.name != 'fc_0.b_0' fluid.io.load_vars( exe, @@ -173,19 +174,9 @@ def compress(args): data_name_map, place) - #print("="*50+"teacher_vars"+"="*50) - #for v in teacher_program.list_vars(): - # if '_generated_var' not in v.name and 'fetch' not in v.name and 'feed' not in v.name: - # print(v.name, v.shape) - #return - with fluid.program_guard(main, s_startup): - l2_loss_v = l2_loss("teacher_fc_0.tmp_0", "fc_0.tmp_0", main) - fsp_loss_v = fsp_loss("teacher_res2a_branch2a.conv2d.output.1.tmp_0", - "teacher_res3a_branch2a.conv2d.output.1.tmp_0", - "depthwise_conv2d_1.tmp_0", "conv2d_3.tmp_0", - main) - loss = avg_cost + l2_loss_v + fsp_loss_v + l2_loss = l2_loss("teacher_fc_0.tmp_0", "fc_0.tmp_0", main) + loss = avg_cost + l2_loss opt = create_optimizer(args) opt.minimize(loss) exe.run(s_startup) @@ -196,17 +187,16 @@ def compress(args): for epoch_id in range(args.num_epochs): for step_id, data in enumerate(train_loader): - loss_1, loss_2, loss_3, loss_4 = exe.run( + loss_1, loss_2, loss_3 = exe.run( parallel_main, feed=data, fetch_list=[ - loss.name, avg_cost.name, l2_loss_v.name, fsp_loss_v.name + loss.name, avg_cost.name, l2_loss.name ]) if step_id % args.log_period == 0: _logger.info( - "train_epoch {} step {} loss {:.6f}, class loss {:.6f}, l2 loss {:.6f}, fsp loss {:.6f}". - format(epoch_id, step_id, loss_1[0], loss_2[0], loss_3[0], - loss_4[0])) + "train_epoch {} step {} loss {:.6f}, class loss {:.6f}, l2 loss {:.6f}". + format(epoch_id, step_id, loss_1[0], loss_2[0], loss_3[0])) val_acc1s = [] val_acc5s = [] for step_id, data in enumerate(valid_loader): diff --git a/demo/utility.py b/demo/utility.py index dd52f69457c9f8d94920b85dc09b58ff8e605a64..475468f2777ae40427465327ff7b78355cfcbbeb 100644 --- a/demo/utility.py +++ b/demo/utility.py @@ -20,6 +20,12 @@ import distutils.util import os import numpy as np import six +import requests +import shutil +import tqdm +import hashlib +import tarfile +import zipfile import logging import paddle.fluid as fluid import paddle.compat as cpt @@ -30,6 +36,7 @@ logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s') _logger = logging.getLogger(__name__) _logger.setLevel(logging.INFO) +DOWNLOAD_RETRY_LIMIT=3 def print_arguments(args): """Print argparse's arguments. @@ -154,3 +161,122 @@ def load_persistable_nodes(executor, dirname, graph): else: _logger.info("Cannot find the var %s!!!" % (node.name())) fluid.io.load_vars(executor=executor, dirname=dirname, vars=var_list) + + +def _download(url, path, md5sum=None): + """ + Download from url, save to path. + + url (str): download url + path (str): download to given path + """ + if not os.path.exists(path): + os.makedirs(path) + + fname = os.path.split(url)[-1] + fullname = os.path.join(path, fname) + retry_cnt = 0 + + while not (os.path.exists(fullname) and _md5check(fullname, md5sum)): + if retry_cnt < DOWNLOAD_RETRY_LIMIT: + retry_cnt += 1 + else: + raise RuntimeError("Download from {} failed. " + "Retry limit reached".format(url)) + + _logger.info("Downloading {} from {}".format(fname, url)) + + req = requests.get(url, stream=True) + if req.status_code != 200: + raise RuntimeError("Downloading from {} failed with code " + "{}!".format(url, req.status_code)) + + # For protecting download interupted, download to + # tmp_fullname firstly, move tmp_fullname to fullname + # after download finished + tmp_fullname = fullname + "_tmp" + total_size = req.headers.get('content-length') + with open(tmp_fullname, 'wb') as f: + if total_size: + for chunk in tqdm.tqdm( + req.iter_content(chunk_size=1024), + total=(int(total_size) + 1023) // 1024, + unit='KB'): + f.write(chunk) + else: + for chunk in req.iter_content(chunk_size=1024): + if chunk: + f.write(chunk) + shutil.move(tmp_fullname, fullname) + + return fullname + +def _md5check(fullname, md5sum=None): + if md5sum is None: + return True + + _logger.info("File {} md5 checking...".format(fullname)) + md5 = hashlib.md5() + with open(fullname, 'rb') as f: + for chunk in iter(lambda: f.read(4096), b""): + md5.update(chunk) + calc_md5sum = md5.hexdigest() + + if calc_md5sum != md5sum: + _logger.info("File {} md5 check failed, {}(calc) != " + "{}(base)".format(fullname, calc_md5sum, md5sum)) + return False + return True + +def _decompress(fname): + """ + Decompress for zip and tar file + """ + _logger.info("Decompressing {}...".format(fname)) + + # For protecting decompressing interupted, + # decompress to fpath_tmp directory firstly, if decompress + # successed, move decompress files to fpath and delete + # fpath_tmp and remove download compress file. + fpath = os.path.split(fname)[0] + fpath_tmp = os.path.join(fpath, 'tmp') + if os.path.isdir(fpath_tmp): + shutil.rmtree(fpath_tmp) + os.makedirs(fpath_tmp) + + if fname.find('tar') >= 0: + with tarfile.open(fname) as tf: + tf.extractall(path=fpath_tmp) + elif fname.find('zip') >= 0: + with zipfile.ZipFile(fname) as zf: + zf.extractall(path=fpath_tmp) + else: + raise TypeError("Unsupport compress file type {}".format(fname)) + + for f in os.listdir(fpath_tmp): + src_dir = os.path.join(fpath_tmp, f) + dst_dir = os.path.join(fpath, f) + _move_and_merge_tree(src_dir, dst_dir) + + shutil.rmtree(fpath_tmp) + os.remove(fname) + +def _move_and_merge_tree(src, dst): + """ + Move src directory to dst, if dst is already exists, + merge src to dst + """ + if not os.path.exists(dst): + shutil.move(src, dst) + else: + for fp in os.listdir(src): + src_fp = os.path.join(src, fp) + dst_fp = os.path.join(dst, fp) + if os.path.isdir(src_fp): + if os.path.isdir(dst_fp): + _move_and_merge_tree(src_fp, dst_fp) + else: + shutil.move(src_fp, dst_fp) + elif os.path.isfile(src_fp) and \ + not os.path.isfile(dst_fp): + shutil.move(src_fp, dst_fp) diff --git a/doc/demo_guide.md b/doc/demo_guide.md index affbd8c246d2d08db8f535068202c3058cc1b80b..e468db907d7bb859be7122cd59d5815ff066aa8c 100644 --- a/doc/demo_guide.md +++ b/doc/demo_guide.md @@ -1,2 +1,9 @@ -## [蒸馏]() +## [蒸馏](../demo/distillation/distillation_demo.py) + +蒸馏demo默认使用ResNet50作为teacher网络,MobileNet作为student网络,此外还支持将teacher和student换成[models目录](../demo/models)支持的任意模型。 + +demo中对teahcer模型和student模型的一层特征图添加了l2_loss的蒸馏损失函数,使用时也可根据需要选择fsp_loss, soft_label_loss以及自定义的loss函数。 + +训练默认使用的是cifar10数据集,piecewise_decay学习率衰减策略,momentum优化器进行120轮蒸馏训练。使用者也可以简单地用args参数切换为使用ImageNet数据集,cosine_decay学习率衰减策略等其他训练配置。 +