提交 4afca16d 编写于 作者: W weishengyu

Merge branch 'develop_reg' of https://github.com/weisy11/PaddleClas into develop_reg

#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
......@@ -17,11 +17,9 @@ import importlib
import paddle.nn as nn
from . import backbone
from . import gears
from . import backbone, gears
from .backbone import *
from .gears import *
from .gears import build_gear
from .utils import *
__all__ = ["build_model", "RecModel"]
......@@ -38,34 +36,28 @@ def build_model(config):
class RecModel(nn.Layer):
def __init__(self, **config):
super().__init__()
backbone_config = config["Backbone"]
backbone_name = backbone_config.pop("name")
self.backbone = eval(backbone_name)(**backbone_config)
if "BackboneStopLayer" in config:
backbone_stop_layer = config["BackboneStopLayer"]["name"]
self.backbone.stop_after(backbone_stop_layer)
assert "Stoplayer" in config, "Stoplayer should be specified in retrieval task \
please specified a Stoplayer config"
stop_layer_config = config["Stoplayer"]
self.backbone.stop_after(stop_layer_config["name"])
if stop_layer_config.get("embedding_size", 0) > 0:
self.neck = nn.Linear(stop_layer_config["output_dim"],
stop_layer_config["embedding_size"])
embedding_size = stop_layer_config["embedding_size"]
if "Neck" in config:
self.neck = build_gear(config["Neck"])
else:
self.neck = None
embedding_size = stop_layer_config["output_dim"]
assert "Head" in config, "Head should be specified in retrieval task \
please specify a Head config"
config["Head"]["embedding_size"] = embedding_size
self.head = build_head(config["Head"])
if "Head" in config:
self.head = build_gear(config["Head"])
else:
self.head = None
def forward(self, x, label):
x = self.backbone(x)
if self.neck is not None:
x = self.neck(x)
y = self.head(x, label)
y = x
if self.head is not None:
y = self.head(x, label)
return {"features": x, "logits": y}
......@@ -16,12 +16,15 @@ from .arcmargin import ArcMargin
from .cosmargin import CosMargin
from .circlemargin import CircleMargin
from .fc import FC
from .vehicle_neck import VehicleNeck
__all__ = ['build_head']
__all__ = ['build_gear']
def build_head(config):
support_dict = ['ArcMargin', 'CosMargin', 'CircleMargin', 'FC']
def build_gear(config):
support_dict = [
'ArcMargin', 'CosMargin', 'CircleMargin', 'FC', 'VehicleNeck'
]
module_name = config.pop('name')
assert module_name in support_dict, Exception(
'head only support {}'.format(support_dict))
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import, division, print_function
import paddle
import paddle.nn as nn
class VehicleNeck(nn.Layer):
def __init__(self,
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0,
dilation=1,
groups=1,
padding_mode='zeros',
weight_attr=None,
bias_attr=None,
data_format='NCHW'):
super().__init__()
self.conv = nn.Conv2D(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
padding_mode=padding_mode,
weight_attr=weight_attr,
bias_attr=weight_attr,
data_format=data_format)
self.flatten = nn.Flatten()
def forward(self, x):
x = self.conv(x)
x = self.flatten(x)
return x
......@@ -19,11 +19,14 @@ Global:
Arch:
name: "RecModel"
Backbone:
name: "ResNet50"
Stoplayer:
name: "flatten_0"
output_dim: 2048
embedding_size: 512
name: "ResNet50_last_stage_stride1"
pretrained: True
BackboneStopLayer:
name: "adaptive_avg_pool2d_0"
Neck:
name: "VehicleNeck"
in_channels: 2048
out_channels: 512
Head:
name: "ArcMargin"
embedding_size: 512
......@@ -88,7 +91,7 @@ DataLoader:
sampler:
name: DistributedRandomIdentitySampler
batch_size: 64
batch_size: 128
num_instances: 2
drop_last: False
shuffle: True
......@@ -114,7 +117,7 @@ DataLoader:
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 64
batch_size: 128
drop_last: False
shuffle: False
loader:
......
......@@ -19,16 +19,18 @@ Global:
num_split: 1
feature_normalize: True
# model architecture
Arch:
name: "RecModel"
Backbone:
name: "ResNet50"
Stoplayer:
name: "flatten_0"
output_dim: 2048
embedding_size: 512
name: "ResNet50_last_stage_stride1"
pretrained: True
BackboneStopLayer:
name: "adaptive_avg_pool2d_0"
Neck:
name: "VehicleNeck"
in_channels: 2048
out_channels: 512
Head:
name: "ArcMargin"
embedding_size: 512
......@@ -41,9 +43,9 @@ Loss:
Train:
- CELoss:
weight: 1.0
- TripletLossV2:
- SupConLoss:
weight: 1.0
margin: 0.5
views: 2
Eval:
- CELoss:
weight: 1.0
......@@ -68,7 +70,7 @@ DataLoader:
dataset:
name: "VeriWild"
image_root: "/work/dataset/VeRI-Wild/images/"
cls_label_path: "/work/dataset/VeRI-Wild/train_test_split/debug_train.txt"
cls_label_path: "/work/dataset/VeRI-Wild/train_test_split/train_list_start0.txt"
transform_ops:
- ResizeImage:
size: 224
......@@ -103,7 +105,7 @@ DataLoader:
dataset:
name: "VeriWild"
image_root: "/work/dataset/VeRI-Wild/images"
cls_label_path: "/work/dataset/VeRI-Wild/train_test_split/debug_test_query.txt"
cls_label_path: "/work/dataset/VeRI-Wild/train_test_split/test_3000_id_query.txt"
transform_ops:
- ResizeImage:
size: 224
......@@ -126,7 +128,7 @@ DataLoader:
dataset:
name: "VeriWild"
image_root: "/work/dataset/VeRI-Wild/images"
cls_label_path: "/work/dataset/VeRI-Wild/train_test_split/debug_test.txt"
cls_label_path: "/work/dataset/VeRI-Wild/train_test_split/test_3000_id.txt"
transform_ops:
- ResizeImage:
size: 224
......
......@@ -11,9 +11,11 @@ from .msmloss import MSMLoss
from .npairsloss import NpairsLoss
from .trihardloss import TriHardLoss
from .triplet import TripletLoss, TripletLossV2
from .supconloss import SupConLoss
from .pairwisecosface import PairwiseCosface
class CombinedLoss(nn.Layer):
def __init__(self, config_list):
super().__init__()
......
import paddle
from paddle import nn
class SupConLoss(nn.Layer):
"""Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
It also supports the unsupervised contrastive loss in SimCLR"""
def __init__(self,
views=16,
temperature=0.07,
contrast_mode='all',
base_temperature=0.07,
normalize_feature=True):
super(SupConLoss, self).__init__()
self.temperature = paddle.to_tensor(temperature)
self.contrast_mode = contrast_mode
self.base_temperature = paddle.to_tensor(base_temperature)
self.num_ids = None
self.views = views
self.normalize_feature = normalize_feature
def forward(self, features, labels, mask=None):
"""Compute loss for model. If both `labels` and `mask` are None,
it degenerates to SimCLR unsupervised loss:
https://arxiv.org/pdf/2002.05709.pdf
Args:
features: hidden vector of shape [bsz, n_views, ...].
labels: ground truth of shape [bsz].
mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
has the same class as sample i. Can be asymmetric.
Returns:
A loss scalar.
"""
features = features["features"]
if self.num_ids is None:
self.num_ids = int(features.shape[0] / self.views)
if self.normalize_feature:
features = 1. * features / (paddle.expand_as(
paddle.norm(
features, p=2, axis=-1, keepdim=True), features) + 1e-12)
features = features.reshape([self.num_ids, self.views, -1])
labels = labels.reshape([self.num_ids, self.views])[:, 0]
if len(features.shape) < 3:
raise ValueError('`features` needs to be [bsz, n_views, ...],'
'at least 3 dimensions are required')
if len(features.shape) > 3:
features = features.reshape(
[features.shape[0], features.shape[1], -1])
batch_size = features.shape[0]
if labels is not None and mask is not None:
raise ValueError('Cannot define both `labels` and `mask`')
elif labels is None and mask is None:
mask = paddle.eye(batch_size, dtype='float32')
elif labels is not None:
labels = labels.reshape([-1, 1])
if labels.shape[0] != batch_size:
raise ValueError(
'Num of labels does not match num of features')
mask = paddle.cast(
paddle.equal(labels, paddle.t(labels)), 'float32')
else:
mask = paddle.cast(mask, 'float32')
contrast_count = features.shape[1]
contrast_feature = paddle.concat(
paddle.unbind(
features, axis=1), axis=0)
if self.contrast_mode == 'one':
anchor_feature = features[:, 0]
anchor_count = 1
elif self.contrast_mode == 'all':
anchor_feature = contrast_feature
anchor_count = contrast_count
else:
raise ValueError('Unknown mode: {}'.format(self.contrast_mode))
# compute logits
anchor_dot_contrast = paddle.divide(
paddle.matmul(anchor_feature, paddle.t(contrast_feature)),
self.temperature)
# for numerical stability
logits_max = paddle.max(anchor_dot_contrast, axis=1, keepdim=True)
logits = anchor_dot_contrast - logits_max.detach()
# tile mask
mask = paddle.tile(mask, [anchor_count, contrast_count])
logits_mask = 1 - paddle.eye(batch_size * anchor_count)
mask = mask * logits_mask
# compute log_prob
exp_logits = paddle.exp(logits) * logits_mask
log_prob = logits - paddle.log(
paddle.sum(exp_logits, axis=1, keepdim=True))
# compute mean of log-likelihood over positive
mean_log_prob_pos = paddle.sum((mask * log_prob),
axis=1) / paddle.sum(mask, axis=1)
# loss
loss = -(self.temperature / self.base_temperature) * mean_log_prob_pos
loss = paddle.mean(loss.reshape([anchor_count, batch_size]))
return {"SupConLoss": loss}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册