未验证 提交 a40d2d3a 编写于 作者: T Taojerxi 提交者: GitHub

Few shot (#7503)

* finish_cotuning

* fix_bug

* add_readme

* finish contrast
上级 194ad728
# Co-tuning for Transfer Learning <br />Supervised Contrastive Learning
## Data preparation
[Kaggle数据集](https://www.kaggle.com/andrewmvd/road-sign-detection) 比赛数据为例,说明如何准备自定义数据。
Kaggle上的 [road-sign-detection](https://www.kaggle.com/andrewmvd/road-sign-detection) 比赛数据包含877张图像,数据类别4类:crosswalk,speedlimit,stop,trafficlight。
可从Kaggle上下载,也可以从[下载链接](https://paddlemodels.bj.bcebos.com/object_detection/roadsign_voc.tar) 下载。
分别从原始数据集中每类选取相同样本(例如:10shots即每类都有十个训练样本)训练即可。<br />
工业数据集使用PKU-Market-PCB,该数据集用于印刷电路板(PCB)的瑕疵检测,提供了6种常见的PCB缺陷[下载链接](./configs/ppyoloe/application/README.md)
## Model Zoo
| 骨架网络 | 网络类型 | 每张GPU图片个数 | 每类样本个数 | Box AP | 下载 | 配置文件 |
| :------------------- | :------------- | :-----: | :------------: | :-----: | :-----------------------------------------------------: | :-----: |
| ResNet50-vd | Faster | 1 | 10 | 60.1 | [下载链接](https://bj.bcebos.com/v1/paddledet/models/faster_rcnn_r50_vd_fpn_1x_coco.pdparams) | [配置文件](./faster_rcnn_r50_vd_fpn_1x_coco_cotuning_roadsign.yml) |
| PPYOLOE_crn_s | PPYOLOE | 1 | 30 | 17.8 | [下载链接](https://bj.bcebos.com/v1/paddledet/models/ppyoloe_plus_crn_s_80e_contrast_pcb.pdparams) |[配置文件](./ppyoloe_plus_crn_s_80e_contrast_pcb.yml) |
## Compare-cotuning
| 骨架网络 | 网络类型 | 每张GPU图片个数 |每类样本个数 | Cotuning | Box AP |
| :------------------- | :------------- | :-----: | :-----: | :------------: | :-----: |
| ResNet50-vd | Faster | 1 | 10 | False | 56.7 |
| ResNet50-vd | Faster | 1 | 10 | True | 60.1 |
## Compare-contrast
| 骨架网络 | 网络类型 | 每张GPU图片个数 | 每类样本个数 | Contrast | Box AP |
| :------------------- | :------------- | :-----: | :-----: | :------------: | :-----: |
| PPYOLOE_crn_s | PPYOLOE | 1 | 30 | False | 15.4 |
| PPYOLOE_crn_s | PPYOLOE | 1 | 30 | True | 17.8 |
## Citations
```
@article{you2020co,
title={Co-tuning for transfer learning},
author={You, Kaichao and Kou, Zhi and Long, Mingsheng and Wang, Jianmin},
journal={Advances in Neural Information Processing Systems},
volume={33},
pages={17236--17246},
year={2020}
}
@article{khosla2020supervised,
title={Supervised contrastive learning},
author={Khosla, Prannay and Teterwak, Piotr and Wang, Chen and Sarna, Aaron and Tian, Yonglong and Isola, Phillip and Maschinot, Aaron and Liu, Ce and Krishnan, Dilip},
journal={Advances in Neural Information Processing Systems},
volume={33},
pages={18661--18673},
year={2020}
}
```
\ No newline at end of file
worker_num: 2
TrainReader:
sample_transforms:
- Decode: {}
- RandomResize: {target_size: [[640, 1333], [672, 1333], [704, 1333], [736, 1333], [768, 1333], [800, 1333]], interp: 2, keep_ratio: True}
- RandomFlip: {prob: 0.5}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Permute: {}
batch_transforms:
- PadBatch: {pad_to_stride: 32}
batch_size: 1
shuffle: true
drop_last: true
collate_batch: false
EvalReader:
sample_transforms:
- Decode: {}
- Resize: {interp: 2, target_size: [800, 1333], keep_ratio: True}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Permute: {}
batch_transforms:
- PadBatch: {pad_to_stride: 32}
batch_size: 1
shuffle: false
drop_last: false
TestReader:
sample_transforms:
- Decode: {}
- Resize: {interp: 2, target_size: [800, 1333], keep_ratio: True}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Permute: {}
batch_transforms:
- PadBatch: {pad_to_stride: 32}
batch_size: 1
shuffle: false
drop_last: false
architecture: FasterRCNN
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet50_cos_pretrained.pdparams
FasterRCNN:
backbone: ResNet
rpn_head: RPNHead
bbox_head: BBoxHead
# post process
bbox_post_process: BBoxPostProcess
ResNet:
# index 0 stands for res2
depth: 50
norm_type: bn
freeze_at: 0
return_idx: [2]
num_stages: 3
RPNHead:
anchor_generator:
aspect_ratios: [0.5, 1.0, 2.0]
anchor_sizes: [32, 64, 128, 256, 512]
strides: [16]
rpn_target_assign:
batch_size_per_im: 256
fg_fraction: 0.5
negative_overlap: 0.3
positive_overlap: 0.7
use_random: True
train_proposal:
min_size: 0.0
nms_thresh: 0.7
pre_nms_top_n: 12000
post_nms_top_n: 2000
topk_after_collect: False
test_proposal:
min_size: 0.0
nms_thresh: 0.7
pre_nms_top_n: 6000
post_nms_top_n: 1000
BBoxHead:
head: Res5Head
roi_extractor:
resolution: 14
sampling_ratio: 0
aligned: True
bbox_assigner: BBoxAssigner
with_pool: true
BBoxAssigner:
batch_size_per_im: 512
bg_thresh: 0.5
fg_thresh: 0.5
fg_fraction: 0.25
use_random: True
BBoxPostProcess:
decode: RCNNBox
nms:
name: MultiClassNMS
keep_top_k: 100
score_threshold: 0.05
nms_threshold: 0.5
architecture: FasterRCNN
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet50_cos_pretrained.pdparams
FasterRCNN:
backbone: ResNet
neck: FPN
rpn_head: RPNHead
bbox_head: BBoxHead
# post process
bbox_post_process: BBoxPostProcess
ResNet:
# index 0 stands for res2
depth: 50
norm_type: bn
freeze_at: 0
return_idx: [0,1,2,3]
num_stages: 4
FPN:
out_channel: 256
RPNHead:
anchor_generator:
aspect_ratios: [0.5, 1.0, 2.0]
anchor_sizes: [[32], [64], [128], [256], [512]]
strides: [4, 8, 16, 32, 64]
rpn_target_assign:
batch_size_per_im: 256
fg_fraction: 0.5
negative_overlap: 0.3
positive_overlap: 0.7
use_random: True
train_proposal:
min_size: 0.0
nms_thresh: 0.7
pre_nms_top_n: 2000
post_nms_top_n: 1000
topk_after_collect: True
test_proposal:
min_size: 0.0
nms_thresh: 0.7
pre_nms_top_n: 1000
post_nms_top_n: 1000
BBoxHead:
head: TwoFCHead
roi_extractor:
resolution: 7
sampling_ratio: 0
aligned: True
bbox_assigner: BBoxAssigner
BBoxAssigner:
batch_size_per_im: 512
bg_thresh: 0.5
fg_thresh: 0.5
fg_fraction: 0.25
use_random: True
TwoFCHead:
out_channel: 1024
BBoxPostProcess:
decode: RCNNBox
nms:
name: MultiClassNMS
keep_top_k: 100
score_threshold: 0.05
nms_threshold: 0.5
worker_num: 2
TrainReader:
sample_transforms:
- Decode: {}
- RandomResize: {target_size: [[640, 1333], [672, 1333], [704, 1333], [736, 1333], [768, 1333], [800, 1333]], interp: 2, keep_ratio: True}
- RandomFlip: {prob: 0.5}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Permute: {}
batch_transforms:
- PadBatch: {pad_to_stride: -1}
batch_size: 1
shuffle: true
drop_last: true
collate_batch: false
EvalReader:
sample_transforms:
- Decode: {}
- Resize: {interp: 2, target_size: [800, 1333], keep_ratio: True}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Permute: {}
batch_transforms:
- PadBatch: {pad_to_stride: -1}
batch_size: 1
shuffle: false
drop_last: false
TestReader:
sample_transforms:
- Decode: {}
- Resize: {interp: 2, target_size: [800, 1333], keep_ratio: True}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Permute: {}
batch_transforms:
- PadBatch: {pad_to_stride: -1}
batch_size: 1
shuffle: false
drop_last: false
epoch: 12
LearningRate:
base_lr: 0.01
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [8, 11]
- !LinearWarmup
start_factor: 0.1
steps: 1000
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0001
type: L2
epoch: 80
LearningRate:
base_lr: 0.001
schedulers:
- !CosineDecay
max_epochs: 96
- !LinearWarmup
start_factor: 0.
epochs: 5
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0005
type: L2
architecture: YOLOv3
norm_type: sync_bn
use_ema: true
use_cot: False
ema_decay: 0.9998
ema_black_list: ['proj_conv.weight']
custom_black_list: ['reduce_mean']
YOLOv3:
backbone: CSPResNet
neck: CustomCSPPAN
yolo_head: PPYOLOEHead
post_process: ~
CSPResNet:
layers: [3, 6, 6, 3]
channels: [64, 128, 256, 512, 1024]
return_idx: [1, 2, 3]
use_large_stem: True
use_alpha: True
CustomCSPPAN:
out_channels: [768, 384, 192]
stage_num: 1
block_num: 3
act: 'swish'
spp: true
PPYOLOEHead:
fpn_strides: [32, 16, 8]
grid_cell_scale: 5.0
grid_cell_offset: 0.5
static_assigner_epoch: 30
use_varifocal_loss: True
loss_weight: {class: 1.0, iou: 2.5, dfl: 0.5}
static_assigner:
name: ATSSAssigner
topk: 9
assigner:
name: TaskAlignedAssigner
topk: 13
alpha: 1.0
beta: 6.0
nms:
name: MultiClassNMS
nms_top_k: 1000
keep_top_k: 300
score_threshold: 0.01
nms_threshold: 0.7
worker_num: 4
eval_height: &eval_height 640
eval_width: &eval_width 640
eval_size: &eval_size [*eval_height, *eval_width]
TrainReader:
sample_transforms:
- Decode: {}
- RandomDistort: {}
- RandomExpand: {fill_value: [123.675, 116.28, 103.53]}
- RandomCrop: {}
- RandomFlip: {}
batch_transforms:
- BatchRandomResize: {target_size: [320, 352, 384, 416, 448, 480, 512, 544, 576, 608, 640, 672, 704, 736, 768], random_size: True, random_interp: True, keep_ratio: False}
- NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none}
- Permute: {}
- PadGT: {}
batch_size: 8
shuffle: true
drop_last: true
use_shared_memory: true
collate_batch: true
EvalReader:
sample_transforms:
- Decode: {}
- Resize: {target_size: *eval_size, keep_ratio: False, interp: 2}
- NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none}
- Permute: {}
batch_size: 2
TestReader:
inputs_def:
image_shape: [3, *eval_height, *eval_width]
sample_transforms:
- Decode: {}
- Resize: {target_size: *eval_size, keep_ratio: False, interp: 2}
- NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none}
- Permute: {}
batch_size: 1
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'_base_/optimizer_1x.yml',
'_base_/faster_rcnn_r50_fpn.yml',
'_base_/faster_fpn_reader.yml',
]
pretrain_weights: https://paddledet.bj.bcebos.com/models/faster_rcnn_r50_vd_fpn_1x_coco.pdparams
weights: output/faster_rcnn_r50_vd_fpn_1x_coco_cotuning_roadsign/model_final
snapshot_epoch: 5
ResNet:
# index 0 stands for res2
depth: 50
variant: d
norm_type: bn
freeze_at: 0
return_idx: [0,1,2,3]
num_stages: 4
epoch: 30
LearningRate:
base_lr: 0.001
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [8, 11]
- !LinearWarmup
start_factor: 0.1
steps: 1000
use_cot: True
BBoxHead:
head: TwoFCHead
roi_extractor:
resolution: 7
sampling_ratio: 0
aligned: True
bbox_assigner: BBoxAssigner
cot_classes: 80
loss_cot:
name: COTLoss
cot_lambda: 1
cot_scale: 1
num_classes: 4
metric: COCO
map_type: integral
TrainDataset:
!COCODataSet
image_dir: images
anno_path: annotations/train_shots10.json
dataset_dir: dataset/roadsign_coco
data_fields: ['image', 'gt_bbox', 'gt_class', 'is_crowd']
EvalDataset:
!COCODataSet
image_dir: images
anno_path: annotations/roadsign_valid.json
dataset_dir: dataset/roadsign_coco
TestDataset:
!ImageFolder
anno_path: annotations/roadsign_valid.json
dataset_dir: dataset/roadsign_coco
\ No newline at end of file
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'./_base_/optimizer_80e.yml',
'./_base_/ppyoloe_plus_crn.yml',
'./_base_/ppyoloe_plus_reader.yml',
]
log_iter: 100
snapshot_epoch: 10
weights: output/ppyoloe_plus_crn_s_80e_contrast_pcb/model_final
pretrain_weights: https://bj.bcebos.com/v1/paddledet/models/pretrained/ppyoloe_crn_s_obj365_pretrained.pdparams
depth_mult: 0.33
width_mult: 0.50
epoch: 80
LearningRate:
base_lr: 0.0001
schedulers:
- !CosineDecay
max_epochs: 96
- !LinearWarmup
start_factor: 0.
epochs: 5
YOLOv3:
backbone: CSPResNet
neck: CustomCSPPAN
yolo_head: PPYOLOEContrastHead
post_process: ~
PPYOLOEContrastHead:
fpn_strides: [32, 16, 8]
grid_cell_scale: 5.0
grid_cell_offset: 0.5
static_assigner_epoch: 100
use_varifocal_loss: True
loss_weight: {class: 1.0, iou: 2.5, dfl: 0.5, contrast: 0.2}
static_assigner:
name: ATSSAssigner
topk: 9
assigner:
name: TaskAlignedAssigner
topk: 13
alpha: 1.0
beta: 6.0
contrast_loss:
name: SupContrast
temperature: 100
sample_num: 2048
thresh: 0.75
nms:
name: MultiClassNMS
nms_top_k: 1000
keep_top_k: 300
score_threshold: 0.01
nms_threshold: 0.7
num_classes: 6
metric: COCO
map_type: integral
TrainDataset:
!COCODataSet
image_dir: images
anno_path: pcb_cocoanno/train_shots30.json
dataset_dir: dataset/pcb
data_fields: ['image', 'gt_bbox', 'gt_class', 'is_crowd']
EvalDataset:
!COCODataSet
image_dir: images
anno_path: pcb_cocoanno/val.json
dataset_dir: dataset/pcb
TestDataset:
!ImageFolder
anno_path: pcb_cocoanno/val.json
dataset_dir: dataset/pcb
\ No newline at end of file
......@@ -15,6 +15,9 @@
from . import trainer
from .trainer import *
from . import trainer_cot
from .trainer_cot import *
from . import callbacks
from .callbacks import *
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ppdet.core.workspace import create
from ppdet.utils.logger import setup_logger
logger = setup_logger('ppdet.engine')
from . import Trainer
__all__ = ['TrainerCot']
class TrainerCot(Trainer):
"""
Trainer for label-cotuning
calculate the relationship between base_classes and novel_classes
"""
def __init__(self, cfg, mode='train'):
super(TrainerCot, self).__init__(cfg, mode)
self.cotuning_init()
def cotuning_init(self):
num_classes_novel = self.cfg['num_classes']
self.load_weights(self.cfg.pretrain_weights)
self.model.eval()
relationship = self.model.relationship_learning(self.loader, num_classes_novel)
self.model.init_cot_head(relationship)
self.optimizer = create('OptimizerBuilder')(self.lr, self.model)
......@@ -19,6 +19,7 @@ from __future__ import print_function
import paddle
from ppdet.core.workspace import register, create
from .meta_arch import BaseArch
import numpy as np
__all__ = ['FasterRCNN']
......@@ -51,6 +52,9 @@ class FasterRCNN(BaseArch):
self.bbox_head = bbox_head
self.bbox_post_process = bbox_post_process
def init_cot_head(self, relationship):
self.bbox_head.init_cot_head(relationship)
@classmethod
def from_config(cls, cfg, *args, **kwargs):
backbone = create(cfg['backbone'])
......@@ -104,3 +108,38 @@ class FasterRCNN(BaseArch):
bbox_pred, bbox_num = self._forward()
output = {'bbox': bbox_pred, 'bbox_num': bbox_num}
return output
def target_bbox_forward(self, data):
body_feats = self.backbone(data)
if self.neck is not None:
body_feats = self.neck(body_feats)
rois = [roi for roi in data['gt_bbox']]
rois_num = paddle.concat([paddle.shape(roi)[0] for roi in rois])
preds, _ = self.bbox_head(body_feats, rois, rois_num, None, cot=True)
return preds
def relationship_learning(self, loader, num_classes_novel):
print('computing relationship')
train_labels_list = []
label_list = []
for step_id, data in enumerate(loader):
_, bbox_prob = self.target_bbox_forward(data)
batch_size = data['im_id'].shape[0]
for i in range(batch_size):
num_bbox = data['gt_class'][i].shape[0]
train_labels = data['gt_class'][i]
train_labels_list.append(train_labels.numpy().squeeze(1))
base_labels = bbox_prob.detach().numpy()[:,:-1]
label_list.append(base_labels)
labels = np.concatenate(train_labels_list, 0)
probabilities = np.concatenate(label_list, 0)
N_t = np.max(labels) + 1
conditional = []
for i in range(N_t):
this_class = probabilities[labels == i]
average = np.mean(this_class, axis=0, keepdims=True)
conditional.append(average)
return np.concatenate(conditional)
\ No newline at end of file
......@@ -37,6 +37,7 @@ from . import fcosr_head
from . import ppyoloe_r_head
from . import ld_gfl_head
from . import yolof_head
from . import ppyoloe_contrast_head
from .bbox_head import *
from .mask_head import *
......@@ -63,3 +64,4 @@ from .fcosr_head import *
from .ld_gfl_head import *
from .ppyoloe_r_head import *
from .yolof_head import *
from .ppyoloe_contrast_head import *
\ No newline at end of file
......@@ -160,8 +160,8 @@ class XConvNormHead(nn.Layer):
@register
class BBoxHead(nn.Layer):
__shared__ = ['num_classes']
__inject__ = ['bbox_assigner', 'bbox_loss']
__shared__ = ['num_classes', 'use_cot']
__inject__ = ['bbox_assigner', 'bbox_loss', 'loss_cot']
"""
RCNN bbox head
......@@ -173,7 +173,10 @@ class BBoxHead(nn.Layer):
box.
with_pool (bool): Whether to use pooling for the RoI feature.
num_classes (int): The number of classes
bbox_weight (List[float]): The weight to get the decode box
bbox_weight (List[float]): The weight to get the decode box
cot_classes (int): The number of base classes
loss_cot (object): The module of Label-cotuning
use_cot(bool): whether to use Label-cotuning
"""
def __init__(self,
......@@ -185,7 +188,10 @@ class BBoxHead(nn.Layer):
num_classes=80,
bbox_weight=[10., 10., 5., 5.],
bbox_loss=None,
loss_normalize_pos=False):
loss_normalize_pos=False,
cot_classes=None,
loss_cot='COTLoss',
use_cot=False):
super(BBoxHead, self).__init__()
self.head = head
self.roi_extractor = roi_extractor
......@@ -199,11 +205,29 @@ class BBoxHead(nn.Layer):
self.bbox_loss = bbox_loss
self.loss_normalize_pos = loss_normalize_pos
self.bbox_score = nn.Linear(
in_channel,
self.num_classes + 1,
weight_attr=paddle.ParamAttr(initializer=Normal(
mean=0.0, std=0.01)))
self.loss_cot = loss_cot
self.cot_relation = None
self.cot_classes = cot_classes
self.use_cot = use_cot
if use_cot:
self.cot_bbox_score = nn.Linear(
in_channel,
self.num_classes + 1,
weight_attr=paddle.ParamAttr(initializer=Normal(
mean=0.0, std=0.01)))
self.bbox_score = nn.Linear(
in_channel,
self.cot_classes + 1,
weight_attr=paddle.ParamAttr(initializer=Normal(
mean=0.0, std=0.01)))
self.cot_bbox_score.skip_quant = True
else:
self.bbox_score = nn.Linear(
in_channel,
self.num_classes + 1,
weight_attr=paddle.ParamAttr(initializer=Normal(
mean=0.0, std=0.01)))
self.bbox_score.skip_quant = True
self.bbox_delta = nn.Linear(
......@@ -215,6 +239,9 @@ class BBoxHead(nn.Layer):
self.assigned_label = None
self.assigned_rois = None
def init_cot_head(self, relationship):
self.cot_relation = relationship
@classmethod
def from_config(cls, cfg, input_shape):
roi_pooler = cfg['roi_extractor']
......@@ -229,7 +256,7 @@ class BBoxHead(nn.Layer):
'in_channel': head.out_shape[0].channels
}
def forward(self, body_feats=None, rois=None, rois_num=None, inputs=None):
def forward(self, body_feats=None, rois=None, rois_num=None, inputs=None, cot=False):
"""
body_feats (list[Tensor]): Feature maps from backbone
rois (list[Tensor]): RoIs generated from RPN module
......@@ -248,7 +275,11 @@ class BBoxHead(nn.Layer):
feat = paddle.squeeze(feat, axis=[2, 3])
else:
feat = bbox_feat
scores = self.bbox_score(feat)
if self.use_cot:
scores = self.cot_bbox_score(feat)
cot_scores = self.bbox_score(feat)
else:
scores = self.bbox_score(feat)
deltas = self.bbox_delta(feat)
if self.training:
......@@ -259,11 +290,19 @@ class BBoxHead(nn.Layer):
rois,
self.bbox_weight,
loss_normalize_pos=self.loss_normalize_pos)
if self.cot_relation is not None:
loss_cot = self.loss_cot(cot_scores, targets, self.cot_relation)
loss.update(loss_cot)
return loss, bbox_feat
else:
pred = self.get_prediction(scores, deltas)
if cot:
pred = self.get_prediction(cot_scores, deltas)
else:
pred = self.get_prediction(scores, deltas)
return pred, self.head
def get_loss(self,
scores,
deltas,
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from ppdet.core.workspace import register
from ..bbox_utils import batch_distance2bbox
from ..losses import GIoULoss
from ..initializer import bias_init_with_prob, constant_, normal_
from ..assigners.utils import generate_anchors_for_grid_cell
from ppdet.modeling.backbones.cspresnet import ConvBNLayer
from ppdet.modeling.ops import get_static_shape, get_act_fn
from ppdet.modeling.layers import MultiClassNMS
from ppdet.modeling.heads.ppyoloe_head import PPYOLOEHead
__all__ = ['PPYOLOEContrastHead']
@register
class PPYOLOEContrastHead(PPYOLOEHead):
__shared__ = [
'num_classes', 'eval_size', 'trt', 'exclude_nms',
'exclude_post_process', 'use_shared_conv'
]
__inject__ = ['static_assigner', 'assigner', 'nms', 'contrast_loss']
def __init__(self,
in_channels=[1024, 512, 256],
num_classes=80,
act='swish',
fpn_strides=(32, 16, 8),
grid_cell_scale=5.0,
grid_cell_offset=0.5,
reg_max=16,
reg_range=None,
static_assigner_epoch=4,
use_varifocal_loss=True,
static_assigner='ATSSAssigner',
assigner='TaskAlignedAssigner',
contrast_loss='SupContrast',
nms='MultiClassNMS',
eval_size=None,
loss_weight={
'class': 1.0,
'iou': 2.5,
'dfl': 0.5,
},
trt=False,
exclude_nms=False,
exclude_post_process=False,
use_shared_conv=True):
super().__init__(in_channels,
num_classes,
act,
fpn_strides,
grid_cell_scale,
grid_cell_offset,
reg_max,
reg_range,
static_assigner_epoch,
use_varifocal_loss,
static_assigner,
assigner,
nms,
eval_size,
loss_weight,
trt,
exclude_nms,
exclude_post_process,
use_shared_conv)
assert len(in_channels) > 0, "len(in_channels) should > 0"
self.contrast_loss = contrast_loss
self.contrast_encoder = nn.LayerList()
for in_c in self.in_channels:
self.contrast_encoder.append(
nn.Conv2D(
in_c, 128, 3, padding=1))
self._init_contrast_encoder()
def _init_contrast_encoder(self):
bias_en = bias_init_with_prob(0.01)
for en_ in self.contrast_encoder:
constant_(en_.weight)
constant_(en_.bias, bias_en)
def forward_train(self, feats, targets):
anchors, anchor_points, num_anchors_list, stride_tensor = \
generate_anchors_for_grid_cell(
feats, self.fpn_strides, self.grid_cell_scale,
self.grid_cell_offset)
cls_score_list, reg_distri_list = [], []
contrast_encoder_list = []
for i, feat in enumerate(feats):
avg_feat = F.adaptive_avg_pool2d(feat, (1, 1))
cls_logit = self.pred_cls[i](self.stem_cls[i](feat, avg_feat) +
feat)
reg_distri = self.pred_reg[i](self.stem_reg[i](feat, avg_feat))
contrast_logit = self.contrast_encoder[i](self.stem_cls[i](feat, avg_feat) +
feat)
contrast_encoder_list.append(contrast_logit.flatten(2).transpose([0, 2, 1]))
# cls and reg
cls_score = F.sigmoid(cls_logit)
cls_score_list.append(cls_score.flatten(2).transpose([0, 2, 1]))
reg_distri_list.append(reg_distri.flatten(2).transpose([0, 2, 1]))
cls_score_list = paddle.concat(cls_score_list, axis=1)
reg_distri_list = paddle.concat(reg_distri_list, axis=1)
contrast_encoder_list = paddle.concat(contrast_encoder_list, axis=1)
return self.get_loss([
cls_score_list, reg_distri_list, contrast_encoder_list, anchors, anchor_points,
num_anchors_list, stride_tensor
], targets)
def get_loss(self, head_outs, gt_meta):
pred_scores, pred_distri, pred_contrast_encoder, anchors,\
anchor_points, num_anchors_list, stride_tensor = head_outs
anchor_points_s = anchor_points / stride_tensor
pred_bboxes = self._bbox_decode(anchor_points_s, pred_distri)
gt_labels = gt_meta['gt_class']
gt_bboxes = gt_meta['gt_bbox']
pad_gt_mask = gt_meta['pad_gt_mask']
# label assignment
if gt_meta['epoch_id'] < self.static_assigner_epoch:
assigned_labels, assigned_bboxes, assigned_scores = \
self.static_assigner(
anchors,
num_anchors_list,
gt_labels,
gt_bboxes,
pad_gt_mask,
bg_index=self.num_classes,
pred_bboxes=pred_bboxes.detach() * stride_tensor)
alpha_l = 0.25
else:
if self.sm_use:
assigned_labels, assigned_bboxes, assigned_scores = \
self.assigner(
pred_scores.detach(),
pred_bboxes.detach() * stride_tensor,
anchor_points,
stride_tensor,
gt_labels,
gt_bboxes,
pad_gt_mask,
bg_index=self.num_classes)
else:
assigned_labels, assigned_bboxes, assigned_scores = \
self.assigner(
pred_scores.detach(),
pred_bboxes.detach() * stride_tensor,
anchor_points,
num_anchors_list,
gt_labels,
gt_bboxes,
pad_gt_mask,
bg_index=self.num_classes)
alpha_l = -1
# rescale bbox
assigned_bboxes /= stride_tensor
# cls loss
if self.use_varifocal_loss:
one_hot_label = F.one_hot(assigned_labels,
self.num_classes + 1)[..., :-1]
loss_cls = self._varifocal_loss(pred_scores, assigned_scores,
one_hot_label)
else:
loss_cls = self._focal_loss(pred_scores, assigned_scores, alpha_l)
assigned_scores_sum = assigned_scores.sum()
if paddle.distributed.get_world_size() > 1:
paddle.distributed.all_reduce(assigned_scores_sum)
assigned_scores_sum /= paddle.distributed.get_world_size()
assigned_scores_sum = paddle.clip(assigned_scores_sum, min=1.)
loss_cls /= assigned_scores_sum
loss_l1, loss_iou, loss_dfl = \
self._bbox_loss(pred_distri, pred_bboxes, anchor_points_s,
assigned_labels, assigned_bboxes, assigned_scores,
assigned_scores_sum)
# contrast loss
loss_contrast = self.contrast_loss(pred_contrast_encoder.reshape([-1, pred_contrast_encoder.shape[-1]]), \
assigned_labels.reshape([-1]), assigned_scores.max(-1).reshape([-1]))
loss = self.loss_weight['class'] * loss_cls + \
self.loss_weight['iou'] * loss_iou + \
self.loss_weight['dfl'] * loss_dfl + \
self.loss_weight['contrast'] * loss_contrast
out_dict = {
'loss': loss,
'loss_cls': loss_cls,
'loss_iou': loss_iou,
'loss_dfl': loss_dfl,
'loss_l1': loss_l1,
'loss_contrast': loss_contrast
}
return out_dict
......@@ -28,6 +28,8 @@ from . import sparsercnn_loss
from . import focal_loss
from . import smooth_l1_loss
from . import probiou_loss
from . import cot_loss
from . import supcontrast
from .yolo_loss import *
from .iou_aware_loss import *
......@@ -46,3 +48,5 @@ from .focal_loss import *
from .smooth_l1_loss import *
from .pose3d_loss import *
from .probiou_loss import *
from .cot_loss import *
from .supcontrast import *
\ No newline at end of file
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import numpy as np
from ppdet.core.workspace import register
__all__ = ['COTLoss']
@register
class COTLoss(nn.Layer):
__shared__ = ['num_classes']
def __init__(self,
num_classes=80,
cot_scale=1,
cot_lambda=1):
super(COTLoss, self).__init__()
self.cot_scale = cot_scale
self.cot_lambda = cot_lambda
self.num_classes = num_classes
def forward(self, scores, targets, cot_relation):
cls_name = 'loss_bbox_cls_cot'
loss_bbox = {}
tgt_labels, tgt_bboxes, tgt_gt_inds = targets
tgt_labels = paddle.concat(tgt_labels) if len(
tgt_labels) > 1 else tgt_labels[0]
mask = (tgt_labels < self.num_classes)
valid_inds = paddle.nonzero(tgt_labels >= 0).flatten()
if valid_inds.shape[0] == 0:
loss_bbox[cls_name] = paddle.zeros([1], dtype='float32')
else:
tgt_labels = tgt_labels.cast('int64')
valid_cot_targets = []
for i in range(tgt_labels.shape[0]):
train_label = tgt_labels[i]
if train_label < self.num_classes:
valid_cot_targets.append(cot_relation[train_label])
coco_targets = paddle.to_tensor(valid_cot_targets)
coco_targets.stop_gradient = True
coco_loss = - coco_targets * F.log_softmax(scores[mask][:, :-1] * self.cot_scale)
loss_bbox[cls_name] = self.cot_lambda * paddle.mean(paddle.sum(coco_loss, axis=-1))
return loss_bbox
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import random
from ppdet.core.workspace import register
__all__ = ['SupContrast']
@register
class SupContrast(nn.Layer):
__shared__ = [
'num_classes'
]
def __init__(self, num_classes=80, temperature=2.5, sample_num=4096, thresh=0.75):
super(SupContrast, self).__init__()
self.num_classes = num_classes
self.temperature = temperature
self.sample_num = sample_num
self.thresh = thresh
def forward(self, features, labels, scores):
assert features.shape[0] == labels.shape[0] == scores.shape[0]
positive_mask = (labels < self.num_classes)
positive_features, positive_labels, positive_scores = features[positive_mask], labels[positive_mask], \
scores[positive_mask]
negative_mask = (labels == self.num_classes)
negative_features, negative_labels, negative_scores = features[negative_mask], labels[negative_mask], \
scores[negative_mask]
N = negative_features.shape[0]
S = self.sample_num - positive_mask.sum()
index = paddle.to_tensor(random.sample(range(N), int(S)), dtype='int32')
negative_features = paddle.index_select(x=negative_features, index=index, axis=0)
negative_labels = paddle.index_select(x=negative_labels, index=index, axis=0)
negative_scores = paddle.index_select(x=negative_scores, index=index, axis=0)
features = paddle.concat([positive_features, negative_features], 0)
labels = paddle.concat([positive_labels, negative_labels], 0)
scores = paddle.concat([positive_scores, negative_scores], 0)
if len(labels.shape) == 1:
labels = labels.reshape([-1, 1])
label_mask = paddle.equal(labels, labels.T).detach()
similarity = (paddle.matmul(features, features.T) / self.temperature)
sim_row_max = paddle.max(similarity, axis=1, keepdim=True)
similarity = similarity - sim_row_max
logits_mask = paddle.ones_like(similarity).detach()
logits_mask.fill_diagonal_(0)
exp_sim = paddle.exp(similarity) * logits_mask
log_prob = similarity - paddle.log(exp_sim.sum(axis=1, keepdim=True))
per_label_log_prob = (log_prob * logits_mask * label_mask).sum(1) / label_mask.sum(1)
keep = scores > self.thresh
per_label_log_prob = per_label_log_prob[keep]
loss = -per_label_log_prob
return loss.mean()
\ No newline at end of file
......@@ -30,8 +30,10 @@ warnings.filterwarnings('ignore')
import paddle
from ppdet.core.workspace import load_config, merge_config
from ppdet.engine import Trainer, init_parallel_env, set_random_seed, init_fleet_env
from ppdet.engine import Trainer, TrainerCot, init_parallel_env, set_random_seed, init_fleet_env
from ppdet.engine.trainer_ssod import Trainer_DenseTeacher
from ppdet.slim import build_slim_model
from ppdet.utils.cli import ArgsParser, merge_args
......@@ -125,6 +127,7 @@ def run(FLAGS, cfg):
if FLAGS.enable_ce:
set_random_seed(0)
# build trainer
ssod_method = cfg.get('ssod_method', None)
if ssod_method is not None:
if ssod_method == 'DenseTeacher':
......@@ -133,8 +136,9 @@ def run(FLAGS, cfg):
raise ValueError(
"Semi-Supervised Object Detection only support DenseTeacher now."
)
elif cfg.get('use_cot', False):
trainer = TrainerCot(cfg, mode='train')
else:
# build trainer
trainer = Trainer(cfg, mode='train')
# load weights
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册