未验证 提交 cdd87759 编写于 作者: L littletomatodonkey 提交者: GitHub

fix rec doc (#829)

* fix recognition doc
上级 e170ad52
Global:
infer_imgs: "./dataset/logo_demo_data_v1.0/query/logo_AKG.jpg"
infer_imgs: "./dataset/logo_demo_data_v1.0/query/logo_auxx-1.jpg"
det_inference_model_dir: "./models/ppyolov2_r50vd_dcn_mainbody_v1.0_infer/"
rec_inference_model_dir: "./models/logo_rec_ResNet50_Logo3K_v1.0_infer/"
batch_size: 1
......
......@@ -28,6 +28,8 @@ from vector_search import Graph_Index
from utils import logger
from utils import config
from utils.get_image_list import get_image_list
from utils.draw_bbox import draw_bbox_results
class SystemPredictor(object):
def __init__(self, config):
......@@ -40,7 +42,8 @@ class SystemPredictor(object):
self.return_k = self.config['IndexProcess']['return_k']
self.search_budget = self.config['IndexProcess']['search_budget']
self.Searcher = Graph_Index(dist_type=config['IndexProcess']['dist_type'])
self.Searcher = Graph_Index(
dist_type=config['IndexProcess']['dist_type'])
self.Searcher.load(config['IndexProcess']['index_path'])
def predict(self, img):
......@@ -53,13 +56,16 @@ class SystemPredictor(object):
rec_results = self.rec_predictor.predict(crop_img)
#preds["feature"] = rec_results
preds["bbox"] = [xmin, ymin, xmax, ymax]
scores, docs = self.Searcher.search(query=rec_results, return_k=self.return_k, search_budget=self.search_budget)
scores, docs = self.Searcher.search(
query=rec_results,
return_k=self.return_k,
search_budget=self.search_budget)
preds["rec_docs"] = docs
preds["rec_scores"] = scores
output.append(preds)
return output
def main(config):
system_predictor = SystemPredictor(config)
......@@ -69,6 +75,7 @@ def main(config):
for idx, image_file in enumerate(image_list):
img = cv2.imread(image_file)[:, :, ::-1]
output = system_predictor.predict(img)
draw_bbox_results(img[:, :, ::-1], output, image_file)
print(output)
return
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import numpy as np
import cv2
def draw_bbox_results(image, results, input_path, save_dir=None):
for result in results:
[xmin, ymin, xmax, ymax] = result["bbox"]
image = cv2.rectangle(image, (xmin, ymin), (xmax, ymax), (0, 255, 0),
2)
image_name = os.path.basename(input_path)
if save_dir is None:
save_dir = "output"
os.makedirs(save_dir, exist_ok=True)
output_path = os.path.join(save_dir, image_name)
cv2.imwrite(output_path, image)
# Quick Start of Recognition
This tutorial contains 3 parts: Environment Preparation, Image Recognition Experience, and Unknown Category Image Recognition Experience.
If the image category already exists in the image index database, then you can take a reference to chapter [Image Recognition Experience](#image_recognition_experience),to complete the progress of image recognition;If you wish to recognize unknow category image, which is not included in the index database,you can take a reference to chapter [Unknown Category Image Recognition Experience](#unkonw_category_image_recognition_experience),to complete the process of creating an index to recognize it。
## Catalogue
* [1. Enviroment Preparation](#enviroment_preperation )
* [2. Image Recognition Experience](#image_recognition_experience)
* [2.1 Download and Unzip the Inference Model and Demo Data](#download_and_unzip_the_inference_model_and_demo_data)
* [2.2 Logo Recognition and Retrieval](#Logo_recognition_and_retrival)
* [2.2.1 Single Image Recognition](#recognition_of_single_image)
* [2.2.2 Folder-based Batch Recognition](#folder_based_batch_recognition)
* [3. Unknown Category Image Recognition Experience](#unkonw_category_image_recognition_experience)
* [3.1 Build the Base Library Based on Our Own Dataset](#build_the_base_library_based_on_your_own_dataset)
* [3.2 ecognize the Unknown Category Images](#Image_differentiation_based_on_the_new_index_library)
<a name="enviroment_preparation"></a>
## 1. Enviroment Preparation
* Installation:Please take a reference to [Quick Installation ](./installation.md)to configure the PaddleClas environment.
* Using the following command to enter Folder `deploy`. All content and commands in this section need to be run in folder `deploy`.
```
cd deploy
```
<a name="image_recognition_experience"></a>
## 2. Image Recognition Experience
The detection model with the recognition inference model for the 4 directions (Logo, Cartoon Face, Vehicle, Product), the address for downloading the test data and the address of the corresponding configuration file are as follows.
| Models Introduction | Recommended Scenarios | Test Data Address | inference Model | Predict Config File | Config File to Build Index Database |
| ------------ | ------------- | ------- | -------- | ------- | -------- |
| Generic mainbody detection model | General Scenarios | - |[Model Download Link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/ppyolov2_r50vd_dcn_mainbody_v1.0_infer.tar) | - | - |
| Logo Recognition Model | Logo Scenario | [Data Download Link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/data/logo_demo_data_v1.0.tar) | [Model Download Link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/logo_rec_ResNet50_Logo3K_v1.0_infer.tar) | [inference_logo.yaml](../../../deploy/configs/inference_logo.yaml) | [build_logo.yaml](../../../deploy/configs/build_logo.yaml) |
| Cartoon Face Recognition Model| Cartoon Face Scenario | [Data Download Link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/data/cartoon_demo_data_v1.0.tar) | [Model Download Link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/cartoon_rec_ResNet50_iCartoon_v1.0_infer.tar) | [inference_cartoon.yaml](../../../deploy/configs/inference_cartoon.yaml) | [build_cartoon.yaml](../../../deploy/configs/build_cartoon.yaml) |
| Vehicle Subclassification Model | Vehicle Scenario | [Data Download Link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/data/vehicle_demo_data_v1.0.tar) | [Model Download Link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/vehicle_cls_ResNet50_CompCars_v1.0_infer.tar) | [inference_vehicle.yaml](../../../deploy/configs/inference_vehicle.yaml) | [build_vehicle.yaml](../../../deploy/configs/build_vehicle.yaml) |
| Product Recignition Model | Product Scenario | [Data Download Link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/data/product_demo_data_v1.0.tar) | [Model Download Link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/product_ResNet50_vd_Inshop_v1.0_infer.tar) | [inference_inshop.yaml](../../../deploy/configs/) | [build_inshop.yaml](../../../deploy/configs/build_inshop.yaml) |
**Attention**:If you do not have wget installed on Windows, you can download the model by copying the link into your browser and unzipping it in the appropriate folder; for Linux or macOS users, you can right-click and copy the download link to download it via the `wget` command.
* You can download and unzip the data and models by following the command below
```shell
mkdir dataset
cd dataset
# Download the demo data and unzip
wget {Data download link} && tar -xf {Name of the tar archive}
cd ..
mkdir models
cd models
# Download and unzip the inference model
wget {Models download link} && tar -xf {Name of the tar archive}
cd ..
```
<a name="download_and_unzip_the_inference_model_and_demo_data"></a>
### 2.1 Download and Unzip the Inference Model and Demo Data
Take the Logo recognition as an example, download the detection model, recognition model and Logo recognition demo data with the following commands.
```shell
mkdir models
cd models
# Download the generic detection inference model and unzip it
wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/ppyolov2_r50vd_dcn_mainbody_v1.0_infer.tar && tar -xf ppyolov2_r50vd_dcn_mainbody_v1.0_infer.tar
# Download and unpack the inference model
wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/logo_rec_ResNet50_Logo3K_v1.0_infer.tar && tar -xf logo_rec_ResNet50_Logo3K_v1.0_infer.tar
cd ..
mkdir dataset
cd dataset
# Download the demo data and unzip it
wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/data/logo_demo_data_v1.0.tar && tar -xf logo_demo_data_v1.0.tar
cd ..
```
Once unpacked, the `dataset` folder should have the following file structure.
```
├── logo_demo_data_v1.0
│ ├── data_file.txt
│ ├── gallery
│ ├── index
│ └── query
├── ...
```
The `data_file.txt` is images list used to build the index database, the `gallery` folder contains all the original images used to build the index database, the `index` folder contains the index files generated by building the index database, and the `query` is the demo image used to test the recognition effect.
The `models` folder should have the following file structure.
```
├── logo_rec_ResNet50_Logo3K_v1.0_infer
│ ├── inference.pdiparams
│ ├── inference.pdiparams.info
│ └── inference.pdmodel
├── ppyolov2_r50vd_dcn_mainbody_v1.0_infer
│ ├── inference.pdiparams
│ ├── inference.pdiparams.info
│ └── inference.pdmodel
```
<a name="Logo_recognition_and_retrival"></a>
### 2.2 Logo Recognition and Retrival
Take the Logo recognition demo as an example to show the recognition and retrieval process (if you wish to try other scenarios of recognition and retrieval, replace the corresponding configuration file after downloading and unzipping the corresponding demo data and model to complete the prediction)。
<a name="recognition_of_single_image"></a>
#### 2.2.1 Single Image Recognition
Run the following command to identify and retrieve the image `. /dataset/logo_demo_data_v1.0/query/logo_auxx-1.jpg` for recognition and retrieval
```shell
python3.7 python/predict_system.py -c configs/inference_logo.yaml
```
The image to be retrieved is shown below.
<div align="center">
<img src="../../images/recognition/logo_demo/query/logo_auxx-1.jpg" width = "400" />
</div>
The final output is shown below.
```
[{'bbox': [129, 219, 230, 253], 'rec_docs': ['auxx-2', 'auxx-1', 'auxx-2', 'auxx-1', 'auxx-2'], 'rec_scores': array([3.09635019, 3.09635019, 2.83965826, 2.83965826, 2.64057827])}]
```
where bbox indicates the location of the detected subject, rec_docs indicates the labels corresponding to a number of images in the index dabase that are most similar to the detected subject, and rec_scores indicates the corresponding similarity.
<a name="folder_based_batch_recognition"></a>
#### 2.2.2 Folder-based Batch Recognition
If you want to predict the images in the folder, you can directly modify the `Global.infer_imgs` field in the configuration file, or you can also modify the corresponding configuration through the following `-o` parameter.
```shell
python3.7 python/predict_system.py -c configs/inference_logo.yaml -o Global.infer_imgs="./dataset/logo_demo_data_v1.0/query"
```
Furthermore, the recognition inference model path can be changed by modifying the `Global.rec_inference_model_dir` field, and the path of the index to the index databass can be changed by modifying the `IndexProcess.index_path` field.
<a name="unkonw_category_image_recognition_experience"></a>
## 3. Recognize Images of Unknown Category
To recognize the image `./dataset/logo_demo_data_v1.0/query/logo_cola.jpg`, run the command as follows:
```shell
python3.7 python/predict_system.py -c configs/inference_logo.yaml -o Global.infer_imgs="./dataset/logo_demo_data_v1.0/query/logo_cola.jpg"
```
The image to be retrieved is shown below.
<div align="center">
<img src="../../images/recognition/logo_demo/query/logo_cola.jpg" width = "400" />
</div>
The output is as follows:
```
[{'bbox': [635, 0, 1382, 1043], 'rec_docs': ['Arcam', 'univox', 'univox', 'Arecont Vision', 'univox'], 'rec_scores': array([0.47730467, 0.47625482, 0.46496609, 0.46296868, 0.45239362])}]
```
Since the index infomation is not included in the corresponding index databse, the recognition results are not proper. At this time, we can complete the image recognition of unknown categories by constructing a new index database.
When the index database cannot cover the scenes we actually recognise, i.e. when predicting images of unknown categories, we need to add similar images of the corresponding categories to the index databasey, thus completing the recognition of images of unknown categories ,which does not require retraining.
<a name="build_the_base_library_based_on_your_own_dataset"></a>
### 3.1 Build the Base Library Based on Your Own Dataset
First, you need to obtain the original images to store in the database (store in `./dataset/logo_demo_data_v1.0/gallery`) and the corresponding label infomation,recording the category of the original images and the label information)store in the text file `./dataset/logo_demo_data_v1.0/data_file_update.txt`
Then use the following command to build the index to accelerate the retrieval process after recognition.
```shell
python3.7 python/build_gallery.py -c configs/build_logo.yaml -o IndexProcess.data_file="./dataset/logo_demo_data_v1.0/data_file_update.txt" -o IndexProcess.index_path="./dataset/logo_demo_data_v1.0/index_update"
```
Finally, the new index information is stored in the folder`./dataset/logo_demo_data_v1.0/index_update`. Use the new index database for the above index.
<a name="Image_differentiation_based_on_the_new_index_library"></a>
### 3.2 Recognize the Unknown Category Images
To recognize the image `./dataset/logo_demo_data_v1.0/query/logo_cola.jpg`, run the command as follows.
```shell
python3.7 python/predict_system.py -c configs/inference_logo.yaml -o Global.infer_imgs="./dataset/logo_demo_data_v1.0/query/logo_cola.jpg" -o IndexProcess.index_path="./dataset/logo_demo_data_v1.0/index_update"
```
The output is as follows:
```
[{'bbox': [635, 0, 1382, 1043], 'rec_docs': ['coca cola', 'coca cola', 'coca cola', 'coca cola', 'coca cola'], 'rec_scores': array([0.57111013, 0.56019932, 0.55656564, 0.54122502, 0.48266801])}]
```
The recognition result is correct.
# 图像识别快速开始
图像识别主要包含3个部分:主体检测得到检测框、识别提取特征、根据特征进行检索
本文档包含3个部分:环境配置、图像识别体验、未知类别的图像识别体验
## 1. 环境配置
如果图像类别已经存在于图像索引库中,那么可以直接参考[图像识别体验](#图像识别体验)章节,完成图像识别过程;如果希望识别未知类别的图像,即图像类别之前不存在于索引库中,那么可以参考[未知类别的图像识别体验](#未知类别的图像识别体验)章节,完成建立索引并识别的过程。
* 请先参考[快速安装](./installation.md)配置PaddleClas运行环境。
## 目录
* [1. 环境配置](#环境配置)
* [2. 图像识别体验](#图像识别体验)
* [2.1 下载、解压inference 模型与demo数据](#下载、解压inference_模型与demo数据)
* [2.2 Logo识别与检索](#Logo识别与检索)
* [2.2.1 识别单张图像](#识别单张图像)
* [2.2.2 基于文件夹的批量识别](#基于文件夹的批量识别)
* [3. 未知类别的图像识别体验](#未知类别的图像识别体验)
* [3.1 基于自己的数据集构建索引库](#基于自己的数据集构建索引库)
* [3.2 基于新的索引库的图像识别](#基于新的索引库的图像识别)
注意:
**本部分内容需要在`deploy`文件夹下运行,在PaddleClas代码的根目录下,可以通过以下方法进入该文件夹**
<a name="环境配置"></a>
## 1. 环境配置
```shell
cd deploy
```
* 安装:请先参考[快速安装](./installation.md)配置PaddleClas运行环境。
* 进入`deploy`运行目录。本部分所有内容与命令均需要在`deploy`目录下运行,可以通过下面的命令进入`deploy`目录。
## 2. inference 模型和数据下载
```
cd deploy
```
检测模型与4个方向(Logo、动漫人物、车辆、商品)的识别inference模型以及测试数据下载方法如下。。
<a name="图像识别体验"></a>
## 2. 图像识别体验
| 模型简介 | 推荐场景 | 测试数据地址 | inference模型 |
| ------------ | ------------- | ------- | -------- |
| 通用主体检测模型 | 通用场景 | - |[下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/ppyolov2_r50vd_dcn_mainbody_v1.0_infer.tar) |
| Logo识别模型 | Logo场景 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/data/logo_demo_data_v1.0.tar) | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/logo_rec_ResNet50_Logo3K_v1.0_infer.tar) |
| 动漫人物识别模型 | 动漫人物场景 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/data/cartoon_demo_data_v1.0.tar) | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/cartoon_rec_ResNet50_iCartoon_v1.0_infer.tar) |
| 车辆细分类模型 | 车辆场景 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/data/vehicle_demo_data_v1.0.tar) | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/vehicle_cls_ResNet50_CompCars_v1.0_infer.tar) |
| 商品识别模型 | 商品场景 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/data/product_demo_data_v1.0.tar) | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/product_ResNet50_vd_Inshop_v1.0_infer.tar) |
检测模型与4个方向(Logo、动漫人物、车辆、商品)的识别inference模型、测试数据下载地址以及对应的配置文件地址如下。
| 模型简介 | 推荐场景 | 测试数据地址 | inference模型 | 预测配置文件 | 构建索引库的配置文件 |
| ------------ | ------------- | ------- | -------- | ------- | -------- |
| 通用主体检测模型 | 通用场景 | - |[模型下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/ppyolov2_r50vd_dcn_mainbody_v1.0_infer.tar) | - | - |
| Logo识别模型 | Logo场景 | [数据下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/data/logo_demo_data_v1.0.tar) | [模型下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/logo_rec_ResNet50_Logo3K_v1.0_infer.tar) | [inference_logo.yaml](../../../deploy/configs/inference_logo.yaml) | [build_logo.yaml](../../../deploy/configs/build_logo.yaml) |
| 动漫人物识别模型 | 动漫人物场景 | [数据下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/data/cartoon_demo_data_v1.0.tar) | [模型下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/cartoon_rec_ResNet50_iCartoon_v1.0_infer.tar) | [inference_cartoon.yaml](../../../deploy/configs/inference_cartoon.yaml) | [build_cartoon.yaml](../../../deploy/configs/build_cartoon.yaml) |
| 车辆细分类模型 | 车辆场景 | [数据下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/data/vehicle_demo_data_v1.0.tar) | [模型下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/vehicle_cls_ResNet50_CompCars_v1.0_infer.tar) | [inference_vehicle.yaml](../../../deploy/configs/inference_vehicle.yaml) | [build_vehicle.yaml](../../../deploy/configs/build_vehicle.yaml) |
| 商品识别模型 | 商品场景 | [数据下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/data/product_demo_data_v1.0.tar) | [模型下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/product_ResNet50_vd_Inshop_v1.0_infer.tar) | [inference_inshop.yaml](../../../deploy/configs/) | [build_inshop.yaml](../../../deploy/configs/build_inshop.yaml) |
**注意**:windows 环境下如果没有安装wget,下载模型时可将链接复制到浏览器中下载,并解压放置在相应目录下
**注意**:windows 环境下如果没有安装wget,下载模型时可将链接复制到浏览器中下载,并解压放置在相应目录下;linux或者macOS用户可以右键点击,然后复制下载链接,即可通过`wget`命令下载。
* 下载并解压数据与模型
* 可以按照下面的命令下载并解压数据与模型
```shell
mkdir dataset
cd dataset
# 下载demo数据并解压
wget {url/of/data} && tar -xf {name/of/data/package}
wget {数据下载链接地址} && tar -xf {压缩包的名称}
cd ..
mkdir models
cd models
# 下载识别inference模型并解压
wget {url/of/inference model} && tar -xf {name/of/inference model/package}
wget {模型下载链接地址} && tar -xf {压缩包的名称}
cd ..
```
### 2.1 下载通用检测模型
<a name="下载、解压inference_模型与demo数据"></a>
### 2.1 下载、解压inference 模型与demo数据
Logo识别为例,下载通用检测、识别模型以及Logo识别demo数据,命令如下。
```shell
mkdir models
cd models
# 下载通用检测inference模型并解压
wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/ppyolov2_r50vd_dcn_mainbody_v1.0_infer.tar && tar -xf ppyolov2_r50vd_dcn_mainbody_v1.0_infer.tar
cd ..
```
### 2.1 Logo识别
以Logo识别demo为例,按照下面的命令下载demo数据与模型。
# 下载识别inference模型并解压
wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/logo_rec_ResNet50_Logo3K_v1.0_infer.tar && tar -xf logo_rec_ResNet50_Logo3K_v1.0_infer.tar
```shell
cd ..
mkdir dataset
cd dataset
# 下载demo数据并解压
wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/data/logo_demo_data_v1.0.tar && tar -xf logo_demo_data_v1.0.tar
cd ..
mkdir models
cd models
# 下载识别inference模型并解压
wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/logo_rec_ResNet50_Logo3K_v1.0_infer.tar && tar -xf logo_rec_ResNet50_Logo3K_v1.0_infer.tar
cd ..
```
解压完毕后,`dataset`文件夹下应有如下文件结构:
......@@ -89,6 +94,8 @@ cd ..
├── ...
```
其中`data_file.txt`是用于构建索引库的图像列表文件,`gallery`文件夹中是所有用于构建索引库的图像原始文件,`index`文件夹中是构建索引库生成的索引文件,`query`是用来测试识别效果的demo图像。
`models`文件夹下应有如下文件结构:
```
......@@ -102,80 +109,102 @@ cd ..
│ └── inference.pdmodel
```
按照下面的方式可以完成对于图片的检索
<a name="Logo识别与检索"></a>
### 2.2 Logo识别与检索
以Logo识别demo为例,展示识别与检索过程(如果希望尝试其他方向的识别与检索效果,在下载解压好对应的demo数据与模型之后,替换对应的配置文件即可完成预测)。
```shell
python3.7 python/predict_system.py -c configs/inference_logo.yaml
```
配置文件中,部分关键字段解释如下
<a name="识别单张图像"></a>
#### 2.2.1 识别单张图像
```yaml
Global:
infer_imgs: "./dataset/logo_demo_data_v1.0/query/" # 预测图像
det_inference_model_dir: "./models/ppyolov2_r50vd_dcn_mainbody_v1.0_infer/" # 检测inference模型文件夹
rec_inference_model_dir: "./models/logo_rec_ResNet50_Logo3K_v1.0_infer/" # 识别inference模型文件夹
batch_size: 1 # 预测的批大小
image_shape: [3, 640, 640] # 检测的图像尺寸
threshold: 0.5 # 检测的阈值,得分超过该阈值的检测框才会被检出
max_det_results: 1 # 用于图像识别的检测框数量,符合阈值条件的检测框中,根据得分,最多对其中的max_det_results个检测框做后续的识别
运行下面的命令,对图像`./dataset/logo_demo_data_v1.0/query/logo_auxx-1.jpg`进行识别与检索
# indexing engine config
IndexProcess:
index_path: "./dataset/logo_demo_data_v1.0/index/" # 索引文件夹,用于识别特征提取之后的索引
search_budget: 100
return_k: 5 # 从底库中反馈return_k个数量的最相似内容
dist_type: "IP"
```shell
python3.7 python/predict_system.py -c configs/inference_logo.yaml
```
待检索图像如下所示。
<div align="center">
<img src="../../images/recognition/logo_demo/query/logo_auxx-1.jpg" width = "400" />
</div>
最终输出结果如下
最终输出结果如下。
```
[{'bbox': [25, 21, 483, 382], 'rec_docs': ['AKG', 'AKG', 'AKG', 'AKG', 'AKG'], 'rec_scores': array([2.32288337, 2.31903863, 2.28398442, 2.16804123, 2.10190272])}]
[{'bbox': [129, 219, 230, 253], 'rec_docs': ['auxx-2', 'auxx-1', 'auxx-2', 'auxx-1', 'auxx-2'], 'rec_scores': array([3.09635019, 3.09635019, 2.83965826, 2.83965826, 2.64057827])}]
```
其中bbox表示检测出的主体所在位置,rec_docs表示底库中与检出主体最相近的若干张图像对应的标签,rec_scores表示对应的相似度。
其中bbox表示检测出的主体所在位置,rec_docs表示索引库中与检出主体最相近的若干张图像对应的标签,rec_scores表示对应的相似度。由rec_docs字段可以看出,返回的若干个结果均为aux,识别正确。
如果希望预测文件夹内的图像,可以直接修改配置文件,也可以通过下面的`-o`参数修改对应的配置。
<a name="基于文件夹的批量识别"></a>
#### 2.2.2 基于文件夹的批量识别
如果希望预测文件夹内的图像,可以直接修改配置文件中的`Global.infer_imgs`字段,也可以通过下面的`-o`参数修改对应的配置。
```shell
python3.7 python/predict_system.py -c configs/inference_logo.yaml -o Global.infer_imgs="./dataset/logo_demo_data_v1.0/query"
```
如果希望在底库中新增图像,重新构建idnex,可以使用下面的命令重新构建index。
更多地,可以通过修改`Global.rec_inference_model_dir`字段来更改识别inference模型的路径,通过修改`IndexProcess.index_path`字段来更改索引库索引的路径。
<a name="未知类别的图像识别体验"></a>
## 3. 未知类别的图像识别体验
对图像`./dataset/logo_demo_data_v1.0/query/logo_cola.jpg`进行识别,命令如下
```shell
python3.7 python/build_gallery.py -c configs/build_logo.yaml
python3.7 python/predict_system.py -c configs/inference_logo.yaml -o Global.infer_imgs="./dataset/logo_demo_data_v1.0/query/logo_cola.jpg"
```
待检索图像如下所示。
<div align="center">
<img src="../../images/recognition/logo_demo/query/logo_cola.jpg" width = "400" />
</div>
输出结果如下
```
[{'bbox': [635, 0, 1382, 1043], 'rec_docs': ['Arcam', 'univox', 'univox', 'Arecont Vision', 'univox'], 'rec_scores': array([0.47730467, 0.47625482, 0.46496609, 0.46296868, 0.45239362])}]
```
其中index相关配置如下。
由于默认的索引库中不包含对应的索引信息,所以这里的识别结果有误,此时我们可以通过构建新的索引库的方式,完成未知类别的图像识别。
当索引库中的图像无法覆盖我们实际识别的场景时,即在预测未知类别的图像时,我们需要将对应类别的相似图像添加到索引库中,从而完成对未知类别的图像识别,这一过程是不需要重新训练的。
<a name="基于自己的数据集构建索引库"></a>
### 3.1 基于自己的数据集构建索引库
```yaml
# indexing engine config
IndexProcess:
index_path: "./dataset/logo_demo_data_v1.0/index/" # 保存的索引地址
image_root: "./dataset/logo_demo_data_v1.0/" # 图像的根目录
data_file: "./dataset/logo_demo_data_v1.0/data_file.txt" # 图像的数据list文本,每一行包含图像的文件名与标签信息
delimiter: "\t"
dist_type: "IP"
pq_size: 100
embedding_size: 512 # 特征维度
首先需要获取待入库的原始图像文件(保存在`./dataset/logo_demo_data_v1.0/gallery`文件夹中)以及对应的标签信息,记录原始图像文件的文件名与标签信息)保存在文本文件`./dataset/logo_demo_data_v1.0/data_file_update.txt`中)。
然后使用下面的命令构建index索引,加速识别后的检索过程。
```shell
python3.7 python/build_gallery.py -c configs/build_logo.yaml -o IndexProcess.data_file="./dataset/logo_demo_data_v1.0/data_file_update.txt" -o IndexProcess.index_path="./dataset/logo_demo_data_v1.0/index_update"
```
需要改动的内容为:
1. 在图像根目录下面添加对应的图像内容(也可以在其子文件夹下面,保证最终根目录与数据list文本中添加的文件名合并之后,图像存在即可)
2. 图像的数据list文本中添加图像新的内容,每行包含图像文件名以及对应的标签信息。
最终新的索引信息保存在文件夹`./dataset/logo_demo_data_v1.0/index_update`中。
### 2.2 其他任务的识别
<a name="基于新的索引库的图像识别"></a>
### 3.2 基于新的索引库的图像识别
如果希望尝试其他方向的识别与检索效果,在下载解压好对应的demo数据与模型之后,替换对应的配置文件即可完成预测
使用新的索引库,对上述图像进行识别,运行命令如下
```shell
python3.7 python/predict_system.py -c configs/inference_logo.yaml -o Global.infer_imgs="./dataset/logo_demo_data_v1.0/query/logo_cola.jpg" -o IndexProcess.index_path="./dataset/logo_demo_data_v1.0/index_update"
```
输出结果如下。
```
[{'bbox': [635, 0, 1382, 1043], 'rec_docs': ['coca cola', 'coca cola', 'coca cola', 'coca cola', 'coca cola'], 'rec_scores': array([0.57111013, 0.56019932, 0.55656564, 0.54122502, 0.48266801])}]
```
| 场景 | 预测配置文件 | 构建底库的配置文件 |
| ---- | ----- | ----- |
| 动漫人物 | [inference_cartoon.yaml](../../../deploy/configs/inference_cartoon.yaml) | [build_cartoon.yaml](../../../deploy/configs/build_cartoon.yaml) |
| 车辆 | [inference_vehicle.yaml](../../../deploy/configs/inference_vehicle.yaml) | [build_vehicle.yaml](../../../deploy/configs/build_vehicle.yaml) |
| 商品 | [inference_inshop.yaml](../../../deploy/configs/) | [build_inshop.yaml](../../../deploy/configs/build_inshop.yaml) |
识别结果正确。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册