提交 26d5b7d1 编写于 作者: Z zhiboniu

adapted dataset and loss

上级 aa8f4c16
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
...@@ -14,7 +14,6 @@ Global: ...@@ -14,7 +14,6 @@ Global:
image_shape: [3, 256, 192] image_shape: [3, 256, 192]
save_inference_dir: "./inference" save_inference_dir: "./inference"
use_multilabel: True use_multilabel: True
metric_attr: True
# model architecture # model architecture
Arch: Arch:
...@@ -26,11 +25,15 @@ Arch: ...@@ -26,11 +25,15 @@ Arch:
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
Train: Train:
- BCELoss: - MultiLabelLoss:
weight: 1.0 weight: 1.0
weight_ratio: True
size_sum: True
Eval: Eval:
- BCELoss: - MultiLabelLoss:
weight: 1.0 weight: 1.0
weight_ratio: True
size_sum: True
Optimizer: Optimizer:
name: Adam name: Adam
...@@ -47,10 +50,10 @@ Optimizer: ...@@ -47,10 +50,10 @@ Optimizer:
DataLoader: DataLoader:
Train: Train:
dataset: dataset:
name: AttrDataset name: MultiLabelDataset
image_root: "dataset/xingrenfenxi/data/" image_root: "dataset/xingrenfenxi/data/"
cls_label_path: "dataset/xingrenfenxi/all_qiye.pkl" cls_label_path: "dataset/xingrenfenxi/trainval.txt"
split: 'trainval' label_ratio: True
transform_ops: transform_ops:
- DecodeImage: - DecodeImage:
to_rgb: True to_rgb: True
...@@ -80,10 +83,10 @@ DataLoader: ...@@ -80,10 +83,10 @@ DataLoader:
use_shared_memory: True use_shared_memory: True
Eval: Eval:
dataset: dataset:
name: AttrDataset name: MultiLabelDataset
image_root: "dataset/xingrenfenxi/data/" image_root: "dataset/xingrenfenxi/data/"
cls_label_path: "dataset/xingrenfenxi/all_qiye.pkl" cls_label_path: "dataset/xingrenfenxi/test.txt"
split: 'test' label_ratio: True
transform_ops: transform_ops:
- DecodeImage: - DecodeImage:
to_rgb: True to_rgb: True
......
...@@ -30,7 +30,6 @@ from ppcls.data.dataloader.icartoon_dataset import ICartoonDataset ...@@ -30,7 +30,6 @@ from ppcls.data.dataloader.icartoon_dataset import ICartoonDataset
from ppcls.data.dataloader.mix_dataset import MixDataset from ppcls.data.dataloader.mix_dataset import MixDataset
from ppcls.data.dataloader.multi_scale_dataset import MultiScaleDataset from ppcls.data.dataloader.multi_scale_dataset import MultiScaleDataset
from ppcls.data.dataloader.person_dataset import Market1501, MSMT17 from ppcls.data.dataloader.person_dataset import Market1501, MSMT17
from ppcls.data.dataloader.attr_dataset import AttrDataset
# sampler # sampler
......
...@@ -10,4 +10,3 @@ from ppcls.data.dataloader.mix_sampler import MixSampler ...@@ -10,4 +10,3 @@ from ppcls.data.dataloader.mix_sampler import MixSampler
from ppcls.data.dataloader.multi_scale_sampler import MultiScaleSampler from ppcls.data.dataloader.multi_scale_sampler import MultiScaleSampler
from ppcls.data.dataloader.pk_sampler import PKSampler from ppcls.data.dataloader.pk_sampler import PKSampler
from ppcls.data.dataloader.person_dataset import Market1501, MSMT17 from ppcls.data.dataloader.person_dataset import Market1501, MSMT17
from ppcls.data.dataloader.attr_dataset import AttrDataset
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import os
import pickle
from .common_dataset import CommonDataset
from ppcls.data.preprocess import transform
class AttrDataset(CommonDataset):
def _load_anno(self, seed=None, split='trainval'):
assert os.path.exists(self._cls_path)
assert os.path.exists(self._img_root)
anno_path = self._cls_path
image_dir = self._img_root
self.images = []
self.labels = []
dataset_info = pickle.load(open(anno_path, 'rb+'))
img_id = dataset_info.image_name
attr_label = dataset_info.label
attr_label[attr_label == 2] = 0
attr_id = dataset_info.attr_name
if 'label_idx' in dataset_info.keys():
eval_attr_idx = dataset_info.label_idx.eval
attr_label = attr_label[:, eval_attr_idx]
attr_id = [attr_id[i] for i in eval_attr_idx]
attr_num = len(attr_id)
# mapping category name to class id
# first_class:0, second_class:1, ...
cname2cid = {attr_id[i]: i for i in range(attr_num)}
assert split in dataset_info.partition.keys(
), f'split {split} is not exist'
img_idx = dataset_info.partition[split]
if isinstance(img_idx, list):
img_idx = img_idx[0] # default partition 0
img_num = img_idx.shape[0]
img_id = [img_id[i] for i in img_idx]
label = attr_label[img_idx] # [:, [0, 12]]
self.label_ratio = label.mean(0)
print("label_ratio:", self.label_ratio)
for i, (img_i, label_i) in enumerate(zip(img_id, label)):
imgname = os.path.join(image_dir, img_i)
self.images.append(imgname)
self.labels.append(np.int64(label_i))
def __getitem__(self, idx):
try:
with open(self.images[idx], 'rb') as f:
img = f.read()
if self._transform_ops:
img = transform(img, self._transform_ops)
img = img.transpose((2, 0, 1))
return (img, [self.labels[idx], self.label_ratio])
except Exception as ex:
logger.error("Exception occured when parse line: {} with msg: {}".
format(self.images[idx], ex))
rnd_idx = np.random.randint(self.__len__())
return self.__getitem__(rnd_idx)
...@@ -48,7 +48,7 @@ class CommonDataset(Dataset): ...@@ -48,7 +48,7 @@ class CommonDataset(Dataset):
image_root, image_root,
cls_label_path, cls_label_path,
transform_ops=None, transform_ops=None,
split='trainval'): label_ratio=False):
self._img_root = image_root self._img_root = image_root
self._cls_path = cls_label_path self._cls_path = cls_label_path
if transform_ops: if transform_ops:
...@@ -56,7 +56,10 @@ class CommonDataset(Dataset): ...@@ -56,7 +56,10 @@ class CommonDataset(Dataset):
self.images = [] self.images = []
self.labels = [] self.labels = []
self._load_anno(split=split) if label_ratio:
self.label_ratio = self._load_anno(label_ratio=label_ratio)
else:
self._load_anno()
def _load_anno(self): def _load_anno(self):
pass pass
......
...@@ -25,7 +25,7 @@ from .common_dataset import CommonDataset ...@@ -25,7 +25,7 @@ from .common_dataset import CommonDataset
class MultiLabelDataset(CommonDataset): class MultiLabelDataset(CommonDataset):
def _load_anno(self): def _load_anno(self, label_ratio=False):
assert os.path.exists(self._cls_path) assert os.path.exists(self._cls_path)
assert os.path.exists(self._img_root) assert os.path.exists(self._img_root)
self.images = [] self.images = []
...@@ -41,6 +41,8 @@ class MultiLabelDataset(CommonDataset): ...@@ -41,6 +41,8 @@ class MultiLabelDataset(CommonDataset):
self.labels.append(labels) self.labels.append(labels)
assert os.path.exists(self.images[-1]) assert os.path.exists(self.images[-1])
if label_ratio:
return np.array(self.labels).mean(0)
def __getitem__(self, idx): def __getitem__(self, idx):
try: try:
...@@ -50,7 +52,10 @@ class MultiLabelDataset(CommonDataset): ...@@ -50,7 +52,10 @@ class MultiLabelDataset(CommonDataset):
img = transform(img, self._transform_ops) img = transform(img, self._transform_ops)
img = img.transpose((2, 0, 1)) img = img.transpose((2, 0, 1))
label = np.array(self.labels[idx]).astype("float32") label = np.array(self.labels[idx]).astype("float32")
return (img, label) if self.label_ratio is not None:
return (img, [label, self.label_ratio])
else:
return (img, label)
except Exception as ex: except Exception as ex:
logger.error("Exception occured when parse line: {} with msg: {}". logger.error("Exception occured when parse line: {} with msg: {}".
......
...@@ -32,8 +32,8 @@ def classification_eval(engine, epoch_id=0): ...@@ -32,8 +32,8 @@ def classification_eval(engine, epoch_id=0):
} }
print_batch_step = engine.config["Global"]["print_batch_step"] print_batch_step = engine.config["Global"]["print_batch_step"]
if engine.eval_metric_func is not None and engine.config["Global"][ if engine.eval_metric_func is not None and engine.config["Arch"][
"metric_attr"]: "name"] == "StrongBaselineAttr":
output_info["attr"] = AttrMeter(threshold=0.5) output_info["attr"] = AttrMeter(threshold=0.5)
metric_key = None metric_key = None
...@@ -128,7 +128,7 @@ def classification_eval(engine, epoch_id=0): ...@@ -128,7 +128,7 @@ def classification_eval(engine, epoch_id=0):
# calc metric # calc metric
if engine.eval_metric_func is not None: if engine.eval_metric_func is not None:
if engine.config["Global"]["metric_attr"]: if engine.config["Arch"]["name"] == "StrongBaselineAttr":
metric_dict = engine.eval_metric_func(preds, labels) metric_dict = engine.eval_metric_func(preds, labels)
metric_key = "attr" metric_key = "attr"
output_info["attr"].update(metric_dict) output_info["attr"].update(metric_dict)
...@@ -153,7 +153,7 @@ def classification_eval(engine, epoch_id=0): ...@@ -153,7 +153,7 @@ def classification_eval(engine, epoch_id=0):
ips_msg = "ips: {:.5f} images/sec".format( ips_msg = "ips: {:.5f} images/sec".format(
batch_size / time_info["batch_cost"].avg) batch_size / time_info["batch_cost"].avg)
if engine.config["Global"]["metric_attr"]: if engine.config["Arch"]["name"] == "StrongBaselineAttr":
metric_msg = "" metric_msg = ""
else: else:
metric_msg = ", ".join([ metric_msg = ", ".join([
...@@ -168,7 +168,7 @@ def classification_eval(engine, epoch_id=0): ...@@ -168,7 +168,7 @@ def classification_eval(engine, epoch_id=0):
if engine.use_dali: if engine.use_dali:
engine.eval_dataloader.reset() engine.eval_dataloader.reset()
if engine.config["Global"]["metric_attr"]: if engine.config["Arch"]["name"] == "StrongBaselineAttr":
metric_msg = ", ".join([ metric_msg = ", ".join([
"evalres: ma: {:.5f} label_f1: {:.5f} label_pos_recall: {:.5f} label_neg_recall: {:.5f} instance_f1: {:.5f} instance_acc: {:.5f} instance_prec: {:.5f} instance_recall: {:.5f}". "evalres: ma: {:.5f} label_f1: {:.5f} label_pos_recall: {:.5f} label_neg_recall: {:.5f} instance_f1: {:.5f} instance_acc: {:.5f} instance_prec: {:.5f} instance_recall: {:.5f}".
format(*output_info["attr"].res()) format(*output_info["attr"].res())
......
...@@ -26,7 +26,6 @@ from .distillationloss import DistillationKLDivLoss ...@@ -26,7 +26,6 @@ from .distillationloss import DistillationKLDivLoss
from .distillationloss import DistillationDKDLoss from .distillationloss import DistillationDKDLoss
from .multilabelloss import MultiLabelLoss from .multilabelloss import MultiLabelLoss
from .afdloss import AFDLoss from .afdloss import AFDLoss
from .bceloss import BCELoss
from .deephashloss import DSHSDLoss from .deephashloss import DSHSDLoss
from .deephashloss import LCDSHLoss from .deephashloss import LCDSHLoss
......
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
def ratio2weight(targets, ratio):
# print(targets)
pos_weights = targets * (1. - ratio)
neg_weights = (1. - targets) * ratio
weights = paddle.exp(neg_weights + pos_weights)
# for RAP dataloader, targets element may be 2, with or without smooth, some element must great than 1
weights = weights - weights * (targets > 1)
return weights
class BCELoss(nn.Layer):
"""BCE Loss.
Args:
"""
def __init__(self,
sample_weight=True,
size_sum=True,
smoothing=None,
weight=1.0):
super(BCELoss, self).__init__()
self.sample_weight = sample_weight
self.size_sum = size_sum
self.hyper = 0.8
self.smoothing = smoothing
def forward(self, logits, labels):
targets, ratio = labels
if self.smoothing is not None:
targets = (1 - self.smoothing) * targets + self.smoothing * (
1 - targets)
targets = paddle.cast(targets, 'float32')
loss_m = F.binary_cross_entropy_with_logits(
logits, targets, reduction='none')
targets_mask = paddle.cast(targets > 0.5, 'float32')
if self.sample_weight:
weight = ratio2weight(targets_mask, ratio[0])
weight = weight * (targets > -1)
loss_m = loss_m * weight
loss = loss_m.sum(1).mean() if self.size_sum else loss_m.sum()
return {"BCELoss": loss}
...@@ -19,12 +19,13 @@ class MultiLabelLoss(nn.Layer): ...@@ -19,12 +19,13 @@ class MultiLabelLoss(nn.Layer):
Multi-label loss Multi-label loss
""" """
def __init__(self, epsilon=None, weight_ratio=None): def __init__(self, epsilon=None, size_sum=False, weight_ratio=False):
super().__init__() super().__init__()
if epsilon is not None and (epsilon <= 0 or epsilon >= 1): if epsilon is not None and (epsilon <= 0 or epsilon >= 1):
epsilon = None epsilon = None
self.epsilon = epsilon self.epsilon = epsilon
self.weight_ratio = weight_ratio self.weight_ratio = weight_ratio
self.size_sum = size_sum
def _labelsmoothing(self, target, class_num): def _labelsmoothing(self, target, class_num):
if target.ndim == 1 or target.shape[-1] != class_num: if target.ndim == 1 or target.shape[-1] != class_num:
...@@ -36,18 +37,21 @@ class MultiLabelLoss(nn.Layer): ...@@ -36,18 +37,21 @@ class MultiLabelLoss(nn.Layer):
return soft_target return soft_target
def _binary_crossentropy(self, input, target, class_num): def _binary_crossentropy(self, input, target, class_num):
if self.weight_ratio:
target, label_ratio = target
if self.epsilon is not None: if self.epsilon is not None:
target = self._labelsmoothing(target, class_num) target = self._labelsmoothing(target, class_num)
cost = F.binary_cross_entropy_with_logits(logit=input, label=target) cost = F.binary_cross_entropy_with_logits(
logit=input, label=target, reduction='none')
if self.weight_ratio is not None: if self.weight_ratio:
targets_mask = paddle.cast(target > 0.5, 'float32') targets_mask = paddle.cast(target > 0.5, 'float32')
weight = ratio2weight(targets_mask, weight = ratio2weight(targets_mask, paddle.to_tensor(label_ratio))
paddle.to_tensor(self.weight_ratio))
weight = weight * (target > -1) weight = weight * (target > -1)
cost = cost * weight cost = cost * weight
import pdb
pdb.set_trace() if self.size_sum:
cost = cost.sum(1).mean() if self.size_sum else cost.mean()
return cost return cost
......
...@@ -9,3 +9,4 @@ scipy ...@@ -9,3 +9,4 @@ scipy
scikit-learn==0.23.2 scikit-learn==0.23.2
gast==0.3.3 gast==0.3.3
faiss-cpu==1.7.1.post2 faiss-cpu==1.7.1.post2
easydict
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册