diff --git a/ppdet/engine/trainer_ssod.py b/ppdet/engine/trainer_ssod.py index 90b8a9f7f8089f7102ab18162079004187aa8f14..ef2409b09b4e599200abf3e670577926bd07792d 100644 --- a/ppdet/engine/trainer_ssod.py +++ b/ppdet/engine/trainer_ssod.py @@ -31,7 +31,7 @@ from ppdet.core.workspace import create from ppdet.utils.checkpoint import load_weight, load_pretrain_weight import ppdet.utils.stats as stats from ppdet.utils import profiler -from ppdet.modeling.ssod_utils import align_weak_strong_shape +from ppdet.modeling.ssod.utils import align_weak_strong_shape from .trainer import Trainer from ppdet.utils.logger import setup_logger @@ -317,10 +317,10 @@ class Trainer_DenseTeacher(Trainer): train_cfg['curr_iter'] = curr_iter train_cfg['st_iter'] = st_iter if self._nranks > 1: - loss_dict_unsup = self.model._layers.get_ssod_distill_loss( + loss_dict_unsup = self.model._layers.get_ssod_loss( student_preds, teacher_preds, train_cfg) else: - loss_dict_unsup = self.model.get_ssod_distill_loss( + loss_dict_unsup = self.model.get_ssod_loss( student_preds, teacher_preds, train_cfg) fg_num = loss_dict_unsup["fg_sum"] diff --git a/ppdet/modeling/__init__.py b/ppdet/modeling/__init__.py index 601b14f0d4f72fe75c1bb47ae774f643ccec496c..fc7caf4403318f0ff37fecc1a4a032c468009fb0 100644 --- a/ppdet/modeling/__init__.py +++ b/ppdet/modeling/__init__.py @@ -30,7 +30,7 @@ from . import mot from . import transformers from . import assigners from . import rbox_utils -from . import ssod_utils +from . import ssod from .ops import * from .backbones import * @@ -46,4 +46,4 @@ from .mot import * from .transformers import * from .assigners import * from .rbox_utils import * -from .ssod_utils import * +from .ssod import * diff --git a/ppdet/modeling/architectures/fcos.py b/ppdet/modeling/architectures/fcos.py index 4a892c836f5322e4972aab8fdee65a81ed37624e..efebb6efb8a20558540772f5c31994e15ff8d09c 100644 --- a/ppdet/modeling/architectures/fcos.py +++ b/ppdet/modeling/architectures/fcos.py @@ -16,12 +16,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import paddle -import paddle.nn.functional as F from ppdet.core.workspace import register, create from .meta_arch import BaseArch -from ..ssod_utils import permute_to_N_HWA_K, QFLv2 -from ..losses import GIoULoss __all__ = ['FCOS'] @@ -35,16 +31,25 @@ class FCOS(BaseArch): backbone (object): backbone instance neck (object): 'FPN' instance fcos_head (object): 'FCOSHead' instance + ssod_loss (object): 'SSODFCOSLoss' instance, only used for semi-det(ssod) """ __category__ = 'architecture' + __inject__ = ['ssod_loss'] - def __init__(self, backbone, neck='FPN', fcos_head='FCOSHead'): + def __init__(self, + backbone='ResNet', + neck='FPN', + fcos_head='FCOSHead', + ssod_loss='SSODFCOSLoss'): super(FCOS, self).__init__() self.backbone = backbone self.neck = neck self.fcos_head = fcos_head + + # for ssod, semi-det self.is_teacher = False + self.ssod_loss = ssod_loss @classmethod def from_config(cls, cfg, *args, **kwargs): @@ -85,90 +90,7 @@ class FCOS(BaseArch): def get_loss_keys(self): return ['loss_cls', 'loss_box', 'loss_quality'] - def get_ssod_distill_loss(self, student_head_outs, teacher_head_outs, - train_cfg): - student_logits, student_deltas, student_quality = student_head_outs - teacher_logits, teacher_deltas, teacher_quality = teacher_head_outs - nc = student_logits[0].shape[1] - - student_logits = paddle.concat( - [ - _.transpose([0, 2, 3, 1]).reshape([-1, nc]) - for _ in student_logits - ], - axis=0) - teacher_logits = paddle.concat( - [ - _.transpose([0, 2, 3, 1]).reshape([-1, nc]) - for _ in teacher_logits - ], - axis=0) - - student_deltas = paddle.concat( - [ - _.transpose([0, 2, 3, 1]).reshape([-1, 4]) - for _ in student_deltas - ], - axis=0) - teacher_deltas = paddle.concat( - [ - _.transpose([0, 2, 3, 1]).reshape([-1, 4]) - for _ in teacher_deltas - ], - axis=0) - - student_quality = paddle.concat( - [ - _.transpose([0, 2, 3, 1]).reshape([-1, 1]) - for _ in student_quality - ], - axis=0) - teacher_quality = paddle.concat( - [ - _.transpose([0, 2, 3, 1]).reshape([-1, 1]) - for _ in teacher_quality - ], - axis=0) - - ratio = train_cfg.get('ratio', 0.01) - with paddle.no_grad(): - # Region Selection - count_num = int(teacher_logits.shape[0] * ratio) - teacher_probs = F.sigmoid(teacher_logits) - max_vals = paddle.max(teacher_probs, 1) - sorted_vals, sorted_inds = paddle.topk(max_vals, - teacher_logits.shape[0]) - mask = paddle.zeros_like(max_vals) - mask[sorted_inds[:count_num]] = 1. - fg_num = sorted_vals[:count_num].sum() - b_mask = mask > 0 - - # distill_loss_cls - loss_logits = QFLv2( - F.sigmoid(student_logits), - teacher_probs, - weight=mask, - reduction="sum") / fg_num - - # distill_loss_box - inputs = paddle.concat( - (-student_deltas[b_mask][..., :2], student_deltas[b_mask][..., 2:]), - axis=-1) - targets = paddle.concat( - (-teacher_deltas[b_mask][..., :2], teacher_deltas[b_mask][..., 2:]), - axis=-1) - iou_loss = GIoULoss(reduction='mean') - loss_deltas = iou_loss(inputs, targets) - - # distill_loss_quality - loss_quality = F.binary_cross_entropy( - F.sigmoid(student_quality[b_mask]), - F.sigmoid(teacher_quality[b_mask]), - reduction='mean') - - return { - "distill_loss_cls": loss_logits, - "distill_loss_box": loss_deltas, - "distill_loss_quality": loss_quality, - "fg_sum": fg_num, - } + def get_ssod_loss(self, student_head_outs, teacher_head_outs, train_cfg): + ssod_losses = self.ssod_loss(student_head_outs, teacher_head_outs, + train_cfg) + return ssod_losses diff --git a/ppdet/modeling/architectures/ppyoloe.py b/ppdet/modeling/architectures/ppyoloe.py index 96b556aea5f53bf4fa4a12c7416b614e00f3d125..7c2b3a81575a6c527eed6def5ee702320b1e7e76 100644 --- a/ppdet/modeling/architectures/ppyoloe.py +++ b/ppdet/modeling/architectures/ppyoloe.py @@ -17,13 +17,9 @@ from __future__ import division from __future__ import print_function import copy - import paddle -import paddle.nn.functional as F from ppdet.core.workspace import register, create from .meta_arch import BaseArch -from ..ssod_utils import QFLv2 -from ..losses import GIoULoss __all__ = ['PPYOLOE', 'PPYOLOEWithAuxHead'] # PP-YOLOE and PP-YOLOE+ are recommended to use this architecture, especially when use distillation or aux head @@ -32,29 +28,34 @@ __all__ = ['PPYOLOE', 'PPYOLOEWithAuxHead'] @register class PPYOLOE(BaseArch): + """ + PPYOLOE network, see https://arxiv.org/abs/2203.16250 + + Args: + backbone (nn.Layer): backbone instance + neck (nn.Layer): neck instance + yolo_head (nn.Layer): anchor_head instance + post_process (object): `BBoxPostProcess` instance + ssod_loss (object): 'SSODPPYOLOELoss' instance, only used for semi-det(ssod) + for_distill (bool): whether for distillation + feat_distill_place (str): distill which feature for distillation + for_mot (bool): whether return other features for multi-object tracking + models, default False in pure object detection models. + """ + __category__ = 'architecture' __shared__ = ['for_distill'] - __inject__ = ['post_process'] + __inject__ = ['post_process', 'ssod_loss'] def __init__(self, backbone='CSPResNet', neck='CustomCSPPAN', yolo_head='PPYOLOEHead', post_process='BBoxPostProcess', + ssod_loss='SSODPPYOLOELoss', for_distill=False, feat_distill_place='neck_feats', for_mot=False): - """ - PPYOLOE network, see https://arxiv.org/abs/2203.16250 - - Args: - backbone (nn.Layer): backbone instance - neck (nn.Layer): neck instance - yolo_head (nn.Layer): anchor_head instance - post_process (object): `BBoxPostProcess` instance - for_mot (bool): whether return other features for multi-object tracking - models, default False in pure object detection models. - """ super(PPYOLOE, self).__init__() self.backbone = backbone self.neck = neck @@ -62,8 +63,9 @@ class PPYOLOE(BaseArch): self.post_process = post_process self.for_mot = for_mot - # semi-det + # for ssod, semi-det self.is_teacher = False + self.ssod_loss = ssod_loss # distill self.for_distill = for_distill @@ -73,14 +75,11 @@ class PPYOLOE(BaseArch): @classmethod def from_config(cls, cfg, *args, **kwargs): - # backbone backbone = create(cfg['backbone']) - # fpn kwargs = {'input_shape': backbone.out_shape} neck = create(cfg['neck'], **kwargs) - # head kwargs = {'input_shape': neck.out_shape} yolo_head = create(cfg['yolo_head'], **kwargs) @@ -134,106 +133,10 @@ class PPYOLOE(BaseArch): def get_loss_keys(self): return ['loss_cls', 'loss_iou', 'loss_dfl', 'loss_contrast'] - def get_ssod_distill_loss(self, student_head_outs, teacher_head_outs, - train_cfg): - # for semi-det distill - # student_probs: already sigmoid - student_probs, student_deltas, student_dfl = student_head_outs - teacher_probs, teacher_deltas, teacher_dfl = teacher_head_outs - bs, l, nc = student_probs.shape[:] - student_probs = student_probs.reshape([-1, nc]) - teacher_probs = teacher_probs.reshape([-1, nc]) - student_deltas = student_deltas.reshape([-1, 4]) - teacher_deltas = teacher_deltas.reshape([-1, 4]) - student_dfl = student_dfl.reshape([-1, 4, self.yolo_head.reg_channels]) - teacher_dfl = teacher_dfl.reshape([-1, 4, self.yolo_head.reg_channels]) - - ratio = train_cfg.get('ratio', 0.01) - - # for contrast loss - curr_iter = train_cfg['curr_iter'] - st_iter = train_cfg['st_iter'] - if curr_iter == st_iter + 1: - # start semi-det training - self.queue_ptr = 0 - self.queue_size = int(bs * l * ratio) - self.queue_feats = paddle.zeros([self.queue_size, nc]) - self.queue_probs = paddle.zeros([self.queue_size, nc]) - contrast_loss_cfg = train_cfg['contrast_loss'] - temperature = contrast_loss_cfg.get('temperature', 0.2) - alpha = contrast_loss_cfg.get('alpha', 0.9) - smooth_iter = contrast_loss_cfg.get('smooth_iter', 100) + st_iter - - with paddle.no_grad(): - # Region Selection - count_num = int(teacher_probs.shape[0] * ratio) - max_vals = paddle.max(teacher_probs, 1) - sorted_vals, sorted_inds = paddle.topk(max_vals, - teacher_probs.shape[0]) - mask = paddle.zeros_like(max_vals) - mask[sorted_inds[:count_num]] = 1. - fg_num = sorted_vals[:count_num].sum() - b_mask = mask > 0. - - # for contrast loss - probs = teacher_probs[b_mask].detach() - if curr_iter > smooth_iter: # memory-smoothing - A = paddle.exp( - paddle.mm(teacher_probs[b_mask], self.queue_probs.t()) / - temperature) - A = A / A.sum(1, keepdim=True) - probs = alpha * probs + (1 - alpha) * paddle.mm( - A, self.queue_probs) - n = student_probs[b_mask].shape[0] - # update memory bank - self.queue_feats[self.queue_ptr:self.queue_ptr + - n, :] = teacher_probs[b_mask].detach() - self.queue_probs[self.queue_ptr:self.queue_ptr + - n, :] = teacher_probs[b_mask].detach() - self.queue_ptr = (self.queue_ptr + n) % self.queue_size - - # embedding similarity - sim = paddle.exp( - paddle.mm(student_probs[b_mask], teacher_probs[b_mask].t()) / 0.2) - sim_probs = sim / sim.sum(1, keepdim=True) - # pseudo-label graph with self-loop - Q = paddle.mm(probs, probs.t()) - Q.fill_diagonal_(1) - pos_mask = (Q >= 0.5).astype('float32') - Q = Q * pos_mask - Q = Q / Q.sum(1, keepdim=True) - # contrastive loss - loss_contrast = -(paddle.log(sim_probs + 1e-7) * Q).sum(1) - loss_contrast = loss_contrast.mean() - - # distill_loss_cls - loss_cls = QFLv2( - student_probs, teacher_probs, weight=mask, reduction="sum") / fg_num - - # distill_loss_iou - inputs = paddle.concat( - (-student_deltas[b_mask][..., :2], student_deltas[b_mask][..., 2:]), - -1) - targets = paddle.concat( - (-teacher_deltas[b_mask][..., :2], teacher_deltas[b_mask][..., 2:]), - -1) - iou_loss = GIoULoss(reduction='mean') - loss_iou = iou_loss(inputs, targets) - - # distill_loss_dfl - loss_dfl = F.cross_entropy( - student_dfl[b_mask].reshape([-1, self.yolo_head.reg_channels]), - teacher_dfl[b_mask].reshape([-1, self.yolo_head.reg_channels]), - soft_label=True, - reduction='mean') - - return { - "distill_loss_cls": loss_cls, - "distill_loss_iou": loss_iou, - "distill_loss_dfl": loss_dfl, - "distill_loss_contrast": loss_contrast, - "fg_sum": fg_num, - } + def get_ssod_loss(self, student_head_outs, teacher_head_outs, train_cfg): + ssod_losses = self.ssod_loss(student_head_outs, teacher_head_outs, + train_cfg) + return ssod_losses @register diff --git a/ppdet/modeling/ssod/__init__.py b/ppdet/modeling/ssod/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e7588577e943fcac4bbe1f6ea8e1dd17c4ca8362 --- /dev/null +++ b/ppdet/modeling/ssod/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import utils +from . import losses + +from .utils import * +from .losses import * diff --git a/ppdet/modeling/ssod/losses.py b/ppdet/modeling/ssod/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..e4c5038d4b4c6657f8351ccaa3238d639b53d3f9 --- /dev/null +++ b/ppdet/modeling/ssod/losses.py @@ -0,0 +1,236 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +from ppdet.core.workspace import register +from ppdet.modeling.losses.iou_loss import GIoULoss +from .utils import QFLv2 + +from ppdet.utils.logger import setup_logger +logger = setup_logger(__name__) + +__all__ = [ + 'SSODFCOSLoss', + 'SSODPPYOLOELoss', +] + + +@register +class SSODFCOSLoss(nn.Layer): + def __init__(self, loss_weight=1.0): + super(SSODFCOSLoss, self).__init__() + self.loss_weight = loss_weight + + def forward(self, student_head_outs, teacher_head_outs, train_cfg): + # for semi-det distill + student_logits, student_deltas, student_quality = student_head_outs + teacher_logits, teacher_deltas, teacher_quality = teacher_head_outs + nc = student_logits[0].shape[1] + + student_logits = paddle.concat( + [ + _.transpose([0, 2, 3, 1]).reshape([-1, nc]) + for _ in student_logits + ], + axis=0) + teacher_logits = paddle.concat( + [ + _.transpose([0, 2, 3, 1]).reshape([-1, nc]) + for _ in teacher_logits + ], + axis=0) + + student_deltas = paddle.concat( + [ + _.transpose([0, 2, 3, 1]).reshape([-1, 4]) + for _ in student_deltas + ], + axis=0) + teacher_deltas = paddle.concat( + [ + _.transpose([0, 2, 3, 1]).reshape([-1, 4]) + for _ in teacher_deltas + ], + axis=0) + + student_quality = paddle.concat( + [ + _.transpose([0, 2, 3, 1]).reshape([-1, 1]) + for _ in student_quality + ], + axis=0) + teacher_quality = paddle.concat( + [ + _.transpose([0, 2, 3, 1]).reshape([-1, 1]) + for _ in teacher_quality + ], + axis=0) + + ratio = train_cfg.get('ratio', 0.01) + with paddle.no_grad(): + # Region Selection + count_num = int(teacher_logits.shape[0] * ratio) + teacher_probs = F.sigmoid(teacher_logits) + max_vals = paddle.max(teacher_probs, 1) + sorted_vals, sorted_inds = paddle.topk(max_vals, + teacher_logits.shape[0]) + mask = paddle.zeros_like(max_vals) + mask[sorted_inds[:count_num]] = 1. + fg_num = sorted_vals[:count_num].sum() + b_mask = mask > 0 + + # distill_loss_cls + loss_logits = QFLv2( + F.sigmoid(student_logits), + teacher_probs, + weight=mask, + reduction="sum") / fg_num + + # distill_loss_box + inputs = paddle.concat( + (-student_deltas[b_mask][..., :2], student_deltas[b_mask][..., 2:]), + axis=-1) + targets = paddle.concat( + (-teacher_deltas[b_mask][..., :2], teacher_deltas[b_mask][..., 2:]), + axis=-1) + iou_loss = GIoULoss(reduction='mean') + loss_deltas = iou_loss(inputs, targets) + + # distill_loss_quality + loss_quality = F.binary_cross_entropy( + F.sigmoid(student_quality[b_mask]), + F.sigmoid(teacher_quality[b_mask]), + reduction='mean') + + return { + "distill_loss_cls": loss_logits, + "distill_loss_box": loss_deltas, + "distill_loss_quality": loss_quality, + "fg_sum": fg_num, + } + + +@register +class SSODPPYOLOELoss(nn.Layer): + def __init__(self, loss_weight=1.0): + super(SSODPPYOLOELoss, self).__init__() + self.loss_weight = loss_weight + + def forward(self, student_head_outs, teacher_head_outs, train_cfg): + # for semi-det distill + # student_probs: already sigmoid + student_probs, student_deltas, student_dfl = student_head_outs + teacher_probs, teacher_deltas, teacher_dfl = teacher_head_outs + bs, l, nc = student_probs.shape[:] # bs, l, num_classes + bs, l, _, reg_ch = student_dfl.shape[:] # bs, l, 4, reg_ch + student_probs = student_probs.reshape([-1, nc]) + teacher_probs = teacher_probs.reshape([-1, nc]) + student_deltas = student_deltas.reshape([-1, 4]) + teacher_deltas = teacher_deltas.reshape([-1, 4]) + student_dfl = student_dfl.reshape([-1, 4, reg_ch]) + teacher_dfl = teacher_dfl.reshape([-1, 4, reg_ch]) + + ratio = train_cfg.get('ratio', 0.01) + + # for contrast loss + curr_iter = train_cfg['curr_iter'] + st_iter = train_cfg['st_iter'] + if curr_iter == st_iter + 1: + # start semi-det training + self.queue_ptr = 0 + self.queue_size = int(bs * l * ratio) + self.queue_feats = paddle.zeros([self.queue_size, nc]) + self.queue_probs = paddle.zeros([self.queue_size, nc]) + contrast_loss_cfg = train_cfg['contrast_loss'] + temperature = contrast_loss_cfg.get('temperature', 0.2) + alpha = contrast_loss_cfg.get('alpha', 0.9) + smooth_iter = contrast_loss_cfg.get('smooth_iter', 100) + st_iter + + with paddle.no_grad(): + # Region Selection + count_num = int(teacher_probs.shape[0] * ratio) + max_vals = paddle.max(teacher_probs, 1) + sorted_vals, sorted_inds = paddle.topk(max_vals, + teacher_probs.shape[0]) + mask = paddle.zeros_like(max_vals) + mask[sorted_inds[:count_num]] = 1. + fg_num = sorted_vals[:count_num].sum() + b_mask = mask > 0. + + # for contrast loss + probs = teacher_probs[b_mask].detach() + if curr_iter > smooth_iter: # memory-smoothing + A = paddle.exp( + paddle.mm(teacher_probs[b_mask], self.queue_probs.t()) / + temperature) + A = A / A.sum(1, keepdim=True) + probs = alpha * probs + (1 - alpha) * paddle.mm( + A, self.queue_probs) + n = student_probs[b_mask].shape[0] + # update memory bank + self.queue_feats[self.queue_ptr:self.queue_ptr + + n, :] = teacher_probs[b_mask].detach() + self.queue_probs[self.queue_ptr:self.queue_ptr + + n, :] = teacher_probs[b_mask].detach() + self.queue_ptr = (self.queue_ptr + n) % self.queue_size + + # embedding similarity + sim = paddle.exp( + paddle.mm(student_probs[b_mask], teacher_probs[b_mask].t()) / 0.2) + sim_probs = sim / sim.sum(1, keepdim=True) + # pseudo-label graph with self-loop + Q = paddle.mm(probs, probs.t()) + Q.fill_diagonal_(1) + pos_mask = (Q >= 0.5).astype('float32') + Q = Q * pos_mask + Q = Q / Q.sum(1, keepdim=True) + # contrastive loss + loss_contrast = -(paddle.log(sim_probs + 1e-7) * Q).sum(1) + loss_contrast = loss_contrast.mean() + + # distill_loss_cls + loss_cls = QFLv2( + student_probs, teacher_probs, weight=mask, reduction="sum") / fg_num + + # distill_loss_iou + inputs = paddle.concat( + (-student_deltas[b_mask][..., :2], student_deltas[b_mask][..., 2:]), + -1) + targets = paddle.concat( + (-teacher_deltas[b_mask][..., :2], teacher_deltas[b_mask][..., 2:]), + -1) + iou_loss = GIoULoss(reduction='mean') + loss_iou = iou_loss(inputs, targets) + + # distill_loss_dfl + loss_dfl = F.cross_entropy( + student_dfl[b_mask].reshape([-1, reg_ch]), + teacher_dfl[b_mask].reshape([-1, reg_ch]), + soft_label=True, + reduction='mean') + + return { + "distill_loss_cls": loss_cls, + "distill_loss_iou": loss_iou, + "distill_loss_dfl": loss_dfl, + "distill_loss_contrast": loss_contrast, + "fg_sum": fg_num, + } diff --git a/ppdet/modeling/ssod_utils.py b/ppdet/modeling/ssod/utils.py similarity index 88% rename from ppdet/modeling/ssod_utils.py rename to ppdet/modeling/ssod/utils.py index 3f29ef3f4f6d9e2855d7bd9b1bf7bc057bcf9487..09753abfeddd4a017cb64ec8560ad0da1e585708 100644 --- a/ppdet/modeling/ssod_utils.py +++ b/ppdet/modeling/ssod/utils.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -58,17 +58,6 @@ def align_weak_strong_shape(data_weak, data_strong): return data_weak, data_strong -def permute_to_N_HWA_K(tensor, K): - """ - Transpose/reshape a tensor from (N, (A x K), H, W) to (N, (HxWxA), K) - """ - assert tensor.dim() == 4, tensor.shape - N, _, H, W = tensor.shape - tensor = tensor.reshape([N, -1, K, H, W]).transpose([0, 3, 4, 1, 2]) - tensor = tensor.reshape([N, -1, K]) - return tensor - - def QFLv2(pred_sigmoid, teacher_sigmoid, weight=None,