diff --git a/docs/zh_CN/image_recognition_pipeline/feature_extraction.md b/docs/zh_CN/image_recognition_pipeline/feature_extraction.md index a447521f646e91fd940cb15aa16217a06d86cb76..ad278deb2ca4b20bf4bda4777037b837d53a7dd0 100644 --- a/docs/zh_CN/image_recognition_pipeline/feature_extraction.md +++ b/docs/zh_CN/image_recognition_pipeline/feature_extraction.md @@ -12,7 +12,7 @@ - **Loss**: 指定所使用的Loss函数。 我们将Loss设计为组合loss的形式, 可以方便得将Classification Loss和Pair_wise Loss组合在一起。 ## 3. 通用识别模型 -在PP-Shitu中, 我们采用[PP_LCNet_x2_5](../models/PP-LCNet.md)作为骨干网络, Neck部分选用Linear Layer, Head部分选用[ArcMargin](https://arxiv.org/abs/1801.07698), Loss部分选用CELoss,详细的配置文件见[通用商品识别配置文件](../../../ppcls/configs/GeneralRecognition/GeneralRecognition_PPLCNet_x2_5.yaml)。其中,训练数据为如下7个公开数据集的汇总: +在PP-Shitu中, 我们采用[PP_LCNet_x2_5](../models/PP-LCNet.md)作为骨干网络, Neck部分选用Linear Layer, Head部分选用[ArcMargin](../../../ppcls/arch/gears/arcmargin.py), Loss部分选用CELoss,详细的配置文件见[通用识别配置文件](../../../ppcls/configs/GeneralRecognition/GeneralRecognition_PPLCNet_x2_5.yaml)。其中,训练数据为如下7个公开数据集的汇总: | 数据集 | 数据量 | 类别数 | 场景 | 数据集地址 | | :------------: | :-------------: | :-------: | :-------: | :--------: | | Aliproduct | 2498771 | 50030 | 商品 | [地址](https://retailvisionworkshop.github.io/recognition_challenge_2020/) | @@ -31,15 +31,42 @@ PP-LCNet-2.5x | 0.839 | 0.888 | 0.861 | 0.841 | 0.793 | 0.892 | 5.0 * 采用的评测指标为:`Recall@1`; * 速度评测机器的CPU具体信息为:`Intel(R) Xeon(R) Gold 6148 CPU @ 2.40GHz`; * 速度指标的评测条件为: 开启MKLDNN, 线程数设置为10; -* 预训练模型地址:[通用识别预训练模型](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/pretrain/general_PPLCNet_x2_5_pretrained_v1.0.pdparams) +* 预训练模型地址:[通用识别预训练模型](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/pretrain/general_PPLCNet_x2_5_pretrained_v1.0.pdparams); # 4. 自定义特征提取 自定义特征提取,是指依据自己的任务,重新训练特征提取模型。主要包含如下四个步骤:1)数据准备;2)模型训练;3)模型评估;4)模型推理。 ## 4.1 数据准备 -首先,需要基于任务定制自己的数据集。数据集格式参见[格式说明](https://github.com/PaddlePaddle/PaddleClas/blob/develop/docs/zh_CN/data_preparation/recognition_dataset.md#%E6%95%B0%E6%8D%AE%E9%9B%86%E6%A0%BC%E5%BC%8F%E8%AF%B4%E6%98%8E) +首先,需要基于任务定制自己的数据集。数据集格式参见[格式说明](https://github.com/PaddlePaddle/PaddleClas/blob/develop/docs/zh_CN/data_preparation/recognition_dataset.md#%E6%95%B0%E6%8D%AE%E9%9B%86%E6%A0%BC%E5%BC%8F%E8%AF%B4%E6%98%8E)。在启动模型训练之前,需要在配置文件中修改数据配置相关的内容, 主要包括数据集的地址以及类别数量。对应到配置文件中的位置如下所示: +``` + Head: + name: ArcMargin + embedding_size: 512 + class_num: 185341 #此处表示类别数 +``` +``` + Train: + dataset: + name: ImageNetDataset + image_root: ./dataset/ #此处表示train数据所在的目录 + cls_label_path: ./dataset/train_reg_all_data.txt #此处表示train数据集label文件的地址 +``` +``` + Query: + dataset: + name: VeriWild + image_root: ./dataset/Aliproduct/. #此处表示query数据集所在的目录 + cls_label_path: ./dataset/Aliproduct/val_list.txt. #此处表示query数据集label文件的地址 +``` +``` + Gallery: + dataset: + name: VeriWild + image_root: ./dataset/Aliproduct/ #此处表示gallery数据集所在的目录 + cls_label_path: ./dataset/Aliproduct/val_list.txt. #此处表示gallery数据集label文件的地址 +``` + ## 4.2 模型训练 -在启动模型训练之前,需要在配置文件中修改数据配置相关的内容, 主要包括数据集的地址以及类别数量。 - 单机单卡训练 ```shell export CUDA_VISIBLE_DEVICES=0 @@ -52,9 +79,10 @@ python -m paddle.distributed.launch \ --gpus="0,1,2,3" tools/train.py \ -c ppcls/configs/GeneralRecognition/GeneralRecognition_PPLCNet_x2_5.yaml ``` -**注意:** 配置文件中默认采用`在线评估`的方式,如果你想加快训练速度,去除`在线评估`,只需要在上述命令后面,增加`-o eval_during_train=False`。 -训练完毕后,在output目录下会生成最终模型文件`latest.pd*`,`best_model.pd*`和训练日志文件`train.log`。`best_model`用来存储当前评测指标下 -的最佳模型。`latest`用来存储最新的模型, 方便在任务中断的情况下从断点位置启动训练,断点重训命令如下所示: +**注意:** +配置文件中默认采用`在线评估`的方式,如果你想加快训练速度,去除`在线评估`,只需要在上述命令后面,增加`-o eval_during_train=False`。 +训练完毕后,在output目录下会生成最终模型文件`latest.pd*`,`best_model.pd*`和训练日志文件`train.log`。 +其中,`best_model`用来存储当前评测指标下的最佳模型;`latest`用来存储最新生成的模型, 方便在任务中断的情况下从断点位置启动训练,断点重训命令如下所示: ```shell export CUDA_VISIBLE_DEVICES=0,1,2,3 python -m paddle.distributed.launch \ @@ -101,4 +129,4 @@ python python/predict_rec.py \ 得到的特征输出格式如下图所示: ![](../../images/feature_extraction_output.png) -在实际使用过程中,单纯得到特征往往并不能够满足业务的需求。如果想进一步通过特征来进行图像识别,可以参照文档[图像识别流程]()和[向量检索]()。 +在实际使用过程中,单纯得到特征往往并不能够满足业务的需求。如果想进一步通过特征来进行图像识别,可以参照文档[向量检索](./vector_search.md)。