diff --git a/configs/few-shot/README.md b/configs/few-shot/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..b8b4a3226230481fd2e0d5f7e763c7215bcad82c
--- /dev/null
+++ b/configs/few-shot/README.md
@@ -0,0 +1,47 @@
+# Co-tuning for Transfer Learning
Supervised Contrastive Learning
+
+## Data preparation
+以[Kaggle数据集](https://www.kaggle.com/andrewmvd/road-sign-detection) 比赛数据为例,说明如何准备自定义数据。
+Kaggle上的 [road-sign-detection](https://www.kaggle.com/andrewmvd/road-sign-detection) 比赛数据包含877张图像,数据类别4类:crosswalk,speedlimit,stop,trafficlight。
+可从Kaggle上下载,也可以从[下载链接](https://paddlemodels.bj.bcebos.com/object_detection/roadsign_voc.tar) 下载。
+分别从原始数据集中每类选取相同样本(例如:10shots即每类都有十个训练样本)训练即可。
+工业数据集使用PKU-Market-PCB,该数据集用于印刷电路板(PCB)的瑕疵检测,提供了6种常见的PCB缺陷[下载链接](./configs/ppyoloe/application/README.md)
+
+## Model Zoo
+| 骨架网络 | 网络类型 | 每张GPU图片个数 | 每类样本个数 | Box AP | 下载 | 配置文件 |
+| :------------------- | :------------- | :-----: | :------------: | :-----: | :-----------------------------------------------------: | :-----: |
+| ResNet50-vd | Faster | 1 | 10 | 60.1 | [下载链接](https://bj.bcebos.com/v1/paddledet/models/faster_rcnn_r50_vd_fpn_1x_coco.pdparams) | [配置文件](./faster_rcnn_r50_vd_fpn_1x_coco_cotuning_roadsign.yml) |
+| PPYOLOE_crn_s | PPYOLOE | 1 | 30 | 17.8 | [下载链接](https://bj.bcebos.com/v1/paddledet/models/ppyoloe_plus_crn_s_80e_contrast_pcb.pdparams) |[配置文件](./ppyoloe_plus_crn_s_80e_contrast_pcb.yml) |
+
+## Compare-cotuning
+| 骨架网络 | 网络类型 | 每张GPU图片个数 |每类样本个数 | Cotuning | Box AP |
+| :------------------- | :------------- | :-----: | :-----: | :------------: | :-----: |
+| ResNet50-vd | Faster | 1 | 10 | False | 56.7 |
+| ResNet50-vd | Faster | 1 | 10 | True | 60.1 |
+
+## Compare-contrast
+| 骨架网络 | 网络类型 | 每张GPU图片个数 | 每类样本个数 | Contrast | Box AP |
+| :------------------- | :------------- | :-----: | :-----: | :------------: | :-----: |
+| PPYOLOE_crn_s | PPYOLOE | 1 | 30 | False | 15.4 |
+| PPYOLOE_crn_s | PPYOLOE | 1 | 30 | True | 17.8 |
+
+## Citations
+```
+@article{you2020co,
+ title={Co-tuning for transfer learning},
+ author={You, Kaichao and Kou, Zhi and Long, Mingsheng and Wang, Jianmin},
+ journal={Advances in Neural Information Processing Systems},
+ volume={33},
+ pages={17236--17246},
+ year={2020}
+}
+
+@article{khosla2020supervised,
+ title={Supervised contrastive learning},
+ author={Khosla, Prannay and Teterwak, Piotr and Wang, Chen and Sarna, Aaron and Tian, Yonglong and Isola, Phillip and Maschinot, Aaron and Liu, Ce and Krishnan, Dilip},
+ journal={Advances in Neural Information Processing Systems},
+ volume={33},
+ pages={18661--18673},
+ year={2020}
+}
+```
\ No newline at end of file
diff --git a/configs/few-shot/_base_/faster_fpn_reader.yml b/configs/few-shot/_base_/faster_fpn_reader.yml
new file mode 100644
index 0000000000000000000000000000000000000000..9b9abccd63e499bfa9402f3038425470e4a6e953
--- /dev/null
+++ b/configs/few-shot/_base_/faster_fpn_reader.yml
@@ -0,0 +1,40 @@
+worker_num: 2
+TrainReader:
+ sample_transforms:
+ - Decode: {}
+ - RandomResize: {target_size: [[640, 1333], [672, 1333], [704, 1333], [736, 1333], [768, 1333], [800, 1333]], interp: 2, keep_ratio: True}
+ - RandomFlip: {prob: 0.5}
+ - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
+ - Permute: {}
+ batch_transforms:
+ - PadBatch: {pad_to_stride: 32}
+ batch_size: 1
+ shuffle: true
+ drop_last: true
+ collate_batch: false
+
+
+EvalReader:
+ sample_transforms:
+ - Decode: {}
+ - Resize: {interp: 2, target_size: [800, 1333], keep_ratio: True}
+ - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
+ - Permute: {}
+ batch_transforms:
+ - PadBatch: {pad_to_stride: 32}
+ batch_size: 1
+ shuffle: false
+ drop_last: false
+
+
+TestReader:
+ sample_transforms:
+ - Decode: {}
+ - Resize: {interp: 2, target_size: [800, 1333], keep_ratio: True}
+ - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
+ - Permute: {}
+ batch_transforms:
+ - PadBatch: {pad_to_stride: 32}
+ batch_size: 1
+ shuffle: false
+ drop_last: false
diff --git a/configs/few-shot/_base_/faster_rcnn_r50.yml b/configs/few-shot/_base_/faster_rcnn_r50.yml
new file mode 100644
index 0000000000000000000000000000000000000000..fd29f5ea1a1df9e2599d3efcff344c5d3363945e
--- /dev/null
+++ b/configs/few-shot/_base_/faster_rcnn_r50.yml
@@ -0,0 +1,66 @@
+architecture: FasterRCNN
+pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet50_cos_pretrained.pdparams
+
+FasterRCNN:
+ backbone: ResNet
+ rpn_head: RPNHead
+ bbox_head: BBoxHead
+ # post process
+ bbox_post_process: BBoxPostProcess
+
+
+ResNet:
+ # index 0 stands for res2
+ depth: 50
+ norm_type: bn
+ freeze_at: 0
+ return_idx: [2]
+ num_stages: 3
+
+RPNHead:
+ anchor_generator:
+ aspect_ratios: [0.5, 1.0, 2.0]
+ anchor_sizes: [32, 64, 128, 256, 512]
+ strides: [16]
+ rpn_target_assign:
+ batch_size_per_im: 256
+ fg_fraction: 0.5
+ negative_overlap: 0.3
+ positive_overlap: 0.7
+ use_random: True
+ train_proposal:
+ min_size: 0.0
+ nms_thresh: 0.7
+ pre_nms_top_n: 12000
+ post_nms_top_n: 2000
+ topk_after_collect: False
+ test_proposal:
+ min_size: 0.0
+ nms_thresh: 0.7
+ pre_nms_top_n: 6000
+ post_nms_top_n: 1000
+
+
+BBoxHead:
+ head: Res5Head
+ roi_extractor:
+ resolution: 14
+ sampling_ratio: 0
+ aligned: True
+ bbox_assigner: BBoxAssigner
+ with_pool: true
+
+BBoxAssigner:
+ batch_size_per_im: 512
+ bg_thresh: 0.5
+ fg_thresh: 0.5
+ fg_fraction: 0.25
+ use_random: True
+
+BBoxPostProcess:
+ decode: RCNNBox
+ nms:
+ name: MultiClassNMS
+ keep_top_k: 100
+ score_threshold: 0.05
+ nms_threshold: 0.5
diff --git a/configs/few-shot/_base_/faster_rcnn_r50_fpn.yml b/configs/few-shot/_base_/faster_rcnn_r50_fpn.yml
new file mode 100644
index 0000000000000000000000000000000000000000..38ee81def0cb528f3f67e8ed616b9589bd72de9e
--- /dev/null
+++ b/configs/few-shot/_base_/faster_rcnn_r50_fpn.yml
@@ -0,0 +1,73 @@
+architecture: FasterRCNN
+pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet50_cos_pretrained.pdparams
+
+FasterRCNN:
+ backbone: ResNet
+ neck: FPN
+ rpn_head: RPNHead
+ bbox_head: BBoxHead
+ # post process
+ bbox_post_process: BBoxPostProcess
+
+
+ResNet:
+ # index 0 stands for res2
+ depth: 50
+ norm_type: bn
+ freeze_at: 0
+ return_idx: [0,1,2,3]
+ num_stages: 4
+
+FPN:
+ out_channel: 256
+
+RPNHead:
+ anchor_generator:
+ aspect_ratios: [0.5, 1.0, 2.0]
+ anchor_sizes: [[32], [64], [128], [256], [512]]
+ strides: [4, 8, 16, 32, 64]
+ rpn_target_assign:
+ batch_size_per_im: 256
+ fg_fraction: 0.5
+ negative_overlap: 0.3
+ positive_overlap: 0.7
+ use_random: True
+ train_proposal:
+ min_size: 0.0
+ nms_thresh: 0.7
+ pre_nms_top_n: 2000
+ post_nms_top_n: 1000
+ topk_after_collect: True
+ test_proposal:
+ min_size: 0.0
+ nms_thresh: 0.7
+ pre_nms_top_n: 1000
+ post_nms_top_n: 1000
+
+
+BBoxHead:
+ head: TwoFCHead
+ roi_extractor:
+ resolution: 7
+ sampling_ratio: 0
+ aligned: True
+ bbox_assigner: BBoxAssigner
+
+BBoxAssigner:
+ batch_size_per_im: 512
+ bg_thresh: 0.5
+ fg_thresh: 0.5
+ fg_fraction: 0.25
+ use_random: True
+
+TwoFCHead:
+ out_channel: 1024
+
+
+BBoxPostProcess:
+ decode: RCNNBox
+ nms:
+ name: MultiClassNMS
+ keep_top_k: 100
+ score_threshold: 0.05
+ nms_threshold: 0.5
diff --git a/configs/few-shot/_base_/faster_reader.yml b/configs/few-shot/_base_/faster_reader.yml
new file mode 100644
index 0000000000000000000000000000000000000000..e1c1bb6bc262e86ea69ae78919064aa2b6834311
--- /dev/null
+++ b/configs/few-shot/_base_/faster_reader.yml
@@ -0,0 +1,40 @@
+worker_num: 2
+TrainReader:
+ sample_transforms:
+ - Decode: {}
+ - RandomResize: {target_size: [[640, 1333], [672, 1333], [704, 1333], [736, 1333], [768, 1333], [800, 1333]], interp: 2, keep_ratio: True}
+ - RandomFlip: {prob: 0.5}
+ - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
+ - Permute: {}
+ batch_transforms:
+ - PadBatch: {pad_to_stride: -1}
+ batch_size: 1
+ shuffle: true
+ drop_last: true
+ collate_batch: false
+
+
+EvalReader:
+ sample_transforms:
+ - Decode: {}
+ - Resize: {interp: 2, target_size: [800, 1333], keep_ratio: True}
+ - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
+ - Permute: {}
+ batch_transforms:
+ - PadBatch: {pad_to_stride: -1}
+ batch_size: 1
+ shuffle: false
+ drop_last: false
+
+
+TestReader:
+ sample_transforms:
+ - Decode: {}
+ - Resize: {interp: 2, target_size: [800, 1333], keep_ratio: True}
+ - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
+ - Permute: {}
+ batch_transforms:
+ - PadBatch: {pad_to_stride: -1}
+ batch_size: 1
+ shuffle: false
+ drop_last: false
diff --git a/configs/few-shot/_base_/optimizer_1x.yml b/configs/few-shot/_base_/optimizer_1x.yml
new file mode 100644
index 0000000000000000000000000000000000000000..4caaa63bda15917137a9ac22b736ae83c3d04856
--- /dev/null
+++ b/configs/few-shot/_base_/optimizer_1x.yml
@@ -0,0 +1,19 @@
+epoch: 12
+
+LearningRate:
+ base_lr: 0.01
+ schedulers:
+ - !PiecewiseDecay
+ gamma: 0.1
+ milestones: [8, 11]
+ - !LinearWarmup
+ start_factor: 0.1
+ steps: 1000
+
+OptimizerBuilder:
+ optimizer:
+ momentum: 0.9
+ type: Momentum
+ regularizer:
+ factor: 0.0001
+ type: L2
diff --git a/configs/few-shot/_base_/optimizer_80e.yml b/configs/few-shot/_base_/optimizer_80e.yml
new file mode 100644
index 0000000000000000000000000000000000000000..7a8773df15aa103f3194f56634604d84a2a084eb
--- /dev/null
+++ b/configs/few-shot/_base_/optimizer_80e.yml
@@ -0,0 +1,18 @@
+epoch: 80
+
+LearningRate:
+ base_lr: 0.001
+ schedulers:
+ - !CosineDecay
+ max_epochs: 96
+ - !LinearWarmup
+ start_factor: 0.
+ epochs: 5
+
+OptimizerBuilder:
+ optimizer:
+ momentum: 0.9
+ type: Momentum
+ regularizer:
+ factor: 0.0005
+ type: L2
diff --git a/configs/few-shot/_base_/ppyoloe_plus_crn.yml b/configs/few-shot/_base_/ppyoloe_plus_crn.yml
new file mode 100644
index 0000000000000000000000000000000000000000..a83f35008f4797311689ed952abef15df0c0eea7
--- /dev/null
+++ b/configs/few-shot/_base_/ppyoloe_plus_crn.yml
@@ -0,0 +1,49 @@
+architecture: YOLOv3
+norm_type: sync_bn
+use_ema: true
+use_cot: False
+ema_decay: 0.9998
+ema_black_list: ['proj_conv.weight']
+custom_black_list: ['reduce_mean']
+
+YOLOv3:
+ backbone: CSPResNet
+ neck: CustomCSPPAN
+ yolo_head: PPYOLOEHead
+ post_process: ~
+
+CSPResNet:
+ layers: [3, 6, 6, 3]
+ channels: [64, 128, 256, 512, 1024]
+ return_idx: [1, 2, 3]
+ use_large_stem: True
+ use_alpha: True
+
+CustomCSPPAN:
+ out_channels: [768, 384, 192]
+ stage_num: 1
+ block_num: 3
+ act: 'swish'
+ spp: true
+
+PPYOLOEHead:
+ fpn_strides: [32, 16, 8]
+ grid_cell_scale: 5.0
+ grid_cell_offset: 0.5
+ static_assigner_epoch: 30
+ use_varifocal_loss: True
+ loss_weight: {class: 1.0, iou: 2.5, dfl: 0.5}
+ static_assigner:
+ name: ATSSAssigner
+ topk: 9
+ assigner:
+ name: TaskAlignedAssigner
+ topk: 13
+ alpha: 1.0
+ beta: 6.0
+ nms:
+ name: MultiClassNMS
+ nms_top_k: 1000
+ keep_top_k: 300
+ score_threshold: 0.01
+ nms_threshold: 0.7
diff --git a/configs/few-shot/_base_/ppyoloe_plus_reader.yml b/configs/few-shot/_base_/ppyoloe_plus_reader.yml
new file mode 100644
index 0000000000000000000000000000000000000000..cd9cdeff8b9d46e41a4e6fb518339168dfd4b154
--- /dev/null
+++ b/configs/few-shot/_base_/ppyoloe_plus_reader.yml
@@ -0,0 +1,40 @@
+worker_num: 4
+eval_height: &eval_height 640
+eval_width: &eval_width 640
+eval_size: &eval_size [*eval_height, *eval_width]
+
+TrainReader:
+ sample_transforms:
+ - Decode: {}
+ - RandomDistort: {}
+ - RandomExpand: {fill_value: [123.675, 116.28, 103.53]}
+ - RandomCrop: {}
+ - RandomFlip: {}
+ batch_transforms:
+ - BatchRandomResize: {target_size: [320, 352, 384, 416, 448, 480, 512, 544, 576, 608, 640, 672, 704, 736, 768], random_size: True, random_interp: True, keep_ratio: False}
+ - NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none}
+ - Permute: {}
+ - PadGT: {}
+ batch_size: 8
+ shuffle: true
+ drop_last: true
+ use_shared_memory: true
+ collate_batch: true
+
+EvalReader:
+ sample_transforms:
+ - Decode: {}
+ - Resize: {target_size: *eval_size, keep_ratio: False, interp: 2}
+ - NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none}
+ - Permute: {}
+ batch_size: 2
+
+TestReader:
+ inputs_def:
+ image_shape: [3, *eval_height, *eval_width]
+ sample_transforms:
+ - Decode: {}
+ - Resize: {target_size: *eval_size, keep_ratio: False, interp: 2}
+ - NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none}
+ - Permute: {}
+ batch_size: 1
diff --git a/configs/few-shot/faster_rcnn_r50_vd_fpn_1x_coco_cotuning_roadsign.yml b/configs/few-shot/faster_rcnn_r50_vd_fpn_1x_coco_cotuning_roadsign.yml
new file mode 100644
index 0000000000000000000000000000000000000000..75fd9e3d0ccaa56fd77d8851711b8a44720df566
--- /dev/null
+++ b/configs/few-shot/faster_rcnn_r50_vd_fpn_1x_coco_cotuning_roadsign.yml
@@ -0,0 +1,67 @@
+_BASE_: [
+ '../datasets/coco_detection.yml',
+ '../runtime.yml',
+ '_base_/optimizer_1x.yml',
+ '_base_/faster_rcnn_r50_fpn.yml',
+ '_base_/faster_fpn_reader.yml',
+]
+pretrain_weights: https://paddledet.bj.bcebos.com/models/faster_rcnn_r50_vd_fpn_1x_coco.pdparams
+weights: output/faster_rcnn_r50_vd_fpn_1x_coco_cotuning_roadsign/model_final
+
+snapshot_epoch: 5
+
+ResNet:
+ # index 0 stands for res2
+ depth: 50
+ variant: d
+ norm_type: bn
+ freeze_at: 0
+ return_idx: [0,1,2,3]
+ num_stages: 4
+
+epoch: 30
+LearningRate:
+ base_lr: 0.001
+ schedulers:
+ - !PiecewiseDecay
+ gamma: 0.1
+ milestones: [8, 11]
+ - !LinearWarmup
+ start_factor: 0.1
+ steps: 1000
+
+use_cot: True
+BBoxHead:
+ head: TwoFCHead
+ roi_extractor:
+ resolution: 7
+ sampling_ratio: 0
+ aligned: True
+ bbox_assigner: BBoxAssigner
+ cot_classes: 80
+ loss_cot:
+ name: COTLoss
+ cot_lambda: 1
+ cot_scale: 1
+
+num_classes: 4
+metric: COCO
+map_type: integral
+
+TrainDataset:
+ !COCODataSet
+ image_dir: images
+ anno_path: annotations/train_shots10.json
+ dataset_dir: dataset/roadsign_coco
+ data_fields: ['image', 'gt_bbox', 'gt_class', 'is_crowd']
+
+EvalDataset:
+ !COCODataSet
+ image_dir: images
+ anno_path: annotations/roadsign_valid.json
+ dataset_dir: dataset/roadsign_coco
+
+TestDataset:
+ !ImageFolder
+ anno_path: annotations/roadsign_valid.json
+ dataset_dir: dataset/roadsign_coco
\ No newline at end of file
diff --git a/configs/few-shot/ppyoloe_plus_crn_s_80e_contrast_pcb.yml b/configs/few-shot/ppyoloe_plus_crn_s_80e_contrast_pcb.yml
new file mode 100644
index 0000000000000000000000000000000000000000..05320089fcb1f19a690edd030f9b57b502909a38
--- /dev/null
+++ b/configs/few-shot/ppyoloe_plus_crn_s_80e_contrast_pcb.yml
@@ -0,0 +1,81 @@
+_BASE_: [
+ '../datasets/coco_detection.yml',
+ '../runtime.yml',
+ './_base_/optimizer_80e.yml',
+ './_base_/ppyoloe_plus_crn.yml',
+ './_base_/ppyoloe_plus_reader.yml',
+]
+
+log_iter: 100
+snapshot_epoch: 10
+weights: output/ppyoloe_plus_crn_s_80e_contrast_pcb/model_final
+
+pretrain_weights: https://bj.bcebos.com/v1/paddledet/models/pretrained/ppyoloe_crn_s_obj365_pretrained.pdparams
+depth_mult: 0.33
+width_mult: 0.50
+
+epoch: 80
+
+LearningRate:
+ base_lr: 0.0001
+ schedulers:
+ - !CosineDecay
+ max_epochs: 96
+ - !LinearWarmup
+ start_factor: 0.
+ epochs: 5
+
+YOLOv3:
+ backbone: CSPResNet
+ neck: CustomCSPPAN
+ yolo_head: PPYOLOEContrastHead
+ post_process: ~
+
+PPYOLOEContrastHead:
+ fpn_strides: [32, 16, 8]
+ grid_cell_scale: 5.0
+ grid_cell_offset: 0.5
+ static_assigner_epoch: 100
+ use_varifocal_loss: True
+ loss_weight: {class: 1.0, iou: 2.5, dfl: 0.5, contrast: 0.2}
+ static_assigner:
+ name: ATSSAssigner
+ topk: 9
+ assigner:
+ name: TaskAlignedAssigner
+ topk: 13
+ alpha: 1.0
+ beta: 6.0
+ contrast_loss:
+ name: SupContrast
+ temperature: 100
+ sample_num: 2048
+ thresh: 0.75
+ nms:
+ name: MultiClassNMS
+ nms_top_k: 1000
+ keep_top_k: 300
+ score_threshold: 0.01
+ nms_threshold: 0.7
+
+num_classes: 6
+metric: COCO
+map_type: integral
+
+TrainDataset:
+ !COCODataSet
+ image_dir: images
+ anno_path: pcb_cocoanno/train_shots30.json
+ dataset_dir: dataset/pcb
+ data_fields: ['image', 'gt_bbox', 'gt_class', 'is_crowd']
+
+EvalDataset:
+ !COCODataSet
+ image_dir: images
+ anno_path: pcb_cocoanno/val.json
+ dataset_dir: dataset/pcb
+
+TestDataset:
+ !ImageFolder
+ anno_path: pcb_cocoanno/val.json
+ dataset_dir: dataset/pcb
\ No newline at end of file
diff --git a/ppdet/engine/__init__.py b/ppdet/engine/__init__.py
index 32cb85e224df9db965ee47a989c2aef0284dd9fd..91166e8764f521cb3dd78ba86c5681c4b531413c 100644
--- a/ppdet/engine/__init__.py
+++ b/ppdet/engine/__init__.py
@@ -15,6 +15,9 @@
from . import trainer
from .trainer import *
+from . import trainer_cot
+from .trainer_cot import *
+
from . import callbacks
from .callbacks import *
diff --git a/ppdet/engine/trainer_cot.py b/ppdet/engine/trainer_cot.py
new file mode 100644
index 0000000000000000000000000000000000000000..38d95fabfd0d19312af3cc40309cc8051ff538c3
--- /dev/null
+++ b/ppdet/engine/trainer_cot.py
@@ -0,0 +1,42 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ppdet.core.workspace import create
+from ppdet.utils.logger import setup_logger
+logger = setup_logger('ppdet.engine')
+
+from . import Trainer
+__all__ = ['TrainerCot']
+
+class TrainerCot(Trainer):
+ """
+ Trainer for label-cotuning
+ calculate the relationship between base_classes and novel_classes
+ """
+ def __init__(self, cfg, mode='train'):
+ super(TrainerCot, self).__init__(cfg, mode)
+ self.cotuning_init()
+
+ def cotuning_init(self):
+ num_classes_novel = self.cfg['num_classes']
+
+ self.load_weights(self.cfg.pretrain_weights)
+
+ self.model.eval()
+ relationship = self.model.relationship_learning(self.loader, num_classes_novel)
+
+ self.model.init_cot_head(relationship)
+ self.optimizer = create('OptimizerBuilder')(self.lr, self.model)
+
+
diff --git a/ppdet/modeling/architectures/faster_rcnn.py b/ppdet/modeling/architectures/faster_rcnn.py
index ce9a8e4b57d2dfe54fde037fed2dc0156cb71b51..e5972b2d8f39b3667155e91851551c35c13a343e 100644
--- a/ppdet/modeling/architectures/faster_rcnn.py
+++ b/ppdet/modeling/architectures/faster_rcnn.py
@@ -19,6 +19,7 @@ from __future__ import print_function
import paddle
from ppdet.core.workspace import register, create
from .meta_arch import BaseArch
+import numpy as np
__all__ = ['FasterRCNN']
@@ -51,6 +52,9 @@ class FasterRCNN(BaseArch):
self.bbox_head = bbox_head
self.bbox_post_process = bbox_post_process
+ def init_cot_head(self, relationship):
+ self.bbox_head.init_cot_head(relationship)
+
@classmethod
def from_config(cls, cfg, *args, **kwargs):
backbone = create(cfg['backbone'])
@@ -104,3 +108,38 @@ class FasterRCNN(BaseArch):
bbox_pred, bbox_num = self._forward()
output = {'bbox': bbox_pred, 'bbox_num': bbox_num}
return output
+
+ def target_bbox_forward(self, data):
+ body_feats = self.backbone(data)
+ if self.neck is not None:
+ body_feats = self.neck(body_feats)
+ rois = [roi for roi in data['gt_bbox']]
+ rois_num = paddle.concat([paddle.shape(roi)[0] for roi in rois])
+
+ preds, _ = self.bbox_head(body_feats, rois, rois_num, None, cot=True)
+ return preds
+
+ def relationship_learning(self, loader, num_classes_novel):
+ print('computing relationship')
+ train_labels_list = []
+ label_list = []
+
+ for step_id, data in enumerate(loader):
+ _, bbox_prob = self.target_bbox_forward(data)
+ batch_size = data['im_id'].shape[0]
+ for i in range(batch_size):
+ num_bbox = data['gt_class'][i].shape[0]
+ train_labels = data['gt_class'][i]
+ train_labels_list.append(train_labels.numpy().squeeze(1))
+ base_labels = bbox_prob.detach().numpy()[:,:-1]
+ label_list.append(base_labels)
+
+ labels = np.concatenate(train_labels_list, 0)
+ probabilities = np.concatenate(label_list, 0)
+ N_t = np.max(labels) + 1
+ conditional = []
+ for i in range(N_t):
+ this_class = probabilities[labels == i]
+ average = np.mean(this_class, axis=0, keepdims=True)
+ conditional.append(average)
+ return np.concatenate(conditional)
\ No newline at end of file
diff --git a/ppdet/modeling/heads/__init__.py b/ppdet/modeling/heads/__init__.py
index f54bb03ea2919d5d70e56f577ca46b79c14ba81f..8856d7c182f4efcd41f9366fb807808160ad7e36 100644
--- a/ppdet/modeling/heads/__init__.py
+++ b/ppdet/modeling/heads/__init__.py
@@ -37,6 +37,7 @@ from . import fcosr_head
from . import ppyoloe_r_head
from . import ld_gfl_head
from . import yolof_head
+from . import ppyoloe_contrast_head
from .bbox_head import *
from .mask_head import *
@@ -63,3 +64,4 @@ from .fcosr_head import *
from .ld_gfl_head import *
from .ppyoloe_r_head import *
from .yolof_head import *
+from .ppyoloe_contrast_head import *
\ No newline at end of file
diff --git a/ppdet/modeling/heads/bbox_head.py b/ppdet/modeling/heads/bbox_head.py
index e0041be371a23e31842165cd5e5a0e4d95265c8c..3ce47983cec3e7e5961be95e719a1187c7d7fc55 100644
--- a/ppdet/modeling/heads/bbox_head.py
+++ b/ppdet/modeling/heads/bbox_head.py
@@ -160,8 +160,8 @@ class XConvNormHead(nn.Layer):
@register
class BBoxHead(nn.Layer):
- __shared__ = ['num_classes']
- __inject__ = ['bbox_assigner', 'bbox_loss']
+ __shared__ = ['num_classes', 'use_cot']
+ __inject__ = ['bbox_assigner', 'bbox_loss', 'loss_cot']
"""
RCNN bbox head
@@ -173,7 +173,10 @@ class BBoxHead(nn.Layer):
box.
with_pool (bool): Whether to use pooling for the RoI feature.
num_classes (int): The number of classes
- bbox_weight (List[float]): The weight to get the decode box
+ bbox_weight (List[float]): The weight to get the decode box
+ cot_classes (int): The number of base classes
+ loss_cot (object): The module of Label-cotuning
+ use_cot(bool): whether to use Label-cotuning
"""
def __init__(self,
@@ -185,7 +188,10 @@ class BBoxHead(nn.Layer):
num_classes=80,
bbox_weight=[10., 10., 5., 5.],
bbox_loss=None,
- loss_normalize_pos=False):
+ loss_normalize_pos=False,
+ cot_classes=None,
+ loss_cot='COTLoss',
+ use_cot=False):
super(BBoxHead, self).__init__()
self.head = head
self.roi_extractor = roi_extractor
@@ -199,11 +205,29 @@ class BBoxHead(nn.Layer):
self.bbox_loss = bbox_loss
self.loss_normalize_pos = loss_normalize_pos
- self.bbox_score = nn.Linear(
- in_channel,
- self.num_classes + 1,
- weight_attr=paddle.ParamAttr(initializer=Normal(
- mean=0.0, std=0.01)))
+ self.loss_cot = loss_cot
+ self.cot_relation = None
+ self.cot_classes = cot_classes
+ self.use_cot = use_cot
+ if use_cot:
+ self.cot_bbox_score = nn.Linear(
+ in_channel,
+ self.num_classes + 1,
+ weight_attr=paddle.ParamAttr(initializer=Normal(
+ mean=0.0, std=0.01)))
+
+ self.bbox_score = nn.Linear(
+ in_channel,
+ self.cot_classes + 1,
+ weight_attr=paddle.ParamAttr(initializer=Normal(
+ mean=0.0, std=0.01)))
+ self.cot_bbox_score.skip_quant = True
+ else:
+ self.bbox_score = nn.Linear(
+ in_channel,
+ self.num_classes + 1,
+ weight_attr=paddle.ParamAttr(initializer=Normal(
+ mean=0.0, std=0.01)))
self.bbox_score.skip_quant = True
self.bbox_delta = nn.Linear(
@@ -215,6 +239,9 @@ class BBoxHead(nn.Layer):
self.assigned_label = None
self.assigned_rois = None
+ def init_cot_head(self, relationship):
+ self.cot_relation = relationship
+
@classmethod
def from_config(cls, cfg, input_shape):
roi_pooler = cfg['roi_extractor']
@@ -229,7 +256,7 @@ class BBoxHead(nn.Layer):
'in_channel': head.out_shape[0].channels
}
- def forward(self, body_feats=None, rois=None, rois_num=None, inputs=None):
+ def forward(self, body_feats=None, rois=None, rois_num=None, inputs=None, cot=False):
"""
body_feats (list[Tensor]): Feature maps from backbone
rois (list[Tensor]): RoIs generated from RPN module
@@ -248,7 +275,11 @@ class BBoxHead(nn.Layer):
feat = paddle.squeeze(feat, axis=[2, 3])
else:
feat = bbox_feat
- scores = self.bbox_score(feat)
+ if self.use_cot:
+ scores = self.cot_bbox_score(feat)
+ cot_scores = self.bbox_score(feat)
+ else:
+ scores = self.bbox_score(feat)
deltas = self.bbox_delta(feat)
if self.training:
@@ -259,11 +290,19 @@ class BBoxHead(nn.Layer):
rois,
self.bbox_weight,
loss_normalize_pos=self.loss_normalize_pos)
+
+ if self.cot_relation is not None:
+ loss_cot = self.loss_cot(cot_scores, targets, self.cot_relation)
+ loss.update(loss_cot)
return loss, bbox_feat
else:
- pred = self.get_prediction(scores, deltas)
+ if cot:
+ pred = self.get_prediction(cot_scores, deltas)
+ else:
+ pred = self.get_prediction(scores, deltas)
return pred, self.head
+
def get_loss(self,
scores,
deltas,
diff --git a/ppdet/modeling/heads/ppyoloe_contrast_head.py b/ppdet/modeling/heads/ppyoloe_contrast_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..190c519cfe58cee624cec2a2b71b07003576cd2c
--- /dev/null
+++ b/ppdet/modeling/heads/ppyoloe_contrast_head.py
@@ -0,0 +1,212 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+from ppdet.core.workspace import register
+
+from ..bbox_utils import batch_distance2bbox
+from ..losses import GIoULoss
+from ..initializer import bias_init_with_prob, constant_, normal_
+from ..assigners.utils import generate_anchors_for_grid_cell
+from ppdet.modeling.backbones.cspresnet import ConvBNLayer
+from ppdet.modeling.ops import get_static_shape, get_act_fn
+from ppdet.modeling.layers import MultiClassNMS
+from ppdet.modeling.heads.ppyoloe_head import PPYOLOEHead
+__all__ = ['PPYOLOEContrastHead']
+
+@register
+class PPYOLOEContrastHead(PPYOLOEHead):
+ __shared__ = [
+ 'num_classes', 'eval_size', 'trt', 'exclude_nms',
+ 'exclude_post_process', 'use_shared_conv'
+ ]
+ __inject__ = ['static_assigner', 'assigner', 'nms', 'contrast_loss']
+
+ def __init__(self,
+ in_channels=[1024, 512, 256],
+ num_classes=80,
+ act='swish',
+ fpn_strides=(32, 16, 8),
+ grid_cell_scale=5.0,
+ grid_cell_offset=0.5,
+ reg_max=16,
+ reg_range=None,
+ static_assigner_epoch=4,
+ use_varifocal_loss=True,
+ static_assigner='ATSSAssigner',
+ assigner='TaskAlignedAssigner',
+ contrast_loss='SupContrast',
+ nms='MultiClassNMS',
+ eval_size=None,
+ loss_weight={
+ 'class': 1.0,
+ 'iou': 2.5,
+ 'dfl': 0.5,
+ },
+ trt=False,
+ exclude_nms=False,
+ exclude_post_process=False,
+ use_shared_conv=True):
+ super().__init__(in_channels,
+ num_classes,
+ act,
+ fpn_strides,
+ grid_cell_scale,
+ grid_cell_offset,
+ reg_max,
+ reg_range,
+ static_assigner_epoch,
+ use_varifocal_loss,
+ static_assigner,
+ assigner,
+ nms,
+ eval_size,
+ loss_weight,
+ trt,
+ exclude_nms,
+ exclude_post_process,
+ use_shared_conv)
+
+ assert len(in_channels) > 0, "len(in_channels) should > 0"
+ self.contrast_loss = contrast_loss
+ self.contrast_encoder = nn.LayerList()
+ for in_c in self.in_channels:
+ self.contrast_encoder.append(
+ nn.Conv2D(
+ in_c, 128, 3, padding=1))
+ self._init_contrast_encoder()
+
+ def _init_contrast_encoder(self):
+ bias_en = bias_init_with_prob(0.01)
+ for en_ in self.contrast_encoder:
+ constant_(en_.weight)
+ constant_(en_.bias, bias_en)
+
+ def forward_train(self, feats, targets):
+ anchors, anchor_points, num_anchors_list, stride_tensor = \
+ generate_anchors_for_grid_cell(
+ feats, self.fpn_strides, self.grid_cell_scale,
+ self.grid_cell_offset)
+
+ cls_score_list, reg_distri_list = [], []
+ contrast_encoder_list = []
+ for i, feat in enumerate(feats):
+ avg_feat = F.adaptive_avg_pool2d(feat, (1, 1))
+ cls_logit = self.pred_cls[i](self.stem_cls[i](feat, avg_feat) +
+ feat)
+ reg_distri = self.pred_reg[i](self.stem_reg[i](feat, avg_feat))
+ contrast_logit = self.contrast_encoder[i](self.stem_cls[i](feat, avg_feat) +
+ feat)
+ contrast_encoder_list.append(contrast_logit.flatten(2).transpose([0, 2, 1]))
+ # cls and reg
+ cls_score = F.sigmoid(cls_logit)
+ cls_score_list.append(cls_score.flatten(2).transpose([0, 2, 1]))
+ reg_distri_list.append(reg_distri.flatten(2).transpose([0, 2, 1]))
+ cls_score_list = paddle.concat(cls_score_list, axis=1)
+ reg_distri_list = paddle.concat(reg_distri_list, axis=1)
+ contrast_encoder_list = paddle.concat(contrast_encoder_list, axis=1)
+
+ return self.get_loss([
+ cls_score_list, reg_distri_list, contrast_encoder_list, anchors, anchor_points,
+ num_anchors_list, stride_tensor
+ ], targets)
+
+ def get_loss(self, head_outs, gt_meta):
+ pred_scores, pred_distri, pred_contrast_encoder, anchors,\
+ anchor_points, num_anchors_list, stride_tensor = head_outs
+
+ anchor_points_s = anchor_points / stride_tensor
+ pred_bboxes = self._bbox_decode(anchor_points_s, pred_distri)
+
+ gt_labels = gt_meta['gt_class']
+ gt_bboxes = gt_meta['gt_bbox']
+ pad_gt_mask = gt_meta['pad_gt_mask']
+ # label assignment
+ if gt_meta['epoch_id'] < self.static_assigner_epoch:
+ assigned_labels, assigned_bboxes, assigned_scores = \
+ self.static_assigner(
+ anchors,
+ num_anchors_list,
+ gt_labels,
+ gt_bboxes,
+ pad_gt_mask,
+ bg_index=self.num_classes,
+ pred_bboxes=pred_bboxes.detach() * stride_tensor)
+ alpha_l = 0.25
+ else:
+ if self.sm_use:
+ assigned_labels, assigned_bboxes, assigned_scores = \
+ self.assigner(
+ pred_scores.detach(),
+ pred_bboxes.detach() * stride_tensor,
+ anchor_points,
+ stride_tensor,
+ gt_labels,
+ gt_bboxes,
+ pad_gt_mask,
+ bg_index=self.num_classes)
+ else:
+ assigned_labels, assigned_bboxes, assigned_scores = \
+ self.assigner(
+ pred_scores.detach(),
+ pred_bboxes.detach() * stride_tensor,
+ anchor_points,
+ num_anchors_list,
+ gt_labels,
+ gt_bboxes,
+ pad_gt_mask,
+ bg_index=self.num_classes)
+ alpha_l = -1
+ # rescale bbox
+ assigned_bboxes /= stride_tensor
+ # cls loss
+ if self.use_varifocal_loss:
+ one_hot_label = F.one_hot(assigned_labels,
+ self.num_classes + 1)[..., :-1]
+ loss_cls = self._varifocal_loss(pred_scores, assigned_scores,
+ one_hot_label)
+ else:
+ loss_cls = self._focal_loss(pred_scores, assigned_scores, alpha_l)
+
+ assigned_scores_sum = assigned_scores.sum()
+ if paddle.distributed.get_world_size() > 1:
+ paddle.distributed.all_reduce(assigned_scores_sum)
+ assigned_scores_sum /= paddle.distributed.get_world_size()
+ assigned_scores_sum = paddle.clip(assigned_scores_sum, min=1.)
+ loss_cls /= assigned_scores_sum
+
+ loss_l1, loss_iou, loss_dfl = \
+ self._bbox_loss(pred_distri, pred_bboxes, anchor_points_s,
+ assigned_labels, assigned_bboxes, assigned_scores,
+ assigned_scores_sum)
+ # contrast loss
+ loss_contrast = self.contrast_loss(pred_contrast_encoder.reshape([-1, pred_contrast_encoder.shape[-1]]), \
+ assigned_labels.reshape([-1]), assigned_scores.max(-1).reshape([-1]))
+
+ loss = self.loss_weight['class'] * loss_cls + \
+ self.loss_weight['iou'] * loss_iou + \
+ self.loss_weight['dfl'] * loss_dfl + \
+ self.loss_weight['contrast'] * loss_contrast
+
+ out_dict = {
+ 'loss': loss,
+ 'loss_cls': loss_cls,
+ 'loss_iou': loss_iou,
+ 'loss_dfl': loss_dfl,
+ 'loss_l1': loss_l1,
+ 'loss_contrast': loss_contrast
+ }
+ return out_dict
diff --git a/ppdet/modeling/losses/__init__.py b/ppdet/modeling/losses/__init__.py
index 0c946d92827f61a9ac132587413ac6c8fc0df89e..13bcd49503d9a071a2d23250012b46b3c78f9e03 100644
--- a/ppdet/modeling/losses/__init__.py
+++ b/ppdet/modeling/losses/__init__.py
@@ -28,6 +28,8 @@ from . import sparsercnn_loss
from . import focal_loss
from . import smooth_l1_loss
from . import probiou_loss
+from . import cot_loss
+from . import supcontrast
from .yolo_loss import *
from .iou_aware_loss import *
@@ -46,3 +48,5 @@ from .focal_loss import *
from .smooth_l1_loss import *
from .pose3d_loss import *
from .probiou_loss import *
+from .cot_loss import *
+from .supcontrast import *
\ No newline at end of file
diff --git a/ppdet/modeling/losses/cot_loss.py b/ppdet/modeling/losses/cot_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..40f8f9acf9f92e3051c6e90c44604be3460e964a
--- /dev/null
+++ b/ppdet/modeling/losses/cot_loss.py
@@ -0,0 +1,61 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+import numpy as np
+from ppdet.core.workspace import register
+
+__all__ = ['COTLoss']
+
+@register
+class COTLoss(nn.Layer):
+ __shared__ = ['num_classes']
+ def __init__(self,
+ num_classes=80,
+ cot_scale=1,
+ cot_lambda=1):
+ super(COTLoss, self).__init__()
+ self.cot_scale = cot_scale
+ self.cot_lambda = cot_lambda
+ self.num_classes = num_classes
+
+ def forward(self, scores, targets, cot_relation):
+ cls_name = 'loss_bbox_cls_cot'
+ loss_bbox = {}
+
+ tgt_labels, tgt_bboxes, tgt_gt_inds = targets
+ tgt_labels = paddle.concat(tgt_labels) if len(
+ tgt_labels) > 1 else tgt_labels[0]
+ mask = (tgt_labels < self.num_classes)
+ valid_inds = paddle.nonzero(tgt_labels >= 0).flatten()
+ if valid_inds.shape[0] == 0:
+ loss_bbox[cls_name] = paddle.zeros([1], dtype='float32')
+ else:
+ tgt_labels = tgt_labels.cast('int64')
+ valid_cot_targets = []
+ for i in range(tgt_labels.shape[0]):
+ train_label = tgt_labels[i]
+ if train_label < self.num_classes:
+ valid_cot_targets.append(cot_relation[train_label])
+ coco_targets = paddle.to_tensor(valid_cot_targets)
+ coco_targets.stop_gradient = True
+ coco_loss = - coco_targets * F.log_softmax(scores[mask][:, :-1] * self.cot_scale)
+ loss_bbox[cls_name] = self.cot_lambda * paddle.mean(paddle.sum(coco_loss, axis=-1))
+ return loss_bbox
diff --git a/ppdet/modeling/losses/supcontrast.py b/ppdet/modeling/losses/supcontrast.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e59f08124fc0acb974ee57d70d5a489e9cd5312
--- /dev/null
+++ b/ppdet/modeling/losses/supcontrast.py
@@ -0,0 +1,83 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+
+import random
+from ppdet.core.workspace import register
+
+
+__all__ = ['SupContrast']
+
+
+@register
+class SupContrast(nn.Layer):
+ __shared__ = [
+ 'num_classes'
+ ]
+ def __init__(self, num_classes=80, temperature=2.5, sample_num=4096, thresh=0.75):
+ super(SupContrast, self).__init__()
+ self.num_classes = num_classes
+ self.temperature = temperature
+ self.sample_num = sample_num
+ self.thresh = thresh
+ def forward(self, features, labels, scores):
+
+ assert features.shape[0] == labels.shape[0] == scores.shape[0]
+ positive_mask = (labels < self.num_classes)
+ positive_features, positive_labels, positive_scores = features[positive_mask], labels[positive_mask], \
+ scores[positive_mask]
+
+ negative_mask = (labels == self.num_classes)
+ negative_features, negative_labels, negative_scores = features[negative_mask], labels[negative_mask], \
+ scores[negative_mask]
+
+ N = negative_features.shape[0]
+ S = self.sample_num - positive_mask.sum()
+ index = paddle.to_tensor(random.sample(range(N), int(S)), dtype='int32')
+
+ negative_features = paddle.index_select(x=negative_features, index=index, axis=0)
+ negative_labels = paddle.index_select(x=negative_labels, index=index, axis=0)
+ negative_scores = paddle.index_select(x=negative_scores, index=index, axis=0)
+
+ features = paddle.concat([positive_features, negative_features], 0)
+ labels = paddle.concat([positive_labels, negative_labels], 0)
+ scores = paddle.concat([positive_scores, negative_scores], 0)
+
+ if len(labels.shape) == 1:
+ labels = labels.reshape([-1, 1])
+ label_mask = paddle.equal(labels, labels.T).detach()
+ similarity = (paddle.matmul(features, features.T) / self.temperature)
+
+ sim_row_max = paddle.max(similarity, axis=1, keepdim=True)
+ similarity = similarity - sim_row_max
+
+ logits_mask = paddle.ones_like(similarity).detach()
+ logits_mask.fill_diagonal_(0)
+
+ exp_sim = paddle.exp(similarity) * logits_mask
+ log_prob = similarity - paddle.log(exp_sim.sum(axis=1, keepdim=True))
+
+ per_label_log_prob = (log_prob * logits_mask * label_mask).sum(1) / label_mask.sum(1)
+ keep = scores > self.thresh
+ per_label_log_prob = per_label_log_prob[keep]
+ loss = -per_label_log_prob
+
+ return loss.mean()
\ No newline at end of file
diff --git a/tools/train.py b/tools/train.py
index 6f0d2a6d3591ca4c0f5a3d7b88f9009b3696c127..6604495a2b5c91e0c6257653f617a23d21c4e5e0 100755
--- a/tools/train.py
+++ b/tools/train.py
@@ -30,8 +30,10 @@ warnings.filterwarnings('ignore')
import paddle
from ppdet.core.workspace import load_config, merge_config
-from ppdet.engine import Trainer, init_parallel_env, set_random_seed, init_fleet_env
+
+from ppdet.engine import Trainer, TrainerCot, init_parallel_env, set_random_seed, init_fleet_env
from ppdet.engine.trainer_ssod import Trainer_DenseTeacher
+
from ppdet.slim import build_slim_model
from ppdet.utils.cli import ArgsParser, merge_args
@@ -125,6 +127,7 @@ def run(FLAGS, cfg):
if FLAGS.enable_ce:
set_random_seed(0)
+ # build trainer
ssod_method = cfg.get('ssod_method', None)
if ssod_method is not None:
if ssod_method == 'DenseTeacher':
@@ -133,8 +136,9 @@ def run(FLAGS, cfg):
raise ValueError(
"Semi-Supervised Object Detection only support DenseTeacher now."
)
+ elif cfg.get('use_cot', False):
+ trainer = TrainerCot(cfg, mode='train')
else:
- # build trainer
trainer = Trainer(cfg, mode='train')
# load weights