提交 50f25470 编写于 作者: D dongshuilong

Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleClas into slim

......@@ -28,11 +28,11 @@ RecPostProcess: null
# indexing engine config
IndexProcess:
index_path: "./recognition_demo_data_v1.0/gallery_cartoon/index/"
image_root: "./recognition_demo_data_v1.0/gallery_cartoon/"
data_file: "./recognition_demo_data_v1.0/gallery_cartoon/data_file.txt"
append_index: False
index_method: "HNSW32" # supported: HNSW32, IVF, Flat
index_dir: "./recognition_demo_data_v1.1/gallery_cartoon/index/"
image_root: "./recognition_demo_data_v1.1/gallery_cartoon/"
data_file: "./recognition_demo_data_v1.1/gallery_cartoon/data_file.txt"
index_operation: "new" # suported: "append", "remove", "new"
delimiter: "\t"
dist_type: "IP"
pq_size: 100
embedding_size: 2048
......@@ -26,11 +26,11 @@ RecPostProcess: null
# indexing engine config
IndexProcess:
index_path: "./recognition_demo_data_v1.0/gallery_logo/index/"
image_root: "./recognition_demo_data_v1.0/gallery_logo/"
data_file: "./recognition_demo_data_v1.0/gallery_logo/data_file.txt"
append_index: False
index_method: "HNSW32" # supported: HNSW32, IVF, Flat
index_dir: "./recognition_demo_data_v1.1/gallery_logo/index/"
image_root: "./recognition_demo_data_v1.1/gallery_logo/"
data_file: "./recognition_demo_data_v1.1/gallery_logo/data_file.txt"
index_operation: "new" # suported: "append", "remove", "new"
delimiter: "\t"
dist_type: "IP"
pq_size: 100
embedding_size: 512
......@@ -26,11 +26,11 @@ RecPostProcess: null
# indexing engine config
IndexProcess:
index_path: "./recognition_demo_data_v1.0/gallery_product/index"
image_root: "./recognition_demo_data_v1.0/gallery_product/"
data_file: "./recognition_demo_data_v1.0/gallery_product/data_file.txt"
append_index: False
index_method: "HNSW32" # supported: HNSW32, IVF, Flat
index_dir: "./recognition_demo_data_v1.1/gallery_product/index"
image_root: "./recognition_demo_data_v1.1/gallery_product/"
data_file: "./recognition_demo_data_v1.1/gallery_product/data_file.txt"
index_operation: "new" # suported: "append", "remove", "new"
delimiter: "\t"
dist_type: "IP"
pq_size: 100
embedding_size: 512
......@@ -26,11 +26,11 @@ RecPostProcess: null
# indexing engine config
IndexProcess:
index_path: "./recognition_demo_data_v1.0/gallery_vehicle/index/"
image_root: "./recognition_demo_data_v1.0/gallery_vehicle/"
data_file: "./recognition_demo_data_v1.0/gallery_vehicle/data_file.txt"
append_index: False
index_method: "HNSW32" # supported: HNSW32, IVF, Flat
index_dir: "./recognition_demo_data_v1.1/gallery_vehicle/index/"
image_root: "./recognition_demo_data_v1.1/gallery_vehicle/"
data_file: "./recognition_demo_data_v1.1/gallery_vehicle/data_file.txt"
index_operation: "new" # suported: "append", "remove", "new"
delimiter: "\t"
dist_type: "IP"
pq_size: 100
embedding_size: 512
Global:
infer_imgs: "./recognition_demo_data_v1.0/test_cartoon"
infer_imgs: "./recognition_demo_data_v1.1/test_cartoon"
det_inference_model_dir: "./models/ppyolov2_r50vd_dcn_mainbody_v1.0_infer/"
rec_inference_model_dir: "./models/cartoon_rec_ResNet50_iCartoon_v1.0_infer/"
rec_nms_thresold: 0.05
......@@ -51,8 +51,6 @@ RecPreProcess:
RecPostProcess: null
IndexProcess:
index_path: "./recognition_demo_data_v1.0/gallery_cartoon/index/"
search_budget: 100
index_dir: "./recognition_demo_data_v1.1/gallery_cartoon/index/"
return_k: 5
dist_type: "IP"
score_thres: 0.5
Global:
infer_imgs: "./recognition_demo_data_v1.0/test_logo"
infer_imgs: "./recognition_demo_data_v1.1/test_logo"
det_inference_model_dir: "./models/ppyolov2_r50vd_dcn_mainbody_v1.0_infer/"
rec_inference_model_dir: "./models/logo_rec_ResNet50_Logo3K_v1.0_infer/"
rec_nms_thresold: 0.05
......@@ -50,8 +50,6 @@ RecPostProcess: null
# indexing engine config
IndexProcess:
index_path: "./recognition_demo_data_v1.0/gallery_logo/index/"
search_budget: 100
index_dir: "./recognition_demo_data_v1.1/gallery_logo/index/"
return_k: 5
dist_type: "IP"
score_thres: 0.5
Global:
infer_imgs: "./recognition_demo_data_v1.0/test_product/daoxiangcunjinzhubing_6.jpg"
infer_imgs: "./recognition_demo_data_v1.1/test_product/daoxiangcunjinzhubing_6.jpg"
det_inference_model_dir: "./models/ppyolov2_r50vd_dcn_mainbody_v1.0_infer"
rec_inference_model_dir: "./models/product_ResNet50_vd_aliproduct_v1.0_infer"
rec_nms_thresold: 0.05
......@@ -50,8 +50,6 @@ RecPostProcess: null
# indexing engine config
IndexProcess:
index_path: "./recognition_demo_data_v1.0/gallery_product/index"
search_budget: 100
index_dir: "./recognition_demo_data_v1.1/gallery_product/index"
return_k: 5
dist_type: "IP"
score_thres: 0.5
Global:
infer_imgs: "./recognition_demo_data_v1.0/test_vehicle/"
infer_imgs: "./recognition_demo_data_v1.1/test_vehicle/"
det_inference_model_dir: "./models/ppyolov2_r50vd_dcn_mainbody_v1.0_infer/"
rec_inference_model_dir: "./models/vehicle_cls_ResNet50_CompCars_v1.0_infer/"
rec_nms_thresold: 0.05
......@@ -52,8 +52,6 @@ RecPostProcess: null
# indexing engine config
IndexProcess:
index_path: "./recognition_demo_data_v1.0/gallery_vehicle/index/"
search_budget: 100
index_dir: "./recognition_demo_data_v1.1/gallery_vehicle/index/"
return_k: 5
dist_type: "IP"
score_thres: 0.5
......@@ -17,13 +17,13 @@ import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.abspath(os.path.join(__dir__, '../')))
import copy
import cv2
import faiss
import numpy as np
from tqdm import tqdm
import pickle
from python.predict_rec import RecPredictor
from vector_search import Graph_Index
from utils import logger
from utils import config
......@@ -31,9 +31,9 @@ from utils import config
def split_datafile(data_file, image_root, delimiter="\t"):
'''
data_file: image path and info, which can be splitted by spacer
data_file: image path and info, which can be splitted by spacer
image_root: image path root
delimiter: delimiter
delimiter: delimiter
'''
gallery_images = []
gallery_docs = []
......@@ -45,9 +45,8 @@ def split_datafile(data_file, image_root, delimiter="\t"):
assert text_num >= 2, f"line({ori_line}) must be splitted into at least 2 parts, but got {text_num}"
image_file = os.path.join(image_root, line[0])
image_doc = line[1]
gallery_images.append(image_file)
gallery_docs.append(image_doc)
gallery_docs.append(ori_line.strip())
return gallery_images, gallery_docs
......@@ -64,9 +63,91 @@ class GalleryBuilder(object):
'''
build index from scratch
'''
operation_method = config.get("index_operation", "new").lower()
gallery_images, gallery_docs = split_datafile(
config['data_file'], config['image_root'], config['delimiter'])
# when remove data in index, do not need extract fatures
if operation_method != "remove":
gallery_features = self._extract_features(gallery_images, config)
assert operation_method in [
"new", "remove", "append"
], "Only append, remove and new operation are supported"
# vector.index: faiss index file
# id_map.pkl: use this file to map id to image_doc
if operation_method in ["remove", "append"]:
# if remove or append, vector.index and id_map.pkl must exist
assert os.path.join(
config["index_dir"], "vector.index"
), "The vector.index dose not exist in {} when 'index_operation' is not None".format(
config["index_dir"])
assert os.path.join(
config["index_dir"], "id_map.pkl"
), "The id_map.pkl dose not exist in {} when 'index_operation' is not None".format(
config["index_dir"])
index = faiss.read_index(
os.path.join(config["index_dir"], "vector.index"))
with open(os.path.join(config["index_dir"], "id_map.pkl"),
'rb') as fd:
ids = pickle.load(fd)
assert index.ntotal == len(ids.keys(
)), "data number in index is not equal in in id_map"
else:
if not os.path.exists(config["index_dir"]):
os.makedirs(config["index_dir"], exist_ok=True)
index_method = config.get("index_method", "HNSW32")
# if IVF method, cal ivf number automaticlly
if index_method == "IVF":
index_method = index_method + str(
min(int(len(gallery_images) // 8), 65536)) + ",Flat"
dist_type = faiss.METRIC_INNER_PRODUCT if config[
"dist_type"] == "IP" else faiss.METRIC_L2
index = faiss.index_factory(config["embedding_size"], index_method,
dist_type)
index = faiss.IndexIDMap2(index)
ids = {}
if config["index_method"] == "HNSW32":
logger.warning(
"The HNSW32 method dose not support 'remove' operation")
if operation_method != "remove":
# calculate id for new data
start_id = max(ids.keys()) + 1 if ids else 0
ids_now = (
np.arange(0, len(gallery_images)) + start_id).astype(np.int64)
# only train when new index file
if operation_method == "new":
index.train(gallery_features)
index.add_with_ids(gallery_features, ids_now)
for i, d in zip(list(ids_now), gallery_docs):
ids[i] = d
else:
if config["index_method"] == "HNSW32":
raise RuntimeError(
"The index_method: HNSW32 dose not support 'remove' operation"
)
# remove ids in id_map, remove index data in faiss index
remove_ids = list(
filter(lambda k: ids.get(k) in gallery_docs, ids.keys()))
remove_ids = np.asarray(remove_ids)
index.remove_ids(remove_ids)
for k in remove_ids:
del ids[k]
# store faiss index file and id_map file
faiss.write_index(index,
os.path.join(config["index_dir"], "vector.index"))
with open(os.path.join(config["index_dir"], "id_map.pkl"), 'wb') as fd:
pickle.dump(ids, fd)
def _extract_features(self, gallery_images, config):
# extract gallery features
gallery_features = np.zeros(
[len(gallery_images), config['embedding_size']], dtype=np.float32)
......@@ -91,19 +172,11 @@ class GalleryBuilder(object):
rec_feat = self.rec_predictor.predict(batch_img)
gallery_features[-len(batch_img):, :] = rec_feat
batch_img = []
# train index
self.Searcher = Graph_Index(dist_type=config['dist_type'])
self.Searcher.build(
gallery_vectors=gallery_features,
gallery_docs=gallery_docs,
pq_size=config['pq_size'],
index_path=config['index_path'],
append_index=config["append_index"])
return gallery_features
def main(config):
system_builder = GalleryBuilder(config)
GalleryBuilder(config)
return
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
......@@ -20,10 +20,11 @@ sys.path.append(os.path.abspath(os.path.join(__dir__, '../')))
import copy
import cv2
import numpy as np
import faiss
import pickle
from python.predict_rec import RecPredictor
from python.predict_det import DetPredictor
from vector_search import Graph_Index
from utils import logger
from utils import config
......@@ -40,11 +41,16 @@ class SystemPredictor(object):
assert 'IndexProcess' in config.keys(), "Index config not found ... "
self.return_k = self.config['IndexProcess']['return_k']
self.search_budget = self.config['IndexProcess']['search_budget']
self.Searcher = Graph_Index(
dist_type=config['IndexProcess']['dist_type'])
self.Searcher.load(config['IndexProcess']['index_path'])
index_dir = self.config["IndexProcess"]["index_dir"]
assert os.path.exists(os.path.join(
index_dir, "vector.index")), "vector.index not found ..."
assert os.path.exists(os.path.join(
index_dir, "id_map.pkl")), "id_map.pkl not found ... "
self.Searcher = faiss.read_index(
os.path.join(index_dir, "vector.index"))
with open(os.path.join(index_dir, "id_map.pkl"), "rb") as fd:
self.id_map = pickle.load(fd)
def append_self(self, results, shape):
results.append({
......@@ -98,14 +104,11 @@ class SystemPredictor(object):
crop_img = img[ymin:ymax, xmin:xmax, :].copy()
rec_results = self.rec_predictor.predict(crop_img)
preds["bbox"] = [xmin, ymin, xmax, ymax]
scores, docs = self.Searcher.search(
query=rec_results,
return_k=self.return_k,
search_budget=self.search_budget)
scores, docs = self.Searcher.search(rec_results, self.return_k)
# just top-1 result will be returned for the final
if scores[0] >= self.config["IndexProcess"]["score_thres"]:
preds["rec_docs"] = docs[0]
preds["rec_scores"] = scores[0]
if scores[0][0] >= self.config["IndexProcess"]["score_thres"]:
preds["rec_docs"] = self.id_map[docs[0][0]].split()[1]
preds["rec_scores"] = scores[0][0]
output.append(preds)
# st5: nms to the final results to avoid fetching duplicate results
......
# 向量检索
**注意**:由于系统适配性问题,在新版本中,此检索算法将被废弃。新版本中将使用[faiss](https://github.com/facebookresearch/faiss),整体检索的过程保持不变,但建立索引及检索时的yaml文件有所修改。
## 1. 简介
一些垂域识别任务(如车辆、商品等)需要识别的类别数较大,往往采用基于检索的方式,通过查询向量与底库向量进行快速的最近邻搜索,获得匹配的预测类别。向量检索模块提供基础的近似最近邻搜索算法,基于百度自研的Möbius算法,一种基于图的近似最近邻搜索算法,用于最大内积搜索 (MIPS)。 该模块提供python接口,支持numpy和 tensor类型向量,支持L2和Inner Product距离计算。
......
# Vector search
**Attention**: Due to the system adaptability problem, this retrieval algorithm will be abandoned in the new version. [faiss](https://github.com/facebookresearch/faiss) will be used in the new version. The use process of the overall retrieval system base will remain unchanged, but the yaml files for build indexes and retrieval will be modified.
## 1. Introduction
Some vertical domain recognition tasks (e.g., vehicles, commodities, etc.) require a large number of recognized categories, and often use a retrieval-based approach to obtain matching predicted categories by performing a fast nearest neighbor search with query vectors and underlying library vectors. The vector search module provides the basic approximate nearest neighbor search algorithm based on Baidu's self-developed Möbius algorithm, a graph-based approximate nearest neighbor search algorithm for maximum inner product search (MIPS). This module provides python interface, supports numpy and tensor type vectors, and supports L2 and Inner Product distance calculation.
......@@ -57,7 +59,7 @@ brew install gcc
1. If prompted with `Error: Running Homebrew as root is extremely dangerous and no longer supported... `, refer to this [link](https://jingyan.baidu.com/article/e52e3615057a2840c60c519c.html)
2. If prompted with `Error: Failure while executing; tar --extract --no-same-owner --file... `, refer to this [link](https://blog.csdn.net/Dawn510/article/details/117787358).
After installation the compiled executable is copied under /usr/local/bin, look at the gcc in this folder:
After installation the compiled executable is copied under /usr/local/bin, look at the gcc in this folder:
```
ls /usr/local/bin/gcc*
......
......@@ -40,11 +40,11 @@ The detection model with the recognition inference model for the 4 directions (L
| Logo Recognition Model | Logo Scenario | [Model Download Link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/logo_rec_ResNet50_Logo3K_v1.0_infer.tar) | [inference_logo.yaml](../../../deploy/configs/inference_logo.yaml) | [build_logo.yaml](../../../deploy/configs/build_logo.yaml) |
| Cartoon Face Recognition Model| Cartoon Face Scenario | [Model Download Link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/cartoon_rec_ResNet50_iCartoon_v1.0_infer.tar) | [inference_cartoon.yaml](../../../deploy/configs/inference_cartoon.yaml) | [build_cartoon.yaml](../../../deploy/configs/build_cartoon.yaml) |
| Vehicle Fine-Grained Classfication Model | Vehicle Scenario | [Model Download Link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/vehicle_cls_ResNet50_CompCars_v1.0_infer.tar) | [inference_vehicle.yaml](../../../deploy/configs/inference_vehicle.yaml) | [build_vehicle.yaml](../../../deploy/configs/build_vehicle.yaml) |
| Product Recignition Model | Product Scenario | [Model Download Link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/product_ResNet50_vd_Inshop_v1.0_infer.tar) | [inference_product.yaml](../../../deploy/configs/inference_product.yaml) | [build_product.yaml](../../../deploy/configs/build_product.yaml) |
| Product Recignition Model | Product Scenario | [Model Download Link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/product_ResNet50_vd_aliproduct_v1.0_infer.tar) | [inference_product.yaml](../../../deploy/configs/inference_product.yaml) | [build_product.yaml](../../../deploy/configs/build_product.yaml) |
| Vehicle ReID Model | Vehicle ReID Scenario | [Model Download Link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/vehicle_reid_ResNet50_VERIWild_v1.0_infer.tar) | - | - |
Demo data in this tutorial can be downloaded here: [download link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/data/recognition_demo_data_en_v1.0.tar).
Demo data in this tutorial can be downloaded here: [download link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/data/recognition_demo_data_en_v1.1.tar).
**Attention**
......
docs/images/wx_group.png

199.6 KB | W: | H:

docs/images/wx_group.png

57.6 KB | W: | H:

docs/images/wx_group.png
docs/images/wx_group.png
docs/images/wx_group.png
docs/images/wx_group.png
  • 2-up
  • Swipe
  • Onion skin
......@@ -44,7 +44,7 @@
| 车辆ReID模型 | 车辆ReID场景 | [模型下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/vehicle_reid_ResNet50_VERIWild_v1.0_infer.tar) | - | - |
本章节demo数据下载地址如下: [数据下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/data/recognition_demo_data_v1.0.tar)
本章节demo数据下载地址如下: [数据下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/data/recognition_demo_data_v1.1.tar)
**注意**
......
......@@ -41,10 +41,15 @@ class _SysPathG(object):
self.path)
with _SysPathG(
os.path.join(
os.path.dirname(os.path.abspath(__file__)), 'ppcls', 'arch')):
import backbone
with _SysPathG(os.path.dirname(os.path.abspath(__file__)), ):
import ppcls
import ppcls.arch.backbone as backbone
def ppclas_init():
if ppcls.utils.logger._logger is None:
ppcls.utils.logger.init_logger()
ppclas_init()
def _load_pretrained_parameters(model, name):
url = 'https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/{}_pretrained.pdparams'.format(
......@@ -63,9 +68,8 @@ with _SysPathG(
Returns:
model: nn.Layer. Specific `AlexNet` model depends on args.
"""
kwargs.update({'pretrained': pretrained})
model = backbone.AlexNet(**kwargs)
if pretrained:
model = _load_pretrained_parameters(model, 'AlexNet')
return model
......@@ -80,9 +84,8 @@ with _SysPathG(
Returns:
model: nn.Layer. Specific `VGG11` model depends on args.
"""
kwargs.update({'pretrained': pretrained})
model = backbone.VGG11(**kwargs)
if pretrained:
model = _load_pretrained_parameters(model, 'VGG11')
return model
......@@ -97,9 +100,8 @@ with _SysPathG(
Returns:
model: nn.Layer. Specific `VGG13` model depends on args.
"""
kwargs.update({'pretrained': pretrained})
model = backbone.VGG13(**kwargs)
if pretrained:
model = _load_pretrained_parameters(model, 'VGG13')
return model
......@@ -114,9 +116,8 @@ with _SysPathG(
Returns:
model: nn.Layer. Specific `VGG16` model depends on args.
"""
kwargs.update({'pretrained': pretrained})
model = backbone.VGG16(**kwargs)
if pretrained:
model = _load_pretrained_parameters(model, 'VGG16')
return model
......@@ -131,9 +132,8 @@ with _SysPathG(
Returns:
model: nn.Layer. Specific `VGG19` model depends on args.
"""
kwargs.update({'pretrained': pretrained})
model = backbone.VGG19(**kwargs)
if pretrained:
model = _load_pretrained_parameters(model, 'VGG19')
return model
......@@ -149,9 +149,8 @@ with _SysPathG(
Returns:
model: nn.Layer. Specific `ResNet18` model depends on args.
"""
kwargs.update({'pretrained': pretrained})
model = backbone.ResNet18(**kwargs)
if pretrained:
model = _load_pretrained_parameters(model, 'ResNet18')
return model
......@@ -167,9 +166,8 @@ with _SysPathG(
Returns:
model: nn.Layer. Specific `ResNet34` model depends on args.
"""
kwargs.update({'pretrained': pretrained})
model = backbone.ResNet34(**kwargs)
if pretrained:
model = _load_pretrained_parameters(model, 'ResNet34')
return model
......@@ -185,9 +183,8 @@ with _SysPathG(
Returns:
model: nn.Layer. Specific `ResNet50` model depends on args.
"""
kwargs.update({'pretrained': pretrained})
model = backbone.ResNet50(**kwargs)
if pretrained:
model = _load_pretrained_parameters(model, 'ResNet50')
return model
......@@ -203,9 +200,8 @@ with _SysPathG(
Returns:
model: nn.Layer. Specific `ResNet101` model depends on args.
"""
kwargs.update({'pretrained': pretrained})
model = backbone.ResNet101(**kwargs)
if pretrained:
model = _load_pretrained_parameters(model, 'ResNet101')
return model
......@@ -221,9 +217,8 @@ with _SysPathG(
Returns:
model: nn.Layer. Specific `ResNet152` model depends on args.
"""
kwargs.update({'pretrained': pretrained})
model = backbone.ResNet152(**kwargs)
if pretrained:
model = _load_pretrained_parameters(model, 'ResNet152')
return model
......@@ -237,9 +232,8 @@ with _SysPathG(
Returns:
model: nn.Layer. Specific `SqueezeNet1_0` model depends on args.
"""
kwargs.update({'pretrained': pretrained})
model = backbone.SqueezeNet1_0(**kwargs)
if pretrained:
model = _load_pretrained_parameters(model, 'SqueezeNet1_0')
return model
......@@ -253,9 +247,8 @@ with _SysPathG(
Returns:
model: nn.Layer. Specific `SqueezeNet1_1` model depends on args.
"""
kwargs.update({'pretrained': pretrained})
model = backbone.SqueezeNet1_1(**kwargs)
if pretrained:
model = _load_pretrained_parameters(model, 'SqueezeNet1_1')
return model
......@@ -271,9 +264,8 @@ with _SysPathG(
Returns:
model: nn.Layer. Specific `DenseNet121` model depends on args.
"""
kwargs.update({'pretrained': pretrained})
model = backbone.DenseNet121(**kwargs)
if pretrained:
model = _load_pretrained_parameters(model, 'DenseNet121')
return model
......@@ -289,9 +281,8 @@ with _SysPathG(
Returns:
model: nn.Layer. Specific `DenseNet161` model depends on args.
"""
kwargs.update({'pretrained': pretrained})
model = backbone.DenseNet161(**kwargs)
if pretrained:
model = _load_pretrained_parameters(model, 'DenseNet161')
return model
......@@ -307,9 +298,8 @@ with _SysPathG(
Returns:
model: nn.Layer. Specific `DenseNet169` model depends on args.
"""
kwargs.update({'pretrained': pretrained})
model = backbone.DenseNet169(**kwargs)
if pretrained:
model = _load_pretrained_parameters(model, 'DenseNet169')
return model
......@@ -325,9 +315,8 @@ with _SysPathG(
Returns:
model: nn.Layer. Specific `DenseNet201` model depends on args.
"""
kwargs.update({'pretrained': pretrained})
model = backbone.DenseNet201(**kwargs)
if pretrained:
model = _load_pretrained_parameters(model, 'DenseNet201')
return model
......@@ -343,9 +332,8 @@ with _SysPathG(
Returns:
model: nn.Layer. Specific `DenseNet264` model depends on args.
"""
kwargs.update({'pretrained': pretrained})
model = backbone.DenseNet264(**kwargs)
if pretrained:
model = _load_pretrained_parameters(model, 'DenseNet264')
return model
......@@ -359,9 +347,8 @@ with _SysPathG(
Returns:
model: nn.Layer. Specific `InceptionV3` model depends on args.
"""
kwargs.update({'pretrained': pretrained})
model = backbone.InceptionV3(**kwargs)
if pretrained:
model = _load_pretrained_parameters(model, 'InceptionV3')
return model
......@@ -375,9 +362,8 @@ with _SysPathG(
Returns:
model: nn.Layer. Specific `InceptionV4` model depends on args.
"""
kwargs.update({'pretrained': pretrained})
model = backbone.InceptionV4(**kwargs)
if pretrained:
model = _load_pretrained_parameters(model, 'InceptionV4')
return model
......@@ -391,9 +377,8 @@ with _SysPathG(
Returns:
model: nn.Layer. Specific `GoogLeNet` model depends on args.
"""
kwargs.update({'pretrained': pretrained})
model = backbone.GoogLeNet(**kwargs)
if pretrained:
model = _load_pretrained_parameters(model, 'GoogLeNet')
return model
......@@ -407,9 +392,8 @@ with _SysPathG(
Returns:
model: nn.Layer. Specific `ShuffleNetV2_x0_25` model depends on args.
"""
kwargs.update({'pretrained': pretrained})
model = backbone.ShuffleNetV2_x0_25(**kwargs)
if pretrained:
model = _load_pretrained_parameters(model, 'ShuffleNetV2_x0_25')
return model
......@@ -423,9 +407,8 @@ with _SysPathG(
Returns:
model: nn.Layer. Specific `MobileNetV1` model depends on args.
"""
kwargs.update({'pretrained': pretrained})
model = backbone.MobileNetV1(**kwargs)
if pretrained:
model = _load_pretrained_parameters(model, 'MobileNetV1')
return model
......@@ -439,9 +422,8 @@ with _SysPathG(
Returns:
model: nn.Layer. Specific `MobileNetV1_x0_25` model depends on args.
"""
kwargs.update({'pretrained': pretrained})
model = backbone.MobileNetV1_x0_25(**kwargs)
if pretrained:
model = _load_pretrained_parameters(model, 'MobileNetV1_x0_25')
return model
......@@ -455,9 +437,8 @@ with _SysPathG(
Returns:
model: nn.Layer. Specific `MobileNetV1_x0_5` model depends on args.
"""
kwargs.update({'pretrained': pretrained})
model = backbone.MobileNetV1_x0_5(**kwargs)
if pretrained:
model = _load_pretrained_parameters(model, 'MobileNetV1_x0_5')
return model
......@@ -471,9 +452,8 @@ with _SysPathG(
Returns:
model: nn.Layer. Specific `MobileNetV1_x0_75` model depends on args.
"""
kwargs.update({'pretrained': pretrained})
model = backbone.MobileNetV1_x0_75(**kwargs)
if pretrained:
model = _load_pretrained_parameters(model, 'MobileNetV1_x0_75')
return model
......@@ -487,9 +467,8 @@ with _SysPathG(
Returns:
model: nn.Layer. Specific `MobileNetV2_x0_25` model depends on args.
"""
kwargs.update({'pretrained': pretrained})
model = backbone.MobileNetV2_x0_25(**kwargs)
if pretrained:
model = _load_pretrained_parameters(model, 'MobileNetV2_x0_25')
return model
......@@ -503,9 +482,8 @@ with _SysPathG(
Returns:
model: nn.Layer. Specific `MobileNetV2_x0_5` model depends on args.
"""
kwargs.update({'pretrained': pretrained})
model = backbone.MobileNetV2_x0_5(**kwargs)
if pretrained:
model = _load_pretrained_parameters(model, 'MobileNetV2_x0_5')
return model
......@@ -519,9 +497,8 @@ with _SysPathG(
Returns:
model: nn.Layer. Specific `MobileNetV2_x0_75` model depends on args.
"""
kwargs.update({'pretrained': pretrained})
model = backbone.MobileNetV2_x0_75(**kwargs)
if pretrained:
model = _load_pretrained_parameters(model, 'MobileNetV2_x0_75')
return model
......@@ -535,9 +512,8 @@ with _SysPathG(
Returns:
model: nn.Layer. Specific `MobileNetV2_x1_5` model depends on args.
"""
kwargs.update({'pretrained': pretrained})
model = backbone.MobileNetV2_x1_5(**kwargs)
if pretrained:
model = _load_pretrained_parameters(model, 'MobileNetV2_x1_5')
return model
......@@ -551,9 +527,8 @@ with _SysPathG(
Returns:
model: nn.Layer. Specific `MobileNetV2_x2_0` model depends on args.
"""
kwargs.update({'pretrained': pretrained})
model = backbone.MobileNetV2_x2_0(**kwargs)
if pretrained:
model = _load_pretrained_parameters(model, 'MobileNetV2_x2_0')
return model
......@@ -567,10 +542,8 @@ with _SysPathG(
Returns:
model: nn.Layer. Specific `MobileNetV3_large_x0_35` model depends on args.
"""
kwargs.update({'pretrained': pretrained})
model = backbone.MobileNetV3_large_x0_35(**kwargs)
if pretrained:
model = _load_pretrained_parameters(model,
'MobileNetV3_large_x0_35')
return model
......@@ -584,10 +557,8 @@ with _SysPathG(
Returns:
model: nn.Layer. Specific `MobileNetV3_large_x0_5` model depends on args.
"""
kwargs.update({'pretrained': pretrained})
model = backbone.MobileNetV3_large_x0_5(**kwargs)
if pretrained:
model = _load_pretrained_parameters(model,
'MobileNetV3_large_x0_5')
return model
......@@ -601,10 +572,8 @@ with _SysPathG(
Returns:
model: nn.Layer. Specific `MobileNetV3_large_x0_75` model depends on args.
"""
kwargs.update({'pretrained': pretrained})
model = backbone.MobileNetV3_large_x0_75(**kwargs)
if pretrained:
model = _load_pretrained_parameters(model,
'MobileNetV3_large_x0_75')
return model
......@@ -618,10 +587,8 @@ with _SysPathG(
Returns:
model: nn.Layer. Specific `MobileNetV3_large_x1_0` model depends on args.
"""
kwargs.update({'pretrained': pretrained})
model = backbone.MobileNetV3_large_x1_0(**kwargs)
if pretrained:
model = _load_pretrained_parameters(model,
'MobileNetV3_large_x1_0')
return model
......@@ -635,10 +602,8 @@ with _SysPathG(
Returns:
model: nn.Layer. Specific `MobileNetV3_large_x1_25` model depends on args.
"""
kwargs.update({'pretrained': pretrained})
model = backbone.MobileNetV3_large_x1_25(**kwargs)
if pretrained:
model = _load_pretrained_parameters(model,
'MobileNetV3_large_x1_25')
return model
......@@ -652,10 +617,8 @@ with _SysPathG(
Returns:
model: nn.Layer. Specific `MobileNetV3_small_x0_35` model depends on args.
"""
kwargs.update({'pretrained': pretrained})
model = backbone.MobileNetV3_small_x0_35(**kwargs)
if pretrained:
model = _load_pretrained_parameters(model,
'MobileNetV3_small_x0_35')
return model
......@@ -669,10 +632,8 @@ with _SysPathG(
Returns:
model: nn.Layer. Specific `MobileNetV3_small_x0_5` model depends on args.
"""
kwargs.update({'pretrained': pretrained})
model = backbone.MobileNetV3_small_x0_5(**kwargs)
if pretrained:
model = _load_pretrained_parameters(model,
'MobileNetV3_small_x0_5')
return model
......@@ -686,10 +647,8 @@ with _SysPathG(
Returns:
model: nn.Layer. Specific `MobileNetV3_small_x0_75` model depends on args.
"""
kwargs.update({'pretrained': pretrained})
model = backbone.MobileNetV3_small_x0_75(**kwargs)
if pretrained:
model = _load_pretrained_parameters(model,
'MobileNetV3_small_x0_75')
return model
......@@ -703,10 +662,8 @@ with _SysPathG(
Returns:
model: nn.Layer. Specific `MobileNetV3_small_x1_0` model depends on args.
"""
kwargs.update({'pretrained': pretrained})
model = backbone.MobileNetV3_small_x1_0(**kwargs)
if pretrained:
model = _load_pretrained_parameters(model,
'MobileNetV3_small_x1_0')
return model
......@@ -720,10 +677,8 @@ with _SysPathG(
Returns:
model: nn.Layer. Specific `MobileNetV3_small_x1_25` model depends on args.
"""
kwargs.update({'pretrained': pretrained})
model = backbone.MobileNetV3_small_x1_25(**kwargs)
if pretrained:
model = _load_pretrained_parameters(model,
'MobileNetV3_small_x1_25')
return model
......@@ -737,9 +692,8 @@ with _SysPathG(
Returns:
model: nn.Layer. Specific `ResNeXt101_32x4d` model depends on args.
"""
kwargs.update({'pretrained': pretrained})
model = backbone.ResNeXt101_32x4d(**kwargs)
if pretrained:
model = _load_pretrained_parameters(model, 'ResNeXt101_32x4d')
return model
......@@ -753,9 +707,8 @@ with _SysPathG(
Returns:
model: nn.Layer. Specific `ResNeXt101_64x4d` model depends on args.
"""
kwargs.update({'pretrained': pretrained})
model = backbone.ResNeXt101_64x4d(**kwargs)
if pretrained:
model = _load_pretrained_parameters(model, 'ResNeXt101_64x4d')
return model
......@@ -769,9 +722,8 @@ with _SysPathG(
Returns:
model: nn.Layer. Specific `ResNeXt152_32x4d` model depends on args.
"""
kwargs.update({'pretrained': pretrained})
model = backbone.ResNeXt152_32x4d(**kwargs)
if pretrained:
model = _load_pretrained_parameters(model, 'ResNeXt152_32x4d')
return model
......@@ -785,9 +737,8 @@ with _SysPathG(
Returns:
model: nn.Layer. Specific `ResNeXt152_64x4d` model depends on args.
"""
kwargs.update({'pretrained': pretrained})
model = backbone.ResNeXt152_64x4d(**kwargs)
if pretrained:
model = _load_pretrained_parameters(model, 'ResNeXt152_64x4d')
return model
......@@ -801,9 +752,8 @@ with _SysPathG(
Returns:
model: nn.Layer. Specific `ResNeXt50_32x4d` model depends on args.
"""
kwargs.update({'pretrained': pretrained})
model = backbone.ResNeXt50_32x4d(**kwargs)
if pretrained:
model = _load_pretrained_parameters(model, 'ResNeXt50_32x4d')
return model
......@@ -817,9 +767,8 @@ with _SysPathG(
Returns:
model: nn.Layer. Specific `ResNeXt50_64x4d` model depends on args.
"""
kwargs.update({'pretrained': pretrained})
model = backbone.ResNeXt50_64x4d(**kwargs)
if pretrained:
model = _load_pretrained_parameters(model, 'ResNeXt50_64x4d')
return model
......@@ -833,8 +782,7 @@ with _SysPathG(
Returns:
model: nn.Layer. Specific `ResNeXt50_64x4d` model depends on args.
"""
kwargs.update({'pretrained': pretrained})
model = backbone.DarkNet53(**kwargs)
if pretrained:
model = _load_pretrained_parameters(model, 'DarkNet53')
return model
......@@ -58,6 +58,7 @@ from ppcls.arch.backbone.model_zoo.rednet import RedNet26, RedNet38, RedNet50, R
from ppcls.arch.backbone.model_zoo.tnt import TNT_small
from ppcls.arch.backbone.model_zoo.hardnet import HarDNet68, HarDNet85, HarDNet39_ds, HarDNet68_ds
from ppcls.arch.backbone.variant_models.resnet_variant import ResNet50_last_stage_stride1
from ppcls.arch.backbone.variant_models.vgg_variant import VGG19Sigmoid
def get_apis():
......
from .resnet_variant import ResNet50_last_stage_stride1
from .vgg_variant import VGG19Sigmoid
import paddle
from paddle.nn import Sigmoid
from ppcls.arch.backbone.legendary_models.vgg import VGG19
__all__ = ["VGG19Sigmoid"]
class SigmoidSuffix(paddle.nn.Layer):
def __init__(self, origin_layer):
super(SigmoidSuffix, self).__init__()
self.origin_layer = origin_layer
self.sigmoid = Sigmoid()
def forward(self, input, res_dict=None, **kwargs):
x = self.origin_layer(input)
x = self.sigmoid(x)
return x
def VGG19Sigmoid(pretrained=False, use_ssld=False, **kwargs):
def replace_function(origin_layer):
new_layer = SigmoidSuffix(origin_layer)
return new_layer
match_re = "linear_2"
model = VGG19(pretrained=pretrained, use_ssld=use_ssld, **kwargs)
model.replace_sub(match_re, replace_function, True)
return model
......@@ -28,7 +28,7 @@ class CircleMargin(nn.Layer):
weight_attr = paddle.ParamAttr(
initializer=paddle.nn.initializer.XavierNormal())
self.fc0 = paddle.nn.Linear(
self.fc = paddle.nn.Linear(
self.embedding_size, self.class_num, weight_attr=weight_attr)
def forward(self, input, label):
......@@ -36,19 +36,22 @@ class CircleMargin(nn.Layer):
paddle.sum(paddle.square(input), axis=1, keepdim=True))
input = paddle.divide(input, feat_norm)
weight = self.fc0.weight
weight = self.fc.weight
weight_norm = paddle.sqrt(
paddle.sum(paddle.square(weight), axis=0, keepdim=True))
weight = paddle.divide(weight, weight_norm)
logits = paddle.matmul(input, weight)
if not self.training or label is None:
return logits
alpha_p = paddle.clip(-logits.detach() + 1 + self.margin, min=0.)
alpha_n = paddle.clip(logits.detach() + self.margin, min=0.)
delta_p = 1 - self.margin
delta_n = self.margin
index = paddle.fluid.layers.where(label != -1).reshape([-1])
m_hot = F.one_hot(label.reshape([-1]), num_classes=logits.shape[1])
logits_p = alpha_p * (logits - delta_p)
logits_n = alpha_n * (logits - delta_n)
pre_logits = logits_p * m_hot + logits_n * (1 - m_hot)
......
......@@ -46,6 +46,9 @@ class CosMargin(paddle.nn.Layer):
weight = paddle.divide(weight, weight_norm)
cos = paddle.matmul(input, weight)
if not self.training or label is None:
return cos
cos_m = cos - self.margin
one_hot = paddle.nn.functional.one_hot(label, self.class_num)
......
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: ./output_dlbhc/
device: gpu
save_interval: 1
eval_during_train: True
eval_interval: 1
epochs: 100
#eval_mode: "retrieval"
print_batch_step: 10
use_visualdl: False
# used for static mode and model export
image_shape: [3, 224, 224]
save_inference_dir: ./inference
#feature postprocess
feature_normalize: False
feature_binarize: "round"
# model architecture
Arch:
name: "RecModel"
Backbone:
name: "MobileNetV3_large_x1_0"
pretrained: True
class_num: 512
Head:
name: "FC"
class_num: 50030
embedding_size: 512
infer_output_key: "features"
infer_add_softmax: "false"
# loss function config for train/eval process
Loss:
Train:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
- CELoss:
weight: 1.0
Optimizer:
name: Momentum
momentum: 0.9
lr:
name: Piecewise
learning_rate: 0.1
decay_epochs: [50, 150]
values: [0.1, 0.01, 0.001]
# data loader for train and eval
DataLoader:
Train:
dataset:
name: ImageNetDataset
image_root: ./dataset/Aliproduct/
cls_label_path: ./dataset/Aliproduct/train_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
size: 256
- RandCropImage:
size: 227
- RandFlipImage:
flip_code: 1
- NormalizeImage:
scale: 1.0/255.0
mean: [0.4914, 0.4822, 0.4465]
std: [0.2023, 0.1994, 0.2010]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 128
drop_last: False
shuffle: True
loader:
num_workers: 4
use_shared_memory: True
Eval:
dataset:
name: ImageNetDataset
image_root: ./dataset/Aliproduct/
cls_label_path: ./dataset/Aliproduct/val_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
size: 227
- NormalizeImage:
scale: 1.0/255.0
mean: [0.4914, 0.4822, 0.4465]
std: [0.2023, 0.1994, 0.2010]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 256
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Infer:
infer_imgs: docs/images/whl/demo.jpg
batch_size: 10
transforms:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 227
- NormalizeImage:
scale: 1.0/255.0
mean: [0.4914, 0.4822, 0.4465]
std: [0.2023, 0.1994, 0.2010]
order: ''
- ToCHWImage:
PostProcess:
name: Topk
topk: 5
class_id_map_file: ppcls/utils/imagenet1k_label_list.txt
Metric:
Train:
- TopkAcc:
topk: [1, 5]
Eval:
- TopkAcc:
topk: [1, 5]
# switch to metric below when eval by retrieval
# - Recallk:
# topk: [1]
# - mAP:
# - Precisionk:
# topk: [1]
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: ./output
device: gpu
save_interval: 1
eval_during_train: True
eval_interval: 1
eval_mode: "retrieval"
epochs: 128
print_batch_step: 10
use_visualdl: False
# used for static mode and model export
image_shape: [3, 224, 224]
save_inference_dir: ./inference
#feature postprocess
feature_normalize: False
feature_binarize: "round"
# model architecture
Arch:
name: "RecModel"
Backbone:
name: "VGG19Sigmoid"
pretrained: True
class_num: 48
Head:
name: "FC"
class_num: 10
embedding_size: 48
infer_output_key: "features"
infer_add_softmax: "false"
# loss function config for train/eval process
Loss:
Train:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
- CELoss:
weight: 1.0
Optimizer:
name: Momentum
momentum: 0.9
lr:
name: Piecewise
learning_rate: 0.01
decay_epochs: [200]
values: [0.01, 0.001]
# data loader for train and eval
DataLoader:
Train:
dataset:
name: ImageNetDataset
image_root: ./dataset/cifar10/
cls_label_path: ./dataset/cifar10/cifar10-2/train.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
size: 256
- RandCropImage:
size: 224
- RandFlipImage:
flip_code: 1
- NormalizeImage:
scale: 1.0/255.0
mean: [0.4914, 0.4822, 0.4465]
std: [0.2023, 0.1994, 0.2010]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 128
drop_last: False
shuffle: True
loader:
num_workers: 4
use_shared_memory: True
Eval:
Query:
dataset:
name: ImageNetDataset
image_root: ./dataset/cifar10/
cls_label_path: ./dataset/cifar10/cifar10-2/test.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.4914, 0.4822, 0.4465]
std: [0.2023, 0.1994, 0.2010]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 512
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Gallery:
dataset:
name: ImageNetDataset
image_root: ./dataset/cifar10/
cls_label_path: ./dataset/cifar10/cifar10-2/database.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.4914, 0.4822, 0.4465]
std: [0.2023, 0.1994, 0.2010]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 512
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Metric:
Train:
- TopkAcc:
topk: [1, 5]
Eval:
- mAP:
- Precisionk:
topk: [1, 5]
......@@ -52,6 +52,11 @@ class Engine(object):
self.config = config
self.eval_mode = self.config["Global"].get("eval_mode",
"classification")
if "Head" in self.config["Arch"]:
self.is_rec = True
else:
self.is_rec = False
# init logger
self.output_dir = self.config['Global']['output_dir']
log_file = os.path.join(self.output_dir, self.config["Arch"]["name"],
......
......@@ -124,6 +124,13 @@ def cal_feature(evaler, name='gallery'):
feas_norm = paddle.sqrt(
paddle.sum(paddle.square(batch_feas), axis=1, keepdim=True))
batch_feas = paddle.divide(batch_feas, feas_norm)
# do binarize
if evaler.config["Global"].get("feature_binarize") == "round":
batch_feas = paddle.round(batch_feas).astype("float32") * 2.0 - 1.0
if evaler.config["Global"].get("feature_binarize") == "sign":
batch_feas = paddle.sign(batch_feas).astype("float32")
if all_feas is None:
all_feas = batch_feas
......@@ -135,8 +142,10 @@ def cal_feature(evaler, name='gallery'):
all_image_id = paddle.concat([all_image_id, batch[1]])
if has_unique_id:
all_unique_id = paddle.concat([all_unique_id, batch[2]])
if evaler.use_dali:
dataloader_tmp.reset()
if paddle.distributed.get_world_size() > 1:
feat_list = []
img_id_list = []
......
......@@ -79,7 +79,7 @@ def train_epoch(trainer, epoch_id, print_batch_step):
def forward(trainer, batch):
if trainer.eval_mode == "classification":
if not trainer.is_rec:
return trainer.model(batch[0])
else:
return trainer.model(batch[0], batch[1])
......@@ -29,7 +29,7 @@ class CELoss(nn.Layer):
self.epsilon = epsilon
def _labelsmoothing(self, target, class_num):
if target.shape[-1] != class_num:
if target.ndim == 1 or target.shape[-1] != class_num:
one_hot_target = F.one_hot(target, class_num)
else:
one_hot_target = target
......
#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import paddle
import paddle.nn as nn
class DSHSDLoss(nn.Layer):
"""
# DSHSD(IEEE ACCESS 2019)
# paper [Deep Supervised Hashing Based on Stable Distribution](https://ieeexplore.ieee.org/document/8648432/)
# [DSHSD] epoch:70, bit:48, dataset:cifar10-1, MAP:0.809, Best MAP: 0.809
# [DSHSD] epoch:250, bit:48, dataset:nuswide_21, MAP:0.809, Best MAP: 0.815
# [DSHSD] epoch:135, bit:48, dataset:imagenet, MAP:0.647, Best MAP: 0.647
"""
def __init__(self, n_class, bit, alpha, multi_label=False):
super(DSHSDLoss, self).__init__()
self.m = 2 * bit
self.alpha = alpha
self.multi_label = multi_label
self.n_class = n_class
self.fc = paddle.nn.Linear(bit, n_class, bias_attr=False)
def forward(self, input, label):
feature = input["features"]
feature = feature.tanh().astype("float32")
dist = paddle.sum(
paddle.square((paddle.unsqueeze(feature, 1) - paddle.unsqueeze(feature, 0))),
axis=2)
# label to ont-hot
label = paddle.flatten(label)
label = paddle.nn.functional.one_hot(label, self.n_class).astype("float32")
s = (paddle.matmul(label, label, transpose_y=True) == 0).astype("float32")
Ld = (1 - s) / 2 * dist + s / 2 * (self.m - dist).clip(min=0)
Ld = Ld.mean()
logits = self.fc(feature)
if self.multi_label:
# multiple labels classification loss
Lc = (logits - label * logits + ((1 + (-logits).exp()).log())).sum(axis=1).mean()
else:
# single labels classification loss
Lc = (-paddle.nn.functional.softmax(logits).log() * label).sum(axis=1).mean()
return {"dshsdloss": Lc + Ld * self.alpha}
class LCDSHLoss(nn.Layer):
"""
# paper [Locality-Constrained Deep Supervised Hashing for Image Retrieval](https://www.ijcai.org/Proceedings/2017/0499.pdf)
# [LCDSH] epoch:145, bit:48, dataset:cifar10-1, MAP:0.798, Best MAP: 0.798
# [LCDSH] epoch:183, bit:48, dataset:nuswide_21, MAP:0.833, Best MAP: 0.834
"""
def __init__(self, n_class, _lambda):
super(LCDSHLoss, self).__init__()
self._lambda = _lambda
self.n_class = n_class
def forward(self, input, label):
feature = input["features"]
# label to ont-hot
label = paddle.flatten(label)
label = paddle.nn.functional.one_hot(label, self.n_class).astype("float32")
s = 2 * (paddle.matmul(label, label, transpose_y=True) > 0).astype("float32") - 1
inner_product = paddle.matmul(feature, feature, transpose_y=True) * 0.5
inner_product = inner_product.clip(min=-50, max=50)
L1 = paddle.log(1 + paddle.exp(-s * inner_product)).mean()
b = feature.sign()
inner_product_ = paddle.matmul(b, b, transpose_y=True) * 0.5
sigmoid = paddle.nn.Sigmoid()
L2 = (sigmoid(inner_product) - sigmoid(inner_product_)).pow(2).mean()
return {"lcdshloss": L1 + self._lambda * L2}
......@@ -16,7 +16,7 @@ from paddle import nn
import copy
from collections import OrderedDict
from .metrics import TopkAcc, mAP, mINP, Recallk
from .metrics import TopkAcc, mAP, mINP, Recallk, Precisionk
from .metrics import DistillationTopkAcc
from .metrics import GoogLeNetTopkAcc
......
......@@ -168,6 +168,47 @@ class Recallk(nn.Layer):
return metric_dict
class Precisionk(nn.Layer):
def __init__(self, topk=(1, 5)):
super().__init__()
assert isinstance(topk, (int, list, tuple))
if isinstance(topk, int):
topk = [topk]
self.topk = topk
def forward(self, similarities_matrix, query_img_id, gallery_img_id,
keep_mask):
metric_dict = dict()
#get cmc
choosen_indices = paddle.argsort(
similarities_matrix, axis=1, descending=True)
gallery_labels_transpose = paddle.transpose(gallery_img_id, [1, 0])
gallery_labels_transpose = paddle.broadcast_to(
gallery_labels_transpose,
shape=[
choosen_indices.shape[0], gallery_labels_transpose.shape[1]
])
choosen_label = paddle.index_sample(gallery_labels_transpose,
choosen_indices)
equal_flag = paddle.equal(choosen_label, query_img_id)
if keep_mask is not None:
keep_mask = paddle.index_sample(
keep_mask.astype('float32'), choosen_indices)
equal_flag = paddle.logical_and(equal_flag,
keep_mask.astype('bool'))
equal_flag = paddle.cast(equal_flag, 'float32')
Ns = paddle.arange(gallery_img_id.shape[0]) + 1
equal_flag_cumsum = paddle.cumsum(equal_flag, axis=1)
Precision_at_k = (paddle.mean(equal_flag_cumsum, axis=0) / Ns).numpy()
for k in self.topk:
metric_dict["precision@{}".format(k)] = Precision_at_k[k - 1]
return metric_dict
class DistillationTopkAcc(TopkAcc):
def __init__(self, model_key, feature_key=None, topk=(1, 5)):
super().__init__(topk=topk)
......
......@@ -8,3 +8,4 @@ visualdl >= 2.0.0b
scipy
scikit-learn==0.23.2
gast==0.3.3
faiss-cpu==1.7.1
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册