From 3b4f5f4dfcabd3abc143d5859d54ce3ec8be5d10 Mon Sep 17 00:00:00 2001 From: littletomatodonkey <2120160898@bit.edu.cn> Date: Thu, 10 Jun 2021 16:30:05 +0800 Subject: [PATCH] add distillation and fix some apis (#810) * fix save load and imagenet dataset * refine trainer --- ppcls/arch/__init__.py | 48 ++- ppcls/arch/loss_metrics/__init__.py | 91 ----- ...mv3_large_x1_0_distill_mv3_small_x1_0.yaml | 145 ++++++++ ppcls/data/dataloader/imagenet_dataset.py | 2 - ppcls/engine/trainer.py | 13 +- ppcls/loss/__init__.py | 5 + ppcls/loss/celoss.py | 130 ++----- ppcls/metric/__init__.py | 6 +- ppcls/metric/metrics.py | 21 +- ppcls/utils/download.py | 319 ++++++++++++++++++ ppcls/utils/save_load.py | 52 +-- 11 files changed, 585 insertions(+), 247 deletions(-) delete mode 100644 ppcls/arch/loss_metrics/__init__.py create mode 100644 ppcls/configs/ImageNet/Distillation/mv3_large_x1_0_distill_mv3_small_x1_0.yaml create mode 100644 ppcls/utils/download.py diff --git a/ppcls/arch/__init__.py b/ppcls/arch/__init__.py index 18004a77..67df1974 100644 --- a/ppcls/arch/__init__.py +++ b/ppcls/arch/__init__.py @@ -21,8 +21,9 @@ from . import backbone, gears from .backbone import * from .gears import build_gear from .utils import * +from ppcls.utils.save_load import load_dygraph_pretrain -__all__ = ["build_model", "RecModel"] +__all__ = ["build_model", "RecModel", "DistillationModel"] def build_model(config): @@ -62,3 +63,48 @@ class RecModel(nn.Layer): else: y = None return {"features": x, "logits": y} + + +class DistillationModel(nn.Layer): + def __init__(self, + models=None, + pretrained_list=None, + freeze_params_list=None): + super().__init__() + assert isinstance(models, list) + self.model_list = [] + self.model_name_list = [] + if pretrained_list is not None: + assert len(pretrained_list) == len(models) + + if freeze_params_list is None: + freeze_params_list = [False] * len(models) + assert len(freeze_params_list) == len(models) + for idx, model_config in enumerate(models): + assert len(model_config) == 1 + key = list(model_config.keys())[0] + model_config = model_config[key] + print(model_config) + model_name = model_config.pop("name") + model = eval(model_name)(**model_config) + + if freeze_params_list[idx]: + for param in model.parameters(): + param.trainable = False + self.model_list.append(self.add_sublayer(key, model)) + self.model_name_list.append(key) + + if pretrained_list is not None: + for idx, pretrained in enumerate(pretrained_list): + if pretrained is not None: + load_dygraph_pretrain( + self.model_name_list[idx], path=pretrained) + + def forward(self, x, label=None): + result_dict = dict() + for idx, model_name in enumerate(self.model_name_list): + if label is None: + result_dict[model_name] = self.model_list[idx](x) + else: + result_dict[model_name] = self.model_list[idx](x) + return result_dict diff --git a/ppcls/arch/loss_metrics/__init__.py b/ppcls/arch/loss_metrics/__init__.py deleted file mode 100644 index 934fbd82..00000000 --- a/ppcls/arch/loss_metrics/__init__.py +++ /dev/null @@ -1,91 +0,0 @@ -#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. -# -#Licensed under the Apache License, Version 2.0 (the "License"); -#you may not use this file except in compliance with the License. -#You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -#Unless required by applicable law or agreed to in writing, software -#distributed under the License is distributed on an "AS IS" BASIS, -#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -#See the License for the specific language governing permissions and -#limitations under the License. - -import copy -import sys - -import paddle -import paddle.nn as nn -import paddle.nn.functional as F - - -# TODO: fix the format -class CELoss(nn.Layer): - """ - """ - - def __init__(self, name="loss", epsilon=None): - super().__init__() - self.name = name - if epsilon is not None and (epsilon <= 0 or epsilon >= 1): - epsilon = None - self.epsilon = epsilon - - def _labelsmoothing(self, target, class_num): - if target.shape[-1] != class_num: - one_hot_target = F.one_hot(target, class_num) - else: - one_hot_target = target - soft_target = F.label_smooth(one_hot_target, epsilon=self.epsilon) - soft_target = paddle.reshape(soft_target, shape=[-1, class_num]) - return soft_target - - def forward(self, logits, label, mode="train"): - loss_dict = {} - if self.epsilon is not None: - class_num = logits.shape[-1] - label = self._labelsmoothing(label, class_num) - x = -F.log_softmax(logits, axis=-1) - loss = paddle.sum(logits * label, axis=-1) - else: - if label.shape[-1] == logits.shape[-1]: - label = F.softmax(label, axis=-1) - soft_label = True - else: - soft_label = False - loss = F.cross_entropy(logits, label=label, soft_label=soft_label) - loss_dict[self.name] = paddle.mean(loss) - return loss_dict - - -# TODO: fix the format -class Topk(nn.Layer): - def __init__(self, topk=[1, 5]): - super().__init__() - assert isinstance(topk, (int, list)) - if isinstance(topk, int): - topk = [topk] - self.topk = topk - - def forward(self, x, label): - if isinstance(x, dict): - x = x["logits"] - - metric_dict = dict() - for k in self.topk: - metric_dict["top{}".format(k)] = paddle.metric.accuracy( - x, label, k=k) - return metric_dict - - -# TODO: fix the format -def build_loss(config): - loss_func = CELoss() - return loss_func - - -# TODO: fix the format -def build_metrics(config): - metrics_func = Topk() - return metrics_func diff --git a/ppcls/configs/ImageNet/Distillation/mv3_large_x1_0_distill_mv3_small_x1_0.yaml b/ppcls/configs/ImageNet/Distillation/mv3_large_x1_0_distill_mv3_small_x1_0.yaml new file mode 100644 index 00000000..f46d7be1 --- /dev/null +++ b/ppcls/configs/ImageNet/Distillation/mv3_large_x1_0_distill_mv3_small_x1_0.yaml @@ -0,0 +1,145 @@ +# global configs +Global: + checkpoints: null + pretrained_model: null + output_dir: "./output/" + device: "gpu" + class_num: 1000 + save_interval: 1 + eval_during_train: True + eval_interval: 1 + epochs: 120 + print_batch_step: 10 + use_visualdl: False + # used for static mode and model export + image_shape: [3, 224, 224] + save_inference_dir: "./inference" + +# model architecture +Arch: + name: "DistillationModel" + # if not null, its lengths should be same as models + pretrained_list: + # if not null, its lengths should be same as models + freeze_params_list: + - True + - False + models: + - Teacher: + name: MobileNetV3_large_x1_0 + pretrained: True + use_ssld: True + - Student: + name: MobileNetV3_small_x1_0 + pretrained: False + + +# loss function config for traing/eval process +Loss: + Train: + - DistillationCELoss: + weight: 1.0 + model_name_pairs: + - ["Student", "Teacher"] + Eval: + - DistillationGTCELoss: + weight: 1.0 + model_names: ["Student"] + + +Optimizer: + name: Momentum + momentum: 0.9 + lr: + name: Cosine + learning_rate: 1.3 + warmup_epoch: 5 + regularizer: + name: 'L2' + coeff: 0.00001 + + +# data loader for train and eval +DataLoader: + Train: + dataset: + name: ImageNetDataset + image_root: "./dataset/ILSVRC2012/" + cls_label_path: "./dataset/ILSVRC2012/train_list.txt" + transform_ops: + - RandCropImage: + size: 224 + - RandFlipImage: + flip_code: 1 + - AutoAugment: + - NormalizeImage: + scale: 0.00392157 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + + sampler: + name: DistributedBatchSampler + batch_size: 512 + drop_last: False + shuffle: True + loader: + num_workers: 6 + use_shared_memory: True + + Eval: + # TOTO: modify to the latest trainer + dataset: + name: ImageNetDataset + image_root: "./dataset/ILSVRC2012/" + cls_label_path: "./dataset/ILSVRC2012/val_list.txt" + transform_ops: + - ResizeImage: + resize_short: 256 + - CropImage: + size: 224 + - NormalizeImage: + scale: 0.00392157 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + sampler: + name: DistributedBatchSampler + batch_size: 64 + drop_last: False + shuffle: False + loader: + num_workers: 6 + use_shared_memory: True + +Infer: + infer_imgs: "docs/images/whl/demo.jpg" + batch_size: 10 + transforms: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 256 + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: + PostProcess: + name: Topk + topk: 5 + class_id_map_file: "ppcls/utils/imagenet1k_label_list.txt" + +Metric: + Train: + - DistillationTopkAcc: + model_key: "Student" + topk: [1, 5] + Eval: + - DistillationTopkAcc: + model_key: "Student" + topk: [1, 5] diff --git a/ppcls/data/dataloader/imagenet_dataset.py b/ppcls/data/dataloader/imagenet_dataset.py index 08846ba8..e084bb74 100644 --- a/ppcls/data/dataloader/imagenet_dataset.py +++ b/ppcls/data/dataloader/imagenet_dataset.py @@ -31,8 +31,6 @@ class ImageNetDataset(CommonDataset): lines = fd.readlines() if seed is not None: np.random.RandomState(seed).shuffle(lines) - else: - np.random.shuffle(lines) for l in lines: l = l.strip().split(" ") self.images.append(os.path.join(self._img_root, l[0])) diff --git a/ppcls/engine/trainer.py b/ppcls/engine/trainer.py index 513fba1e..83749357 100644 --- a/ppcls/engine/trainer.py +++ b/ppcls/engine/trainer.py @@ -235,6 +235,8 @@ class Trainer(object): self.output_dir, model_name=self.config["Arch"]["name"], prefix="best_model") + logger.info("[Eval][Epoch {}][best metric: {}]".format( + epoch_id, acc)) self.model.train() # save model @@ -245,14 +247,21 @@ class Trainer(object): "epoch": epoch_id}, self.output_dir, model_name=self.config["Arch"]["name"], - prefix="ppcls_epoch_{}".format(epoch_id)) + prefix="epoch_{}".format(epoch_id)) + # save the latest model + save_load.save_model( + self.model, + optimizer, {"metric": acc, + "epoch": epoch_id}, + self.output_dir, + model_name=self.config["Arch"]["name"], + prefix="latest") def build_avg_metrics(self, info_dict): return {key: AverageMeter(key, '7.5f') for key in info_dict} @paddle.no_grad() def eval(self, epoch_id=0): - self.model.eval() if self.eval_loss_func is None: loss_config = self.config.get("Loss", None) diff --git a/ppcls/loss/__init__.py b/ppcls/loss/__init__.py index c49a5355..d90ff345 100644 --- a/ppcls/loss/__init__.py +++ b/ppcls/loss/__init__.py @@ -13,7 +13,12 @@ from .trihardloss import TriHardLoss from .triplet import TripletLoss, TripletLossV2 from .supconloss import SupConLoss from .pairwisecosface import PairwiseCosface +from .dmlloss import DMLLoss +from .distanceloss import DistanceLoss +from .distillationloss import DistillationCELoss +from .distillationloss import DistillationGTCELoss +from .distillationloss import DistillationDMLLoss class CombinedLoss(nn.Layer): diff --git a/ppcls/loss/celoss.py b/ppcls/loss/celoss.py index 257c41e1..14ead3b6 100644 --- a/ppcls/loss/celoss.py +++ b/ppcls/loss/celoss.py @@ -1,4 +1,4 @@ -# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,113 +13,39 @@ # limitations under the License. import paddle +import paddle.nn as nn import paddle.nn.functional as F -__all__ = ['CELoss', 'JSDivLoss', 'KLDivLoss'] +class CELoss(nn.Layer): + def __init__(self, epsilon=None): + super().__init__() + if epsilon is not None and (epsilon <= 0 or epsilon >= 1): + epsilon = None + self.epsilon = epsilon -class Loss(object): - """ - Loss - """ - - def __init__(self, class_dim=1000, epsilon=None): - assert class_dim > 1, "class_dim=%d is not larger than 1" % (class_dim) - self._class_dim = class_dim - if epsilon is not None and epsilon >= 0.0 and epsilon <= 1.0: - self._epsilon = epsilon - self._label_smoothing = True #use label smoothing.(Actually, it is softmax label) - else: - self._epsilon = None - self._label_smoothing = False - - #do label_smoothing - def _labelsmoothing(self, target): - if target.shape[-1] != self._class_dim: - one_hot_target = F.one_hot( - target, - self._class_dim) #do ont hot(23,34,46)-> 3 * _class_dim + def _labelsmoothing(self, target, class_num): + if target.shape[-1] != class_num: + one_hot_target = F.one_hot(target, class_num) else: one_hot_target = target - - #do label_smooth - soft_target = F.label_smooth( - one_hot_target, - epsilon=self._epsilon) #(1 - epsilon) * input + eposilon / K. - soft_target = paddle.reshape(soft_target, shape=[-1, self._class_dim]) + soft_target = F.label_smooth(one_hot_target, epsilon=self.epsilon) + soft_target = paddle.reshape(soft_target, shape=[-1, class_num]) return soft_target - def _crossentropy(self, input, target, use_pure_fp16=False): - if self._label_smoothing: - target = self._labelsmoothing(target) - input = -F.log_softmax(input, axis=-1) #softmax and do log - cost = paddle.sum(target * input, axis=-1) #sum - else: - cost = F.cross_entropy(input=input, label=target) - - if use_pure_fp16: - avg_cost = paddle.sum(cost) - else: - avg_cost = paddle.mean(cost) - return avg_cost - - def _kldiv(self, input, target, name=None): - eps = 1.0e-10 - cost = target * paddle.log( - (target + eps) / (input + eps)) * self._class_dim - return cost - - def _jsdiv(self, input, - target): #so the input and target is the fc output; no softmax - input = F.softmax(input) - target = F.softmax(target) - - #two distribution - cost = self._kldiv(input, target) + self._kldiv(target, input) - cost = cost / 2 - avg_cost = paddle.mean(cost) - return avg_cost - - def __call__(self, input, target): - pass - - -class CELoss(Loss): - """ - Cross entropy loss - """ - - def __init__(self, class_dim=1000, epsilon=None): - super(CELoss, self).__init__(class_dim, epsilon) - - def __call__(self, input, target, use_pure_fp16=False): - if type(input) is dict: - logits = input["logits"] + def forward(self, x, label): + if isinstance(x, dict): + x = x["logits"] + if self.epsilon is not None: + class_num = x.shape[-1] + label = self._labelsmoothing(label, class_num) + x = -F.log_softmax(x, axis=-1) + loss = paddle.sum(x * label, axis=-1) else: - logits = input - cost = self._crossentropy(logits, target, use_pure_fp16) - return {"CELoss": cost} - - -class JSDivLoss(Loss): - """ - JSDiv loss - """ - - def __init__(self, class_dim=1000, epsilon=None): - super(JSDivLoss, self).__init__(class_dim, epsilon) - - def __call__(self, input, target): - cost = self._jsdiv(input, target) - return cost - - -class KLDivLoss(paddle.nn.Layer): - def __init__(self): - super(KLDivLoss, self).__init__() - - def __call__(self, p, q, is_logit=True): - if is_logit: - p = paddle.nn.functional.softmax(p) - q = paddle.nn.functional.softmax(q) - return -(p * paddle.log(q + 1e-8)).sum(1).mean() + if label.shape[-1] == x.shape[-1]: + label = F.softmax(label, axis=-1) + soft_label = True + else: + soft_label = False + loss = F.cross_entropy(x, label=label, soft_label=soft_label) + return {"CELoss": loss} diff --git a/ppcls/metric/__init__.py b/ppcls/metric/__init__.py index 696e8f85..95f86e4a 100644 --- a/ppcls/metric/__init__.py +++ b/ppcls/metric/__init__.py @@ -17,6 +17,8 @@ import copy from collections import OrderedDict from .metrics import TopkAcc, mAP, mINP, Recallk, RetriMetric +from .metrics import DistillationTopkAcc + class CombinedMetrics(nn.Layer): def __init__(self, config_list): @@ -24,7 +26,7 @@ class CombinedMetrics(nn.Layer): self.metric_func_list = [] assert isinstance(config_list, list), ( 'operator config should be a list') - + self.retri_config = dict() # retrieval metrics config for config in config_list: assert isinstance(config, @@ -35,7 +37,7 @@ class CombinedMetrics(nn.Layer): continue metric_params = config[metric_name] self.metric_func_list.append(eval(metric_name)(**metric_params)) - + if self.retri_config: self.metric_func_list.append(RetriMetric(self.retri_config)) diff --git a/ppcls/metric/metrics.py b/ppcls/metric/metrics.py index 8ec438ec..76cab0f4 100644 --- a/ppcls/metric/metrics.py +++ b/ppcls/metric/metrics.py @@ -18,7 +18,6 @@ import paddle.nn as nn from functools import lru_cache -# TODO: fix the format class TopkAcc(nn.Layer): def __init__(self, topk=(1, 5)): super().__init__() @@ -84,6 +83,7 @@ class Recallk(nn.Layer): metric_dict["recall{}".format(k)] = all_cmc[k - 1] return metric_dict + # retrieval metrics class RetriMetric(nn.Layer): def __init__(self, config): @@ -93,8 +93,8 @@ class RetriMetric(nn.Layer): def forward(self, similarities_matrix, query_img_id, gallery_img_id): metric_dict = dict() - all_cmc, all_AP, all_INP = get_metrics(similarities_matrix, query_img_id, - gallery_img_id, self.max_rank) + all_cmc, all_AP, all_INP = get_metrics( + similarities_matrix, query_img_id, gallery_img_id, self.max_rank) if "Recallk" in self.config.keys(): topk = self.config['Recallk']['topk'] assert isinstance(topk, (int, list, tuple)) @@ -109,7 +109,7 @@ class RetriMetric(nn.Layer): mINP = np.mean(all_INP) metric_dict["mINP"] = mINP return metric_dict - + @lru_cache() def get_metrics(similarities_matrix, query_img_id, gallery_img_id, @@ -155,3 +155,16 @@ def get_metrics(similarities_matrix, query_img_id, gallery_img_id, all_cmc = all_cmc.sum(0) / num_valid_q return all_cmc, all_AP, all_INP + + +class DistillationTopkAcc(TopkAcc): + def __init__(self, model_key, feature_key=None, topk=(1, 5)): + super().__init__(topk=topk) + self.model_key = model_key + self.feature_key = feature_key + + def forward(self, x, label): + x = x[self.model_key] + if self.feature_key is not None: + x = x[self.feature_key] + return super().forward(x, label) diff --git a/ppcls/utils/download.py b/ppcls/utils/download.py new file mode 100644 index 00000000..9c457504 --- /dev/null +++ b/ppcls/utils/download.py @@ -0,0 +1,319 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import sys +import os.path as osp +import shutil +import requests +import hashlib +import tarfile +import zipfile +import time +from collections import OrderedDict +from tqdm import tqdm + +from ppcls.utils import logger + +__all__ = ['get_weights_path_from_url'] + +WEIGHTS_HOME = osp.expanduser("~/.paddleclas/weights") + +DOWNLOAD_RETRY_LIMIT = 3 + + +def is_url(path): + """ + Whether path is URL. + Args: + path (string): URL string or not. + """ + return path.startswith('http://') or path.startswith('https://') + + +def get_weights_path_from_url(url, md5sum=None): + """Get weights path from WEIGHT_HOME, if not exists, + download it from url. + + Args: + url (str): download url + md5sum (str): md5 sum of download package + + Returns: + str: a local path to save downloaded weights. + + Examples: + .. code-block:: python + + from paddle.utils.download import get_weights_path_from_url + + resnet18_pretrained_weight_url = 'https://paddle-hapi.bj.bcebos.com/models/resnet18.pdparams' + local_weight_path = get_weights_path_from_url(resnet18_pretrained_weight_url) + + """ + path = get_path_from_url(url, WEIGHTS_HOME, md5sum) + return path + + +def _map_path(url, root_dir): + # parse path after download under root_dir + fname = osp.split(url)[-1] + fpath = fname + return osp.join(root_dir, fpath) + + +def _get_unique_endpoints(trainer_endpoints): + # Sorting is to avoid different environmental variables for each card + trainer_endpoints.sort() + ips = set() + unique_endpoints = set() + for endpoint in trainer_endpoints: + ip = endpoint.split(":")[0] + if ip in ips: + continue + ips.add(ip) + unique_endpoints.add(endpoint) + logger.info("unique_endpoints {}".format(unique_endpoints)) + return unique_endpoints + + +def get_path_from_url(url, + root_dir, + md5sum=None, + check_exist=True, + decompress=True): + """ Download from given url to root_dir. + if file or directory specified by url is exists under + root_dir, return the path directly, otherwise download + from url and decompress it, return the path. + + Args: + url (str): download url + root_dir (str): root dir for downloading, it should be + WEIGHTS_HOME or DATASET_HOME + md5sum (str): md5 sum of download package + + Returns: + str: a local path to save downloaded models & weights & datasets. + """ + + from paddle.fluid.dygraph.parallel import ParallelEnv + + assert is_url(url), "downloading from {} not a url".format(url) + # parse path after download to decompress under root_dir + fullpath = _map_path(url, root_dir) + # Mainly used to solve the problem of downloading data from different + # machines in the case of multiple machines. Different ips will download + # data, and the same ip will only download data once. + unique_endpoints = _get_unique_endpoints(ParallelEnv() + .trainer_endpoints[:]) + if osp.exists(fullpath) and check_exist and _md5check(fullpath, md5sum): + logger.info("Found {}".format(fullpath)) + else: + if ParallelEnv().current_endpoint in unique_endpoints: + fullpath = _download(url, root_dir, md5sum) + else: + while not os.path.exists(fullpath): + time.sleep(1) + + if ParallelEnv().current_endpoint in unique_endpoints: + if decompress and (tarfile.is_tarfile(fullpath) or + zipfile.is_zipfile(fullpath)): + fullpath = _decompress(fullpath) + + return fullpath + + +def _download(url, path, md5sum=None): + """ + Download from url, save to path. + + url (str): download url + path (str): download to given path + """ + if not osp.exists(path): + os.makedirs(path) + + fname = osp.split(url)[-1] + fullname = osp.join(path, fname) + retry_cnt = 0 + + while not (osp.exists(fullname) and _md5check(fullname, md5sum)): + if retry_cnt < DOWNLOAD_RETRY_LIMIT: + retry_cnt += 1 + else: + raise RuntimeError("Download from {} failed. " + "Retry limit reached".format(url)) + + logger.info("Downloading {} from {}".format(fname, url)) + + try: + req = requests.get(url, stream=True) + except Exception as e: # requests.exceptions.ConnectionError + logger.info( + "Downloading {} from {} failed {} times with exception {}". + format(fname, url, retry_cnt + 1, str(e))) + time.sleep(1) + continue + + if req.status_code != 200: + raise RuntimeError("Downloading from {} failed with code " + "{}!".format(url, req.status_code)) + + # For protecting download interupted, download to + # tmp_fullname firstly, move tmp_fullname to fullname + # after download finished + tmp_fullname = fullname + "_tmp" + total_size = req.headers.get('content-length') + with open(tmp_fullname, 'wb') as f: + if total_size: + with tqdm(total=(int(total_size) + 1023) // 1024) as pbar: + for chunk in req.iter_content(chunk_size=1024): + f.write(chunk) + pbar.update(1) + else: + for chunk in req.iter_content(chunk_size=1024): + if chunk: + f.write(chunk) + shutil.move(tmp_fullname, fullname) + + return fullname + + +def _md5check(fullname, md5sum=None): + if md5sum is None: + return True + + logger.info("File {} md5 checking...".format(fullname)) + md5 = hashlib.md5() + with open(fullname, 'rb') as f: + for chunk in iter(lambda: f.read(4096), b""): + md5.update(chunk) + calc_md5sum = md5.hexdigest() + + if calc_md5sum != md5sum: + logger.info("File {} md5 check failed, {}(calc) != " + "{}(base)".format(fullname, calc_md5sum, md5sum)) + return False + return True + + +def _decompress(fname): + """ + Decompress for zip and tar file + """ + logger.info("Decompressing {}...".format(fname)) + + # For protecting decompressing interupted, + # decompress to fpath_tmp directory firstly, if decompress + # successed, move decompress files to fpath and delete + # fpath_tmp and remove download compress file. + + if tarfile.is_tarfile(fname): + uncompressed_path = _uncompress_file_tar(fname) + elif zipfile.is_zipfile(fname): + uncompressed_path = _uncompress_file_zip(fname) + else: + raise TypeError("Unsupport compress file type {}".format(fname)) + + return uncompressed_path + + +def _uncompress_file_zip(filepath): + files = zipfile.ZipFile(filepath, 'r') + file_list = files.namelist() + + file_dir = os.path.dirname(filepath) + + if _is_a_single_file(file_list): + rootpath = file_list[0] + uncompressed_path = os.path.join(file_dir, rootpath) + + for item in file_list: + files.extract(item, file_dir) + + elif _is_a_single_dir(file_list): + rootpath = os.path.splitext(file_list[0])[0].split(os.sep)[-1] + uncompressed_path = os.path.join(file_dir, rootpath) + + for item in file_list: + files.extract(item, file_dir) + + else: + rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1] + uncompressed_path = os.path.join(file_dir, rootpath) + if not os.path.exists(uncompressed_path): + os.makedirs(uncompressed_path) + for item in file_list: + files.extract(item, os.path.join(file_dir, rootpath)) + + files.close() + + return uncompressed_path + + +def _uncompress_file_tar(filepath, mode="r:*"): + files = tarfile.open(filepath, mode) + file_list = files.getnames() + + file_dir = os.path.dirname(filepath) + + if _is_a_single_file(file_list): + rootpath = file_list[0] + uncompressed_path = os.path.join(file_dir, rootpath) + for item in file_list: + files.extract(item, file_dir) + elif _is_a_single_dir(file_list): + rootpath = os.path.splitext(file_list[0])[0].split(os.sep)[-1] + uncompressed_path = os.path.join(file_dir, rootpath) + for item in file_list: + files.extract(item, file_dir) + else: + rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1] + uncompressed_path = os.path.join(file_dir, rootpath) + if not os.path.exists(uncompressed_path): + os.makedirs(uncompressed_path) + + for item in file_list: + files.extract(item, os.path.join(file_dir, rootpath)) + + files.close() + + return uncompressed_path + + +def _is_a_single_file(file_list): + if len(file_list) == 1 and file_list[0].find(os.sep) < -1: + return True + return False + + +def _is_a_single_dir(file_list): + new_file_list = [] + for file_path in file_list: + if '/' in file_path: + file_path = file_path.replace('/', os.sep) + elif '\\' in file_path: + file_path = file_path.replace('\\', os.sep) + new_file_list.append(file_path) + + file_name = new_file_list[0].split(os.sep)[0] + for i in range(1, len(new_file_list)): + if file_name != new_file_list[i].split(os.sep)[0]: + return False + return True diff --git a/ppcls/utils/save_load.py b/ppcls/utils/save_load.py index d8d80639..c878e71e 100644 --- a/ppcls/utils/save_load.py +++ b/ppcls/utils/save_load.py @@ -23,10 +23,8 @@ import shutil import tempfile import paddle -from paddle.static import load_program_state -from paddle.utils.download import get_weights_path_from_url - from ppcls.utils import logger +from .download import get_weights_path_from_url __all__ = ['init_model', 'save_model', 'load_dygraph_pretrain'] @@ -47,70 +45,42 @@ def _mkdir_if_not_exist(path): raise OSError('Failed to mkdir {}'.format(path)) -def load_dygraph_pretrain(model, path=None, load_static_weights=False): +def load_dygraph_pretrain(model, path=None): if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')): raise ValueError("Model pretrain path {} does not " "exists.".format(path)) - if load_static_weights: - pre_state_dict = load_program_state(path) - param_state_dict = {} - model_dict = model.state_dict() - for key in model_dict.keys(): - weight_name = model_dict[key].name - if weight_name in pre_state_dict.keys(): - logger.info('Load weight: {}, shape: {}'.format( - weight_name, pre_state_dict[weight_name].shape)) - param_state_dict[key] = pre_state_dict[weight_name] - else: - param_state_dict[key] = model_dict[key] - model.set_dict(param_state_dict) - return - param_state_dict = paddle.load(path + ".pdparams") model.set_dict(param_state_dict) return -def load_dygraph_pretrain_from_url(model, - pretrained_url, - use_ssld, - load_static_weights=False): +def load_dygraph_pretrain_from_url(model, pretrained_url, use_ssld): if use_ssld: pretrained_url = pretrained_url.replace("_pretrained", "_ssld_pretrained") local_weight_path = get_weights_path_from_url(pretrained_url).replace( ".pdparams", "") - load_dygraph_pretrain( - model, path=local_weight_path, load_static_weights=load_static_weights) + load_dygraph_pretrain(model, path=local_weight_path) return -def load_distillation_model(model, pretrained_model, load_static_weights): +def load_distillation_model(model, pretrained_model): logger.info("In distillation mode, teacher model will be " "loaded firstly before student model.") if not isinstance(pretrained_model, list): pretrained_model = [pretrained_model] - if not isinstance(load_static_weights, list): - load_static_weights = [load_static_weights] * len(pretrained_model) - teacher = model.teacher if hasattr(model, "teacher") else model._layers.teacher student = model.student if hasattr(model, "student") else model._layers.student - load_dygraph_pretrain( - teacher, - path=pretrained_model[0], - load_static_weights=load_static_weights[0]) + load_dygraph_pretrain(teacher, path=pretrained_model[0]) logger.info("Finish initing teacher model from {}".format( pretrained_model)) # load student model if len(pretrained_model) >= 2: - load_dygraph_pretrain( - student, - path=pretrained_model[1], - load_static_weights=load_static_weights[1]) + load_dygraph_pretrain(student, path=pretrained_model[1]) logger.info("Finish initing student model from {}".format( pretrained_model)) @@ -134,16 +104,12 @@ def init_model(config, net, optimizer=None): return metric_dict pretrained_model = config.get('pretrained_model') - load_static_weights = config.get('load_static_weights', False) use_distillation = config.get('use_distillation', False) if pretrained_model: if use_distillation: - load_distillation_model(net, pretrained_model, load_static_weights) + load_distillation_model(net, pretrained_model) else: # common load - load_dygraph_pretrain( - net, - path=pretrained_model, - load_static_weights=load_static_weights) + load_dygraph_pretrain(net, path=pretrained_model) logger.info( logger.coloring("Finish load pretrained model from {}".format( pretrained_model), "HEADER")) -- GitLab