diff --git a/configs/few-shot/README.md b/configs/few-shot/README.md new file mode 100644 index 0000000000000000000000000000000000000000..b8b4a3226230481fd2e0d5f7e763c7215bcad82c --- /dev/null +++ b/configs/few-shot/README.md @@ -0,0 +1,47 @@ +# Co-tuning for Transfer Learning
Supervised Contrastive Learning + +## Data preparation +以[Kaggle数据集](https://www.kaggle.com/andrewmvd/road-sign-detection) 比赛数据为例,说明如何准备自定义数据。 +Kaggle上的 [road-sign-detection](https://www.kaggle.com/andrewmvd/road-sign-detection) 比赛数据包含877张图像,数据类别4类:crosswalk,speedlimit,stop,trafficlight。 +可从Kaggle上下载,也可以从[下载链接](https://paddlemodels.bj.bcebos.com/object_detection/roadsign_voc.tar) 下载。 +分别从原始数据集中每类选取相同样本(例如:10shots即每类都有十个训练样本)训练即可。
+工业数据集使用PKU-Market-PCB,该数据集用于印刷电路板(PCB)的瑕疵检测,提供了6种常见的PCB缺陷[下载链接](./configs/ppyoloe/application/README.md) + +## Model Zoo +| 骨架网络 | 网络类型 | 每张GPU图片个数 | 每类样本个数 | Box AP | 下载 | 配置文件 | +| :------------------- | :------------- | :-----: | :------------: | :-----: | :-----------------------------------------------------: | :-----: | +| ResNet50-vd | Faster | 1 | 10 | 60.1 | [下载链接](https://bj.bcebos.com/v1/paddledet/models/faster_rcnn_r50_vd_fpn_1x_coco.pdparams) | [配置文件](./faster_rcnn_r50_vd_fpn_1x_coco_cotuning_roadsign.yml) | +| PPYOLOE_crn_s | PPYOLOE | 1 | 30 | 17.8 | [下载链接](https://bj.bcebos.com/v1/paddledet/models/ppyoloe_plus_crn_s_80e_contrast_pcb.pdparams) |[配置文件](./ppyoloe_plus_crn_s_80e_contrast_pcb.yml) | + +## Compare-cotuning +| 骨架网络 | 网络类型 | 每张GPU图片个数 |每类样本个数 | Cotuning | Box AP | +| :------------------- | :------------- | :-----: | :-----: | :------------: | :-----: | +| ResNet50-vd | Faster | 1 | 10 | False | 56.7 | +| ResNet50-vd | Faster | 1 | 10 | True | 60.1 | + +## Compare-contrast +| 骨架网络 | 网络类型 | 每张GPU图片个数 | 每类样本个数 | Contrast | Box AP | +| :------------------- | :------------- | :-----: | :-----: | :------------: | :-----: | +| PPYOLOE_crn_s | PPYOLOE | 1 | 30 | False | 15.4 | +| PPYOLOE_crn_s | PPYOLOE | 1 | 30 | True | 17.8 | + +## Citations +``` +@article{you2020co, + title={Co-tuning for transfer learning}, + author={You, Kaichao and Kou, Zhi and Long, Mingsheng and Wang, Jianmin}, + journal={Advances in Neural Information Processing Systems}, + volume={33}, + pages={17236--17246}, + year={2020} +} + +@article{khosla2020supervised, + title={Supervised contrastive learning}, + author={Khosla, Prannay and Teterwak, Piotr and Wang, Chen and Sarna, Aaron and Tian, Yonglong and Isola, Phillip and Maschinot, Aaron and Liu, Ce and Krishnan, Dilip}, + journal={Advances in Neural Information Processing Systems}, + volume={33}, + pages={18661--18673}, + year={2020} +} +``` \ No newline at end of file diff --git a/configs/few-shot/_base_/faster_fpn_reader.yml b/configs/few-shot/_base_/faster_fpn_reader.yml new file mode 100644 index 0000000000000000000000000000000000000000..9b9abccd63e499bfa9402f3038425470e4a6e953 --- /dev/null +++ b/configs/few-shot/_base_/faster_fpn_reader.yml @@ -0,0 +1,40 @@ +worker_num: 2 +TrainReader: + sample_transforms: + - Decode: {} + - RandomResize: {target_size: [[640, 1333], [672, 1333], [704, 1333], [736, 1333], [768, 1333], [800, 1333]], interp: 2, keep_ratio: True} + - RandomFlip: {prob: 0.5} + - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - Permute: {} + batch_transforms: + - PadBatch: {pad_to_stride: 32} + batch_size: 1 + shuffle: true + drop_last: true + collate_batch: false + + +EvalReader: + sample_transforms: + - Decode: {} + - Resize: {interp: 2, target_size: [800, 1333], keep_ratio: True} + - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - Permute: {} + batch_transforms: + - PadBatch: {pad_to_stride: 32} + batch_size: 1 + shuffle: false + drop_last: false + + +TestReader: + sample_transforms: + - Decode: {} + - Resize: {interp: 2, target_size: [800, 1333], keep_ratio: True} + - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - Permute: {} + batch_transforms: + - PadBatch: {pad_to_stride: 32} + batch_size: 1 + shuffle: false + drop_last: false diff --git a/configs/few-shot/_base_/faster_rcnn_r50.yml b/configs/few-shot/_base_/faster_rcnn_r50.yml new file mode 100644 index 0000000000000000000000000000000000000000..fd29f5ea1a1df9e2599d3efcff344c5d3363945e --- /dev/null +++ b/configs/few-shot/_base_/faster_rcnn_r50.yml @@ -0,0 +1,66 @@ +architecture: FasterRCNN +pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet50_cos_pretrained.pdparams + +FasterRCNN: + backbone: ResNet + rpn_head: RPNHead + bbox_head: BBoxHead + # post process + bbox_post_process: BBoxPostProcess + + +ResNet: + # index 0 stands for res2 + depth: 50 + norm_type: bn + freeze_at: 0 + return_idx: [2] + num_stages: 3 + +RPNHead: + anchor_generator: + aspect_ratios: [0.5, 1.0, 2.0] + anchor_sizes: [32, 64, 128, 256, 512] + strides: [16] + rpn_target_assign: + batch_size_per_im: 256 + fg_fraction: 0.5 + negative_overlap: 0.3 + positive_overlap: 0.7 + use_random: True + train_proposal: + min_size: 0.0 + nms_thresh: 0.7 + pre_nms_top_n: 12000 + post_nms_top_n: 2000 + topk_after_collect: False + test_proposal: + min_size: 0.0 + nms_thresh: 0.7 + pre_nms_top_n: 6000 + post_nms_top_n: 1000 + + +BBoxHead: + head: Res5Head + roi_extractor: + resolution: 14 + sampling_ratio: 0 + aligned: True + bbox_assigner: BBoxAssigner + with_pool: true + +BBoxAssigner: + batch_size_per_im: 512 + bg_thresh: 0.5 + fg_thresh: 0.5 + fg_fraction: 0.25 + use_random: True + +BBoxPostProcess: + decode: RCNNBox + nms: + name: MultiClassNMS + keep_top_k: 100 + score_threshold: 0.05 + nms_threshold: 0.5 diff --git a/configs/few-shot/_base_/faster_rcnn_r50_fpn.yml b/configs/few-shot/_base_/faster_rcnn_r50_fpn.yml new file mode 100644 index 0000000000000000000000000000000000000000..38ee81def0cb528f3f67e8ed616b9589bd72de9e --- /dev/null +++ b/configs/few-shot/_base_/faster_rcnn_r50_fpn.yml @@ -0,0 +1,73 @@ +architecture: FasterRCNN +pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet50_cos_pretrained.pdparams + +FasterRCNN: + backbone: ResNet + neck: FPN + rpn_head: RPNHead + bbox_head: BBoxHead + # post process + bbox_post_process: BBoxPostProcess + + +ResNet: + # index 0 stands for res2 + depth: 50 + norm_type: bn + freeze_at: 0 + return_idx: [0,1,2,3] + num_stages: 4 + +FPN: + out_channel: 256 + +RPNHead: + anchor_generator: + aspect_ratios: [0.5, 1.0, 2.0] + anchor_sizes: [[32], [64], [128], [256], [512]] + strides: [4, 8, 16, 32, 64] + rpn_target_assign: + batch_size_per_im: 256 + fg_fraction: 0.5 + negative_overlap: 0.3 + positive_overlap: 0.7 + use_random: True + train_proposal: + min_size: 0.0 + nms_thresh: 0.7 + pre_nms_top_n: 2000 + post_nms_top_n: 1000 + topk_after_collect: True + test_proposal: + min_size: 0.0 + nms_thresh: 0.7 + pre_nms_top_n: 1000 + post_nms_top_n: 1000 + + +BBoxHead: + head: TwoFCHead + roi_extractor: + resolution: 7 + sampling_ratio: 0 + aligned: True + bbox_assigner: BBoxAssigner + +BBoxAssigner: + batch_size_per_im: 512 + bg_thresh: 0.5 + fg_thresh: 0.5 + fg_fraction: 0.25 + use_random: True + +TwoFCHead: + out_channel: 1024 + + +BBoxPostProcess: + decode: RCNNBox + nms: + name: MultiClassNMS + keep_top_k: 100 + score_threshold: 0.05 + nms_threshold: 0.5 diff --git a/configs/few-shot/_base_/faster_reader.yml b/configs/few-shot/_base_/faster_reader.yml new file mode 100644 index 0000000000000000000000000000000000000000..e1c1bb6bc262e86ea69ae78919064aa2b6834311 --- /dev/null +++ b/configs/few-shot/_base_/faster_reader.yml @@ -0,0 +1,40 @@ +worker_num: 2 +TrainReader: + sample_transforms: + - Decode: {} + - RandomResize: {target_size: [[640, 1333], [672, 1333], [704, 1333], [736, 1333], [768, 1333], [800, 1333]], interp: 2, keep_ratio: True} + - RandomFlip: {prob: 0.5} + - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - Permute: {} + batch_transforms: + - PadBatch: {pad_to_stride: -1} + batch_size: 1 + shuffle: true + drop_last: true + collate_batch: false + + +EvalReader: + sample_transforms: + - Decode: {} + - Resize: {interp: 2, target_size: [800, 1333], keep_ratio: True} + - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - Permute: {} + batch_transforms: + - PadBatch: {pad_to_stride: -1} + batch_size: 1 + shuffle: false + drop_last: false + + +TestReader: + sample_transforms: + - Decode: {} + - Resize: {interp: 2, target_size: [800, 1333], keep_ratio: True} + - NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} + - Permute: {} + batch_transforms: + - PadBatch: {pad_to_stride: -1} + batch_size: 1 + shuffle: false + drop_last: false diff --git a/configs/few-shot/_base_/optimizer_1x.yml b/configs/few-shot/_base_/optimizer_1x.yml new file mode 100644 index 0000000000000000000000000000000000000000..4caaa63bda15917137a9ac22b736ae83c3d04856 --- /dev/null +++ b/configs/few-shot/_base_/optimizer_1x.yml @@ -0,0 +1,19 @@ +epoch: 12 + +LearningRate: + base_lr: 0.01 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: [8, 11] + - !LinearWarmup + start_factor: 0.1 + steps: 1000 + +OptimizerBuilder: + optimizer: + momentum: 0.9 + type: Momentum + regularizer: + factor: 0.0001 + type: L2 diff --git a/configs/few-shot/_base_/optimizer_80e.yml b/configs/few-shot/_base_/optimizer_80e.yml new file mode 100644 index 0000000000000000000000000000000000000000..7a8773df15aa103f3194f56634604d84a2a084eb --- /dev/null +++ b/configs/few-shot/_base_/optimizer_80e.yml @@ -0,0 +1,18 @@ +epoch: 80 + +LearningRate: + base_lr: 0.001 + schedulers: + - !CosineDecay + max_epochs: 96 + - !LinearWarmup + start_factor: 0. + epochs: 5 + +OptimizerBuilder: + optimizer: + momentum: 0.9 + type: Momentum + regularizer: + factor: 0.0005 + type: L2 diff --git a/configs/few-shot/_base_/ppyoloe_plus_crn.yml b/configs/few-shot/_base_/ppyoloe_plus_crn.yml new file mode 100644 index 0000000000000000000000000000000000000000..a83f35008f4797311689ed952abef15df0c0eea7 --- /dev/null +++ b/configs/few-shot/_base_/ppyoloe_plus_crn.yml @@ -0,0 +1,49 @@ +architecture: YOLOv3 +norm_type: sync_bn +use_ema: true +use_cot: False +ema_decay: 0.9998 +ema_black_list: ['proj_conv.weight'] +custom_black_list: ['reduce_mean'] + +YOLOv3: + backbone: CSPResNet + neck: CustomCSPPAN + yolo_head: PPYOLOEHead + post_process: ~ + +CSPResNet: + layers: [3, 6, 6, 3] + channels: [64, 128, 256, 512, 1024] + return_idx: [1, 2, 3] + use_large_stem: True + use_alpha: True + +CustomCSPPAN: + out_channels: [768, 384, 192] + stage_num: 1 + block_num: 3 + act: 'swish' + spp: true + +PPYOLOEHead: + fpn_strides: [32, 16, 8] + grid_cell_scale: 5.0 + grid_cell_offset: 0.5 + static_assigner_epoch: 30 + use_varifocal_loss: True + loss_weight: {class: 1.0, iou: 2.5, dfl: 0.5} + static_assigner: + name: ATSSAssigner + topk: 9 + assigner: + name: TaskAlignedAssigner + topk: 13 + alpha: 1.0 + beta: 6.0 + nms: + name: MultiClassNMS + nms_top_k: 1000 + keep_top_k: 300 + score_threshold: 0.01 + nms_threshold: 0.7 diff --git a/configs/few-shot/_base_/ppyoloe_plus_reader.yml b/configs/few-shot/_base_/ppyoloe_plus_reader.yml new file mode 100644 index 0000000000000000000000000000000000000000..cd9cdeff8b9d46e41a4e6fb518339168dfd4b154 --- /dev/null +++ b/configs/few-shot/_base_/ppyoloe_plus_reader.yml @@ -0,0 +1,40 @@ +worker_num: 4 +eval_height: &eval_height 640 +eval_width: &eval_width 640 +eval_size: &eval_size [*eval_height, *eval_width] + +TrainReader: + sample_transforms: + - Decode: {} + - RandomDistort: {} + - RandomExpand: {fill_value: [123.675, 116.28, 103.53]} + - RandomCrop: {} + - RandomFlip: {} + batch_transforms: + - BatchRandomResize: {target_size: [320, 352, 384, 416, 448, 480, 512, 544, 576, 608, 640, 672, 704, 736, 768], random_size: True, random_interp: True, keep_ratio: False} + - NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none} + - Permute: {} + - PadGT: {} + batch_size: 8 + shuffle: true + drop_last: true + use_shared_memory: true + collate_batch: true + +EvalReader: + sample_transforms: + - Decode: {} + - Resize: {target_size: *eval_size, keep_ratio: False, interp: 2} + - NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none} + - Permute: {} + batch_size: 2 + +TestReader: + inputs_def: + image_shape: [3, *eval_height, *eval_width] + sample_transforms: + - Decode: {} + - Resize: {target_size: *eval_size, keep_ratio: False, interp: 2} + - NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none} + - Permute: {} + batch_size: 1 diff --git a/configs/few-shot/faster_rcnn_r50_vd_fpn_1x_coco_cotuning_roadsign.yml b/configs/few-shot/faster_rcnn_r50_vd_fpn_1x_coco_cotuning_roadsign.yml new file mode 100644 index 0000000000000000000000000000000000000000..75fd9e3d0ccaa56fd77d8851711b8a44720df566 --- /dev/null +++ b/configs/few-shot/faster_rcnn_r50_vd_fpn_1x_coco_cotuning_roadsign.yml @@ -0,0 +1,67 @@ +_BASE_: [ + '../datasets/coco_detection.yml', + '../runtime.yml', + '_base_/optimizer_1x.yml', + '_base_/faster_rcnn_r50_fpn.yml', + '_base_/faster_fpn_reader.yml', +] +pretrain_weights: https://paddledet.bj.bcebos.com/models/faster_rcnn_r50_vd_fpn_1x_coco.pdparams +weights: output/faster_rcnn_r50_vd_fpn_1x_coco_cotuning_roadsign/model_final + +snapshot_epoch: 5 + +ResNet: + # index 0 stands for res2 + depth: 50 + variant: d + norm_type: bn + freeze_at: 0 + return_idx: [0,1,2,3] + num_stages: 4 + +epoch: 30 +LearningRate: + base_lr: 0.001 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: [8, 11] + - !LinearWarmup + start_factor: 0.1 + steps: 1000 + +use_cot: True +BBoxHead: + head: TwoFCHead + roi_extractor: + resolution: 7 + sampling_ratio: 0 + aligned: True + bbox_assigner: BBoxAssigner + cot_classes: 80 + loss_cot: + name: COTLoss + cot_lambda: 1 + cot_scale: 1 + +num_classes: 4 +metric: COCO +map_type: integral + +TrainDataset: + !COCODataSet + image_dir: images + anno_path: annotations/train_shots10.json + dataset_dir: dataset/roadsign_coco + data_fields: ['image', 'gt_bbox', 'gt_class', 'is_crowd'] + +EvalDataset: + !COCODataSet + image_dir: images + anno_path: annotations/roadsign_valid.json + dataset_dir: dataset/roadsign_coco + +TestDataset: + !ImageFolder + anno_path: annotations/roadsign_valid.json + dataset_dir: dataset/roadsign_coco \ No newline at end of file diff --git a/configs/few-shot/ppyoloe_plus_crn_s_80e_contrast_pcb.yml b/configs/few-shot/ppyoloe_plus_crn_s_80e_contrast_pcb.yml new file mode 100644 index 0000000000000000000000000000000000000000..05320089fcb1f19a690edd030f9b57b502909a38 --- /dev/null +++ b/configs/few-shot/ppyoloe_plus_crn_s_80e_contrast_pcb.yml @@ -0,0 +1,81 @@ +_BASE_: [ + '../datasets/coco_detection.yml', + '../runtime.yml', + './_base_/optimizer_80e.yml', + './_base_/ppyoloe_plus_crn.yml', + './_base_/ppyoloe_plus_reader.yml', +] + +log_iter: 100 +snapshot_epoch: 10 +weights: output/ppyoloe_plus_crn_s_80e_contrast_pcb/model_final + +pretrain_weights: https://bj.bcebos.com/v1/paddledet/models/pretrained/ppyoloe_crn_s_obj365_pretrained.pdparams +depth_mult: 0.33 +width_mult: 0.50 + +epoch: 80 + +LearningRate: + base_lr: 0.0001 + schedulers: + - !CosineDecay + max_epochs: 96 + - !LinearWarmup + start_factor: 0. + epochs: 5 + +YOLOv3: + backbone: CSPResNet + neck: CustomCSPPAN + yolo_head: PPYOLOEContrastHead + post_process: ~ + +PPYOLOEContrastHead: + fpn_strides: [32, 16, 8] + grid_cell_scale: 5.0 + grid_cell_offset: 0.5 + static_assigner_epoch: 100 + use_varifocal_loss: True + loss_weight: {class: 1.0, iou: 2.5, dfl: 0.5, contrast: 0.2} + static_assigner: + name: ATSSAssigner + topk: 9 + assigner: + name: TaskAlignedAssigner + topk: 13 + alpha: 1.0 + beta: 6.0 + contrast_loss: + name: SupContrast + temperature: 100 + sample_num: 2048 + thresh: 0.75 + nms: + name: MultiClassNMS + nms_top_k: 1000 + keep_top_k: 300 + score_threshold: 0.01 + nms_threshold: 0.7 + +num_classes: 6 +metric: COCO +map_type: integral + +TrainDataset: + !COCODataSet + image_dir: images + anno_path: pcb_cocoanno/train_shots30.json + dataset_dir: dataset/pcb + data_fields: ['image', 'gt_bbox', 'gt_class', 'is_crowd'] + +EvalDataset: + !COCODataSet + image_dir: images + anno_path: pcb_cocoanno/val.json + dataset_dir: dataset/pcb + +TestDataset: + !ImageFolder + anno_path: pcb_cocoanno/val.json + dataset_dir: dataset/pcb \ No newline at end of file diff --git a/ppdet/engine/__init__.py b/ppdet/engine/__init__.py index 32cb85e224df9db965ee47a989c2aef0284dd9fd..91166e8764f521cb3dd78ba86c5681c4b531413c 100644 --- a/ppdet/engine/__init__.py +++ b/ppdet/engine/__init__.py @@ -15,6 +15,9 @@ from . import trainer from .trainer import * +from . import trainer_cot +from .trainer_cot import * + from . import callbacks from .callbacks import * diff --git a/ppdet/engine/trainer_cot.py b/ppdet/engine/trainer_cot.py new file mode 100644 index 0000000000000000000000000000000000000000..38d95fabfd0d19312af3cc40309cc8051ff538c3 --- /dev/null +++ b/ppdet/engine/trainer_cot.py @@ -0,0 +1,42 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ppdet.core.workspace import create +from ppdet.utils.logger import setup_logger +logger = setup_logger('ppdet.engine') + +from . import Trainer +__all__ = ['TrainerCot'] + +class TrainerCot(Trainer): + """ + Trainer for label-cotuning + calculate the relationship between base_classes and novel_classes + """ + def __init__(self, cfg, mode='train'): + super(TrainerCot, self).__init__(cfg, mode) + self.cotuning_init() + + def cotuning_init(self): + num_classes_novel = self.cfg['num_classes'] + + self.load_weights(self.cfg.pretrain_weights) + + self.model.eval() + relationship = self.model.relationship_learning(self.loader, num_classes_novel) + + self.model.init_cot_head(relationship) + self.optimizer = create('OptimizerBuilder')(self.lr, self.model) + + diff --git a/ppdet/modeling/architectures/faster_rcnn.py b/ppdet/modeling/architectures/faster_rcnn.py index ce9a8e4b57d2dfe54fde037fed2dc0156cb71b51..e5972b2d8f39b3667155e91851551c35c13a343e 100644 --- a/ppdet/modeling/architectures/faster_rcnn.py +++ b/ppdet/modeling/architectures/faster_rcnn.py @@ -19,6 +19,7 @@ from __future__ import print_function import paddle from ppdet.core.workspace import register, create from .meta_arch import BaseArch +import numpy as np __all__ = ['FasterRCNN'] @@ -51,6 +52,9 @@ class FasterRCNN(BaseArch): self.bbox_head = bbox_head self.bbox_post_process = bbox_post_process + def init_cot_head(self, relationship): + self.bbox_head.init_cot_head(relationship) + @classmethod def from_config(cls, cfg, *args, **kwargs): backbone = create(cfg['backbone']) @@ -104,3 +108,38 @@ class FasterRCNN(BaseArch): bbox_pred, bbox_num = self._forward() output = {'bbox': bbox_pred, 'bbox_num': bbox_num} return output + + def target_bbox_forward(self, data): + body_feats = self.backbone(data) + if self.neck is not None: + body_feats = self.neck(body_feats) + rois = [roi for roi in data['gt_bbox']] + rois_num = paddle.concat([paddle.shape(roi)[0] for roi in rois]) + + preds, _ = self.bbox_head(body_feats, rois, rois_num, None, cot=True) + return preds + + def relationship_learning(self, loader, num_classes_novel): + print('computing relationship') + train_labels_list = [] + label_list = [] + + for step_id, data in enumerate(loader): + _, bbox_prob = self.target_bbox_forward(data) + batch_size = data['im_id'].shape[0] + for i in range(batch_size): + num_bbox = data['gt_class'][i].shape[0] + train_labels = data['gt_class'][i] + train_labels_list.append(train_labels.numpy().squeeze(1)) + base_labels = bbox_prob.detach().numpy()[:,:-1] + label_list.append(base_labels) + + labels = np.concatenate(train_labels_list, 0) + probabilities = np.concatenate(label_list, 0) + N_t = np.max(labels) + 1 + conditional = [] + for i in range(N_t): + this_class = probabilities[labels == i] + average = np.mean(this_class, axis=0, keepdims=True) + conditional.append(average) + return np.concatenate(conditional) \ No newline at end of file diff --git a/ppdet/modeling/heads/__init__.py b/ppdet/modeling/heads/__init__.py index f54bb03ea2919d5d70e56f577ca46b79c14ba81f..8856d7c182f4efcd41f9366fb807808160ad7e36 100644 --- a/ppdet/modeling/heads/__init__.py +++ b/ppdet/modeling/heads/__init__.py @@ -37,6 +37,7 @@ from . import fcosr_head from . import ppyoloe_r_head from . import ld_gfl_head from . import yolof_head +from . import ppyoloe_contrast_head from .bbox_head import * from .mask_head import * @@ -63,3 +64,4 @@ from .fcosr_head import * from .ld_gfl_head import * from .ppyoloe_r_head import * from .yolof_head import * +from .ppyoloe_contrast_head import * \ No newline at end of file diff --git a/ppdet/modeling/heads/bbox_head.py b/ppdet/modeling/heads/bbox_head.py index e0041be371a23e31842165cd5e5a0e4d95265c8c..3ce47983cec3e7e5961be95e719a1187c7d7fc55 100644 --- a/ppdet/modeling/heads/bbox_head.py +++ b/ppdet/modeling/heads/bbox_head.py @@ -160,8 +160,8 @@ class XConvNormHead(nn.Layer): @register class BBoxHead(nn.Layer): - __shared__ = ['num_classes'] - __inject__ = ['bbox_assigner', 'bbox_loss'] + __shared__ = ['num_classes', 'use_cot'] + __inject__ = ['bbox_assigner', 'bbox_loss', 'loss_cot'] """ RCNN bbox head @@ -173,7 +173,10 @@ class BBoxHead(nn.Layer): box. with_pool (bool): Whether to use pooling for the RoI feature. num_classes (int): The number of classes - bbox_weight (List[float]): The weight to get the decode box + bbox_weight (List[float]): The weight to get the decode box + cot_classes (int): The number of base classes + loss_cot (object): The module of Label-cotuning + use_cot(bool): whether to use Label-cotuning """ def __init__(self, @@ -185,7 +188,10 @@ class BBoxHead(nn.Layer): num_classes=80, bbox_weight=[10., 10., 5., 5.], bbox_loss=None, - loss_normalize_pos=False): + loss_normalize_pos=False, + cot_classes=None, + loss_cot='COTLoss', + use_cot=False): super(BBoxHead, self).__init__() self.head = head self.roi_extractor = roi_extractor @@ -199,11 +205,29 @@ class BBoxHead(nn.Layer): self.bbox_loss = bbox_loss self.loss_normalize_pos = loss_normalize_pos - self.bbox_score = nn.Linear( - in_channel, - self.num_classes + 1, - weight_attr=paddle.ParamAttr(initializer=Normal( - mean=0.0, std=0.01))) + self.loss_cot = loss_cot + self.cot_relation = None + self.cot_classes = cot_classes + self.use_cot = use_cot + if use_cot: + self.cot_bbox_score = nn.Linear( + in_channel, + self.num_classes + 1, + weight_attr=paddle.ParamAttr(initializer=Normal( + mean=0.0, std=0.01))) + + self.bbox_score = nn.Linear( + in_channel, + self.cot_classes + 1, + weight_attr=paddle.ParamAttr(initializer=Normal( + mean=0.0, std=0.01))) + self.cot_bbox_score.skip_quant = True + else: + self.bbox_score = nn.Linear( + in_channel, + self.num_classes + 1, + weight_attr=paddle.ParamAttr(initializer=Normal( + mean=0.0, std=0.01))) self.bbox_score.skip_quant = True self.bbox_delta = nn.Linear( @@ -215,6 +239,9 @@ class BBoxHead(nn.Layer): self.assigned_label = None self.assigned_rois = None + def init_cot_head(self, relationship): + self.cot_relation = relationship + @classmethod def from_config(cls, cfg, input_shape): roi_pooler = cfg['roi_extractor'] @@ -229,7 +256,7 @@ class BBoxHead(nn.Layer): 'in_channel': head.out_shape[0].channels } - def forward(self, body_feats=None, rois=None, rois_num=None, inputs=None): + def forward(self, body_feats=None, rois=None, rois_num=None, inputs=None, cot=False): """ body_feats (list[Tensor]): Feature maps from backbone rois (list[Tensor]): RoIs generated from RPN module @@ -248,7 +275,11 @@ class BBoxHead(nn.Layer): feat = paddle.squeeze(feat, axis=[2, 3]) else: feat = bbox_feat - scores = self.bbox_score(feat) + if self.use_cot: + scores = self.cot_bbox_score(feat) + cot_scores = self.bbox_score(feat) + else: + scores = self.bbox_score(feat) deltas = self.bbox_delta(feat) if self.training: @@ -259,11 +290,19 @@ class BBoxHead(nn.Layer): rois, self.bbox_weight, loss_normalize_pos=self.loss_normalize_pos) + + if self.cot_relation is not None: + loss_cot = self.loss_cot(cot_scores, targets, self.cot_relation) + loss.update(loss_cot) return loss, bbox_feat else: - pred = self.get_prediction(scores, deltas) + if cot: + pred = self.get_prediction(cot_scores, deltas) + else: + pred = self.get_prediction(scores, deltas) return pred, self.head + def get_loss(self, scores, deltas, diff --git a/ppdet/modeling/heads/ppyoloe_contrast_head.py b/ppdet/modeling/heads/ppyoloe_contrast_head.py new file mode 100644 index 0000000000000000000000000000000000000000..190c519cfe58cee624cec2a2b71b07003576cd2c --- /dev/null +++ b/ppdet/modeling/heads/ppyoloe_contrast_head.py @@ -0,0 +1,212 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from ppdet.core.workspace import register + +from ..bbox_utils import batch_distance2bbox +from ..losses import GIoULoss +from ..initializer import bias_init_with_prob, constant_, normal_ +from ..assigners.utils import generate_anchors_for_grid_cell +from ppdet.modeling.backbones.cspresnet import ConvBNLayer +from ppdet.modeling.ops import get_static_shape, get_act_fn +from ppdet.modeling.layers import MultiClassNMS +from ppdet.modeling.heads.ppyoloe_head import PPYOLOEHead +__all__ = ['PPYOLOEContrastHead'] + +@register +class PPYOLOEContrastHead(PPYOLOEHead): + __shared__ = [ + 'num_classes', 'eval_size', 'trt', 'exclude_nms', + 'exclude_post_process', 'use_shared_conv' + ] + __inject__ = ['static_assigner', 'assigner', 'nms', 'contrast_loss'] + + def __init__(self, + in_channels=[1024, 512, 256], + num_classes=80, + act='swish', + fpn_strides=(32, 16, 8), + grid_cell_scale=5.0, + grid_cell_offset=0.5, + reg_max=16, + reg_range=None, + static_assigner_epoch=4, + use_varifocal_loss=True, + static_assigner='ATSSAssigner', + assigner='TaskAlignedAssigner', + contrast_loss='SupContrast', + nms='MultiClassNMS', + eval_size=None, + loss_weight={ + 'class': 1.0, + 'iou': 2.5, + 'dfl': 0.5, + }, + trt=False, + exclude_nms=False, + exclude_post_process=False, + use_shared_conv=True): + super().__init__(in_channels, + num_classes, + act, + fpn_strides, + grid_cell_scale, + grid_cell_offset, + reg_max, + reg_range, + static_assigner_epoch, + use_varifocal_loss, + static_assigner, + assigner, + nms, + eval_size, + loss_weight, + trt, + exclude_nms, + exclude_post_process, + use_shared_conv) + + assert len(in_channels) > 0, "len(in_channels) should > 0" + self.contrast_loss = contrast_loss + self.contrast_encoder = nn.LayerList() + for in_c in self.in_channels: + self.contrast_encoder.append( + nn.Conv2D( + in_c, 128, 3, padding=1)) + self._init_contrast_encoder() + + def _init_contrast_encoder(self): + bias_en = bias_init_with_prob(0.01) + for en_ in self.contrast_encoder: + constant_(en_.weight) + constant_(en_.bias, bias_en) + + def forward_train(self, feats, targets): + anchors, anchor_points, num_anchors_list, stride_tensor = \ + generate_anchors_for_grid_cell( + feats, self.fpn_strides, self.grid_cell_scale, + self.grid_cell_offset) + + cls_score_list, reg_distri_list = [], [] + contrast_encoder_list = [] + for i, feat in enumerate(feats): + avg_feat = F.adaptive_avg_pool2d(feat, (1, 1)) + cls_logit = self.pred_cls[i](self.stem_cls[i](feat, avg_feat) + + feat) + reg_distri = self.pred_reg[i](self.stem_reg[i](feat, avg_feat)) + contrast_logit = self.contrast_encoder[i](self.stem_cls[i](feat, avg_feat) + + feat) + contrast_encoder_list.append(contrast_logit.flatten(2).transpose([0, 2, 1])) + # cls and reg + cls_score = F.sigmoid(cls_logit) + cls_score_list.append(cls_score.flatten(2).transpose([0, 2, 1])) + reg_distri_list.append(reg_distri.flatten(2).transpose([0, 2, 1])) + cls_score_list = paddle.concat(cls_score_list, axis=1) + reg_distri_list = paddle.concat(reg_distri_list, axis=1) + contrast_encoder_list = paddle.concat(contrast_encoder_list, axis=1) + + return self.get_loss([ + cls_score_list, reg_distri_list, contrast_encoder_list, anchors, anchor_points, + num_anchors_list, stride_tensor + ], targets) + + def get_loss(self, head_outs, gt_meta): + pred_scores, pred_distri, pred_contrast_encoder, anchors,\ + anchor_points, num_anchors_list, stride_tensor = head_outs + + anchor_points_s = anchor_points / stride_tensor + pred_bboxes = self._bbox_decode(anchor_points_s, pred_distri) + + gt_labels = gt_meta['gt_class'] + gt_bboxes = gt_meta['gt_bbox'] + pad_gt_mask = gt_meta['pad_gt_mask'] + # label assignment + if gt_meta['epoch_id'] < self.static_assigner_epoch: + assigned_labels, assigned_bboxes, assigned_scores = \ + self.static_assigner( + anchors, + num_anchors_list, + gt_labels, + gt_bboxes, + pad_gt_mask, + bg_index=self.num_classes, + pred_bboxes=pred_bboxes.detach() * stride_tensor) + alpha_l = 0.25 + else: + if self.sm_use: + assigned_labels, assigned_bboxes, assigned_scores = \ + self.assigner( + pred_scores.detach(), + pred_bboxes.detach() * stride_tensor, + anchor_points, + stride_tensor, + gt_labels, + gt_bboxes, + pad_gt_mask, + bg_index=self.num_classes) + else: + assigned_labels, assigned_bboxes, assigned_scores = \ + self.assigner( + pred_scores.detach(), + pred_bboxes.detach() * stride_tensor, + anchor_points, + num_anchors_list, + gt_labels, + gt_bboxes, + pad_gt_mask, + bg_index=self.num_classes) + alpha_l = -1 + # rescale bbox + assigned_bboxes /= stride_tensor + # cls loss + if self.use_varifocal_loss: + one_hot_label = F.one_hot(assigned_labels, + self.num_classes + 1)[..., :-1] + loss_cls = self._varifocal_loss(pred_scores, assigned_scores, + one_hot_label) + else: + loss_cls = self._focal_loss(pred_scores, assigned_scores, alpha_l) + + assigned_scores_sum = assigned_scores.sum() + if paddle.distributed.get_world_size() > 1: + paddle.distributed.all_reduce(assigned_scores_sum) + assigned_scores_sum /= paddle.distributed.get_world_size() + assigned_scores_sum = paddle.clip(assigned_scores_sum, min=1.) + loss_cls /= assigned_scores_sum + + loss_l1, loss_iou, loss_dfl = \ + self._bbox_loss(pred_distri, pred_bboxes, anchor_points_s, + assigned_labels, assigned_bboxes, assigned_scores, + assigned_scores_sum) + # contrast loss + loss_contrast = self.contrast_loss(pred_contrast_encoder.reshape([-1, pred_contrast_encoder.shape[-1]]), \ + assigned_labels.reshape([-1]), assigned_scores.max(-1).reshape([-1])) + + loss = self.loss_weight['class'] * loss_cls + \ + self.loss_weight['iou'] * loss_iou + \ + self.loss_weight['dfl'] * loss_dfl + \ + self.loss_weight['contrast'] * loss_contrast + + out_dict = { + 'loss': loss, + 'loss_cls': loss_cls, + 'loss_iou': loss_iou, + 'loss_dfl': loss_dfl, + 'loss_l1': loss_l1, + 'loss_contrast': loss_contrast + } + return out_dict diff --git a/ppdet/modeling/losses/__init__.py b/ppdet/modeling/losses/__init__.py index 0c946d92827f61a9ac132587413ac6c8fc0df89e..13bcd49503d9a071a2d23250012b46b3c78f9e03 100644 --- a/ppdet/modeling/losses/__init__.py +++ b/ppdet/modeling/losses/__init__.py @@ -28,6 +28,8 @@ from . import sparsercnn_loss from . import focal_loss from . import smooth_l1_loss from . import probiou_loss +from . import cot_loss +from . import supcontrast from .yolo_loss import * from .iou_aware_loss import * @@ -46,3 +48,5 @@ from .focal_loss import * from .smooth_l1_loss import * from .pose3d_loss import * from .probiou_loss import * +from .cot_loss import * +from .supcontrast import * \ No newline at end of file diff --git a/ppdet/modeling/losses/cot_loss.py b/ppdet/modeling/losses/cot_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..40f8f9acf9f92e3051c6e90c44604be3460e964a --- /dev/null +++ b/ppdet/modeling/losses/cot_loss.py @@ -0,0 +1,61 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +import numpy as np +from ppdet.core.workspace import register + +__all__ = ['COTLoss'] + +@register +class COTLoss(nn.Layer): + __shared__ = ['num_classes'] + def __init__(self, + num_classes=80, + cot_scale=1, + cot_lambda=1): + super(COTLoss, self).__init__() + self.cot_scale = cot_scale + self.cot_lambda = cot_lambda + self.num_classes = num_classes + + def forward(self, scores, targets, cot_relation): + cls_name = 'loss_bbox_cls_cot' + loss_bbox = {} + + tgt_labels, tgt_bboxes, tgt_gt_inds = targets + tgt_labels = paddle.concat(tgt_labels) if len( + tgt_labels) > 1 else tgt_labels[0] + mask = (tgt_labels < self.num_classes) + valid_inds = paddle.nonzero(tgt_labels >= 0).flatten() + if valid_inds.shape[0] == 0: + loss_bbox[cls_name] = paddle.zeros([1], dtype='float32') + else: + tgt_labels = tgt_labels.cast('int64') + valid_cot_targets = [] + for i in range(tgt_labels.shape[0]): + train_label = tgt_labels[i] + if train_label < self.num_classes: + valid_cot_targets.append(cot_relation[train_label]) + coco_targets = paddle.to_tensor(valid_cot_targets) + coco_targets.stop_gradient = True + coco_loss = - coco_targets * F.log_softmax(scores[mask][:, :-1] * self.cot_scale) + loss_bbox[cls_name] = self.cot_lambda * paddle.mean(paddle.sum(coco_loss, axis=-1)) + return loss_bbox diff --git a/ppdet/modeling/losses/supcontrast.py b/ppdet/modeling/losses/supcontrast.py new file mode 100644 index 0000000000000000000000000000000000000000..3e59f08124fc0acb974ee57d70d5a489e9cd5312 --- /dev/null +++ b/ppdet/modeling/losses/supcontrast.py @@ -0,0 +1,83 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +import random +from ppdet.core.workspace import register + + +__all__ = ['SupContrast'] + + +@register +class SupContrast(nn.Layer): + __shared__ = [ + 'num_classes' + ] + def __init__(self, num_classes=80, temperature=2.5, sample_num=4096, thresh=0.75): + super(SupContrast, self).__init__() + self.num_classes = num_classes + self.temperature = temperature + self.sample_num = sample_num + self.thresh = thresh + def forward(self, features, labels, scores): + + assert features.shape[0] == labels.shape[0] == scores.shape[0] + positive_mask = (labels < self.num_classes) + positive_features, positive_labels, positive_scores = features[positive_mask], labels[positive_mask], \ + scores[positive_mask] + + negative_mask = (labels == self.num_classes) + negative_features, negative_labels, negative_scores = features[negative_mask], labels[negative_mask], \ + scores[negative_mask] + + N = negative_features.shape[0] + S = self.sample_num - positive_mask.sum() + index = paddle.to_tensor(random.sample(range(N), int(S)), dtype='int32') + + negative_features = paddle.index_select(x=negative_features, index=index, axis=0) + negative_labels = paddle.index_select(x=negative_labels, index=index, axis=0) + negative_scores = paddle.index_select(x=negative_scores, index=index, axis=0) + + features = paddle.concat([positive_features, negative_features], 0) + labels = paddle.concat([positive_labels, negative_labels], 0) + scores = paddle.concat([positive_scores, negative_scores], 0) + + if len(labels.shape) == 1: + labels = labels.reshape([-1, 1]) + label_mask = paddle.equal(labels, labels.T).detach() + similarity = (paddle.matmul(features, features.T) / self.temperature) + + sim_row_max = paddle.max(similarity, axis=1, keepdim=True) + similarity = similarity - sim_row_max + + logits_mask = paddle.ones_like(similarity).detach() + logits_mask.fill_diagonal_(0) + + exp_sim = paddle.exp(similarity) * logits_mask + log_prob = similarity - paddle.log(exp_sim.sum(axis=1, keepdim=True)) + + per_label_log_prob = (log_prob * logits_mask * label_mask).sum(1) / label_mask.sum(1) + keep = scores > self.thresh + per_label_log_prob = per_label_log_prob[keep] + loss = -per_label_log_prob + + return loss.mean() \ No newline at end of file diff --git a/tools/train.py b/tools/train.py index 6f0d2a6d3591ca4c0f5a3d7b88f9009b3696c127..6604495a2b5c91e0c6257653f617a23d21c4e5e0 100755 --- a/tools/train.py +++ b/tools/train.py @@ -30,8 +30,10 @@ warnings.filterwarnings('ignore') import paddle from ppdet.core.workspace import load_config, merge_config -from ppdet.engine import Trainer, init_parallel_env, set_random_seed, init_fleet_env + +from ppdet.engine import Trainer, TrainerCot, init_parallel_env, set_random_seed, init_fleet_env from ppdet.engine.trainer_ssod import Trainer_DenseTeacher + from ppdet.slim import build_slim_model from ppdet.utils.cli import ArgsParser, merge_args @@ -125,6 +127,7 @@ def run(FLAGS, cfg): if FLAGS.enable_ce: set_random_seed(0) + # build trainer ssod_method = cfg.get('ssod_method', None) if ssod_method is not None: if ssod_method == 'DenseTeacher': @@ -133,8 +136,9 @@ def run(FLAGS, cfg): raise ValueError( "Semi-Supervised Object Detection only support DenseTeacher now." ) + elif cfg.get('use_cot', False): + trainer = TrainerCot(cfg, mode='train') else: - # build trainer trainer = Trainer(cfg, mode='train') # load weights