From 9cdbdca4ee9edfc39e178fcdb38c20c4a9361c9c Mon Sep 17 00:00:00 2001 From: littletomatodonkey Date: Tue, 28 Jun 2022 15:58:07 +0800 Subject: [PATCH] add support for unlabel training (#2103) --- ppcls/loss/distillationloss.py | 7 ++++++- ppcls/loss/dkdloss.py | 11 +++++++++-- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/ppcls/loss/distillationloss.py b/ppcls/loss/distillationloss.py index 4ca58d26..4f72777f 100644 --- a/ppcls/loss/distillationloss.py +++ b/ppcls/loss/distillationloss.py @@ -236,8 +236,13 @@ class DistillationDKDLoss(DKDLoss): temperature=1.0, alpha=1.0, beta=1.0, + use_target_as_gt=False, name="loss_dkd"): - super().__init__(temperature=temperature, alpha=alpha, beta=beta) + super().__init__( + temperature=temperature, + alpha=alpha, + beta=beta, + use_target_as_gt=use_target_as_gt) self.key = key self.model_name_pairs = model_name_pairs self.name = name diff --git a/ppcls/loss/dkdloss.py b/ppcls/loss/dkdloss.py index 9ce2c56d..bf9224e3 100644 --- a/ppcls/loss/dkdloss.py +++ b/ppcls/loss/dkdloss.py @@ -10,13 +10,20 @@ class DKDLoss(nn.Layer): Code was heavily based on https://github.com/megvii-research/mdistiller """ - def __init__(self, temperature=1.0, alpha=1.0, beta=1.0): + def __init__(self, + temperature=1.0, + alpha=1.0, + beta=1.0, + use_target_as_gt=False): super().__init__() self.temperature = temperature self.alpha = alpha self.beta = beta + self.use_target_as_gt = use_target_as_gt - def forward(self, logits_student, logits_teacher, target): + def forward(self, logits_student, logits_teacher, target=None): + if target is None or self.use_target_as_gt: + target = logits_teacher.argmax(axis=-1) gt_mask = _get_gt_mask(logits_student, target) other_mask = 1 - gt_mask pred_student = F.softmax(logits_student / self.temperature, axis=1) -- GitLab