提交 e6c14427 编写于 作者: R Ross Wightman

More appropriate/correct loss name

上级 99ab1b12
from loss.cross_entropy import LabelSmoothingCrossEntropy, SparseLabelCrossEntropy
\ No newline at end of file
from loss.cross_entropy import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
\ No newline at end of file
......@@ -26,10 +26,10 @@ class LabelSmoothingCrossEntropy(nn.Module):
return loss.mean()
class SparseLabelCrossEntropy(nn.Module):
class SoftTargetCrossEntropy(nn.Module):
def __init__(self):
super(SparseLabelCrossEntropy, self).__init__()
super(SoftTargetCrossEntropy, self).__init__()
def forward(self, x, target):
loss = torch.sum(-target * F.log_softmax(x, dim=-1), dim=-1)
......
......@@ -13,7 +13,7 @@ except ImportError:
from data import Dataset, create_loader, resolve_data_config, FastCollateMixup, mixup_target
from models import create_model, resume_checkpoint
from utils import *
from loss import LabelSmoothingCrossEntropy, SparseLabelCrossEntropy
from loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
from optim import create_optimizer
from scheduler import create_scheduler
......@@ -261,7 +261,7 @@ def main():
if args.mixup > 0.:
# smoothing is handled with mixup label transform
train_loss_fn = SparseLabelCrossEntropy().cuda()
train_loss_fn = SoftTargetCrossEntropy().cuda()
validate_loss_fn = nn.CrossEntropyLoss().cuda()
elif args.smoothing:
train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册