提交 bba0cf8f 编写于 作者: D dongshuilong

add CompCars train

上级 d58fd3b7
...@@ -21,7 +21,7 @@ from . import backbone ...@@ -21,7 +21,7 @@ from . import backbone
from . import head from . import head
from .backbone import * from .backbone import *
from .head import * from .head import *
from .utils import * from .utils import *
__all__ = ["build_model", "RecModel"] __all__ = ["build_model", "RecModel"]
...@@ -43,20 +43,24 @@ class RecModel(nn.Layer): ...@@ -43,20 +43,24 @@ class RecModel(nn.Layer):
backbone_name = backbone_config.pop("name") backbone_name = backbone_config.pop("name")
self.backbone = eval(backbone_name)(**backbone_config) self.backbone = eval(backbone_name)(**backbone_config)
assert "Stoplayer" in config, "Stoplayer should be specified in retrieval task \ assert "Stoplayer" in config, "Stoplayer should be specified in retrieval task \
please specified a Stoplayer config" please specified a Stoplayer config"
stop_layer_config = config["Stoplayer"] stop_layer_config = config["Stoplayer"]
self.backbone.stop_after(stop_layer_config["name"]) self.backbone.stop_after(stop_layer_config["name"])
if stop_layer_config.get("embedding_size", 0) > 0: if stop_layer_config.get("embedding_size", 0) > 0:
self.neck = nn.Linear(stop_layer_config["output_dim"], stop_layer_config["embedding_size"]) # self.neck = nn.Linear(stop_layer_config["output_dim"], stop_layer_config["embedding_size"])
self.neck = nn.Conv2D(stop_layer_config["output_dim"],
stop_layer_config["embedding_size"])
embedding_size = stop_layer_config["embedding_size"] embedding_size = stop_layer_config["embedding_size"]
else: else:
self.neck = None self.neck = None
embedding_size = stop_layer_config["output_dim"] embedding_size = stop_layer_config["output_dim"]
assert "Head" in config, "Head should be specified in retrieval task \ assert "Head" in config, "Head should be specified in retrieval task \
please specify a Head config" please specify a Head config"
config["Head"]["embedding_size"] = embedding_size config["Head"]["embedding_size"] = embedding_size
self.head = build_head(config["Head"]) self.head = build_head(config["Head"])
...@@ -65,4 +69,4 @@ class RecModel(nn.Layer): ...@@ -65,4 +69,4 @@ class RecModel(nn.Layer):
if self.neck is not None: if self.neck is not None:
x = self.neck(x) x = self.neck(x)
y = self.head(x, label) y = self.head(x, label)
return {"features":x, "logits":y} return {"features": x, "logits": y}
...@@ -16,35 +16,44 @@ import paddle ...@@ -16,35 +16,44 @@ import paddle
import paddle.nn as nn import paddle.nn as nn
import math import math
class ArcMargin(nn.Layer): class ArcMargin(nn.Layer):
def __init__(self, embedding_size, def __init__(self,
class_num, embedding_size,
margin=0.5, class_num,
scale=80.0, margin=0.5,
easy_margin=False): scale=80.0,
easy_margin=False):
super(ArcMargin, self).__init__() super(ArcMargin, self).__init__()
self.embedding_size = embedding_size self.embedding_size = embedding_size
self.class_num = class_num self.class_num = class_num
self.margin = margin self.margin = margin
self.scale = scale self.scale = scale
self.easy_margin = easy_margin self.easy_margin = easy_margin
weight_attr = paddle.ParamAttr(initializer = paddle.nn.initializer.XavierNormal()) weight_attr = paddle.ParamAttr(
self.fc = nn.Linear(self.embedding_size, self.class_num, weight_attr=weight_attr, bias_attr=False) initializer=paddle.nn.initializer.XavierNormal())
self.fc = nn.Linear(
self.embedding_size,
self.class_num,
weight_attr=weight_attr,
bias_attr=False)
def forward(self, input, label): def forward(self, input, label):
input_norm = paddle.sqrt(paddle.sum(paddle.square(input), axis=1, keepdim=True)) input_norm = paddle.sqrt(
paddle.sum(paddle.square(input), axis=1, keepdim=True))
input = paddle.divide(input, input_norm) input = paddle.divide(input, input_norm)
weight = self.fc.weight weight = self.fc.weight
weight_norm = paddle.sqrt(paddle.sum(paddle.square(weight), axis=0, keepdim=True)) weight_norm = paddle.sqrt(
paddle.sum(paddle.square(weight), axis=0, keepdim=True))
weight = paddle.divide(weight, weight_norm) weight = paddle.divide(weight, weight_norm)
cos = paddle.matmul(input, weight) cos = paddle.matmul(input, weight)
sin = paddle.sqrt(1.0 - paddle.square(cos) + 1e-6) sin = paddle.sqrt(1.0 - paddle.square(cos) + 1e-6)
cos_m = math.cos(self.margin) cos_m = math.cos(self.margin)
sin_m = math.sin(self.margin) sin_m = math.sin(self.margin)
phi = cos * cos_m - sin * sin_m phi = cos * cos_m - sin * sin_m
th = math.cos(self.margin) * (-1) th = math.cos(self.margin) * (-1)
mm = math.sin(self.margin) * self.margin mm = math.sin(self.margin) * self.margin
...@@ -55,11 +64,12 @@ class ArcMargin(nn.Layer): ...@@ -55,11 +64,12 @@ class ArcMargin(nn.Layer):
one_hot = paddle.nn.functional.one_hot(label, self.class_num) one_hot = paddle.nn.functional.one_hot(label, self.class_num)
one_hot = paddle.squeeze(one_hot, axis=[1]) one_hot = paddle.squeeze(one_hot, axis=[1])
output = paddle.multiply(one_hot, phi) + paddle.multiply((1.0 - one_hot), cos) output = paddle.multiply(one_hot, phi) + paddle.multiply(
output = output * self.scale (1.0 - one_hot), cos)
output = output * self.scale
return output return output
def _paddle_where_more_than(self, target, limit, x, y): def _paddle_where_more_than(self, target, limit, x, y):
mask = paddle.cast( x = (target > limit), dtype='float32') mask = paddle.cast(x=(target > limit), dtype='float32')
output = paddle.multiply(mask, x) + paddle.multiply((1.0 - mask), y) output = paddle.multiply(mask, x) + paddle.multiply((1.0 - mask), y)
return output return output
...@@ -46,8 +46,8 @@ class CELoss(nn.Layer): ...@@ -46,8 +46,8 @@ class CELoss(nn.Layer):
if self.epsilon is not None: if self.epsilon is not None:
class_num = logits.shape[-1] class_num = logits.shape[-1]
label = self._labelsmoothing(label, class_num) label = self._labelsmoothing(label, class_num)
x = -F.log_softmax(x, axis=-1) x = -F.log_softmax(logits, axis=-1)
loss = paddle.sum(x * label, axis=-1) loss = paddle.sum(logits * label, axis=-1)
else: else:
if label.shape[-1] == logits.shape[-1]: if label.shape[-1] == logits.shape[-1]:
label = F.softmax(label, axis=-1) label = F.softmax(label, axis=-1)
......
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: "./output/"
device: "gpu"
class_num: 431
save_interval: 1
eval_during_train: True
eval_interval: 1
epochs: 160
print_batch_step: 10
use_visualdl: False
# used for static mode and model export
image_shape: [3, 224, 224]
save_inference_dir: "./inference"
# model architecture
RecModel:
Backbone: "ResNet50"
Stoplayer: "adaptive_avg_pool2d_0"
embedding_size: 512
Head:
name: "ArcMargin"
embedding_size: 512
class_num: 431
margin: 0.15
scale: 32
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
weight: 1.0
- TripletLossV2:
weight: 1.0
margin: 0.5
Optimizer:
name: Momentum
momentum: 0.9
lr:
name: MultiStepDecay
learning_rate: 0.01
decay_epochs: [30, 60, 70, 80, 90, 100, 120, 140]
gamma: 0.5
verbose: False
last_epoch: -1
regularizer:
name: 'L2'
coeff: 0.0005
# data loader for train and eval
DataLoader:
Train:
dataset:
name: "CompCars"
image_root: "/work/dataset/CompCars/image/"
label_root: "/work/dataset/CompCars/label/"
bbox_crop: True
cls_label_path: "/work/dataset/CompCars/train_test_split/classification/train_label.txt"
transform_ops:
- ResizeImage:
size: 224
- RandFlipImage:
flip_code: 1
- AugMix:
prob: 0.5
- NormalizeImage:
scale: 0.00392157
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- RandomErasing:
EPSILON: 0.5
sl: 0.02
sh: 0.4
r1: 0.3
mean: [0., 0., 0.]
sampler:
name: DistributedRandomIdentitySampler
batch_size: 128
num_instances: 2
drop_last: False
shuffle: True
loader:
num_workers: 6
use_shared_memory: False
Eval:
# TOTO: modify to the latest trainer
dataset:
name: "CompCars"
image_root: "/work/dataset/CompCars/image/"
label_root: "/work/dataset/CompCars/label/"
cls_label_path: "/work/dataset/CompCars/train_test_split/classification/test_label.txt"
bbox_crop: True
transform_ops:
- ResizeImage:
size: 224
- NormalizeImage:
scale: 0.00392157
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 64
drop_last: False
shuffle: False
loader:
num_workers: 6
use_shared_memory: False
Infer:
infer_imgs: "docs/images/whl/demo.jpg"
batch_size: 10
transforms:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
PostProcess:
name: Topk
topk: 5
class_id_map_file: "ppcls/utils/imagenet1k_label_list.txt"
Metric:
Train:
- Topk:
k: [1, 5]
Eval:
- Topk:
k: [1, 5]
...@@ -25,14 +25,17 @@ from . import samplers ...@@ -25,14 +25,17 @@ from . import samplers
from .dataset.imagenet_dataset import ImageNetDataset from .dataset.imagenet_dataset import ImageNetDataset
from .dataset.multilabel_dataset import MultiLabelDataset from .dataset.multilabel_dataset import MultiLabelDataset
from .dataset.common_dataset import create_operators from .dataset.common_dataset import create_operators
from .dataset.vehicle_dataset import CompCars, VeriWild
# sampler # sampler
from .samplers import DistributedRandomIdentitySampler from .samplers import DistributedRandomIdentitySampler
from .preprocess import transform from .preprocess import transform
def build_dataloader(config, mode, device, seed=None): def build_dataloader(config, mode, device, seed=None):
assert mode in ['Train', 'Eval', 'Test'], "Mode should be Train, Eval or Test." assert mode in ['Train', 'Eval', 'Test'
], "Mode should be Train, Eval or Test."
# build dataset # build dataset
config_dataset = config[mode]['dataset'] config_dataset = config[mode]['dataset']
config_dataset = copy.deepcopy(config_dataset) config_dataset = copy.deepcopy(config_dataset)
...@@ -76,7 +79,7 @@ def build_dataloader(config, mode, device, seed=None): ...@@ -76,7 +79,7 @@ def build_dataloader(config, mode, device, seed=None):
batch_ops = create_operators(batch_transform) batch_ops = create_operators(batch_transform)
batch_collate_fn = mix_collate_fn batch_collate_fn = mix_collate_fn
else: else:
batch_collate_fn = None batch_collate_fn = None
# build dataloader # build dataloader
config_loader = config[mode]['loader'] config_loader = config[mode]['loader']
...@@ -105,9 +108,10 @@ def build_dataloader(config, mode, device, seed=None): ...@@ -105,9 +108,10 @@ def build_dataloader(config, mode, device, seed=None):
collate_fn=batch_collate_fn) collate_fn=batch_collate_fn)
logger.info("build data_loader({}) success...".format(data_loader)) logger.info("build data_loader({}) success...".format(data_loader))
return data_loader return data_loader
''' '''
# TODO: fix the format # TODO: fix the format
def build_dataloader(config, mode, device, seed=None): def build_dataloader(config, mode, device, seed=None):
......
...@@ -14,17 +14,10 @@ ...@@ -14,17 +14,10 @@
from __future__ import print_function from __future__ import print_function
import io
import tarfile
import numpy as np import numpy as np
from PIL import Image #all use default backend
import paddle
from paddle.io import Dataset from paddle.io import Dataset
import pickle
import os
import cv2 import cv2
import random
from ppcls.data import preprocess from ppcls.data import preprocess
from ppcls.data.preprocess import transform from ppcls.data.preprocess import transform
...@@ -65,7 +58,7 @@ class CommonDataset(Dataset): ...@@ -65,7 +58,7 @@ class CommonDataset(Dataset):
self.labels = [] self.labels = []
self._load_anno() self._load_anno()
def _load_anno(self): def _load_anno(self):
pass pass
def __getitem__(self, idx): def __getitem__(self, idx):
...@@ -89,4 +82,3 @@ class CommonDataset(Dataset): ...@@ -89,4 +82,3 @@ class CommonDataset(Dataset):
@property @property
def class_num(self): def class_num(self):
return len(set(self.labels)) return len(set(self.labels))
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import paddle
from paddle.io import Dataset
import os
import cv2
from ppcls.data import preprocess
from ppcls.data.preprocess import transform
from ppcls.utils import logger
from .common_dataset import create_operators
class CompCars(Dataset):
def __init__(self,
image_root,
cls_label_path,
label_root=None,
transform_ops=None,
bbox_crop=False):
self._img_root = image_root
self._cls_path = cls_label_path
self._label_root = label_root
if transform_ops:
self._transform_ops = create_operators(transform_ops)
self._bbox_crop = bbox_crop
self._dtype = paddle.get_default_dtype()
self._load_anno()
def _load_anno(self):
assert os.path.exists(self._cls_path)
assert os.path.exists(self._img_root)
if self._bbox_crop:
assert os.path.exists(self._label_root)
self.images = []
self.labels = []
self.bboxes = []
with open(self._cls_path) as fd:
lines = fd.readlines()
for l in lines:
l = l.strip().split()
if not self._bbox_crop:
self.images.append(os.path.join(self._img_root, l[0]))
self.labels.append(int(l[1]))
else:
label_path = os.path.join(self._label_root,
l[0].split('.')[0] + '.txt')
assert os.path.exists(label_path)
bbox = open(label_path).readlines()[-1].strip().split()
bbox = [int(x) for x in bbox]
self.images.append(os.path.join(self._img_root, l[0]))
self.labels.append(int(l[1]))
self.bboxes.append(bbox)
assert os.path.exists(self.images[-1])
def __getitem__(self, idx):
img = cv2.imread(self.images[idx])
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
if self._bbox_crop:
bbox = self.bboxes[idx]
img = img[bbox[1]:bbox[3], bbox[0]:bbox[2], :]
if self._transform_ops:
img = transform(img, self._transform_ops)
img = img.transpose((2, 0, 1))
return (img, self.labels[idx])
def __len__(self):
return len(self.images)
@property
def class_num(self):
return len(set(self.labels))
class VeriWild(Dataset):
def __init__(
self,
image_root,
cls_label_path,
transform_ops=None, ):
self._img_root = image_root
self._cls_path = cls_label_path
if transform_ops:
self._transform_ops = create_operators(transform_ops)
self._dtype = paddle.get_default_dtype()
self._load_anno()
def _load_anno(self):
assert os.path.exists(self._cls_path)
assert os.path.exists(self._img_root)
self.images = []
self.labels = []
self.cameras = []
with open(self._cls_path) as fd:
lines = fd.readlines()
for l in lines:
l = l.strip().split()
self.images.append(os.path.join(self._img_root, l[0]))
self.labels.append(int(l[1]))
self.cameras.append(int(l[2]))
assert os.path.exists(self.images[-1])
def __getitem__(self, idx):
try:
img = cv2.imread(self.images[idx])
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
if self._transform_ops:
img = transform(img, self._transform_ops)
img = img.transpose((2, 0, 1))
return (img, self.labels[idx], self.cameras[idx])
except Exception as ex:
logger.error("Exception occured when parse line: {} with msg: {}".
format(self.images[idx], ex))
rnd_idx = np.random.randint(self.__len__())
return self.__getitem__(rnd_idx)
def __len__(self):
return len(self.images)
@property
def class_num(self):
return len(set(self.labels))
...@@ -29,11 +29,13 @@ from PIL import Image ...@@ -29,11 +29,13 @@ from PIL import Image
from .autoaugment import ImageNetPolicy from .autoaugment import ImageNetPolicy
from .functional import augmentations from .functional import augmentations
class OperatorParamError(ValueError): class OperatorParamError(ValueError):
""" OperatorParamError """ OperatorParamError
""" """
pass pass
class DecodeImage(object): class DecodeImage(object):
""" decode image """ """ decode image """
...@@ -235,7 +237,12 @@ class AugMix(object): ...@@ -235,7 +237,12 @@ class AugMix(object):
""" Perform AugMix augmentation and compute mixture. """ Perform AugMix augmentation and compute mixture.
""" """
def __init__(self, prob=0.5, aug_prob_coeff=0.1, mixture_width=3, mixture_depth=1, aug_severity=1): def __init__(self,
prob=0.5,
aug_prob_coeff=0.1,
mixture_width=3,
mixture_depth=1,
aug_severity=1):
""" """
Args: Args:
prob: Probability of taking augmix prob: Probability of taking augmix
...@@ -264,14 +271,16 @@ class AugMix(object): ...@@ -264,14 +271,16 @@ class AugMix(object):
ws = np.float32( ws = np.float32(
np.random.dirichlet([self.aug_prob_coeff] * self.mixture_width)) np.random.dirichlet([self.aug_prob_coeff] * self.mixture_width))
m = np.float32(np.random.beta(self.aug_prob_coeff, self.aug_prob_coeff)) m = np.float32(
np.random.beta(self.aug_prob_coeff, self.aug_prob_coeff))
# image = Image.fromarray(image) # image = Image.fromarray(image)
mix = np.zeros([image.shape[1], image.shape[0], 3]) mix = np.zeros([image.shape[1], image.shape[0], 3])
for i in range(self.mixture_width): for i in range(self.mixture_width):
image_aug = image.copy() image_aug = image.copy()
image_aug = Image.fromarray(image_aug) image_aug = Image.fromarray(image_aug)
depth = self.mixture_depth if self.mixture_depth > 0 else np.random.randint(1, 4) depth = self.mixture_depth if self.mixture_depth > 0 else np.random.randint(
1, 4)
for _ in range(depth): for _ in range(depth):
op = np.random.choice(self.augmentations) op = np.random.choice(self.augmentations)
image_aug = op(image_aug, self.aug_severity) image_aug = op(image_aug, self.aug_severity)
......
...@@ -30,7 +30,7 @@ from ppcls.utils.misc import AverageMeter ...@@ -30,7 +30,7 @@ from ppcls.utils.misc import AverageMeter
from ppcls.utils import logger from ppcls.utils import logger
from ppcls.data import build_dataloader from ppcls.data import build_dataloader
from ppcls.arch import build_model from ppcls.arch import build_model
from ppcls.arch.loss_metrics import build_loss from ppcls.losses import build_loss
from ppcls.arch.loss_metrics import build_metrics from ppcls.arch.loss_metrics import build_metrics
from ppcls.optimizer import build_optimizer from ppcls.optimizer import build_optimizer
from ppcls.utils.save_load import load_dygraph_pretrain from ppcls.utils.save_load import load_dygraph_pretrain
......
...@@ -5,17 +5,20 @@ from __future__ import print_function ...@@ -5,17 +5,20 @@ from __future__ import print_function
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
class TripletLossV2(nn.Layer): class TripletLossV2(nn.Layer):
"""Triplet loss with hard positive/negative mining. """Triplet loss with hard positive/negative mining.
Args: Args:
margin (float): margin for triplet. margin (float): margin for triplet.
""" """
def __init__(self, margin=0.5):
def __init__(self, margin=0.5, normalize_feature=True):
super(TripletLossV2, self).__init__() super(TripletLossV2, self).__init__()
self.margin = margin self.margin = margin
self.ranking_loss = paddle.nn.loss.MarginRankingLoss(margin=margin) self.ranking_loss = paddle.nn.loss.MarginRankingLoss(margin=margin)
self.normalize_feature = normalize_feature
def forward(self, input, target, normalize_feature=True): def forward(self, input, target):
""" """
Args: Args:
inputs: feature matrix with shape (batch_size, feat_dim) inputs: feature matrix with shape (batch_size, feat_dim)
...@@ -23,28 +26,25 @@ class TripletLossV2(nn.Layer): ...@@ -23,28 +26,25 @@ class TripletLossV2(nn.Layer):
""" """
inputs = input["features"] inputs = input["features"]
if normalize_feature: if self.normalize_feature:
inputs = 1. * inputs / (paddle.expand_as( inputs = 1. * inputs / (paddle.expand_as(
paddle.norm(inputs, p=2, axis=-1, keepdim=True), inputs) + paddle.norm(
1e-12) inputs, p=2, axis=-1, keepdim=True), inputs) + 1e-12)
bs = inputs.shape[0] bs = inputs.shape[0]
# compute distance # compute distance
dist = paddle.pow(inputs, 2).sum(axis=1, keepdim=True).expand([bs, bs]) dist = paddle.pow(inputs, 2).sum(axis=1, keepdim=True).expand([bs, bs])
dist = dist + dist.t() dist = dist + dist.t()
dist = paddle.addmm(input=dist, dist = paddle.addmm(
x=inputs, input=dist, x=inputs, y=inputs.t(), alpha=-2.0, beta=1.0)
y=inputs.t(),
alpha=-2.0,
beta=1.0)
dist = paddle.clip(dist, min=1e-12).sqrt() dist = paddle.clip(dist, min=1e-12).sqrt()
# hard negative mining # hard negative mining
is_pos = paddle.expand(target, (bs, bs)).equal( is_pos = paddle.expand(target, (
paddle.expand(target, (bs, bs)).t()) bs, bs)).equal(paddle.expand(target, (bs, bs)).t())
is_neg = paddle.expand(target, (bs, bs)).not_equal( is_neg = paddle.expand(target, (
paddle.expand(target, (bs, bs)).t()) bs, bs)).not_equal(paddle.expand(target, (bs, bs)).t())
# `dist_ap` means distance(anchor, positive) # `dist_ap` means distance(anchor, positive)
## both `dist_ap` and `relative_p_inds` with shape [N, 1] ## both `dist_ap` and `relative_p_inds` with shape [N, 1]
...@@ -56,14 +56,14 @@ class TripletLossV2(nn.Layer): ...@@ -56,14 +56,14 @@ class TripletLossV2(nn.Layer):
dist_an, relative_n_inds = paddle.min( dist_an, relative_n_inds = paddle.min(
paddle.reshape(dist[is_neg], (bs, -1)), axis=1, keepdim=True) paddle.reshape(dist[is_neg], (bs, -1)), axis=1, keepdim=True)
''' '''
dist_ap = paddle.max(paddle.reshape(paddle.masked_select(dist, is_pos), dist_ap = paddle.max(paddle.reshape(
(bs, -1)), paddle.masked_select(dist, is_pos), (bs, -1)),
axis=1, axis=1,
keepdim=True) keepdim=True)
# `dist_an` means distance(anchor, negative) # `dist_an` means distance(anchor, negative)
# both `dist_an` and `relative_n_inds` with shape [N, 1] # both `dist_an` and `relative_n_inds` with shape [N, 1]
dist_an = paddle.min(paddle.reshape(paddle.masked_select(dist, is_neg), dist_an = paddle.min(paddle.reshape(
(bs, -1)), paddle.masked_select(dist, is_neg), (bs, -1)),
axis=1, axis=1,
keepdim=True) keepdim=True)
# shape [N] # shape [N]
...@@ -84,6 +84,7 @@ class TripletLoss(nn.Layer): ...@@ -84,6 +84,7 @@ class TripletLoss(nn.Layer):
Args: Args:
margin (float): margin for triplet. margin (float): margin for triplet.
""" """
def __init__(self, margin=1.0): def __init__(self, margin=1.0):
super(TripletLoss, self).__init__() super(TripletLoss, self).__init__()
self.margin = margin self.margin = margin
...@@ -101,15 +102,12 @@ class TripletLoss(nn.Layer): ...@@ -101,15 +102,12 @@ class TripletLoss(nn.Layer):
# Compute pairwise distance, replace by the official when merged # Compute pairwise distance, replace by the official when merged
dist = paddle.pow(inputs, 2).sum(axis=1, keepdim=True).expand([bs, bs]) dist = paddle.pow(inputs, 2).sum(axis=1, keepdim=True).expand([bs, bs])
dist = dist + dist.t() dist = dist + dist.t()
dist = paddle.addmm(input=dist, dist = paddle.addmm(
x=inputs, input=dist, x=inputs, y=inputs.t(), alpha=-2.0, beta=1.0)
y=inputs.t(),
alpha=-2.0,
beta=1.0)
dist = paddle.clip(dist, min=1e-12).sqrt() dist = paddle.clip(dist, min=1e-12).sqrt()
mask = paddle.equal(target.expand([bs, bs]), mask = paddle.equal(
target.expand([bs, bs]).t()) target.expand([bs, bs]), target.expand([bs, bs]).t())
mask_numpy_idx = mask.numpy() mask_numpy_idx = mask.numpy()
dist_ap, dist_an = [], [] dist_ap, dist_an = [], []
for i in range(bs): for i in range(bs):
...@@ -118,18 +116,16 @@ class TripletLoss(nn.Layer): ...@@ -118,18 +116,16 @@ class TripletLoss(nn.Layer):
# dist_ap.append(dist_ap_i) # dist_ap.append(dist_ap_i)
dist_ap.append( dist_ap.append(
max([ max([
dist[i][j] dist[i][j] if mask_numpy_idx[i][j] == True else float(
if mask_numpy_idx[i][j] == True else float("-inf") "-inf") for j in range(bs)
for j in range(bs)
]).unsqueeze(0)) ]).unsqueeze(0))
# dist_an_i = paddle.to_tensor(dist[i].numpy()[mask_numpy_idx[i] == False].min(), dtype='float64').unsqueeze(0) # dist_an_i = paddle.to_tensor(dist[i].numpy()[mask_numpy_idx[i] == False].min(), dtype='float64').unsqueeze(0)
# dist_an_i.stop_gradient = False # dist_an_i.stop_gradient = False
# dist_an.append(dist_an_i) # dist_an.append(dist_an_i)
dist_an.append( dist_an.append(
min([ min([
dist[i][k] dist[i][k] if mask_numpy_idx[i][k] == False else float(
if mask_numpy_idx[i][k] == False else float("inf") "inf") for k in range(bs)
for k in range(bs)
]).unsqueeze(0)) ]).unsqueeze(0))
dist_ap = paddle.concat(dist_ap, axis=0) dist_ap = paddle.concat(dist_ap, axis=0)
...@@ -139,4 +135,3 @@ class TripletLoss(nn.Layer): ...@@ -139,4 +135,3 @@ class TripletLoss(nn.Layer):
y = paddle.ones_like(dist_an) y = paddle.ones_like(dist_an)
loss = self.ranking_loss(dist_an, dist_ap, y) loss = self.ranking_loss(dist_an, dist_ap, y)
return {"TripletLoss": loss} return {"TripletLoss": loss}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册