Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSlim
提交
e263e885
P
PaddleSlim
项目概览
PaddlePaddle
/
PaddleSlim
大约 2 年 前同步成功
通知
51
Star
1434
Fork
344
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
16
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSlim
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
16
合并请求
16
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
e263e885
编写于
10月 12, 2022
作者:
G
Guanghua Yu
提交者:
GitHub
10月 12, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update picodet full quant demo (#1460)
上级
38f6f578
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
611 addition
and
135 deletion
+611
-135
example/full_quantization/picodet/README.md
example/full_quantization/picodet/README.md
+19
-9
example/full_quantization/picodet/configs/picodet_npu.yaml
example/full_quantization/picodet/configs/picodet_npu.yaml
+35
-0
example/full_quantization/picodet/configs/picodet_npu_with_postprocess.yaml
...ization/picodet/configs/picodet_npu_with_postprocess.yaml
+2
-1
example/full_quantization/picodet/configs/picodet_reader.yml
example/full_quantization/picodet/configs/picodet_reader.yml
+13
-6
example/full_quantization/picodet/eval.py
example/full_quantization/picodet/eval.py
+43
-66
example/full_quantization/picodet/onnxruntime_eval.py
example/full_quantization/picodet/onnxruntime_eval.py
+128
-0
example/full_quantization/picodet/post_process.py
example/full_quantization/picodet/post_process.py
+227
-0
example/full_quantization/picodet/post_quant.py
example/full_quantization/picodet/post_quant.py
+102
-0
example/full_quantization/picodet/run.py
example/full_quantization/picodet/run.py
+42
-53
未找到文件。
example/full_quantization/
detection
/README.md
→
example/full_quantization/
picodet
/README.md
浏览文件 @
e263e885
...
@@ -23,7 +23,7 @@
...
@@ -23,7 +23,7 @@
| 模型 | 策略 | mAP | TRT-FP32 | TRT-FP16 | TRT-INT8 | 配置文件 | 模型 |
| 模型 | 策略 | mAP | TRT-FP32 | TRT-FP16 | TRT-INT8 | 配置文件 | 模型 |
| :-------- |:-------- |:--------: | :----------------: | :----------------: | :---------------: | :----------------------: | :---------------------: |
| :-------- |:-------- |:--------: | :----------------: | :----------------: | :---------------: | :----------------------: | :---------------------: |
| PicoDet-S-NPU | Baseline | 30.1 | - | - | - |
[
config
](
https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/picodet/picodet_s_416_coco_npu.yml
)
|
[
Model
](
https://bj.bcebos.com/v1/paddle-slim-models/act/picodet_s_416_coco_npu.tar
)
|
| PicoDet-S-NPU | Baseline | 30.1 | - | - | - |
[
config
](
https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/picodet/picodet_s_416_coco_npu.yml
)
|
[
Model
](
https://bj.bcebos.com/v1/paddle-slim-models/act/picodet_s_416_coco_npu.tar
)
|
| PicoDet-S-NPU | 量化训练 | 29.7 | - | - | - |
[
config
](
https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/full_quantization/
detection/configs/picodet_s_qat_di
s.yaml
)
|
[
Model
](
https://bj.bcebos.com/v1/paddle-slim-models/act/picodet_s_npu_quant.tar
)
|
| PicoDet-S-NPU | 量化训练 | 29.7 | - | - | - |
[
config
](
https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/full_quantization/
picodet/configs/picodet_npu_with_postproces
s.yaml
)
|
[
Model
](
https://bj.bcebos.com/v1/paddle-slim-models/act/picodet_s_npu_quant.tar
)
|
-
mAP的指标均在COCO val2017数据集中评测得到,IoU=0.5:0.95。
-
mAP的指标均在COCO val2017数据集中评测得到,IoU=0.5:0.95。
...
@@ -31,7 +31,7 @@
...
@@ -31,7 +31,7 @@
#### 3.1 准备环境
#### 3.1 准备环境
-
PaddlePaddle >= 2.3 (可从
[
Paddle官网
](
https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html
)
下载安装)
-
PaddlePaddle >= 2.3 (可从
[
Paddle官网
](
https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html
)
下载安装)
-
PaddleSlim >= 2.3
-
PaddleSlim >= 2.3
.4
-
PaddleDet >= 2.4
-
PaddleDet >= 2.4
-
opencv-python
-
opencv-python
...
@@ -67,9 +67,6 @@ pip install paddledet
...
@@ -67,9 +67,6 @@ pip install paddledet
预测模型的格式为:
`model.pdmodel`
和
`model.pdiparams`
两个,带
`pdmodel`
的是模型文件,带
`pdiparams`
后缀的是权重文件。
预测模型的格式为:
`model.pdmodel`
和
`model.pdiparams`
两个,带
`pdmodel`
的是模型文件,带
`pdiparams`
后缀的是权重文件。
注:其他像
`__model__`
和
`__params__`
分别对应
`model.pdmodel`
和
`model.pdiparams`
文件。
根据
[
PaddleDetection文档
](
https://github.com/PaddlePaddle/PaddleDetection/blob/develop/docs/tutorials/GETTING_STARTED_cn.md#8-%E6%A8%A1%E5%9E%8B%E5%AF%BC%E5%87%BA
)
导出Inference模型,具体可参考下方PicoDet-S-NPU模型的导出示例:
根据
[
PaddleDetection文档
](
https://github.com/PaddlePaddle/PaddleDetection/blob/develop/docs/tutorials/GETTING_STARTED_cn.md#8-%E6%A8%A1%E5%9E%8B%E5%AF%BC%E5%87%BA
)
导出Inference模型,具体可参考下方PicoDet-S-NPU模型的导出示例:
-
下载代码
-
下载代码
```
```
...
@@ -77,13 +74,20 @@ git clone https://github.com/PaddlePaddle/PaddleDetection.git
...
@@ -77,13 +74,20 @@ git clone https://github.com/PaddlePaddle/PaddleDetection.git
```
```
-
导出预测模型
-
导出预测模型
PicoDet-S-NPU模型,包含
NMS
:如快速体验,可直接下载
[
PicoDet-S-NPU导出模型
](
https://bj.bcebos.com/v1/paddle-slim-models/act/picodet_s_416_coco_npu.tar
)
PicoDet-S-NPU模型,包含
后处理
:如快速体验,可直接下载
[
PicoDet-S-NPU导出模型
](
https://bj.bcebos.com/v1/paddle-slim-models/act/picodet_s_416_coco_npu.tar
)
```
shell
```
shell
python tools/export_model.py
\
python tools/export_model.py
\
-c
configs/picodet/picodet_s_416_coco_npu.yml
\
-c
configs/picodet/picodet_s_416_coco_npu.yml
\
-o
weights
=
https://paddledet.bj.bcebos.com/models/picodet_s_416_coco_npu.pdparams
\
-o
weights
=
https://paddledet.bj.bcebos.com/models/picodet_s_416_coco_npu.pdparams
\
```
```
导出PicoDet-S-NPU不带后处理模型:
```
shell
python tools/export_model.py
\
-c
configs/picodet/picodet_s_416_coco_npu.yml
\
-o
weights
=
https://paddledet.bj.bcebos.com/models/picodet_s_416_coco_npu.pdparams
\
export.benchmark
=
True
```
#### 3.4 全量化并产出模型
#### 3.4 全量化并产出模型
...
@@ -92,14 +96,20 @@ python tools/export_model.py \
...
@@ -92,14 +96,20 @@ python tools/export_model.py \
-
单卡训练:
-
单卡训练:
```
```
export CUDA_VISIBLE_DEVICES=0
export CUDA_VISIBLE_DEVICES=0
python run.py --config_path=./configs/picodet_
s_qat_di
s.yaml --save_dir='./output/'
python run.py --config_path=./configs/picodet_
npu_with_postproces
s.yaml --save_dir='./output/'
```
```
-
多卡训练:
-
多卡训练:
```
```
CUDA_VISIBLE_DEVICES=0,1,2,3
CUDA_VISIBLE_DEVICES=0,1,2,3
python -m paddle.distributed.launch --log_dir=log --gpus 0,1,2,3 run.py \
python -m paddle.distributed.launch --log_dir=log --gpus 0,1,2,3 run.py \
--config_path=./configs/picodet_s_qat_dis.yaml --save_dir='./output/'
--config_path=./configs/picodet_npu_with_postprocess.yaml --save_dir='./output/'
```
-
不带后处理PicoDet模型训练:
```
export CUDA_VISIBLE_DEVICES=0
python run.py --config_path=./configs/picodet_npu.yaml --save_dir='./output/'
```
```
#### 3.5 测试模型精度
#### 3.5 测试模型精度
...
@@ -107,7 +117,7 @@ python -m paddle.distributed.launch --log_dir=log --gpus 0,1,2,3 run.py \
...
@@ -107,7 +117,7 @@ python -m paddle.distributed.launch --log_dir=log --gpus 0,1,2,3 run.py \
使用eval.py脚本得到模型的mAP:
使用eval.py脚本得到模型的mAP:
```
```
export CUDA_VISIBLE_DEVICES=0
export CUDA_VISIBLE_DEVICES=0
python eval.py --config_path=./configs/picodet_
s_qat_di
s.yaml
python eval.py --config_path=./configs/picodet_
npu_with_postproces
s.yaml
```
```
**注意**
:
**注意**
:
...
...
example/full_quantization/picodet/configs/picodet_npu.yaml
0 → 100644
浏览文件 @
e263e885
Global
:
reader_config
:
./configs/picodet_reader.yml
input_list
:
[
'
image'
]
include_post_process
:
False
Evaluation
:
True
model_dir
:
./picodet_s_416_coco_npu
model_filename
:
model.pdmodel
params_filename
:
model.pdiparams
Distillation
:
alpha
:
1.0
loss
:
l2
Quantization
:
use_pact
:
true
activation_quantize_type
:
'
moving_average_abs_max'
weight_bits
:
8
activation_bits
:
8
quantize_op_types
:
-
conv2d
-
depthwise_conv2d
TrainConfig
:
train_iter
:
8000
eval_iter
:
1000
learning_rate
:
type
:
CosineAnnealingDecay
learning_rate
:
0.00001
T_max
:
8000
optimizer_builder
:
optimizer
:
type
:
SGD
weight_decay
:
4.0e-05
example/full_quantization/
detection/configs/picodet_s_qat_di
s.yaml
→
example/full_quantization/
picodet/configs/picodet_npu_with_postproces
s.yaml
浏览文件 @
e263e885
Global
:
Global
:
reader_config
:
./configs/picodet_reader.yml
reader_config
:
./configs/picodet_reader.yml
input_list
:
[
'
image'
,
'
scale_factor'
]
input_list
:
[
'
image'
,
'
scale_factor'
]
include_post_process
:
True
Evaluation
:
True
Evaluation
:
True
model_dir
:
./picodet_s_416_coco_npu
/
model_dir
:
./picodet_s_416_coco_npu
model_filename
:
model.pdmodel
model_filename
:
model.pdmodel
params_filename
:
model.pdiparams
params_filename
:
model.pdiparams
...
...
example/full_quantization/
detection
/configs/picodet_reader.yml
→
example/full_quantization/
picodet
/configs/picodet_reader.yml
浏览文件 @
e263e885
...
@@ -7,26 +7,33 @@ TrainDataset:
...
@@ -7,26 +7,33 @@ TrainDataset:
!COCODataSet
!COCODataSet
image_dir
:
train2017
image_dir
:
train2017
anno_path
:
annotations/instances_train2017.json
anno_path
:
annotations/instances_train2017.json
dataset_dir
:
/paddle/
dataset/coco/
dataset_dir
:
dataset/coco/
EvalDataset
:
EvalDataset
:
!COCODataSet
!COCODataSet
image_dir
:
val2017
image_dir
:
val2017
anno_path
:
annotations/instances_val2017.json
anno_path
:
annotations/instances_val2017.json
dataset_dir
:
/paddle/
dataset/coco/
dataset_dir
:
dataset/coco/
worker_num
:
6
worker_num
:
0
eval_height
:
&eval_height
416
eval_height
:
&eval_height
416
eval_width
:
&eval_width
416
eval_width
:
&eval_width
416
eval_size
:
&eval_size
[
*eval_height
,
*eval_width
]
eval_size
:
&eval_size
[
*eval_height
,
*eval_width
]
Eval
Reader
:
Train
Reader
:
sample_transforms
:
sample_transforms
:
-
Decode
:
{}
-
Decode
:
{}
-
Resize
:
{
interp
:
2
,
target_size
:
*eval_size
,
keep_ratio
:
False
}
-
Resize
:
{
interp
:
2
,
target_size
:
*eval_size
,
keep_ratio
:
False
}
-
NormalizeImage
:
{
mean
:
[
0
,
0
,
0
],
std
:
[
1
,
1
,
1
],
is_scale
:
True
}
-
NormalizeImage
:
{
mean
:
[
0
,
0
,
0
],
std
:
[
1
,
1
,
1
],
is_scale
:
True
}
-
Permute
:
{}
-
Permute
:
{}
batch_transforms
:
-
PadBatch
:
{
pad_to_stride
:
32
}
batch_size
:
8
batch_size
:
8
shuffle
:
false
shuffle
:
false
EvalReader
:
sample_transforms
:
-
Decode
:
{}
-
Resize
:
{
interp
:
2
,
target_size
:
*eval_size
,
keep_ratio
:
False
}
-
NormalizeImage
:
{
mean
:
[
0
,
0
,
0
],
std
:
[
1
,
1
,
1
],
is_scale
:
True
}
-
Permute
:
{}
batch_size
:
1
shuffle
:
false
example/full_quantization/
detection
/eval.py
→
example/full_quantization/
picodet
/eval.py
浏览文件 @
e263e885
...
@@ -22,6 +22,8 @@ from ppdet.core.workspace import create
...
@@ -22,6 +22,8 @@ from ppdet.core.workspace import create
from
ppdet.metrics
import
COCOMetric
,
VOCMetric
,
KeyPointTopDownCOCOEval
from
ppdet.metrics
import
COCOMetric
,
VOCMetric
,
KeyPointTopDownCOCOEval
from
paddleslim.common
import
load_config
as
load_slim_config
from
paddleslim.common
import
load_config
as
load_slim_config
from
post_process
import
PicoDetPostProcess
def
argsparser
():
def
argsparser
():
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
)
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
)
...
@@ -40,37 +42,7 @@ def argsparser():
...
@@ -40,37 +42,7 @@ def argsparser():
return
parser
return
parser
def
reader_wrapper
(
reader
,
input_list
):
def
eval
(
metric
):
def
gen
():
for
data
in
reader
:
in_dict
=
{}
if
isinstance
(
input_list
,
list
):
for
input_name
in
input_list
:
in_dict
[
input_name
]
=
data
[
input_name
]
elif
isinstance
(
input_list
,
dict
):
for
input_name
in
input_list
.
keys
():
in_dict
[
input_list
[
input_name
]]
=
data
[
input_name
]
yield
in_dict
return
gen
def
convert_numpy_data
(
data
,
metric
):
data_all
=
{}
data_all
=
{
k
:
np
.
array
(
v
)
for
k
,
v
in
data
.
items
()}
if
isinstance
(
metric
,
VOCMetric
):
for
k
,
v
in
data_all
.
items
():
if
not
isinstance
(
v
[
0
],
np
.
ndarray
):
tmp_list
=
[]
for
t
in
v
:
tmp_list
.
append
(
np
.
array
(
t
))
data_all
[
k
]
=
np
.
array
(
tmp_list
)
else
:
data_all
=
{
k
:
np
.
array
(
v
)
for
k
,
v
in
data
.
items
()}
return
data_all
def
eval
():
place
=
paddle
.
CUDAPlace
(
0
)
if
FLAGS
.
devices
==
'gpu'
else
paddle
.
CPUPlace
()
place
=
paddle
.
CUDAPlace
(
0
)
if
FLAGS
.
devices
==
'gpu'
else
paddle
.
CPUPlace
()
exe
=
paddle
.
static
.
Executor
(
place
)
exe
=
paddle
.
static
.
Executor
(
place
)
...
@@ -82,30 +54,46 @@ def eval():
...
@@ -82,30 +54,46 @@ def eval():
params_filename
=
global_config
[
"params_filename"
])
params_filename
=
global_config
[
"params_filename"
])
print
(
'Loaded model from: {}'
.
format
(
global_config
[
"model_dir"
]))
print
(
'Loaded model from: {}'
.
format
(
global_config
[
"model_dir"
]))
metric
=
global_config
[
'metric'
]
for
batch_id
,
data
in
enumerate
(
val_loader
):
for
batch_id
,
data
in
enumerate
(
val_loader
):
data_all
=
convert_numpy_data
(
data
,
metric
)
data_all
=
{
k
:
np
.
array
(
v
)
for
k
,
v
in
data
.
items
()}
batch_size
=
data_all
[
'image'
].
shape
[
0
]
data_input
=
{}
data_input
=
{}
for
k
,
v
in
data
.
items
():
for
k
,
v
in
data
.
items
():
if
isinstance
(
global_config
[
'input_list'
],
list
):
if
k
in
feed_target_names
:
if
k
in
global_config
[
'input_list'
]:
data_input
[
k
]
=
np
.
array
(
v
)
data_input
[
k
]
=
np
.
array
(
v
)
elif
isinstance
(
global_config
[
'input_list'
],
dict
):
if
k
in
global_config
[
'input_list'
].
keys
():
data_input
[
global_config
[
'input_list'
][
k
]]
=
np
.
array
(
v
)
outs
=
exe
.
run
(
val_program
,
outs
=
exe
.
run
(
val_program
,
feed
=
data_input
,
feed
=
data_input
,
fetch_list
=
fetch_targets
,
fetch_list
=
fetch_targets
,
return_numpy
=
False
)
return_numpy
=
False
)
res
=
{}
if
not
global_config
[
'include_post_process'
]:
np_score_list
,
np_boxes_list
=
[],
[]
for
out
in
outs
:
for
i
,
out
in
enumerate
(
outs
):
v
=
np
.
array
(
out
)
np_out
=
np
.
array
(
out
)
if
len
(
v
.
shape
)
>
1
:
if
i
<
4
:
res
[
'bbox'
]
=
v
num_classes
=
np_out
.
shape
[
-
1
]
else
:
np_score_list
.
append
(
res
[
'bbox_num'
]
=
v
np_out
.
reshape
(
batch_size
,
-
1
,
num_classes
))
else
:
box_reg_shape
=
np_out
.
shape
[
-
1
]
np_boxes_list
.
append
(
np_out
.
reshape
(
batch_size
,
-
1
,
box_reg_shape
))
post_processor
=
PicoDetPostProcess
(
data_all
[
'image'
].
shape
[
2
:],
data_all
[
'im_shape'
],
data_all
[
'scale_factor'
],
score_threshold
=
0.01
,
nms_threshold
=
0.6
)
res
=
post_processor
(
np_score_list
,
np_boxes_list
)
else
:
res
=
{}
for
out
in
outs
:
v
=
np
.
array
(
out
)
if
len
(
v
.
shape
)
>
1
:
res
[
'bbox'
]
=
v
else
:
res
[
'bbox_num'
]
=
v
metric
.
update
(
data_all
,
res
)
metric
.
update
(
data_all
,
res
)
if
batch_id
%
100
==
0
:
if
batch_id
%
100
==
0
:
print
(
'Eval iter:'
,
batch_id
)
print
(
'Eval iter:'
,
batch_id
)
...
@@ -125,26 +113,15 @@ def main():
...
@@ -125,26 +113,15 @@ def main():
val_loader
=
create
(
'EvalReader'
)(
reader_cfg
[
'EvalDataset'
],
val_loader
=
create
(
'EvalReader'
)(
reader_cfg
[
'EvalDataset'
],
reader_cfg
[
'worker_num'
],
reader_cfg
[
'worker_num'
],
return_list
=
True
)
return_list
=
True
)
global
num_classes
num_classes
=
reader_cfg
[
'num_classes'
]
metric
=
None
metric
=
None
if
reader_cfg
[
'metric'
]
==
'COCO'
:
clsid2catid
=
{
v
:
k
for
k
,
v
in
dataset
.
catid2clsid
.
items
()}
clsid2catid
=
{
v
:
k
for
k
,
v
in
dataset
.
catid2clsid
.
items
()}
anno_file
=
dataset
.
get_anno
()
anno_file
=
dataset
.
get_anno
()
metric
=
COCOMetric
(
metric
=
COCOMetric
(
anno_file
=
anno_file
,
clsid2catid
=
clsid2catid
,
IouType
=
'bbox'
)
anno_file
=
anno_file
,
clsid2catid
=
clsid2catid
,
IouType
=
'bbox'
)
elif
reader_cfg
[
'metric'
]
==
'VOC'
:
eval
(
metric
)
metric
=
VOCMetric
(
label_list
=
dataset
.
get_label_list
(),
class_num
=
reader_cfg
[
'num_classes'
],
map_type
=
reader_cfg
[
'map_type'
])
elif
reader_cfg
[
'metric'
]
==
'KeyPointTopDownCOCOEval'
:
anno_file
=
dataset
.
get_anno
()
metric
=
KeyPointTopDownCOCOEval
(
anno_file
,
len
(
dataset
),
17
,
'output_eval'
)
else
:
raise
ValueError
(
"metric currently only supports COCO and VOC."
)
global_config
[
'metric'
]
=
metric
eval
()
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
example/full_quantization/picodet/onnxruntime_eval.py
0 → 100644
浏览文件 @
e263e885
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
sys
import
numpy
as
np
import
argparse
import
paddle
from
ppdet.core.workspace
import
load_config
from
ppdet.core.workspace
import
create
from
ppdet.metrics
import
COCOMetric
import
onnxruntime
as
ort
from
post_process
import
PicoDetPostProcess
def
argsparser
():
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
)
parser
.
add_argument
(
'--reader_config'
,
type
=
str
,
default
=
'configs/picodet_reader.yml'
,
help
=
"path of compression strategy config."
,
required
=
True
)
parser
.
add_argument
(
'--model_path'
,
type
=
str
,
default
=
'onnx_file/picodet_s_416_npu_postprocessed.onnx'
,
help
=
"onnx filepath"
)
parser
.
add_argument
(
'--include_post_process'
,
type
=
bool
,
default
=
False
,
help
=
"Whether include post_process or not."
)
return
parser
def
eval
(
val_loader
,
metric
,
sess
):
inputs_name
=
[
a
.
name
for
a
in
sess
.
get_inputs
()]
for
batch_id
,
data
in
enumerate
(
val_loader
):
data_all
=
{
k
:
np
.
array
(
v
)
for
k
,
v
in
data
.
items
()}
batch_size
=
data_all
[
'image'
].
shape
[
0
]
data_input
=
{}
for
k
,
v
in
data
.
items
():
if
k
in
inputs_name
:
data_input
[
k
]
=
np
.
array
(
v
)
outs
=
sess
.
run
(
None
,
data_input
)
if
not
FLAGS
.
include_post_process
:
np_score_list
,
np_boxes_list
=
[],
[]
for
i
,
out
in
enumerate
(
outs
):
np_out
=
np
.
array
(
out
)
if
i
<
4
:
num_classes
=
np_out
.
shape
[
-
1
]
np_score_list
.
append
(
np_out
.
reshape
(
batch_size
,
-
1
,
num_classes
))
else
:
box_reg_shape
=
np_out
.
shape
[
-
1
]
np_boxes_list
.
append
(
np_out
.
reshape
(
batch_size
,
-
1
,
box_reg_shape
))
post_processor
=
PicoDetPostProcess
(
data_all
[
'image'
].
shape
[
2
:],
data_all
[
'im_shape'
],
data_all
[
'scale_factor'
],
score_threshold
=
0.01
,
nms_threshold
=
0.6
)
res
=
post_processor
(
np_score_list
,
np_boxes_list
)
else
:
res
=
{}
for
out
in
outs
:
v
=
np
.
array
(
out
)
if
len
(
v
.
shape
)
>
1
:
res
[
'bbox'
]
=
v
else
:
res
[
'bbox_num'
]
=
v
metric
.
update
(
data_all
,
res
)
if
batch_id
%
100
==
0
:
print
(
'Eval iter:'
,
batch_id
)
metric
.
accumulate
()
metric
.
log
()
metric
.
reset
()
def
main
():
reader_cfg
=
load_config
(
FLAGS
.
reader_config
)
dataset
=
reader_cfg
[
'EvalDataset'
]
val_loader
=
create
(
'EvalReader'
)(
reader_cfg
[
'EvalDataset'
],
reader_cfg
[
'worker_num'
],
return_list
=
True
)
clsid2catid
=
{
v
:
k
for
k
,
v
in
dataset
.
catid2clsid
.
items
()}
anno_file
=
dataset
.
get_anno
()
metric
=
COCOMetric
(
anno_file
=
anno_file
,
clsid2catid
=
clsid2catid
,
IouType
=
'bbox'
)
providers
=
[
'CPUExecutionProvider'
]
sess_options
=
ort
.
SessionOptions
()
sess_options
.
optimized_model_filepath
=
"./optimize_model.onnx"
sess
=
ort
.
InferenceSession
(
FLAGS
.
model_path
,
providers
=
providers
,
sess_options
=
sess_options
)
eval
(
val_loader
,
metric
,
sess
)
if
__name__
==
'__main__'
:
paddle
.
enable_static
()
parser
=
argsparser
()
FLAGS
=
parser
.
parse_args
()
# DataLoader need run on cpu
paddle
.
set_device
(
"cpu"
)
main
()
example/full_quantization/picodet/post_process.py
0 → 100644
浏览文件 @
e263e885
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
from
scipy.special
import
softmax
def
hard_nms
(
box_scores
,
iou_threshold
,
top_k
=-
1
,
candidate_size
=
200
):
"""
Args:
box_scores (N, 5): boxes in corner-form and probabilities.
iou_threshold: intersection over union threshold.
top_k: keep top_k results. If k <= 0, keep all the results.
candidate_size: only consider the candidates with the highest scores.
Returns:
picked: a list of indexes of the kept boxes
"""
scores
=
box_scores
[:,
-
1
]
boxes
=
box_scores
[:,
:
-
1
]
picked
=
[]
indexes
=
np
.
argsort
(
scores
)
indexes
=
indexes
[
-
candidate_size
:]
while
len
(
indexes
)
>
0
:
current
=
indexes
[
-
1
]
picked
.
append
(
current
)
if
0
<
top_k
==
len
(
picked
)
or
len
(
indexes
)
==
1
:
break
current_box
=
boxes
[
current
,
:]
indexes
=
indexes
[:
-
1
]
rest_boxes
=
boxes
[
indexes
,
:]
iou
=
iou_of
(
rest_boxes
,
np
.
expand_dims
(
current_box
,
axis
=
0
),
)
indexes
=
indexes
[
iou
<=
iou_threshold
]
return
box_scores
[
picked
,
:]
def
iou_of
(
boxes0
,
boxes1
,
eps
=
1e-5
):
"""Return intersection-over-union (Jaccard index) of boxes.
Args:
boxes0 (N, 4): ground truth boxes.
boxes1 (N or 1, 4): predicted boxes.
eps: a small number to avoid 0 as denominator.
Returns:
iou (N): IoU values.
"""
overlap_left_top
=
np
.
maximum
(
boxes0
[...,
:
2
],
boxes1
[...,
:
2
])
overlap_right_bottom
=
np
.
minimum
(
boxes0
[...,
2
:],
boxes1
[...,
2
:])
overlap_area
=
area_of
(
overlap_left_top
,
overlap_right_bottom
)
area0
=
area_of
(
boxes0
[...,
:
2
],
boxes0
[...,
2
:])
area1
=
area_of
(
boxes1
[...,
:
2
],
boxes1
[...,
2
:])
return
overlap_area
/
(
area0
+
area1
-
overlap_area
+
eps
)
def
area_of
(
left_top
,
right_bottom
):
"""Compute the areas of rectangles given two corners.
Args:
left_top (N, 2): left top corner.
right_bottom (N, 2): right bottom corner.
Returns:
area (N): return the area.
"""
hw
=
np
.
clip
(
right_bottom
-
left_top
,
0.0
,
None
)
return
hw
[...,
0
]
*
hw
[...,
1
]
class
PicoDetPostProcess
(
object
):
"""
Args:
input_shape (int): network input image size
ori_shape (int): ori image shape of before padding
scale_factor (float): scale factor of ori image
enable_mkldnn (bool): whether to open MKLDNN
"""
def
__init__
(
self
,
input_shape
,
ori_shape
,
scale_factor
,
strides
=
[
8
,
16
,
32
,
64
],
score_threshold
=
0.4
,
nms_threshold
=
0.5
,
nms_top_k
=
1000
,
keep_top_k
=
100
):
self
.
ori_shape
=
ori_shape
self
.
input_shape
=
input_shape
self
.
scale_factor
=
scale_factor
self
.
strides
=
strides
self
.
score_threshold
=
score_threshold
self
.
nms_threshold
=
nms_threshold
self
.
nms_top_k
=
nms_top_k
self
.
keep_top_k
=
keep_top_k
def
warp_boxes
(
self
,
boxes
,
ori_shape
):
"""Apply transform to boxes
"""
width
,
height
=
ori_shape
[
1
],
ori_shape
[
0
]
n
=
len
(
boxes
)
if
n
:
# warp points
xy
=
np
.
ones
((
n
*
4
,
3
))
xy
[:,
:
2
]
=
boxes
[:,
[
0
,
1
,
2
,
3
,
0
,
3
,
2
,
1
]].
reshape
(
n
*
4
,
2
)
# x1y1, x2y2, x1y2, x2y1
# xy = xy @ M.T # transform
xy
=
(
xy
[:,
:
2
]
/
xy
[:,
2
:
3
]).
reshape
(
n
,
8
)
# rescale
# create new boxes
x
=
xy
[:,
[
0
,
2
,
4
,
6
]]
y
=
xy
[:,
[
1
,
3
,
5
,
7
]]
xy
=
np
.
concatenate
(
(
x
.
min
(
1
),
y
.
min
(
1
),
x
.
max
(
1
),
y
.
max
(
1
))).
reshape
(
4
,
n
).
T
# clip boxes
xy
[:,
[
0
,
2
]]
=
xy
[:,
[
0
,
2
]].
clip
(
0
,
width
)
xy
[:,
[
1
,
3
]]
=
xy
[:,
[
1
,
3
]].
clip
(
0
,
height
)
return
xy
.
astype
(
np
.
float32
)
else
:
return
boxes
def
__call__
(
self
,
scores
,
raw_boxes
):
batch_size
=
raw_boxes
[
0
].
shape
[
0
]
reg_max
=
int
(
raw_boxes
[
0
].
shape
[
-
1
]
/
4
-
1
)
out_boxes_num
=
[]
out_boxes_list
=
[]
for
batch_id
in
range
(
batch_size
):
# generate centers
decode_boxes
=
[]
select_scores
=
[]
for
stride
,
box_distribute
,
score
in
zip
(
self
.
strides
,
raw_boxes
,
scores
):
box_distribute
=
box_distribute
[
batch_id
]
score
=
score
[
batch_id
]
# centers
fm_h
=
self
.
input_shape
[
0
]
/
stride
fm_w
=
self
.
input_shape
[
1
]
/
stride
h_range
=
np
.
arange
(
fm_h
)
w_range
=
np
.
arange
(
fm_w
)
ww
,
hh
=
np
.
meshgrid
(
w_range
,
h_range
)
ct_row
=
(
hh
.
flatten
()
+
0.5
)
*
stride
ct_col
=
(
ww
.
flatten
()
+
0.5
)
*
stride
center
=
np
.
stack
((
ct_col
,
ct_row
,
ct_col
,
ct_row
),
axis
=
1
)
# box distribution to distance
reg_range
=
np
.
arange
(
reg_max
+
1
)
box_distance
=
box_distribute
.
reshape
((
-
1
,
reg_max
+
1
))
box_distance
=
softmax
(
box_distance
,
axis
=
1
)
box_distance
=
box_distance
*
np
.
expand_dims
(
reg_range
,
axis
=
0
)
box_distance
=
np
.
sum
(
box_distance
,
axis
=
1
).
reshape
((
-
1
,
4
))
box_distance
=
box_distance
*
stride
# top K candidate
topk_idx
=
np
.
argsort
(
score
.
max
(
axis
=
1
))[::
-
1
]
topk_idx
=
topk_idx
[:
self
.
nms_top_k
]
center
=
center
[
topk_idx
]
score
=
score
[
topk_idx
]
box_distance
=
box_distance
[
topk_idx
]
# decode box
decode_box
=
center
+
[
-
1
,
-
1
,
1
,
1
]
*
box_distance
select_scores
.
append
(
score
)
decode_boxes
.
append
(
decode_box
)
# nms
bboxes
=
np
.
concatenate
(
decode_boxes
,
axis
=
0
)
confidences
=
np
.
concatenate
(
select_scores
,
axis
=
0
)
picked_box_probs
=
[]
picked_labels
=
[]
for
class_index
in
range
(
0
,
confidences
.
shape
[
1
]):
probs
=
confidences
[:,
class_index
]
mask
=
probs
>
self
.
score_threshold
probs
=
probs
[
mask
]
if
probs
.
shape
[
0
]
==
0
:
continue
subset_boxes
=
bboxes
[
mask
,
:]
box_probs
=
np
.
concatenate
(
[
subset_boxes
,
probs
.
reshape
(
-
1
,
1
)],
axis
=
1
)
box_probs
=
hard_nms
(
box_probs
,
iou_threshold
=
self
.
nms_threshold
,
top_k
=
self
.
keep_top_k
,
)
picked_box_probs
.
append
(
box_probs
)
picked_labels
.
extend
([
class_index
]
*
box_probs
.
shape
[
0
])
if
len
(
picked_box_probs
)
==
0
:
out_boxes_list
.
append
(
np
.
empty
((
0
,
4
)))
out_boxes_num
.
append
(
0
)
else
:
picked_box_probs
=
np
.
concatenate
(
picked_box_probs
)
# resize output boxes
picked_box_probs
[:,
:
4
]
=
self
.
warp_boxes
(
picked_box_probs
[:,
:
4
],
self
.
ori_shape
[
batch_id
])
im_scale
=
np
.
concatenate
([
self
.
scale_factor
[
batch_id
][::
-
1
],
self
.
scale_factor
[
batch_id
][::
-
1
]
])
picked_box_probs
[:,
:
4
]
/=
im_scale
# clas score box
out_boxes_list
.
append
(
np
.
concatenate
(
[
np
.
expand_dims
(
np
.
array
(
picked_labels
),
axis
=-
1
),
np
.
expand_dims
(
picked_box_probs
[:,
4
],
axis
=-
1
),
picked_box_probs
[:,
:
4
]
],
axis
=
1
))
out_boxes_num
.
append
(
len
(
picked_labels
))
out_boxes_list
=
np
.
concatenate
(
out_boxes_list
,
axis
=
0
)
out_boxes_num
=
np
.
asarray
(
out_boxes_num
).
astype
(
np
.
int32
)
return
{
'bbox'
:
out_boxes_list
,
'bbox_num'
:
out_boxes_num
}
\ No newline at end of file
example/full_quantization/picodet/post_quant.py
0 → 100644
浏览文件 @
e263e885
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
sys
import
numpy
as
np
import
argparse
import
paddle
from
ppdet.core.workspace
import
load_config
,
merge_config
from
ppdet.core.workspace
import
create
from
paddleslim.quant
import
quant_post_static
from
paddleslim.common
import
load_config
as
load_slim_config
def
argsparser
():
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
)
parser
.
add_argument
(
'--config_path'
,
type
=
str
,
default
=
None
,
help
=
"path of compression strategy config."
,
required
=
True
)
parser
.
add_argument
(
'--save_dir'
,
type
=
str
,
default
=
'ptq_out'
,
help
=
"directory to save compressed model."
)
parser
.
add_argument
(
'--devices'
,
type
=
str
,
default
=
'gpu'
,
help
=
"which device used to compress."
)
parser
.
add_argument
(
'--algo'
,
type
=
str
,
default
=
'avg'
,
help
=
"post quant algo."
)
return
parser
def
reader_wrapper
(
reader
,
input_list
):
def
gen
():
for
data
in
reader
:
in_dict
=
{}
if
isinstance
(
input_list
,
list
):
for
input_name
in
input_list
:
in_dict
[
input_name
]
=
data
[
input_name
]
elif
isinstance
(
input_list
,
dict
):
for
input_name
in
input_list
.
keys
():
in_dict
[
input_list
[
input_name
]]
=
data
[
input_name
]
yield
in_dict
return
gen
def
main
():
all_config
=
load_slim_config
(
FLAGS
.
config_path
)
global_config
=
all_config
[
"Global"
]
reader_cfg
=
load_config
(
global_config
[
'reader_config'
])
train_loader
=
create
(
'EvalReader'
)(
reader_cfg
[
'TrainDataset'
],
reader_cfg
[
'worker_num'
],
return_list
=
True
)
train_loader
=
reader_wrapper
(
train_loader
,
global_config
[
'input_list'
])
place
=
paddle
.
CUDAPlace
(
0
)
if
FLAGS
.
devices
==
'gpu'
else
paddle
.
CPUPlace
()
exe
=
paddle
.
static
.
Executor
(
place
)
quant_post_static
(
executor
=
exe
,
model_dir
=
global_config
[
"model_dir"
],
quantize_model_path
=
FLAGS
.
save_dir
,
data_loader
=
train_loader
,
model_filename
=
global_config
[
"model_filename"
],
params_filename
=
global_config
[
"params_filename"
],
batch_size
=
32
,
batch_nums
=
10
,
algo
=
FLAGS
.
algo
,
hist_percent
=
0.999
,
is_full_quantize
=
False
,
bias_correction
=
False
,
onnx_format
=
True
,
skip_tensor_list
=
None
)
if
__name__
==
'__main__'
:
paddle
.
enable_static
()
parser
=
argsparser
()
FLAGS
=
parser
.
parse_args
()
assert
FLAGS
.
devices
in
[
'cpu'
,
'gpu'
,
'xpu'
,
'npu'
]
paddle
.
set_device
(
FLAGS
.
devices
)
main
()
example/full_quantization/
detection
/run.py
→
example/full_quantization/
picodet
/run.py
浏览文件 @
e263e885
...
@@ -24,6 +24,8 @@ from ppdet.metrics import COCOMetric, VOCMetric, KeyPointTopDownCOCOEval
...
@@ -24,6 +24,8 @@ from ppdet.metrics import COCOMetric, VOCMetric, KeyPointTopDownCOCOEval
from
paddleslim.common
import
load_config
as
load_slim_config
from
paddleslim.common
import
load_config
as
load_slim_config
from
paddleslim.auto_compression
import
AutoCompression
from
paddleslim.auto_compression
import
AutoCompression
from
post_process
import
PicoDetPostProcess
def
argsparser
():
def
argsparser
():
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
)
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
)
...
@@ -62,48 +64,48 @@ def reader_wrapper(reader, input_list):
...
@@ -62,48 +64,48 @@ def reader_wrapper(reader, input_list):
return
gen
return
gen
def
convert_numpy_data
(
data
,
metric
):
data_all
=
{}
data_all
=
{
k
:
np
.
array
(
v
)
for
k
,
v
in
data
.
items
()}
if
isinstance
(
metric
,
VOCMetric
):
for
k
,
v
in
data_all
.
items
():
if
not
isinstance
(
v
[
0
],
np
.
ndarray
):
tmp_list
=
[]
for
t
in
v
:
tmp_list
.
append
(
np
.
array
(
t
))
data_all
[
k
]
=
np
.
array
(
tmp_list
)
else
:
data_all
=
{
k
:
np
.
array
(
v
)
for
k
,
v
in
data
.
items
()}
return
data_all
def
eval_function
(
exe
,
compiled_test_program
,
test_feed_names
,
test_fetch_list
):
def
eval_function
(
exe
,
compiled_test_program
,
test_feed_names
,
test_fetch_list
):
metric
=
global_config
[
'metric'
]
metric
=
global_config
[
'metric'
]
with
tqdm
(
with
tqdm
(
total
=
len
(
val_loader
),
total
=
len
(
val_loader
),
bar_format
=
'Evaluation stage, Run batch:|{bar}| {n_fmt}/{total_fmt}'
,
bar_format
=
'Evaluation stage, Run batch:|{bar}| {n_fmt}/{total_fmt}'
,
ncols
=
80
)
as
t
:
ncols
=
80
)
as
t
:
for
batch_id
,
data
in
enumerate
(
val_loader
):
for
data
in
val_loader
:
data_all
=
convert_numpy_data
(
data
,
metric
)
data_all
=
{
k
:
np
.
array
(
v
)
for
k
,
v
in
data
.
items
()}
batch_size
=
data_all
[
'image'
].
shape
[
0
]
data_input
=
{}
data_input
=
{}
for
k
,
v
in
data
.
items
():
for
k
,
v
in
data
.
items
():
if
isinstance
(
global_config
[
'input_list'
],
list
):
if
k
in
test_feed_names
:
if
k
in
test_feed_names
:
data_input
[
k
]
=
np
.
array
(
v
)
data_input
[
k
]
=
np
.
array
(
v
)
elif
isinstance
(
global_config
[
'input_list'
],
dict
):
if
k
in
global_config
[
'input_list'
].
keys
():
data_input
[
global_config
[
'input_list'
][
k
]]
=
np
.
array
(
v
)
outs
=
exe
.
run
(
compiled_test_program
,
outs
=
exe
.
run
(
compiled_test_program
,
feed
=
data_input
,
feed
=
data_input
,
fetch_list
=
test_fetch_list
,
fetch_list
=
test_fetch_list
,
return_numpy
=
False
)
return_numpy
=
False
)
res
=
{}
if
not
global_config
[
'include_post_process'
]:
for
out
in
outs
:
np_score_list
,
np_boxes_list
=
[],
[]
v
=
np
.
array
(
out
)
for
i
,
out
in
enumerate
(
outs
):
if
len
(
v
.
shape
)
>
1
:
if
i
<
4
:
res
[
'bbox'
]
=
v
np_score_list
.
append
(
else
:
np
.
array
(
out
).
reshape
(
batch_size
,
-
1
,
num_classes
))
res
[
'bbox_num'
]
=
v
else
:
np_boxes_list
.
append
(
np
.
array
(
out
).
reshape
(
batch_size
,
-
1
,
32
))
post_processor
=
PicoDetPostProcess
(
data_all
[
'image'
].
shape
[
2
:],
data_all
[
'im_shape'
],
data_all
[
'scale_factor'
],
score_threshold
=
0.01
,
nms_threshold
=
0.6
)
res
=
post_processor
(
np_score_list
,
np_boxes_list
)
else
:
res
=
{}
for
out
in
outs
:
v
=
np
.
array
(
out
)
if
len
(
v
.
shape
)
>
1
:
res
[
'bbox'
]
=
v
else
:
res
[
'bbox_num'
]
=
v
metric
.
update
(
data_all
,
res
)
metric
.
update
(
data_all
,
res
)
t
.
update
()
t
.
update
()
...
@@ -111,9 +113,7 @@ def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
...
@@ -111,9 +113,7 @@ def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
metric
.
log
()
metric
.
log
()
map_res
=
metric
.
get_results
()
map_res
=
metric
.
get_results
()
metric
.
reset
()
metric
.
reset
()
map_key
=
'keypoint'
if
'arch'
in
global_config
and
global_config
[
return
map_res
[
'bbox'
][
0
]
'arch'
]
==
'keypoint'
else
'bbox'
return
map_res
[
map_key
][
0
]
def
main
():
def
main
():
...
@@ -123,9 +123,9 @@ def main():
...
@@ -123,9 +123,9 @@ def main():
global_config
=
all_config
[
"Global"
]
global_config
=
all_config
[
"Global"
]
reader_cfg
=
load_config
(
global_config
[
'reader_config'
])
reader_cfg
=
load_config
(
global_config
[
'reader_config'
])
train_loader
=
create
(
'
Eval
Reader'
)(
reader_cfg
[
'TrainDataset'
],
train_loader
=
create
(
'
Train
Reader'
)(
reader_cfg
[
'TrainDataset'
],
reader_cfg
[
'worker_num'
],
reader_cfg
[
'worker_num'
],
return_list
=
True
)
return_list
=
True
)
train_loader
=
reader_wrapper
(
train_loader
,
global_config
[
'input_list'
])
train_loader
=
reader_wrapper
(
train_loader
,
global_config
[
'input_list'
])
if
'Evaluation'
in
global_config
.
keys
()
and
global_config
[
if
'Evaluation'
in
global_config
.
keys
()
and
global_config
[
...
@@ -139,23 +139,12 @@ def main():
...
@@ -139,23 +139,12 @@ def main():
reader_cfg
[
'worker_num'
],
reader_cfg
[
'worker_num'
],
batch_sampler
=
_eval_batch_sampler
,
batch_sampler
=
_eval_batch_sampler
,
return_list
=
True
)
return_list
=
True
)
metric
=
None
global
num_classes
if
reader_cfg
[
'metric'
]
==
'COCO'
:
num_classes
=
reader_cfg
[
'num_classes'
]
clsid2catid
=
{
v
:
k
for
k
,
v
in
dataset
.
catid2clsid
.
items
()}
clsid2catid
=
{
v
:
k
for
k
,
v
in
dataset
.
catid2clsid
.
items
()}
anno_file
=
dataset
.
get_anno
()
anno_file
=
dataset
.
get_anno
()
metric
=
COCOMetric
(
metric
=
COCOMetric
(
anno_file
=
anno_file
,
clsid2catid
=
clsid2catid
,
IouType
=
'bbox'
)
anno_file
=
anno_file
,
clsid2catid
=
clsid2catid
,
IouType
=
'bbox'
)
elif
reader_cfg
[
'metric'
]
==
'VOC'
:
metric
=
VOCMetric
(
label_list
=
dataset
.
get_label_list
(),
class_num
=
reader_cfg
[
'num_classes'
],
map_type
=
reader_cfg
[
'map_type'
])
elif
reader_cfg
[
'metric'
]
==
'KeyPointTopDownCOCOEval'
:
anno_file
=
dataset
.
get_anno
()
metric
=
KeyPointTopDownCOCOEval
(
anno_file
,
len
(
dataset
),
17
,
'output_eval'
)
else
:
raise
ValueError
(
"metric currently only supports COCO and VOC."
)
global_config
[
'metric'
]
=
metric
global_config
[
'metric'
]
=
metric
else
:
else
:
eval_func
=
None
eval_func
=
None
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录