提交 0a3ecf60 编写于 作者: Z zhiboniu

add attribute strongbaseline

上级 675e60d5
...@@ -70,6 +70,7 @@ from ppcls.arch.backbone.model_zoo.van import VAN_tiny ...@@ -70,6 +70,7 @@ from ppcls.arch.backbone.model_zoo.van import VAN_tiny
from ppcls.arch.backbone.variant_models.resnet_variant import ResNet50_last_stage_stride1 from ppcls.arch.backbone.variant_models.resnet_variant import ResNet50_last_stage_stride1
from ppcls.arch.backbone.variant_models.vgg_variant import VGG19Sigmoid from ppcls.arch.backbone.variant_models.vgg_variant import VGG19Sigmoid
from ppcls.arch.backbone.variant_models.pp_lcnet_variant import PPLCNet_x2_5_Tanh from ppcls.arch.backbone.variant_models.pp_lcnet_variant import PPLCNet_x2_5_Tanh
from ppcls.arch.backbone.model_zoo.strongbaseline_attr import StrongBaselineAttr
# help whl get all the models' api (class type) and components' api (func type) # help whl get all the models' api (class type) and components' api (func type)
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import paddle
from paddle import ParamAttr
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn import Conv2D, BatchNorm, Linear, Dropout
from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
from paddle.nn.initializer import Uniform
import math
from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url, get_weights_path_from_url
from ..legendary_models.resnet import ResNet50
MODEL_URLS = {"StrongBaselineAttr": "strongbaseline_attr_clas", }
__all__ = list(MODEL_URLS.keys())
class StrongBaselinePAR(nn.Layer):
def __init__(
self,
**config, ):
"""
A strong baseline for Pedestrian Attribute Recognition, see https://arxiv.org/abs/2107.03576
Args:
backbone (object): backbone instance
classifier (object): classifier instance
loss (object): loss instance
"""
super(StrongBaselinePAR, self).__init__()
backbone_config = config["Backbone"]
backbone_name = backbone_config.pop("name")
self.backbone = eval(backbone_name)(**backbone_config)
def forward(self, x):
fc_feat = self.backbone(x)
output = F.sigmoid(fc_feat)
return output
def _load_pretrained(pretrained, model, model_url, use_ssld):
if pretrained is False:
pass
elif pretrained is True:
load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld)
elif isinstance(pretrained, str):
load_dygraph_pretrain(model, pretrained)
else:
raise RuntimeError(
"pretrained type is not available. Please use `string` or `boolean` type."
)
def load_pretrained(model, local_weight_path):
# local_weight_path = get_weights_path_from_url(model_url).replace(
# ".pdparams", "")
param_state_dict = paddle.load(local_weight_path + ".pdparams")
model_dict = model.state_dict()
model_dict_keys = list(model_dict.keys())
param_state_dict_keys = list(param_state_dict.keys())
# assert(len(model_dict_keys) == len(param_state_dict_keys)), "{} == {}".format(len(model_dict_keys), len(param_state_dict_keys))
for idx in range(len(model_dict.keys())):
model_key = model_dict_keys[idx]
param_key = param_state_dict_keys[idx]
if model_dict[model_key].shape == param_state_dict[param_key].shape:
model_dict[model_key] = param_state_dict[param_key]
else:
print("miss match idx: {} weights: {} vs {}; {} vs {}".format(
idx, model_key, param_key, model_dict[
model_key].shape, param_state_dict[param_key].shape))
model.set_dict(model_dict)
def StrongBaselineAttr(pretrained=True, use_ssld=False, **kwargs):
model = StrongBaselinePAR(**kwargs)
_load_pretrained(MODEL_URLS["StrongBaselineAttr"], model, None, None)
return model
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: "./output/"
device: "gpu"
save_interval: 5
eval_during_train: False
eval_interval: 1
epochs: 30
print_batch_step: 20
use_visualdl: False
# used for static mode and model export
image_shape: [3, 256, 192]
save_inference_dir: "./inference"
use_multilabel: True
metric_attr: True
# model architecture
Arch:
name: "StrongBaselineAttr"
Backbone:
name: "ResNet50"
class_num: 26
# loss function config for traing/eval process
Loss:
Train:
- BCELoss:
weight: 1.0
Eval:
- BCELoss:
weight: 1.0
Optimizer:
name: Adam
lr:
name: Piecewise
decay_epochs: [12, 18, 24, 28]
values: [0.0001, 0.00001, 0.000001, 0.0000001]
regularizer:
name: 'L2'
coeff: 0.0005
clip_norm: 10
# data loader for train and eval
DataLoader:
Train:
dataset:
name: AttrDataset
image_root: "dataset/xingrenfenxi/data/"
cls_label_path: "dataset/xingrenfenxi/all_qiye.pkl"
split: 'trainval'
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
# - ResizeImage:
# size: [192, 256]
- RandCropImage:
size: [192, 256]
scale: [0.9, 1.1]
ratio: [0.75, 0.75]
- RandFlipImage:
flip_code: 1
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 64
drop_last: True
shuffle: True
loader:
num_workers: 4
use_shared_memory: True
Eval:
dataset:
name: AttrDataset
image_root: "dataset/xingrenfenxi/data/"
cls_label_path: "dataset/xingrenfenxi/all_qiye.pkl"
split: 'test'
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
size: [192, 256]
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 64
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Metric:
Eval:
- ATTRMetric:
...@@ -30,6 +30,7 @@ from ppcls.data.dataloader.icartoon_dataset import ICartoonDataset ...@@ -30,6 +30,7 @@ from ppcls.data.dataloader.icartoon_dataset import ICartoonDataset
from ppcls.data.dataloader.mix_dataset import MixDataset from ppcls.data.dataloader.mix_dataset import MixDataset
from ppcls.data.dataloader.multi_scale_dataset import MultiScaleDataset from ppcls.data.dataloader.multi_scale_dataset import MultiScaleDataset
from ppcls.data.dataloader.person_dataset import Market1501, MSMT17 from ppcls.data.dataloader.person_dataset import Market1501, MSMT17
from ppcls.data.dataloader.attr_dataset import AttrDataset
# sampler # sampler
......
...@@ -18,7 +18,7 @@ import time ...@@ -18,7 +18,7 @@ import time
import platform import platform
import paddle import paddle
from ppcls.utils.misc import AverageMeter from ppcls.utils.misc import AverageMeter, AttrMeter
from ppcls.utils import logger from ppcls.utils import logger
...@@ -32,6 +32,10 @@ def classification_eval(engine, epoch_id=0): ...@@ -32,6 +32,10 @@ def classification_eval(engine, epoch_id=0):
} }
print_batch_step = engine.config["Global"]["print_batch_step"] print_batch_step = engine.config["Global"]["print_batch_step"]
if engine.eval_metric_func is not None and engine.config["Global"][
"metric_attr"]:
output_info["attr"] = AttrMeter(threshold=0.5)
metric_key = None metric_key = None
tic = time.time() tic = time.time()
accum_samples = 0 accum_samples = 0
...@@ -121,17 +125,22 @@ def classification_eval(engine, epoch_id=0): ...@@ -121,17 +125,22 @@ def classification_eval(engine, epoch_id=0):
output_info[key] = AverageMeter(key, '7.5f') output_info[key] = AverageMeter(key, '7.5f')
output_info[key].update(loss_dict[key].numpy()[0], output_info[key].update(loss_dict[key].numpy()[0],
current_samples) current_samples)
# calc metric # calc metric
if engine.eval_metric_func is not None: if engine.eval_metric_func is not None:
metric_dict = engine.eval_metric_func(preds, labels) if engine.config["Global"]["metric_attr"]:
for key in metric_dict: metric_dict = engine.eval_metric_func(preds, labels)
if metric_key is None: metric_key = "attr"
metric_key = key output_info["attr"].update(metric_dict)
if key not in output_info: else:
output_info[key] = AverageMeter(key, '7.5f') metric_dict = engine.eval_metric_func(preds, labels)
for key in metric_dict:
output_info[key].update(metric_dict[key].numpy()[0], if metric_key is None:
current_samples) metric_key = key
if key not in output_info:
output_info[key] = AverageMeter(key, '7.5f')
output_info[key].update(metric_dict[key].numpy()[0],
current_samples)
time_info["batch_cost"].update(time.time() - tic) time_info["batch_cost"].update(time.time() - tic)
...@@ -144,10 +153,13 @@ def classification_eval(engine, epoch_id=0): ...@@ -144,10 +153,13 @@ def classification_eval(engine, epoch_id=0):
ips_msg = "ips: {:.5f} images/sec".format( ips_msg = "ips: {:.5f} images/sec".format(
batch_size / time_info["batch_cost"].avg) batch_size / time_info["batch_cost"].avg)
metric_msg = ", ".join([ if engine.config["Global"]["metric_attr"]:
"{}: {:.5f}".format(key, output_info[key].val) metric_msg = ""
for key in output_info else:
]) metric_msg = ", ".join([
"{}: {:.5f}".format(key, output_info[key].val)
for key in output_info
])
logger.info("[Eval][Epoch {}][Iter: {}/{}]{}, {}, {}".format( logger.info("[Eval][Epoch {}][Iter: {}/{}]{}, {}, {}".format(
epoch_id, iter_id, epoch_id, iter_id,
len(engine.eval_dataloader), metric_msg, time_msg, ips_msg)) len(engine.eval_dataloader), metric_msg, time_msg, ips_msg))
...@@ -155,13 +167,28 @@ def classification_eval(engine, epoch_id=0): ...@@ -155,13 +167,28 @@ def classification_eval(engine, epoch_id=0):
tic = time.time() tic = time.time()
if engine.use_dali: if engine.use_dali:
engine.eval_dataloader.reset() engine.eval_dataloader.reset()
metric_msg = ", ".join([
"{}: {:.5f}".format(key, output_info[key].avg) for key in output_info if engine.config["Global"]["metric_attr"]:
]) metric_msg = ", ".join([
logger.info("[Eval][Epoch {}][Avg]{}".format(epoch_id, metric_msg)) "evalres: ma: {:.5f} label_f1: {:.5f} label_pos_recall: {:.5f} label_neg_recall: {:.5f} instance_f1: {:.5f} instance_acc: {:.5f} instance_prec: {:.5f} instance_recall: {:.5f}".
format(*output_info["attr"].res())
# do not try to save best eval.model ])
if engine.eval_metric_func is None: logger.info("[Eval][Epoch {}][Avg]{}".format(epoch_id, metric_msg))
return -1
# return 1st metric in the dict # do not try to save best eval.model
return output_info[metric_key].avg if engine.eval_metric_func is None:
return -1
# return 1st metric in the dict
return output_info["attr"].res()[0]
else:
metric_msg = ", ".join([
"{}: {:.5f}".format(key, output_info[key].avg)
for key in output_info
])
logger.info("[Eval][Epoch {}][Avg]{}".format(epoch_id, metric_msg))
# do not try to save best eval.model
if engine.eval_metric_func is None:
return -1
# return 1st metric in the dict
return output_info[metric_key].avg
...@@ -26,6 +26,7 @@ from .distillationloss import DistillationKLDivLoss ...@@ -26,6 +26,7 @@ from .distillationloss import DistillationKLDivLoss
from .distillationloss import DistillationDKDLoss from .distillationloss import DistillationDKDLoss
from .multilabelloss import MultiLabelLoss from .multilabelloss import MultiLabelLoss
from .afdloss import AFDLoss from .afdloss import AFDLoss
from .bceloss import BCELoss
from .deephashloss import DSHSDLoss from .deephashloss import DSHSDLoss
from .deephashloss import LCDSHLoss from .deephashloss import LCDSHLoss
......
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
def ratio2weight(targets, ratio):
# print(targets)
pos_weights = targets * (1. - ratio)
neg_weights = (1. - targets) * ratio
weights = paddle.exp(neg_weights + pos_weights)
# for RAP dataloader, targets element may be 2, with or without smooth, some element must great than 1
weights = weights - weights * (targets > 1)
return weights
class BCELoss(nn.Layer):
"""BCE Loss.
Args:
"""
def __init__(self,
sample_weight=True,
size_sum=True,
smoothing=None,
weight=1.0):
super(BCELoss, self).__init__()
self.sample_weight = sample_weight
self.size_sum = size_sum
self.hyper = 0.8
self.smoothing = smoothing
def forward(self, logits, labels):
targets, ratio = labels
if self.smoothing is not None:
targets = (1 - self.smoothing) * targets + self.smoothing * (
1 - targets)
targets = paddle.cast(targets, 'float32')
loss_m = F.binary_cross_entropy_with_logits(
logits, targets, reduction='none')
targets_mask = paddle.cast(targets > 0.5, 'float32')
if self.sample_weight:
weight = ratio2weight(targets_mask, ratio[0])
weight = weight * (targets > -1)
loss_m = loss_m * weight
loss = loss_m.sum(1).mean() if self.size_sum else loss_m.sum()
return {"BCELoss": loss}
...@@ -20,6 +20,7 @@ from .metrics import TopkAcc, mAP, mINP, Recallk, Precisionk ...@@ -20,6 +20,7 @@ from .metrics import TopkAcc, mAP, mINP, Recallk, Precisionk
from .metrics import DistillationTopkAcc from .metrics import DistillationTopkAcc
from .metrics import GoogLeNetTopkAcc from .metrics import GoogLeNetTopkAcc
from .metrics import HammingDistance, AccuracyScore from .metrics import HammingDistance, AccuracyScore
from .metrics import ATTRMetric
class CombinedMetrics(nn.Layer): class CombinedMetrics(nn.Layer):
......
...@@ -22,6 +22,8 @@ from sklearn.metrics import accuracy_score as accuracy_metric ...@@ -22,6 +22,8 @@ from sklearn.metrics import accuracy_score as accuracy_metric
from sklearn.metrics import multilabel_confusion_matrix from sklearn.metrics import multilabel_confusion_matrix
from sklearn.preprocessing import binarize from sklearn.preprocessing import binarize
from easydict import EasyDict
class TopkAcc(nn.Layer): class TopkAcc(nn.Layer):
def __init__(self, topk=(1, 5)): def __init__(self, topk=(1, 5)):
...@@ -308,3 +310,59 @@ class AccuracyScore(MutiLabelMetric): ...@@ -308,3 +310,59 @@ class AccuracyScore(MutiLabelMetric):
sum(tps) + sum(tns) + sum(fns) + sum(fps)) sum(tps) + sum(tns) + sum(fns) + sum(fps))
metric_dict["AccuracyScore"] = paddle.to_tensor(accuracy) metric_dict["AccuracyScore"] = paddle.to_tensor(accuracy)
return metric_dict return metric_dict
def get_attr_metrics(gt_label, preds_probs, threshold):
"""
index: evaluated label index
"""
pred_label = (preds_probs > threshold).astype(int)
eps = 1e-20
result = EasyDict()
has_fuyi = gt_label == -1
pred_label[has_fuyi] = -1
###############################
# label metrics
# TP + FN
result.gt_pos = np.sum((gt_label == 1), axis=0).astype(float)
# TN + FP
result.gt_neg = np.sum((gt_label == 0), axis=0).astype(float)
# TP
result.true_pos = np.sum((gt_label == 1) * (pred_label == 1),
axis=0).astype(float)
# TN
result.true_neg = np.sum((gt_label == 0) * (pred_label == 0),
axis=0).astype(float)
# FP
result.false_pos = np.sum(((gt_label == 0) * (pred_label == 1)),
axis=0).astype(float)
# FN
result.false_neg = np.sum(((gt_label == 1) * (pred_label == 0)),
axis=0).astype(float)
################
# instance metrics
result.gt_pos_ins = np.sum((gt_label == 1), axis=1).astype(float)
result.true_pos_ins = np.sum((pred_label == 1), axis=1).astype(float)
# true positive
result.intersect_pos = np.sum((gt_label == 1) * (pred_label == 1),
axis=1).astype(float)
# IOU
result.union_pos = np.sum(((gt_label == 1) + (pred_label == 1)),
axis=1).astype(float)
return result
class ATTRMetric(nn.Layer):
def __init__(self, threshold=0.5):
super().__init__()
self.threshold = threshold
def __call__(self, output, target):
metric_dict = get_attr_metrics(target[0].numpy(),
output.numpy(), self.threshold)
return metric_dict
...@@ -61,3 +61,87 @@ class AverageMeter(object): ...@@ -61,3 +61,87 @@ class AverageMeter(object):
def value(self): def value(self):
return '{self.name}: {self.val:{self.fmt}}{self.postfix}'.format( return '{self.name}: {self.val:{self.fmt}}{self.postfix}'.format(
self=self) self=self)
class AttrMeter(object):
"""
Computes and stores the average and current value
Code was based on https://github.com/pytorch/examples/blob/master/imagenet/main.py
"""
def __init__(self, threshold=0.5):
self.threshold = threshold
self.reset()
def reset(self):
self.gt_pos = 0
self.gt_neg = 0
self.true_pos = 0
self.true_neg = 0
self.false_pos = 0
self.false_neg = 0
self.gt_pos_ins = []
self.true_pos_ins = []
self.intersect_pos = []
self.union_pos = []
def update(self, metric_dict):
self.gt_pos += metric_dict['gt_pos']
self.gt_neg += metric_dict['gt_neg']
self.true_pos += metric_dict['true_pos']
self.true_neg += metric_dict['true_neg']
self.false_pos += metric_dict['false_pos']
self.false_neg += metric_dict['false_neg']
self.gt_pos_ins += metric_dict['gt_pos_ins'].tolist()
self.true_pos_ins += metric_dict['true_pos_ins'].tolist()
self.intersect_pos += metric_dict['intersect_pos'].tolist()
self.union_pos += metric_dict['union_pos'].tolist()
def res(self):
import numpy as np
eps = 1e-20
label_pos_recall = 1.0 * self.true_pos / (
self.gt_pos + eps) # true positive
label_neg_recall = 1.0 * self.true_neg / (
self.gt_neg + eps) # true negative
# mean accuracy
label_ma = (label_pos_recall + label_neg_recall) / 2
label_pos_recall = np.mean(label_pos_recall)
label_neg_recall = np.mean(label_neg_recall)
label_prec = (self.true_pos / (self.true_pos + self.false_pos + eps))
label_acc = (self.true_pos /
(self.true_pos + self.false_pos + self.false_neg + eps))
label_f1 = np.mean(2 * label_prec * label_pos_recall /
(label_prec + label_pos_recall + eps))
ma = (np.mean(label_ma))
self.gt_pos_ins = np.array(self.gt_pos_ins)
self.true_pos_ins = np.array(self.true_pos_ins)
self.intersect_pos = np.array(self.intersect_pos)
self.union_pos = np.array(self.union_pos)
instance_acc = self.intersect_pos / (self.union_pos + eps)
instance_prec = self.intersect_pos / (self.true_pos_ins + eps)
instance_recall = self.intersect_pos / (self.gt_pos_ins + eps)
instance_f1 = 2 * instance_prec * instance_recall / (
instance_prec + instance_recall + eps)
instance_acc = np.mean(instance_acc)
instance_prec = np.mean(instance_prec)
instance_recall = np.mean(instance_recall)
instance_f1 = 2 * instance_prec * instance_recall / (
instance_prec + instance_recall + eps)
instance_acc = np.mean(instance_acc)
instance_prec = np.mean(instance_prec)
instance_recall = np.mean(instance_recall)
instance_f1 = np.mean(instance_f1)
res = [
ma, label_f1, label_pos_recall, label_neg_recall, instance_f1,
instance_acc, instance_prec, instance_recall
]
return res
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册