From 1e211871aa95d8772f8b1990ffc4f3c84bb79a17 Mon Sep 17 00:00:00 2001 From: wjm <897383984@qq.com> Date: Sat, 1 Apr 2023 15:05:35 +0800 Subject: [PATCH] Support ARSL(CVPR2023) for semi-supervised object detection (#7980) * add SSOD_asrl * modify traniner name * add modelzoo * add config * add config * add config * modify cfg name * modify cfg * modify cfg * modify checkpoint * modify cfg * add voc and lsj * add voc and lsj * del export * modify * modify * refine codes * fix fcos_head get_loss * add export * fix bug * add export infer * change * retry * fix eval infer --------- Co-authored-by: nemonameless --- configs/semi_det/README.md | 22 +- .../semi_det/_base_/coco_detection_voc.yml | 31 + configs/semi_det/_base_/voc2coco.py | 213 +++++ configs/semi_det/arsl/README.md | 48 ++ .../arsl/_base_/arsl_fcos_r50_fpn.yml | 56 ++ .../semi_det/arsl/_base_/arsl_fcos_reader.yml | 55 ++ .../semi_det/arsl/_base_/optimizer_360k.yml | 29 + .../semi_det/arsl/_base_/optimizer_90k.yml | 30 + .../arsl/arsl_fcos_r50_fpn_coco_full.yml | 12 + .../arsl/arsl_fcos_r50_fpn_coco_semi001.yml | 12 + .../arsl/arsl_fcos_r50_fpn_coco_semi005.yml | 12 + .../arsl/arsl_fcos_r50_fpn_coco_semi010.yml | 12 + .../arsl_fcos_r50_fpn_coco_semi010_lsj.yml | 47 ++ ppdet/engine/trainer.py | 10 +- ppdet/engine/trainer_ssod.py | 401 ++++++++- ppdet/modeling/architectures/__init__.py | 2 +- ppdet/modeling/architectures/fcos.py | 130 ++- ppdet/modeling/heads/fcos_head.py | 155 +++- ppdet/modeling/losses/fcos_loss.py | 759 +++++++++++++++++- ppdet/utils/checkpoint.py | 182 +++-- tools/eval.py | 19 +- tools/export_model.py | 18 +- tools/infer.py | 15 +- tools/train.py | 6 +- 24 files changed, 2159 insertions(+), 117 deletions(-) create mode 100644 configs/semi_det/_base_/coco_detection_voc.yml create mode 100644 configs/semi_det/_base_/voc2coco.py create mode 100644 configs/semi_det/arsl/README.md create mode 100644 configs/semi_det/arsl/_base_/arsl_fcos_r50_fpn.yml create mode 100644 configs/semi_det/arsl/_base_/arsl_fcos_reader.yml create mode 100644 configs/semi_det/arsl/_base_/optimizer_360k.yml create mode 100644 configs/semi_det/arsl/_base_/optimizer_90k.yml create mode 100644 configs/semi_det/arsl/arsl_fcos_r50_fpn_coco_full.yml create mode 100644 configs/semi_det/arsl/arsl_fcos_r50_fpn_coco_semi001.yml create mode 100644 configs/semi_det/arsl/arsl_fcos_r50_fpn_coco_semi005.yml create mode 100644 configs/semi_det/arsl/arsl_fcos_r50_fpn_coco_semi010.yml create mode 100644 configs/semi_det/arsl/arsl_fcos_r50_fpn_coco_semi010_lsj.yml diff --git a/configs/semi_det/README.md b/configs/semi_det/README.md index 996a1decf..5256026a3 100644 --- a/configs/semi_det/README.md +++ b/configs/semi_det/README.md @@ -7,6 +7,7 @@ - [模型库](#模型库) - [Baseline](#Baseline) - [DenseTeacher](#DenseTeacher) + - [ARSL](#ARSL) - [半监督数据集准备](#半监督数据集准备) - [半监督检测配置](#半监督检测配置) - [训练集配置](#训练集配置) @@ -23,7 +24,7 @@ - [引用](#引用) ## 简介 -半监督目标检测(Semi DET)是**同时使用有标注数据和无标注数据**进行训练的目标检测,既可以极大地节省标注成本,也可以充分利用无标注数据进一步提高检测精度。PaddleDetection团队复现了[DenseTeacher](denseteacher)半监督检测算法,用户可以下载使用。 +半监督目标检测(Semi DET)是**同时使用有标注数据和无标注数据**进行训练的目标检测,既可以极大地节省标注成本,也可以充分利用无标注数据进一步提高检测精度。PaddleDetection团队提供了[DenseTeacher](denseteacher/)和[ARSL](arsl/)等最前沿的半监督检测算法,用户可以下载使用。 ## 模型库 @@ -41,6 +42,25 @@ | DenseTeacher-FCOS(LSJ)| 10% | [sup_config](./baseline/fcos_r50_fpn_2x_coco_sup010.yml) | 24 (17424) | 26.3 | **37.1(LSJ)** | 240 (174240) | [download](https://paddledet.bj.bcebos.com/models/denseteacher_fcos_r50_fpn_coco_semi010_lsj.pdparams) | [config](denseteacher/denseteacher_fcos_r50_fpn_coco_semi010_lsj.yml) | | DenseTeacher-FCOS |100%(full)| [sup_config](./../fcos/fcos_r50_fpn_iou_multiscale_2x_coco.ymll) | 24 (175896) | 42.6 | **44.2** | 24 (175896)| [download](https://paddledet.bj.bcebos.com/models/denseteacher_fcos_r50_fpn_coco_full.pdparams) | [config](denseteacher/denseteacher_fcos_r50_fpn_coco_full.yml) | +| 模型 | 监督数据比例 | Sup Baseline | Sup Epochs (Iters) | Sup mAPval
0.5:0.95 | Semi mAPval
0.5:0.95 | Semi Epochs (Iters) | 模型下载 | 配置文件 | +| :------------: | :---------: | :---------------------: | :---------------------: |:---------------------------: |:----------------------------: | :------------------: |:--------: |:----------: | +| DenseTeacher-PPYOLOE+_s | 5% | [sup_config](./baseline/ppyoloe_plus_crn_s_80e_coco_sup005.yml) | 80 (14480) | 32.8 | **34.0** | 200 (36200) | [download](https://paddledet.bj.bcebos.com/models/denseteacher_ppyoloe_plus_crn_s_coco_semi005.pdparams) | [config](denseteacher/denseteacher_ppyoloe_plus_crn_s_coco_semi005.yml) | +| DenseTeacher-PPYOLOE+_s | 10% | [sup_config](./baseline/ppyoloe_plus_crn_s_80e_coco_sup010.yml) | 80 (14480) | 35.3 | **37.5** | 200 (36200) | [download](https://paddledet.bj.bcebos.com/models/denseteacher_ppyoloe_plus_crn_s_coco_semi010.pdparams) | [config](denseteacher/denseteacher_ppyoloe_plus_crn_s_coco_semi010.yml) | +| DenseTeacher-PPYOLOE+_l | 5% | [sup_config](./baseline/ppyoloe_plus_crn_s_80e_coco_sup005.yml) | 80 (14480) | 42.9 | **45.4** | 200 (36200) | [download](https://paddledet.bj.bcebos.com/models/denseteacher_ppyoloe_plus_crn_l_coco_semi005.pdparams) | [config](denseteacher/denseteacher_ppyoloe_plus_crn_l_coco_semi005.yml) | +| DenseTeacher-PPYOLOE+_l | 10% | [sup_config](./baseline/ppyoloe_plus_crn_l_80e_coco_sup010.yml) | 80 (14480) | 45.7 | **47.4** | 200 (36200) | [download](https://paddledet.bj.bcebos.com/models/denseteacher_ppyoloe_plus_crn_l_coco_semi010.pdparams) | [config](denseteacher/denseteacher_ppyoloe_plus_crn_l_coco_semi010.yml) | + + +### [ARSL](arsl) + +| 模型 | COCO监督数据比例 | Semi mAPval
0.5:0.95 | Semi Epochs (Iters) | 模型下载 | 配置文件 | +| :------------: | :---------:|:----------------------------: | :------------------: |:--------: |:----------: | +| ARSL-FCOS | 1% | **22.8** | 240 (87120) | [download](https://paddledet.bj.bcebos.com/models/arsl_fcos_r50_fpn_coco_semi001.pdparams) | [config](arsl/arsl_fcos_r50_fpn_coco_semi001.yml) | +| ARSL-FCOS | 5% | **33.1** | 240 (174240) | [download](https://paddledet.bj.bcebos.com/models/arsl_fcos_r50_fpn_coco_semi005.pdparams) | [config](arsl/arsl_fcos_r50_fpn_coco_semi005.yml ) | +| ARSL-FCOS | 10% | **36.9** | 240 (174240) | [download](https://paddledet.bj.bcebos.com/models/arsl_fcos_r50_fpn_coco_semi010.pdparams) | [config](arsl/arsl_fcos_r50_fpn_coco_semi010.yml ) | +| ARSL-FCOS | 10% | **38.5(LSJ)** | 240 (174240) | [download](https://paddledet.bj.bcebos.com/models/arsl_fcos_r50_fpn_coco_semi010_lsj.pdparams) | [config](arsl/arsl_fcos_r50_fpn_coco_semi010_lsj.yml ) | +| ARSL-FCOS | full(100%) | **45.1** | 240 (174240) | [download](https://paddledet.bj.bcebos.com/models/arsl_fcos_r50_fpn_coco_full.pdparams) | [config](arsl/arsl_fcos_r50_fpn_coco_full.yml ) | + + ## 半监督数据集准备 diff --git a/configs/semi_det/_base_/coco_detection_voc.yml b/configs/semi_det/_base_/coco_detection_voc.yml new file mode 100644 index 000000000..8548081cf --- /dev/null +++ b/configs/semi_det/_base_/coco_detection_voc.yml @@ -0,0 +1,31 @@ +metric: COCO +num_classes: 20 +# before training, change VOC to COCO format by 'python voc2coco.py' +# partial labeled COCO, use `SemiCOCODataSet` rather than `COCODataSet` +TrainDataset: + !SemiCOCODataSet + image_dir: VOC2007/JPEGImages + anno_path: PseudoAnnotations/VOC2007_trainval.json + dataset_dir: dataset/voc/VOCdevkit + data_fields: ['image', 'gt_bbox', 'gt_class', 'is_crowd'] + +# partial unlabeled COCO, use `SemiCOCODataSet` rather than `COCODataSet` +UnsupTrainDataset: + !SemiCOCODataSet + image_dir: VOC2012/JPEGImages + anno_path: PseudoAnnotations/VOC2012_trainval.json + dataset_dir: dataset/voc/VOCdevkit + data_fields: ['image'] + supervised: False + +EvalDataset: + !COCODataSet + image_dir: VOC2007/JPEGImages + anno_path: PseudoAnnotations/VOC2007_test.json + dataset_dir: dataset/voc/VOCdevkit/ + allow_empty: true + +TestDataset: + !ImageFolder + anno_path: PseudoAnnotations/VOC2007_test.json # also support txt (like VOC's label_list.txt) + dataset_dir: dataset/voc/VOCdevkit/ # if set, anno_path will be 'dataset_dir/anno_path' diff --git a/configs/semi_det/_base_/voc2coco.py b/configs/semi_det/_base_/voc2coco.py new file mode 100644 index 000000000..87bfe809d --- /dev/null +++ b/configs/semi_det/_base_/voc2coco.py @@ -0,0 +1,213 @@ +# convert VOC xml to COCO format json +import xml.etree.ElementTree as ET +import os +import json +import argparse + + +# create and init coco json, img set, and class set +def init_json(): + # create coco json + coco = dict() + coco['images'] = [] + coco['type'] = 'instances' + coco['annotations'] = [] + coco['categories'] = [] + # voc classes + voc_class = [ + 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', + 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', + 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor' + ] + # init json categories + image_set = set() + class_set = dict() + for cat_id, cat_name in enumerate(voc_class): + cat_item = dict() + cat_item['supercategory'] = 'none' + cat_item['id'] = cat_id + cat_item['name'] = cat_name + coco['categories'].append(cat_item) + class_set[cat_name] = cat_id + return coco, class_set, image_set + + +def getImgItem(file_name, size, img_id): + if file_name is None: + raise Exception('Could not find filename tag in xml file.') + if size['width'] is None: + raise Exception('Could not find width tag in xml file.') + if size['height'] is None: + raise Exception('Could not find height tag in xml file.') + image_item = dict() + image_item['id'] = img_id + image_item['file_name'] = file_name + image_item['width'] = size['width'] + image_item['height'] = size['height'] + return image_item + + +def getAnnoItem(object_name, image_id, ann_id, category_id, bbox): + annotation_item = dict() + annotation_item['segmentation'] = [] + seg = [] + # bbox[] is x,y,w,h + # left_top + seg.append(bbox[0]) + seg.append(bbox[1]) + # left_bottom + seg.append(bbox[0]) + seg.append(bbox[1] + bbox[3]) + # right_bottom + seg.append(bbox[0] + bbox[2]) + seg.append(bbox[1] + bbox[3]) + # right_top + seg.append(bbox[0] + bbox[2]) + seg.append(bbox[1]) + + annotation_item['segmentation'].append(seg) + + annotation_item['area'] = bbox[2] * bbox[3] + annotation_item['iscrowd'] = 0 + annotation_item['ignore'] = 0 + annotation_item['image_id'] = image_id + annotation_item['bbox'] = bbox + annotation_item['category_id'] = category_id + annotation_item['id'] = ann_id + return annotation_item + + +def convert_voc_to_coco(txt_path, json_path, xml_path): + + # create and init coco json, img set, and class set + coco_json, class_set, image_set = init_json() + + ### collect img and ann info into coco json + # read img_name in txt, e.g., 000005 for voc2007, 2008_000002 for voc2012 + img_txt = open(txt_path, 'r') + img_line = img_txt.readline().strip() + + # loop xml + img_id = 0 + ann_id = 0 + while img_line: + print('img_id:', img_id) + + # find corresponding xml + xml_name = img_line.split('Annotations/', 1)[1] + xml_file = os.path.join(xml_path, xml_name) + if not os.path.exists(xml_file): + print('{} is not exists.'.format(xml_name)) + img_line = img_txt.readline().strip() + continue + + # decode xml + tree = ET.parse(xml_file) + root = tree.getroot() + if root.tag != 'annotation': + raise Exception( + 'xml {} root element should be annotation, rather than {}'. + format(xml_name, root.tag)) + + # init img and ann info + bndbox = dict() + size = dict() + size['width'] = None + size['height'] = None + size['depth'] = None + + # filename + fileNameNode = root.find('filename') + file_name = fileNameNode.text + + # img size + sizeNode = root.find('size') + if not sizeNode: + raise Exception('xml {} structure broken at size tag.'.format( + xml_name)) + for subNode in sizeNode: + size[subNode.tag] = int(subNode.text) + + # add img into json + if file_name not in image_set: + img_id += 1 + format_img_id = int("%04d" % img_id) + # print('line 120. format_img_id:', format_img_id) + image_item = getImgItem(file_name, size, img_id) + image_set.add(file_name) + coco_json['images'].append(image_item) + else: + raise Exception(' xml {} duplicated image: {}'.format(xml_name, + file_name)) + + ### add objAnn into json + objectAnns = root.findall('object') + for objectAnn in objectAnns: + bndbox['xmin'] = None + bndbox['xmax'] = None + bndbox['ymin'] = None + bndbox['ymax'] = None + + #add obj category + object_name = objectAnn.find('name').text + if object_name not in class_set: + raise Exception('xml {} Unrecognized category: {}'.format( + xml_name, object_name)) + else: + current_category_id = class_set[object_name] + + #add obj bbox ann + objectBboxNode = objectAnn.find('bndbox') + for coordinate in objectBboxNode: + if bndbox[coordinate.tag] is not None: + raise Exception('xml {} structure corrupted at bndbox tag.'. + format(xml_name)) + bndbox[coordinate.tag] = int(float(coordinate.text)) + bbox = [] + # x + bbox.append(bndbox['xmin']) + # y + bbox.append(bndbox['ymin']) + # w + bbox.append(bndbox['xmax'] - bndbox['xmin']) + # h + bbox.append(bndbox['ymax'] - bndbox['ymin']) + ann_id += 1 + ann_item = getAnnoItem(object_name, img_id, ann_id, + current_category_id, bbox) + coco_json['annotations'].append(ann_item) + + img_line = img_txt.readline().strip() + + print('Saving json.') + json.dump(coco_json, open(json_path, 'w')) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument( + '--type', type=str, default='VOC2007_test', help="data type") + parser.add_argument( + '--base_path', + type=str, + default='dataset/voc/VOCdevkit', + help="base VOC path.") + args = parser.parse_args() + + # image info path + txt_name = args.type + '.txt' + json_name = args.type + '.json' + txt_path = os.path.join(args.base_path, 'PseudoAnnotations', txt_name) + json_path = os.path.join(args.base_path, 'PseudoAnnotations', json_name) + + # xml path + xml_path = os.path.join(args.base_path, + args.type.split('_')[0], 'Annotations') + + print('txt_path:', txt_path) + print('json_path:', json_path) + print('xml_path:', xml_path) + + print('Converting {} to COCO json.'.format(args.type)) + convert_voc_to_coco(txt_path, json_path, xml_path) + print('Finished.') diff --git a/configs/semi_det/arsl/README.md b/configs/semi_det/arsl/README.md new file mode 100644 index 000000000..aee750ecd --- /dev/null +++ b/configs/semi_det/arsl/README.md @@ -0,0 +1,48 @@ +简体中文 | [English](README_en.md) + +# Ambiguity-Resistant Semi-Supervised Learning for Dense Object Detection (ARSL) + +## ARSL-FCOS 模型库 + +| 模型 | COCO监督数据比例 | Semi mAPval
0.5:0.95 | Semi Epochs (Iters) | 模型下载 | 配置文件 | +| :------------: | :---------:|:----------------------------: | :------------------: |:--------: |:----------: | +| ARSL-FCOS | 1% | **22.8** | 240 (87120) | [download](https://paddledet.bj.bcebos.com/models/arsl_fcos_r50_fpn_coco_semi001.pdparams) | [config](./arsl_fcos_r50_fpn_coco_semi001.yml) | +| ARSL-FCOS | 5% | **33.1** | 240 (174240) | [download](https://paddledet.bj.bcebos.com/models/arsl_fcos_r50_fpn_coco_semi005.pdparams) | [config](./arsl_fcos_r50_fpn_coco_semi005.yml ) | +| ARSL-FCOS | 10% | **36.9** | 240 (174240) | [download](https://paddledet.bj.bcebos.com/models/arsl_fcos_r50_fpn_coco_semi010.pdparams) | [config](./arsl_fcos_r50_fpn_coco_semi010.yml ) | +| ARSL-FCOS | 10% | **38.5(LSJ)** | 240 (174240) | [download](https://paddledet.bj.bcebos.com/models/arsl_fcos_r50_fpn_coco_semi010_lsj.pdparams) | [config](./arsl_fcos_r50_fpn_coco_semi010_lsj.yml ) | +| ARSL-FCOS | full(100%) | **45.1** | 240 (174240) | [download](https://paddledet.bj.bcebos.com/models/arsl_fcos_r50_fpn_coco_full.pdparams) | [config](./arsl_fcos_r50_fpn_coco_full.yml ) | + + + +## 使用说明 + +仅训练时必须使用半监督检测的配置文件去训练,评估、预测、部署也可以按基础检测器的配置文件去执行。 + +### 训练 + +```bash +# 单卡训练 (不推荐,需按线性比例相应地调整学习率) +CUDA_VISIBLE_DEVICES=0 python tools/train.py -c configs/semi_det/arsl/arsl_fcos_r50_fpn_coco_semi010.yml --eval + +# 多卡训练 +python -m paddle.distributed.launch --log_dir=arsl_fcos_r50_fpn_coco_semi010/ --gpus 0,1,2,3,4,5,6,7 tools/train.py -c configs/semi_det/arsl/arsl_fcos_r50_fpn_coco_semi010.yml --eval +``` + +### 评估 + +```bash +CUDA_VISIBLE_DEVICES=0 python tools/eval.py -c configs/semi_det/arsl/arsl_fcos_r50_fpn_coco_semi010.yml -o weights=output/arsl_fcos_r50_fpn_coco_semi010/model_final.pdparams +``` + +### 预测 + +```bash +CUDA_VISIBLE_DEVICES=0 python tools/infer.py -c configs/semi_det/arsl/arsl_fcos_r50_fpn_coco_semi010.yml -o weights=output/arsl_fcos_r50_fpn_coco_semi010/model_final.pdparams --infer_img=demo/000000014439.jpg +``` + + +## 引用 + +``` + +``` diff --git a/configs/semi_det/arsl/_base_/arsl_fcos_r50_fpn.yml b/configs/semi_det/arsl/_base_/arsl_fcos_r50_fpn.yml new file mode 100644 index 000000000..95733bc87 --- /dev/null +++ b/configs/semi_det/arsl/_base_/arsl_fcos_r50_fpn.yml @@ -0,0 +1,56 @@ +architecture: ARSL_FCOS +pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet50_cos_pretrained.pdparams + +ARSL_FCOS: + backbone: ResNet + neck: FPN + fcos_head: FCOSHead_ARSL + fcos_cr_loss: FCOSLossCR + + +ResNet: + depth: 50 + norm_type: bn + freeze_at: 0 + return_idx: [1,2,3] + num_stages: 4 + +FPN: + out_channel: 256 + spatial_scales: [0.125, 0.0625, 0.03125] + extra_stage: 2 + has_extra_convs: true + use_c5: false + +FCOSHead_ARSL: + fcos_feat: + name: FCOSFeat + feat_in: 256 + feat_out: 256 + num_convs: 4 + norm_type: "gn" + use_dcn: false + fpn_stride: [8, 16, 32, 64, 128] + prior_prob: 0.01 + norm_reg_targets: True + centerness_on_reg: True + fcos_loss: + name: FCOSLossMILC + loss_alpha: 0.25 + loss_gamma: 2.0 + iou_loss_type: "giou" + reg_weights: 1.0 + nms: + name: MultiClassNMS + nms_top_k: 1000 + keep_top_k: 100 + score_threshold: 0.025 + nms_threshold: 0.6 + + +FCOSLossCR: + iou_loss_type: "giou" + cls_weight: 2.0 + reg_weight: 2.0 + iou_weight: 0.5 + hard_neg_mining_flag: true diff --git a/configs/semi_det/arsl/_base_/arsl_fcos_reader.yml b/configs/semi_det/arsl/_base_/arsl_fcos_reader.yml new file mode 100644 index 000000000..30dddffcb --- /dev/null +++ b/configs/semi_det/arsl/_base_/arsl_fcos_reader.yml @@ -0,0 +1,55 @@ +worker_num: 2 +SemiTrainReader: + sample_transforms: + - Decode: {} + - RandomResize: {target_size: [[640, 1333], [672, 1333], [704, 1333], [736, 1333], [768, 1333], [800, 1333]], keep_ratio: True, interp: 1} + - RandomFlip: {} + weak_aug: + - NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: true} + strong_aug: + - StrongAugImage: {transforms: [ + RandomColorJitter: {prob: 0.8, brightness: 0.4, contrast: 0.4, saturation: 0.4, hue: 0.1}, + RandomErasingCrop: {}, + RandomGaussianBlur: {prob: 0.5, sigma: [0.1, 2.0]}, + RandomGrayscale: {prob: 0.2}, + ]} + - NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: true} + sup_batch_transforms: + - Permute: {} + - PadBatch: {pad_to_stride: 32} + - Gt2FCOSTarget: + object_sizes_boundary: [64, 128, 256, 512] + center_sampling_radius: 1.5 + downsample_ratios: [8, 16, 32, 64, 128] + num_shift: 0. # default 0.5 + multiply_strides_reg_targets: False + norm_reg_targets: True + unsup_batch_transforms: + - Permute: {} + - PadBatch: {pad_to_stride: 32} + sup_batch_size: 2 + unsup_batch_size: 2 + shuffle: True + drop_last: True + + +EvalReader: + sample_transforms: + - Decode: {} + - Resize: {interp: 2, target_size: [800, 1333], keep_ratio: True} + - NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: true} + - Permute: {} + batch_transforms: + - PadBatch: {pad_to_stride: 32} + batch_size: 1 + + +TestReader: + sample_transforms: + - Decode: {} + - Resize: {interp: 2, target_size: [800, 1333], keep_ratio: True} + - NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: true} + - Permute: {} + batch_transforms: + - PadBatch: {pad_to_stride: 32} + batch_size: 1 diff --git a/configs/semi_det/arsl/_base_/optimizer_360k.yml b/configs/semi_det/arsl/_base_/optimizer_360k.yml new file mode 100644 index 000000000..99072de55 --- /dev/null +++ b/configs/semi_det/arsl/_base_/optimizer_360k.yml @@ -0,0 +1,29 @@ +epoch: 120 # employ iter to control shedule +LearningRate: + base_lr: 0.02 # 0.02 for 8*(4+4) batch + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: [3000] # do not decay lr + - !LinearWarmup + start_factor: 0.3333333333333333 + steps: 1000 + +max_iter: 360000 # 360k for 32 batch, 720k for 16 batch +epoch_iter: 1000 # set epoch_iter for saving checkpoint and eval +optimize_rate: 1 +SEMISUPNET: + BBOX_THRESHOLD: 0.5 # # not used + TEACHER_UPDATE_ITER: 1 + BURN_UP_STEP: 30000 + EMA_KEEP_RATE: 0.9996 + UNSUP_LOSS_WEIGHT: 1.0 # detailed weights for cls and loc task can be seen in cr_loss + PSEUDO_WARM_UP_STEPS: 2000 + +OptimizerBuilder: + optimizer: + momentum: 0.9 + type: Momentum + regularizer: + factor: 0.0001 + type: L2 diff --git a/configs/semi_det/arsl/_base_/optimizer_90k.yml b/configs/semi_det/arsl/_base_/optimizer_90k.yml new file mode 100644 index 000000000..623d7f33e --- /dev/null +++ b/configs/semi_det/arsl/_base_/optimizer_90k.yml @@ -0,0 +1,30 @@ +epoch: 30 # employ iter to control shedule +LearningRate: + base_lr: 0.02 # 0.02 for 8*(4+4) batch + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: [300] # do not decay lr + - !LinearWarmup + start_factor: 0.3333333333333333 + steps: 1000 + +max_iter: 90000 # 90k for 32 batch, 180k for 16 batch +epoch_iter: 1000 # set epoch_iter for saving checkpoint and eval +# update student params according to loss_grad every X iter. +optimize_rate: 1 +SEMISUPNET: + BBOX_THRESHOLD: 0.5 # not used + TEACHER_UPDATE_ITER: 1 + BURN_UP_STEP: 9000 + EMA_KEEP_RATE: 0.9996 + UNSUP_LOSS_WEIGHT: 1.0 # detailed weights for cls and loc task can be seen in cr_loss + PSEUDO_WARM_UP_STEPS: 2000 + +OptimizerBuilder: + optimizer: + momentum: 0.9 + type: Momentum + regularizer: + factor: 0.0001 + type: L2 diff --git a/configs/semi_det/arsl/arsl_fcos_r50_fpn_coco_full.yml b/configs/semi_det/arsl/arsl_fcos_r50_fpn_coco_full.yml new file mode 100644 index 000000000..a868aaf77 --- /dev/null +++ b/configs/semi_det/arsl/arsl_fcos_r50_fpn_coco_full.yml @@ -0,0 +1,12 @@ +_BASE_: [ + '../_base_/coco_detection_full.yml', + '../../runtime.yml', + '_base_/arsl_fcos_r50_fpn.yml', + '_base_/optimizer_360k.yml', + '_base_/arsl_fcos_reader.yml', +] + +weights: output/fcos_r50_fpn_arsl_360k_coco_full/model_final + +#semi detector type +ssod_method: ARSL diff --git a/configs/semi_det/arsl/arsl_fcos_r50_fpn_coco_semi001.yml b/configs/semi_det/arsl/arsl_fcos_r50_fpn_coco_semi001.yml new file mode 100644 index 000000000..136483e54 --- /dev/null +++ b/configs/semi_det/arsl/arsl_fcos_r50_fpn_coco_semi001.yml @@ -0,0 +1,12 @@ +_BASE_: [ + '../_base_/coco_detection_percent_1.yml', + '../../runtime.yml', + '_base_/arsl_fcos_r50_fpn.yml', + '_base_/optimizer_90k.yml', + '_base_/arsl_fcos_reader.yml', +] + +weights: output/arsl_fcos_r50_fpn_coco_semi001/model_final + +#semi detector type +ssod_method: ARSL diff --git a/configs/semi_det/arsl/arsl_fcos_r50_fpn_coco_semi005.yml b/configs/semi_det/arsl/arsl_fcos_r50_fpn_coco_semi005.yml new file mode 100644 index 000000000..7c2a77947 --- /dev/null +++ b/configs/semi_det/arsl/arsl_fcos_r50_fpn_coco_semi005.yml @@ -0,0 +1,12 @@ +_BASE_: [ + '../_base_/coco_detection_percent_5.yml', + '../../runtime.yml', + '_base_/arsl_fcos_r50_fpn.yml', + '_base_/optimizer_90k.yml', + '_base_/arsl_fcos_reader.yml', +] + +weights: output/arsl_fcos_r50_fpn_coco_semi005/model_final + +#semi detector type +ssod_method: ARSL diff --git a/configs/semi_det/arsl/arsl_fcos_r50_fpn_coco_semi010.yml b/configs/semi_det/arsl/arsl_fcos_r50_fpn_coco_semi010.yml new file mode 100644 index 000000000..7abfa59d6 --- /dev/null +++ b/configs/semi_det/arsl/arsl_fcos_r50_fpn_coco_semi010.yml @@ -0,0 +1,12 @@ +_BASE_: [ + '../_base_/coco_detection_percent_10.yml', + '../../runtime.yml', + '_base_/arsl_fcos_r50_fpn.yml', + '_base_/optimizer_360k.yml', + '_base_/arsl_fcos_reader.yml', +] + +weights: output/arsl_fcos_r50_fpn_coco_semi010/model_final + +#semi detector type +ssod_method: ARSL diff --git a/configs/semi_det/arsl/arsl_fcos_r50_fpn_coco_semi010_lsj.yml b/configs/semi_det/arsl/arsl_fcos_r50_fpn_coco_semi010_lsj.yml new file mode 100644 index 000000000..258a7a8e4 --- /dev/null +++ b/configs/semi_det/arsl/arsl_fcos_r50_fpn_coco_semi010_lsj.yml @@ -0,0 +1,47 @@ +_BASE_: [ + '../_base_/coco_detection_percent_10.yml', + '../../runtime.yml', + '_base_/arsl_fcos_r50_fpn.yml', + '_base_/optimizer_360k.yml', + '_base_/arsl_fcos_reader.yml', +] + +weights: output/arsl_fcos_r50_fpn_coco_semi010/model_final + +#semi detector type +ssod_method: ARSL + +worker_num: 2 +SemiTrainReader: + sample_transforms: + - Decode: {} + # large-scale jittering + - RandomResize: {target_size: [[400, 1333], [1200, 1333]], keep_ratio: True, interp: 1, random_range: True} + - RandomFlip: {} + weak_aug: + - NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: true} + strong_aug: + - StrongAugImage: {transforms: [ + RandomColorJitter: {prob: 0.8, brightness: 0.4, contrast: 0.4, saturation: 0.4, hue: 0.1}, + RandomErasingCrop: {}, + RandomGaussianBlur: {prob: 0.5, sigma: [0.1, 2.0]}, + RandomGrayscale: {prob: 0.2}, + ]} + - NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: true} + sup_batch_transforms: + - Permute: {} + - PadBatch: {pad_to_stride: 32} + - Gt2FCOSTarget: + object_sizes_boundary: [64, 128, 256, 512] + center_sampling_radius: 1.5 + downsample_ratios: [8, 16, 32, 64, 128] + num_shift: 0. # default 0.5 + multiply_strides_reg_targets: False + norm_reg_targets: True + unsup_batch_transforms: + - Permute: {} + - PadBatch: {pad_to_stride: 32} + sup_batch_size: 2 + unsup_batch_size: 2 + shuffle: True + drop_last: True diff --git a/ppdet/engine/trainer.py b/ppdet/engine/trainer.py index 730b99f28..55890a979 100644 --- a/ppdet/engine/trainer.py +++ b/ppdet/engine/trainer.py @@ -394,11 +394,11 @@ class Trainer(object): "metrics shoule be instances of subclass of Metric" self._metrics.extend(metrics) - def load_weights(self, weights): + def load_weights(self, weights, ARSL_eval=False): if self.is_loaded_weights: return self.start_epoch = 0 - load_pretrain_weight(self.model, weights) + load_pretrain_weight(self.model, weights, ARSL_eval) logger.debug("Load weights {} to start training".format(weights)) def load_weights_sde(self, det_weights, reid_weights): @@ -985,8 +985,10 @@ class Trainer(object): for step_id, data in enumerate(tqdm(loader)): self.status['step_id'] = step_id # forward - outs = self.model(data) - + if hasattr(self.model, 'modelTeacher'): + outs = self.model.modelTeacher(data) + else: + outs = self.model(data) for _m in metrics: _m.update(data, outs) diff --git a/ppdet/engine/trainer_ssod.py b/ppdet/engine/trainer_ssod.py index ef2409b09..ac39c9a97 100644 --- a/ppdet/engine/trainer_ssod.py +++ b/ppdet/engine/trainer_ssod.py @@ -16,6 +16,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os import copy import time import typing @@ -26,18 +27,20 @@ import paddle.nn as nn import paddle.distributed as dist from paddle.distributed import fleet from ppdet.optimizer import ModelEMA, SimpleModelEMA - from ppdet.core.workspace import create -from ppdet.utils.checkpoint import load_weight, load_pretrain_weight +from ppdet.utils.checkpoint import load_weight, load_pretrain_weight, save_model import ppdet.utils.stats as stats from ppdet.utils import profiler from ppdet.modeling.ssod.utils import align_weak_strong_shape from .trainer import Trainer - from ppdet.utils.logger import setup_logger +from paddle.static import InputSpec +from ppdet.engine.export_utils import _dump_infer_config, _prune_input_spec +MOT_ARCH = ['JDE', 'FairMOT', 'DeepSORT', 'ByteTrack', 'CenterTrack'] + logger = setup_logger('ppdet.engine') -__all__ = ['Trainer_DenseTeacher'] +__all__ = ['Trainer_DenseTeacher', 'Trainer_ARSL'] class Trainer_DenseTeacher(Trainer): @@ -199,11 +202,6 @@ class Trainer_DenseTeacher(Trainer): self.status['data_time'] = stats.SmoothedValue( self.cfg.log_iter, fmt='{avg:.4f}') self.status['training_staus'] = stats.TrainingStats(self.cfg.log_iter) - - if self.cfg.get('print_flops', False): - flops_loader = create('{}Reader'.format(self.mode.capitalize()))( - self.dataset, self.cfg.worker_num) - self._flops(flops_loader) profiler_options = self.cfg.get('profiler_options', None) self._compose_callback.on_train_begin(self.status) @@ -466,6 +464,365 @@ class Trainer_DenseTeacher(Trainer): self.status['sample_num'] = sample_num self.status['cost_time'] = time.time() - tic + # accumulate metric to log out + for metric in self._metrics: + metric.accumulate() + metric.log() + self._compose_callback.on_epoch_end(self.status) + self._reset_metrics() + + +class Trainer_ARSL(Trainer): + def __init__(self, cfg, mode='train'): + self.cfg = cfg + assert mode.lower() in ['train', 'eval', 'test'], \ + "mode should be 'train', 'eval' or 'test'" + self.mode = mode.lower() + self.optimizer = None + self.is_loaded_weights = False + capital_mode = self.mode.capitalize() + self.use_ema = False + self.dataset = self.cfg['{}Dataset'.format(capital_mode)] = create( + '{}Dataset'.format(capital_mode))() + if self.mode == 'train': + self.dataset_unlabel = self.cfg['UnsupTrainDataset'] = create( + 'UnsupTrainDataset') + self.loader = create('SemiTrainReader')( + self.dataset, self.dataset_unlabel, cfg.worker_num) + + # build model + if 'model' not in self.cfg: + self.student_model = create(cfg.architecture) + self.teacher_model = create(cfg.architecture) + self.model = EnsembleTSModel(self.teacher_model, self.student_model) + else: + self.model = self.cfg.model + self.is_loaded_weights = True + # save path for burn-in model + self.base_path = cfg.get('weights') + self.base_path = os.path.dirname(self.base_path) + + # EvalDataset build with BatchSampler to evaluate in single device + # TODO: multi-device evaluate + if self.mode == 'eval': + self._eval_batch_sampler = paddle.io.BatchSampler( + self.dataset, batch_size=self.cfg.EvalReader['batch_size']) + self.loader = create('{}Reader'.format(self.mode.capitalize()))( + self.dataset, cfg.worker_num, self._eval_batch_sampler) + # TestDataset build after user set images, skip loader creation here + + self.start_epoch = 0 + self.end_epoch = 0 if 'epoch' not in cfg else cfg.epoch + self.epoch_iter = self.cfg.epoch_iter # set fixed iter in each epoch to control checkpoint + + # build optimizer in train mode + if self.mode == 'train': + steps_per_epoch = self.epoch_iter + self.lr = create('LearningRate')(steps_per_epoch) + self.optimizer = create('OptimizerBuilder')(self.lr, + self.model.modelStudent) + + self._nranks = dist.get_world_size() + self._local_rank = dist.get_rank() + + self.status = {} + + # initial default callbacks + self._init_callbacks() + + # initial default metrics + self._init_metrics() + self._reset_metrics() + self.iter = 0 + + def resume_weights(self, weights): + # support Distill resume weights + if hasattr(self.model, 'student_model'): + self.start_epoch = load_weight(self.model.student_model, weights, + self.optimizer) + else: + self.start_epoch = load_weight(self.model, weights, self.optimizer) + logger.debug("Resume weights of epoch {}".format(self.start_epoch)) + + def train(self, validate=False): + assert self.mode == 'train', "Model not in 'train' mode" + Init_mark = False + + # if validation in training is enabled, metrics should be re-init + if validate: + self._init_metrics(validate=validate) + self._reset_metrics() + + if self.cfg.get('fleet', False): + self.model.modelStudent = fleet.distributed_model( + self.model.modelStudent) + self.optimizer = fleet.distributed_optimizer(self.optimizer) + elif self._nranks > 1: + find_unused_parameters = self.cfg[ + 'find_unused_parameters'] if 'find_unused_parameters' in self.cfg else False + self.model.modelStudent = paddle.DataParallel( + self.model.modelStudent, + find_unused_parameters=find_unused_parameters) + + # set fixed iter in each epoch to control checkpoint + self.status.update({ + 'epoch_id': self.start_epoch, + 'step_id': 0, + 'steps_per_epoch': self.epoch_iter + }) + print('338 Len of DataLoader: {}'.format(len(self.loader))) + + self.status['batch_time'] = stats.SmoothedValue( + self.cfg.log_iter, fmt='{avg:.4f}') + self.status['data_time'] = stats.SmoothedValue( + self.cfg.log_iter, fmt='{avg:.4f}') + self.status['training_staus'] = stats.TrainingStats(self.cfg.log_iter) + + self._compose_callback.on_train_begin(self.status) + + epoch_id = self.start_epoch + self.iter = self.start_epoch * self.epoch_iter + # use iter rather than epoch to control training schedule + while self.iter < self.cfg.max_iter: + # epoch loop + self.status['mode'] = 'train' + self.status['epoch_id'] = epoch_id + self._compose_callback.on_epoch_begin(self.status) + self.loader.dataset_label.set_epoch(epoch_id) + self.loader.dataset_unlabel.set_epoch(epoch_id) + paddle.device.cuda.empty_cache() # clear GPU memory + # set model status + self.model.modelStudent.train() + self.model.modelTeacher.eval() + iter_tic = time.time() + + # iter loop in each eopch + for step_id in range(self.epoch_iter): + data = next(self.loader) + self.status['data_time'].update(time.time() - iter_tic) + self.status['step_id'] = step_id + # profiler.add_profiler_step(profiler_options) + self._compose_callback.on_step_begin(self.status) + + # model forward and calculate loss + loss_dict = self.run_step_full_semisup(data) + + if (step_id + 1) % self.cfg.optimize_rate == 0: + self.optimizer.step() + self.optimizer.clear_grad() + curr_lr = self.optimizer.get_lr() + self.lr.step() + + # update log status + self.status['learning_rate'] = curr_lr + if self._nranks < 2 or self._local_rank == 0: + self.status['training_staus'].update(loss_dict) + self.status['batch_time'].update(time.time() - iter_tic) + self._compose_callback.on_step_end(self.status) + self.iter += 1 + iter_tic = time.time() + + self._compose_callback.on_epoch_end(self.status) + + if validate and (self._nranks < 2 or self._local_rank == 0) \ + and ((epoch_id + 1) % self.cfg.snapshot_epoch == 0 \ + or epoch_id == self.end_epoch - 1): + if not hasattr(self, '_eval_loader'): + # build evaluation dataset and loader + self._eval_dataset = self.cfg.EvalDataset + self._eval_batch_sampler = \ + paddle.io.BatchSampler( + self._eval_dataset, + batch_size=self.cfg.EvalReader['batch_size']) + self._eval_loader = create('EvalReader')( + self._eval_dataset, + self.cfg.worker_num, + batch_sampler=self._eval_batch_sampler) + if validate and Init_mark == False: + Init_mark = True + self._init_metrics(validate=validate) + self._reset_metrics() + with paddle.no_grad(): + self.status['save_best_model'] = True + # before burn-in stage, eval student. after burn-in stage, eval teacher + if self.iter <= self.cfg.SEMISUPNET['BURN_UP_STEP']: + print("start eval student model") + self._eval_with_loader( + self._eval_loader, mode="student") + else: + print("start eval teacher model") + self._eval_with_loader( + self._eval_loader, mode="teacher") + + epoch_id += 1 + + self._compose_callback.on_train_end(self.status) + + def merge_data(self, data1, data2): + data = copy.deepcopy(data1) + for k, v in data1.items(): + if type(v) is paddle.Tensor: + data[k] = paddle.concat(x=[data[k], data2[k]], axis=0) + elif type(v) is list: + data[k].extend(data2[k]) + return data + + def run_step_full_semisup(self, data): + label_data_k, label_data_q, unlabel_data_k, unlabel_data_q = data + data_merge = self.merge_data(label_data_k, label_data_q) + loss_sup_dict = self.model.modelStudent(data_merge, branch="supervised") + loss_dict = {} + for key in loss_sup_dict.keys(): + if key[:4] == "loss": + loss_dict[key] = loss_sup_dict[key] * 1 + losses_sup = paddle.add_n(list(loss_dict.values())) + # norm loss when using gradient accumulation + losses_sup = losses_sup / self.cfg.optimize_rate + losses_sup.backward() + + for key in loss_sup_dict.keys(): + loss_dict[key + "_pseudo"] = paddle.to_tensor([0]) + loss_dict["loss_tot"] = losses_sup + """ + semi-supervised training after burn-in stage + """ + if self.iter >= self.cfg.SEMISUPNET['BURN_UP_STEP']: + # init teacher model with burn-up weight + if self.iter == self.cfg.SEMISUPNET['BURN_UP_STEP']: + print( + 'Starting semi-supervised learning and load the teacher model.' + ) + self._update_teacher_model(keep_rate=0.00) + # save burn-in model + if dist.get_world_size() < 2 or dist.get_rank() == 0: + print('saving burn-in model.') + save_name = 'burnIn' + epoch_id = self.iter // self.epoch_iter + save_model(self.model, self.optimizer, self.base_path, + save_name, epoch_id) + # Update teacher model with EMA + elif (self.iter + 1) % self.cfg.optimize_rate == 0: + self._update_teacher_model( + keep_rate=self.cfg.SEMISUPNET['EMA_KEEP_RATE']) + + #warm-up weight for pseudo loss + pseudo_weight = self.cfg.SEMISUPNET['UNSUP_LOSS_WEIGHT'] + pseudo_warmup_iter = self.cfg.SEMISUPNET['PSEUDO_WARM_UP_STEPS'] + temp = self.iter - self.cfg.SEMISUPNET['BURN_UP_STEP'] + if temp <= pseudo_warmup_iter: + pseudo_weight *= (temp / pseudo_warmup_iter) + + # get teacher predictions on weak-augmented unlabeled data + with paddle.no_grad(): + teacher_pred = self.model.modelTeacher( + unlabel_data_k, branch='semi_supervised') + + # calculate unsupervised loss on strong-augmented unlabeled data + loss_unsup_dict = self.model.modelStudent( + unlabel_data_q, + branch="semi_supervised", + teacher_prediction=teacher_pred, ) + + for key in loss_unsup_dict.keys(): + if key[-6:] == "pseudo": + loss_unsup_dict[key] = loss_unsup_dict[key] * pseudo_weight + losses_unsup = paddle.add_n(list(loss_unsup_dict.values())) + # norm loss when using gradient accumulation + losses_unsup = losses_unsup / self.cfg.optimize_rate + losses_unsup.backward() + + loss_dict.update(loss_unsup_dict) + loss_dict["loss_tot"] += losses_unsup + return loss_dict + + def export(self, output_dir='output_inference'): + self.model.eval() + model_name = os.path.splitext(os.path.split(self.cfg.filename)[-1])[0] + save_dir = os.path.join(output_dir, model_name) + if not os.path.exists(save_dir): + os.makedirs(save_dir) + image_shape = None + if self.cfg.architecture in MOT_ARCH: + test_reader_name = 'TestMOTReader' + else: + test_reader_name = 'TestReader' + if 'inputs_def' in self.cfg[test_reader_name]: + inputs_def = self.cfg[test_reader_name]['inputs_def'] + image_shape = inputs_def.get('image_shape', None) + # set image_shape=[3, -1, -1] as default + if image_shape is None: + image_shape = [3, -1, -1] + + self.model.modelTeacher.eval() + if hasattr(self.model.modelTeacher, 'deploy'): + self.model.modelTeacher.deploy = True + + # Save infer cfg + _dump_infer_config(self.cfg, + os.path.join(save_dir, 'infer_cfg.yml'), image_shape, + self.model.modelTeacher) + + input_spec = [{ + "image": InputSpec( + shape=[None] + image_shape, name='image'), + "im_shape": InputSpec( + shape=[None, 2], name='im_shape'), + "scale_factor": InputSpec( + shape=[None, 2], name='scale_factor') + }] + if self.cfg.architecture == 'DeepSORT': + input_spec[0].update({ + "crops": InputSpec( + shape=[None, 3, 192, 64], name='crops') + }) + + static_model = paddle.jit.to_static( + self.model.modelTeacher, input_spec=input_spec) + # NOTE: dy2st do not pruned program, but jit.save will prune program + # input spec, prune input spec here and save with pruned input spec + pruned_input_spec = _prune_input_spec(input_spec, + static_model.forward.main_program, + static_model.forward.outputs) + + # dy2st and save model + if 'slim' not in self.cfg or self.cfg['slim_type'] != 'QAT': + paddle.jit.save( + static_model, + os.path.join(save_dir, 'model'), + input_spec=pruned_input_spec) + else: + self.cfg.slim.save_quantized_model( + self.model.modelTeacher, + os.path.join(save_dir, 'model'), + input_spec=pruned_input_spec) + logger.info("Export model and saved in {}".format(save_dir)) + + def _eval_with_loader(self, loader, mode="teacher"): + sample_num = 0 + tic = time.time() + self._compose_callback.on_epoch_begin(self.status) + self.status['mode'] = 'eval' + # self.model.eval() + self.model.modelTeacher.eval() + self.model.modelStudent.eval() + for step_id, data in enumerate(loader): + self.status['step_id'] = step_id + self._compose_callback.on_step_begin(self.status) + if mode == "teacher": + outs = self.model.modelTeacher(data) + else: + outs = self.model.modelStudent(data) + + # update metrics + for metric in self._metrics: + metric.update(data, outs) + + sample_num += data['im_id'].numpy().shape[0] + self._compose_callback.on_step_end(self.status) + + self.status['sample_num'] = sample_num + self.status['cost_time'] = time.time() - tic + # accumulate metric to log out for metric in self._metrics: metric.accumulate() @@ -473,3 +830,29 @@ class Trainer_DenseTeacher(Trainer): self._compose_callback.on_epoch_end(self.status) # reset metric states for metric may performed multiple times self._reset_metrics() + + def evaluate(self): + with paddle.no_grad(): + self._eval_with_loader(self.loader) + + @paddle.no_grad() + def _update_teacher_model(self, keep_rate=0.996): + student_model_dict = copy.deepcopy(self.model.modelStudent.state_dict()) + new_teacher_dict = dict() + for key, value in self.model.modelTeacher.state_dict().items(): + if key in student_model_dict.keys(): + v = student_model_dict[key] * (1 - keep_rate + ) + value * keep_rate + v.stop_gradient = True + new_teacher_dict[key] = v + else: + raise Exception("{} is not found in student model".format(key)) + + self.model.modelTeacher.set_dict(new_teacher_dict) + + +class EnsembleTSModel(nn.Layer): + def __init__(self, modelTeacher, modelStudent): + super(EnsembleTSModel, self).__init__() + self.modelTeacher = modelTeacher + self.modelStudent = modelStudent diff --git a/ppdet/modeling/architectures/__init__.py b/ppdet/modeling/architectures/__init__.py index 4c6c5ed0a..eb5ff75c2 100644 --- a/ppdet/modeling/architectures/__init__.py +++ b/ppdet/modeling/architectures/__init__.py @@ -74,4 +74,4 @@ from .yolof import * from .pose3d_metro import * from .centertrack import * from .queryinst import * -from .keypoint_petr import * +from .keypoint_petr import * \ No newline at end of file diff --git a/ppdet/modeling/architectures/fcos.py b/ppdet/modeling/architectures/fcos.py index efebb6efb..8c338cabf 100644 --- a/ppdet/modeling/architectures/fcos.py +++ b/ppdet/modeling/architectures/fcos.py @@ -16,10 +16,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import paddle from ppdet.core.workspace import register, create from .meta_arch import BaseArch -__all__ = ['FCOS'] +__all__ = ['FCOS', 'ARSL_FCOS'] @register @@ -31,7 +32,7 @@ class FCOS(BaseArch): backbone (object): backbone instance neck (object): 'FPN' instance fcos_head (object): 'FCOSHead' instance - ssod_loss (object): 'SSODFCOSLoss' instance, only used for semi-det(ssod) + ssod_loss (object): 'SSODFCOSLoss' instance, only used for semi-det(ssod) by DenseTeacher """ __category__ = 'architecture' @@ -94,3 +95,128 @@ class FCOS(BaseArch): ssod_losses = self.ssod_loss(student_head_outs, teacher_head_outs, train_cfg) return ssod_losses + + +@register +class ARSL_FCOS(BaseArch): + """ + FCOS ARSL network, see https://arxiv.org/abs/ + + Args: + backbone (object): backbone instance + neck (object): 'FPN' instance + fcos_head (object): 'FCOSHead_ARSL' instance + fcos_cr_loss (object): 'FCOSLossCR' instance, only used for semi-det(ssod) by ARSL + """ + + __category__ = 'architecture' + __inject__ = ['fcos_cr_loss'] + + def __init__(self, + backbone, + neck, + fcos_head='FCOSHead_ARSL', + fcos_cr_loss='FCOSLossCR'): + super(ARSL_FCOS, self).__init__() + self.backbone = backbone + self.neck = neck + self.fcos_head = fcos_head + self.fcos_cr_loss = fcos_cr_loss + + @classmethod + def from_config(cls, cfg, *args, **kwargs): + backbone = create(cfg['backbone']) + + kwargs = {'input_shape': backbone.out_shape} + neck = create(cfg['neck'], **kwargs) + + kwargs = {'input_shape': neck.out_shape} + fcos_head = create(cfg['fcos_head'], **kwargs) + + # consistency regularization loss + fcos_cr_loss = create(cfg['fcos_cr_loss']) + + return { + 'backbone': backbone, + 'neck': neck, + 'fcos_head': fcos_head, + 'fcos_cr_loss': fcos_cr_loss, + } + + def forward(self, inputs, branch="supervised", teacher_prediction=None): + assert branch in ['supervised', 'semi_supervised'], \ + print('In ARSL, type must be supervised or semi_supervised.') + + if self.data_format == 'NHWC': + image = inputs['image'] + inputs['image'] = paddle.transpose(image, [0, 2, 3, 1]) + self.inputs = inputs + + if self.training: + if branch == "supervised": + out = self.get_loss() + else: + out = self.get_pseudo_loss(teacher_prediction) + else: + # norm test + if branch == "supervised": + out = self.get_pred() + # predict pseudo labels + else: + out = self.get_pseudo_pred() + return out + + # model forward + def model_forward(self): + body_feats = self.backbone(self.inputs) + fpn_feats = self.neck(body_feats) + fcos_head_outs = self.fcos_head(fpn_feats) + return fcos_head_outs + + # supervised loss for labeled data + def get_loss(self): + loss = {} + tag_labels, tag_bboxes, tag_centerness = [], [], [] + for i in range(len(self.fcos_head.fpn_stride)): + # labels, reg_target, centerness + k_lbl = 'labels{}'.format(i) + if k_lbl in self.inputs: + tag_labels.append(self.inputs[k_lbl]) + k_box = 'reg_target{}'.format(i) + if k_box in self.inputs: + tag_bboxes.append(self.inputs[k_box]) + k_ctn = 'centerness{}'.format(i) + if k_ctn in self.inputs: + tag_centerness.append(self.inputs[k_ctn]) + fcos_head_outs = self.model_forward() + loss_fcos = self.fcos_head.get_loss(fcos_head_outs, tag_labels, + tag_bboxes, tag_centerness) + loss.update(loss_fcos) + return loss + + # unsupervised loss for unlabeled data + def get_pseudo_loss(self, teacher_prediction): + loss = {} + fcos_head_outs = self.model_forward() + unsup_loss = self.fcos_cr_loss(fcos_head_outs, teacher_prediction) + for k in unsup_loss.keys(): + loss[k + '_pseudo'] = unsup_loss[k] + return loss + + # get detection results for test, decode and rescale the results to original size + def get_pred(self): + fcos_head_outs = self.model_forward() + scale_factor = self.inputs['scale_factor'] + bbox_pred, bbox_num = self.fcos_head.post_process(fcos_head_outs, + scale_factor) + output = {'bbox': bbox_pred, 'bbox_num': bbox_num} + return output + + # generate pseudo labels to guide student + def get_pseudo_pred(self): + fcos_head_outs = self.model_forward() + pred_cls, pred_loc, pred_iou = fcos_head_outs[1:] # 0 is locations + for lvl, _ in enumerate(pred_loc): + pred_loc[lvl] = pred_loc[lvl] / self.fcos_head.fpn_stride[lvl] + + return [pred_cls, pred_loc, pred_iou, self.fcos_head.fpn_stride] diff --git a/ppdet/modeling/heads/fcos_head.py b/ppdet/modeling/heads/fcos_head.py index d6dab8c8d..89c933fe5 100644 --- a/ppdet/modeling/heads/fcos_head.py +++ b/ppdet/modeling/heads/fcos_head.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -26,7 +26,7 @@ from paddle.nn.initializer import Normal, Constant from ppdet.core.workspace import register from ppdet.modeling.layers import ConvNormLayer, MultiClassNMS -__all__ = ['FCOSFeat', 'FCOSHead'] +__all__ = ['FCOSFeat', 'FCOSHead', 'FCOSHead_ARSL'] class ScaleReg(nn.Layer): @@ -263,10 +263,23 @@ class FCOSHead(nn.Layer): centerness_list.append(centerness) if targets is not None: - self.is_teacher = targets.get('is_teacher', False) + self.is_teacher = targets.get('ARSL_teacher', False) if self.is_teacher: return [cls_logits_list, bboxes_reg_list, centerness_list] + if targets is not None: + self.is_student = targets.get('ARSL_student', False) + if self.is_student: + return [cls_logits_list, bboxes_reg_list, centerness_list] + + if targets is not None: + self.is_teacher = targets.get('is_teacher', False) + if self.is_teacher: + return [ + locations_list, cls_logits_list, bboxes_reg_list, + centerness_list + ] + if self.training and targets is not None: get_data = targets.get('get_data', False) if get_data: @@ -361,3 +374,139 @@ class FCOSHead(nn.Layer): pred_scores = pred_scores.transpose([0, 2, 1]) bbox_pred, bbox_num, _ = self.nms(pred_bboxes, pred_scores) return bbox_pred, bbox_num + + +@register +class FCOSHead_ARSL(FCOSHead): + """ + FCOSHead of ARSL for semi-det(ssod) + Args: + fcos_feat (object): Instance of 'FCOSFeat' + num_classes (int): Number of classes + fpn_stride (list): The stride of each FPN Layer + prior_prob (float): Used to set the bias init for the class prediction layer + fcos_loss (object): Instance of 'FCOSLoss' + norm_reg_targets (bool): Normalization the regression target if true + centerness_on_reg (bool): The prediction of centerness on regression or clssification branch + nms (object): Instance of 'MultiClassNMS' + trt (bool): Whether to use trt in nms of deploy + """ + __inject__ = ['fcos_feat', 'fcos_loss', 'nms'] + __shared__ = ['num_classes', 'trt'] + + def __init__(self, + num_classes=80, + fcos_feat='FCOSFeat', + fpn_stride=[8, 16, 32, 64, 128], + prior_prob=0.01, + multiply_strides_reg_targets=False, + norm_reg_targets=True, + centerness_on_reg=True, + num_shift=0.5, + sqrt_score=False, + fcos_loss='FCOSLossMILC', + nms='MultiClassNMS', + trt=False): + super(FCOSHead_ARSL, self).__init__() + self.fcos_feat = fcos_feat + self.num_classes = num_classes + self.fpn_stride = fpn_stride + self.prior_prob = prior_prob + self.fcos_loss = fcos_loss + self.norm_reg_targets = norm_reg_targets + self.centerness_on_reg = centerness_on_reg + self.multiply_strides_reg_targets = multiply_strides_reg_targets + self.num_shift = num_shift + self.nms = nms + if isinstance(self.nms, MultiClassNMS) and trt: + self.nms.trt = trt + self.sqrt_score = sqrt_score + + conv_cls_name = "fcos_head_cls" + bias_init_value = -math.log((1 - self.prior_prob) / self.prior_prob) + self.fcos_head_cls = self.add_sublayer( + conv_cls_name, + nn.Conv2D( + in_channels=256, + out_channels=self.num_classes, + kernel_size=3, + stride=1, + padding=1, + weight_attr=ParamAttr(initializer=Normal( + mean=0., std=0.01)), + bias_attr=ParamAttr( + initializer=Constant(value=bias_init_value)))) + + conv_reg_name = "fcos_head_reg" + self.fcos_head_reg = self.add_sublayer( + conv_reg_name, + nn.Conv2D( + in_channels=256, + out_channels=4, + kernel_size=3, + stride=1, + padding=1, + weight_attr=ParamAttr(initializer=Normal( + mean=0., std=0.01)), + bias_attr=ParamAttr(initializer=Constant(value=0)))) + + conv_centerness_name = "fcos_head_centerness" + self.fcos_head_centerness = self.add_sublayer( + conv_centerness_name, + nn.Conv2D( + in_channels=256, + out_channels=1, + kernel_size=3, + stride=1, + padding=1, + weight_attr=ParamAttr(initializer=Normal( + mean=0., std=0.01)), + bias_attr=ParamAttr(initializer=Constant(value=0)))) + + self.scales_regs = [] + for i in range(len(self.fpn_stride)): + lvl = int(math.log(int(self.fpn_stride[i]), 2)) + feat_name = 'p{}_feat'.format(lvl) + scale_reg = self.add_sublayer(feat_name, ScaleReg()) + self.scales_regs.append(scale_reg) + + def forward(self, fpn_feats, targets=None): + assert len(fpn_feats) == len( + self.fpn_stride + ), "The size of fpn_feats is not equal to size of fpn_stride" + cls_logits_list = [] + bboxes_reg_list = [] + centerness_list = [] + for scale_reg, fpn_stride, fpn_feat in zip(self.scales_regs, + self.fpn_stride, fpn_feats): + fcos_cls_feat, fcos_reg_feat = self.fcos_feat(fpn_feat) + cls_logits = self.fcos_head_cls(fcos_cls_feat) + bbox_reg = scale_reg(self.fcos_head_reg(fcos_reg_feat)) + if self.centerness_on_reg: + centerness = self.fcos_head_centerness(fcos_reg_feat) + else: + centerness = self.fcos_head_centerness(fcos_cls_feat) + if self.norm_reg_targets: + bbox_reg = F.relu(bbox_reg) + if not self.training: + bbox_reg = bbox_reg * fpn_stride + else: + bbox_reg = paddle.exp(bbox_reg) + cls_logits_list.append(cls_logits) + bboxes_reg_list.append(bbox_reg) + centerness_list.append(centerness) + + if not self.training: + locations_list = [] + for fpn_stride, feature in zip(self.fpn_stride, fpn_feats): + location = self._compute_locations_by_level(fpn_stride, feature) + locations_list.append(location) + + return locations_list, cls_logits_list, bboxes_reg_list, centerness_list + else: + return cls_logits_list, bboxes_reg_list, centerness_list + + def get_loss(self, fcos_head_outs, tag_labels, tag_bboxes, tag_centerness): + cls_logits, bboxes_reg, centerness = fcos_head_outs + return self.fcos_loss(cls_logits, bboxes_reg, centerness, tag_labels, + tag_bboxes, tag_centerness) diff --git a/ppdet/modeling/losses/fcos_loss.py b/ppdet/modeling/losses/fcos_loss.py index b3eac7b4e..e9bbc27aa 100644 --- a/ppdet/modeling/losses/fcos_loss.py +++ b/ppdet/modeling/losses/fcos_loss.py @@ -21,8 +21,9 @@ import paddle.nn as nn import paddle.nn.functional as F from ppdet.core.workspace import register from ppdet.modeling import ops +from functools import partial -__all__ = ['FCOSLoss'] +__all__ = ['FCOSLoss', 'FCOSLossMILC', 'FCOSLossCR'] def flatten_tensor(inputs, channel_first=False): @@ -261,3 +262,759 @@ class FCOSLoss(nn.Layer): "loss_quality": paddle.sum(quality_loss), } return loss_all + + +@register +class FCOSLossMILC(FCOSLoss): + """ + FCOSLossMILC for ARSL in semi-det(ssod) + Args: + loss_alpha (float): alpha in focal loss + loss_gamma (float): gamma in focal loss + iou_loss_type (str): location loss type, IoU/GIoU/LINEAR_IoU + reg_weights (float): weight for location loss + """ + + def __init__(self, + loss_alpha=0.25, + loss_gamma=2.0, + iou_loss_type="giou", + reg_weights=1.0): + super(FCOSLossMILC, self).__init__() + self.loss_alpha = loss_alpha + self.loss_gamma = loss_gamma + self.iou_loss_type = iou_loss_type + self.reg_weights = reg_weights + + def iou_loss(self, pred, targets, weights=None, avg_factor=None): + """ + Calculate the loss for location prediction + Args: + pred (Tensor): bounding boxes prediction + targets (Tensor): targets for positive samples + weights (Tensor): weights for each positive samples + Return: + loss (Tensor): location loss + """ + plw = pred[:, 0] + pth = pred[:, 1] + prw = pred[:, 2] + pbh = pred[:, 3] + + tlw = targets[:, 0] + tth = targets[:, 1] + trw = targets[:, 2] + tbh = targets[:, 3] + tlw.stop_gradient = True + trw.stop_gradient = True + tth.stop_gradient = True + tbh.stop_gradient = True + + ilw = paddle.minimum(plw, tlw) + irw = paddle.minimum(prw, trw) + ith = paddle.minimum(pth, tth) + ibh = paddle.minimum(pbh, tbh) + + clw = paddle.maximum(plw, tlw) + crw = paddle.maximum(prw, trw) + cth = paddle.maximum(pth, tth) + cbh = paddle.maximum(pbh, tbh) + + area_predict = (plw + prw) * (pth + pbh) + area_target = (tlw + trw) * (tth + tbh) + area_inter = (ilw + irw) * (ith + ibh) + ious = (area_inter + 1.0) / ( + area_predict + area_target - area_inter + 1.0) + ious = ious + + if self.iou_loss_type.lower() == "linear_iou": + loss = 1.0 - ious + elif self.iou_loss_type.lower() == "giou": + area_uniou = area_predict + area_target - area_inter + area_circum = (clw + crw) * (cth + cbh) + 1e-7 + giou = ious - (area_circum - area_uniou) / area_circum + loss = 1.0 - giou + elif self.iou_loss_type.lower() == "iou": + loss = 0.0 - paddle.log(ious) + else: + raise KeyError + if weights is not None: + loss = loss * weights + loss = paddle.sum(loss) + if avg_factor is not None: + loss = loss / avg_factor + return loss + + # temp function: calcualate iou between bbox and target + def _bbox_overlap_align(self, pred, targets): + assert pred.shape[0] == targets.shape[0], \ + 'the pred should be aligned with target.' + + plw = pred[:, 0] + pth = pred[:, 1] + prw = pred[:, 2] + pbh = pred[:, 3] + + tlw = targets[:, 0] + tth = targets[:, 1] + trw = targets[:, 2] + tbh = targets[:, 3] + + ilw = paddle.minimum(plw, tlw) + irw = paddle.minimum(prw, trw) + ith = paddle.minimum(pth, tth) + ibh = paddle.minimum(pbh, tbh) + + area_predict = (plw + prw) * (pth + pbh) + area_target = (tlw + trw) * (tth + tbh) + area_inter = (ilw + irw) * (ith + ibh) + ious = (area_inter + 1.0) / ( + area_predict + area_target - area_inter + 1.0) + + return ious + + def iou_based_soft_label_loss(self, + pred, + target, + alpha=0.75, + gamma=2.0, + iou_weighted=False, + implicit_iou=None, + avg_factor=None): + assert pred.shape == target.shape + pred = F.sigmoid(pred) + target = target.cast(pred.dtype) + + if implicit_iou is not None: + pred = pred * implicit_iou + + if iou_weighted: + focal_weight = (pred - target).abs().pow(gamma) * target * (target > 0.0).cast('float32') + \ + alpha * (pred - target).abs().pow(gamma) * \ + (target <= 0.0).cast('float32') + else: + focal_weight = (pred - target).abs().pow(gamma) * (target > 0.0).cast('float32') + \ + alpha * (pred - target).abs().pow(gamma) * \ + (target <= 0.0).cast('float32') + + # focal loss + loss = F.binary_cross_entropy( + pred, target, reduction='none') * focal_weight + if avg_factor is not None: + loss = loss / avg_factor + return loss + + def forward(self, cls_logits, bboxes_reg, centerness, tag_labels, + tag_bboxes, tag_center): + """ + Calculate the loss for classification, location and centerness + Args: + cls_logits (list): list of Tensor, which is predicted + score for all anchor points with shape [N, M, C] + bboxes_reg (list): list of Tensor, which is predicted + offsets for all anchor points with shape [N, M, 4] + centerness (list): list of Tensor, which is predicted + centerness for all anchor points with shape [N, M, 1] + tag_labels (list): list of Tensor, which is category + targets for each anchor point + tag_bboxes (list): list of Tensor, which is bounding + boxes targets for positive samples + tag_center (list): list of Tensor, which is centerness + targets for positive samples + Return: + loss (dict): loss composed by classification loss, bounding box + """ + cls_logits_flatten_list = [] + bboxes_reg_flatten_list = [] + centerness_flatten_list = [] + tag_labels_flatten_list = [] + tag_bboxes_flatten_list = [] + tag_center_flatten_list = [] + num_lvl = len(cls_logits) + for lvl in range(num_lvl): + cls_logits_flatten_list.append( + flatten_tensor(cls_logits[lvl], True)) + bboxes_reg_flatten_list.append( + flatten_tensor(bboxes_reg[lvl], True)) + centerness_flatten_list.append( + flatten_tensor(centerness[lvl], True)) + + tag_labels_flatten_list.append( + flatten_tensor(tag_labels[lvl], False)) + tag_bboxes_flatten_list.append( + flatten_tensor(tag_bboxes[lvl], False)) + tag_center_flatten_list.append( + flatten_tensor(tag_center[lvl], False)) + + cls_logits_flatten = paddle.concat(cls_logits_flatten_list, axis=0) + bboxes_reg_flatten = paddle.concat(bboxes_reg_flatten_list, axis=0) + centerness_flatten = paddle.concat(centerness_flatten_list, axis=0) + + tag_labels_flatten = paddle.concat(tag_labels_flatten_list, axis=0) + tag_bboxes_flatten = paddle.concat(tag_bboxes_flatten_list, axis=0) + tag_center_flatten = paddle.concat(tag_center_flatten_list, axis=0) + tag_labels_flatten.stop_gradient = True + tag_bboxes_flatten.stop_gradient = True + tag_center_flatten.stop_gradient = True + + # find positive index + mask_positive_bool = tag_labels_flatten > 0 + mask_positive_bool.stop_gradient = True + mask_positive_float = paddle.cast(mask_positive_bool, dtype="float32") + mask_positive_float.stop_gradient = True + + num_positive_fp32 = paddle.sum(mask_positive_float) + num_positive_fp32.stop_gradient = True + num_positive_int32 = paddle.cast(num_positive_fp32, dtype="int32") + num_positive_int32 = num_positive_int32 * 0 + 1 + num_positive_int32.stop_gradient = True + + # centerness target is used as reg weight + normalize_sum = paddle.sum(tag_center_flatten * mask_positive_float) + normalize_sum.stop_gradient = True + + # 1. IoU-Based soft label loss + # calculate iou + with paddle.no_grad(): + pos_ind = paddle.nonzero( + tag_labels_flatten.reshape([-1]) > 0).reshape([-1]) + pos_pred = bboxes_reg_flatten[pos_ind] + pos_target = tag_bboxes_flatten[pos_ind] + bbox_iou = self._bbox_overlap_align(pos_pred, pos_target) + # pos labels + pos_labels = tag_labels_flatten[pos_ind].squeeze(1) + cls_target = paddle.zeros(cls_logits_flatten.shape) + cls_target[pos_ind, pos_labels - 1] = bbox_iou + cls_loss = self.iou_based_soft_label_loss( + cls_logits_flatten, + cls_target, + implicit_iou=F.sigmoid(centerness_flatten), + avg_factor=num_positive_fp32) + + # 2. bboxes_reg: giou_loss + mask_positive_float = paddle.squeeze(mask_positive_float, axis=-1) + tag_center_flatten = paddle.squeeze(tag_center_flatten, axis=-1) + reg_loss = self._iou_loss( + bboxes_reg_flatten, + tag_bboxes_flatten, + mask_positive_float, + weights=tag_center_flatten) + reg_loss = reg_loss * mask_positive_float / normalize_sum + + # 3. iou loss + pos_iou_pred = paddle.squeeze(centerness_flatten, axis=-1)[pos_ind] + loss_iou = ops.sigmoid_cross_entropy_with_logits(pos_iou_pred, bbox_iou) + loss_iou = loss_iou / num_positive_fp32 * 0.5 + + loss_all = { + "loss_cls": paddle.sum(cls_loss), + "loss_box": paddle.sum(reg_loss), + 'loss_iou': paddle.sum(loss_iou), + } + + return loss_all + + +# Concat multi-level feature maps by image +def levels_to_images(mlvl_tensor): + batch_size = mlvl_tensor[0].shape[0] + batch_list = [[] for _ in range(batch_size)] + channels = mlvl_tensor[0].shape[1] + for t in mlvl_tensor: + t = t.transpose([0, 2, 3, 1]) + t = t.reshape([batch_size, -1, channels]) + for img in range(batch_size): + batch_list[img].append(t[img]) + return [paddle.concat(item, axis=0) for item in batch_list] + + +def multi_apply(func, *args, **kwargs): + """Apply function to a list of arguments. + + Note: + This function applies the ``func`` to multiple inputs and + map the multiple outputs of the ``func`` into different + list. Each list contains the same type of outputs corresponding + to different inputs. + + Args: + func (Function): A function that will be applied to a list of + arguments + + Returns: + tuple(list): A tuple containing multiple list, each list contains \ + a kind of returned results by the function + """ + pfunc = partial(func, **kwargs) if kwargs else func + map_results = map(pfunc, *args) + return tuple(map(list, zip(*map_results))) + + +@register +class FCOSLossCR(FCOSLossMILC): + """ + FCOSLoss of Consistency Regularization + """ + + def __init__(self, + iou_loss_type="giou", + cls_weight=2.0, + reg_weight=2.0, + iou_weight=0.5, + hard_neg_mining_flag=True): + super(FCOSLossCR, self).__init__() + self.iou_loss_type = iou_loss_type + self.cls_weight = cls_weight + self.reg_weight = reg_weight + self.iou_weight = iou_weight + self.hard_neg_mining_flag = hard_neg_mining_flag + + def iou_loss(self, pred, targets, weights=None, avg_factor=None): + """ + Calculate the loss for location prediction + Args: + pred (Tensor): bounding boxes prediction + targets (Tensor): targets for positive samples + weights (Tensor): weights for each positive samples + Return: + loss (Tensor): location loss + """ + plw = pred[:, 0] + pth = pred[:, 1] + prw = pred[:, 2] + pbh = pred[:, 3] + + tlw = targets[:, 0] + tth = targets[:, 1] + trw = targets[:, 2] + tbh = targets[:, 3] + tlw.stop_gradient = True + trw.stop_gradient = True + tth.stop_gradient = True + tbh.stop_gradient = True + + ilw = paddle.minimum(plw, tlw) + irw = paddle.minimum(prw, trw) + ith = paddle.minimum(pth, tth) + ibh = paddle.minimum(pbh, tbh) + + clw = paddle.maximum(plw, tlw) + crw = paddle.maximum(prw, trw) + cth = paddle.maximum(pth, tth) + cbh = paddle.maximum(pbh, tbh) + + area_predict = (plw + prw) * (pth + pbh) + area_target = (tlw + trw) * (tth + tbh) + area_inter = (ilw + irw) * (ith + ibh) + ious = (area_inter + 1.0) / ( + area_predict + area_target - area_inter + 1.0) + ious = ious + + if self.iou_loss_type.lower() == "linear_iou": + loss = 1.0 - ious + elif self.iou_loss_type.lower() == "giou": + area_uniou = area_predict + area_target - area_inter + area_circum = (clw + crw) * (cth + cbh) + 1e-7 + giou = ious - (area_circum - area_uniou) / area_circum + loss = 1.0 - giou + elif self.iou_loss_type.lower() == "iou": + loss = 0.0 - paddle.log(ious) + else: + raise KeyError + if weights is not None: + loss = loss * weights + loss = paddle.sum(loss) + if avg_factor is not None: + loss = loss / avg_factor + return loss + + # calcualate iou between bbox and target + def bbox_overlap_align(self, pred, targets): + assert pred.shape[0] == targets.shape[0], \ + 'the pred should be aligned with target.' + + plw = pred[:, 0] + pth = pred[:, 1] + prw = pred[:, 2] + pbh = pred[:, 3] + + tlw = targets[:, 0] + tth = targets[:, 1] + trw = targets[:, 2] + tbh = targets[:, 3] + + ilw = paddle.minimum(plw, tlw) + irw = paddle.minimum(prw, trw) + ith = paddle.minimum(pth, tth) + ibh = paddle.minimum(pbh, tbh) + + area_predict = (plw + prw) * (pth + pbh) + area_target = (tlw + trw) * (tth + tbh) + area_inter = (ilw + irw) * (ith + ibh) + ious = (area_inter + 1.0) / ( + area_predict + area_target - area_inter + 1.0) + return ious + + # cls loss: iou-based soft lable with joint iou + def quality_focal_loss(self, + stu_cls, + targets, + quality=None, + weights=None, + alpha=0.75, + gamma=2.0, + avg_factor='sum'): + stu_cls = F.sigmoid(stu_cls) + if quality is not None: + stu_cls = stu_cls * F.sigmoid(quality) + + focal_weight = (stu_cls - targets).abs().pow(gamma) * (targets > 0.0).cast('float32') + \ + alpha * (stu_cls - targets).abs().pow(gamma) * \ + (targets <= 0.0).cast('float32') + + loss = F.binary_cross_entropy( + stu_cls, targets, reduction='none') * focal_weight + + if weights is not None: + loss = loss * weights.reshape([-1, 1]) + loss = paddle.sum(loss) + if avg_factor is not None: + loss = loss / avg_factor + return loss + + # generate points according to feature maps + def compute_locations_by_level(self, fpn_stride, h, w): + """ + Compute locations of anchor points of each FPN layer + Return: + Anchor points locations of current FPN feature map + """ + shift_x = paddle.arange(0, w * fpn_stride, fpn_stride) + shift_y = paddle.arange(0, h * fpn_stride, fpn_stride) + shift_x = paddle.unsqueeze(shift_x, axis=0) + shift_y = paddle.unsqueeze(shift_y, axis=1) + shift_x = paddle.expand(shift_x, shape=[h, w]) + shift_y = paddle.expand(shift_y, shape=[h, w]) + shift_x = paddle.reshape(shift_x, shape=[-1]) + shift_y = paddle.reshape(shift_y, shape=[-1]) + location = paddle.stack( + [shift_x, shift_y], axis=-1) + float(fpn_stride) / 2 + return location + + # decode bbox from ltrb to x1y1x2y2 + def decode_bbox(self, ltrb, points): + assert ltrb.shape[0] == points.shape[0], \ + "When decoding bbox in one image, the num of loc should be same with points." + bbox_decoding = paddle.stack( + [ + points[:, 0] - ltrb[:, 0], points[:, 1] - ltrb[:, 1], + points[:, 0] + ltrb[:, 2], points[:, 1] + ltrb[:, 3] + ], + axis=1) + return bbox_decoding + + # encode bbox from x1y1x2y2 to ltrb + def encode_bbox(self, bbox, points): + assert bbox.shape[0] == points.shape[0], \ + "When encoding bbox in one image, the num of bbox should be same with points." + bbox_encoding = paddle.stack( + [ + points[:, 0] - bbox[:, 0], points[:, 1] - bbox[:, 1], + bbox[:, 2] - points[:, 0], bbox[:, 3] - points[:, 1] + ], + axis=1) + return bbox_encoding + + def calcualate_iou(self, gt_bbox, predict_bbox): + # bbox area + gt_area = (gt_bbox[:, 2] - gt_bbox[:, 0]) * \ + (gt_bbox[:, 3] - gt_bbox[:, 1]) + predict_area = (predict_bbox[:, 2] - predict_bbox[:, 0]) * \ + (predict_bbox[:, 3] - predict_bbox[:, 1]) + # overlop area + lt = paddle.fmax(gt_bbox[:, None, :2], predict_bbox[None, :, :2]) + rb = paddle.fmin(gt_bbox[:, None, 2:], predict_bbox[None, :, 2:]) + wh = paddle.clip(rb - lt, min=0) + overlap = wh[..., 0] * wh[..., 1] + # iou + iou = overlap / (gt_area[:, None] + predict_area[None, :] - overlap) + return iou + + # select potential positives from hard negatives + def hard_neg_mining(self, + cls_score, + loc_ltrb, + quality, + pos_ind, + hard_neg_ind, + loc_mask, + loc_targets, + iou_thresh=0.6): + # get points locations and strides + points_list = [] + strides_list = [] + scale_list = [] + scale = [0, 1, 2, 3, 4] + for fpn_scale, fpn_stride, HW in zip(scale, self.fpn_stride, + self.lvl_hw): + h, w = HW + lvl_points = self.compute_locations_by_level(fpn_stride, h, w) + points_list.append(lvl_points) + lvl_strides = paddle.full([h * w, 1], fpn_stride) + strides_list.append(lvl_strides) + lvl_scales = paddle.full([h * w, 1], fpn_scale) + scale_list.append(lvl_scales) + points = paddle.concat(points_list, axis=0) + strides = paddle.concat(strides_list, axis=0) + scales = paddle.concat(scale_list, axis=0) + + # cls scores + cls_vals = F.sigmoid(cls_score) * F.sigmoid(quality) + max_vals = paddle.max(cls_vals, axis=-1) + class_ind = paddle.argmax(cls_vals, axis=-1) + + ### calculate iou between positive and hard negative + # decode pos bbox + pos_cls = max_vals[pos_ind] + pos_loc = loc_ltrb[pos_ind].reshape([-1, 4]) + pos_strides = strides[pos_ind] + pos_points = points[pos_ind].reshape([-1, 2]) + pos_loc = pos_loc * pos_strides + pos_bbox = self.decode_bbox(pos_loc, pos_points) + pos_scales = scales[pos_ind] + # decode hard negative bbox + hard_neg_loc = loc_ltrb[hard_neg_ind].reshape([-1, 4]) + hard_neg_strides = strides[hard_neg_ind] + hard_neg_points = points[hard_neg_ind].reshape([-1, 2]) + hard_neg_loc = hard_neg_loc * hard_neg_strides + hard_neg_bbox = self.decode_bbox(hard_neg_loc, hard_neg_points) + hard_neg_scales = scales[hard_neg_ind] + # iou between pos bbox and hard negative bbox + hard_neg_pos_iou = self.calcualate_iou(hard_neg_bbox, pos_bbox) + + ### select potential positives from hard negatives + # scale flag + scale_temp = paddle.abs( + pos_scales.reshape([-1])[None, :] - hard_neg_scales.reshape([-1]) + [:, None]) + scale_flag = (scale_temp <= 1.) + # iou flag + iou_flag = (hard_neg_pos_iou >= iou_thresh) + # same class flag + pos_class = class_ind[pos_ind] + hard_neg_class = class_ind[hard_neg_ind] + class_flag = pos_class[None, :] - hard_neg_class[:, None] + class_flag = (class_flag == 0) + # hard negative point inside positive bbox flag + ltrb_temp = paddle.stack( + [ + hard_neg_points[:, None, 0] - pos_bbox[None, :, 0], + hard_neg_points[:, None, 1] - pos_bbox[None, :, 1], + pos_bbox[None, :, 2] - hard_neg_points[:, None, 0], + pos_bbox[None, :, 3] - hard_neg_points[:, None, 1] + ], + axis=-1) + inside_flag = ltrb_temp.min(axis=-1) > 0 + # reset iou + valid_flag = (iou_flag & class_flag & inside_flag & scale_flag) + invalid_iou = paddle.zeros_like(hard_neg_pos_iou) + hard_neg_pos_iou = paddle.where(valid_flag, hard_neg_pos_iou, + invalid_iou) + pos_hard_neg_max_iou = hard_neg_pos_iou.max(axis=-1) + # selece potential pos + potential_pos_ind = (pos_hard_neg_max_iou > 0.) + num_potential_pos = paddle.nonzero(potential_pos_ind).shape[0] + if num_potential_pos == 0: + return None + + ### calculate loc target:aggregate all matching bboxes as the bbox targets of potential pos + # prepare data + potential_points = hard_neg_points[potential_pos_ind].reshape([-1, 2]) + potential_strides = hard_neg_strides[potential_pos_ind] + potential_valid_flag = valid_flag[potential_pos_ind] + potential_pos_ind = hard_neg_ind[potential_pos_ind] + + # get cls and box of matching positives + pos_cls = max_vals[pos_ind] + expand_pos_bbox = paddle.expand( + pos_bbox, + shape=[num_potential_pos, pos_bbox.shape[0], pos_bbox.shape[1]]) + expand_pos_cls = paddle.expand( + pos_cls, shape=[num_potential_pos, pos_cls.shape[0]]) + invalid_cls = paddle.zeros_like(expand_pos_cls) + expand_pos_cls = paddle.where(potential_valid_flag, expand_pos_cls, + invalid_cls) + expand_pos_cls = paddle.unsqueeze(expand_pos_cls, axis=-1) + # aggregate box based on cls_score + agg_bbox = (expand_pos_bbox * expand_pos_cls).sum(axis=1) \ + / expand_pos_cls.sum(axis=1) + agg_ltrb = self.encode_bbox(agg_bbox, potential_points) + agg_ltrb = agg_ltrb / potential_strides + + # loc target for all pos + loc_targets[potential_pos_ind] = agg_ltrb + loc_mask[potential_pos_ind] = 1. + + return loc_mask, loc_targets + + # get training targets + def get_targets_per_img(self, tea_cls, tea_loc, tea_iou, stu_cls, stu_loc, + stu_iou): + + ### sample selection + # prepare datas + tea_cls_scores = F.sigmoid(tea_cls) * F.sigmoid(tea_iou) + class_ind = paddle.argmax(tea_cls_scores, axis=-1) + max_vals = paddle.max(tea_cls_scores, axis=-1) + cls_mask = paddle.zeros_like( + max_vals + ) # set cls valid mask: pos is 1, hard_negative and negative are 0. + num_pos, num_hard_neg = 0, 0 + + # mean-std selection + # using nonzero to turn index from bool to int, because the index will be used to compose two-dim index in following. + # using squeeze rather than reshape to avoid errors when no score is larger than thresh. + candidate_ind = paddle.nonzero(max_vals >= 0.1).squeeze(axis=-1) + num_candidate = candidate_ind.shape[0] + if num_candidate > 0: + # pos thresh = mean + std to select pos samples + candidate_score = max_vals[candidate_ind] + candidate_score_mean = candidate_score.mean() + candidate_score_std = candidate_score.std() + pos_thresh = (candidate_score_mean + candidate_score_std).clip( + max=0.4) + # select pos + pos_ind = paddle.nonzero(max_vals >= pos_thresh).squeeze(axis=-1) + num_pos = pos_ind.shape[0] + # select hard negatives as potential pos + hard_neg_ind = (max_vals >= 0.1) & (max_vals < pos_thresh) + hard_neg_ind = paddle.nonzero(hard_neg_ind).squeeze(axis=-1) + num_hard_neg = hard_neg_ind.shape[0] + # if not positive, directly select top-10 as pos. + if (num_pos == 0): + num_pos = 10 + _, pos_ind = paddle.topk(max_vals, k=num_pos) + cls_mask[pos_ind] = 1. + + ### Consistency Regularization Training targets + # cls targets + pos_class_ind = class_ind[pos_ind] + cls_targets = paddle.zeros_like(tea_cls) + cls_targets[pos_ind, pos_class_ind] = tea_cls_scores[pos_ind, + pos_class_ind] + # hard negative cls target + if num_hard_neg != 0: + cls_targets[hard_neg_ind] = tea_cls_scores[hard_neg_ind] + # loc targets + loc_targets = paddle.zeros_like(tea_loc) + loc_targets[pos_ind] = tea_loc[pos_ind] + # iou targets + iou_targets = paddle.zeros( + shape=[tea_iou.shape[0]], dtype=tea_iou.dtype) + iou_targets[pos_ind] = F.sigmoid( + paddle.squeeze( + tea_iou, axis=-1)[pos_ind]) + + loc_mask = cls_mask.clone() + # select potential positive from hard negatives for loc_task training + if (num_hard_neg > 0) and self.hard_neg_mining_flag: + results = self.hard_neg_mining(tea_cls, tea_loc, tea_iou, pos_ind, + hard_neg_ind, loc_mask, loc_targets) + if results is not None: + loc_mask, loc_targets = results + loc_pos_ind = paddle.nonzero(loc_mask > 0.).squeeze(axis=-1) + iou_targets[loc_pos_ind] = F.sigmoid( + paddle.squeeze( + tea_iou, axis=-1)[loc_pos_ind]) + + return cls_mask, loc_mask, \ + cls_targets, loc_targets, iou_targets + + def forward(self, student_prediction, teacher_prediction): + stu_cls_lvl, stu_loc_lvl, stu_iou_lvl = student_prediction + tea_cls_lvl, tea_loc_lvl, tea_iou_lvl, self.fpn_stride = teacher_prediction + + # H and W of level (used for aggregating targets) + self.lvl_hw = [] + for t in tea_cls_lvl: + _, _, H, W = t.shape + self.lvl_hw.append([H, W]) + + # levels to images + stu_cls_img = levels_to_images(stu_cls_lvl) + stu_loc_img = levels_to_images(stu_loc_lvl) + stu_iou_img = levels_to_images(stu_iou_lvl) + tea_cls_img = levels_to_images(tea_cls_lvl) + tea_loc_img = levels_to_images(tea_loc_lvl) + tea_iou_img = levels_to_images(tea_iou_lvl) + + with paddle.no_grad(): + cls_mask, loc_mask, \ + cls_targets, loc_targets, iou_targets = multi_apply( + self.get_targets_per_img, + tea_cls_img, + tea_loc_img, + tea_iou_img, + stu_cls_img, + stu_loc_img, + stu_iou_img + ) + + # flatten preditction + stu_cls = paddle.concat(stu_cls_img, axis=0) + stu_loc = paddle.concat(stu_loc_img, axis=0) + stu_iou = paddle.concat(stu_iou_img, axis=0) + # flatten targets + cls_mask = paddle.concat(cls_mask, axis=0) + loc_mask = paddle.concat(loc_mask, axis=0) + cls_targets = paddle.concat(cls_targets, axis=0) + loc_targets = paddle.concat(loc_targets, axis=0) + iou_targets = paddle.concat(iou_targets, axis=0) + + ### Training Weights and avg factor + # find positives + cls_pos_ind = paddle.nonzero(cls_mask > 0.).squeeze(axis=-1) + loc_pos_ind = paddle.nonzero(loc_mask > 0.).squeeze(axis=-1) + # cls weight + cls_sample_weights = paddle.ones([cls_targets.shape[0]]) + cls_avg_factor = paddle.max(cls_targets[cls_pos_ind], + axis=-1).sum().item() + # loc weight + loc_sample_weights = paddle.max(cls_targets[loc_pos_ind], axis=-1) + loc_avg_factor = loc_sample_weights.sum().item() + # iou weight + iou_sample_weights = paddle.ones([loc_pos_ind.shape[0]]) + iou_avg_factor = loc_pos_ind.shape[0] + + ### unsupervised loss + # cls loss + loss_cls = self.quality_focal_loss( + stu_cls, + cls_targets, + quality=stu_iou, + weights=cls_sample_weights, + avg_factor=cls_avg_factor) * self.cls_weight + # iou loss + pos_stu_iou = paddle.squeeze(stu_iou, axis=-1)[loc_pos_ind] + pos_iou_targets = iou_targets[loc_pos_ind] + loss_iou = F.binary_cross_entropy( + F.sigmoid(pos_stu_iou), pos_iou_targets, + reduction='none') * iou_sample_weights + loss_iou = loss_iou.sum() / iou_avg_factor * self.iou_weight + # box loss + pos_stu_loc = stu_loc[loc_pos_ind] + pos_loc_targets = loc_targets[loc_pos_ind] + + loss_box = self.iou_loss( + pos_stu_loc, + pos_loc_targets, + weights=loc_sample_weights, + avg_factor=loc_avg_factor) + loss_box = loss_box * self.reg_weight + + loss_all = { + "loss_cls": loss_cls, + "loss_box": loss_box, + "loss_iou": loss_iou, + } + return loss_all diff --git a/ppdet/utils/checkpoint.py b/ppdet/utils/checkpoint.py index f57ef0227..ed0433764 100644 --- a/ppdet/utils/checkpoint.py +++ b/ppdet/utils/checkpoint.py @@ -17,9 +17,7 @@ from __future__ import division from __future__ import print_function from __future__ import unicode_literals -import errno import os -import time import numpy as np import paddle import paddle.nn as nn @@ -40,21 +38,6 @@ def is_url(path): or path.startswith('ppdet://') -def _get_unique_endpoints(trainer_endpoints): - # Sorting is to avoid different environmental variables for each card - trainer_endpoints.sort() - ips = set() - unique_endpoints = set() - for endpoint in trainer_endpoints: - ip = endpoint.split(":")[0] - if ip in ips: - continue - ips.add(ip) - unique_endpoints.add(endpoint) - logger.info("unique_endpoints {}".format(unique_endpoints)) - return unique_endpoints - - def _strip_postfix(path): path, ext = os.path.splitext(path) assert ext in ['', '.pdparams', '.pdopt', '.pdmodel'], \ @@ -92,28 +75,35 @@ def load_weight(model, weight, optimizer=None, ema=None, exchange=True): ema_state_dict = None param_state_dict = paddle.load(pdparam_path) - model_dict = model.state_dict() - model_weight = {} - incorrect_keys = 0 + if hasattr(model, 'modelTeacher') and hasattr(model, 'modelStudent'): + print('Loading pretrain weights for Teacher-Student framework.') + print('Loading pretrain weights for Student model.') + student_model_dict = model.modelStudent.state_dict() + student_param_state_dict = match_state_dict( + student_model_dict, param_state_dict, mode='student') + model.modelStudent.set_dict(student_param_state_dict) + print('Loading pretrain weights for Teacher model.') + teacher_model_dict = model.modelTeacher.state_dict() - for key, value in model_dict.items(): - if key in param_state_dict.keys(): - if isinstance(param_state_dict[key], np.ndarray): - param_state_dict[key] = paddle.to_tensor(param_state_dict[key]) - if value.dtype == param_state_dict[key].dtype: + teacher_param_state_dict = match_state_dict( + teacher_model_dict, param_state_dict, mode='teacher') + model.modelTeacher.set_dict(teacher_param_state_dict) + + else: + model_dict = model.state_dict() + model_weight = {} + incorrect_keys = 0 + for key in model_dict.keys(): + if key in param_state_dict.keys(): model_weight[key] = param_state_dict[key] else: - model_weight[key] = param_state_dict[key].astype(value.dtype) - else: - logger.info('Unmatched key: {}'.format(key)) - incorrect_keys += 1 - - assert incorrect_keys == 0, "Load weight {} incorrectly, \ - {} keys unmatched, please check again.".format(weight, - incorrect_keys) - logger.info('Finish resuming model weights: {}'.format(pdparam_path)) - - model.set_dict(model_weight) + logger.info('Unmatched key: {}'.format(key)) + incorrect_keys += 1 + assert incorrect_keys == 0, "Load weight {} incorrectly, \ + {} keys unmatched, please check again.".format(weight, + incorrect_keys) + logger.info('Finish resuming model weights: {}'.format(pdparam_path)) + model.set_dict(model_weight) last_epoch = 0 if optimizer is not None and os.path.exists(path + '.pdopt'): @@ -134,7 +124,7 @@ def load_weight(model, weight, optimizer=None, ema=None, exchange=True): return last_epoch -def match_state_dict(model_state_dict, weight_state_dict): +def match_state_dict(model_state_dict, weight_state_dict, mode='default'): """ Match between the model state dict and pretrained weight state dict. Return the matched state dict. @@ -152,33 +142,47 @@ def match_state_dict(model_state_dict, weight_state_dict): model_keys = sorted(model_state_dict.keys()) weight_keys = sorted(weight_state_dict.keys()) + def teacher_match(a, b): + # skip student params + if b.startswith('modelStudent'): + return False + return a == b or a.endswith("." + b) or b.endswith("." + a) + + def student_match(a, b): + # skip teacher params + if b.startswith('modelTeacher'): + return False + return a == b or a.endswith("." + b) or b.endswith("." + a) + def match(a, b): - if b.startswith('backbone.res5'): - # In Faster RCNN, res5 pretrained weights have prefix of backbone, - # however, the corresponding model weights have difficult prefix, - # bbox_head. + if a.startswith('backbone.res5'): b = b[9:] return a == b or a.endswith("." + b) + if mode == 'student': + match_op = student_match + elif mode == 'teacher': + match_op = teacher_match + else: + match_op = match + match_matrix = np.zeros([len(model_keys), len(weight_keys)]) for i, m_k in enumerate(model_keys): for j, w_k in enumerate(weight_keys): - if match(m_k, w_k): + if match_op(m_k, w_k): match_matrix[i, j] = len(w_k) max_id = match_matrix.argmax(1) max_len = match_matrix.max(1) max_id[max_len == 0] = -1 - - load_id = set(max_id) - load_id.discard(-1) not_load_weight_name = [] - for idx in range(len(weight_keys)): - if idx not in load_id: - not_load_weight_name.append(weight_keys[idx]) + for match_idx in range(len(max_id)): + if max_id[match_idx] == -1: + not_load_weight_name.append(model_keys[match_idx]) if len(not_load_weight_name) > 0: - logger.info('{} in pretrained weight is not used in the model, ' - 'and its will not be loaded'.format(not_load_weight_name)) + logger.info('{} in model is not matched with pretrained weights, ' + 'and its will be trained from scratch'.format( + not_load_weight_name)) matched_keys = {} result_state_dict = {} for model_id, weight_id in enumerate(max_id): @@ -208,7 +212,7 @@ def match_state_dict(model_state_dict, weight_state_dict): return result_state_dict -def load_pretrain_weight(model, pretrain_weight): +def load_pretrain_weight(model, pretrain_weight, ARSL_eval=False): if is_url(pretrain_weight): pretrain_weight = get_weights_path(pretrain_weight) @@ -219,21 +223,48 @@ def load_pretrain_weight(model, pretrain_weight): "If you don't want to load pretrain model, " "please delete `pretrain_weights` field in " "config file.".format(path)) + teacher_student_flag = False + if not ARSL_eval: + if hasattr(model, 'modelTeacher') and hasattr(model, 'modelStudent'): + print('Loading pretrain weights for Teacher-Student framework.') + print( + 'Assert Teacher model has the same structure with Student model.' + ) + model_dict = model.modelStudent.state_dict() + teacher_student_flag = True + else: + model_dict = model.state_dict() + + weights_path = path + '.pdparams' + param_state_dict = paddle.load(weights_path) + param_state_dict = match_state_dict(model_dict, param_state_dict) + for k, v in param_state_dict.items(): + if isinstance(v, np.ndarray): + v = paddle.to_tensor(v) + if model_dict[k].dtype != v.dtype: + param_state_dict[k] = v.astype(model_dict[k].dtype) + + if teacher_student_flag: + model.modelStudent.set_dict(param_state_dict) + model.modelTeacher.set_dict(param_state_dict) + else: + model.set_dict(param_state_dict) + logger.info('Finish loading model weights: {}'.format(weights_path)) - model_dict = model.state_dict() - - weights_path = path + '.pdparams' - param_state_dict = paddle.load(weights_path) - param_state_dict = match_state_dict(model_dict, param_state_dict) - - for k, v in param_state_dict.items(): - if isinstance(v, np.ndarray): - v = paddle.to_tensor(v) - if model_dict[k].dtype != v.dtype: - param_state_dict[k] = v.astype(model_dict[k].dtype) + else: + weights_path = path + '.pdparams' + param_state_dict = paddle.load(weights_path) + student_model_dict = model.modelStudent.state_dict() + student_param_state_dict = match_state_dict( + student_model_dict, param_state_dict, mode='student') + model.modelStudent.set_dict(student_param_state_dict) + print('Loading pretrain weights for Teacher model.') + teacher_model_dict = model.modelTeacher.state_dict() - model.set_dict(param_state_dict) - logger.info('Finish loading model weights: {}'.format(weights_path)) + teacher_param_state_dict = match_state_dict( + teacher_model_dict, param_state_dict, mode='teacher') + model.modelTeacher.set_dict(teacher_param_state_dict) + logger.info('Finish loading model weights: {}'.format(weights_path)) def save_model(model, @@ -256,21 +287,24 @@ def save_model(model, """ if paddle.distributed.get_rank() != 0: return - assert isinstance(model, dict), ("model is not a instance of dict, " - "please call model.state_dict() to get.") if not os.path.exists(save_dir): os.makedirs(save_dir) save_path = os.path.join(save_dir, save_name) # save model - if ema_model is None: - paddle.save(model, save_path + ".pdparams") + if isinstance(model, nn.Layer): + paddle.save(model.state_dict(), save_path + ".pdparams") else: - assert isinstance(ema_model, - dict), ("ema_model is not a instance of dict, " - "please call model.state_dict() to get.") - # Exchange model and ema_model to save - paddle.save(ema_model, save_path + ".pdparams") - paddle.save(model, save_path + ".pdema") + assert isinstance(model, + dict), 'model is not a instance of nn.layer or dict' + if ema_model is None: + paddle.save(model, save_path + ".pdparams") + else: + assert isinstance(ema_model, + dict), ("ema_model is not a instance of dict, " + "please call model.state_dict() to get.") + # Exchange model and ema_model to save + paddle.save(ema_model, save_path + ".pdparams") + paddle.save(model, save_path + ".pdema") # save optimizer state_dict = optimizer.state_dict() state_dict['last_epoch'] = last_epoch diff --git a/tools/eval.py b/tools/eval.py index 40cbbecd8..fc34686f0 100755 --- a/tools/eval.py +++ b/tools/eval.py @@ -32,7 +32,7 @@ import paddle from ppdet.core.workspace import create, load_config, merge_config from ppdet.utils.check import check_gpu, check_npu, check_xpu, check_mlu, check_version, check_config from ppdet.utils.cli import ArgsParser, merge_args -from ppdet.engine import Trainer, init_parallel_env +from ppdet.engine import Trainer, Trainer_ARSL, init_parallel_env from ppdet.metrics.coco_utils import json_eval_results from ppdet.slim import build_slim_model @@ -135,12 +135,17 @@ def run(FLAGS, cfg): # init parallel environment if nranks > 1 init_parallel_env() - - # build trainer - trainer = Trainer(cfg, mode='eval') - - # load weights - trainer.load_weights(cfg.weights) + ssod_method = cfg.get('ssod_method', None) + if ssod_method == 'ARSL': + # build ARSL_trainer + trainer = Trainer_ARSL(cfg, mode='eval') + # load ARSL_weights + trainer.load_weights(cfg.weights, ARSL_eval=True) + else: + # build trainer + trainer = Trainer(cfg, mode='eval') + #load weights + trainer.load_weights(cfg.weights) # training if FLAGS.slice_infer: diff --git a/tools/export_model.py b/tools/export_model.py index 20cfcfaa5..f4ffcb500 100644 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -32,6 +32,7 @@ from ppdet.core.workspace import load_config, merge_config from ppdet.utils.check import check_gpu, check_version, check_config from ppdet.utils.cli import ArgsParser from ppdet.engine import Trainer +from ppdet.engine.trainer_ssod import Trainer_ARSL from ppdet.slim import build_slim_model from ppdet.utils.logger import setup_logger @@ -60,14 +61,19 @@ def parse_args(): def run(FLAGS, cfg): + ssod_method = cfg.get('ssod_method', None) + if ssod_method is not None and ssod_method == 'ARSL': + trainer = Trainer_ARSL(cfg, mode='test') + trainer.load_weights(cfg.weights, ARSL_eval=True) # build detector - trainer = Trainer(cfg, mode='test') - - # load weights - if cfg.architecture in ['DeepSORT', 'ByteTrack']: - trainer.load_weights_sde(cfg.det_weights, cfg.reid_weights) else: - trainer.load_weights(cfg.weights) + trainer = Trainer(cfg, mode='test') + + # load weights + if cfg.architecture in ['DeepSORT', 'ByteTrack']: + trainer.load_weights_sde(cfg.det_weights, cfg.reid_weights) + else: + trainer.load_weights(cfg.weights) # export model trainer.export(FLAGS.output_dir) diff --git a/tools/infer.py b/tools/infer.py index 65fb3b725..9d99237a1 100755 --- a/tools/infer.py +++ b/tools/infer.py @@ -31,7 +31,7 @@ import ast import paddle from ppdet.core.workspace import load_config, merge_config -from ppdet.engine import Trainer +from ppdet.engine import Trainer, Trainer_ARSL from ppdet.utils.check import check_gpu, check_npu, check_xpu, check_mlu, check_version, check_config from ppdet.utils.cli import ArgsParser, merge_args from ppdet.slim import build_slim_model @@ -156,12 +156,13 @@ def get_test_images(infer_dir, infer_img): def run(FLAGS, cfg): - # build trainer - trainer = Trainer(cfg, mode='test') - - # load weights - trainer.load_weights(cfg.weights) - + ssod_method = cfg.get('ssod_method', None) + if ssod_method == 'ARSL': + trainer = Trainer_ARSL(cfg, mode='test') + trainer.load_weights(cfg.weights, ARSL_eval=True) + else: + trainer = Trainer(cfg, mode='test') + trainer.load_weights(cfg.weights) # get inference images images = get_test_images(FLAGS.infer_dir, FLAGS.infer_img) diff --git a/tools/train.py b/tools/train.py index ec846519e..3aa0a21a7 100755 --- a/tools/train.py +++ b/tools/train.py @@ -32,7 +32,7 @@ import paddle from ppdet.core.workspace import load_config, merge_config from ppdet.engine import Trainer, TrainerCot, init_parallel_env, set_random_seed, init_fleet_env -from ppdet.engine.trainer_ssod import Trainer_DenseTeacher +from ppdet.engine.trainer_ssod import Trainer_DenseTeacher, Trainer_ARSL from ppdet.slim import build_slim_model @@ -132,9 +132,11 @@ def run(FLAGS, cfg): if ssod_method is not None: if ssod_method == 'DenseTeacher': trainer = Trainer_DenseTeacher(cfg, mode='train') + elif ssod_method == 'ARSL': + trainer = Trainer_ARSL(cfg, mode='train') else: raise ValueError( - "Semi-Supervised Object Detection only support DenseTeacher now." + "Semi-Supervised Object Detection only support DenseTeacher and ARSL now." ) elif cfg.get('use_cot', False): trainer = TrainerCot(cfg, mode='train') -- GitLab