未验证 提交 3647da6d 编写于 作者: H HydrogenSulfate 提交者: GitHub

Merge pull request #2383 from HydrogenSulfate/xbm_final

Add XBM model
......@@ -27,7 +27,7 @@ import cv2
import numpy as np
import importlib
from PIL import Image
from paddle.vision.transforms import ToTensor, Normalize
from paddle.vision.transforms import ToTensor, Normalize, Resize, CenterCrop
from paddleclas.deploy.python.det_preprocess import DetNormalizeImage, DetPadStride, DetPermute, DetResize
......
......@@ -97,9 +97,9 @@ The following table summarizes the accuracy metrics of the 3 configurations of t
| configuration file | recall@1(\%) | mAP(\%) | reference recall@1(\%) | reference mAP(\%) | pretrained model download address | inference model download address |
| ------------------ | ------------ | ------- | ---------------------- | ----------------- | --------------------------------- | -------------------------------- |
| baseline.yaml | 88.45 | 74.37 | 87.7 | 74.0 | [download link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/reid/pretrain/baseline_pretrained.pdparams) | [ Download link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/reid/inference/baseline_infer.tar) |
| softmax_triplet.yaml | 94.29 | 85.57 | 94.1 | 85.7 | [download link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/reid/pretrain/softmax_triplet_pretrained.pdparams) | [ Download link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/reid/inference/softmax_triplet_infer.tar) |
| softmax_triplet_with_center.yaml | 94.50 | 85.82 | 94.5 | 85.9 | [Download link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/reid/pretrain/softmax_triplet_with_center_pretrained.pdparams) | [ Download link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/reid/inference/softmax_triplet_with_center_infer.tar) |
| baseline.yaml | 88.45 | 74.37 | 87.7 | 74.0 | [download link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/metric_learning/reid/baseline_pretrained.pdparams) | [ Download link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/metric_learning/reid/baseline_infer.tar) |
| softmax_triplet.yaml | 94.29 | 85.57 | 94.1 | 85.7 | [download link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/metric_learning/reid/softmax_triplet_pretrained.pdparams) | [ Download link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/metric_learning/reid/softmax_triplet_infer.tar) |
| softmax_triplet_with_center.yaml | 94.50 | 85.82 | 94.5 | 85.9 | [Download link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/metric_learning/reid/softmax_triplet_with_center_pretrained.pdparams) | [ Download link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/metric_learning/reid/softmax_triplet_with_center_infer.tar) |
Note: The above reference indicators are obtained by using the author's open source code to train on our equipment for many times. Due to different system environment, torch version, CUDA version and other reasons, there may be slight differences with the indicators provided by the author.
......@@ -171,14 +171,14 @@ Prepare the `*.pdparams` model parameter file for evaluation. You can use the tr
-o Global.pretrained_model="./output/RecModel/latest"
```
- Take the trained model as an example, download [softmax_triplet_with_center_pretrained.pdparams](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/reid/pretrain/softmax_triplet_with_center_pretrained.pdparams) to `PaddleClas/ In the pretrained_models` folder, execute the following command to evaluate.
- Take the trained model as an example, download [softmax_triplet_with_center_pretrained.pdparams](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/metric_learning/reid/softmax_triplet_with_center_pretrained.pdparams) to `PaddleClas/ In the pretrained_models` folder, execute the following command to evaluate.
```shell
# download model
cd PaddleClas
mkdir pretrained_models
cd pretrained_models
wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/reid/pretrain/softmax_triplet_with_center_pretrained.pdparams
wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/metric_learning/reid/softmax_triplet_with_center_pretrained.pdparams
cd..
# Evaluate
python3.7 tools/eval.py \
......@@ -260,7 +260,7 @@ You can convert the model file saved during training into an inference model and
- Or download and unzip the inference model we provide
```shell
cd PaddleClas/deploy
wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/reid/inference/softmax_triplet_with_center_infer.tar
wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/metric_learning/reid/softmax_triplet_with_center_infer.tar
tar xf softmax_triplet_with_center_infer.tar
cd ../
```
......
......@@ -97,9 +97,9 @@
| 配置文件 | recall@1(\%) | mAP(\%) | 参考recall@1(\%) | 参考mAP(\%) | 预训练模型下载地址 | inference模型下载地址 |
| -------------------------------- | ------------ | ------- | ---------------- | ----------- | --------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------ |
| baseline.yaml | 88.45 | 74.37 | 87.7 | 74.0 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/reid/pretrain/baseline_pretrained.pdparams) | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/reid/inference/baseline_infer.tar) |
| softmax_triplet.yaml | 94.29 | 85.57 | 94.1 | 85.7 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/reid/pretrain/softmax_triplet_pretrained.pdparams) | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/reid/inference/softmax_triplet_infer.tar) |
| softmax_triplet_with_center.yaml | 94.50 | 85.82 | 94.5 | 85.9 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/reid/pretrain/softmax_triplet_with_center_pretrained.pdparams) | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/reid/inference/softmax_triplet_with_center_infer.tar) |
| baseline.yaml | 88.45 | 74.37 | 87.7 | 74.0 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/metric_learning/reid/baseline_pretrained.pdparams) | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/metric_learning/reid/baseline_infer.tar) |
| softmax_triplet.yaml | 94.29 | 85.57 | 94.1 | 85.7 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/metric_learning/reid/softmax_triplet_pretrained.pdparams) | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/metric_learning/reid/softmax_triplet_infer.tar) |
| softmax_triplet_with_center.yaml | 94.50 | 85.82 | 94.5 | 85.9 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/metric_learning/reid/softmax_triplet_with_center_pretrained.pdparams) | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/metric_learning/reid/softmax_triplet_with_center_infer.tar) |
注:上述参考指标由使用作者开源的代码在我们的设备上训练多次得到,由于系统环境、torch版本、CUDA版本不同等原因,与作者提供的指标可能存在略微差异。
......@@ -171,14 +171,14 @@
-o Global.pretrained_model="./output/RecModel/latest"
```
- 以训练好的模型为例,下载 [softmax_triplet_with_center_pretrained.pdparams](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/reid/pretrain/softmax_triplet_with_center_pretrained.pdparams)`PaddleClas/pretrained_models` 文件夹中,执行如下命令即可进行评估。
- 以训练好的模型为例,下载 [softmax_triplet_with_center_pretrained.pdparams](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/metric_learning/reid/softmax_triplet_with_center_pretrained.pdparams)`PaddleClas/pretrained_models` 文件夹中,执行如下命令即可进行评估。
```shell
# 下载模型
cd PaddleClas
mkdir pretrained_models
cd pretrained_models
wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/reid/pretrain/softmax_triplet_with_center_pretrained.pdparams
wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/metric_learning/reid/softmax_triplet_with_center_pretrained.pdparams
cd ..
# 评估
python3.7 tools/eval.py \
......@@ -259,7 +259,7 @@
- 或者下载并解压我们提供的 inference 模型
```shell
cd PaddleClas/deploy
wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/reid/inference/softmax_triplet_with_center_infer.tar
wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/metric_learning/reid/softmax_triplet_with_center_infer.tar
tar xf softmax_triplet_with_center_infer.tar
cd ../
```
......
# Metric Learning
----
## 目录
- [1. 简介](#1-简介)
- [2. 应用](#2-应用)
- [3. 算法](#3-算法)
- [3.1 Classification based](#31-classification-based)
- [3.2 Pairwise based](#32-pairwise-based)
- [4. 相关模型](#4-相关模型)
- [4.1 Cross-Batch Memory for Embedding Learning](#41-cross-batch-memory-for-embedding-learning)
* [1. 简介](#1)
* [2. 应用](#2)
* [3. 算法](#3)
* [3.1 Classification based](#3.1)
* [3.2 Pairwise based](#3.2)
<a name='1'></a>
## 1. 简介
在机器学习中,我们经常会遇到度量数据间距离的问题。一般来说,对于可度量的数据,我们可以直接通过欧式距离(Euclidean Distance),向量内积(Inner Product)或者是余弦相似度(Cosine Similarity)来进行计算。但对于非结构化数据来说,我们却很难进行这样的操作,如计算一段视频和一首音乐的匹配程度。由于数据格式的不同,我们难以直接进行上述的向量运算,但先验知识告诉我们 ED(laugh_video, laugh_music) < ED(laugh_video, blue_music), 如何去有效得表征这种”距离”关系呢? 这就是 Metric Learning 所要研究的课题。
在机器学习中,我们经常会遇到度量数据间距离的问题。一般来说,对于可度量的数据,我们可以直接通过欧式距离(Euclidean Distance),向量内积(Inner Product)或者是余弦相似度(Cosine Similarity)来进行计算。但对于非结构化数据来说,我们却很难进行这样的操作,如计算一段视频和一首音乐的匹配程度。由于数据格式的不同,我们难以直接进行上述的向量运算,但先验知识告诉我们 ED(laugh_video, laugh_music) < ED(laugh_video, blue_music), 如何去有效得表征这种”距离”关系呢? 这就是 Metric Learning 所要研究的课题。
Metric learning 全称是 Distance Metric Learning,它是通过机器学习的形式,根据训练数据,自动构造出一种基于特定任务的度量函数。Metric Learning 的目标是学习一个变换函数(线性非线性均可)L,将数据点从原始的向量空间映射到一个新的向量空间,在新的向量空间里相似点的距离更近,非相似点的距离更远,使得度量更符合任务的要求,如下图所示。 Deep Metric Learning,就是用深度神经网络来拟合这个变换函数。
Metric learning 全称是 Distance Metric Learning,它是通过机器学习的形式,根据训练数据,自动构造出一种基于特定任务的度量函数。Metric Learning 的目标是学习一个变换函数(线性非线性均可)L,将数据点从原始的向量空间映射到一个新的向量空间,在新的向量空间里相似点的距离更近,非相似点的距离更远,使得度量更符合任务的要求,如下图所示。 Deep Metric Learning,就是用深度神经网络来拟合这个变换函数。
![example](../../images/ml_illustration.jpg)
<a name='2'></a>
## 2. 应用
Metric Learning 技术在生活实际中应用广泛,如我们耳熟能详的人脸识别(Face Recognition)、行人重识别(Person ReID)、图像检索(Image Retrieval)、细粒度分类(Fine-grained classification)等。随着深度学习在工业实践中越来越广泛的应用,目前大家研究的方向基本都偏向于 Deep Metric Learning(DML).
一般来说, DML 包含三个部分: 特征提取网络来 map embedding, 一个采样策略来将一个 mini-batch 里的样本组合成很多个 sub-set, 最后 loss function 在每个 sub-set 上计算 loss. 如下图所示:
![image](../../images/ml_pipeline.jpg)
<a name='3'></a>
## 3. 算法
Metric Learning 主要有如下两种学习范式:
<a name='3.1'></a>
### 3.1 Classification based:
这是一类基于分类标签的 Metric Learning 方法。这类方法通过将每个样本分类到正确的类别中,来学习有效的特征表示,学习过程中需要每个样本的显式标签参与 Loss 计算。常见的算法有 [L2-Softmax](https://arxiv.org/abs/1703.09507), [Large-margin Softmax](https://arxiv.org/abs/1612.02295), [Angular Softmax](https://arxiv.org/pdf/1704.08063.pdf), [NormFace](https://arxiv.org/abs/1704.06369), [AM-Softmax](https://arxiv.org/abs/1801.05599), [CosFace](https://arxiv.org/abs/1801.09414), [ArcFace](https://arxiv.org/abs/1801.07698)等。
这类方法也被称作是 proxy-based, 因为其本质上优化的是样本和一堆 proxies 之间的相似度。
<a name='3.2'></a>
### 3.2 Pairwise based:
这是一类基于样本对的学习范式。他以样本对作为输入,通过直接学习样本对之间的相似度来得到有效的特征表示,常见的算法包括:[Contrastive loss](http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf), [Triplet loss](https://arxiv.org/abs/1503.03832), [Lifted-Structure loss](https://arxiv.org/abs/1511.06452), [N-pair loss](https://papers.nips.cc/paper/2016/file/6b180037abbebea991d8b1232f8a8ca9-Paper.pdf), [Multi-Similarity loss](https://arxiv.org/pdf/1904.06627.pdf)
### 3.1 Classification based
这是一类基于分类标签的 Metric Learning 方法。这类方法通过将每个样本分类到正确的类别中,来学习有效的特征表示,学习过程中需要每个样本的显式标签参与 Loss 计算。常见的算法有 [L2-Softmax](https://arxiv.org/abs/1703.09507), [Large-margin Softmax](https://arxiv.org/abs/1612.02295), [Angular Softmax](https://arxiv.org/pdf/1704.08063.pdf), [NormFace](https://arxiv.org/abs/1704.06369), [AM-Softmax](https://arxiv.org/abs/1801.05599), [CosFace](https://arxiv.org/abs/1801.09414), [ArcFace](https://arxiv.org/abs/1801.07698)等。
这类方法也被称作是 proxy-based, 因为其本质上优化的是样本和一堆 proxies 之间的相似度。
### 3.2 Pairwise based
这是一类基于样本对的学习范式。他以样本对作为输入,通过直接学习样本对之间的相似度来得到有效的特征表示,常见的算法包括:[Contrastive loss](http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf), [Triplet loss](https://arxiv.org/abs/1503.03832), [Lifted-Structure loss](https://arxiv.org/abs/1511.06452), [N-pair loss](https://papers.nips.cc/paper/2016/file/6b180037abbebea991d8b1232f8a8ca9-Paper.pdf), [Multi-Similarity loss](https://arxiv.org/pdf/1904.06627.pdf)
2020 年发表的[CircleLoss](https://arxiv.org/abs/2002.10857),从一个全新的视角统一了两种学习范式,让研究人员和从业者对 Metric Learning 问题有了更进一步的思考。
## 4. 相关模型
### 4.1 Cross-Batch Memory for Embedding Learning
作者提出了利用历史模型产生的特征来近似当前模型的特征,在一定程度上解耦了候选正负样本的数量与 mini-batch 大小的关系,让informative样本的挖掘更加高效,解决了以往的 metric learning 方法只能利用 mini-batch 内部的样本的局限。
具体信息可参考[xbm.md](../training/metric_learning/xbm.md)
简体中文 | English(TODO)
# Cross-Batch Memory for Embedding Learning
论文出处:[Cross-Batch Memory for Embedding Learning](https://arxiv.org/pdf/1912.06798.pdf)
## 目录
- [1. 原理介绍](#1-原理介绍)
- [2. 精度指标](#2-精度指标)
- [3. 数据准备](#3-数据准备)
- [4. 模型训练](#4-模型训练)
- [5. 模型评估与推理部署](#5-模型评估与推理部署)
- [5.1 模型评估](#51-模型评估)
- [5.2 模型推理](#52-模型推理)
- [5.2.1 推理模型准备](#521-推理模型准备)
- [5.2.2 基于 Python 预测引擎推理](#522-基于-python-预测引擎推理)
- [5.2.3 基于 C++ 预测引擎推理](#523-基于-c-预测引擎推理)
- [5.4 服务化部署](#54-服务化部署)
- [5.5 端侧部署](#55-端侧部署)
- [5.6 Paddle2ONNX 模型转换与预测](#56-paddle2onnx-模型转换与预测)
- [6. 总结](#6-总结)
- [6.1 方法总结与对比](#61-方法总结与对比)
- [6.2 使用建议/FAQ](#62-使用建议faq)
- [7. 参考资料](#7-参考资料)
## 1. 原理介绍
作者首先通过实验观察到 metric learning 方法学习到的特征在模型稳定后变化总是在一个可接受的范围内,因此提出了利用历史模型产生的特征来近似当前模型的特征,将这些历史特征按FIFO原则放入一个队列中,在一定程度上解耦了候选正负样本的数量与 mini-batch 大小的关系,且由于历史特征不记录梯度,开销相对较小,最终让informative样本的挖掘更加高效,解决了以往的 metric learning 方法只能利用 mini-batch 内部的样本的局限。
<img src="../../../images/algorithm_introduction/xbm.jpg" width="80%">
## 2. 精度指标
以下表格总结了复现的 Cross-Batch Memory for Embedding Learning 在 SOP 数据集上的精度指标,
| 配置文件 | recall@1(\%) | mAP(\%) | 参考recall@1(\%) | 参考mAP(\%) | 预训练模型下载地址 | inference模型下载地址 |
| -------- | ------------ | ------- | ---------------- | ----------- | ------------------ | --------------------- |
| xbm_resnet50.yaml | 81.0 | 61.9 | 80.6 | - | [xbm_resnet50_pretrained.pdparams](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/metric_learning/xbm/xbm_resnet50_pretrained.pdparams) | [xbm_resnet50_infer.tar](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/metric_learning/xbm/xbm_resnet50_infer.tar) |
接下来主要以`xbm_resnet50.yaml`配置和训练好的模型文件为例,展示在 SOP 数据集上进行训练、测试、推理的过程。
## 3. 数据准备
下载 [Stanford Online Products (SOP)](https://aistudio.baidu.com/aistudio/datasetdetail/5103) 数据集,解压到 `PaddleClas/dataset/` 下,将文件夹重命名为 `SOP` ,并组织成以下文件结构:
```shell
PaddleClas/dataset/
└── SOP/
├── coffee_maker_final/ # coffee_maker_final类别图片文件夹
├── kettle_final/ # kettle_final类别图片文件夹
├── ... # ...类别图片文件夹
├── ... # ...类别图片文件夹
├── train_list.txt # 训练集路径文件
└── test_list.txt # gallery(query)集路径文件
```
## 4. 模型训练
1. 执行以下命令开始训练
单卡训练:
```shell
python3.7 tools/train.py -c ./ppcls/configs/metric_learning/xbm_resnet50.yaml
```
注:单卡训练大约需要4个小时。
2. 查看训练日志和保存的模型参数文件
训练过程中会在屏幕上实时打印loss等指标信息,同时会保存日志文件`train.log`、模型参数文件`*.pdparams`、优化器参数文件`*.pdopt`等内容到`Global.output_dir`指定的文件夹下,默认在`PaddleClas/output/RecModel/`文件夹下。
## 5. 模型评估与推理部署
### 5.1 模型评估
准备用于评估的`*.pdparams`模型参数文件,可以使用训练好的模型,也可以使用[4. 模型训练](#4-模型训练)中保存的模型。
- 以训练过程中保存的`latest.pdparams`为例,执行如下命令即可进行评估。
```shell
python3.7 tools/eval.py \
-c ./ppcls/configs/metric_learning/xbm_resnet50.yaml \
-o Global.pretrained_model="./output/RecModel/latest"
```
- 以训练好的模型为例,下载 [xbm_resnet50_pretrained.pdparams](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/metric_learning/xbm/xbm_resnet50_pretrained.pdparams)`PaddleClas/pretrained_models` 文件夹中,执行如下命令即可进行评估。
```shell
# 下载模型
cd PaddleClas
mkdir pretrained_models
cd pretrained_models
wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/metric_learning/xbm/xbm_resnet50_pretrained.pdparams
cd ..
# 评估
python3.7 tools/eval.py \
-c ./ppcls/configs/metric_learning/xbm_resnet50.yaml \
-o Global.pretrained_model="pretrained_models/xbm_resnet50_pretrained"
```
注:`pretrained_model` 后填入的地址不需要加 `.pdparams` 后缀,在程序运行时会自动补上。
- 查看输出结果
```log
...
...
ppcls INFO: query feature calculation process: [0/237]
ppcls INFO: query feature calculation process: [20/237]
ppcls INFO: query feature calculation process: [40/237]
ppcls INFO: query feature calculation process: [60/237]
ppcls INFO: query feature calculation process: [80/237]
ppcls INFO: query feature calculation process: [100/237]
ppcls INFO: query feature calculation process: [120/237]
ppcls INFO: query feature calculation process: [140/237]
ppcls INFO: query feature calculation process: [160/237]
ppcls INFO: query feature calculation process: [180/237]
ppcls INFO: query feature calculation process: [200/237]
ppcls INFO: query feature calculation process: [220/237]
ppcls INFO: Build query done, all feat shape: [60502, 128], begin to eval..
ppcls INFO: re_ranking=False
ppcls INFO: [Eval][Epoch 0][Avg]recall1: 0.81083, recall5: 0.89263, mAP: 0.62097
```
默认评估日志保存在`PaddleClas/output/RecModel/eval.log`中,可以看到我们提供的 `xbm_pretrained.pdparams` 模型在 SOP 数据集上的评估指标为recall@1=0.81083,recall@5=0.89263,mAP=0.62097
- 使用re-ranking功能提升评估精度
可参考 [ReID #41-模型评估](../../algorithm_introduction/ReID.md) 文档的re-ranking使用方法。
**注**:目前re-ranking的计算复杂度较高,因此默认不启用。
### 5.2 模型推理
#### 5.2.1 推理模型准备
可以将训练过程中保存的模型文件转换成 inference 模型并推理,或者使用我们提供的转换好的 inference 模型直接进行推理
- 将训练过程中保存的模型文件转换成 inference 模型,同样以 `latest.pdparams` 为例,执行以下命令进行转换
```shell
python3.7 tools/export_model.py \
-c ./ppcls/configs/metric_learning/xbm_resnet50.yaml \
-o Global.pretrained_model="output/RecModel/latest" \
-o Global.save_inference_dir="./deploy/xbm_resnet50_infer"
```
- 或者下载并解压我们提供的 inference 模型
```shell
cd ./deploy
wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/metric_learning/xbm/xbm_resnet50_infer.tar
tar -xf xbm_resnet50_infer.tar
cd ../
```
#### 5.2.2 基于 Python 预测引擎推理
1. 修改 `PaddleClas/deploy/configs/inference_rec.yaml`
-`infer_imgs:` 后的路径段改为 SOP 中 query 文件夹下的任意一张图片路径(下方配置使用的是`111085122871_0.jpg`图片的路径)
-`rec_inference_model_dir:` 后的字段改为解压出来的 xbm_resnet50_infer 文件夹路径
-`transform_ops:` 字段下的预处理配置改为 `xbm_resnet50.yaml``Eval.Query.dataset` 下的预处理配置
```yaml
Global:
infer_imgs: "../dataset/SOP/bicycle_final/111085122871_0.JPG"
rec_inference_model_dir: "./xbm_resnet50_infer"
batch_size: 1
use_gpu: False
enable_mkldnn: True
cpu_num_threads: 10
enable_benchmark: False
use_fp16: False
ir_optim: True
use_tensorrt: False
gpu_mem: 8000
enable_profile: False
RecPreProcess:
transform_ops:
- Resize:
size: 256
- CenterCrop:
size: 224
- ToTensor:
- Normalize:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
RecPostProcess: null
```
2. 执行推理命令
```shell
cd ./deploy/
python3.7 python/predict_rec.py -c ./configs/inference_rec.yaml
```
3. 查看输出结果,实际结果为一个长度2048的向量,表示输入图片经过模型转换后得到的特征向量
```log
111085122871_0.JPG: [ 0.02560742 0.05221584 ... 0.11635944 -0.18817757
0.07170864]
```
推理时的输出向量储存在[predict_rec.py](../../../../deploy/python/predict_rec.py#L131)的 `result_dict` 变量中。
4. 批量预测,将配置文件中`infer_imgs:`后的路径改为为文件夹即可,如`../dataset/SOP/bicycle_final`,会预测并逐个输出query下所有图片的特征向量。
#### 5.2.3 基于 C++ 预测引擎推理
PaddleClas 提供了基于 C++ 预测引擎推理的示例,您可以参考[服务器端 C++ 预测](../../deployment/image_classification/cpp/linux.md)来完成相应的推理部署。如果您使用的是 Windows 平台,可以参考基于 Visual Studio 2019 Community CMake 编译指南完成相应的预测库编译和模型预测工作。
### 5.4 服务化部署
Paddle Serving 提供高性能、灵活易用的工业级在线推理服务。Paddle Serving 支持 RESTful、gRPC、bRPC 等多种协议,提供多种异构硬件和多种操作系统环境下推理解决方案。更多关于Paddle Serving 的介绍,可以参考Paddle Serving 代码仓库。
PaddleClas 提供了基于 Paddle Serving 来完成模型服务化部署的示例,您可以参考[模型服务化部署](../../deployment/PP-ShiTu/paddle_serving.md)来完成相应的部署工作。
### 5.5 端侧部署
Paddle Lite 是一个高性能、轻量级、灵活性强且易于扩展的深度学习推理框架,定位于支持包括移动端、嵌入式以及服务器端在内的多硬件平台。更多关于 Paddle Lite 的介绍,可以参考Paddle Lite 代码仓库。
PaddleClas 提供了基于 Paddle Lite 来完成模型端侧部署的示例,您可以参考[端侧部署](../../deployment/image_classification/paddle_lite.md)来完成相应的部署工作。
### 5.6 Paddle2ONNX 模型转换与预测
Paddle2ONNX 支持将 PaddlePaddle 模型格式转化到 ONNX 模型格式。通过 ONNX 可以完成将 Paddle 模型到多种推理引擎的部署,包括TensorRT/OpenVINO/MNN/TNN/NCNN,以及其它对 ONNX 开源格式进行支持的推理引擎或硬件。更多关于 Paddle2ONNX 的介绍,可以参考Paddle2ONNX 代码仓库。
PaddleClas 提供了基于 Paddle2ONNX 来完成 inference 模型转换 ONNX 模型并作推理预测的示例,您可以参考[Paddle2ONNX 模型转换与预测](../../deployment/image_classification/paddle2onnx.md)来完成相应的部署工作。
### 6. 总结
#### 6.1 方法总结与对比
上述算法能快速地迁移至多数的检索模型中,能进一步提升检索模型的性能,
#### 6.2 使用建议/FAQ
SOP 数据集比较小,可以尝试训练多次取最高精度。
### 7. 参考资料
1. [Cross-Batch Memory for Embedding Learning](https://arxiv.org/pdf/1912.06798.pdf)
......@@ -72,6 +72,7 @@ from .model_zoo.convnext import ConvNeXt_tiny
from .model_zoo.cae import cae_base_patch16_224, cae_large_patch16_224
from .variant_models.resnet_variant import ResNet50_last_stage_stride1
from .variant_models.resnet_variant import ResNet50_adaptive_max_pool2d
from .variant_models.vgg_variant import VGG19Sigmoid
from .variant_models.pp_lcnet_variant import PPLCNet_x2_5_Tanh
from .variant_models.pp_lcnetv2_variant import PPLCNetV2_base_ShiTu
......
from paddle.nn import Conv2D
from paddle import nn
from ..legendary_models.resnet import ResNet50, MODEL_URLS, _load_pretrained
__all__ = ["ResNet50_last_stage_stride1"]
__all__ = ["ResNet50_last_stage_stride1", "ResNet50_adaptive_max_pool2d"]
def ResNet50_last_stage_stride1(pretrained=False, use_ssld=False, **kwargs):
def replace_function(conv, pattern):
new_conv = Conv2D(
new_conv = nn.Conv2D(
in_channels=conv._in_channels,
out_channels=conv._out_channels,
kernel_size=conv._kernel_size,
......@@ -21,3 +21,15 @@ def ResNet50_last_stage_stride1(pretrained=False, use_ssld=False, **kwargs):
model.upgrade_sublayer(pattern, replace_function)
_load_pretrained(pretrained, model, MODEL_URLS["ResNet50"], use_ssld)
return model
def ResNet50_adaptive_max_pool2d(pretrained=False, use_ssld=False, **kwargs):
def replace_function(pool, pattern):
new_pool = nn.AdaptiveMaxPool2D(output_size=1)
return new_pool
pattern = ["avg_pool"]
model = ResNet50(pretrained=False, use_ssld=use_ssld, **kwargs)
model.upgrade_sublayer(pattern, replace_function)
_load_pretrained(pretrained, model, MODEL_URLS["ResNet50"], use_ssld)
return model
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: ./output/
device: gpu
save_interval: 1
eval_during_train: True
eval_interval: 1
epochs: 35
iter_per_epoch: &iter_per_epoch 1000
print_batch_step: 20
use_visualdl: False
eval_mode: retrieval
retrieval_feature_from: features # 'backbone' or 'features'
re_ranking: False
# used for static mode and model export
image_shape: [3, 224, 224]
save_inference_dir: ./inference
# model architecture
Arch:
name: RecModel
infer_output_key: features
infer_add_softmax: False
Backbone:
name: ResNet50_adaptive_max_pool2d
pretrained: https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/others/resnet50-19c8e357_torch2paddle.pdparams
stem_act: null
BackboneStopLayer:
name: flatten
Neck:
name: FC
embedding_size: 2048
class_num: &feat_dim 128
weight_attr:
initializer:
name: KaimingNormal
fan_in: *feat_dim
negative_slope: 0.0
nonlinearity: leaky_relu
bias_attr:
initializer:
name: Constant
value: 0.0
# loss function config for traing/eval process
Loss:
Train:
- ContrastiveLoss_XBM:
weight: 1.0
xbm_size: 55000
xbm_weight: 1.0
start_iter: 1000
margin: 0.5
embedding_size: *feat_dim
epsilon: 1.0e-5
normalize_feature: True
feature_from: features
Eval:
- ContrastiveLoss:
weight: 1.0
margin: 0.5
embedding_size: *feat_dim
normalize_feature: True
epsilon: 1.0e-5
feature_from: features
Optimizer:
name: Adam
lr:
name: ReduceOnPlateau
learning_rate: 0.0001
mode: max
factor: 0.1
patience: 4
threshold: 0.001
threshold_mode: rel
cooldown: 2
min_lr: 0.000005
epsilon: 1e-8
by_epoch: True
regularizer:
name: L2
coeff: 0.0005
# data loader for train and eval
DataLoader:
Train:
dataset:
name: VeriWild
image_root: ./dataset/SOP
cls_label_path: ./dataset/SOP/train_list.txt
backend: pil
transform_ops:
- Resize:
size: 256
- RandomResizedCrop:
scale: [0.2, 1]
size: 224
- RandomHorizontalFlip:
prob: 0.5
- ToTensor:
- Normalize:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
sampler:
name: DistributedRandomIdentitySampler
batch_size: 64
num_instances: 4
drop_last: False
shuffle: True
max_iters: *iter_per_epoch
loader:
num_workers: 8
use_shared_memory: True
Eval:
Gallery:
dataset:
name: VeriWild
image_root: ./dataset/SOP
cls_label_path: ./dataset/SOP/test_list.txt
backend: pil
transform_ops:
- Resize:
size: 256
- CenterCrop:
size: 224
- ToTensor:
- Normalize:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
sampler:
name: DistributedBatchSampler
batch_size: 256
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Query:
dataset:
name: VeriWild
image_root: ./dataset/SOP
cls_label_path: ./dataset/SOP/test_list.txt
backend: pil
transform_ops:
- Resize:
size: 256
- CenterCrop:
size: 224
- ToTensor:
- Normalize:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
sampler:
name: DistributedBatchSampler
batch_size: 256
drop_last: False
shuffle: False
loader:
num_workers: 8
use_shared_memory: True
Metric:
Eval:
- Recallk:
topk: [1, 5]
- mAP: {}
......@@ -20,7 +20,7 @@ import random
from collections import defaultdict
import numpy as np
from paddle.io import DistributedBatchSampler, Sampler
from paddle.io import DistributedBatchSampler
class DistributedRandomIdentitySampler(DistributedBatchSampler):
......@@ -31,14 +31,23 @@ class DistributedRandomIdentitySampler(DistributedBatchSampler):
batch_size (int): batch size
num_instances (int): number of instance(s) within an class
drop_last (bool): whether to discard the data at the end
max_iters (int): max iteration(s). Default to None.
"""
def __init__(self, dataset, batch_size, num_instances, drop_last, **args):
def __init__(self,
dataset,
batch_size,
num_instances,
drop_last,
max_iters=None,
**args):
assert batch_size % num_instances == 0, \
f"batch_size({batch_size}) must be divisible by num_instances({num_instances}) when using DistributedRandomIdentitySampler"
self.dataset = dataset
self.batch_size = batch_size
self.num_instances = num_instances
self.drop_last = drop_last
self.max_iters = max_iters
self.num_pids_per_batch = self.batch_size // self.num_instances
self.index_dic = defaultdict(list)
for index, pid in enumerate(self.dataset.labels):
......@@ -53,8 +62,9 @@ class DistributedRandomIdentitySampler(DistributedBatchSampler):
num = self.num_instances
self.length += num - num % self.num_instances
def __iter__(self):
def _prepare_batch(self):
batch_idxs_dict = defaultdict(list)
count = []
for pid in self.pids:
idxs = copy.deepcopy(self.index_dic[pid])
if len(idxs) < self.num_instances:
......@@ -67,27 +77,59 @@ class DistributedRandomIdentitySampler(DistributedBatchSampler):
if len(batch_idxs) == self.num_instances:
batch_idxs_dict[pid].append(batch_idxs)
batch_idxs = []
count = [len(batch_idxs_dict[pid]) for pid in self.pids]
count = np.array(count)
avai_pids = copy.deepcopy(self.pids)
final_idxs = []
while len(avai_pids) >= self.num_pids_per_batch:
selected_pids = random.sample(avai_pids, self.num_pids_per_batch)
for pid in selected_pids:
batch_idxs = batch_idxs_dict[pid].pop(0)
final_idxs.extend(batch_idxs)
if len(batch_idxs_dict[pid]) == 0:
avai_pids.remove(pid)
_sample_iter = iter(final_idxs)
batch_indices = []
for idx in _sample_iter:
batch_indices.append(idx)
if len(batch_indices) == self.batch_size:
return batch_idxs_dict, avai_pids, count
def __iter__(self):
# prepare
batch_idxs_dict, avai_pids, count = self._prepare_batch()
# sample
if self.max_iters is not None:
for _ in range(self.max_iters):
final_idxs = []
if len(avai_pids) < self.num_pids_per_batch:
batch_idxs_dict, avai_pids, count = self._prepare_batch()
selected_pids = np.random.choice(avai_pids,
self.num_pids_per_batch,
False, count / count.sum())
for pid in selected_pids:
batch_idxs = batch_idxs_dict[pid].pop(0)
final_idxs.extend(batch_idxs)
pid_idx = avai_pids.index(pid)
if len(batch_idxs_dict[pid]) == 0:
avai_pids.pop(pid_idx)
count = np.delete(count, pid_idx)
else:
count[pid_idx] = len(batch_idxs_dict[pid])
yield final_idxs
else:
final_idxs = []
while len(avai_pids) >= self.num_pids_per_batch:
selected_pids = random.sample(avai_pids,
self.num_pids_per_batch)
for pid in selected_pids:
batch_idxs = batch_idxs_dict[pid].pop(0)
final_idxs.extend(batch_idxs)
if len(batch_idxs_dict[pid]) == 0:
avai_pids.remove(pid)
_sample_iter = iter(final_idxs)
batch_indices = []
for idx in _sample_iter:
batch_indices.append(idx)
if len(batch_indices) == self.batch_size:
yield batch_indices
batch_indices = []
if not self.drop_last and len(batch_indices) > 0:
yield batch_indices
batch_indices = []
if not self.drop_last and len(batch_indices) > 0:
yield batch_indices
def __len__(self):
if self.drop_last:
if self.max_iters is not None:
return self.max_iters
elif self.drop_last:
return self.length // self.batch_size
else:
return (self.length + self.batch_size - 1) // self.batch_size
......@@ -25,11 +25,12 @@ class ImageNetDataset(CommonDataset):
Args:
image_root (str): image root, path to `ILSVRC2012`
cls_label_path (str): path to annotation file `train_list.txt` or 'val_list.txt`
cls_label_path (str): path to annotation file `train_list.txt` or `val_list.txt`
transform_ops (list, optional): list of transform op(s). Defaults to None.
delimiter (str, optional): delimiter. Defaults to None.
relabel (bool, optional): whether do relabel when original label do not starts from 0 or are discontinuous. Defaults to False.
"""
def __init__(self,
image_root,
cls_label_path,
......
......@@ -19,8 +19,7 @@ import paddle
from paddle.io import Dataset
import os
import cv2
from ppcls.data import preprocess
from PIL import Image
from ppcls.data.preprocess import transform
from ppcls.utils import logger
from .common_dataset import create_operators
......@@ -89,15 +88,30 @@ class CompCars(Dataset):
class VeriWild(Dataset):
def __init__(self, image_root, cls_label_path, transform_ops=None):
"""Dataset for Vehicle and other similar data structure, such as VeRI-Wild, SOP, Inshop...
Args:
image_root (str): image root
cls_label_path (str): path to annotation file
transform_ops (List[Callable], optional): list of transform op(s). Defaults to None.
backend (str, optional): pil or cv2. Defaults to "cv2".
relabel (bool, optional): whether do relabel when original label do not starts from 0 or are discontinuous. Defaults to False.
"""
def __init__(self,
image_root,
cls_label_path,
transform_ops=None,
backend="cv2",
relabel=False):
self._img_root = image_root
self._cls_path = cls_label_path
if transform_ops:
self._transform_ops = create_operators(transform_ops)
self.backend = backend
self._dtype = paddle.get_default_dtype()
self._load_anno()
self._load_anno(relabel)
def _load_anno(self):
def _load_anno(self, relabel):
assert os.path.exists(
self._cls_path), f"path {self._cls_path} does not exist."
assert os.path.exists(
......@@ -107,22 +121,40 @@ class VeriWild(Dataset):
self.cameras = []
with open(self._cls_path) as fd:
lines = fd.readlines()
if relabel:
label_set = set()
for line in lines:
line = line.strip().split()
label_set.add(np.int64(line[1]))
label_map = {
oldlabel: newlabel
for newlabel, oldlabel in enumerate(label_set)
}
for line in lines:
line = line.strip().split()
self.images.append(os.path.join(self._img_root, line[0]))
self.labels.append(np.int64(line[1]))
if relabel:
self.labels.append(label_map[np.int64(line[1])])
else:
self.labels.append(np.int64(line[1]))
if len(line) >= 3:
self.cameras.append(np.int64(line[2]))
assert os.path.exists(self.images[-1])
assert os.path.exists(self.images[-1]), \
f"path {self.images[-1]} does not exist."
self.has_camera = len(self.cameras) > 0
def __getitem__(self, idx):
try:
with open(self.images[idx], 'rb') as f:
img = f.read()
if self.backend == "cv2":
with open(self.images[idx], 'rb') as f:
img = f.read()
else:
img = Image.open(self.images[idx]).convert("RGB")
if self._transform_ops:
img = transform(img, self._transform_ops)
img = img.transpose((2, 0, 1))
if self.backend == "cv2":
img = img.transpose((2, 0, 1))
if self.has_camera:
return (img, self.labels[idx], self.cameras[idx])
else:
......
......@@ -24,6 +24,7 @@ from ppcls.data.preprocess.ops.grid import GridMask
from ppcls.data.preprocess.ops.operators import DecodeImage
from ppcls.data.preprocess.ops.operators import ResizeImage
from ppcls.data.preprocess.ops.operators import CropImage
from ppcls.data.preprocess.ops.operators import CenterCrop, Resize
from ppcls.data.preprocess.ops.operators import RandCropImage
from ppcls.data.preprocess.ops.operators import RandCropImageV2
from ppcls.data.preprocess.ops.operators import RandFlipImage
......@@ -34,6 +35,7 @@ from ppcls.data.preprocess.ops.operators import Pad
from ppcls.data.preprocess.ops.operators import ToTensor
from ppcls.data.preprocess.ops.operators import Normalize
from ppcls.data.preprocess.ops.operators import RandomHorizontalFlip
from ppcls.data.preprocess.ops.operators import RandomResizedCrop
from ppcls.data.preprocess.ops.operators import CropWithPadding
from ppcls.data.preprocess.ops.operators import RandomInterpolationAugment
from ppcls.data.preprocess.ops.operators import ColorJitter
......@@ -41,7 +43,6 @@ from ppcls.data.preprocess.ops.operators import RandomCropImage
from ppcls.data.preprocess.ops.operators import RandomRotation
from ppcls.data.preprocess.ops.operators import Padv2
from ppcls.data.preprocess.ops.operators import RandomRot90
from ppcls.data.preprocess.ops.operators import BlurImage
from .ops.operators import format_data
from ppcls.data.preprocess.batch_ops.batch_operators import MixupOperator, CutmixOperator, OpSampler, FmixOperator
......
......@@ -26,6 +26,7 @@ import cv2
import numpy as np
from PIL import Image, ImageOps, __version__ as PILLOW_VERSION
from paddle.vision.transforms import ColorJitter as RawColorJitter
from paddle.vision.transforms import CenterCrop, Resize
from paddle.vision.transforms import RandomRotation as RawRandomRotation
from paddle.vision.transforms import ToTensor, Normalize, RandomHorizontalFlip, RandomResizedCrop
from paddle.vision.transforms import functional as F
......@@ -509,7 +510,7 @@ class RandFlipImage(object):
assert flip_code in [-1, 0, 1
], "flip_code should be a value in [-1, 0, 1]"
self.flip_code = flip_code
@format_data
def __call__(self, img):
if random.randint(0, 1) == 1:
......
......@@ -23,7 +23,6 @@ from paddle import nn
import numpy as np
import random
from ppcls.utils.check import check_gpu
from ppcls.utils.misc import AverageMeter
from ppcls.utils import logger
from ppcls.utils.logger import init_logger
......@@ -43,6 +42,7 @@ from ppcls.data.utils.get_image_list import get_image_list
from ppcls.data.postprocess import build_postprocess
from ppcls.data import create_operators
from ppcls.engine.train import train_epoch
from ppcls.engine.train.utils import type_name
from ppcls.engine import evaluation
from ppcls.arch.gears.identity_head import IdentityHead
......@@ -347,6 +347,10 @@ class Engine(object):
self.max_iter = len(self.train_dataloader) - 1 if platform.system(
) == "Windows" else len(self.train_dataloader)
if self.config["Global"].get("iter_per_epoch", None):
# set max iteration per epoch mannualy, when training by iteration(s), such as XBM, FixMatch.
self.max_iter = self.config["Global"].get("iter_per_epoch")
self.max_iter = self.max_iter // self.update_freq * self.update_freq
for epoch_id in range(best_metric["epoch"] + 1,
......@@ -370,6 +374,13 @@ class Engine(object):
"eval_during_train"] and epoch_id % self.config["Global"][
"eval_interval"] == 0 and epoch_id > start_eval_epoch:
acc = self.eval(epoch_id)
# step lr (by epoch) according to given metric, such as acc
for i in range(len(self.lr_sch)):
if getattr(self.lr_sch[i], "by_epoch", False) and \
type_name(self.lr_sch[i]) == "ReduceOnPlateau":
self.lr_sch[i].step(acc)
if acc > best_metric["metric"]:
best_metric["metric"] = acc
best_metric["epoch"] = epoch_id
......
......@@ -15,15 +15,24 @@ from __future__ import absolute_import, division, print_function
import time
import paddle
from ppcls.engine.train.utils import update_loss, update_metric, log_info
from ppcls.engine.train.utils import update_loss, update_metric, log_info, type_name
from ppcls.utils import profiler
def train_epoch(engine, epoch_id, print_batch_step):
tic = time.time()
for iter_id, batch in enumerate(engine.train_dataloader):
if iter_id >= engine.max_iter:
break
if not hasattr(engine, "train_dataloader_iter"):
engine.train_dataloader_iter = iter(engine.train_dataloader)
for iter_id in range(engine.max_iter):
# fetch data batch from dataloader
try:
batch = engine.train_dataloader_iter.next()
except Exception:
engine.train_dataloader_iter = iter(engine.train_dataloader)
batch = engine.train_dataloader_iter.next()
profiler.add_profiler_step(engine.config["profiler_options"])
if iter_id == 5:
for key in engine.time_info:
......@@ -37,7 +46,7 @@ def train_epoch(engine, epoch_id, print_batch_step):
# image input
if engine.amp:
amp_level = engine.config['AMP'].get("level", "O1").upper()
amp_level = engine.config["AMP"].get("level", "O1").upper()
with paddle.amp.auto_cast(
custom_black_list={
"flatten_contiguous_range", "greater_than"
......@@ -89,7 +98,8 @@ def train_epoch(engine, epoch_id, print_batch_step):
# step lr(by epoch)
for i in range(len(engine.lr_sch)):
if getattr(engine.lr_sch[i], "by_epoch", False):
if getattr(engine.lr_sch[i], "by_epoch", False) and \
type_name(engine.lr_sch[i]) != "ReduceOnPlateau":
engine.lr_sch[i].step()
......
......@@ -53,14 +53,13 @@ def log_info(trainer, batch_size, epoch_id, iter_id):
ips_msg = "ips: {:.5f} samples/s".format(
batch_size / trainer.time_info["batch_cost"].avg)
eta_sec = ((trainer.config["Global"]["epochs"] - epoch_id + 1
) * trainer.max_iter - iter_id
) * trainer.time_info["batch_cost"].avg
eta_sec = (
(trainer.config["Global"]["epochs"] - epoch_id + 1
) * trainer.max_iter - iter_id) * trainer.time_info["batch_cost"].avg
eta_msg = "eta: {:s}".format(str(datetime.timedelta(seconds=int(eta_sec))))
logger.info("[Train][Epoch {}/{}][Iter: {}/{}]{}, {}, {}, {}, {}".format(
epoch_id, trainer.config["Global"]["epochs"], iter_id,
trainer.max_iter, lr_msg, metric_msg, time_msg, ips_msg,
eta_msg))
trainer.max_iter, lr_msg, metric_msg, time_msg, ips_msg, eta_msg))
for i, lr in enumerate(trainer.lr_sch):
logger.scaler(
......@@ -74,3 +73,8 @@ def log_info(trainer, batch_size, epoch_id, iter_id):
value=trainer.output_info[key].avg,
step=trainer.global_step,
writer=trainer.vdl_writer)
def type_name(object: object) -> str:
"""get class name of an object"""
return object.__class__.__name__
......@@ -7,6 +7,8 @@ from ppcls.utils import logger
from .celoss import CELoss, MixCELoss
from .googlenetloss import GoogLeNetLoss
from .centerloss import CenterLoss
from .contrasiveloss import ContrastiveLoss
from .contrasiveloss import ContrastiveLoss_XBM
from .emlloss import EmlLoss
from .msmloss import MSMLoss
from .npairsloss import NpairsLoss
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from typing import Dict
import paddle
import paddle.nn as nn
from ppcls.loss.xbm import CrossBatchMemory
class ContrastiveLoss(nn.Layer):
"""ContrastiveLoss
Args:
margin (float): margin
embedding_size (int): number of embedding's dimension
normalize_feature (bool, optional): whether to normalize embedding. Defaults to True.
epsilon (float, optional): epsilon. Defaults to 1e-5.
feature_from (str, optional): which key embedding from input dict. Defaults to "features".
"""
def __init__(self,
margin: float,
embedding_size: int,
normalize_feature=True,
epsilon: float=1e-5,
feature_from: str="features"):
super(ContrastiveLoss, self).__init__()
self.margin = margin
self.embedding_size = embedding_size
self.normalize_feature = normalize_feature
self.epsilon = epsilon
self.feature_from = feature_from
def forward(self, input: Dict[str, paddle.Tensor],
target: paddle.Tensor) -> Dict[str, paddle.Tensor]:
feats = input[self.feature_from]
labels = target
# normalize along feature dim
if self.normalize_feature:
feats = nn.functional.normalize(feats, p=2, axis=1)
# squeeze labels to shape (batch_size, )
if labels.ndim >= 2 and labels.shape[-1] == 1:
labels = paddle.squeeze(labels, axis=[-1])
loss = self._compute_loss(feats, target, feats, target)
return {'ContrastiveLoss': loss}
def _compute_loss(self,
inputs_q: paddle.Tensor,
targets_q: paddle.Tensor,
inputs_k: paddle.Tensor,
targets_k: paddle.Tensor) -> paddle.Tensor:
batch_size = inputs_q.shape[0]
# Compute similarity matrix
sim_mat = paddle.matmul(inputs_q, inputs_k.t())
loss = []
for i in range(batch_size):
pos_pair_ = paddle.masked_select(sim_mat[i],
targets_q[i] == targets_k)
pos_pair_ = paddle.masked_select(pos_pair_,
pos_pair_ < 1 - self.epsilon)
neg_pair_ = paddle.masked_select(sim_mat[i],
targets_q[i] != targets_k)
neg_pair = paddle.masked_select(neg_pair_, neg_pair_ > self.margin)
pos_loss = paddle.sum(-pos_pair_ + 1)
if len(neg_pair) > 0:
neg_loss = paddle.sum(neg_pair)
else:
neg_loss = 0
loss.append(pos_loss + neg_loss)
loss = sum(loss) / batch_size
return loss
class ContrastiveLoss_XBM(ContrastiveLoss):
"""ContrastiveLoss with CrossBatchMemory
Args:
xbm_size (int): size of memory bank
xbm_weight (int): weight of CrossBatchMemory's loss
start_iter (int): store embeddings after start_iter
margin (float): margin
embedding_size (int): number of embedding's dimension
epsilon (float, optional): epsilon. Defaults to 1e-5.
normalize_feature (bool, optional): whether to normalize embedding. Defaults to True.
feature_from (str, optional): which key embedding from input dict. Defaults to "features".
"""
def __init__(self,
xbm_size: int,
xbm_weight: int,
start_iter: int,
margin: float,
embedding_size: int,
epsilon: float=1e-5,
normalize_feature=True,
feature_from: str="features"):
super(ContrastiveLoss_XBM, self).__init__(
margin, embedding_size, normalize_feature, epsilon, feature_from)
self.xbm = CrossBatchMemory(xbm_size, embedding_size)
self.xbm_weight = xbm_weight
self.start_iter = start_iter
self.iter = 0
def __call__(self, input: Dict[str, paddle.Tensor],
target: paddle.Tensor) -> Dict[str, paddle.Tensor]:
feats = input[self.feature_from]
labels = target
# normalize along feature dim
if self.normalize_feature:
feats = nn.functional.normalize(feats, p=2, axis=1)
# squeeze labels to shape (batch_size, )
if labels.ndim >= 2 and labels.shape[-1] == 1:
labels = paddle.squeeze(labels, axis=[-1])
loss = self._compute_loss(feats, labels, feats, labels)
# compute contrastive loss from memory bank
self.iter += 1
if self.iter > self.start_iter:
self.xbm.enqueue_dequeue(feats.detach(), labels.detach())
xbm_feats, xbm_labels = self.xbm.get()
xbm_loss = self._compute_loss(feats, labels, xbm_feats, xbm_labels)
loss = loss + self.xbm_weight * xbm_loss
return {'ContrastiveLoss_XBM': loss}
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from typing import Tuple
import paddle
class CrossBatchMemory(object):
"""
CrossBatchMemory Implementation. refer to "Cross-Batch Memory for Embedding Learning".
code heavily based on https://github.com/msight-tech/research-xbm/blob/master/ret_benchmark/modeling/xbm.py
Args:
size (int): Size of memory bank
embedding_size (int): number of embedding dimension for memory bank
"""
def __init__(self, size: int, embedding_size: int):
self.size = size
self.embedding_size = embedding_size
self.feats = paddle.zeros([self.size, self.embedding_size])
self.targets = paddle.zeros([self.size, ], dtype="int64")
self.ptr = 0
# self.accumulated_size = 0
@property
def _is_full(self) -> bool:
# return self.accumulated_size >= self.size
return self.targets[-1].item() != 0 # author's usage
def get(self) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""return features and targets in memory bank
Returns:
Tuple[paddle.Tensor, paddle.Tensor]: [features, targets]
"""
if self._is_full:
return self.feats, self.targets
else:
return self.feats[:self.ptr], self.targets[:self.ptr]
def enqueue_dequeue(self, feats: paddle.Tensor,
targets: paddle.Tensor) -> None:
"""put newest feats and targets into memory bank and pop oldest feats and targets from momory bank
Args:
feats (paddle.Tensor): features to enque
targets (paddle.Tensor): targets to enque
"""
input_size = len(targets)
if self.ptr + input_size > self.size:
self.feats[-input_size:] = feats
self.targets[-input_size:] = targets
self.ptr = 0
else:
self.feats[self.ptr:self.ptr + input_size] = feats
self.targets[self.ptr:self.ptr + input_size] = targets
self.ptr += input_size
# self.accumulated_size += input_size
......@@ -14,7 +14,7 @@
from __future__ import (absolute_import, division, print_function,
unicode_literals)
import types
from abc import abstractmethod
from typing import Union
......@@ -392,3 +392,86 @@ class MultiStepDecay(LRBase):
setattr(learning_rate, "by_epoch", self.by_epoch)
return learning_rate
class ReduceOnPlateau(LRBase):
"""ReduceOnPlateau learning rate decay
Args:
epochs (int): total epoch(s)
step_each_epoch (int): number of iterations within an epoch
learning_rate (float): learning rate
mode (str, optional): ``'min'`` or ``'max'`` can be selected. Normally, it is ``'min'`` , which means that the
learning rate will reduce when ``loss`` stops descending. Specially, if it's set to ``'max'``, the learning
rate will reduce when ``loss`` stops ascending. Defaults to ``'min'``.
factor (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * factor`` .
It should be less than 1.0. Defaults to 0.1.
patience (int, optional): When ``loss`` doesn't improve for this number of epochs, learing rate will be reduced.
Defaults to 10.
threshold (float, optional): ``threshold`` and ``threshold_mode`` will determine the minimum change of ``loss`` .
This make tiny changes of ``loss`` will be ignored. Defaults to 1e-4.
threshold_mode (str, optional): ``'rel'`` or ``'abs'`` can be selected. In ``'rel'`` mode, the minimum change of ``loss``
is ``last_loss * threshold`` , where ``last_loss`` is ``loss`` in last epoch. In ``'abs'`` mode, the minimum
change of ``loss`` is ``threshold`` . Defaults to ``'rel'`` .
cooldown (int, optional): The number of epochs to wait before resuming normal operation. Defaults to 0.
min_lr (float, optional): The lower bound of the learning rate after reduction. Defaults to 0.
epsilon (float, optional): Minimal decay applied to lr. If the difference between new and old lr is smaller than epsilon,
the update is ignored. Defaults to 1e-8.
warmup_epoch (int, optional): The epoch numbers for LinearWarmup. Defaults to 0.
warmup_start_lr (float, optional): start learning rate within warmup. Defaults to 0.0.
last_epoch (int, optional): last epoch. Defaults to -1.
by_epoch (bool, optional): learning rate decays by epoch when by_epoch is True, else by iter. Defaults to False.
"""
def __init__(self,
epochs,
step_each_epoch,
learning_rate,
mode='min',
factor=0.1,
patience=10,
threshold=1e-4,
threshold_mode='rel',
cooldown=0,
min_lr=0,
epsilon=1e-8,
warmup_epoch=0,
warmup_start_lr=0.0,
last_epoch=-1,
by_epoch=False,
**kwargs):
super(ReduceOnPlateau, self).__init__(
epochs, step_each_epoch, learning_rate, warmup_epoch,
warmup_start_lr, last_epoch, by_epoch)
self.mode = mode
self.factor = factor
self.patience = patience
self.threshold = threshold
self.threshold_mode = threshold_mode
self.cooldown = cooldown
self.min_lr = min_lr
self.epsilon = epsilon
def __call__(self):
learning_rate = lr.ReduceOnPlateau(
learning_rate=self.learning_rate,
mode=self.mode,
factor=self.factor,
patience=self.patience,
threshold=self.threshold,
threshold_mode=self.threshold_mode,
cooldown=self.cooldown,
min_lr=self.min_lr,
epsilon=self.epsilon)
if self.warmup_steps > 0:
learning_rate = self.linear_warmup(learning_rate)
# NOTE: Implement get_lr() method for class `ReduceOnPlateau`,
# which is called in `log_info` function
def get_lr(self):
return self.last_lr
learning_rate.get_lr = types.MethodType(get_lr, learning_rate)
setattr(learning_rate, "by_epoch", self.by_epoch)
return learning_rate
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册