未验证 提交 9cdbdca4 编写于 作者: littletomatodonkey's avatar littletomatodonkey 提交者: GitHub

add support for unlabel training (#2103)

上级 5cc6c50c
......@@ -236,8 +236,13 @@ class DistillationDKDLoss(DKDLoss):
temperature=1.0,
alpha=1.0,
beta=1.0,
use_target_as_gt=False,
name="loss_dkd"):
super().__init__(temperature=temperature, alpha=alpha, beta=beta)
super().__init__(
temperature=temperature,
alpha=alpha,
beta=beta,
use_target_as_gt=use_target_as_gt)
self.key = key
self.model_name_pairs = model_name_pairs
self.name = name
......
......@@ -10,13 +10,20 @@ class DKDLoss(nn.Layer):
Code was heavily based on https://github.com/megvii-research/mdistiller
"""
def __init__(self, temperature=1.0, alpha=1.0, beta=1.0):
def __init__(self,
temperature=1.0,
alpha=1.0,
beta=1.0,
use_target_as_gt=False):
super().__init__()
self.temperature = temperature
self.alpha = alpha
self.beta = beta
self.use_target_as_gt = use_target_as_gt
def forward(self, logits_student, logits_teacher, target):
def forward(self, logits_student, logits_teacher, target=None):
if target is None or self.use_target_as_gt:
target = logits_teacher.argmax(axis=-1)
gt_mask = _get_gt_mask(logits_student, target)
other_mask = 1 - gt_mask
pred_student = F.softmax(logits_student / self.temperature, axis=1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册