Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
eade1b72
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
1 年多 前同步成功
通知
116
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
eade1b72
编写于
8月 22, 2022
作者:
C
cuicheng01
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix multilabel
上级
dab99e3e
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
60 addition
and
26 deletion
+60
-26
deploy/configs/inference_cls_multilabel.yaml
deploy/configs/inference_cls_multilabel.yaml
+5
-6
deploy/python/postprocess.py
deploy/python/postprocess.py
+21
-4
docs/en/quick_start/quick_start_multilabel_classification_en.md
...n/quick_start/quick_start_multilabel_classification_en.md
+1
-1
docs/zh_CN/quick_start/quick_start_multilabel_classification.md
...h_CN/quick_start/quick_start_multilabel_classification.md
+1
-1
ppcls/configs/quick_start/professional/MobileNetV1_multilabel.yaml
...figs/quick_start/professional/MobileNetV1_multilabel.yaml
+3
-4
ppcls/data/postprocess/__init__.py
ppcls/data/postprocess/__init__.py
+2
-2
ppcls/data/postprocess/threshoutput.py
ppcls/data/postprocess/threshoutput.py
+26
-0
ppcls/data/postprocess/topk.py
ppcls/data/postprocess/topk.py
+0
-7
ppcls/engine/engine.py
ppcls/engine/engine.py
+1
-1
未找到文件。
deploy/configs/inference_cls_multilabel.yaml
浏览文件 @
eade1b72
...
...
@@ -25,11 +25,10 @@ PreProcess:
order
:
'
'
channel_num
:
3
-
ToCHWImage
:
PostProcess
:
main_indicator
:
MultiLabelTopk
MultiLabelTopk
:
topk
:
5
class_id_map_file
:
None
main_indicator
:
MultiLabelThreshOutput
MultiLabelThreshOutput
:
threshold
:
0.5
SavePreLabel
:
save_dir
:
./pre_label/
\ No newline at end of file
save_dir
:
./pre_label/
deploy/python/postprocess.py
浏览文件 @
eade1b72
...
...
@@ -138,12 +138,29 @@ class Topk(object):
return
y
class
MultiLabelT
opk
(
Topk
):
def
__init__
(
self
,
t
opk
=
1
,
class_id_map_file
=
None
):
s
uper
().
__init__
()
class
MultiLabelT
hreshOutput
(
object
):
def
__init__
(
self
,
t
hreshold
=
0.5
):
s
elf
.
threshold
=
threshold
def
__call__
(
self
,
x
,
file_names
=
None
):
return
super
().
__call__
(
x
,
file_names
,
multilabel
=
True
)
y
=
[]
for
idx
,
probs
in
enumerate
(
x
):
index
=
np
.
where
(
probs
>=
self
.
threshold
)[
0
].
astype
(
"int32"
)
clas_id_list
=
[]
score_list
=
[]
for
i
in
index
:
clas_id_list
.
append
(
i
.
item
())
score_list
.
append
(
probs
[
i
].
item
())
result
=
{
"class_ids"
:
clas_id_list
,
"scores"
:
np
.
around
(
score_list
,
decimals
=
5
).
tolist
(),
"label_names"
:
[]
}
if
file_names
is
not
None
:
result
[
"file_name"
]
=
file_names
[
idx
]
y
.
append
(
result
)
return
y
class
SavePreLabel
(
object
):
...
...
docs/en/quick_start/quick_start_multilabel_classification_en.md
浏览文件 @
eade1b72
...
...
@@ -107,7 +107,7 @@ Inference and prediction through predictive engines:
```
python3 python/predict_cls.py \
-c configs/inference_
multilabel_cls
.yaml
-c configs/inference_
cls_multilabel
.yaml
```
Obtain an output silimar to the following:
...
...
docs/zh_CN/quick_start/quick_start_multilabel_classification.md
浏览文件 @
eade1b72
...
...
@@ -100,7 +100,7 @@ cd ./deploy
```
python3 python/predict_cls.py \
-c configs/inference_
multilabel_cls
.yaml
-c configs/inference_
cls_multilabel
.yaml
```
得到类似下面的输出:
...
...
ppcls/configs/quick_start/professional/MobileNetV1_multilabel.yaml
浏览文件 @
eade1b72
...
...
@@ -99,7 +99,7 @@ DataLoader:
use_shared_memory
:
True
Infer
:
infer_imgs
:
./
deploy/images/0517_2715693311.jpg
infer_imgs
:
deploy/images/0517_2715693311.jpg
batch_size
:
10
transforms
:
-
DecodeImage
:
...
...
@@ -116,9 +116,8 @@ Infer:
order
:
'
'
-
ToCHWImage
:
PostProcess
:
name
:
MultiLabelTopk
topk
:
5
class_id_map_file
:
None
name
:
MultiLabelThreshOutput
threshold
:
0.5
Metric
:
Train
:
...
...
ppcls/data/postprocess/__init__.py
浏览文件 @
eade1b72
...
...
@@ -16,8 +16,8 @@ import importlib
from
.
import
topk
,
threshoutput
from
.topk
import
Topk
,
MultiLabelTopk
from
.threshoutput
import
ThreshOutput
from
.topk
import
Topk
from
.threshoutput
import
ThreshOutput
,
MultiLabelThreshOutput
from
.attr_rec
import
VehicleAttribute
,
PersonAttribute
...
...
ppcls/data/postprocess/threshoutput.py
浏览文件 @
eade1b72
...
...
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
import
paddle.nn.functional
as
F
...
...
@@ -34,3 +35,28 @@ class ThreshOutput(object):
result
[
"file_name"
]
=
file_names
[
idx
]
y
.
append
(
result
)
return
y
class
MultiLabelThreshOutput
(
object
):
def
__init__
(
self
,
threshold
=
0.5
):
self
.
threshold
=
threshold
def
__call__
(
self
,
x
,
file_names
=
None
):
y
=
[]
x
=
F
.
sigmoid
(
x
).
numpy
()
for
idx
,
probs
in
enumerate
(
x
):
index
=
np
.
where
(
probs
>=
self
.
threshold
)[
0
].
astype
(
"int32"
)
clas_id_list
=
[]
score_list
=
[]
for
i
in
index
:
clas_id_list
.
append
(
i
.
item
())
score_list
.
append
(
probs
[
i
].
item
())
result
=
{
"class_ids"
:
clas_id_list
,
"scores"
:
np
.
around
(
score_list
,
decimals
=
5
).
tolist
(),
}
if
file_names
is
not
None
:
result
[
"file_name"
]
=
file_names
[
idx
]
y
.
append
(
result
)
return
y
ppcls/data/postprocess/topk.py
浏览文件 @
eade1b72
...
...
@@ -79,10 +79,3 @@ class Topk(object):
y
.
append
(
result
)
return
y
class
MultiLabelTopk
(
Topk
):
def
__init__
(
self
,
topk
=
1
,
class_id_map_file
=
None
):
super
().
__init__
()
def
__call__
(
self
,
x
,
file_names
=
None
):
return
super
().
__call__
(
x
,
file_names
,
multilabel
=
True
)
ppcls/engine/engine.py
浏览文件 @
eade1b72
...
...
@@ -501,7 +501,7 @@ class Engine(object):
assert
self
.
mode
==
"export"
use_multilabel
=
self
.
config
[
"Global"
].
get
(
"use_multilabel"
,
False
)
and
"ATTRMetric"
in
self
.
config
[
"Metric"
][
"Eval"
][
0
]
False
)
or
"ATTRMetric"
in
self
.
config
[
"Metric"
][
"Eval"
][
0
]
model
=
ExportModel
(
self
.
config
[
"Arch"
],
self
.
model
,
use_multilabel
)
if
self
.
config
[
"Global"
][
"pretrained_model"
]
is
not
None
:
load_dygraph_pretrain
(
model
.
base_model
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录