提交 7673d94a 编写于 作者: W Wei Shengyu

Merge branch 'PaddlePaddle:develop_reg' into develop_reg

#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. #copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
# #
#Licensed under the Apache License, Version 2.0 (the "License"); #Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License. #you may not use this file except in compliance with the License.
...@@ -17,11 +17,9 @@ import importlib ...@@ -17,11 +17,9 @@ import importlib
import paddle.nn as nn import paddle.nn as nn
from . import backbone from . import backbone, gears
from . import gears
from .backbone import * from .backbone import *
from .gears import * from .gears import build_gear
from .utils import * from .utils import *
__all__ = ["build_model", "RecModel"] __all__ = ["build_model", "RecModel"]
...@@ -38,34 +36,28 @@ def build_model(config): ...@@ -38,34 +36,28 @@ def build_model(config):
class RecModel(nn.Layer): class RecModel(nn.Layer):
def __init__(self, **config): def __init__(self, **config):
super().__init__() super().__init__()
backbone_config = config["Backbone"] backbone_config = config["Backbone"]
backbone_name = backbone_config.pop("name") backbone_name = backbone_config.pop("name")
self.backbone = eval(backbone_name)(**backbone_config) self.backbone = eval(backbone_name)(**backbone_config)
if "BackboneStopLayer" in config:
backbone_stop_layer = config["BackboneStopLayer"]["name"]
self.backbone.stop_after(backbone_stop_layer)
assert "Stoplayer" in config, "Stoplayer should be specified in retrieval task \ if "Neck" in config:
please specified a Stoplayer config" self.neck = build_gear(config["Neck"])
stop_layer_config = config["Stoplayer"]
self.backbone.stop_after(stop_layer_config["name"])
if stop_layer_config.get("embedding_size", 0) > 0:
self.neck = nn.Linear(stop_layer_config["output_dim"],
stop_layer_config["embedding_size"])
embedding_size = stop_layer_config["embedding_size"]
else: else:
self.neck = None self.neck = None
embedding_size = stop_layer_config["output_dim"]
assert "Head" in config, "Head should be specified in retrieval task \
please specify a Head config"
config["Head"]["embedding_size"] = embedding_size if "Head" in config:
self.head = build_head(config["Head"]) self.head = build_gear(config["Head"])
else:
self.head = None
def forward(self, x, label): def forward(self, x, label):
x = self.backbone(x) x = self.backbone(x)
if self.neck is not None: if self.neck is not None:
x = self.neck(x) x = self.neck(x)
y = self.head(x, label) y = x
if self.head is not None:
y = self.head(x, label)
return {"features": x, "logits": y} return {"features": x, "logits": y}
...@@ -16,12 +16,15 @@ from .arcmargin import ArcMargin ...@@ -16,12 +16,15 @@ from .arcmargin import ArcMargin
from .cosmargin import CosMargin from .cosmargin import CosMargin
from .circlemargin import CircleMargin from .circlemargin import CircleMargin
from .fc import FC from .fc import FC
from .vehicle_neck import VehicleNeck
__all__ = ['build_head'] __all__ = ['build_gear']
def build_head(config): def build_gear(config):
support_dict = ['ArcMargin', 'CosMargin', 'CircleMargin', 'FC'] support_dict = [
'ArcMargin', 'CosMargin', 'CircleMargin', 'FC', 'VehicleNeck'
]
module_name = config.pop('name') module_name = config.pop('name')
assert module_name in support_dict, Exception( assert module_name in support_dict, Exception(
'head only support {}'.format(support_dict)) 'head only support {}'.format(support_dict))
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import, division, print_function
import paddle
import paddle.nn as nn
class VehicleNeck(nn.Layer):
def __init__(self,
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0,
dilation=1,
groups=1,
padding_mode='zeros',
weight_attr=None,
bias_attr=None,
data_format='NCHW'):
super().__init__()
self.conv = nn.Conv2D(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
padding_mode=padding_mode,
weight_attr=weight_attr,
bias_attr=weight_attr,
data_format=data_format)
self.flatten = nn.Flatten()
def forward(self, x):
x = self.conv(x)
x = self.flatten(x)
return x
...@@ -19,11 +19,14 @@ Global: ...@@ -19,11 +19,14 @@ Global:
Arch: Arch:
name: "RecModel" name: "RecModel"
Backbone: Backbone:
name: "ResNet50" name: "ResNet50_last_stage_stride1"
Stoplayer: pretrained: True
name: "flatten_0" BackboneStopLayer:
output_dim: 2048 name: "adaptive_avg_pool2d_0"
embedding_size: 512 Neck:
name: "VehicleNeck"
in_channels: 2048
out_channels: 512
Head: Head:
name: "ArcMargin" name: "ArcMargin"
embedding_size: 512 embedding_size: 512
...@@ -88,7 +91,7 @@ DataLoader: ...@@ -88,7 +91,7 @@ DataLoader:
sampler: sampler:
name: DistributedRandomIdentitySampler name: DistributedRandomIdentitySampler
batch_size: 64 batch_size: 128
num_instances: 2 num_instances: 2
drop_last: False drop_last: False
shuffle: True shuffle: True
...@@ -114,7 +117,7 @@ DataLoader: ...@@ -114,7 +117,7 @@ DataLoader:
order: '' order: ''
sampler: sampler:
name: DistributedBatchSampler name: DistributedBatchSampler
batch_size: 64 batch_size: 128
drop_last: False drop_last: False
shuffle: False shuffle: False
loader: loader:
......
...@@ -19,16 +19,18 @@ Global: ...@@ -19,16 +19,18 @@ Global:
num_split: 1 num_split: 1
feature_normalize: True feature_normalize: True
# model architecture # model architecture
Arch: Arch:
name: "RecModel" name: "RecModel"
Backbone: Backbone:
name: "ResNet50" name: "ResNet50_last_stage_stride1"
Stoplayer: pretrained: True
name: "flatten_0" BackboneStopLayer:
output_dim: 2048 name: "adaptive_avg_pool2d_0"
embedding_size: 512 Neck:
name: "VehicleNeck"
in_channels: 2048
out_channels: 512
Head: Head:
name: "ArcMargin" name: "ArcMargin"
embedding_size: 512 embedding_size: 512
...@@ -41,9 +43,9 @@ Loss: ...@@ -41,9 +43,9 @@ Loss:
Train: Train:
- CELoss: - CELoss:
weight: 1.0 weight: 1.0
- TripletLossV2: - SupConLoss:
weight: 1.0 weight: 1.0
margin: 0.5 views: 2
Eval: Eval:
- CELoss: - CELoss:
weight: 1.0 weight: 1.0
...@@ -68,7 +70,7 @@ DataLoader: ...@@ -68,7 +70,7 @@ DataLoader:
dataset: dataset:
name: "VeriWild" name: "VeriWild"
image_root: "/work/dataset/VeRI-Wild/images/" image_root: "/work/dataset/VeRI-Wild/images/"
cls_label_path: "/work/dataset/VeRI-Wild/train_test_split/debug_train.txt" cls_label_path: "/work/dataset/VeRI-Wild/train_test_split/train_list_start0.txt"
transform_ops: transform_ops:
- ResizeImage: - ResizeImage:
size: 224 size: 224
...@@ -103,7 +105,7 @@ DataLoader: ...@@ -103,7 +105,7 @@ DataLoader:
dataset: dataset:
name: "VeriWild" name: "VeriWild"
image_root: "/work/dataset/VeRI-Wild/images" image_root: "/work/dataset/VeRI-Wild/images"
cls_label_path: "/work/dataset/VeRI-Wild/train_test_split/debug_test_query.txt" cls_label_path: "/work/dataset/VeRI-Wild/train_test_split/test_3000_id_query.txt"
transform_ops: transform_ops:
- ResizeImage: - ResizeImage:
size: 224 size: 224
...@@ -126,7 +128,7 @@ DataLoader: ...@@ -126,7 +128,7 @@ DataLoader:
dataset: dataset:
name: "VeriWild" name: "VeriWild"
image_root: "/work/dataset/VeRI-Wild/images" image_root: "/work/dataset/VeRI-Wild/images"
cls_label_path: "/work/dataset/VeRI-Wild/train_test_split/debug_test.txt" cls_label_path: "/work/dataset/VeRI-Wild/train_test_split/test_3000_id.txt"
transform_ops: transform_ops:
- ResizeImage: - ResizeImage:
size: 224 size: 224
......
...@@ -11,9 +11,11 @@ from .msmloss import MSMLoss ...@@ -11,9 +11,11 @@ from .msmloss import MSMLoss
from .npairsloss import NpairsLoss from .npairsloss import NpairsLoss
from .trihardloss import TriHardLoss from .trihardloss import TriHardLoss
from .triplet import TripletLoss, TripletLossV2 from .triplet import TripletLoss, TripletLossV2
from .supconloss import SupConLoss
from .pairwisecosface import PairwiseCosface from .pairwisecosface import PairwiseCosface
class CombinedLoss(nn.Layer): class CombinedLoss(nn.Layer):
def __init__(self, config_list): def __init__(self, config_list):
super().__init__() super().__init__()
......
import paddle
from paddle import nn
class SupConLoss(nn.Layer):
"""Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
It also supports the unsupervised contrastive loss in SimCLR"""
def __init__(self,
views=16,
temperature=0.07,
contrast_mode='all',
base_temperature=0.07,
normalize_feature=True):
super(SupConLoss, self).__init__()
self.temperature = paddle.to_tensor(temperature)
self.contrast_mode = contrast_mode
self.base_temperature = paddle.to_tensor(base_temperature)
self.num_ids = None
self.views = views
self.normalize_feature = normalize_feature
def forward(self, features, labels, mask=None):
"""Compute loss for model. If both `labels` and `mask` are None,
it degenerates to SimCLR unsupervised loss:
https://arxiv.org/pdf/2002.05709.pdf
Args:
features: hidden vector of shape [bsz, n_views, ...].
labels: ground truth of shape [bsz].
mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
has the same class as sample i. Can be asymmetric.
Returns:
A loss scalar.
"""
features = features["features"]
if self.num_ids is None:
self.num_ids = int(features.shape[0] / self.views)
if self.normalize_feature:
features = 1. * features / (paddle.expand_as(
paddle.norm(
features, p=2, axis=-1, keepdim=True), features) + 1e-12)
features = features.reshape([self.num_ids, self.views, -1])
labels = labels.reshape([self.num_ids, self.views])[:, 0]
if len(features.shape) < 3:
raise ValueError('`features` needs to be [bsz, n_views, ...],'
'at least 3 dimensions are required')
if len(features.shape) > 3:
features = features.reshape(
[features.shape[0], features.shape[1], -1])
batch_size = features.shape[0]
if labels is not None and mask is not None:
raise ValueError('Cannot define both `labels` and `mask`')
elif labels is None and mask is None:
mask = paddle.eye(batch_size, dtype='float32')
elif labels is not None:
labels = labels.reshape([-1, 1])
if labels.shape[0] != batch_size:
raise ValueError(
'Num of labels does not match num of features')
mask = paddle.cast(
paddle.equal(labels, paddle.t(labels)), 'float32')
else:
mask = paddle.cast(mask, 'float32')
contrast_count = features.shape[1]
contrast_feature = paddle.concat(
paddle.unbind(
features, axis=1), axis=0)
if self.contrast_mode == 'one':
anchor_feature = features[:, 0]
anchor_count = 1
elif self.contrast_mode == 'all':
anchor_feature = contrast_feature
anchor_count = contrast_count
else:
raise ValueError('Unknown mode: {}'.format(self.contrast_mode))
# compute logits
anchor_dot_contrast = paddle.divide(
paddle.matmul(anchor_feature, paddle.t(contrast_feature)),
self.temperature)
# for numerical stability
logits_max = paddle.max(anchor_dot_contrast, axis=1, keepdim=True)
logits = anchor_dot_contrast - logits_max.detach()
# tile mask
mask = paddle.tile(mask, [anchor_count, contrast_count])
logits_mask = 1 - paddle.eye(batch_size * anchor_count)
mask = mask * logits_mask
# compute log_prob
exp_logits = paddle.exp(logits) * logits_mask
log_prob = logits - paddle.log(
paddle.sum(exp_logits, axis=1, keepdim=True))
# compute mean of log-likelihood over positive
mean_log_prob_pos = paddle.sum((mask * log_prob),
axis=1) / paddle.sum(mask, axis=1)
# loss
loss = -(self.temperature / self.base_temperature) * mean_log_prob_pos
loss = paddle.mean(loss.reshape([anchor_count, batch_size]))
return {"SupConLoss": loss}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册