提交 8a8eb34d 编写于 作者: D dongshuilong

add vehicle neck and fix bugs

上级 505c9309
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
......@@ -17,9 +17,7 @@ import importlib
import paddle.nn as nn
from . import backbone
from . import gears
from . import backbone, gears
from .backbone import *
from .gears import build_gear
from .utils import *
......@@ -40,10 +38,10 @@ class RecModel(nn.Layer):
super().__init__()
backbone_config = config["Backbone"]
backbone_name = backbone_config.pop("name")
self.backbone = getattr(backbone_name)(**backbone_config)
self.backbone = eval(backbone_name)(**backbone_config)
if "BackboneStopLayer" in config:
backbone_stop_layer = config["BackboneStopLayer"]
self.backbone.stop_layer(backbone_stop_layer)
backbone_stop_layer = config["BackboneStopLayer"]["name"]
self.backbone.stop_after(backbone_stop_layer)
if "Neck" in config:
self.neck = build_gear(config["Neck"])
......@@ -55,10 +53,11 @@ class RecModel(nn.Layer):
else:
self.head = None
def forward(self, x):
y = self.backbone(x)
def forward(self, x, label):
x = self.backbone(x)
if self.neck is not None:
y = self.neck(y)
x = self.neck(x)
y = x
if self.head is not None:
y = self.head(y)
return y
y = self.head(x, label)
return {"features": x, "logits": y}
......@@ -21,8 +21,8 @@ Arch:
Backbone:
name: "ResNet50_last_stage_stride1"
pretrained: True
BackboneStoplayer:
name: "adaptive_avg_pool2d_1"
BackboneStopLayer:
name: "adaptive_avg_pool2d_0"
Neck:
name: "VehicleNeck"
in_channels: 2048
......@@ -91,7 +91,7 @@ DataLoader:
sampler:
name: DistributedRandomIdentitySampler
batch_size: 64
batch_size: 128
num_instances: 2
drop_last: False
shuffle: True
......@@ -117,7 +117,7 @@ DataLoader:
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 64
batch_size: 128
drop_last: False
shuffle: False
loader:
......
......@@ -25,8 +25,8 @@ Arch:
Backbone:
name: "ResNet50_last_stage_stride1"
pretrained: True
BackboneStoplayer:
name: "adaptive_avg_pool2d_1"
BackboneStopLayer:
name: "adaptive_avg_pool2d_0"
Neck:
name: "VehicleNeck"
in_channels: 2048
......@@ -43,9 +43,9 @@ Loss:
Train:
- CELoss:
weight: 1.0
- TripletLossV2:
- SupConLoss:
weight: 1.0
margin: 0.5
views: 2
Eval:
- CELoss:
weight: 1.0
......@@ -70,7 +70,7 @@ DataLoader:
dataset:
name: "VeriWild"
image_root: "/work/dataset/VeRI-Wild/images/"
cls_label_path: "/work/dataset/VeRI-Wild/train_test_split/debug_train.txt"
cls_label_path: "/work/dataset/VeRI-Wild/train_test_split/train_list_start0.txt"
transform_ops:
- ResizeImage:
size: 224
......@@ -105,7 +105,7 @@ DataLoader:
dataset:
name: "VeriWild"
image_root: "/work/dataset/VeRI-Wild/images"
cls_label_path: "/work/dataset/VeRI-Wild/train_test_split/debug_test_query.txt"
cls_label_path: "/work/dataset/VeRI-Wild/train_test_split/test_3000_id_query.txt"
transform_ops:
- ResizeImage:
size: 224
......@@ -128,7 +128,7 @@ DataLoader:
dataset:
name: "VeriWild"
image_root: "/work/dataset/VeRI-Wild/images"
cls_label_path: "/work/dataset/VeRI-Wild/train_test_split/debug_test.txt"
cls_label_path: "/work/dataset/VeRI-Wild/train_test_split/test_3000_id.txt"
transform_ops:
- ResizeImage:
size: 224
......
......@@ -11,6 +11,7 @@ from .msmloss import MSMLoss
from .npairsloss import NpairsLoss
from .trihardloss import TriHardLoss
from .triplet import TripletLoss, TripletLossV2
from .supconloss import SupConLoss
class CombinedLoss(nn.Layer):
......
......@@ -32,6 +32,7 @@ class SupConLoss(nn.Layer):
Returns:
A loss scalar.
"""
features = features["features"]
if self.num_ids is None:
self.num_ids = int(features.shape[0] / self.views)
......@@ -104,4 +105,4 @@ class SupConLoss(nn.Layer):
loss = -(self.temperature / self.base_temperature) * mean_log_prob_pos
loss = paddle.mean(loss.reshape([anchor_count, batch_size]))
return loss
return {"SupConLoss": loss}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册