未验证 提交 1e211871 编写于 作者: W wjm 提交者: GitHub

Support ARSL(CVPR2023) for semi-supervised object detection (#7980)

* add SSOD_asrl

* modify traniner name

* add modelzoo

* add config

* add config

* add config

* modify cfg name

* modify cfg

* modify cfg

* modify checkpoint

* modify cfg

* add voc and lsj

* add voc and lsj

* del export

* modify

* modify

* refine codes

* fix fcos_head get_loss

* add export

* fix bug

* add export infer

* change

* retry

* fix eval infer

---------
Co-authored-by: Nnemonameless <nemonameless@qq.com>
上级 a18c8113
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
- [模型库](#模型库) - [模型库](#模型库)
- [Baseline](#Baseline) - [Baseline](#Baseline)
- [DenseTeacher](#DenseTeacher) - [DenseTeacher](#DenseTeacher)
- [ARSL](#ARSL)
- [半监督数据集准备](#半监督数据集准备) - [半监督数据集准备](#半监督数据集准备)
- [半监督检测配置](#半监督检测配置) - [半监督检测配置](#半监督检测配置)
- [训练集配置](#训练集配置) - [训练集配置](#训练集配置)
...@@ -23,7 +24,7 @@ ...@@ -23,7 +24,7 @@
- [引用](#引用) - [引用](#引用)
## 简介 ## 简介
半监督目标检测(Semi DET)是**同时使用有标注数据和无标注数据**进行训练的目标检测,既可以极大地节省标注成本,也可以充分利用无标注数据进一步提高检测精度。PaddleDetection团队复现了[DenseTeacher](denseteacher)半监督检测算法,用户可以下载使用。 半监督目标检测(Semi DET)是**同时使用有标注数据和无标注数据**进行训练的目标检测,既可以极大地节省标注成本,也可以充分利用无标注数据进一步提高检测精度。PaddleDetection团队提供了[DenseTeacher](denseteacher/)[ARSL](arsl/)等最前沿的半监督检测算法,用户可以下载使用。
## 模型库 ## 模型库
...@@ -41,6 +42,25 @@ ...@@ -41,6 +42,25 @@
| DenseTeacher-FCOS(LSJ)| 10% | [sup_config](./baseline/fcos_r50_fpn_2x_coco_sup010.yml) | 24 (17424) | 26.3 | **37.1(LSJ)** | 240 (174240) | [download](https://paddledet.bj.bcebos.com/models/denseteacher_fcos_r50_fpn_coco_semi010_lsj.pdparams) | [config](denseteacher/denseteacher_fcos_r50_fpn_coco_semi010_lsj.yml) | | DenseTeacher-FCOS(LSJ)| 10% | [sup_config](./baseline/fcos_r50_fpn_2x_coco_sup010.yml) | 24 (17424) | 26.3 | **37.1(LSJ)** | 240 (174240) | [download](https://paddledet.bj.bcebos.com/models/denseteacher_fcos_r50_fpn_coco_semi010_lsj.pdparams) | [config](denseteacher/denseteacher_fcos_r50_fpn_coco_semi010_lsj.yml) |
| DenseTeacher-FCOS |100%(full)| [sup_config](./../fcos/fcos_r50_fpn_iou_multiscale_2x_coco.ymll) | 24 (175896) | 42.6 | **44.2** | 24 (175896)| [download](https://paddledet.bj.bcebos.com/models/denseteacher_fcos_r50_fpn_coco_full.pdparams) | [config](denseteacher/denseteacher_fcos_r50_fpn_coco_full.yml) | | DenseTeacher-FCOS |100%(full)| [sup_config](./../fcos/fcos_r50_fpn_iou_multiscale_2x_coco.ymll) | 24 (175896) | 42.6 | **44.2** | 24 (175896)| [download](https://paddledet.bj.bcebos.com/models/denseteacher_fcos_r50_fpn_coco_full.pdparams) | [config](denseteacher/denseteacher_fcos_r50_fpn_coco_full.yml) |
| 模型 | 监督数据比例 | Sup Baseline | Sup Epochs (Iters) | Sup mAP<sup>val<br>0.5:0.95 | Semi mAP<sup>val<br>0.5:0.95 | Semi Epochs (Iters) | 模型下载 | 配置文件 |
| :------------: | :---------: | :---------------------: | :---------------------: |:---------------------------: |:----------------------------: | :------------------: |:--------: |:----------: |
| DenseTeacher-PPYOLOE+_s | 5% | [sup_config](./baseline/ppyoloe_plus_crn_s_80e_coco_sup005.yml) | 80 (14480) | 32.8 | **34.0** | 200 (36200) | [download](https://paddledet.bj.bcebos.com/models/denseteacher_ppyoloe_plus_crn_s_coco_semi005.pdparams) | [config](denseteacher/denseteacher_ppyoloe_plus_crn_s_coco_semi005.yml) |
| DenseTeacher-PPYOLOE+_s | 10% | [sup_config](./baseline/ppyoloe_plus_crn_s_80e_coco_sup010.yml) | 80 (14480) | 35.3 | **37.5** | 200 (36200) | [download](https://paddledet.bj.bcebos.com/models/denseteacher_ppyoloe_plus_crn_s_coco_semi010.pdparams) | [config](denseteacher/denseteacher_ppyoloe_plus_crn_s_coco_semi010.yml) |
| DenseTeacher-PPYOLOE+_l | 5% | [sup_config](./baseline/ppyoloe_plus_crn_s_80e_coco_sup005.yml) | 80 (14480) | 42.9 | **45.4** | 200 (36200) | [download](https://paddledet.bj.bcebos.com/models/denseteacher_ppyoloe_plus_crn_l_coco_semi005.pdparams) | [config](denseteacher/denseteacher_ppyoloe_plus_crn_l_coco_semi005.yml) |
| DenseTeacher-PPYOLOE+_l | 10% | [sup_config](./baseline/ppyoloe_plus_crn_l_80e_coco_sup010.yml) | 80 (14480) | 45.7 | **47.4** | 200 (36200) | [download](https://paddledet.bj.bcebos.com/models/denseteacher_ppyoloe_plus_crn_l_coco_semi010.pdparams) | [config](denseteacher/denseteacher_ppyoloe_plus_crn_l_coco_semi010.yml) |
### [ARSL](arsl)
| 模型 | COCO监督数据比例 | Semi mAP<sup>val<br>0.5:0.95 | Semi Epochs (Iters) | 模型下载 | 配置文件 |
| :------------: | :---------:|:----------------------------: | :------------------: |:--------: |:----------: |
| ARSL-FCOS | 1% | **22.8** | 240 (87120) | [download](https://paddledet.bj.bcebos.com/models/arsl_fcos_r50_fpn_coco_semi001.pdparams) | [config](arsl/arsl_fcos_r50_fpn_coco_semi001.yml) |
| ARSL-FCOS | 5% | **33.1** | 240 (174240) | [download](https://paddledet.bj.bcebos.com/models/arsl_fcos_r50_fpn_coco_semi005.pdparams) | [config](arsl/arsl_fcos_r50_fpn_coco_semi005.yml ) |
| ARSL-FCOS | 10% | **36.9** | 240 (174240) | [download](https://paddledet.bj.bcebos.com/models/arsl_fcos_r50_fpn_coco_semi010.pdparams) | [config](arsl/arsl_fcos_r50_fpn_coco_semi010.yml ) |
| ARSL-FCOS | 10% | **38.5(LSJ)** | 240 (174240) | [download](https://paddledet.bj.bcebos.com/models/arsl_fcos_r50_fpn_coco_semi010_lsj.pdparams) | [config](arsl/arsl_fcos_r50_fpn_coco_semi010_lsj.yml ) |
| ARSL-FCOS | full(100%) | **45.1** | 240 (174240) | [download](https://paddledet.bj.bcebos.com/models/arsl_fcos_r50_fpn_coco_full.pdparams) | [config](arsl/arsl_fcos_r50_fpn_coco_full.yml ) |
## 半监督数据集准备 ## 半监督数据集准备
......
metric: COCO
num_classes: 20
# before training, change VOC to COCO format by 'python voc2coco.py'
# partial labeled COCO, use `SemiCOCODataSet` rather than `COCODataSet`
TrainDataset:
!SemiCOCODataSet
image_dir: VOC2007/JPEGImages
anno_path: PseudoAnnotations/VOC2007_trainval.json
dataset_dir: dataset/voc/VOCdevkit
data_fields: ['image', 'gt_bbox', 'gt_class', 'is_crowd']
# partial unlabeled COCO, use `SemiCOCODataSet` rather than `COCODataSet`
UnsupTrainDataset:
!SemiCOCODataSet
image_dir: VOC2012/JPEGImages
anno_path: PseudoAnnotations/VOC2012_trainval.json
dataset_dir: dataset/voc/VOCdevkit
data_fields: ['image']
supervised: False
EvalDataset:
!COCODataSet
image_dir: VOC2007/JPEGImages
anno_path: PseudoAnnotations/VOC2007_test.json
dataset_dir: dataset/voc/VOCdevkit/
allow_empty: true
TestDataset:
!ImageFolder
anno_path: PseudoAnnotations/VOC2007_test.json # also support txt (like VOC's label_list.txt)
dataset_dir: dataset/voc/VOCdevkit/ # if set, anno_path will be 'dataset_dir/anno_path'
# convert VOC xml to COCO format json
import xml.etree.ElementTree as ET
import os
import json
import argparse
# create and init coco json, img set, and class set
def init_json():
# create coco json
coco = dict()
coco['images'] = []
coco['type'] = 'instances'
coco['annotations'] = []
coco['categories'] = []
# voc classes
voc_class = [
'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat',
'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person',
'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'
]
# init json categories
image_set = set()
class_set = dict()
for cat_id, cat_name in enumerate(voc_class):
cat_item = dict()
cat_item['supercategory'] = 'none'
cat_item['id'] = cat_id
cat_item['name'] = cat_name
coco['categories'].append(cat_item)
class_set[cat_name] = cat_id
return coco, class_set, image_set
def getImgItem(file_name, size, img_id):
if file_name is None:
raise Exception('Could not find filename tag in xml file.')
if size['width'] is None:
raise Exception('Could not find width tag in xml file.')
if size['height'] is None:
raise Exception('Could not find height tag in xml file.')
image_item = dict()
image_item['id'] = img_id
image_item['file_name'] = file_name
image_item['width'] = size['width']
image_item['height'] = size['height']
return image_item
def getAnnoItem(object_name, image_id, ann_id, category_id, bbox):
annotation_item = dict()
annotation_item['segmentation'] = []
seg = []
# bbox[] is x,y,w,h
# left_top
seg.append(bbox[0])
seg.append(bbox[1])
# left_bottom
seg.append(bbox[0])
seg.append(bbox[1] + bbox[3])
# right_bottom
seg.append(bbox[0] + bbox[2])
seg.append(bbox[1] + bbox[3])
# right_top
seg.append(bbox[0] + bbox[2])
seg.append(bbox[1])
annotation_item['segmentation'].append(seg)
annotation_item['area'] = bbox[2] * bbox[3]
annotation_item['iscrowd'] = 0
annotation_item['ignore'] = 0
annotation_item['image_id'] = image_id
annotation_item['bbox'] = bbox
annotation_item['category_id'] = category_id
annotation_item['id'] = ann_id
return annotation_item
def convert_voc_to_coco(txt_path, json_path, xml_path):
# create and init coco json, img set, and class set
coco_json, class_set, image_set = init_json()
### collect img and ann info into coco json
# read img_name in txt, e.g., 000005 for voc2007, 2008_000002 for voc2012
img_txt = open(txt_path, 'r')
img_line = img_txt.readline().strip()
# loop xml
img_id = 0
ann_id = 0
while img_line:
print('img_id:', img_id)
# find corresponding xml
xml_name = img_line.split('Annotations/', 1)[1]
xml_file = os.path.join(xml_path, xml_name)
if not os.path.exists(xml_file):
print('{} is not exists.'.format(xml_name))
img_line = img_txt.readline().strip()
continue
# decode xml
tree = ET.parse(xml_file)
root = tree.getroot()
if root.tag != 'annotation':
raise Exception(
'xml {} root element should be annotation, rather than {}'.
format(xml_name, root.tag))
# init img and ann info
bndbox = dict()
size = dict()
size['width'] = None
size['height'] = None
size['depth'] = None
# filename
fileNameNode = root.find('filename')
file_name = fileNameNode.text
# img size
sizeNode = root.find('size')
if not sizeNode:
raise Exception('xml {} structure broken at size tag.'.format(
xml_name))
for subNode in sizeNode:
size[subNode.tag] = int(subNode.text)
# add img into json
if file_name not in image_set:
img_id += 1
format_img_id = int("%04d" % img_id)
# print('line 120. format_img_id:', format_img_id)
image_item = getImgItem(file_name, size, img_id)
image_set.add(file_name)
coco_json['images'].append(image_item)
else:
raise Exception(' xml {} duplicated image: {}'.format(xml_name,
file_name))
### add objAnn into json
objectAnns = root.findall('object')
for objectAnn in objectAnns:
bndbox['xmin'] = None
bndbox['xmax'] = None
bndbox['ymin'] = None
bndbox['ymax'] = None
#add obj category
object_name = objectAnn.find('name').text
if object_name not in class_set:
raise Exception('xml {} Unrecognized category: {}'.format(
xml_name, object_name))
else:
current_category_id = class_set[object_name]
#add obj bbox ann
objectBboxNode = objectAnn.find('bndbox')
for coordinate in objectBboxNode:
if bndbox[coordinate.tag] is not None:
raise Exception('xml {} structure corrupted at bndbox tag.'.
format(xml_name))
bndbox[coordinate.tag] = int(float(coordinate.text))
bbox = []
# x
bbox.append(bndbox['xmin'])
# y
bbox.append(bndbox['ymin'])
# w
bbox.append(bndbox['xmax'] - bndbox['xmin'])
# h
bbox.append(bndbox['ymax'] - bndbox['ymin'])
ann_id += 1
ann_item = getAnnoItem(object_name, img_id, ann_id,
current_category_id, bbox)
coco_json['annotations'].append(ann_item)
img_line = img_txt.readline().strip()
print('Saving json.')
json.dump(coco_json, open(json_path, 'w'))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--type', type=str, default='VOC2007_test', help="data type")
parser.add_argument(
'--base_path',
type=str,
default='dataset/voc/VOCdevkit',
help="base VOC path.")
args = parser.parse_args()
# image info path
txt_name = args.type + '.txt'
json_name = args.type + '.json'
txt_path = os.path.join(args.base_path, 'PseudoAnnotations', txt_name)
json_path = os.path.join(args.base_path, 'PseudoAnnotations', json_name)
# xml path
xml_path = os.path.join(args.base_path,
args.type.split('_')[0], 'Annotations')
print('txt_path:', txt_path)
print('json_path:', json_path)
print('xml_path:', xml_path)
print('Converting {} to COCO json.'.format(args.type))
convert_voc_to_coco(txt_path, json_path, xml_path)
print('Finished.')
简体中文 | [English](README_en.md)
# Ambiguity-Resistant Semi-Supervised Learning for Dense Object Detection (ARSL)
## ARSL-FCOS 模型库
| 模型 | COCO监督数据比例 | Semi mAP<sup>val<br>0.5:0.95 | Semi Epochs (Iters) | 模型下载 | 配置文件 |
| :------------: | :---------:|:----------------------------: | :------------------: |:--------: |:----------: |
| ARSL-FCOS | 1% | **22.8** | 240 (87120) | [download](https://paddledet.bj.bcebos.com/models/arsl_fcos_r50_fpn_coco_semi001.pdparams) | [config](./arsl_fcos_r50_fpn_coco_semi001.yml) |
| ARSL-FCOS | 5% | **33.1** | 240 (174240) | [download](https://paddledet.bj.bcebos.com/models/arsl_fcos_r50_fpn_coco_semi005.pdparams) | [config](./arsl_fcos_r50_fpn_coco_semi005.yml ) |
| ARSL-FCOS | 10% | **36.9** | 240 (174240) | [download](https://paddledet.bj.bcebos.com/models/arsl_fcos_r50_fpn_coco_semi010.pdparams) | [config](./arsl_fcos_r50_fpn_coco_semi010.yml ) |
| ARSL-FCOS | 10% | **38.5(LSJ)** | 240 (174240) | [download](https://paddledet.bj.bcebos.com/models/arsl_fcos_r50_fpn_coco_semi010_lsj.pdparams) | [config](./arsl_fcos_r50_fpn_coco_semi010_lsj.yml ) |
| ARSL-FCOS | full(100%) | **45.1** | 240 (174240) | [download](https://paddledet.bj.bcebos.com/models/arsl_fcos_r50_fpn_coco_full.pdparams) | [config](./arsl_fcos_r50_fpn_coco_full.yml ) |
## 使用说明
仅训练时必须使用半监督检测的配置文件去训练,评估、预测、部署也可以按基础检测器的配置文件去执行。
### 训练
```bash
# 单卡训练 (不推荐,需按线性比例相应地调整学习率)
CUDA_VISIBLE_DEVICES=0 python tools/train.py -c configs/semi_det/arsl/arsl_fcos_r50_fpn_coco_semi010.yml --eval
# 多卡训练
python -m paddle.distributed.launch --log_dir=arsl_fcos_r50_fpn_coco_semi010/ --gpus 0,1,2,3,4,5,6,7 tools/train.py -c configs/semi_det/arsl/arsl_fcos_r50_fpn_coco_semi010.yml --eval
```
### 评估
```bash
CUDA_VISIBLE_DEVICES=0 python tools/eval.py -c configs/semi_det/arsl/arsl_fcos_r50_fpn_coco_semi010.yml -o weights=output/arsl_fcos_r50_fpn_coco_semi010/model_final.pdparams
```
### 预测
```bash
CUDA_VISIBLE_DEVICES=0 python tools/infer.py -c configs/semi_det/arsl/arsl_fcos_r50_fpn_coco_semi010.yml -o weights=output/arsl_fcos_r50_fpn_coco_semi010/model_final.pdparams --infer_img=demo/000000014439.jpg
```
## 引用
```
```
architecture: ARSL_FCOS
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet50_cos_pretrained.pdparams
ARSL_FCOS:
backbone: ResNet
neck: FPN
fcos_head: FCOSHead_ARSL
fcos_cr_loss: FCOSLossCR
ResNet:
depth: 50
norm_type: bn
freeze_at: 0
return_idx: [1,2,3]
num_stages: 4
FPN:
out_channel: 256
spatial_scales: [0.125, 0.0625, 0.03125]
extra_stage: 2
has_extra_convs: true
use_c5: false
FCOSHead_ARSL:
fcos_feat:
name: FCOSFeat
feat_in: 256
feat_out: 256
num_convs: 4
norm_type: "gn"
use_dcn: false
fpn_stride: [8, 16, 32, 64, 128]
prior_prob: 0.01
norm_reg_targets: True
centerness_on_reg: True
fcos_loss:
name: FCOSLossMILC
loss_alpha: 0.25
loss_gamma: 2.0
iou_loss_type: "giou"
reg_weights: 1.0
nms:
name: MultiClassNMS
nms_top_k: 1000
keep_top_k: 100
score_threshold: 0.025
nms_threshold: 0.6
FCOSLossCR:
iou_loss_type: "giou"
cls_weight: 2.0
reg_weight: 2.0
iou_weight: 0.5
hard_neg_mining_flag: true
worker_num: 2
SemiTrainReader:
sample_transforms:
- Decode: {}
- RandomResize: {target_size: [[640, 1333], [672, 1333], [704, 1333], [736, 1333], [768, 1333], [800, 1333]], keep_ratio: True, interp: 1}
- RandomFlip: {}
weak_aug:
- NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: true}
strong_aug:
- StrongAugImage: {transforms: [
RandomColorJitter: {prob: 0.8, brightness: 0.4, contrast: 0.4, saturation: 0.4, hue: 0.1},
RandomErasingCrop: {},
RandomGaussianBlur: {prob: 0.5, sigma: [0.1, 2.0]},
RandomGrayscale: {prob: 0.2},
]}
- NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: true}
sup_batch_transforms:
- Permute: {}
- PadBatch: {pad_to_stride: 32}
- Gt2FCOSTarget:
object_sizes_boundary: [64, 128, 256, 512]
center_sampling_radius: 1.5
downsample_ratios: [8, 16, 32, 64, 128]
num_shift: 0. # default 0.5
multiply_strides_reg_targets: False
norm_reg_targets: True
unsup_batch_transforms:
- Permute: {}
- PadBatch: {pad_to_stride: 32}
sup_batch_size: 2
unsup_batch_size: 2
shuffle: True
drop_last: True
EvalReader:
sample_transforms:
- Decode: {}
- Resize: {interp: 2, target_size: [800, 1333], keep_ratio: True}
- NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: true}
- Permute: {}
batch_transforms:
- PadBatch: {pad_to_stride: 32}
batch_size: 1
TestReader:
sample_transforms:
- Decode: {}
- Resize: {interp: 2, target_size: [800, 1333], keep_ratio: True}
- NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: true}
- Permute: {}
batch_transforms:
- PadBatch: {pad_to_stride: 32}
batch_size: 1
epoch: 120 # employ iter to control shedule
LearningRate:
base_lr: 0.02 # 0.02 for 8*(4+4) batch
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [3000] # do not decay lr
- !LinearWarmup
start_factor: 0.3333333333333333
steps: 1000
max_iter: 360000 # 360k for 32 batch, 720k for 16 batch
epoch_iter: 1000 # set epoch_iter for saving checkpoint and eval
optimize_rate: 1
SEMISUPNET:
BBOX_THRESHOLD: 0.5 # # not used
TEACHER_UPDATE_ITER: 1
BURN_UP_STEP: 30000
EMA_KEEP_RATE: 0.9996
UNSUP_LOSS_WEIGHT: 1.0 # detailed weights for cls and loc task can be seen in cr_loss
PSEUDO_WARM_UP_STEPS: 2000
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0001
type: L2
epoch: 30 # employ iter to control shedule
LearningRate:
base_lr: 0.02 # 0.02 for 8*(4+4) batch
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [300] # do not decay lr
- !LinearWarmup
start_factor: 0.3333333333333333
steps: 1000
max_iter: 90000 # 90k for 32 batch, 180k for 16 batch
epoch_iter: 1000 # set epoch_iter for saving checkpoint and eval
# update student params according to loss_grad every X iter.
optimize_rate: 1
SEMISUPNET:
BBOX_THRESHOLD: 0.5 # not used
TEACHER_UPDATE_ITER: 1
BURN_UP_STEP: 9000
EMA_KEEP_RATE: 0.9996
UNSUP_LOSS_WEIGHT: 1.0 # detailed weights for cls and loc task can be seen in cr_loss
PSEUDO_WARM_UP_STEPS: 2000
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0001
type: L2
_BASE_: [
'../_base_/coco_detection_full.yml',
'../../runtime.yml',
'_base_/arsl_fcos_r50_fpn.yml',
'_base_/optimizer_360k.yml',
'_base_/arsl_fcos_reader.yml',
]
weights: output/fcos_r50_fpn_arsl_360k_coco_full/model_final
#semi detector type
ssod_method: ARSL
_BASE_: [
'../_base_/coco_detection_percent_1.yml',
'../../runtime.yml',
'_base_/arsl_fcos_r50_fpn.yml',
'_base_/optimizer_90k.yml',
'_base_/arsl_fcos_reader.yml',
]
weights: output/arsl_fcos_r50_fpn_coco_semi001/model_final
#semi detector type
ssod_method: ARSL
_BASE_: [
'../_base_/coco_detection_percent_5.yml',
'../../runtime.yml',
'_base_/arsl_fcos_r50_fpn.yml',
'_base_/optimizer_90k.yml',
'_base_/arsl_fcos_reader.yml',
]
weights: output/arsl_fcos_r50_fpn_coco_semi005/model_final
#semi detector type
ssod_method: ARSL
_BASE_: [
'../_base_/coco_detection_percent_10.yml',
'../../runtime.yml',
'_base_/arsl_fcos_r50_fpn.yml',
'_base_/optimizer_360k.yml',
'_base_/arsl_fcos_reader.yml',
]
weights: output/arsl_fcos_r50_fpn_coco_semi010/model_final
#semi detector type
ssod_method: ARSL
_BASE_: [
'../_base_/coco_detection_percent_10.yml',
'../../runtime.yml',
'_base_/arsl_fcos_r50_fpn.yml',
'_base_/optimizer_360k.yml',
'_base_/arsl_fcos_reader.yml',
]
weights: output/arsl_fcos_r50_fpn_coco_semi010/model_final
#semi detector type
ssod_method: ARSL
worker_num: 2
SemiTrainReader:
sample_transforms:
- Decode: {}
# large-scale jittering
- RandomResize: {target_size: [[400, 1333], [1200, 1333]], keep_ratio: True, interp: 1, random_range: True}
- RandomFlip: {}
weak_aug:
- NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: true}
strong_aug:
- StrongAugImage: {transforms: [
RandomColorJitter: {prob: 0.8, brightness: 0.4, contrast: 0.4, saturation: 0.4, hue: 0.1},
RandomErasingCrop: {},
RandomGaussianBlur: {prob: 0.5, sigma: [0.1, 2.0]},
RandomGrayscale: {prob: 0.2},
]}
- NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: true}
sup_batch_transforms:
- Permute: {}
- PadBatch: {pad_to_stride: 32}
- Gt2FCOSTarget:
object_sizes_boundary: [64, 128, 256, 512]
center_sampling_radius: 1.5
downsample_ratios: [8, 16, 32, 64, 128]
num_shift: 0. # default 0.5
multiply_strides_reg_targets: False
norm_reg_targets: True
unsup_batch_transforms:
- Permute: {}
- PadBatch: {pad_to_stride: 32}
sup_batch_size: 2
unsup_batch_size: 2
shuffle: True
drop_last: True
...@@ -394,11 +394,11 @@ class Trainer(object): ...@@ -394,11 +394,11 @@ class Trainer(object):
"metrics shoule be instances of subclass of Metric" "metrics shoule be instances of subclass of Metric"
self._metrics.extend(metrics) self._metrics.extend(metrics)
def load_weights(self, weights): def load_weights(self, weights, ARSL_eval=False):
if self.is_loaded_weights: if self.is_loaded_weights:
return return
self.start_epoch = 0 self.start_epoch = 0
load_pretrain_weight(self.model, weights) load_pretrain_weight(self.model, weights, ARSL_eval)
logger.debug("Load weights {} to start training".format(weights)) logger.debug("Load weights {} to start training".format(weights))
def load_weights_sde(self, det_weights, reid_weights): def load_weights_sde(self, det_weights, reid_weights):
...@@ -985,8 +985,10 @@ class Trainer(object): ...@@ -985,8 +985,10 @@ class Trainer(object):
for step_id, data in enumerate(tqdm(loader)): for step_id, data in enumerate(tqdm(loader)):
self.status['step_id'] = step_id self.status['step_id'] = step_id
# forward # forward
outs = self.model(data) if hasattr(self.model, 'modelTeacher'):
outs = self.model.modelTeacher(data)
else:
outs = self.model(data)
for _m in metrics: for _m in metrics:
_m.update(data, outs) _m.update(data, outs)
......
...@@ -16,6 +16,7 @@ from __future__ import absolute_import ...@@ -16,6 +16,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import os
import copy import copy
import time import time
import typing import typing
...@@ -26,18 +27,20 @@ import paddle.nn as nn ...@@ -26,18 +27,20 @@ import paddle.nn as nn
import paddle.distributed as dist import paddle.distributed as dist
from paddle.distributed import fleet from paddle.distributed import fleet
from ppdet.optimizer import ModelEMA, SimpleModelEMA from ppdet.optimizer import ModelEMA, SimpleModelEMA
from ppdet.core.workspace import create from ppdet.core.workspace import create
from ppdet.utils.checkpoint import load_weight, load_pretrain_weight from ppdet.utils.checkpoint import load_weight, load_pretrain_weight, save_model
import ppdet.utils.stats as stats import ppdet.utils.stats as stats
from ppdet.utils import profiler from ppdet.utils import profiler
from ppdet.modeling.ssod.utils import align_weak_strong_shape from ppdet.modeling.ssod.utils import align_weak_strong_shape
from .trainer import Trainer from .trainer import Trainer
from ppdet.utils.logger import setup_logger from ppdet.utils.logger import setup_logger
from paddle.static import InputSpec
from ppdet.engine.export_utils import _dump_infer_config, _prune_input_spec
MOT_ARCH = ['JDE', 'FairMOT', 'DeepSORT', 'ByteTrack', 'CenterTrack']
logger = setup_logger('ppdet.engine') logger = setup_logger('ppdet.engine')
__all__ = ['Trainer_DenseTeacher'] __all__ = ['Trainer_DenseTeacher', 'Trainer_ARSL']
class Trainer_DenseTeacher(Trainer): class Trainer_DenseTeacher(Trainer):
...@@ -199,11 +202,6 @@ class Trainer_DenseTeacher(Trainer): ...@@ -199,11 +202,6 @@ class Trainer_DenseTeacher(Trainer):
self.status['data_time'] = stats.SmoothedValue( self.status['data_time'] = stats.SmoothedValue(
self.cfg.log_iter, fmt='{avg:.4f}') self.cfg.log_iter, fmt='{avg:.4f}')
self.status['training_staus'] = stats.TrainingStats(self.cfg.log_iter) self.status['training_staus'] = stats.TrainingStats(self.cfg.log_iter)
if self.cfg.get('print_flops', False):
flops_loader = create('{}Reader'.format(self.mode.capitalize()))(
self.dataset, self.cfg.worker_num)
self._flops(flops_loader)
profiler_options = self.cfg.get('profiler_options', None) profiler_options = self.cfg.get('profiler_options', None)
self._compose_callback.on_train_begin(self.status) self._compose_callback.on_train_begin(self.status)
...@@ -466,6 +464,365 @@ class Trainer_DenseTeacher(Trainer): ...@@ -466,6 +464,365 @@ class Trainer_DenseTeacher(Trainer):
self.status['sample_num'] = sample_num self.status['sample_num'] = sample_num
self.status['cost_time'] = time.time() - tic self.status['cost_time'] = time.time() - tic
# accumulate metric to log out
for metric in self._metrics:
metric.accumulate()
metric.log()
self._compose_callback.on_epoch_end(self.status)
self._reset_metrics()
class Trainer_ARSL(Trainer):
def __init__(self, cfg, mode='train'):
self.cfg = cfg
assert mode.lower() in ['train', 'eval', 'test'], \
"mode should be 'train', 'eval' or 'test'"
self.mode = mode.lower()
self.optimizer = None
self.is_loaded_weights = False
capital_mode = self.mode.capitalize()
self.use_ema = False
self.dataset = self.cfg['{}Dataset'.format(capital_mode)] = create(
'{}Dataset'.format(capital_mode))()
if self.mode == 'train':
self.dataset_unlabel = self.cfg['UnsupTrainDataset'] = create(
'UnsupTrainDataset')
self.loader = create('SemiTrainReader')(
self.dataset, self.dataset_unlabel, cfg.worker_num)
# build model
if 'model' not in self.cfg:
self.student_model = create(cfg.architecture)
self.teacher_model = create(cfg.architecture)
self.model = EnsembleTSModel(self.teacher_model, self.student_model)
else:
self.model = self.cfg.model
self.is_loaded_weights = True
# save path for burn-in model
self.base_path = cfg.get('weights')
self.base_path = os.path.dirname(self.base_path)
# EvalDataset build with BatchSampler to evaluate in single device
# TODO: multi-device evaluate
if self.mode == 'eval':
self._eval_batch_sampler = paddle.io.BatchSampler(
self.dataset, batch_size=self.cfg.EvalReader['batch_size'])
self.loader = create('{}Reader'.format(self.mode.capitalize()))(
self.dataset, cfg.worker_num, self._eval_batch_sampler)
# TestDataset build after user set images, skip loader creation here
self.start_epoch = 0
self.end_epoch = 0 if 'epoch' not in cfg else cfg.epoch
self.epoch_iter = self.cfg.epoch_iter # set fixed iter in each epoch to control checkpoint
# build optimizer in train mode
if self.mode == 'train':
steps_per_epoch = self.epoch_iter
self.lr = create('LearningRate')(steps_per_epoch)
self.optimizer = create('OptimizerBuilder')(self.lr,
self.model.modelStudent)
self._nranks = dist.get_world_size()
self._local_rank = dist.get_rank()
self.status = {}
# initial default callbacks
self._init_callbacks()
# initial default metrics
self._init_metrics()
self._reset_metrics()
self.iter = 0
def resume_weights(self, weights):
# support Distill resume weights
if hasattr(self.model, 'student_model'):
self.start_epoch = load_weight(self.model.student_model, weights,
self.optimizer)
else:
self.start_epoch = load_weight(self.model, weights, self.optimizer)
logger.debug("Resume weights of epoch {}".format(self.start_epoch))
def train(self, validate=False):
assert self.mode == 'train', "Model not in 'train' mode"
Init_mark = False
# if validation in training is enabled, metrics should be re-init
if validate:
self._init_metrics(validate=validate)
self._reset_metrics()
if self.cfg.get('fleet', False):
self.model.modelStudent = fleet.distributed_model(
self.model.modelStudent)
self.optimizer = fleet.distributed_optimizer(self.optimizer)
elif self._nranks > 1:
find_unused_parameters = self.cfg[
'find_unused_parameters'] if 'find_unused_parameters' in self.cfg else False
self.model.modelStudent = paddle.DataParallel(
self.model.modelStudent,
find_unused_parameters=find_unused_parameters)
# set fixed iter in each epoch to control checkpoint
self.status.update({
'epoch_id': self.start_epoch,
'step_id': 0,
'steps_per_epoch': self.epoch_iter
})
print('338 Len of DataLoader: {}'.format(len(self.loader)))
self.status['batch_time'] = stats.SmoothedValue(
self.cfg.log_iter, fmt='{avg:.4f}')
self.status['data_time'] = stats.SmoothedValue(
self.cfg.log_iter, fmt='{avg:.4f}')
self.status['training_staus'] = stats.TrainingStats(self.cfg.log_iter)
self._compose_callback.on_train_begin(self.status)
epoch_id = self.start_epoch
self.iter = self.start_epoch * self.epoch_iter
# use iter rather than epoch to control training schedule
while self.iter < self.cfg.max_iter:
# epoch loop
self.status['mode'] = 'train'
self.status['epoch_id'] = epoch_id
self._compose_callback.on_epoch_begin(self.status)
self.loader.dataset_label.set_epoch(epoch_id)
self.loader.dataset_unlabel.set_epoch(epoch_id)
paddle.device.cuda.empty_cache() # clear GPU memory
# set model status
self.model.modelStudent.train()
self.model.modelTeacher.eval()
iter_tic = time.time()
# iter loop in each eopch
for step_id in range(self.epoch_iter):
data = next(self.loader)
self.status['data_time'].update(time.time() - iter_tic)
self.status['step_id'] = step_id
# profiler.add_profiler_step(profiler_options)
self._compose_callback.on_step_begin(self.status)
# model forward and calculate loss
loss_dict = self.run_step_full_semisup(data)
if (step_id + 1) % self.cfg.optimize_rate == 0:
self.optimizer.step()
self.optimizer.clear_grad()
curr_lr = self.optimizer.get_lr()
self.lr.step()
# update log status
self.status['learning_rate'] = curr_lr
if self._nranks < 2 or self._local_rank == 0:
self.status['training_staus'].update(loss_dict)
self.status['batch_time'].update(time.time() - iter_tic)
self._compose_callback.on_step_end(self.status)
self.iter += 1
iter_tic = time.time()
self._compose_callback.on_epoch_end(self.status)
if validate and (self._nranks < 2 or self._local_rank == 0) \
and ((epoch_id + 1) % self.cfg.snapshot_epoch == 0 \
or epoch_id == self.end_epoch - 1):
if not hasattr(self, '_eval_loader'):
# build evaluation dataset and loader
self._eval_dataset = self.cfg.EvalDataset
self._eval_batch_sampler = \
paddle.io.BatchSampler(
self._eval_dataset,
batch_size=self.cfg.EvalReader['batch_size'])
self._eval_loader = create('EvalReader')(
self._eval_dataset,
self.cfg.worker_num,
batch_sampler=self._eval_batch_sampler)
if validate and Init_mark == False:
Init_mark = True
self._init_metrics(validate=validate)
self._reset_metrics()
with paddle.no_grad():
self.status['save_best_model'] = True
# before burn-in stage, eval student. after burn-in stage, eval teacher
if self.iter <= self.cfg.SEMISUPNET['BURN_UP_STEP']:
print("start eval student model")
self._eval_with_loader(
self._eval_loader, mode="student")
else:
print("start eval teacher model")
self._eval_with_loader(
self._eval_loader, mode="teacher")
epoch_id += 1
self._compose_callback.on_train_end(self.status)
def merge_data(self, data1, data2):
data = copy.deepcopy(data1)
for k, v in data1.items():
if type(v) is paddle.Tensor:
data[k] = paddle.concat(x=[data[k], data2[k]], axis=0)
elif type(v) is list:
data[k].extend(data2[k])
return data
def run_step_full_semisup(self, data):
label_data_k, label_data_q, unlabel_data_k, unlabel_data_q = data
data_merge = self.merge_data(label_data_k, label_data_q)
loss_sup_dict = self.model.modelStudent(data_merge, branch="supervised")
loss_dict = {}
for key in loss_sup_dict.keys():
if key[:4] == "loss":
loss_dict[key] = loss_sup_dict[key] * 1
losses_sup = paddle.add_n(list(loss_dict.values()))
# norm loss when using gradient accumulation
losses_sup = losses_sup / self.cfg.optimize_rate
losses_sup.backward()
for key in loss_sup_dict.keys():
loss_dict[key + "_pseudo"] = paddle.to_tensor([0])
loss_dict["loss_tot"] = losses_sup
"""
semi-supervised training after burn-in stage
"""
if self.iter >= self.cfg.SEMISUPNET['BURN_UP_STEP']:
# init teacher model with burn-up weight
if self.iter == self.cfg.SEMISUPNET['BURN_UP_STEP']:
print(
'Starting semi-supervised learning and load the teacher model.'
)
self._update_teacher_model(keep_rate=0.00)
# save burn-in model
if dist.get_world_size() < 2 or dist.get_rank() == 0:
print('saving burn-in model.')
save_name = 'burnIn'
epoch_id = self.iter // self.epoch_iter
save_model(self.model, self.optimizer, self.base_path,
save_name, epoch_id)
# Update teacher model with EMA
elif (self.iter + 1) % self.cfg.optimize_rate == 0:
self._update_teacher_model(
keep_rate=self.cfg.SEMISUPNET['EMA_KEEP_RATE'])
#warm-up weight for pseudo loss
pseudo_weight = self.cfg.SEMISUPNET['UNSUP_LOSS_WEIGHT']
pseudo_warmup_iter = self.cfg.SEMISUPNET['PSEUDO_WARM_UP_STEPS']
temp = self.iter - self.cfg.SEMISUPNET['BURN_UP_STEP']
if temp <= pseudo_warmup_iter:
pseudo_weight *= (temp / pseudo_warmup_iter)
# get teacher predictions on weak-augmented unlabeled data
with paddle.no_grad():
teacher_pred = self.model.modelTeacher(
unlabel_data_k, branch='semi_supervised')
# calculate unsupervised loss on strong-augmented unlabeled data
loss_unsup_dict = self.model.modelStudent(
unlabel_data_q,
branch="semi_supervised",
teacher_prediction=teacher_pred, )
for key in loss_unsup_dict.keys():
if key[-6:] == "pseudo":
loss_unsup_dict[key] = loss_unsup_dict[key] * pseudo_weight
losses_unsup = paddle.add_n(list(loss_unsup_dict.values()))
# norm loss when using gradient accumulation
losses_unsup = losses_unsup / self.cfg.optimize_rate
losses_unsup.backward()
loss_dict.update(loss_unsup_dict)
loss_dict["loss_tot"] += losses_unsup
return loss_dict
def export(self, output_dir='output_inference'):
self.model.eval()
model_name = os.path.splitext(os.path.split(self.cfg.filename)[-1])[0]
save_dir = os.path.join(output_dir, model_name)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
image_shape = None
if self.cfg.architecture in MOT_ARCH:
test_reader_name = 'TestMOTReader'
else:
test_reader_name = 'TestReader'
if 'inputs_def' in self.cfg[test_reader_name]:
inputs_def = self.cfg[test_reader_name]['inputs_def']
image_shape = inputs_def.get('image_shape', None)
# set image_shape=[3, -1, -1] as default
if image_shape is None:
image_shape = [3, -1, -1]
self.model.modelTeacher.eval()
if hasattr(self.model.modelTeacher, 'deploy'):
self.model.modelTeacher.deploy = True
# Save infer cfg
_dump_infer_config(self.cfg,
os.path.join(save_dir, 'infer_cfg.yml'), image_shape,
self.model.modelTeacher)
input_spec = [{
"image": InputSpec(
shape=[None] + image_shape, name='image'),
"im_shape": InputSpec(
shape=[None, 2], name='im_shape'),
"scale_factor": InputSpec(
shape=[None, 2], name='scale_factor')
}]
if self.cfg.architecture == 'DeepSORT':
input_spec[0].update({
"crops": InputSpec(
shape=[None, 3, 192, 64], name='crops')
})
static_model = paddle.jit.to_static(
self.model.modelTeacher, input_spec=input_spec)
# NOTE: dy2st do not pruned program, but jit.save will prune program
# input spec, prune input spec here and save with pruned input spec
pruned_input_spec = _prune_input_spec(input_spec,
static_model.forward.main_program,
static_model.forward.outputs)
# dy2st and save model
if 'slim' not in self.cfg or self.cfg['slim_type'] != 'QAT':
paddle.jit.save(
static_model,
os.path.join(save_dir, 'model'),
input_spec=pruned_input_spec)
else:
self.cfg.slim.save_quantized_model(
self.model.modelTeacher,
os.path.join(save_dir, 'model'),
input_spec=pruned_input_spec)
logger.info("Export model and saved in {}".format(save_dir))
def _eval_with_loader(self, loader, mode="teacher"):
sample_num = 0
tic = time.time()
self._compose_callback.on_epoch_begin(self.status)
self.status['mode'] = 'eval'
# self.model.eval()
self.model.modelTeacher.eval()
self.model.modelStudent.eval()
for step_id, data in enumerate(loader):
self.status['step_id'] = step_id
self._compose_callback.on_step_begin(self.status)
if mode == "teacher":
outs = self.model.modelTeacher(data)
else:
outs = self.model.modelStudent(data)
# update metrics
for metric in self._metrics:
metric.update(data, outs)
sample_num += data['im_id'].numpy().shape[0]
self._compose_callback.on_step_end(self.status)
self.status['sample_num'] = sample_num
self.status['cost_time'] = time.time() - tic
# accumulate metric to log out # accumulate metric to log out
for metric in self._metrics: for metric in self._metrics:
metric.accumulate() metric.accumulate()
...@@ -473,3 +830,29 @@ class Trainer_DenseTeacher(Trainer): ...@@ -473,3 +830,29 @@ class Trainer_DenseTeacher(Trainer):
self._compose_callback.on_epoch_end(self.status) self._compose_callback.on_epoch_end(self.status)
# reset metric states for metric may performed multiple times # reset metric states for metric may performed multiple times
self._reset_metrics() self._reset_metrics()
def evaluate(self):
with paddle.no_grad():
self._eval_with_loader(self.loader)
@paddle.no_grad()
def _update_teacher_model(self, keep_rate=0.996):
student_model_dict = copy.deepcopy(self.model.modelStudent.state_dict())
new_teacher_dict = dict()
for key, value in self.model.modelTeacher.state_dict().items():
if key in student_model_dict.keys():
v = student_model_dict[key] * (1 - keep_rate
) + value * keep_rate
v.stop_gradient = True
new_teacher_dict[key] = v
else:
raise Exception("{} is not found in student model".format(key))
self.model.modelTeacher.set_dict(new_teacher_dict)
class EnsembleTSModel(nn.Layer):
def __init__(self, modelTeacher, modelStudent):
super(EnsembleTSModel, self).__init__()
self.modelTeacher = modelTeacher
self.modelStudent = modelStudent
...@@ -74,4 +74,4 @@ from .yolof import * ...@@ -74,4 +74,4 @@ from .yolof import *
from .pose3d_metro import * from .pose3d_metro import *
from .centertrack import * from .centertrack import *
from .queryinst import * from .queryinst import *
from .keypoint_petr import * from .keypoint_petr import *
\ No newline at end of file
...@@ -16,10 +16,11 @@ from __future__ import absolute_import ...@@ -16,10 +16,11 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import paddle
from ppdet.core.workspace import register, create from ppdet.core.workspace import register, create
from .meta_arch import BaseArch from .meta_arch import BaseArch
__all__ = ['FCOS'] __all__ = ['FCOS', 'ARSL_FCOS']
@register @register
...@@ -31,7 +32,7 @@ class FCOS(BaseArch): ...@@ -31,7 +32,7 @@ class FCOS(BaseArch):
backbone (object): backbone instance backbone (object): backbone instance
neck (object): 'FPN' instance neck (object): 'FPN' instance
fcos_head (object): 'FCOSHead' instance fcos_head (object): 'FCOSHead' instance
ssod_loss (object): 'SSODFCOSLoss' instance, only used for semi-det(ssod) ssod_loss (object): 'SSODFCOSLoss' instance, only used for semi-det(ssod) by DenseTeacher
""" """
__category__ = 'architecture' __category__ = 'architecture'
...@@ -94,3 +95,128 @@ class FCOS(BaseArch): ...@@ -94,3 +95,128 @@ class FCOS(BaseArch):
ssod_losses = self.ssod_loss(student_head_outs, teacher_head_outs, ssod_losses = self.ssod_loss(student_head_outs, teacher_head_outs,
train_cfg) train_cfg)
return ssod_losses return ssod_losses
@register
class ARSL_FCOS(BaseArch):
"""
FCOS ARSL network, see https://arxiv.org/abs/
Args:
backbone (object): backbone instance
neck (object): 'FPN' instance
fcos_head (object): 'FCOSHead_ARSL' instance
fcos_cr_loss (object): 'FCOSLossCR' instance, only used for semi-det(ssod) by ARSL
"""
__category__ = 'architecture'
__inject__ = ['fcos_cr_loss']
def __init__(self,
backbone,
neck,
fcos_head='FCOSHead_ARSL',
fcos_cr_loss='FCOSLossCR'):
super(ARSL_FCOS, self).__init__()
self.backbone = backbone
self.neck = neck
self.fcos_head = fcos_head
self.fcos_cr_loss = fcos_cr_loss
@classmethod
def from_config(cls, cfg, *args, **kwargs):
backbone = create(cfg['backbone'])
kwargs = {'input_shape': backbone.out_shape}
neck = create(cfg['neck'], **kwargs)
kwargs = {'input_shape': neck.out_shape}
fcos_head = create(cfg['fcos_head'], **kwargs)
# consistency regularization loss
fcos_cr_loss = create(cfg['fcos_cr_loss'])
return {
'backbone': backbone,
'neck': neck,
'fcos_head': fcos_head,
'fcos_cr_loss': fcos_cr_loss,
}
def forward(self, inputs, branch="supervised", teacher_prediction=None):
assert branch in ['supervised', 'semi_supervised'], \
print('In ARSL, type must be supervised or semi_supervised.')
if self.data_format == 'NHWC':
image = inputs['image']
inputs['image'] = paddle.transpose(image, [0, 2, 3, 1])
self.inputs = inputs
if self.training:
if branch == "supervised":
out = self.get_loss()
else:
out = self.get_pseudo_loss(teacher_prediction)
else:
# norm test
if branch == "supervised":
out = self.get_pred()
# predict pseudo labels
else:
out = self.get_pseudo_pred()
return out
# model forward
def model_forward(self):
body_feats = self.backbone(self.inputs)
fpn_feats = self.neck(body_feats)
fcos_head_outs = self.fcos_head(fpn_feats)
return fcos_head_outs
# supervised loss for labeled data
def get_loss(self):
loss = {}
tag_labels, tag_bboxes, tag_centerness = [], [], []
for i in range(len(self.fcos_head.fpn_stride)):
# labels, reg_target, centerness
k_lbl = 'labels{}'.format(i)
if k_lbl in self.inputs:
tag_labels.append(self.inputs[k_lbl])
k_box = 'reg_target{}'.format(i)
if k_box in self.inputs:
tag_bboxes.append(self.inputs[k_box])
k_ctn = 'centerness{}'.format(i)
if k_ctn in self.inputs:
tag_centerness.append(self.inputs[k_ctn])
fcos_head_outs = self.model_forward()
loss_fcos = self.fcos_head.get_loss(fcos_head_outs, tag_labels,
tag_bboxes, tag_centerness)
loss.update(loss_fcos)
return loss
# unsupervised loss for unlabeled data
def get_pseudo_loss(self, teacher_prediction):
loss = {}
fcos_head_outs = self.model_forward()
unsup_loss = self.fcos_cr_loss(fcos_head_outs, teacher_prediction)
for k in unsup_loss.keys():
loss[k + '_pseudo'] = unsup_loss[k]
return loss
# get detection results for test, decode and rescale the results to original size
def get_pred(self):
fcos_head_outs = self.model_forward()
scale_factor = self.inputs['scale_factor']
bbox_pred, bbox_num = self.fcos_head.post_process(fcos_head_outs,
scale_factor)
output = {'bbox': bbox_pred, 'bbox_num': bbox_num}
return output
# generate pseudo labels to guide student
def get_pseudo_pred(self):
fcos_head_outs = self.model_forward()
pred_cls, pred_loc, pred_iou = fcos_head_outs[1:] # 0 is locations
for lvl, _ in enumerate(pred_loc):
pred_loc[lvl] = pred_loc[lvl] / self.fcos_head.fpn_stride[lvl]
return [pred_cls, pred_loc, pred_iou, self.fcos_head.fpn_stride]
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -26,7 +26,7 @@ from paddle.nn.initializer import Normal, Constant ...@@ -26,7 +26,7 @@ from paddle.nn.initializer import Normal, Constant
from ppdet.core.workspace import register from ppdet.core.workspace import register
from ppdet.modeling.layers import ConvNormLayer, MultiClassNMS from ppdet.modeling.layers import ConvNormLayer, MultiClassNMS
__all__ = ['FCOSFeat', 'FCOSHead'] __all__ = ['FCOSFeat', 'FCOSHead', 'FCOSHead_ARSL']
class ScaleReg(nn.Layer): class ScaleReg(nn.Layer):
...@@ -263,10 +263,23 @@ class FCOSHead(nn.Layer): ...@@ -263,10 +263,23 @@ class FCOSHead(nn.Layer):
centerness_list.append(centerness) centerness_list.append(centerness)
if targets is not None: if targets is not None:
self.is_teacher = targets.get('is_teacher', False) self.is_teacher = targets.get('ARSL_teacher', False)
if self.is_teacher: if self.is_teacher:
return [cls_logits_list, bboxes_reg_list, centerness_list] return [cls_logits_list, bboxes_reg_list, centerness_list]
if targets is not None:
self.is_student = targets.get('ARSL_student', False)
if self.is_student:
return [cls_logits_list, bboxes_reg_list, centerness_list]
if targets is not None:
self.is_teacher = targets.get('is_teacher', False)
if self.is_teacher:
return [
locations_list, cls_logits_list, bboxes_reg_list,
centerness_list
]
if self.training and targets is not None: if self.training and targets is not None:
get_data = targets.get('get_data', False) get_data = targets.get('get_data', False)
if get_data: if get_data:
...@@ -361,3 +374,139 @@ class FCOSHead(nn.Layer): ...@@ -361,3 +374,139 @@ class FCOSHead(nn.Layer):
pred_scores = pred_scores.transpose([0, 2, 1]) pred_scores = pred_scores.transpose([0, 2, 1])
bbox_pred, bbox_num, _ = self.nms(pred_bboxes, pred_scores) bbox_pred, bbox_num, _ = self.nms(pred_bboxes, pred_scores)
return bbox_pred, bbox_num return bbox_pred, bbox_num
@register
class FCOSHead_ARSL(FCOSHead):
"""
FCOSHead of ARSL for semi-det(ssod)
Args:
fcos_feat (object): Instance of 'FCOSFeat'
num_classes (int): Number of classes
fpn_stride (list): The stride of each FPN Layer
prior_prob (float): Used to set the bias init for the class prediction layer
fcos_loss (object): Instance of 'FCOSLoss'
norm_reg_targets (bool): Normalization the regression target if true
centerness_on_reg (bool): The prediction of centerness on regression or clssification branch
nms (object): Instance of 'MultiClassNMS'
trt (bool): Whether to use trt in nms of deploy
"""
__inject__ = ['fcos_feat', 'fcos_loss', 'nms']
__shared__ = ['num_classes', 'trt']
def __init__(self,
num_classes=80,
fcos_feat='FCOSFeat',
fpn_stride=[8, 16, 32, 64, 128],
prior_prob=0.01,
multiply_strides_reg_targets=False,
norm_reg_targets=True,
centerness_on_reg=True,
num_shift=0.5,
sqrt_score=False,
fcos_loss='FCOSLossMILC',
nms='MultiClassNMS',
trt=False):
super(FCOSHead_ARSL, self).__init__()
self.fcos_feat = fcos_feat
self.num_classes = num_classes
self.fpn_stride = fpn_stride
self.prior_prob = prior_prob
self.fcos_loss = fcos_loss
self.norm_reg_targets = norm_reg_targets
self.centerness_on_reg = centerness_on_reg
self.multiply_strides_reg_targets = multiply_strides_reg_targets
self.num_shift = num_shift
self.nms = nms
if isinstance(self.nms, MultiClassNMS) and trt:
self.nms.trt = trt
self.sqrt_score = sqrt_score
conv_cls_name = "fcos_head_cls"
bias_init_value = -math.log((1 - self.prior_prob) / self.prior_prob)
self.fcos_head_cls = self.add_sublayer(
conv_cls_name,
nn.Conv2D(
in_channels=256,
out_channels=self.num_classes,
kernel_size=3,
stride=1,
padding=1,
weight_attr=ParamAttr(initializer=Normal(
mean=0., std=0.01)),
bias_attr=ParamAttr(
initializer=Constant(value=bias_init_value))))
conv_reg_name = "fcos_head_reg"
self.fcos_head_reg = self.add_sublayer(
conv_reg_name,
nn.Conv2D(
in_channels=256,
out_channels=4,
kernel_size=3,
stride=1,
padding=1,
weight_attr=ParamAttr(initializer=Normal(
mean=0., std=0.01)),
bias_attr=ParamAttr(initializer=Constant(value=0))))
conv_centerness_name = "fcos_head_centerness"
self.fcos_head_centerness = self.add_sublayer(
conv_centerness_name,
nn.Conv2D(
in_channels=256,
out_channels=1,
kernel_size=3,
stride=1,
padding=1,
weight_attr=ParamAttr(initializer=Normal(
mean=0., std=0.01)),
bias_attr=ParamAttr(initializer=Constant(value=0))))
self.scales_regs = []
for i in range(len(self.fpn_stride)):
lvl = int(math.log(int(self.fpn_stride[i]), 2))
feat_name = 'p{}_feat'.format(lvl)
scale_reg = self.add_sublayer(feat_name, ScaleReg())
self.scales_regs.append(scale_reg)
def forward(self, fpn_feats, targets=None):
assert len(fpn_feats) == len(
self.fpn_stride
), "The size of fpn_feats is not equal to size of fpn_stride"
cls_logits_list = []
bboxes_reg_list = []
centerness_list = []
for scale_reg, fpn_stride, fpn_feat in zip(self.scales_regs,
self.fpn_stride, fpn_feats):
fcos_cls_feat, fcos_reg_feat = self.fcos_feat(fpn_feat)
cls_logits = self.fcos_head_cls(fcos_cls_feat)
bbox_reg = scale_reg(self.fcos_head_reg(fcos_reg_feat))
if self.centerness_on_reg:
centerness = self.fcos_head_centerness(fcos_reg_feat)
else:
centerness = self.fcos_head_centerness(fcos_cls_feat)
if self.norm_reg_targets:
bbox_reg = F.relu(bbox_reg)
if not self.training:
bbox_reg = bbox_reg * fpn_stride
else:
bbox_reg = paddle.exp(bbox_reg)
cls_logits_list.append(cls_logits)
bboxes_reg_list.append(bbox_reg)
centerness_list.append(centerness)
if not self.training:
locations_list = []
for fpn_stride, feature in zip(self.fpn_stride, fpn_feats):
location = self._compute_locations_by_level(fpn_stride, feature)
locations_list.append(location)
return locations_list, cls_logits_list, bboxes_reg_list, centerness_list
else:
return cls_logits_list, bboxes_reg_list, centerness_list
def get_loss(self, fcos_head_outs, tag_labels, tag_bboxes, tag_centerness):
cls_logits, bboxes_reg, centerness = fcos_head_outs
return self.fcos_loss(cls_logits, bboxes_reg, centerness, tag_labels,
tag_bboxes, tag_centerness)
此差异已折叠。
...@@ -17,9 +17,7 @@ from __future__ import division ...@@ -17,9 +17,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from __future__ import unicode_literals from __future__ import unicode_literals
import errno
import os import os
import time
import numpy as np import numpy as np
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
...@@ -40,21 +38,6 @@ def is_url(path): ...@@ -40,21 +38,6 @@ def is_url(path):
or path.startswith('ppdet://') or path.startswith('ppdet://')
def _get_unique_endpoints(trainer_endpoints):
# Sorting is to avoid different environmental variables for each card
trainer_endpoints.sort()
ips = set()
unique_endpoints = set()
for endpoint in trainer_endpoints:
ip = endpoint.split(":")[0]
if ip in ips:
continue
ips.add(ip)
unique_endpoints.add(endpoint)
logger.info("unique_endpoints {}".format(unique_endpoints))
return unique_endpoints
def _strip_postfix(path): def _strip_postfix(path):
path, ext = os.path.splitext(path) path, ext = os.path.splitext(path)
assert ext in ['', '.pdparams', '.pdopt', '.pdmodel'], \ assert ext in ['', '.pdparams', '.pdopt', '.pdmodel'], \
...@@ -92,28 +75,35 @@ def load_weight(model, weight, optimizer=None, ema=None, exchange=True): ...@@ -92,28 +75,35 @@ def load_weight(model, weight, optimizer=None, ema=None, exchange=True):
ema_state_dict = None ema_state_dict = None
param_state_dict = paddle.load(pdparam_path) param_state_dict = paddle.load(pdparam_path)
model_dict = model.state_dict() if hasattr(model, 'modelTeacher') and hasattr(model, 'modelStudent'):
model_weight = {} print('Loading pretrain weights for Teacher-Student framework.')
incorrect_keys = 0 print('Loading pretrain weights for Student model.')
student_model_dict = model.modelStudent.state_dict()
student_param_state_dict = match_state_dict(
student_model_dict, param_state_dict, mode='student')
model.modelStudent.set_dict(student_param_state_dict)
print('Loading pretrain weights for Teacher model.')
teacher_model_dict = model.modelTeacher.state_dict()
for key, value in model_dict.items(): teacher_param_state_dict = match_state_dict(
if key in param_state_dict.keys(): teacher_model_dict, param_state_dict, mode='teacher')
if isinstance(param_state_dict[key], np.ndarray): model.modelTeacher.set_dict(teacher_param_state_dict)
param_state_dict[key] = paddle.to_tensor(param_state_dict[key])
if value.dtype == param_state_dict[key].dtype: else:
model_dict = model.state_dict()
model_weight = {}
incorrect_keys = 0
for key in model_dict.keys():
if key in param_state_dict.keys():
model_weight[key] = param_state_dict[key] model_weight[key] = param_state_dict[key]
else: else:
model_weight[key] = param_state_dict[key].astype(value.dtype) logger.info('Unmatched key: {}'.format(key))
else: incorrect_keys += 1
logger.info('Unmatched key: {}'.format(key)) assert incorrect_keys == 0, "Load weight {} incorrectly, \
incorrect_keys += 1 {} keys unmatched, please check again.".format(weight,
incorrect_keys)
assert incorrect_keys == 0, "Load weight {} incorrectly, \ logger.info('Finish resuming model weights: {}'.format(pdparam_path))
{} keys unmatched, please check again.".format(weight, model.set_dict(model_weight)
incorrect_keys)
logger.info('Finish resuming model weights: {}'.format(pdparam_path))
model.set_dict(model_weight)
last_epoch = 0 last_epoch = 0
if optimizer is not None and os.path.exists(path + '.pdopt'): if optimizer is not None and os.path.exists(path + '.pdopt'):
...@@ -134,7 +124,7 @@ def load_weight(model, weight, optimizer=None, ema=None, exchange=True): ...@@ -134,7 +124,7 @@ def load_weight(model, weight, optimizer=None, ema=None, exchange=True):
return last_epoch return last_epoch
def match_state_dict(model_state_dict, weight_state_dict): def match_state_dict(model_state_dict, weight_state_dict, mode='default'):
""" """
Match between the model state dict and pretrained weight state dict. Match between the model state dict and pretrained weight state dict.
Return the matched state dict. Return the matched state dict.
...@@ -152,33 +142,47 @@ def match_state_dict(model_state_dict, weight_state_dict): ...@@ -152,33 +142,47 @@ def match_state_dict(model_state_dict, weight_state_dict):
model_keys = sorted(model_state_dict.keys()) model_keys = sorted(model_state_dict.keys())
weight_keys = sorted(weight_state_dict.keys()) weight_keys = sorted(weight_state_dict.keys())
def teacher_match(a, b):
# skip student params
if b.startswith('modelStudent'):
return False
return a == b or a.endswith("." + b) or b.endswith("." + a)
def student_match(a, b):
# skip teacher params
if b.startswith('modelTeacher'):
return False
return a == b or a.endswith("." + b) or b.endswith("." + a)
def match(a, b): def match(a, b):
if b.startswith('backbone.res5'): if a.startswith('backbone.res5'):
# In Faster RCNN, res5 pretrained weights have prefix of backbone,
# however, the corresponding model weights have difficult prefix,
# bbox_head.
b = b[9:] b = b[9:]
return a == b or a.endswith("." + b) return a == b or a.endswith("." + b)
if mode == 'student':
match_op = student_match
elif mode == 'teacher':
match_op = teacher_match
else:
match_op = match
match_matrix = np.zeros([len(model_keys), len(weight_keys)]) match_matrix = np.zeros([len(model_keys), len(weight_keys)])
for i, m_k in enumerate(model_keys): for i, m_k in enumerate(model_keys):
for j, w_k in enumerate(weight_keys): for j, w_k in enumerate(weight_keys):
if match(m_k, w_k): if match_op(m_k, w_k):
match_matrix[i, j] = len(w_k) match_matrix[i, j] = len(w_k)
max_id = match_matrix.argmax(1) max_id = match_matrix.argmax(1)
max_len = match_matrix.max(1) max_len = match_matrix.max(1)
max_id[max_len == 0] = -1 max_id[max_len == 0] = -1
load_id = set(max_id)
load_id.discard(-1)
not_load_weight_name = [] not_load_weight_name = []
for idx in range(len(weight_keys)):
if idx not in load_id:
not_load_weight_name.append(weight_keys[idx])
for match_idx in range(len(max_id)):
if max_id[match_idx] == -1:
not_load_weight_name.append(model_keys[match_idx])
if len(not_load_weight_name) > 0: if len(not_load_weight_name) > 0:
logger.info('{} in pretrained weight is not used in the model, ' logger.info('{} in model is not matched with pretrained weights, '
'and its will not be loaded'.format(not_load_weight_name)) 'and its will be trained from scratch'.format(
not_load_weight_name))
matched_keys = {} matched_keys = {}
result_state_dict = {} result_state_dict = {}
for model_id, weight_id in enumerate(max_id): for model_id, weight_id in enumerate(max_id):
...@@ -208,7 +212,7 @@ def match_state_dict(model_state_dict, weight_state_dict): ...@@ -208,7 +212,7 @@ def match_state_dict(model_state_dict, weight_state_dict):
return result_state_dict return result_state_dict
def load_pretrain_weight(model, pretrain_weight): def load_pretrain_weight(model, pretrain_weight, ARSL_eval=False):
if is_url(pretrain_weight): if is_url(pretrain_weight):
pretrain_weight = get_weights_path(pretrain_weight) pretrain_weight = get_weights_path(pretrain_weight)
...@@ -219,21 +223,48 @@ def load_pretrain_weight(model, pretrain_weight): ...@@ -219,21 +223,48 @@ def load_pretrain_weight(model, pretrain_weight):
"If you don't want to load pretrain model, " "If you don't want to load pretrain model, "
"please delete `pretrain_weights` field in " "please delete `pretrain_weights` field in "
"config file.".format(path)) "config file.".format(path))
teacher_student_flag = False
if not ARSL_eval:
if hasattr(model, 'modelTeacher') and hasattr(model, 'modelStudent'):
print('Loading pretrain weights for Teacher-Student framework.')
print(
'Assert Teacher model has the same structure with Student model.'
)
model_dict = model.modelStudent.state_dict()
teacher_student_flag = True
else:
model_dict = model.state_dict()
weights_path = path + '.pdparams'
param_state_dict = paddle.load(weights_path)
param_state_dict = match_state_dict(model_dict, param_state_dict)
for k, v in param_state_dict.items():
if isinstance(v, np.ndarray):
v = paddle.to_tensor(v)
if model_dict[k].dtype != v.dtype:
param_state_dict[k] = v.astype(model_dict[k].dtype)
if teacher_student_flag:
model.modelStudent.set_dict(param_state_dict)
model.modelTeacher.set_dict(param_state_dict)
else:
model.set_dict(param_state_dict)
logger.info('Finish loading model weights: {}'.format(weights_path))
model_dict = model.state_dict() else:
weights_path = path + '.pdparams'
weights_path = path + '.pdparams' param_state_dict = paddle.load(weights_path)
param_state_dict = paddle.load(weights_path) student_model_dict = model.modelStudent.state_dict()
param_state_dict = match_state_dict(model_dict, param_state_dict) student_param_state_dict = match_state_dict(
student_model_dict, param_state_dict, mode='student')
for k, v in param_state_dict.items(): model.modelStudent.set_dict(student_param_state_dict)
if isinstance(v, np.ndarray): print('Loading pretrain weights for Teacher model.')
v = paddle.to_tensor(v) teacher_model_dict = model.modelTeacher.state_dict()
if model_dict[k].dtype != v.dtype:
param_state_dict[k] = v.astype(model_dict[k].dtype)
model.set_dict(param_state_dict) teacher_param_state_dict = match_state_dict(
logger.info('Finish loading model weights: {}'.format(weights_path)) teacher_model_dict, param_state_dict, mode='teacher')
model.modelTeacher.set_dict(teacher_param_state_dict)
logger.info('Finish loading model weights: {}'.format(weights_path))
def save_model(model, def save_model(model,
...@@ -256,21 +287,24 @@ def save_model(model, ...@@ -256,21 +287,24 @@ def save_model(model,
""" """
if paddle.distributed.get_rank() != 0: if paddle.distributed.get_rank() != 0:
return return
assert isinstance(model, dict), ("model is not a instance of dict, "
"please call model.state_dict() to get.")
if not os.path.exists(save_dir): if not os.path.exists(save_dir):
os.makedirs(save_dir) os.makedirs(save_dir)
save_path = os.path.join(save_dir, save_name) save_path = os.path.join(save_dir, save_name)
# save model # save model
if ema_model is None: if isinstance(model, nn.Layer):
paddle.save(model, save_path + ".pdparams") paddle.save(model.state_dict(), save_path + ".pdparams")
else: else:
assert isinstance(ema_model, assert isinstance(model,
dict), ("ema_model is not a instance of dict, " dict), 'model is not a instance of nn.layer or dict'
"please call model.state_dict() to get.") if ema_model is None:
# Exchange model and ema_model to save paddle.save(model, save_path + ".pdparams")
paddle.save(ema_model, save_path + ".pdparams") else:
paddle.save(model, save_path + ".pdema") assert isinstance(ema_model,
dict), ("ema_model is not a instance of dict, "
"please call model.state_dict() to get.")
# Exchange model and ema_model to save
paddle.save(ema_model, save_path + ".pdparams")
paddle.save(model, save_path + ".pdema")
# save optimizer # save optimizer
state_dict = optimizer.state_dict() state_dict = optimizer.state_dict()
state_dict['last_epoch'] = last_epoch state_dict['last_epoch'] = last_epoch
......
...@@ -32,7 +32,7 @@ import paddle ...@@ -32,7 +32,7 @@ import paddle
from ppdet.core.workspace import create, load_config, merge_config from ppdet.core.workspace import create, load_config, merge_config
from ppdet.utils.check import check_gpu, check_npu, check_xpu, check_mlu, check_version, check_config from ppdet.utils.check import check_gpu, check_npu, check_xpu, check_mlu, check_version, check_config
from ppdet.utils.cli import ArgsParser, merge_args from ppdet.utils.cli import ArgsParser, merge_args
from ppdet.engine import Trainer, init_parallel_env from ppdet.engine import Trainer, Trainer_ARSL, init_parallel_env
from ppdet.metrics.coco_utils import json_eval_results from ppdet.metrics.coco_utils import json_eval_results
from ppdet.slim import build_slim_model from ppdet.slim import build_slim_model
...@@ -135,12 +135,17 @@ def run(FLAGS, cfg): ...@@ -135,12 +135,17 @@ def run(FLAGS, cfg):
# init parallel environment if nranks > 1 # init parallel environment if nranks > 1
init_parallel_env() init_parallel_env()
ssod_method = cfg.get('ssod_method', None)
# build trainer if ssod_method == 'ARSL':
trainer = Trainer(cfg, mode='eval') # build ARSL_trainer
trainer = Trainer_ARSL(cfg, mode='eval')
# load weights # load ARSL_weights
trainer.load_weights(cfg.weights) trainer.load_weights(cfg.weights, ARSL_eval=True)
else:
# build trainer
trainer = Trainer(cfg, mode='eval')
#load weights
trainer.load_weights(cfg.weights)
# training # training
if FLAGS.slice_infer: if FLAGS.slice_infer:
......
...@@ -32,6 +32,7 @@ from ppdet.core.workspace import load_config, merge_config ...@@ -32,6 +32,7 @@ from ppdet.core.workspace import load_config, merge_config
from ppdet.utils.check import check_gpu, check_version, check_config from ppdet.utils.check import check_gpu, check_version, check_config
from ppdet.utils.cli import ArgsParser from ppdet.utils.cli import ArgsParser
from ppdet.engine import Trainer from ppdet.engine import Trainer
from ppdet.engine.trainer_ssod import Trainer_ARSL
from ppdet.slim import build_slim_model from ppdet.slim import build_slim_model
from ppdet.utils.logger import setup_logger from ppdet.utils.logger import setup_logger
...@@ -60,14 +61,19 @@ def parse_args(): ...@@ -60,14 +61,19 @@ def parse_args():
def run(FLAGS, cfg): def run(FLAGS, cfg):
ssod_method = cfg.get('ssod_method', None)
if ssod_method is not None and ssod_method == 'ARSL':
trainer = Trainer_ARSL(cfg, mode='test')
trainer.load_weights(cfg.weights, ARSL_eval=True)
# build detector # build detector
trainer = Trainer(cfg, mode='test')
# load weights
if cfg.architecture in ['DeepSORT', 'ByteTrack']:
trainer.load_weights_sde(cfg.det_weights, cfg.reid_weights)
else: else:
trainer.load_weights(cfg.weights) trainer = Trainer(cfg, mode='test')
# load weights
if cfg.architecture in ['DeepSORT', 'ByteTrack']:
trainer.load_weights_sde(cfg.det_weights, cfg.reid_weights)
else:
trainer.load_weights(cfg.weights)
# export model # export model
trainer.export(FLAGS.output_dir) trainer.export(FLAGS.output_dir)
......
...@@ -31,7 +31,7 @@ import ast ...@@ -31,7 +31,7 @@ import ast
import paddle import paddle
from ppdet.core.workspace import load_config, merge_config from ppdet.core.workspace import load_config, merge_config
from ppdet.engine import Trainer from ppdet.engine import Trainer, Trainer_ARSL
from ppdet.utils.check import check_gpu, check_npu, check_xpu, check_mlu, check_version, check_config from ppdet.utils.check import check_gpu, check_npu, check_xpu, check_mlu, check_version, check_config
from ppdet.utils.cli import ArgsParser, merge_args from ppdet.utils.cli import ArgsParser, merge_args
from ppdet.slim import build_slim_model from ppdet.slim import build_slim_model
...@@ -156,12 +156,13 @@ def get_test_images(infer_dir, infer_img): ...@@ -156,12 +156,13 @@ def get_test_images(infer_dir, infer_img):
def run(FLAGS, cfg): def run(FLAGS, cfg):
# build trainer ssod_method = cfg.get('ssod_method', None)
trainer = Trainer(cfg, mode='test') if ssod_method == 'ARSL':
trainer = Trainer_ARSL(cfg, mode='test')
# load weights trainer.load_weights(cfg.weights, ARSL_eval=True)
trainer.load_weights(cfg.weights) else:
trainer = Trainer(cfg, mode='test')
trainer.load_weights(cfg.weights)
# get inference images # get inference images
images = get_test_images(FLAGS.infer_dir, FLAGS.infer_img) images = get_test_images(FLAGS.infer_dir, FLAGS.infer_img)
......
...@@ -32,7 +32,7 @@ import paddle ...@@ -32,7 +32,7 @@ import paddle
from ppdet.core.workspace import load_config, merge_config from ppdet.core.workspace import load_config, merge_config
from ppdet.engine import Trainer, TrainerCot, init_parallel_env, set_random_seed, init_fleet_env from ppdet.engine import Trainer, TrainerCot, init_parallel_env, set_random_seed, init_fleet_env
from ppdet.engine.trainer_ssod import Trainer_DenseTeacher from ppdet.engine.trainer_ssod import Trainer_DenseTeacher, Trainer_ARSL
from ppdet.slim import build_slim_model from ppdet.slim import build_slim_model
...@@ -132,9 +132,11 @@ def run(FLAGS, cfg): ...@@ -132,9 +132,11 @@ def run(FLAGS, cfg):
if ssod_method is not None: if ssod_method is not None:
if ssod_method == 'DenseTeacher': if ssod_method == 'DenseTeacher':
trainer = Trainer_DenseTeacher(cfg, mode='train') trainer = Trainer_DenseTeacher(cfg, mode='train')
elif ssod_method == 'ARSL':
trainer = Trainer_ARSL(cfg, mode='train')
else: else:
raise ValueError( raise ValueError(
"Semi-Supervised Object Detection only support DenseTeacher now." "Semi-Supervised Object Detection only support DenseTeacher and ARSL now."
) )
elif cfg.get('use_cot', False): elif cfg.get('use_cot', False):
trainer = TrainerCot(cfg, mode='train') trainer = TrainerCot(cfg, mode='train')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册