diff --git a/demo/distillation/train.py b/demo/distillation/distillation_demo.py similarity index 77% rename from demo/distillation/train.py rename to demo/distillation/distillation_demo.py index e17678a70c73cfd1af746118ae7d0685317e96aa..d0dd181409adb6e58c8884621345477669ee0dc0 100644 --- a/demo/distillation/train.py +++ b/demo/distillation/distillation_demo.py @@ -14,8 +14,8 @@ import paddle.fluid as fluid sys.path.append(sys.path[0] + "/../") import models import imagenet_reader as reader -from utility import add_arguments, print_arguments -from paddleslim.dist import merge, l2_loss, soft_label_loss, fsp_loss +from utility import add_arguments, print_arguments, _download, _decompress +from single_distiller import merge, l2_loss, soft_label_loss, fsp_loss logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s') _logger = logging.getLogger(__name__) @@ -38,7 +38,7 @@ add_arg('log_period', int, 20, "Log period in batches.") add_arg('model', str, "MobileNet", "Set the network to use.") add_arg('pretrained_model', str, None, "Whether to use pretrained model.") add_arg('teacher_model', str, "ResNet50", "Set the teacher network to use.") -add_arg('teacher_pretrained_model', str, "../pretrain/ResNet50_pretrained", "Whether to use pretrained model.") +add_arg('teacher_pretrained_model', str, "./ResNet50_pretrained", "Whether to use pretrained model.") parser.add_argument('--step_epochs', nargs='+', type=int, default=[30, 60, 90], help="piecewise decay step") # yapf: enable @@ -77,11 +77,11 @@ def create_optimizer(args): def compress(args): if args.data == "mnist": - import paddle.dataset.mnist as reader - train_reader = reader.train() - val_reader = reader.test() + import paddle.dataset.cifar as reader + train_reader = reader.train10() + val_reader = reader.test10() class_dim = 10 - image_shape = "1,28,28" + image_shape = "3,32,32" elif args.data == "imagenet": import imagenet_reader as reader train_reader = reader.train() @@ -132,7 +132,7 @@ def compress(args): val_reader, batch_size=args.batch_size, drop_last=True) val_program = student_program.clone(for_test=True) - places = fluid.cuda_places() + places = fluid.cuda_places() if args.use_gpu else fluid.cpu_places() train_loader.set_sample_list_generator(train_reader, places) valid_loader.set_sample_list_generator(val_reader, place) @@ -140,52 +140,47 @@ def compress(args): # define teacher program teacher_program = fluid.Program() t_startup = fluid.Program() - with fluid.program_guard(teacher_program, t_startup): - with fluid.unique_name.guard(): - image = fluid.layers.data( - name='image', shape=image_shape, dtype='float32') - predict = teacher_model.net(image, class_dim=class_dim) - - #print("="*50+"teacher_model_params"+"="*50) - #for v in teacher_program.list_vars(): - # print(v.name, v.shape) - - exe.run(t_startup) - assert args.teacher_pretrained_model and os.path.exists( - args.teacher_pretrained_model - ), "teacher_pretrained_model should be set when teacher_model is not None." - - def if_exist(var): - return os.path.exists( - os.path.join(args.teacher_pretrained_model, var.name) - ) and var.name != 'conv1_weights' and var.name != 'fc_0.w_0' and var.name != 'fc_0.b_0' - - fluid.io.load_vars( - exe, - args.teacher_pretrained_model, - main_program=teacher_program, - predicate=if_exist) + teacher_scope = fluid.Scope() + with fluid.scope_guard(teacher_scope): + with fluid.program_guard(teacher_program, t_startup): + with fluid.unique_name.guard(): + image = fluid.layers.data( + name='image', shape=image_shape, dtype='float32') + predict = teacher_model.net(image, class_dim=class_dim) + + #print("="*50+"teacher_model_params"+"="*50) + #for v in teacher_program.list_vars(): + # print(v.name, v.shape) + + exe.run(t_startup) + _download('http://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_pretrained.tar', '.') + _decompress('./ResNet50_pretrained.tar') + assert args.teacher_pretrained_model and os.path.exists( + args.teacher_pretrained_model + ), "teacher_pretrained_model should be set when teacher_model is not None." + + def if_exist(var): + return os.path.exists( + os.path.join(args.teacher_pretrained_model, var.name) + ) and var.name != 'fc_0.w_0' and var.name != 'fc_0.b_0' + + fluid.io.load_vars( + exe, + args.teacher_pretrained_model, + main_program=teacher_program, + predicate=if_exist) data_name_map = {'image': 'image'} main = merge( teacher_program, student_program, data_name_map, - place) - - #print("="*50+"teacher_vars"+"="*50) - #for v in teacher_program.list_vars(): - # if '_generated_var' not in v.name and 'fetch' not in v.name and 'feed' not in v.name: - # print(v.name, v.shape) - #return + place, + teacher_scope=teacher_scope) with fluid.program_guard(main, s_startup): l2_loss_v = l2_loss("teacher_fc_0.tmp_0", "fc_0.tmp_0", main) - fsp_loss_v = fsp_loss("teacher_res2a_branch2a.conv2d.output.1.tmp_0", - "teacher_res3a_branch2a.conv2d.output.1.tmp_0", - "depthwise_conv2d_1.tmp_0", "conv2d_3.tmp_0", - main) - loss = avg_cost + l2_loss_v + fsp_loss_v + loss = avg_cost + l2_loss_v opt = create_optimizer(args) opt.minimize(loss) exe.run(s_startup) @@ -196,17 +191,16 @@ def compress(args): for epoch_id in range(args.num_epochs): for step_id, data in enumerate(train_loader): - loss_1, loss_2, loss_3, loss_4 = exe.run( + loss_1, loss_2, loss_3 = exe.run( parallel_main, feed=data, fetch_list=[ - loss.name, avg_cost.name, l2_loss_v.name, fsp_loss_v.name + loss.name, avg_cost.name, l2_loss_v.name ]) if step_id % args.log_period == 0: _logger.info( - "train_epoch {} step {} loss {:.6f}, class loss {:.6f}, l2 loss {:.6f}, fsp loss {:.6f}". - format(epoch_id, step_id, loss_1[0], loss_2[0], loss_3[0], - loss_4[0])) + "train_epoch {} step {} loss {:.6f}, class loss {:.6f}, l2 loss {:.6f}". + format(epoch_id, step_id, loss_1[0], loss_2[0], loss_3[0])) val_acc1s = [] val_acc5s = [] for step_id, data in enumerate(valid_loader): diff --git a/demo/utility.py b/demo/utility.py index dd52f69457c9f8d94920b85dc09b58ff8e605a64..475468f2777ae40427465327ff7b78355cfcbbeb 100644 --- a/demo/utility.py +++ b/demo/utility.py @@ -20,6 +20,12 @@ import distutils.util import os import numpy as np import six +import requests +import shutil +import tqdm +import hashlib +import tarfile +import zipfile import logging import paddle.fluid as fluid import paddle.compat as cpt @@ -30,6 +36,7 @@ logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s') _logger = logging.getLogger(__name__) _logger.setLevel(logging.INFO) +DOWNLOAD_RETRY_LIMIT=3 def print_arguments(args): """Print argparse's arguments. @@ -154,3 +161,122 @@ def load_persistable_nodes(executor, dirname, graph): else: _logger.info("Cannot find the var %s!!!" % (node.name())) fluid.io.load_vars(executor=executor, dirname=dirname, vars=var_list) + + +def _download(url, path, md5sum=None): + """ + Download from url, save to path. + + url (str): download url + path (str): download to given path + """ + if not os.path.exists(path): + os.makedirs(path) + + fname = os.path.split(url)[-1] + fullname = os.path.join(path, fname) + retry_cnt = 0 + + while not (os.path.exists(fullname) and _md5check(fullname, md5sum)): + if retry_cnt < DOWNLOAD_RETRY_LIMIT: + retry_cnt += 1 + else: + raise RuntimeError("Download from {} failed. " + "Retry limit reached".format(url)) + + _logger.info("Downloading {} from {}".format(fname, url)) + + req = requests.get(url, stream=True) + if req.status_code != 200: + raise RuntimeError("Downloading from {} failed with code " + "{}!".format(url, req.status_code)) + + # For protecting download interupted, download to + # tmp_fullname firstly, move tmp_fullname to fullname + # after download finished + tmp_fullname = fullname + "_tmp" + total_size = req.headers.get('content-length') + with open(tmp_fullname, 'wb') as f: + if total_size: + for chunk in tqdm.tqdm( + req.iter_content(chunk_size=1024), + total=(int(total_size) + 1023) // 1024, + unit='KB'): + f.write(chunk) + else: + for chunk in req.iter_content(chunk_size=1024): + if chunk: + f.write(chunk) + shutil.move(tmp_fullname, fullname) + + return fullname + +def _md5check(fullname, md5sum=None): + if md5sum is None: + return True + + _logger.info("File {} md5 checking...".format(fullname)) + md5 = hashlib.md5() + with open(fullname, 'rb') as f: + for chunk in iter(lambda: f.read(4096), b""): + md5.update(chunk) + calc_md5sum = md5.hexdigest() + + if calc_md5sum != md5sum: + _logger.info("File {} md5 check failed, {}(calc) != " + "{}(base)".format(fullname, calc_md5sum, md5sum)) + return False + return True + +def _decompress(fname): + """ + Decompress for zip and tar file + """ + _logger.info("Decompressing {}...".format(fname)) + + # For protecting decompressing interupted, + # decompress to fpath_tmp directory firstly, if decompress + # successed, move decompress files to fpath and delete + # fpath_tmp and remove download compress file. + fpath = os.path.split(fname)[0] + fpath_tmp = os.path.join(fpath, 'tmp') + if os.path.isdir(fpath_tmp): + shutil.rmtree(fpath_tmp) + os.makedirs(fpath_tmp) + + if fname.find('tar') >= 0: + with tarfile.open(fname) as tf: + tf.extractall(path=fpath_tmp) + elif fname.find('zip') >= 0: + with zipfile.ZipFile(fname) as zf: + zf.extractall(path=fpath_tmp) + else: + raise TypeError("Unsupport compress file type {}".format(fname)) + + for f in os.listdir(fpath_tmp): + src_dir = os.path.join(fpath_tmp, f) + dst_dir = os.path.join(fpath, f) + _move_and_merge_tree(src_dir, dst_dir) + + shutil.rmtree(fpath_tmp) + os.remove(fname) + +def _move_and_merge_tree(src, dst): + """ + Move src directory to dst, if dst is already exists, + merge src to dst + """ + if not os.path.exists(dst): + shutil.move(src, dst) + else: + for fp in os.listdir(src): + src_fp = os.path.join(src, fp) + dst_fp = os.path.join(dst, fp) + if os.path.isdir(src_fp): + if os.path.isdir(dst_fp): + _move_and_merge_tree(src_fp, dst_fp) + else: + shutil.move(src_fp, dst_fp) + elif os.path.isfile(src_fp) and \ + not os.path.isfile(dst_fp): + shutil.move(src_fp, dst_fp) diff --git a/doc/demo_guide.md b/doc/demo_guide.md index affbd8c246d2d08db8f535068202c3058cc1b80b..9722dcc8de94a265491dced85e11e305c00073bc 100644 --- a/doc/demo_guide.md +++ b/doc/demo_guide.md @@ -1,2 +1,3 @@ -## [蒸馏]() +## [蒸馏](../demo/distillation/distillation_demo.py) +