提交 fbb36fd3 编写于 作者: littletomatodonkey's avatar littletomatodonkey

add ssld code

上级 2ee646eb
......@@ -42,3 +42,6 @@ from .res2net_vd import Res2Net50_vd_48w_2s, Res2Net50_vd_26w_4s, Res2Net50_vd_1
from .hrnet import HRNet_W18_C, HRNet_W30_C, HRNet_W32_C, HRNet_W40_C, HRNet_W44_C, HRNet_W48_C, HRNet_W60_C, HRNet_W64_C, SE_HRNet_W18_C, SE_HRNet_W30_C, SE_HRNet_W32_C, SE_HRNet_W40_C, SE_HRNet_W44_C, SE_HRNet_W48_C, SE_HRNet_W60_C, SE_HRNet_W64_C
from .darts_gs import DARTS_GS_6M, DARTS_GS_4M
from .resnet_acnet import ResNet18_ACNet, ResNet34_ACNet, ResNet50_ACNet, ResNet101_ACNet, ResNet152_ACNet
# distillation model
from .distillation_models import ResNet50_vd_distill_MobileNetV3_x1_0, ResNeXt101_32x16d_wsl_distill_ResNet50_vd
......@@ -15,7 +15,7 @@
import paddle
import paddle.fluid as fluid
__all__ = ['CELoss', 'MixCELoss', 'GoogLeNetLoss']
__all__ = ['CELoss', 'MixCELoss', 'GoogLeNetLoss', 'JSDivLoss']
class Loss(object):
......@@ -34,8 +34,11 @@ class Loss(object):
self._label_smoothing = False
def _labelsmoothing(self, target):
one_hot_target = fluid.layers.one_hot(
input=target, depth=self._class_dim)
if target.shape[-1] != self._class_dim:
one_hot_target = fluid.layers.one_hot(
input=target, depth=self._class_dim)
else:
one_hot_target = target
soft_target = fluid.layers.label_smooth(
label=one_hot_target, epsilon=self._epsilon, dtype="float32")
return soft_target
......@@ -49,6 +52,19 @@ class Loss(object):
avg_cost = fluid.layers.mean(cost)
return avg_cost
def _kldiv(self, input, target):
cost = target * fluid.layers.log(target / input) * self._class_dim
cost = fluid.layers.sum(cost)
return cost
def _jsdiv(self, input, target):
input = fluid.layers.softmax(input, use_cudnn=False)
target = fluid.layers.softmax(target, use_cudnn=False)
cost = self._kldiv(input, target) + self._kldiv(target, input)
cost = cost / 2
avg_cost = fluid.layers.mean(cost)
return avg_cost
def __call__(self, input, target):
pass
......@@ -97,3 +113,16 @@ class GoogLeNetLoss(Loss):
cost = cost0 + 0.3 * cost1 + 0.3 * cost2
avg_cost = fluid.layers.mean(cost)
return avg_cost
class JSDivLoss(Loss):
"""
JSDiv loss
"""
def __init__(self, class_dim=1000, epsilon=None):
super(JSDivLoss, self).__init__(class_dim, epsilon)
def __call__(self, input, target):
cost = self._jsdiv(input, target)
return cost
......@@ -14,6 +14,7 @@
import os
import logging
logging.basicConfig()
import random
DEBUG = logging.DEBUG #10
......
......@@ -24,6 +24,7 @@ def parse_args():
parser.add_argument("-m", "--model", type=str)
parser.add_argument("-p", "--pretrained_model", type=str)
parser.add_argument("-o", "--output_path", type=str)
parser.add_argument("--class_dim", type=int)
return parser.parse_args()
......@@ -57,7 +58,7 @@ def main():
with fluid.program_guard(infer_prog, startup_prog):
with fluid.unique_name.guard():
image = create_input()
out = create_model(args, model, image)
out = create_model(args, model, image, class_dim=args.class_dim)
infer_prog = infer_prog.clone(for_test=True)
fluid.load(
......
......@@ -31,6 +31,7 @@ from ppcls.optimizer import OptimizerBuilder
from ppcls.modeling import architectures
from ppcls.modeling.loss import CELoss
from ppcls.modeling.loss import MixCELoss
from ppcls.modeling.loss import JSDivLoss
from ppcls.modeling.loss import GoogLeNetLoss
from ppcls.utils.misc import AverageMeter
from ppcls.utils import logger
......@@ -39,13 +40,13 @@ from paddle.fluid.incubate.fleet.collective import fleet
from paddle.fluid.incubate.fleet.collective import DistributedStrategy
def create_feeds(image_shape, mix=None):
def create_feeds(image_shape, use_mix=None):
"""
Create feeds as model input
Args:
image_shape(list[int]): model input shape, such as [3, 224, 224]
mix(bool): whether to use mix(include mixup, cutmix, fmix)
use_mix(bool): whether to use mix(include mixup, cutmix, fmix)
Returns:
feeds(dict): dict of model input variables
......@@ -53,7 +54,7 @@ def create_feeds(image_shape, mix=None):
feeds = OrderedDict()
feeds['image'] = fluid.data(
name="feed_image", shape=[None] + image_shape, dtype="float32")
if mix:
if use_mix:
feeds['feed_y_a'] = fluid.data(
name="feed_y_a", shape=[None, 1], dtype="int64")
feeds['feed_y_b'] = fluid.data(
......@@ -112,7 +113,8 @@ def create_loss(out,
architecture,
classes_num=1000,
epsilon=None,
mix=False):
use_mix=False,
use_distillation=False):
"""
Create a loss for optimization, such as:
1. CrossEnotry loss
......@@ -127,7 +129,7 @@ def create_loss(out,
architecture(dict): architecture information, name(such as ResNet50) is needed
classes_num(int): num of classes
epsilon(float): parameter for label smoothing, 0.0 <= epsilon <= 1.0
mix(bool): whether to use mix(include mixup, cutmix, fmix)
use_mix(bool): whether to use mix(include mixup, cutmix, fmix)
Returns:
loss(variable): loss variable
......@@ -138,7 +140,14 @@ def create_loss(out,
target = feeds['label']
return loss(out[0], out[1], out[2], target)
if mix:
if use_distillation:
assert len(
out) == 2, "distillation output length must be 2 but got {}".format(
len(out))
loss = JSDivLoss(class_dim=classes_num, epsilon=epsilon)
return loss(out[1], out[0])
if use_mix:
loss = MixCELoss(class_dim=classes_num, epsilon=epsilon)
feed_y_a = feeds['feed_y_a']
feed_y_b = feeds['feed_y_b']
......@@ -150,7 +159,8 @@ def create_loss(out,
return loss(out, target)
def create_metric(out, feeds, topk=5, classes_num=1000):
def create_metric(out, feeds, topk=5, classes_num=1000,
use_distillation=False):
"""
Create measures of model accuracy, such as top1 and top5
......@@ -163,6 +173,9 @@ def create_metric(out, feeds, topk=5, classes_num=1000):
Returns:
fetchs(dict): dict of measures
"""
# just need student label to get metrics
if use_distillation:
out = out[1]
fetchs = OrderedDict()
label = feeds['label']
softmax_out = fluid.layers.softmax(out, use_cudnn=False)
......@@ -182,10 +195,11 @@ def create_fetchs(out,
topk=5,
classes_num=1000,
epsilon=None,
mix=False):
use_mix=False,
use_distillation=False):
"""
Create fetchs as model outputs(included loss and measures),
will call create_loss and create_metric(if mix).
will call create_loss and create_metric(if use_mix).
Args:
out(variable): model output variable
......@@ -194,16 +208,17 @@ def create_fetchs(out,
topk(int): usually top5
classes_num(int): num of classes
epsilon(float): parameter for label smoothing, 0.0 <= epsilon <= 1.0
mix(bool): whether to use mix(include mixup, cutmix, fmix)
use_mix(bool): whether to use mix(include mixup, cutmix, fmix)
Returns:
fetchs(dict): dict of model outputs(included loss and measures)
"""
fetchs = OrderedDict()
loss = create_loss(out, feeds, architecture, classes_num, epsilon, mix)
loss = create_loss(out, feeds, architecture, classes_num, epsilon, use_mix,
use_distillation)
fetchs['loss'] = (loss, AverageMeter('loss', ':2.4f', True))
if not mix:
metric = create_metric(out, feeds, topk, classes_num)
if not use_mix:
metric = create_metric(out, feeds, topk, classes_num, use_distillation)
fetchs.update(metric)
return fetchs
......@@ -293,7 +308,8 @@ def build(config, main_prog, startup_prog, is_train=True):
with fluid.program_guard(main_prog, startup_prog):
with fluid.unique_name.guard():
use_mix = config.get('use_mix') and is_train
feeds = create_feeds(config.image_shape, mix=use_mix)
use_distillation = config.get('use_distillation')
feeds = create_feeds(config.image_shape, use_mix=use_mix)
dataloader = create_dataloader(feeds.values())
out = create_model(config.ARCHITECTURE, feeds['image'],
config.classes_num)
......@@ -304,7 +320,8 @@ def build(config, main_prog, startup_prog, is_train=True):
config.topk,
config.classes_num,
epsilon=config.get('ls_epsilon'),
mix=use_mix)
use_mix=use_mix,
use_distillation=use_distillation)
if is_train:
optimizer = create_optimizer(config)
lr = optimizer._global_learning_rate()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册