diff --git a/loss/__init__.py b/loss/__init__.py index 6eaa4c766db1ce4f826758b838da244d7c21e73d..43399e76166677b916712df59ed74c7eca535d8a 100644 --- a/loss/__init__.py +++ b/loss/__init__.py @@ -1 +1 @@ -from loss.cross_entropy import LabelSmoothingCrossEntropy, SparseLabelCrossEntropy \ No newline at end of file +from loss.cross_entropy import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy \ No newline at end of file diff --git a/loss/cross_entropy.py b/loss/cross_entropy.py index 821b1fe312a3feabf2a67d71f7beac3bc78a7994..60bef646cc6c31fd734f234346dbc4255def6622 100644 --- a/loss/cross_entropy.py +++ b/loss/cross_entropy.py @@ -26,10 +26,10 @@ class LabelSmoothingCrossEntropy(nn.Module): return loss.mean() -class SparseLabelCrossEntropy(nn.Module): +class SoftTargetCrossEntropy(nn.Module): def __init__(self): - super(SparseLabelCrossEntropy, self).__init__() + super(SoftTargetCrossEntropy, self).__init__() def forward(self, x, target): loss = torch.sum(-target * F.log_softmax(x, dim=-1), dim=-1) diff --git a/train.py b/train.py index 38d8699225c6222bc796567959f8479b5bbe8236..9a81eecb62cbbc0aefd9ecfed9d186132357720d 100644 --- a/train.py +++ b/train.py @@ -13,7 +13,7 @@ except ImportError: from data import Dataset, create_loader, resolve_data_config, FastCollateMixup, mixup_target from models import create_model, resume_checkpoint from utils import * -from loss import LabelSmoothingCrossEntropy, SparseLabelCrossEntropy +from loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy from optim import create_optimizer from scheduler import create_scheduler @@ -261,7 +261,7 @@ def main(): if args.mixup > 0.: # smoothing is handled with mixup label transform - train_loss_fn = SparseLabelCrossEntropy().cuda() + train_loss_fn = SoftTargetCrossEntropy().cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() elif args.smoothing: train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda()