提交 8a8eb34d 编写于 作者: D dongshuilong

add vehicle neck and fix bugs

上级 505c9309
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. #copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
# #
#Licensed under the Apache License, Version 2.0 (the "License"); #Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License. #you may not use this file except in compliance with the License.
...@@ -17,9 +17,7 @@ import importlib ...@@ -17,9 +17,7 @@ import importlib
import paddle.nn as nn import paddle.nn as nn
from . import backbone from . import backbone, gears
from . import gears
from .backbone import * from .backbone import *
from .gears import build_gear from .gears import build_gear
from .utils import * from .utils import *
...@@ -40,10 +38,10 @@ class RecModel(nn.Layer): ...@@ -40,10 +38,10 @@ class RecModel(nn.Layer):
super().__init__() super().__init__()
backbone_config = config["Backbone"] backbone_config = config["Backbone"]
backbone_name = backbone_config.pop("name") backbone_name = backbone_config.pop("name")
self.backbone = getattr(backbone_name)(**backbone_config) self.backbone = eval(backbone_name)(**backbone_config)
if "BackboneStopLayer" in config: if "BackboneStopLayer" in config:
backbone_stop_layer = config["BackboneStopLayer"] backbone_stop_layer = config["BackboneStopLayer"]["name"]
self.backbone.stop_layer(backbone_stop_layer) self.backbone.stop_after(backbone_stop_layer)
if "Neck" in config: if "Neck" in config:
self.neck = build_gear(config["Neck"]) self.neck = build_gear(config["Neck"])
...@@ -55,10 +53,11 @@ class RecModel(nn.Layer): ...@@ -55,10 +53,11 @@ class RecModel(nn.Layer):
else: else:
self.head = None self.head = None
def forward(self, x): def forward(self, x, label):
y = self.backbone(x) x = self.backbone(x)
if self.neck is not None: if self.neck is not None:
y = self.neck(y) x = self.neck(x)
y = x
if self.head is not None: if self.head is not None:
y = self.head(y) y = self.head(x, label)
return y return {"features": x, "logits": y}
...@@ -21,8 +21,8 @@ Arch: ...@@ -21,8 +21,8 @@ Arch:
Backbone: Backbone:
name: "ResNet50_last_stage_stride1" name: "ResNet50_last_stage_stride1"
pretrained: True pretrained: True
BackboneStoplayer: BackboneStopLayer:
name: "adaptive_avg_pool2d_1" name: "adaptive_avg_pool2d_0"
Neck: Neck:
name: "VehicleNeck" name: "VehicleNeck"
in_channels: 2048 in_channels: 2048
...@@ -91,7 +91,7 @@ DataLoader: ...@@ -91,7 +91,7 @@ DataLoader:
sampler: sampler:
name: DistributedRandomIdentitySampler name: DistributedRandomIdentitySampler
batch_size: 64 batch_size: 128
num_instances: 2 num_instances: 2
drop_last: False drop_last: False
shuffle: True shuffle: True
...@@ -117,7 +117,7 @@ DataLoader: ...@@ -117,7 +117,7 @@ DataLoader:
order: '' order: ''
sampler: sampler:
name: DistributedBatchSampler name: DistributedBatchSampler
batch_size: 64 batch_size: 128
drop_last: False drop_last: False
shuffle: False shuffle: False
loader: loader:
......
...@@ -25,8 +25,8 @@ Arch: ...@@ -25,8 +25,8 @@ Arch:
Backbone: Backbone:
name: "ResNet50_last_stage_stride1" name: "ResNet50_last_stage_stride1"
pretrained: True pretrained: True
BackboneStoplayer: BackboneStopLayer:
name: "adaptive_avg_pool2d_1" name: "adaptive_avg_pool2d_0"
Neck: Neck:
name: "VehicleNeck" name: "VehicleNeck"
in_channels: 2048 in_channels: 2048
...@@ -43,9 +43,9 @@ Loss: ...@@ -43,9 +43,9 @@ Loss:
Train: Train:
- CELoss: - CELoss:
weight: 1.0 weight: 1.0
- TripletLossV2: - SupConLoss:
weight: 1.0 weight: 1.0
margin: 0.5 views: 2
Eval: Eval:
- CELoss: - CELoss:
weight: 1.0 weight: 1.0
...@@ -70,7 +70,7 @@ DataLoader: ...@@ -70,7 +70,7 @@ DataLoader:
dataset: dataset:
name: "VeriWild" name: "VeriWild"
image_root: "/work/dataset/VeRI-Wild/images/" image_root: "/work/dataset/VeRI-Wild/images/"
cls_label_path: "/work/dataset/VeRI-Wild/train_test_split/debug_train.txt" cls_label_path: "/work/dataset/VeRI-Wild/train_test_split/train_list_start0.txt"
transform_ops: transform_ops:
- ResizeImage: - ResizeImage:
size: 224 size: 224
...@@ -105,7 +105,7 @@ DataLoader: ...@@ -105,7 +105,7 @@ DataLoader:
dataset: dataset:
name: "VeriWild" name: "VeriWild"
image_root: "/work/dataset/VeRI-Wild/images" image_root: "/work/dataset/VeRI-Wild/images"
cls_label_path: "/work/dataset/VeRI-Wild/train_test_split/debug_test_query.txt" cls_label_path: "/work/dataset/VeRI-Wild/train_test_split/test_3000_id_query.txt"
transform_ops: transform_ops:
- ResizeImage: - ResizeImage:
size: 224 size: 224
...@@ -128,7 +128,7 @@ DataLoader: ...@@ -128,7 +128,7 @@ DataLoader:
dataset: dataset:
name: "VeriWild" name: "VeriWild"
image_root: "/work/dataset/VeRI-Wild/images" image_root: "/work/dataset/VeRI-Wild/images"
cls_label_path: "/work/dataset/VeRI-Wild/train_test_split/debug_test.txt" cls_label_path: "/work/dataset/VeRI-Wild/train_test_split/test_3000_id.txt"
transform_ops: transform_ops:
- ResizeImage: - ResizeImage:
size: 224 size: 224
......
...@@ -11,6 +11,7 @@ from .msmloss import MSMLoss ...@@ -11,6 +11,7 @@ from .msmloss import MSMLoss
from .npairsloss import NpairsLoss from .npairsloss import NpairsLoss
from .trihardloss import TriHardLoss from .trihardloss import TriHardLoss
from .triplet import TripletLoss, TripletLossV2 from .triplet import TripletLoss, TripletLossV2
from .supconloss import SupConLoss
class CombinedLoss(nn.Layer): class CombinedLoss(nn.Layer):
......
...@@ -32,6 +32,7 @@ class SupConLoss(nn.Layer): ...@@ -32,6 +32,7 @@ class SupConLoss(nn.Layer):
Returns: Returns:
A loss scalar. A loss scalar.
""" """
features = features["features"]
if self.num_ids is None: if self.num_ids is None:
self.num_ids = int(features.shape[0] / self.views) self.num_ids = int(features.shape[0] / self.views)
...@@ -104,4 +105,4 @@ class SupConLoss(nn.Layer): ...@@ -104,4 +105,4 @@ class SupConLoss(nn.Layer):
loss = -(self.temperature / self.base_temperature) * mean_log_prob_pos loss = -(self.temperature / self.base_temperature) * mean_log_prob_pos
loss = paddle.mean(loss.reshape([anchor_count, batch_size])) loss = paddle.mean(loss.reshape([anchor_count, batch_size]))
return loss return {"SupConLoss": loss}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册