提交 eade1b72 编写于 作者: C cuicheng01

fix multilabel

上级 dab99e3e
...@@ -27,9 +27,8 @@ PreProcess: ...@@ -27,9 +27,8 @@ PreProcess:
- ToCHWImage: - ToCHWImage:
PostProcess: PostProcess:
main_indicator: MultiLabelTopk main_indicator: MultiLabelThreshOutput
MultiLabelTopk: MultiLabelThreshOutput:
topk: 5 threshold: 0.5
class_id_map_file: None
SavePreLabel: SavePreLabel:
save_dir: ./pre_label/ save_dir: ./pre_label/
...@@ -138,12 +138,29 @@ class Topk(object): ...@@ -138,12 +138,29 @@ class Topk(object):
return y return y
class MultiLabelTopk(Topk): class MultiLabelThreshOutput(object):
def __init__(self, topk=1, class_id_map_file=None): def __init__(self, threshold=0.5):
super().__init__() self.threshold = threshold
def __call__(self, x, file_names=None): def __call__(self, x, file_names=None):
return super().__call__(x, file_names, multilabel=True) y = []
for idx, probs in enumerate(x):
index = np.where(probs >= self.threshold)[0].astype("int32")
clas_id_list = []
score_list = []
for i in index:
clas_id_list.append(i.item())
score_list.append(probs[i].item())
result = {
"class_ids": clas_id_list,
"scores": np.around(
score_list, decimals=5).tolist(),
"label_names": []
}
if file_names is not None:
result["file_name"] = file_names[idx]
y.append(result)
return y
class SavePreLabel(object): class SavePreLabel(object):
......
...@@ -107,7 +107,7 @@ Inference and prediction through predictive engines: ...@@ -107,7 +107,7 @@ Inference and prediction through predictive engines:
``` ```
python3 python/predict_cls.py \ python3 python/predict_cls.py \
-c configs/inference_multilabel_cls.yaml -c configs/inference_cls_multilabel.yaml
``` ```
Obtain an output silimar to the following: Obtain an output silimar to the following:
......
...@@ -100,7 +100,7 @@ cd ./deploy ...@@ -100,7 +100,7 @@ cd ./deploy
``` ```
python3 python/predict_cls.py \ python3 python/predict_cls.py \
-c configs/inference_multilabel_cls.yaml -c configs/inference_cls_multilabel.yaml
``` ```
得到类似下面的输出: 得到类似下面的输出:
......
...@@ -99,7 +99,7 @@ DataLoader: ...@@ -99,7 +99,7 @@ DataLoader:
use_shared_memory: True use_shared_memory: True
Infer: Infer:
infer_imgs: ./deploy/images/0517_2715693311.jpg infer_imgs: deploy/images/0517_2715693311.jpg
batch_size: 10 batch_size: 10
transforms: transforms:
- DecodeImage: - DecodeImage:
...@@ -116,9 +116,8 @@ Infer: ...@@ -116,9 +116,8 @@ Infer:
order: '' order: ''
- ToCHWImage: - ToCHWImage:
PostProcess: PostProcess:
name: MultiLabelTopk name: MultiLabelThreshOutput
topk: 5 threshold: 0.5
class_id_map_file: None
Metric: Metric:
Train: Train:
......
...@@ -16,8 +16,8 @@ import importlib ...@@ -16,8 +16,8 @@ import importlib
from . import topk, threshoutput from . import topk, threshoutput
from .topk import Topk, MultiLabelTopk from .topk import Topk
from .threshoutput import ThreshOutput from .threshoutput import ThreshOutput, MultiLabelThreshOutput
from .attr_rec import VehicleAttribute, PersonAttribute from .attr_rec import VehicleAttribute, PersonAttribute
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import numpy as np
import paddle.nn.functional as F import paddle.nn.functional as F
...@@ -34,3 +35,28 @@ class ThreshOutput(object): ...@@ -34,3 +35,28 @@ class ThreshOutput(object):
result["file_name"] = file_names[idx] result["file_name"] = file_names[idx]
y.append(result) y.append(result)
return y return y
class MultiLabelThreshOutput(object):
def __init__(self, threshold=0.5):
self.threshold = threshold
def __call__(self, x, file_names=None):
y = []
x = F.sigmoid(x).numpy()
for idx, probs in enumerate(x):
index = np.where(probs >= self.threshold)[0].astype("int32")
clas_id_list = []
score_list = []
for i in index:
clas_id_list.append(i.item())
score_list.append(probs[i].item())
result = {
"class_ids": clas_id_list,
"scores": np.around(
score_list, decimals=5).tolist(),
}
if file_names is not None:
result["file_name"] = file_names[idx]
y.append(result)
return y
...@@ -79,10 +79,3 @@ class Topk(object): ...@@ -79,10 +79,3 @@ class Topk(object):
y.append(result) y.append(result)
return y return y
class MultiLabelTopk(Topk):
def __init__(self, topk=1, class_id_map_file=None):
super().__init__()
def __call__(self, x, file_names=None):
return super().__call__(x, file_names, multilabel=True)
...@@ -501,7 +501,7 @@ class Engine(object): ...@@ -501,7 +501,7 @@ class Engine(object):
assert self.mode == "export" assert self.mode == "export"
use_multilabel = self.config["Global"].get( use_multilabel = self.config["Global"].get(
"use_multilabel", "use_multilabel",
False) and "ATTRMetric" in self.config["Metric"]["Eval"][0] False) or "ATTRMetric" in self.config["Metric"]["Eval"][0]
model = ExportModel(self.config["Arch"], self.model, use_multilabel) model = ExportModel(self.config["Arch"], self.model, use_multilabel)
if self.config["Global"]["pretrained_model"] is not None: if self.config["Global"]["pretrained_model"] is not None:
load_dygraph_pretrain(model.base_model, load_dygraph_pretrain(model.base_model,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册