You need to sign in or sign up before continuing.
未验证 提交 c72fea6a 编写于 作者: K Kaipeng Deng 提交者: GitHub

add ppyolo mobilenetv3 (#1443)

* add ppyolo mobilenetv3
上级 262dfa3e
......@@ -50,6 +50,9 @@ PP-YOLO improved performance and speed of YOLOv3 with following methods:
| PP-YOLO_2x | 8 | 24 | ResNet50vd | 512 | 44.4 | 45.0 | 89.9 | 188.4 | [model](https://paddlemodels.bj.bcebos.com/object_detection/ppyolo_2x.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/ppyolo/ppyolo_2x.yml) |
| PP-YOLO_2x | 8 | 24 | ResNet50vd | 416 | 42.7 | 43.2 | 109.1 | 215.4 | [model](https://paddlemodels.bj.bcebos.com/object_detection/ppyolo_2x.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/ppyolo/ppyolo_2x.yml) |
| PP-YOLO_2x | 8 | 24 | ResNet50vd | 320 | 39.5 | 40.1 | 132.2 | 242.2 | [model](https://paddlemodels.bj.bcebos.com/object_detection/ppyolo_2x.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/ppyolo/ppyolo_2x.yml) |
| PP-YOLO_ResNet18vd | 4 | 32 | ResNet18vd | 512 | 29.3 | 29.5 | 357.1 | 657.9 | [model](https://paddlemodels.bj.bcebos.com/object_detection/ppyolo_r18vd.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/ppyolo/ppyolo_r18vd.yml) |
| PP-YOLO_ResNet18vd | 4 | 32 | ResNet18vd | 416 | 28.6 | 28.9 | 409.8 | 719.4 | [model](https://paddlemodels.bj.bcebos.com/object_detection/ppyolo_r18vd.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/ppyolo/ppyolo_r18vd.yml) |
| PP-YOLO_ResNet18vd | 4 | 32 | ResNet18vd | 320 | 26.2 | 26.4 | 480.7 | 763.4 | [model](https://paddlemodels.bj.bcebos.com/object_detection/ppyolo_r18vd.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/ppyolo/ppyolo_r18vd.yml) |
**Notes:**
......@@ -64,14 +67,22 @@ PP-YOLO improved performance and speed of YOLOv3 with following methods:
### PP-YOLO for mobile
| Model | GPU number | images/GPU | backbone | input shape | Box AP50<sup>val</sup> | Box AP50<sup>test</sup> | V100 FP32(FPS) | V100 TensorRT FP16(FPS) | download | config |
|:------------------------:|:----------:|:----------:|:----------:| :----------:| :--------------------: | :---------------------: | :------------: | :---------------------: | :------: | :-----: |
| PP-YOLO_r18vd | 4 | 32 | ResNet18vd | 416 | 47.0 | 47.7 | 401.6 | 724.6 | [model](https://paddlemodels.bj.bcebos.com/object_detection/ppyolo_r18vd.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/ppyolo/ppyolo_r18vd.yml) |
| PP-YOLO_r18vd | 4 | 32 | ResNet18vd | 320 | 43.7 | 44.4 | 478.5 | 791.3 | [model](https://paddlemodels.bj.bcebos.com/object_detection/ppyolo_r18vd.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/ppyolo/ppyolo_r18vd.yml) |
| Model | GPU number | images/GPU | Model Size | input shape | Box AP<sup>val</sup> | Kirin 990(FPS) | download | config |
|:----------------------------:|:----------:|:----------:| :--------: | :----------:| :------------------: | :------------: | :------: | :-----: |
| PP-YOLO_MobileNetV3_large | 4 | 32 | 18MB | 320 | 22.0 | 14.1 | [model](https://paddlemodels.bj.bcebos.com/object_detection/ppyolo_mobilenet_v3_large.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/ppyolo/ppyolo_mobilenet_v3_large.yml) |
| PP-YOLO_MobileNetV3_small | 4 | 32 | 11MB | 320 | 16.8 | 21.5 | [model](https://paddlemodels.bj.bcebos.com/object_detection/ppyolo_mobilenet_v3_small.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/ppyolo/ppyolo_mobilenet_v3_small.yml) |
- PP-YOLO_r18vd is trained on COCO train2017 datast and evaluated on val2017 & test-dev2017 dataset,Box AP50<sup>val</sup> is evaluation results of `mAP(IoU=0.5)`.
- PP-YOLO_r18vd used 4 GPUs for training and mini-batch size as 32 on each GPU, if GPU number and mini-batch size is changed, learning rate and iteration times should be adjusted according [FAQ](../../docs/FAQ.md).
- PP-YOLO_r18vd inference speeding testing environment and configuration is same as PP-YOLO above.
**Notes:**
- PP-YOLO_MobileNetV3 is trained on COCO train2017 datast and evaluated on val2017 dataset,Box AP<sup>val</sup> is evaluation results of `mAP(IoU=0.5:0.95)`.
- PP-YOLO_MobileNetV3 used 4 GPUs for training and mini-batch size as 32 on each GPU, if GPU number and mini-batch size is changed, learning rate and iteration times should be adjusted according [FAQ](../../docs/FAQ.md).
- PP-YOLO_MobileNetV3 inference speed is tested on Kirin 990 with 1 thread.
### Slim PP-YOLO
| Model | GPU number | images/GPU | Prune Ratio | Teacher Model | Model Size | input shape | Box AP<sup>val</sup> | Kirin 990(FPS) | download | config |
|:----------------------------:|:----------:|:----------:| :---------: | :-----------------------: | :--------: | :----------:| :------------------: | :------------: | :------: | :-----: |
| PP-YOLO_MobileNetV3_small | 4 | 32 | 75% | PP-YOLO_MobileNetV3_large | 4.1MB | 320 | 14.4 | 21.5 | [model](https://paddlemodels.bj.bcebos.com/object_detection/ppyolo_mobilenet_v3_small.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/ppyolo/ppyolo_mobilenet_v3_small.yml) |
## Getting Start
......
......@@ -50,6 +50,9 @@ PP-YOLO从如下方面优化和提升YOLOv3模型的精度和速度:
| PP-YOLO_2x | 8 | 24 | ResNet50vd | 512 | 44.4 | 45.0 | 89.9 | 188.4 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/ppyolo.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/ppyolo/ppyolo.yml) |
| PP-YOLO_2x | 8 | 24 | ResNet50vd | 416 | 42.7 | 43.2 | 109.1 | 215.4 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/ppyolo.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/ppyolo/ppyolo.yml) |
| PP-YOLO_2x | 8 | 24 | ResNet50vd | 320 | 39.5 | 40.1 | 132.2 | 242.2 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/ppyolo.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/ppyolo/ppyolo.yml) |
| PP-YOLO_ResNet18vd | 4 | 32 | ResNet18vd | 512 | 29.3 | 29.5 | 357.1 | 657.9 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/ppyolo_r18vd.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/ppyolo/ppyolo_r18vd.yml) |
| PP-YOLO_ResNet18vd | 4 | 32 | ResNet18vd | 416 | 28.6 | 28.9 | 409.8 | 719.4 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/ppyolo_r18vd.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/ppyolo/ppyolo_r18vd.yml) |
| PP-YOLO_ResNet18vd | 4 | 32 | ResNet18vd | 320 | 26.2 | 26.4 | 480.7 | 763.4 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/ppyolo_r18vd.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/ppyolo/ppyolo_r18vd.yml) |
**注意:**
......@@ -63,16 +66,16 @@ PP-YOLO从如下方面优化和提升YOLOv3模型的精度和速度:
- YOLOv4(AlexyAB)行`模型下载``配置文件`为PaddleDetection复现的YOLOv4模型,目前评估精度已对齐,支持finetune,训练精度对齐中,可参见[PaddleDetection YOLOv4 模型](../yolov4/README.md)
- PP-YOLO使用每GPU `batch_size=24`训练,需要使用显存为32G的GPU,我们也提供了`batch_size=12`的可以在显存为16G的GPU上训练的配置文件`ppyolo_2x_bs12.yml`,使用这个配置文件训练在COCO val2017数据集上评估结果为`mAP(IoU=0.5:0.95) = 45.1%`,可通过[ppyolo_2x_bs12模型](https://paddlemodels.bj.bcebos.com/object_detection/ppyolo_2x_bs12.pdparams)下载权重。
### PP-YOLO 移动端模型
### PP-YOLO 轻量级模型
| 模型 | GPU个数 | 每GPU图片个数 | 骨干网络 | 输入尺寸 | Box AP50<sup>val</sup> | Box AP50<sup>test</sup> | V100 FP32(FPS) | V100 TensorRT FP16(FPS) | 模型下载 | 配置文件 |
|:------------------------:|:-------:|:-------------:|:----------:| :-------:| :--------------------: | :---------------------: |------------: | :---------------------: | :------: | :------: |
| PP-YOLO_r18vd | 4 | 32 | ResNet18vd | 416 | 47.0 | 47.7 | 401.6 | 724.6 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/ppyolo_r18vd.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/ppyolo/ppyolo_r18vd.yml) |
| PP-YOLO_r18vd | 4 | 32 | ResNet18vd | 320 | 43.7 | 44.4 | 478.5 | 791.3 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/ppyolo_r18vd.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/ppyolo/ppyolo_r18vd.yml) |
| 模型 | GPU个数 | 每GPU图片个数 | 模型体积 | 输入尺寸 | Box AP<sup>val</sup> | Kirin 990 (FPS) | 模型下载 | 配置文件 |
|:----------------------------:|:-------:|:-------------:|:----------:| :-------:| :------------------: | :-------------: |------------: | :---------------------: |
| PP-YOLO_MobileNetV3_large | 4 | 32 | 18MB | 320 | 22.0 | 14.1 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/ppyolo_mobilenet_v3_large.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/ppyolo/ppyolo_mobilenet_v3_large.yml) |
| PP-YOLO_MobileNetV3_small | 4 | 32 | 11MB | 320 | 16.8 | 21.5 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/ppyolo_mobilenet_v3_small.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/ppyolo/ppyolo_mobilenet_v3_small.yml) |
- PP-YOLO_r18vd 模型使用COCO数据集中train2017作为训练集,使用val2017和test-dev2017作为测试集,Box AP50<sup>val</sup>`mAP(IoU=0.5)`评估结果。
- PP-YOLO_r18vd 模型训练过程中使用4GPU,每GPU batch size为32进行训练,如训练GPU数和batch size不使用上述配置,须参考[FAQ](../../docs/FAQ.md)调整学习率和迭代次数。
- PP-YOLO_r18vd 模型推理速度测试环境配置和测试方法与PP-YOLO模型一致
- PP-YOLO_MobileNetV3 模型使用COCO数据集中train2017作为训练集,使用val2017作为测试集,Box AP50<sup>val</sup>`mAP(IoU=0.5:0.95)`评估结果。
- PP-YOLO_MobileNetV3 模型训练过程中使用4GPU,每GPU batch size为32进行训练,如训练GPU数和batch size不使用上述配置,须参考[FAQ](../../docs/FAQ.md)调整学习率和迭代次数。
- PP-YOLO_MobileNetV3 模型推理速度测试环境配置为麒麟990芯片单线程
## 使用说明
......
architecture: YOLOv3
use_gpu: true
max_iters: 250000
log_smooth_window: 20
log_iter: 20
save_dir: output
snapshot_iter: 10000
metric: COCO
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV3_large_x1_0_ssld_pretrained.tar
weights: output/ppyolo_tiny/model_final
num_classes: 80
use_fine_grained_loss: true
use_ema: true
ema_decay: 0.9998
YOLOv3:
backbone: MobileNetV3
yolo_head: YOLOv3Head
use_fine_grained_loss: true
MobileNetV3:
norm_type: sync_bn
norm_decay: 0.
model_name: large
scale: 1.
extra_block_filters: []
feature_maps: [1, 2, 3, 4, 6]
YOLOv3Head:
anchor_masks: [[3, 4, 5], [0, 1, 2]]
anchors: [[10, 14], [23, 27], [37, 58],
[81, 82], [135, 169], [344, 319]]
norm_decay: 0.
conv_block_num: 0
scale_x_y: 1.05
yolo_loss: YOLOv3Loss
spp: true
nms:
background_label: -1
keep_top_k: 100
nms_threshold: 0.45
nms_top_k: 1000
normalized: false
score_threshold: 0.01
drop_block: true
YOLOv3Loss:
ignore_thresh: 0.7
scale_x_y: 1.05
label_smooth: false
use_fine_grained_loss: true
iou_loss: IouLoss
IouLoss:
loss_weight: 2.5
max_height: 608
max_width: 608
LearningRate:
base_lr: 0.00666
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones:
- 150000
- 200000
- !LinearWarmup
start_factor: 0.
steps: 4000
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0005
type: L2
_READER_: 'ppyolo_reader.yml'
TrainReader:
inputs_def:
fields: ['image', 'gt_bbox', 'gt_class', 'gt_score']
num_max_boxes: 50
dataset:
!COCODataSet
image_dir: train2017
anno_path: annotations/instances_train2017.json
dataset_dir: dataset/coco
with_background: false
sample_transforms:
- !DecodeImage
to_rgb: True
with_mixup: True
- !MixupImage
alpha: 1.5
beta: 1.5
- !ColorDistort {}
- !RandomExpand
fill_value: [123.675, 116.28, 103.53]
- !RandomCrop {}
- !RandomFlipImage
is_normalized: false
- !NormalizeBox {}
- !PadBox
num_max_boxes: 50
- !BboxXYXY2XYWH {}
batch_transforms:
- !RandomShape
sizes: [320, 352, 384, 416, 448, 480, 512, 544, 576, 608]
random_inter: True
- !NormalizeImage
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
is_scale: True
is_channel_first: false
- !Permute
to_bgr: false
channel_first: True
# Gt2YoloTarget is only used when use_fine_grained_loss set as true,
# this operator will be deleted automatically if use_fine_grained_loss
# is set as false
- !Gt2YoloTarget
anchor_masks: [[3, 4, 5], [0, 1, 2]]
anchors: [[10, 14], [23, 27], [37, 58],
[81, 82], [135, 169], [344, 319]]
downsample_ratios: [32, 16]
batch_size: 32
shuffle: true
mixup_epoch: 500
drop_last: true
worker_num: 8
bufsize: 4
use_process: true
EvalReader:
inputs_def:
fields: ['image', 'im_size', 'im_id']
num_max_boxes: 50
dataset:
!COCODataSet
image_dir: val2017
anno_path: annotations/instances_val2017.json
dataset_dir: dataset/coco
with_background: false
sample_transforms:
- !DecodeImage
to_rgb: True
- !ResizeImage
target_size: 320
interp: 2
- !NormalizeImage
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
is_scale: True
is_channel_first: false
- !PadBox
num_max_boxes: 50
- !Permute
to_bgr: false
channel_first: True
batch_size: 8
drop_empty: false
worker_num: 2
bufsize: 4
architecture: YOLOv3
use_gpu: true
max_iters: 250000
log_smooth_window: 20
log_iter: 20
save_dir: output
snapshot_iter: 10000
metric: COCO
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV3_small_x1_0_ssld_pretrained.tar
weights: output/ppyolo_tiny/model_final
num_classes: 80
use_fine_grained_loss: true
use_ema: true
ema_decay: 0.9998
YOLOv3:
backbone: MobileNetV3
yolo_head: YOLOv3Head
use_fine_grained_loss: true
MobileNetV3:
norm_type: sync_bn
norm_decay: 0.
model_name: small
scale: 1.
extra_block_filters: []
feature_maps: [1, 2, 3, 4, 6]
YOLOv3Head:
anchor_masks: [[3, 4, 5], [0, 1, 2]]
anchors: [[10, 14], [23, 27], [37, 58],
[81, 82], [135, 169], [344, 319]]
norm_decay: 0.
conv_block_num: 0
scale_x_y: 1.05
yolo_loss: YOLOv3Loss
spp: true
nms:
background_label: -1
keep_top_k: 100
nms_threshold: 0.45
nms_top_k: 1000
normalized: false
score_threshold: 0.01
drop_block: true
YOLOv3Loss:
ignore_thresh: 0.7
scale_x_y: 1.05
label_smooth: false
use_fine_grained_loss: true
iou_loss: IouLoss
IouLoss:
loss_weight: 2.5
max_height: 608
max_width: 608
LearningRate:
base_lr: 0.00666
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones:
- 150000
- 200000
- !LinearWarmup
start_factor: 0.
steps: 4000
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0005
type: L2
_READER_: 'ppyolo_reader.yml'
TrainReader:
inputs_def:
fields: ['image', 'gt_bbox', 'gt_class', 'gt_score']
num_max_boxes: 50
dataset:
!COCODataSet
image_dir: train2017
anno_path: annotations/instances_train2017.json
dataset_dir: dataset/coco
with_background: false
sample_transforms:
- !DecodeImage
to_rgb: True
with_mixup: True
- !MixupImage
alpha: 1.5
beta: 1.5
- !ColorDistort {}
- !RandomExpand
fill_value: [123.675, 116.28, 103.53]
- !RandomCrop {}
- !RandomFlipImage
is_normalized: false
- !NormalizeBox {}
- !PadBox
num_max_boxes: 50
- !BboxXYXY2XYWH {}
batch_transforms:
- !RandomShape
sizes: [320, 352, 384, 416, 448, 480, 512, 544, 576, 608]
random_inter: True
- !NormalizeImage
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
is_scale: True
is_channel_first: false
- !Permute
to_bgr: false
channel_first: True
# Gt2YoloTarget is only used when use_fine_grained_loss set as true,
# this operator will be deleted automatically if use_fine_grained_loss
# is set as false
- !Gt2YoloTarget
anchor_masks: [[3, 4, 5], [0, 1, 2]]
anchors: [[10, 14], [23, 27], [37, 58],
[81, 82], [135, 169], [344, 319]]
downsample_ratios: [32, 16]
batch_size: 32
shuffle: true
mixup_epoch: 500
drop_last: true
worker_num: 8
bufsize: 4
use_process: true
EvalReader:
inputs_def:
fields: ['image', 'im_size', 'im_id']
num_max_boxes: 50
dataset:
!COCODataSet
image_dir: val2017
anno_path: annotations/instances_val2017.json
dataset_dir: dataset/coco
with_background: false
sample_transforms:
- !DecodeImage
to_rgb: True
- !ResizeImage
target_size: 320
interp: 2
- !NormalizeImage
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
is_scale: True
is_channel_first: false
- !PadBox
num_max_boxes: 50
- !Permute
to_bgr: false
channel_first: True
batch_size: 8
drop_empty: false
worker_num: 2
bufsize: 4
......@@ -263,7 +263,10 @@ class YOLOv3Head(object):
keep_prob=self.keep_prob,
is_test=is_test)
if self.drop_block and is_first:
if self.use_spp and conv_block_num == 0 and is_first:
conv = self._spp_module(conv, name="spp")
if self.drop_block and (is_first or conv_block_num == 0):
conv = DropBlock(
conv,
block_size=self.block_size,
......
......@@ -44,7 +44,7 @@ logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__)
def split_distill(split_output_names, weight):
def split_distill(split_output_names, weight, target_number):
"""
Add fine grained distillation losses.
Each loss is composed by distill_reg_loss, distill_cls_loss and
......@@ -54,16 +54,27 @@ def split_distill(split_output_names, weight):
for name in split_output_names:
student_var.append(fluid.default_main_program().global_block().var(
name))
s_x0, s_y0, s_w0, s_h0, s_obj0, s_cls0 = student_var[0:6]
s_x1, s_y1, s_w1, s_h1, s_obj1, s_cls1 = student_var[6:12]
s_x2, s_y2, s_w2, s_h2, s_obj2, s_cls2 = student_var[12:18]
s_x, s_y, s_w, s_h, s_obj, s_cls = [], [], [], [], [], []
for i in range(target_number):
s_x.append(student_var[i * 6])
s_y.append(student_var[i * 6 + 1])
s_w.append(student_var[i * 6 + 2])
s_h.append(student_var[i * 6 + 3])
s_obj.append(student_var[i * 6 + 4])
s_cls.append(student_var[i * 6 + 5])
teacher_var = []
for name in split_output_names:
teacher_var.append(fluid.default_main_program().global_block().var(
'teacher_' + name))
t_x0, t_y0, t_w0, t_h0, t_obj0, t_cls0 = teacher_var[0:6]
t_x1, t_y1, t_w1, t_h1, t_obj1, t_cls1 = teacher_var[6:12]
t_x2, t_y2, t_w2, t_h2, t_obj2, t_cls2 = teacher_var[12:18]
t_x, t_y, t_w, t_h, t_obj, t_cls = [], [], [], [], [], []
for i in range(target_number):
t_x.append(teacher_var[i * 6])
t_y.append(teacher_var[i * 6 + 1])
t_w.append(teacher_var[i * 6 + 2])
t_h.append(teacher_var[i * 6 + 3])
t_obj.append(teacher_var[i * 6 + 4])
t_cls.append(teacher_var[i * 6 + 5])
def obj_weighted_reg(sx, sy, sw, sh, tx, ty, tw, th, tobj):
loss_x = fluid.layers.sigmoid_cross_entropy_with_logits(
......@@ -92,26 +103,24 @@ def split_distill(split_output_names, weight):
fluid.layers.sigmoid_cross_entropy_with_logits(sobj, obj_mask))
return loss
distill_reg_loss0 = obj_weighted_reg(s_x0, s_y0, s_w0, s_h0, t_x0, t_y0,
t_w0, t_h0, t_obj0)
distill_reg_loss1 = obj_weighted_reg(s_x1, s_y1, s_w1, s_h1, t_x1, t_y1,
t_w1, t_h1, t_obj1)
distill_reg_loss2 = obj_weighted_reg(s_x2, s_y2, s_w2, s_h2, t_x2, t_y2,
t_w2, t_h2, t_obj2)
distill_reg_loss = fluid.layers.sum(
[distill_reg_loss0, distill_reg_loss1, distill_reg_loss2])
distill_cls_loss0 = obj_weighted_cls(s_cls0, t_cls0, t_obj0)
distill_cls_loss1 = obj_weighted_cls(s_cls1, t_cls1, t_obj1)
distill_cls_loss2 = obj_weighted_cls(s_cls2, t_cls2, t_obj2)
distill_cls_loss = fluid.layers.sum(
[distill_cls_loss0, distill_cls_loss1, distill_cls_loss2])
distill_obj_loss0 = obj_loss(s_obj0, t_obj0)
distill_obj_loss1 = obj_loss(s_obj1, t_obj1)
distill_obj_loss2 = obj_loss(s_obj2, t_obj2)
distill_obj_loss = fluid.layers.sum(
[distill_obj_loss0, distill_obj_loss1, distill_obj_loss2])
distill_reg_losses = []
for i in range(target_number):
distill_reg_losses.append(
obj_weighted_reg(s_x[i], s_y[i], s_w[i], s_h[i], t_x[i], t_y[i],
t_w[i], t_h[i], t_obj[i]))
distill_reg_loss = fluid.layers.sum(distill_reg_losses)
distill_cls_losses = []
for i in range(target_number):
distill_cls_losses.append(
obj_weighted_cls(s_cls[i], t_cls[i], t_obj[i]))
distill_cls_loss = fluid.layers.sum(distill_cls_losses)
distill_obj_losses = []
for i in range(target_number):
distill_obj_losses.append(obj_loss(s_obj[i], t_obj[i]))
distill_obj_loss = fluid.layers.sum(distill_obj_losses)
loss = (distill_reg_loss + distill_cls_loss + distill_obj_loss) * weight
return loss
......@@ -187,31 +196,44 @@ def main():
checkpoint.load_params(exe, teacher_program, FLAGS.teacher_pretrained)
teacher_program = teacher_program.clone(for_test=True)
target_number = len(model.yolo_head.anchor_masks)
data_name_map = {
'target0': 'target0',
'target1': 'target1',
'target2': 'target2',
'image': 'image',
'gt_bbox': 'gt_bbox',
'gt_class': 'gt_class',
'gt_score': 'gt_score'
}
for i in range(target_number):
data_name_map['target{}'.format(i)] = 'target{}'.format(i)
merge(teacher_program, fluid.default_main_program(), data_name_map, place)
yolo_output_names = [
'strided_slice_0.tmp_0', 'strided_slice_1.tmp_0',
'strided_slice_2.tmp_0', 'strided_slice_3.tmp_0',
'strided_slice_4.tmp_0', 'transpose_0.tmp_0', 'strided_slice_5.tmp_0',
'strided_slice_6.tmp_0', 'strided_slice_7.tmp_0',
'strided_slice_8.tmp_0', 'strided_slice_9.tmp_0', 'transpose_2.tmp_0',
'strided_slice_10.tmp_0', 'strided_slice_11.tmp_0',
'strided_slice_12.tmp_0', 'strided_slice_13.tmp_0',
'strided_slice_14.tmp_0', 'transpose_4.tmp_0'
output_names = [
[
'strided_slice_0.tmp_0', 'strided_slice_1.tmp_0',
'strided_slice_2.tmp_0', 'strided_slice_3.tmp_0',
'strided_slice_4.tmp_0', 'transpose_0.tmp_0'
],
[
'strided_slice_5.tmp_0', 'strided_slice_6.tmp_0',
'strided_slice_7.tmp_0', 'strided_slice_8.tmp_0',
'strided_slice_9.tmp_0', 'transpose_2.tmp_0'
],
[
'strided_slice_10.tmp_0', 'strided_slice_11.tmp_0',
'strided_slice_12.tmp_0', 'strided_slice_13.tmp_0',
'strided_slice_14.tmp_0', 'transpose_4.tmp_0'
],
]
yolo_output_names = []
for i in range(target_number):
yolo_output_names.extend(output_names[i])
assert cfg.use_fine_grained_loss, \
"Only support use_fine_grained_loss=True, Please set it in config file or '-o use_fine_grained_loss=true'"
distill_loss = split_distill(yolo_output_names, 1000)
distill_loss = split_distill(yolo_output_names, 1000, target_number)
loss = distill_loss + loss
lr_builder = create('LearningRate')
optim_builder = create('OptimizerBuilder')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册