未验证 提交 3b4f5f4d 编写于 作者: L littletomatodonkey 提交者: GitHub

add distillation and fix some apis (#810)

* fix save load and imagenet dataset
* refine trainer
上级 b9786424
......@@ -21,8 +21,9 @@ from . import backbone, gears
from .backbone import *
from .gears import build_gear
from .utils import *
from ppcls.utils.save_load import load_dygraph_pretrain
__all__ = ["build_model", "RecModel"]
__all__ = ["build_model", "RecModel", "DistillationModel"]
def build_model(config):
......@@ -62,3 +63,48 @@ class RecModel(nn.Layer):
else:
y = None
return {"features": x, "logits": y}
class DistillationModel(nn.Layer):
def __init__(self,
models=None,
pretrained_list=None,
freeze_params_list=None):
super().__init__()
assert isinstance(models, list)
self.model_list = []
self.model_name_list = []
if pretrained_list is not None:
assert len(pretrained_list) == len(models)
if freeze_params_list is None:
freeze_params_list = [False] * len(models)
assert len(freeze_params_list) == len(models)
for idx, model_config in enumerate(models):
assert len(model_config) == 1
key = list(model_config.keys())[0]
model_config = model_config[key]
print(model_config)
model_name = model_config.pop("name")
model = eval(model_name)(**model_config)
if freeze_params_list[idx]:
for param in model.parameters():
param.trainable = False
self.model_list.append(self.add_sublayer(key, model))
self.model_name_list.append(key)
if pretrained_list is not None:
for idx, pretrained in enumerate(pretrained_list):
if pretrained is not None:
load_dygraph_pretrain(
self.model_name_list[idx], path=pretrained)
def forward(self, x, label=None):
result_dict = dict()
for idx, model_name in enumerate(self.model_name_list):
if label is None:
result_dict[model_name] = self.model_list[idx](x)
else:
result_dict[model_name] = self.model_list[idx](x)
return result_dict
#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import copy
import sys
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
# TODO: fix the format
class CELoss(nn.Layer):
"""
"""
def __init__(self, name="loss", epsilon=None):
super().__init__()
self.name = name
if epsilon is not None and (epsilon <= 0 or epsilon >= 1):
epsilon = None
self.epsilon = epsilon
def _labelsmoothing(self, target, class_num):
if target.shape[-1] != class_num:
one_hot_target = F.one_hot(target, class_num)
else:
one_hot_target = target
soft_target = F.label_smooth(one_hot_target, epsilon=self.epsilon)
soft_target = paddle.reshape(soft_target, shape=[-1, class_num])
return soft_target
def forward(self, logits, label, mode="train"):
loss_dict = {}
if self.epsilon is not None:
class_num = logits.shape[-1]
label = self._labelsmoothing(label, class_num)
x = -F.log_softmax(logits, axis=-1)
loss = paddle.sum(logits * label, axis=-1)
else:
if label.shape[-1] == logits.shape[-1]:
label = F.softmax(label, axis=-1)
soft_label = True
else:
soft_label = False
loss = F.cross_entropy(logits, label=label, soft_label=soft_label)
loss_dict[self.name] = paddle.mean(loss)
return loss_dict
# TODO: fix the format
class Topk(nn.Layer):
def __init__(self, topk=[1, 5]):
super().__init__()
assert isinstance(topk, (int, list))
if isinstance(topk, int):
topk = [topk]
self.topk = topk
def forward(self, x, label):
if isinstance(x, dict):
x = x["logits"]
metric_dict = dict()
for k in self.topk:
metric_dict["top{}".format(k)] = paddle.metric.accuracy(
x, label, k=k)
return metric_dict
# TODO: fix the format
def build_loss(config):
loss_func = CELoss()
return loss_func
# TODO: fix the format
def build_metrics(config):
metrics_func = Topk()
return metrics_func
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: "./output/"
device: "gpu"
class_num: 1000
save_interval: 1
eval_during_train: True
eval_interval: 1
epochs: 120
print_batch_step: 10
use_visualdl: False
# used for static mode and model export
image_shape: [3, 224, 224]
save_inference_dir: "./inference"
# model architecture
Arch:
name: "DistillationModel"
# if not null, its lengths should be same as models
pretrained_list:
# if not null, its lengths should be same as models
freeze_params_list:
- True
- False
models:
- Teacher:
name: MobileNetV3_large_x1_0
pretrained: True
use_ssld: True
- Student:
name: MobileNetV3_small_x1_0
pretrained: False
# loss function config for traing/eval process
Loss:
Train:
- DistillationCELoss:
weight: 1.0
model_name_pairs:
- ["Student", "Teacher"]
Eval:
- DistillationGTCELoss:
weight: 1.0
model_names: ["Student"]
Optimizer:
name: Momentum
momentum: 0.9
lr:
name: Cosine
learning_rate: 1.3
warmup_epoch: 5
regularizer:
name: 'L2'
coeff: 0.00001
# data loader for train and eval
DataLoader:
Train:
dataset:
name: ImageNetDataset
image_root: "./dataset/ILSVRC2012/"
cls_label_path: "./dataset/ILSVRC2012/train_list.txt"
transform_ops:
- RandCropImage:
size: 224
- RandFlipImage:
flip_code: 1
- AutoAugment:
- NormalizeImage:
scale: 0.00392157
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 512
drop_last: False
shuffle: True
loader:
num_workers: 6
use_shared_memory: True
Eval:
# TOTO: modify to the latest trainer
dataset:
name: ImageNetDataset
image_root: "./dataset/ILSVRC2012/"
cls_label_path: "./dataset/ILSVRC2012/val_list.txt"
transform_ops:
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 0.00392157
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 64
drop_last: False
shuffle: False
loader:
num_workers: 6
use_shared_memory: True
Infer:
infer_imgs: "docs/images/whl/demo.jpg"
batch_size: 10
transforms:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
PostProcess:
name: Topk
topk: 5
class_id_map_file: "ppcls/utils/imagenet1k_label_list.txt"
Metric:
Train:
- DistillationTopkAcc:
model_key: "Student"
topk: [1, 5]
Eval:
- DistillationTopkAcc:
model_key: "Student"
topk: [1, 5]
......@@ -31,8 +31,6 @@ class ImageNetDataset(CommonDataset):
lines = fd.readlines()
if seed is not None:
np.random.RandomState(seed).shuffle(lines)
else:
np.random.shuffle(lines)
for l in lines:
l = l.strip().split(" ")
self.images.append(os.path.join(self._img_root, l[0]))
......
......@@ -235,6 +235,8 @@ class Trainer(object):
self.output_dir,
model_name=self.config["Arch"]["name"],
prefix="best_model")
logger.info("[Eval][Epoch {}][best metric: {}]".format(
epoch_id, acc))
self.model.train()
# save model
......@@ -245,14 +247,21 @@ class Trainer(object):
"epoch": epoch_id},
self.output_dir,
model_name=self.config["Arch"]["name"],
prefix="ppcls_epoch_{}".format(epoch_id))
prefix="epoch_{}".format(epoch_id))
# save the latest model
save_load.save_model(
self.model,
optimizer, {"metric": acc,
"epoch": epoch_id},
self.output_dir,
model_name=self.config["Arch"]["name"],
prefix="latest")
def build_avg_metrics(self, info_dict):
return {key: AverageMeter(key, '7.5f') for key in info_dict}
@paddle.no_grad()
def eval(self, epoch_id=0):
self.model.eval()
if self.eval_loss_func is None:
loss_config = self.config.get("Loss", None)
......
......@@ -13,7 +13,12 @@ from .trihardloss import TriHardLoss
from .triplet import TripletLoss, TripletLossV2
from .supconloss import SupConLoss
from .pairwisecosface import PairwiseCosface
from .dmlloss import DMLLoss
from .distanceloss import DistanceLoss
from .distillationloss import DistillationCELoss
from .distillationloss import DistillationGTCELoss
from .distillationloss import DistillationDMLLoss
class CombinedLoss(nn.Layer):
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -13,113 +13,39 @@
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
__all__ = ['CELoss', 'JSDivLoss', 'KLDivLoss']
class CELoss(nn.Layer):
def __init__(self, epsilon=None):
super().__init__()
if epsilon is not None and (epsilon <= 0 or epsilon >= 1):
epsilon = None
self.epsilon = epsilon
class Loss(object):
"""
Loss
"""
def __init__(self, class_dim=1000, epsilon=None):
assert class_dim > 1, "class_dim=%d is not larger than 1" % (class_dim)
self._class_dim = class_dim
if epsilon is not None and epsilon >= 0.0 and epsilon <= 1.0:
self._epsilon = epsilon
self._label_smoothing = True #use label smoothing.(Actually, it is softmax label)
else:
self._epsilon = None
self._label_smoothing = False
#do label_smoothing
def _labelsmoothing(self, target):
if target.shape[-1] != self._class_dim:
one_hot_target = F.one_hot(
target,
self._class_dim) #do ont hot(23,34,46)-> 3 * _class_dim
def _labelsmoothing(self, target, class_num):
if target.shape[-1] != class_num:
one_hot_target = F.one_hot(target, class_num)
else:
one_hot_target = target
#do label_smooth
soft_target = F.label_smooth(
one_hot_target,
epsilon=self._epsilon) #(1 - epsilon) * input + eposilon / K.
soft_target = paddle.reshape(soft_target, shape=[-1, self._class_dim])
soft_target = F.label_smooth(one_hot_target, epsilon=self.epsilon)
soft_target = paddle.reshape(soft_target, shape=[-1, class_num])
return soft_target
def _crossentropy(self, input, target, use_pure_fp16=False):
if self._label_smoothing:
target = self._labelsmoothing(target)
input = -F.log_softmax(input, axis=-1) #softmax and do log
cost = paddle.sum(target * input, axis=-1) #sum
else:
cost = F.cross_entropy(input=input, label=target)
if use_pure_fp16:
avg_cost = paddle.sum(cost)
else:
avg_cost = paddle.mean(cost)
return avg_cost
def _kldiv(self, input, target, name=None):
eps = 1.0e-10
cost = target * paddle.log(
(target + eps) / (input + eps)) * self._class_dim
return cost
def _jsdiv(self, input,
target): #so the input and target is the fc output; no softmax
input = F.softmax(input)
target = F.softmax(target)
#two distribution
cost = self._kldiv(input, target) + self._kldiv(target, input)
cost = cost / 2
avg_cost = paddle.mean(cost)
return avg_cost
def __call__(self, input, target):
pass
class CELoss(Loss):
"""
Cross entropy loss
"""
def __init__(self, class_dim=1000, epsilon=None):
super(CELoss, self).__init__(class_dim, epsilon)
def __call__(self, input, target, use_pure_fp16=False):
if type(input) is dict:
logits = input["logits"]
def forward(self, x, label):
if isinstance(x, dict):
x = x["logits"]
if self.epsilon is not None:
class_num = x.shape[-1]
label = self._labelsmoothing(label, class_num)
x = -F.log_softmax(x, axis=-1)
loss = paddle.sum(x * label, axis=-1)
else:
logits = input
cost = self._crossentropy(logits, target, use_pure_fp16)
return {"CELoss": cost}
class JSDivLoss(Loss):
"""
JSDiv loss
"""
def __init__(self, class_dim=1000, epsilon=None):
super(JSDivLoss, self).__init__(class_dim, epsilon)
def __call__(self, input, target):
cost = self._jsdiv(input, target)
return cost
class KLDivLoss(paddle.nn.Layer):
def __init__(self):
super(KLDivLoss, self).__init__()
def __call__(self, p, q, is_logit=True):
if is_logit:
p = paddle.nn.functional.softmax(p)
q = paddle.nn.functional.softmax(q)
return -(p * paddle.log(q + 1e-8)).sum(1).mean()
if label.shape[-1] == x.shape[-1]:
label = F.softmax(label, axis=-1)
soft_label = True
else:
soft_label = False
loss = F.cross_entropy(x, label=label, soft_label=soft_label)
return {"CELoss": loss}
......@@ -17,6 +17,8 @@ import copy
from collections import OrderedDict
from .metrics import TopkAcc, mAP, mINP, Recallk, RetriMetric
from .metrics import DistillationTopkAcc
class CombinedMetrics(nn.Layer):
def __init__(self, config_list):
......@@ -24,7 +26,7 @@ class CombinedMetrics(nn.Layer):
self.metric_func_list = []
assert isinstance(config_list, list), (
'operator config should be a list')
self.retri_config = dict() # retrieval metrics config
for config in config_list:
assert isinstance(config,
......@@ -35,7 +37,7 @@ class CombinedMetrics(nn.Layer):
continue
metric_params = config[metric_name]
self.metric_func_list.append(eval(metric_name)(**metric_params))
if self.retri_config:
self.metric_func_list.append(RetriMetric(self.retri_config))
......
......@@ -18,7 +18,6 @@ import paddle.nn as nn
from functools import lru_cache
# TODO: fix the format
class TopkAcc(nn.Layer):
def __init__(self, topk=(1, 5)):
super().__init__()
......@@ -84,6 +83,7 @@ class Recallk(nn.Layer):
metric_dict["recall{}".format(k)] = all_cmc[k - 1]
return metric_dict
# retrieval metrics
class RetriMetric(nn.Layer):
def __init__(self, config):
......@@ -93,8 +93,8 @@ class RetriMetric(nn.Layer):
def forward(self, similarities_matrix, query_img_id, gallery_img_id):
metric_dict = dict()
all_cmc, all_AP, all_INP = get_metrics(similarities_matrix, query_img_id,
gallery_img_id, self.max_rank)
all_cmc, all_AP, all_INP = get_metrics(
similarities_matrix, query_img_id, gallery_img_id, self.max_rank)
if "Recallk" in self.config.keys():
topk = self.config['Recallk']['topk']
assert isinstance(topk, (int, list, tuple))
......@@ -109,7 +109,7 @@ class RetriMetric(nn.Layer):
mINP = np.mean(all_INP)
metric_dict["mINP"] = mINP
return metric_dict
@lru_cache()
def get_metrics(similarities_matrix, query_img_id, gallery_img_id,
......@@ -155,3 +155,16 @@ def get_metrics(similarities_matrix, query_img_id, gallery_img_id,
all_cmc = all_cmc.sum(0) / num_valid_q
return all_cmc, all_AP, all_INP
class DistillationTopkAcc(TopkAcc):
def __init__(self, model_key, feature_key=None, topk=(1, 5)):
super().__init__(topk=topk)
self.model_key = model_key
self.feature_key = feature_key
def forward(self, x, label):
x = x[self.model_key]
if self.feature_key is not None:
x = x[self.feature_key]
return super().forward(x, label)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
import os.path as osp
import shutil
import requests
import hashlib
import tarfile
import zipfile
import time
from collections import OrderedDict
from tqdm import tqdm
from ppcls.utils import logger
__all__ = ['get_weights_path_from_url']
WEIGHTS_HOME = osp.expanduser("~/.paddleclas/weights")
DOWNLOAD_RETRY_LIMIT = 3
def is_url(path):
"""
Whether path is URL.
Args:
path (string): URL string or not.
"""
return path.startswith('http://') or path.startswith('https://')
def get_weights_path_from_url(url, md5sum=None):
"""Get weights path from WEIGHT_HOME, if not exists,
download it from url.
Args:
url (str): download url
md5sum (str): md5 sum of download package
Returns:
str: a local path to save downloaded weights.
Examples:
.. code-block:: python
from paddle.utils.download import get_weights_path_from_url
resnet18_pretrained_weight_url = 'https://paddle-hapi.bj.bcebos.com/models/resnet18.pdparams'
local_weight_path = get_weights_path_from_url(resnet18_pretrained_weight_url)
"""
path = get_path_from_url(url, WEIGHTS_HOME, md5sum)
return path
def _map_path(url, root_dir):
# parse path after download under root_dir
fname = osp.split(url)[-1]
fpath = fname
return osp.join(root_dir, fpath)
def _get_unique_endpoints(trainer_endpoints):
# Sorting is to avoid different environmental variables for each card
trainer_endpoints.sort()
ips = set()
unique_endpoints = set()
for endpoint in trainer_endpoints:
ip = endpoint.split(":")[0]
if ip in ips:
continue
ips.add(ip)
unique_endpoints.add(endpoint)
logger.info("unique_endpoints {}".format(unique_endpoints))
return unique_endpoints
def get_path_from_url(url,
root_dir,
md5sum=None,
check_exist=True,
decompress=True):
""" Download from given url to root_dir.
if file or directory specified by url is exists under
root_dir, return the path directly, otherwise download
from url and decompress it, return the path.
Args:
url (str): download url
root_dir (str): root dir for downloading, it should be
WEIGHTS_HOME or DATASET_HOME
md5sum (str): md5 sum of download package
Returns:
str: a local path to save downloaded models & weights & datasets.
"""
from paddle.fluid.dygraph.parallel import ParallelEnv
assert is_url(url), "downloading from {} not a url".format(url)
# parse path after download to decompress under root_dir
fullpath = _map_path(url, root_dir)
# Mainly used to solve the problem of downloading data from different
# machines in the case of multiple machines. Different ips will download
# data, and the same ip will only download data once.
unique_endpoints = _get_unique_endpoints(ParallelEnv()
.trainer_endpoints[:])
if osp.exists(fullpath) and check_exist and _md5check(fullpath, md5sum):
logger.info("Found {}".format(fullpath))
else:
if ParallelEnv().current_endpoint in unique_endpoints:
fullpath = _download(url, root_dir, md5sum)
else:
while not os.path.exists(fullpath):
time.sleep(1)
if ParallelEnv().current_endpoint in unique_endpoints:
if decompress and (tarfile.is_tarfile(fullpath) or
zipfile.is_zipfile(fullpath)):
fullpath = _decompress(fullpath)
return fullpath
def _download(url, path, md5sum=None):
"""
Download from url, save to path.
url (str): download url
path (str): download to given path
"""
if not osp.exists(path):
os.makedirs(path)
fname = osp.split(url)[-1]
fullname = osp.join(path, fname)
retry_cnt = 0
while not (osp.exists(fullname) and _md5check(fullname, md5sum)):
if retry_cnt < DOWNLOAD_RETRY_LIMIT:
retry_cnt += 1
else:
raise RuntimeError("Download from {} failed. "
"Retry limit reached".format(url))
logger.info("Downloading {} from {}".format(fname, url))
try:
req = requests.get(url, stream=True)
except Exception as e: # requests.exceptions.ConnectionError
logger.info(
"Downloading {} from {} failed {} times with exception {}".
format(fname, url, retry_cnt + 1, str(e)))
time.sleep(1)
continue
if req.status_code != 200:
raise RuntimeError("Downloading from {} failed with code "
"{}!".format(url, req.status_code))
# For protecting download interupted, download to
# tmp_fullname firstly, move tmp_fullname to fullname
# after download finished
tmp_fullname = fullname + "_tmp"
total_size = req.headers.get('content-length')
with open(tmp_fullname, 'wb') as f:
if total_size:
with tqdm(total=(int(total_size) + 1023) // 1024) as pbar:
for chunk in req.iter_content(chunk_size=1024):
f.write(chunk)
pbar.update(1)
else:
for chunk in req.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)
shutil.move(tmp_fullname, fullname)
return fullname
def _md5check(fullname, md5sum=None):
if md5sum is None:
return True
logger.info("File {} md5 checking...".format(fullname))
md5 = hashlib.md5()
with open(fullname, 'rb') as f:
for chunk in iter(lambda: f.read(4096), b""):
md5.update(chunk)
calc_md5sum = md5.hexdigest()
if calc_md5sum != md5sum:
logger.info("File {} md5 check failed, {}(calc) != "
"{}(base)".format(fullname, calc_md5sum, md5sum))
return False
return True
def _decompress(fname):
"""
Decompress for zip and tar file
"""
logger.info("Decompressing {}...".format(fname))
# For protecting decompressing interupted,
# decompress to fpath_tmp directory firstly, if decompress
# successed, move decompress files to fpath and delete
# fpath_tmp and remove download compress file.
if tarfile.is_tarfile(fname):
uncompressed_path = _uncompress_file_tar(fname)
elif zipfile.is_zipfile(fname):
uncompressed_path = _uncompress_file_zip(fname)
else:
raise TypeError("Unsupport compress file type {}".format(fname))
return uncompressed_path
def _uncompress_file_zip(filepath):
files = zipfile.ZipFile(filepath, 'r')
file_list = files.namelist()
file_dir = os.path.dirname(filepath)
if _is_a_single_file(file_list):
rootpath = file_list[0]
uncompressed_path = os.path.join(file_dir, rootpath)
for item in file_list:
files.extract(item, file_dir)
elif _is_a_single_dir(file_list):
rootpath = os.path.splitext(file_list[0])[0].split(os.sep)[-1]
uncompressed_path = os.path.join(file_dir, rootpath)
for item in file_list:
files.extract(item, file_dir)
else:
rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1]
uncompressed_path = os.path.join(file_dir, rootpath)
if not os.path.exists(uncompressed_path):
os.makedirs(uncompressed_path)
for item in file_list:
files.extract(item, os.path.join(file_dir, rootpath))
files.close()
return uncompressed_path
def _uncompress_file_tar(filepath, mode="r:*"):
files = tarfile.open(filepath, mode)
file_list = files.getnames()
file_dir = os.path.dirname(filepath)
if _is_a_single_file(file_list):
rootpath = file_list[0]
uncompressed_path = os.path.join(file_dir, rootpath)
for item in file_list:
files.extract(item, file_dir)
elif _is_a_single_dir(file_list):
rootpath = os.path.splitext(file_list[0])[0].split(os.sep)[-1]
uncompressed_path = os.path.join(file_dir, rootpath)
for item in file_list:
files.extract(item, file_dir)
else:
rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1]
uncompressed_path = os.path.join(file_dir, rootpath)
if not os.path.exists(uncompressed_path):
os.makedirs(uncompressed_path)
for item in file_list:
files.extract(item, os.path.join(file_dir, rootpath))
files.close()
return uncompressed_path
def _is_a_single_file(file_list):
if len(file_list) == 1 and file_list[0].find(os.sep) < -1:
return True
return False
def _is_a_single_dir(file_list):
new_file_list = []
for file_path in file_list:
if '/' in file_path:
file_path = file_path.replace('/', os.sep)
elif '\\' in file_path:
file_path = file_path.replace('\\', os.sep)
new_file_list.append(file_path)
file_name = new_file_list[0].split(os.sep)[0]
for i in range(1, len(new_file_list)):
if file_name != new_file_list[i].split(os.sep)[0]:
return False
return True
......@@ -23,10 +23,8 @@ import shutil
import tempfile
import paddle
from paddle.static import load_program_state
from paddle.utils.download import get_weights_path_from_url
from ppcls.utils import logger
from .download import get_weights_path_from_url
__all__ = ['init_model', 'save_model', 'load_dygraph_pretrain']
......@@ -47,70 +45,42 @@ def _mkdir_if_not_exist(path):
raise OSError('Failed to mkdir {}'.format(path))
def load_dygraph_pretrain(model, path=None, load_static_weights=False):
def load_dygraph_pretrain(model, path=None):
if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')):
raise ValueError("Model pretrain path {} does not "
"exists.".format(path))
if load_static_weights:
pre_state_dict = load_program_state(path)
param_state_dict = {}
model_dict = model.state_dict()
for key in model_dict.keys():
weight_name = model_dict[key].name
if weight_name in pre_state_dict.keys():
logger.info('Load weight: {}, shape: {}'.format(
weight_name, pre_state_dict[weight_name].shape))
param_state_dict[key] = pre_state_dict[weight_name]
else:
param_state_dict[key] = model_dict[key]
model.set_dict(param_state_dict)
return
param_state_dict = paddle.load(path + ".pdparams")
model.set_dict(param_state_dict)
return
def load_dygraph_pretrain_from_url(model,
pretrained_url,
use_ssld,
load_static_weights=False):
def load_dygraph_pretrain_from_url(model, pretrained_url, use_ssld):
if use_ssld:
pretrained_url = pretrained_url.replace("_pretrained",
"_ssld_pretrained")
local_weight_path = get_weights_path_from_url(pretrained_url).replace(
".pdparams", "")
load_dygraph_pretrain(
model, path=local_weight_path, load_static_weights=load_static_weights)
load_dygraph_pretrain(model, path=local_weight_path)
return
def load_distillation_model(model, pretrained_model, load_static_weights):
def load_distillation_model(model, pretrained_model):
logger.info("In distillation mode, teacher model will be "
"loaded firstly before student model.")
if not isinstance(pretrained_model, list):
pretrained_model = [pretrained_model]
if not isinstance(load_static_weights, list):
load_static_weights = [load_static_weights] * len(pretrained_model)
teacher = model.teacher if hasattr(model,
"teacher") else model._layers.teacher
student = model.student if hasattr(model,
"student") else model._layers.student
load_dygraph_pretrain(
teacher,
path=pretrained_model[0],
load_static_weights=load_static_weights[0])
load_dygraph_pretrain(teacher, path=pretrained_model[0])
logger.info("Finish initing teacher model from {}".format(
pretrained_model))
# load student model
if len(pretrained_model) >= 2:
load_dygraph_pretrain(
student,
path=pretrained_model[1],
load_static_weights=load_static_weights[1])
load_dygraph_pretrain(student, path=pretrained_model[1])
logger.info("Finish initing student model from {}".format(
pretrained_model))
......@@ -134,16 +104,12 @@ def init_model(config, net, optimizer=None):
return metric_dict
pretrained_model = config.get('pretrained_model')
load_static_weights = config.get('load_static_weights', False)
use_distillation = config.get('use_distillation', False)
if pretrained_model:
if use_distillation:
load_distillation_model(net, pretrained_model, load_static_weights)
load_distillation_model(net, pretrained_model)
else: # common load
load_dygraph_pretrain(
net,
path=pretrained_model,
load_static_weights=load_static_weights)
load_dygraph_pretrain(net, path=pretrained_model)
logger.info(
logger.coloring("Finish load pretrained model from {}".format(
pretrained_model), "HEADER"))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册