提交 0224c4fb 编写于 作者: K kbChen 提交者: qingqing01

add npairs loss (#1891)

* add npairs loss
* fix metric learning readme
上级 6fd31ee2
...@@ -76,7 +76,7 @@ python train_pair.py \ ...@@ -76,7 +76,7 @@ python train_pair.py \
``` ```
## Evaluation ## Evaluation
Evaluation is to evaluate the performance of a trained model. One can download [pretrained models](#supported-models) and set its path to ```path_to_pretrain_model```. Then Recall@Rank-1 can be obtained by running the following command: Evaluation is to evaluate the performance of a trained model. You should set model path to ```path_to_pretrain_model```. Then Recall@Rank-1 can be obtained by running the following command:
``` ```
python eval.py \ python eval.py \
--model=ResNet50 \ --model=ResNet50 \
...@@ -103,9 +103,11 @@ For comparation, many metric learning models with different neural networks and ...@@ -103,9 +103,11 @@ For comparation, many metric learning models with different neural networks and
|fine-tuned with triplet | 78.37% | 79.21% |fine-tuned with triplet | 78.37% | 79.21%
|fine-tuned with quadruplet | 78.10% | 79.59% |fine-tuned with quadruplet | 78.10% | 79.59%
|fine-tuned with eml | 79.32% | 80.11% |fine-tuned with eml | 79.32% | 80.11%
|fine-tuned with npairs | - | 79.81%
## Reference ## Reference
- ArcFace: Additive Angular Margin Loss for Deep Face Recognition [link](https://arxiv.org/abs/1801.07698) - ArcFace: Additive Angular Margin Loss for Deep Face Recognition [link](https://arxiv.org/abs/1801.07698)
- Margin Sample Mining Loss: A Deep Learning Based Method for Person Re-identification [link](https://arxiv.org/abs/1710.00478) - Margin Sample Mining Loss: A Deep Learning Based Method for Person Re-identification [link](https://arxiv.org/abs/1710.00478)
- Large Scale Strongly Supervised Ensemble Metric Learning, with Applications to Face Verification and Retrieval [link](https://arxiv.org/abs/1212.6094) - Large Scale Strongly Supervised Ensemble Metric Learning, with Applications to Face Verification and Retrieval [link](https://arxiv.org/abs/1212.6094)
- Improved Deep Metric Learning with Multi-class N-pair Loss Objective [link](http://www.nec-labs.com/uploads/images/Department-Images/MediaAnalytics/papers/nips16_npairmetriclearning.pdf)
...@@ -103,9 +103,11 @@ python infer.py \ ...@@ -103,9 +103,11 @@ python infer.py \
|使用triplet微调 | 78.37% | 79.21% |使用triplet微调 | 78.37% | 79.21%
|使用quadruplet微调 | 78.10% | 79.59% |使用quadruplet微调 | 78.10% | 79.59%
|使用eml微调 | 79.32% | 80.11% |使用eml微调 | 79.32% | 80.11%
|使用npairs微调 | - | 79.81%
## 引用 ## 引用
- ArcFace: Additive Angular Margin Loss for Deep Face Recognition [链接](https://arxiv.org/abs/1801.07698) - ArcFace: Additive Angular Margin Loss for Deep Face Recognition [链接](https://arxiv.org/abs/1801.07698)
- Margin Sample Mining Loss: A Deep Learning Based Method for Person Re-identification [链接](https://arxiv.org/abs/1710.00478) - Margin Sample Mining Loss: A Deep Learning Based Method for Person Re-identification [链接](https://arxiv.org/abs/1710.00478)
- Large Scale Strongly Supervised Ensemble Metric Learning, with Applications to Face Verification and Retrieval [链接](https://arxiv.org/abs/1212.6094) - Large Scale Strongly Supervised Ensemble Metric Learning, with Applications to Face Verification and Retrieval [链接](https://arxiv.org/abs/1212.6094)
- Improved Deep Metric Learning with Multi-class N-pair Loss Objective [链接](http://www.nec-labs.com/uploads/images/Department-Images/MediaAnalytics/papers/nips16_npairmetriclearning.pdf)
...@@ -6,3 +6,4 @@ from .arcmarginloss import ArcMarginLoss ...@@ -6,3 +6,4 @@ from .arcmarginloss import ArcMarginLoss
from .tripletloss import TripletLoss from .tripletloss import TripletLoss
from .quadrupletloss import QuadrupletLoss from .quadrupletloss import QuadrupletLoss
from .emlloss import EmlLoss from .emlloss import EmlLoss
from .npairsloss import NpairsLoss
...@@ -38,7 +38,7 @@ class EmlLoss(): ...@@ -38,7 +38,7 @@ class EmlLoss():
loss = loss1 + loss2 - bias loss = loss1 + loss2 - bias
return loss return loss
def loss(self, input): def loss(self, input, label=None):
samples_each_class = self.samples_each_class samples_each_class = self.samples_each_class
batch_size = self.cal_loss_batch_size batch_size = self.cal_loss_batch_size
#input = fluid.layers.l2_normalize(input, axis=1) #input = fluid.layers.l2_normalize(input, axis=1)
......
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle.fluid as fluid
from utility import get_gpu_num
class NpairsLoss():
def __init__(self,
train_batch_size = 160,
samples_each_class=2,
reg_lambda=0.01):
self.samples_each_class = samples_each_class
assert(self.samples_each_class == 2)
self.train_batch_size = train_batch_size
num_gpus = get_gpu_num()
assert(train_batch_size % num_gpus == 0)
self.cal_loss_batch_size = train_batch_size // num_gpus
assert(self.cal_loss_batch_size % samples_each_class == 0)
self.reg_lambda = reg_lambda
def loss(self, input, label=None):
reg_lambda = self.reg_lambda
samples_each_class = self.samples_each_class
batch_size = self.cal_loss_batch_size
num_class = batch_size // samples_each_class
fea_dim = input.shape[1]
input = fluid.layers.reshape(input, shape = [-1, fea_dim])
feature = fluid.layers.reshape(input, shape = [-1, samples_each_class, fea_dim])
label = fluid.layers.reshape(label, shape = [-1, samples_each_class])
label = fluid.layers.cast(label, dtype='float32')
if samples_each_class == 2:
anchor_fea, positive_fea = fluid.layers.split(feature, num_or_sections = 2, dim = 1)
anchor_lab, positive_lab = fluid.layers.split(label, num_or_sections = 2, dim = 1)
else:
anchor_fea, positive_fea = fluid.layers.split(feature, num_or_sections = [1, samples_each_class-1], dim = 1)
anchor_lab, positive_lab = fluid.layers.split(label, num_or_sections = [1, samples_each_class-1], dim = 1)
anchor_fea = fluid.layers.reshape(anchor_fea, shape = [-1, fea_dim])
positive_fea = fluid.layers.reshape(positive_fea, shape = [-1, fea_dim])
positive_fea_trans = fluid.layers.transpose(positive_fea, perm = [1, 0])
similarity_matrix = fluid.layers.mul(anchor_fea, positive_fea_trans)
anchor_lab = fluid.layers.expand(x=anchor_lab, expand_times=[1, batch_size-num_class])
positive_lab_tran = fluid.layers.transpose(positive_lab, perm = [1, 0])
positive_lab_tran = fluid.layers.expand(x=positive_lab_tran, expand_times=[num_class, 1])
label_remapped = fluid.layers.equal(anchor_lab, positive_lab_tran)
label_remapped = fluid.layers.cast(label_remapped, dtype='float32') / (samples_each_class-1)
label_remapped.stop_gradient = True
out = fluid.layers.softmax(input=similarity_matrix, use_cudnn=False)
xentloss = fluid.layers.cross_entropy(input=out, label=label_remapped, soft_label=True)
xentloss = fluid.layers.mean(x=xentloss)
reg = fluid.layers.reduce_mean(fluid.layers.reduce_sum(fluid.layers.square(input), dim=1))
l2loss = 0.5 * reg_lambda * reg
return xentloss + l2loss
...@@ -19,7 +19,7 @@ class QuadrupletLoss(): ...@@ -19,7 +19,7 @@ class QuadrupletLoss():
self.cal_loss_batch_size = train_batch_size // num_gpus self.cal_loss_batch_size = train_batch_size // num_gpus
assert(self.cal_loss_batch_size % samples_each_class == 0) assert(self.cal_loss_batch_size % samples_each_class == 0)
def loss(self, input): def loss(self, input, label=None):
#input = fluid.layers.l2_normalize(input, axis=1) #input = fluid.layers.l2_normalize(input, axis=1)
input_norm = fluid.layers.sqrt(fluid.layers.reduce_sum(fluid.layers.square(input), dim=1)) input_norm = fluid.layers.sqrt(fluid.layers.reduce_sum(fluid.layers.square(input), dim=1))
input = fluid.layers.elementwise_div(input, input_norm, axis=0) input = fluid.layers.elementwise_div(input, input_norm, axis=0)
......
...@@ -8,7 +8,7 @@ class TripletLoss(): ...@@ -8,7 +8,7 @@ class TripletLoss():
def __init__(self, margin=0.1): def __init__(self, margin=0.1):
self.margin = margin self.margin = margin
def loss(self, input): def loss(self, input, label=None):
margin = self.margin margin = self.margin
fea_dim = input.shape[1] # number of channels fea_dim = input.shape[1] # number of channels
#input = fluid.layers.l2_normalize(input, axis=1) #input = fluid.layers.l2_normalize(input, axis=1)
......
...@@ -32,7 +32,7 @@ add_arg('image_shape', str, "3,224,224", "input image size") ...@@ -32,7 +32,7 @@ add_arg('image_shape', str, "3,224,224", "input image size")
add_arg('class_dim', int, 11318 , "Class number.") add_arg('class_dim', int, 11318 , "Class number.")
add_arg('lr', float, 0.01, "set learning rate.") add_arg('lr', float, 0.01, "set learning rate.")
add_arg('lr_strategy', str, "piecewise_decay", "Set the learning rate decay strategy.") add_arg('lr_strategy', str, "piecewise_decay", "Set the learning rate decay strategy.")
add_arg('lr_steps', str, "30000", "step of lr") add_arg('lr_steps', str, "15000,25000", "step of lr")
add_arg('total_iter_num', int, 30000, "total_iter_num") add_arg('total_iter_num', int, 30000, "total_iter_num")
add_arg('display_iter_step', int, 10, "display_iter_step.") add_arg('display_iter_step', int, 10, "display_iter_step.")
add_arg('test_iter_step', int, 1000, "test_iter_step.") add_arg('test_iter_step', int, 1000, "test_iter_step.")
......
...@@ -19,6 +19,7 @@ import reader ...@@ -19,6 +19,7 @@ import reader
from losses import TripletLoss from losses import TripletLoss
from losses import QuadrupletLoss from losses import QuadrupletLoss
from losses import EmlLoss from losses import EmlLoss
from losses import NpairsLoss
from utility import add_arguments, print_arguments from utility import add_arguments, print_arguments
from utility import fmt_time, recall_topk, get_gpu_num from utility import fmt_time, recall_topk, get_gpu_num
...@@ -46,6 +47,7 @@ add_arg('model_save_dir', str, "output", "model save directory") ...@@ -46,6 +47,7 @@ add_arg('model_save_dir', str, "output", "model save directory")
add_arg('loss_name', str, "triplet", "Set the loss type to use.") add_arg('loss_name', str, "triplet", "Set the loss type to use.")
add_arg('samples_each_class', int, 2, "samples_each_class.") add_arg('samples_each_class', int, 2, "samples_each_class.")
add_arg('margin', float, 0.1, "margin.") add_arg('margin', float, 0.1, "margin.")
add_arg('npairs_reg_lambda', float, 0.01, "npairs reg lambda.")
# yapf: enable # yapf: enable
model_list = [m for m in dir(models) if "__" not in m] model_list = [m for m in dir(models) if "__" not in m]
...@@ -90,7 +92,13 @@ def net_config(image, label, model, args, is_train): ...@@ -90,7 +92,13 @@ def net_config(image, label, model, args, is_train):
train_batch_size = args.train_batch_size, train_batch_size = args.train_batch_size,
samples_each_class = args.samples_each_class, samples_each_class = args.samples_each_class,
) )
cost = metricloss.loss(out) elif args.loss_name == "npairs":
metricloss = NpairsLoss(
train_batch_size = args.train_batch_size,
samples_each_class = args.samples_each_class,
reg_lambda = args.npairs_reg_lambda,
)
cost = metricloss.loss(out, label)
avg_cost = fluid.layers.mean(x=cost) avg_cost = fluid.layers.mean(x=cost)
return avg_cost, out return avg_cost, out
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册