Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
ff4a2108
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
ff4a2108
编写于
6月 02, 2022
作者:
F
Feng Ni
提交者:
GitHub
6月 02, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[cherry-pick] add exclude_nms and trt silu for YOLOX (#6034)
* add exclude_nms and trt silu for yolox * fix silu act, fix readme
上级
0f424dcb
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
167 addition
and
68 deletion
+167
-68
configs/yolox/README.md
configs/yolox/README.md
+104
-28
ppdet/modeling/architectures/yolox.py
ppdet/modeling/architectures/yolox.py
+1
-6
ppdet/modeling/backbones/csp_darknet.py
ppdet/modeling/backbones/csp_darknet.py
+7
-6
ppdet/modeling/heads/yolo_head.py
ppdet/modeling/heads/yolo_head.py
+19
-7
ppdet/modeling/necks/yolo_fpn.py
ppdet/modeling/necks/yolo_fpn.py
+14
-3
ppdet/modeling/ops.py
ppdet/modeling/ops.py
+22
-18
未找到文件。
configs/yolox/README.md
浏览文件 @
ff4a2108
# YOLOX (YOLOX: Exceeding YOLO Series in 2021)
## Model Zoo
## 内容
-
[
模型库
](
#模型库
)
-
[
使用说明
](
#使用说明
)
-
[
速度测试
](
#速度测试
)
-
[
引用
](
#引用
)
## 模型库
### YOLOX on COCO
| 网络网络 | 输入尺寸 | 图片数/GPU | 学习率策略 |
推理时间(fp
s) | Box AP | 下载链接 | 配置文件 |
| 网络网络 | 输入尺寸 | 图片数/GPU | 学习率策略 |
模型推理耗时(m
s) | Box AP | 下载链接 | 配置文件 |
| :------------- | :------- | :-------: | :------: | :---------: | :-----: | :-------------: | :-----: |
| YOLOX-nano | 416 | 8 | 300e |
----
| 26.1 |
[
下载链接
](
https://paddledet.bj.bcebos.com/models/yolox_nano_300e_coco.pdparams
)
|
[
配置文件
](
./yolox_nano_300e_coco.yml
)
|
| YOLOX-tiny | 416 | 8 | 300e |
----
| 32.9 |
[
下载链接
](
https://paddledet.bj.bcebos.com/models/yolox_tiny_300e_coco.pdparams
)
|
[
配置文件
](
./yolox_tiny_300e_coco.yml
)
|
| YOLOX-s | 640 | 8 | 300e |
----
| 40.4 |
[
下载链接
](
https://paddledet.bj.bcebos.com/models/yolox_s_300e_coco.pdparams
)
|
[
配置文件
](
./yolox_s_300e_coco.yml
)
|
| YOLOX-m | 640 | 8 | 300e |
----
| 46.9 |
[
下载链接
](
https://paddledet.bj.bcebos.com/models/yolox_m_300e_coco.pdparams
)
|
[
配置文件
](
./yolox_m_300e_coco.yml
)
|
| YOLOX-l | 640 | 8 | 300e |
----
| 50.1 |
[
下载链接
](
https://paddledet.bj.bcebos.com/models/yolox_l_300e_coco.pdparams
)
|
[
配置文件
](
./yolox_l_300e_coco.yml
)
|
| YOLOX-x | 640 | 8 | 300e |
----
| 51.8 |
[
下载链接
](
https://paddledet.bj.bcebos.com/models/yolox_x_300e_coco.pdparams
)
|
[
配置文件
](
./yolox_x_300e_coco.yml
)
|
| YOLOX-nano | 416 | 8 | 300e |
2.3
| 26.1 |
[
下载链接
](
https://paddledet.bj.bcebos.com/models/yolox_nano_300e_coco.pdparams
)
|
[
配置文件
](
./yolox_nano_300e_coco.yml
)
|
| YOLOX-tiny | 416 | 8 | 300e |
2.8
| 32.9 |
[
下载链接
](
https://paddledet.bj.bcebos.com/models/yolox_tiny_300e_coco.pdparams
)
|
[
配置文件
](
./yolox_tiny_300e_coco.yml
)
|
| YOLOX-s | 640 | 8 | 300e |
3.0
| 40.4 |
[
下载链接
](
https://paddledet.bj.bcebos.com/models/yolox_s_300e_coco.pdparams
)
|
[
配置文件
](
./yolox_s_300e_coco.yml
)
|
| YOLOX-m | 640 | 8 | 300e |
5.8
| 46.9 |
[
下载链接
](
https://paddledet.bj.bcebos.com/models/yolox_m_300e_coco.pdparams
)
|
[
配置文件
](
./yolox_m_300e_coco.yml
)
|
| YOLOX-l | 640 | 8 | 300e |
9.3
| 50.1 |
[
下载链接
](
https://paddledet.bj.bcebos.com/models/yolox_l_300e_coco.pdparams
)
|
[
配置文件
](
./yolox_l_300e_coco.yml
)
|
| YOLOX-x | 640 | 8 | 300e |
16.6
| 51.8 |
[
下载链接
](
https://paddledet.bj.bcebos.com/models/yolox_x_300e_coco.pdparams
)
|
[
配置文件
](
./yolox_x_300e_coco.yml
)
|
**注意:**
-
YOLOX模型训练使用COCO train2017作为训练集,Box AP为在COCO val2017上的
`mAP(IoU=0.5:0.95)`
结果;
-
YOLOX模型训练过程中默认使用8 GPUs进行混合精度训练,默认
单卡batch_size为8,如果
**GPU卡数**
或者
**batch size**
发生了改变,你需要按照公式
**lr<sub>new</sub> = lr<sub>default</sub> * (batch_size<sub>new</sub> * GPU_number<sub>new</sub>) / (batch_size<sub>default</sub> * GPU_number<sub>default</sub>)**
调整学习率;
-
为保持高mAP的同时提高推理速度,可以将
[
yolox_cspdarknet.yml
](
_base_/yolox_cspdarknet.yml
)
中的
`nms_top_k`
修改为
`1000`
,将
`keep_top_k`
修改为
`100`
,mAP会下降约0.1~0.2%;
-
YOLOX模型训练过程中默认使用8 GPUs进行混合精度训练,默认
每卡batch_size为8,默认lr为0.01为8卡总batch_size=64的设置,如果
**GPU卡数**
或者每卡
**batch size**
发生了改变,你需要按照公式
**lr<sub>new</sub> = lr<sub>default</sub> * (batch_size<sub>new</sub> * GPU_number<sub>new</sub>) / (batch_size<sub>default</sub> * GPU_number<sub>default</sub>)**
调整学习率;
-
为保持高mAP的同时提高推理速度,可以将
[
yolox_cspdarknet.yml
](
_base_/yolox_cspdarknet.yml
)
中的
`nms_top_k`
修改为
`1000`
,将
`keep_top_k`
修改为
`100`
,
将
`score_threshold`
修改为
`0.01`
,
mAP会下降约0.1~0.2%;
-
为快速的demo演示效果,可以将
[
yolox_cspdarknet.yml
](
_base_/yolox_cspdarknet.yml
)
中的
`score_threshold`
修改为
`0.25`
,将
`nms_threshold`
修改为
`0.45`
,但mAP会下降较多;
-
YOLOX模型推理速度测试采用单卡V100,batch size=1进行测试,使用
**CUDA 10.2**
,
**CUDNN 7.6.5**
,TensorRT推理速度测试使用
**TensorRT 6.0.1.8**
。
-
参考
[
速度测试
](
#速度测试
)
以复现YOLOX推理速度测试结果,速度为tensorRT-FP16测速后的最快速度,不包含数据预处理和模型输出后处理(NMS)的耗时。
-
如果你设置了
`--run_benchmark=True`
, 你首先需要安装以下依赖
`pip install pynvml psutil GPUtil`
。
## 使用教程
### 1.
训练
### 1.训练
执行以下指令使用混合精度训练YOLOX
```
bash
python
-m
paddle.distributed.launch
--gpus
0,1,2,3,4,5,6,7 tools/train.py
-c
configs/yolox/yolox_s_300e_coco.yml
--
fleet
--
amp
--eval
python
-m
paddle.distributed.launch
--gpus
0,1,2,3,4,5,6,7 tools/train.py
-c
configs/yolox/yolox_s_300e_coco.yml
--amp
--eval
```
**注意:**
使用默认配置训练需要设置
`--fleet`
,
`--amp`
最好也设置
以避免显存溢出,
`--eval`
表示边训边验证。
-
`--amp`
表示开启混合精度训练
以避免显存溢出,
`--eval`
表示边训边验证。
### 2.
评估
### 2.评估
执行以下命令在单个GPU上评估COCO val2017数据集
```
bash
CUDA_VISIBLE_DEVICES
=
0 python tools/eval.py
-c
configs/yolox/yolox_s_300e_coco.yml
-o
weights
=
https://paddledet.bj.bcebos.com/models/yolox_s_300e_coco.pdparams
```
### 3.
推理
### 3.推理
使用以下命令在单张GPU上预测图片,使用
`--infer_img`
推理单张图片以及使用
`--infer_dir`
推理文件中的所有图片。
```
bash
# 推理单张图片
...
...
@@ -45,16 +54,57 @@ CUDA_VISIBLE_DEVICES=0 python tools/infer.py -c configs/yolox/yolox_s_300e_coco.
CUDA_VISIBLE_DEVICES
=
0 python tools/infer.py
-c
configs/yolox/yolox_s_300e_coco.yml
-o
weights
=
https://paddledet.bj.bcebos.com/models/yolox_s_300e_coco.pdparams
--infer_dir
=
demo
```
### 4. 部署
#### 4.1. 导出模型
### 4.导出模型
YOLOX在GPU上推理部署或benchmark测速等需要通过
`tools/export_model.py`
导出模型。
运行以下的命令进行导出:
当你
**使用Paddle Inference但不使用TensorRT**
时,运行以下的命令导出模型
```
bash
python tools/export_model.py
-c
configs/yolox/yolox_s_300e_coco.yml
-o
weights
=
https://paddledet.bj.bcebos.com/models/yolox_s_300e_coco.pdparams
```
#### 4.2. Python部署
当你
**使用Paddle Inference且使用TensorRT**
时,需要指定
`-o trt=True`
来导出模型。
```
bash
python tools/export_model.py
-c
configs/yolox/yolox_s_300e_coco.yml
-o
weights
=
https://paddledet.bj.bcebos.com/models/yolox_s_300e_coco.pdparams
trt
=
True
```
如果你想将YOLOX模型导出为
**ONNX格式**
,参考
[
PaddleDetection模型导出为ONNX格式教程
](
../../deploy/EXPORT_ONNX_MODEL.md
)
,运行以下命令:
```
bash
# 导出推理模型
python tools/export_model.py
-c
configs/yolox/yolox_s_300e_coco.yml
--output_dir
=
output_inference
-o
weights
=
https://paddledet.bj.bcebos.com/models/yolox_s_300e_coco.pdparams
# 安装paddle2onnx
pip
install
paddle2onnx
# 转换成onnx格式
paddle2onnx
--model_dir
output_inference/yolox_s_300e_coco
--model_filename
model.pdmodel
--params_filename
model.pdiparams
--opset_version
11
--save_file
yolox_s_300e_coco.onnx
```
**注意:**
ONNX模型目前只支持batch_size=1
### 5.推理部署
YOLOX可以使用以下方式进行部署:
-
Paddle Inference
[
Python
](
../../deploy/python
)
&
[
C++
](
../../deploy/cpp
)
-
[
Paddle-TensorRT
](
../../deploy/TENSOR_RT.md
)
-
[
PaddleServing
](
https://github.com/PaddlePaddle/Serving
)
-
[
PaddleSlim模型量化
](
../slim
)
运行以下命令导出模型
```
bash
python tools/export_model.py
-c
configs/yolox/yolox_s_300e_coco.yml
-o
weights
=
https://paddledet.bj.bcebos.com/models/yolox_s_300e_coco.pdparams
trt
=
True
```
**注意:**
-
trt=True表示
**使用Paddle Inference且使用TensorRT**
进行测速,速度会更快,默认不加即为False,表示
**使用Paddle Inference但不使用TensorRT**
进行测速。
-
如果是使用Paddle Inference在TensorRT FP16模式下部署,需要参考
[
Paddle Inference文档
](
https://www.paddlepaddle.org.cn/inference/master/user_guides/download_lib.html#python
)
,下载并安装与你的CUDA, CUDNN和TensorRT相应的wheel包。
#### 5.1.Python部署
`deploy/python/infer.py`
使用上述导出后的Paddle Inference模型用于推理和benchnark测速,如果设置了
`--run_benchmark=True`
, 首先需要安装以下依赖
`pip install pynvml psutil GPUtil`
。
```
bash
...
...
@@ -63,8 +113,37 @@ python deploy/python/infer.py --model_dir=output_inference/yolox_s_300e_coco --i
# 推理文件夹下的所有图片
python deploy/python/infer.py
--model_dir
=
output_inference/yolox_s_300e_coco
--image_dir
=
demo/
--device
=
gpu
```
#### 5.2. C++部署
`deploy/cpp/build/main`
使用上述导出后的Paddle Inference模型用于C++推理部署, 首先按照
[
docs
](
../../deploy/cpp/docs
)
编译安装环境。
```
bash
# C++部署推理单张图片
./deploy/cpp/build/main
--model_dir
=
output_inference/yolox_s_300e_coco/
--image_file
=
demo/000000014439_640x640.jpg
--run_mode
=
paddle
--device
=
GPU
--threshold
=
0.5
--output_dir
=
cpp_infer_output/yolox_s_300e_coco
```
## 速度测试
为了公平起见,在
[
模型库
](
#模型库
)
中的速度测试结果均为不包含数据预处理和模型输出后处理(NMS)的数据(与
[
YOLOv4(AlexyAB)
](
https://github.com/AlexeyAB/darknet
)
测试方法一致),需要在导出模型时指定
`-o exclude_nms=True`
。测速需设置
`--run_benchmark=True`
, 首先需要安装以下依赖
`pip install pynvml psutil GPUtil`
。
**使用Paddle Inference但不使用TensorRT**
进行测速,执行以下命令:
```
bash
# 导出模型
python tools/export_model.py
-c
configs/yolox/yolox_s_300e_coco.yml
-o
weights
=
https://paddledet.bj.bcebos.com/models/yolox_s_300e_coco.pdparams
exclude_nms
=
True
# 速度测试,使用run_benchmark=True
python deploy/python/infer.py
--model_dir
=
output_inference/yolox_s_300e_coco
--image_file
=
demo/000000014439_640x640.jpg
--run_mode
=
paddle
--device
=
gpu
--run_benchmark
=
True
```
**使用Paddle Inference且使用TensorRT**
进行测速,执行以下命令:
```
bash
# 导出模型,使用trt=True
python tools/export_model.py
-c
configs/yolox/yolox_s_300e_coco.yml
-o
weights
=
https://paddledet.bj.bcebos.com/models/yolox_s_300e_coco.pdparams
exclude_nms
=
True
trt
=
True
#
benchmark测速
#
速度测试,使用run_benchmark=True
python deploy/python/infer.py
--model_dir
=
output_inference/yolox_s_300e_coco
--image_file
=
demo/000000014439_640x640.jpg
--device
=
gpu
--run_benchmark
=
True
# tensorRT-FP32测速
...
...
@@ -73,13 +152,10 @@ python deploy/python/infer.py --model_dir=output_inference/yolox_s_300e_coco --i
# tensorRT-FP16测速
python deploy/python/infer.py
--model_dir
=
output_inference/yolox_s_300e_coco
--image_file
=
demo/000000014439_640x640.jpg
--device
=
gpu
--run_benchmark
=
True
--trt_max_shape
=
640
--trt_min_shape
=
640
--trt_opt_shape
=
640
--run_mode
=
trt_fp16
```
**注意:**
-
导出模型时指定
`-o exclude_nms=True`
仅作为测速时用,这样导出的模型其推理部署预测的结果不是最终检出框的结果。
-
[
模型库
](
#模型库
)
中的速度测试结果为tensorRT-FP16测速后的最快速度,为不包含数据预处理和模型输出后处理(NMS)的耗时。
#### 4.2. C++部署
`deploy/cpp/build/main`
使用上述导出后的Paddle Inference模型用于C++推理部署, 首先按照
[
docs
](
../../deploy/cpp/docs
)
编译安装环境。
```
bash
# C++部署推理单张图片
./deploy/cpp/build/main
--model_dir
=
output_inference/yolox_s_300e_coco/
--image_file
=
demo/000000014439_640x640.jpg
--run_mode
=
paddle
--device
=
GPU
--threshold
=
0.5
--output_dir
=
cpp_infer_output/yolox_s_300e_coco
```
## Citations
```
...
...
ppdet/modeling/architectures/yolox.py
浏览文件 @
ff4a2108
...
...
@@ -135,10 +135,5 @@ class YOLOX(BaseArch):
self
.
size_stride
*
size_factor
,
self
.
size_stride
*
int
(
size_factor
*
image_ratio
)
]
size
=
paddle
.
to_tensor
(
size
)
if
dist
.
get_world_size
()
>
1
and
paddle_distributed_is_initialized
(
):
dist
.
barrier
()
dist
.
broadcast
(
size
,
0
)
self
.
_input_size
=
size
self
.
_input_size
=
paddle
.
to_tensor
(
size
)
self
.
_step
+=
1
ppdet/modeling/backbones/csp_darknet.py
浏览文件 @
ff4a2108
...
...
@@ -18,7 +18,6 @@ import paddle.nn.functional as F
from
paddle
import
ParamAttr
from
paddle.regularizer
import
L2Decay
from
ppdet.core.workspace
import
register
,
serializable
from
ppdet.modeling.ops
import
get_activation
from
ppdet.modeling.initializer
import
conv_init_
from
..shape_spec
import
ShapeSpec
...
...
@@ -49,7 +48,6 @@ class BaseConv(nn.Layer):
out_channels
,
weight_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.0
)),
bias_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.0
)))
self
.
act
=
get_activation
(
act
)
self
.
_init_weights
()
...
...
@@ -57,7 +55,10 @@ class BaseConv(nn.Layer):
conv_init_
(
self
.
conv
)
def
forward
(
self
,
x
):
return
self
.
act
(
self
.
bn
(
self
.
conv
(
x
)))
# use 'x * F.sigmoid(x)' replace 'silu'
x
=
self
.
bn
(
self
.
conv
(
x
))
y
=
x
*
F
.
sigmoid
(
x
)
return
y
class
DWConv
(
nn
.
Layer
):
...
...
@@ -78,7 +79,7 @@ class DWConv(nn.Layer):
stride
=
stride
,
groups
=
in_channels
,
bias
=
bias
,
act
=
act
,
)
act
=
act
)
self
.
pw_conv
=
BaseConv
(
in_channels
,
out_channels
,
...
...
@@ -274,7 +275,7 @@ class CSPDarkNet(nn.Layer):
return_idx (list): Index of stages whose feature maps are returned.
"""
__shared__
=
[
'depth_mult'
,
'width_mult'
,
'act'
]
__shared__
=
[
'depth_mult'
,
'width_mult'
,
'act'
,
'trt'
]
# in_channels, out_channels, num_blocks, add_shortcut, use_spp(use_sppf)
# 'X' means setting used in YOLOX, 'P5/P6' means setting used in YOLOv5.
...
...
@@ -294,12 +295,12 @@ class CSPDarkNet(nn.Layer):
width_mult
=
1.0
,
depthwise
=
False
,
act
=
'silu'
,
trt
=
False
,
return_idx
=
[
2
,
3
,
4
]):
super
(
CSPDarkNet
,
self
).
__init__
()
self
.
arch
=
arch
self
.
return_idx
=
return_idx
Conv
=
DWConv
if
depthwise
else
BaseConv
arch_setting
=
self
.
arch_settings
[
arch
]
base_channels
=
int
(
arch_setting
[
0
][
0
]
*
width_mult
)
...
...
ppdet/modeling/heads/yolo_head.py
浏览文件 @
ff4a2108
...
...
@@ -26,6 +26,7 @@ from ..backbones.csp_darknet import BaseConv, DWConv
from
..losses
import
IouLoss
from
ppdet.modeling.assigners.simota_assigner
import
SimOTAAssigner
from
ppdet.modeling.bbox_utils
import
bbox_overlaps
from
ppdet.modeling.layers
import
MultiClassNMS
__all__
=
[
'YOLOv3Head'
,
'YOLOXHead'
]
...
...
@@ -150,7 +151,7 @@ class YOLOv3Head(nn.Layer):
@
register
class
YOLOXHead
(
nn
.
Layer
):
__shared__
=
[
'num_classes'
,
'width_mult'
,
'act'
]
__shared__
=
[
'num_classes'
,
'width_mult'
,
'act'
,
'trt'
,
'exclude_nms'
]
__inject__
=
[
'assigner'
,
'nms'
]
def
__init__
(
self
,
...
...
@@ -164,10 +165,14 @@ class YOLOXHead(nn.Layer):
act
=
'silu'
,
assigner
=
SimOTAAssigner
(
use_vfl
=
False
),
nms
=
'MultiClassNMS'
,
loss_weight
=
{
'cls'
:
1.0
,
'obj'
:
1.0
,
'iou'
:
5.0
,
'l1'
:
1.0
}):
loss_weight
=
{
'cls'
:
1.0
,
'obj'
:
1.0
,
'iou'
:
5.0
,
'l1'
:
1.0
,
},
trt
=
False
,
exclude_nms
=
False
):
super
(
YOLOXHead
,
self
).
__init__
()
self
.
_dtype
=
paddle
.
framework
.
get_default_dtype
()
self
.
num_classes
=
num_classes
...
...
@@ -178,6 +183,9 @@ class YOLOXHead(nn.Layer):
self
.
l1_epoch
=
l1_epoch
self
.
assigner
=
assigner
self
.
nms
=
nms
if
isinstance
(
self
.
nms
,
MultiClassNMS
)
and
trt
:
self
.
nms
.
trt
=
trt
self
.
exclude_nms
=
exclude_nms
self
.
loss_weight
=
loss_weight
self
.
iou_loss
=
IouLoss
(
loss_weight
=
1.0
)
# default loss_weight 2.5
...
...
@@ -400,5 +408,9 @@ class YOLOXHead(nn.Layer):
# scale bbox to origin image
scale_factor
=
scale_factor
.
flip
(
-
1
).
tile
([
1
,
2
]).
unsqueeze
(
1
)
pred_bboxes
/=
scale_factor
bbox_pred
,
bbox_num
,
_
=
self
.
nms
(
pred_bboxes
,
pred_scores
)
return
bbox_pred
,
bbox_num
if
self
.
exclude_nms
:
# `exclude_nms=True` just use in benchmark
return
pred_bboxes
.
sum
(),
pred_scores
.
sum
()
else
:
bbox_pred
,
bbox_num
,
_
=
self
.
nms
(
pred_bboxes
,
pred_scores
)
return
bbox_pred
,
bbox_num
ppdet/modeling/necks/yolo_fpn.py
浏览文件 @
ff4a2108
...
...
@@ -17,6 +17,7 @@ import paddle.nn as nn
import
paddle.nn.functional
as
F
from
ppdet.core.workspace
import
register
,
serializable
from
ppdet.modeling.layers
import
DropBlock
from
ppdet.modeling.ops
import
get_act_fn
from
..backbones.darknet
import
ConvBNLayer
from
..shape_spec
import
ShapeSpec
from
..backbones.csp_darknet
import
BaseConv
,
DWConv
,
CSPLayer
...
...
@@ -995,18 +996,24 @@ class YOLOCSPPAN(nn.Layer):
"""
YOLO CSP-PAN, used in YOLOv5 and YOLOX.
"""
__shared__
=
[
'depth_mult'
,
'
ac
t'
]
__shared__
=
[
'depth_mult'
,
'
data_format'
,
'act'
,
'tr
t'
]
def
__init__
(
self
,
depth_mult
=
1.0
,
in_channels
=
[
256
,
512
,
1024
],
depthwise
=
False
,
act
=
'silu'
):
data_format
=
'NCHW'
,
act
=
'silu'
,
trt
=
False
):
super
(
YOLOCSPPAN
,
self
).
__init__
()
self
.
in_channels
=
in_channels
self
.
_out_channels
=
in_channels
Conv
=
DWConv
if
depthwise
else
BaseConv
self
.
data_format
=
data_format
act
=
get_act_fn
(
act
,
trt
=
trt
)
if
act
is
None
or
isinstance
(
act
,
(
str
,
dict
))
else
act
self
.
upsample
=
nn
.
Upsample
(
scale_factor
=
2
,
mode
=
"nearest"
)
# top-down fpn
...
...
@@ -1061,7 +1068,11 @@ class YOLOCSPPAN(nn.Layer):
feat_heigh
)
inner_outs
[
0
]
=
feat_heigh
upsample_feat
=
self
.
upsample
(
feat_heigh
)
upsample_feat
=
F
.
interpolate
(
feat_heigh
,
scale_factor
=
2.
,
mode
=
"nearest"
,
data_format
=
self
.
data_format
)
inner_out
=
self
.
fpn_blocks
[
len
(
self
.
in_channels
)
-
1
-
idx
](
paddle
.
concat
(
[
upsample_feat
,
feat_low
],
axis
=
1
))
...
...
ppdet/modeling/ops.py
浏览文件 @
ff4a2108
...
...
@@ -25,10 +25,22 @@ from paddle.fluid.layer_helper import LayerHelper
from
paddle.fluid.data_feeder
import
check_variable_and_dtype
,
check_type
,
check_dtype
__all__
=
[
'roi_pool'
,
'roi_align'
,
'prior_box'
,
'generate_proposals'
,
'iou_similarity'
,
'box_coder'
,
'yolo_box'
,
'multiclass_nms'
,
'distribute_fpn_proposals'
,
'collect_fpn_proposals'
,
'matrix_nms'
,
'batch_norm'
,
'get_activation'
,
'mish'
,
'swish'
,
'identity'
'roi_pool'
,
'roi_align'
,
'prior_box'
,
'generate_proposals'
,
'iou_similarity'
,
'box_coder'
,
'yolo_box'
,
'multiclass_nms'
,
'distribute_fpn_proposals'
,
'collect_fpn_proposals'
,
'matrix_nms'
,
'batch_norm'
,
'mish'
,
'silu'
,
'swish'
,
'identity'
,
]
...
...
@@ -40,13 +52,17 @@ def mish(x):
return
F
.
mish
(
x
)
if
hasattr
(
F
,
mish
)
else
x
*
F
.
tanh
(
F
.
softplus
(
x
))
def
silu
(
x
):
return
F
.
silu
(
x
)
def
swish
(
x
):
return
x
*
F
.
sigmoid
(
x
)
TRT_ACT_SPEC
=
{
'swish'
:
swish
}
TRT_ACT_SPEC
=
{
'swish'
:
swish
,
'silu'
:
swish
}
ACT_SPEC
=
{
'mish'
:
mish
}
ACT_SPEC
=
{
'mish'
:
mish
,
'silu'
:
silu
}
def
get_act_fn
(
act
=
None
,
trt
=
False
):
...
...
@@ -106,18 +122,6 @@ def batch_norm(ch,
return
norm_layer
def
get_activation
(
name
=
"silu"
):
if
name
==
"silu"
:
module
=
nn
.
Silu
()
elif
name
==
"relu"
:
module
=
nn
.
ReLU
()
elif
name
==
"leakyrelu"
:
module
=
nn
.
LeakyReLU
(
0.1
)
else
:
raise
AttributeError
(
"Unsupported act type: {}"
.
format
(
name
))
return
module
@
paddle
.
jit
.
not_to_static
def
roi_pool
(
input
,
rois
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录