未验证 提交 8d52387a 编写于 作者: B Bin Lu 提交者: GitHub

Merge pull request #1702 from Intsigstephon/develop

add deephash configure files and dch algorithm
......@@ -17,13 +17,14 @@ from .cosmargin import CosMargin
from .circlemargin import CircleMargin
from .fc import FC
from .vehicle_neck import VehicleNeck
from paddle.nn import Tanh
__all__ = ['build_gear']
def build_gear(config):
support_dict = [
'ArcMargin', 'CosMargin', 'CircleMargin', 'FC', 'VehicleNeck'
'ArcMargin', 'CosMargin', 'CircleMargin', 'FC', 'VehicleNeck', 'Tanh'
]
module_name = config.pop('name')
assert module_name in support_dict, Exception(
......
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: ./output
device: gpu
save_interval: 15
eval_during_train: True
eval_interval: 15
epochs: 150
print_batch_step: 10
use_visualdl: False
# used for static mode and model export
image_shape: [3, 224, 224]
save_inference_dir: ./inference
eval_mode: retrieval
use_dali: False
to_static: False
#feature postprocess
feature_normalize: False
feature_binarize: "sign"
# model architecture
Arch:
name: RecModel
infer_output_key: features
infer_add_softmax: False
Backbone:
name: AlexNet
pretrained: True
class_num: 48
# loss function config for train/eval process
Loss:
Train:
- DCHLoss:
weight: 1.0
gamma: 20.0
_lambda: 0.1
n_class: 10
Eval:
- DCHLoss:
weight: 1.0
gamma: 20.0
_lambda: 0.1
n_class: 10
Optimizer:
name: SGD
lr:
name: Piecewise
learning_rate: 0.005
decay_epochs: [200]
values: [0.005, 0.0005]
regularizer:
name: 'L2'
coeff: 0.00001
# data loader for train and eval
DataLoader:
Train:
dataset:
name: ImageNetDataset
image_root: ./dataset/CIFAR10/
cls_label_path: ./dataset/CIFAR10/train_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
batch_size: 128
drop_last: False
shuffle: True
loader:
num_workers: 4
use_shared_memory: True
Eval:
Query:
dataset:
name: ImageNetDataset
image_root: ./dataset/CIFAR10/
cls_label_path: ./dataset/CIFAR10/test_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
batch_size: 128
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Gallery:
dataset:
name: ImageNetDataset
image_root: ./dataset/CIFAR10/
cls_label_path: ./dataset/CIFAR10/train_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
batch_size: 128
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Metric:
Eval:
- mAP: {}
- Recallk:
topk: [1, 5]
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: ./output
device: gpu
save_interval: 15
eval_during_train: True
eval_interval: 15
epochs: 150
print_batch_step: 10
use_visualdl: False
# used for static mode and model export
image_shape: [3, 224, 224]
save_inference_dir: ./inference
eval_mode: retrieval
use_dali: False
to_static: False
#feature postprocess
feature_normalize: False
feature_binarize: "sign"
# model architecture
Arch:
name: RecModel
infer_output_key: features
infer_add_softmax: False
Backbone:
name: AlexNet
pretrained: True
class_num: 48
Neck:
name: Tanh
Head:
name: FC
class_num: 10
embedding_size: 48
# loss function config for train/eval process
Loss:
Train:
- DSHSDLoss:
weight: 1.0
alpha: 0.05
Eval:
- DSHSDLoss:
weight: 1.0
alpha: 0.05
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
lr:
name: Piecewise
learning_rate: 0.00001
decay_epochs: [200]
values: [0.00001, 0.000001]
# data loader for train and eval
DataLoader:
Train:
dataset:
name: ImageNetDataset
image_root: ./dataset/CIFAR10/
cls_label_path: ./dataset/CIFAR10/train_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
batch_size: 128
drop_last: False
shuffle: True
loader:
num_workers: 4
use_shared_memory: True
Eval:
Query:
dataset:
name: ImageNetDataset
image_root: ./dataset/CIFAR10/
cls_label_path: ./dataset/CIFAR10/test_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
batch_size: 128
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Gallery:
dataset:
name: ImageNetDataset
image_root: ./dataset/CIFAR10/
cls_label_path: ./dataset/CIFAR10/train_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
batch_size: 128
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Metric:
Eval:
- mAP: {}
- Recallk:
topk: [1, 5]
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: ./output
device: gpu
save_interval: 15
eval_during_train: True
eval_interval: 15
epochs: 150
print_batch_step: 10
use_visualdl: False
# used for static mode and model export
image_shape: [3, 224, 224]
save_inference_dir: ./inference
eval_mode: retrieval
use_dali: False
to_static: False
#feature postprocess
feature_normalize: False
feature_binarize: "sign"
# model architecture
Arch:
name: RecModel
infer_output_key: features
infer_add_softmax: False
Backbone:
name: AlexNet
pretrained: True
class_num: 48
# loss function config for train/eval process
Loss:
Train:
- LCDSHLoss:
weight: 1.0
_lambda: 3
n_class: 10
Eval:
- LCDSHLoss:
weight: 1.0
_lambda: 3
n_class: 10
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
lr:
name: Piecewise
learning_rate: 0.00001
decay_epochs: [200]
values: [0.00001, 0.000001]
# data loader for train and eval
DataLoader:
Train:
dataset:
name: ImageNetDataset
image_root: ./dataset/CIFAR10/
cls_label_path: ./dataset/CIFAR10/train_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
batch_size: 128
drop_last: False
shuffle: True
loader:
num_workers: 4
use_shared_memory: True
Eval:
Query:
dataset:
name: ImageNetDataset
image_root: ./dataset/CIFAR10/
cls_label_path: ./dataset/CIFAR10/test_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
batch_size: 128
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Gallery:
dataset:
name: ImageNetDataset
image_root: ./dataset/CIFAR10/
cls_label_path: ./dataset/CIFAR10/train_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
batch_size: 128
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Metric:
Eval:
- mAP: {}
- Recallk:
topk: [1, 5]
......@@ -26,7 +26,9 @@ from .distillationloss import DistillationKLDivLoss
from .multilabelloss import MultiLabelLoss
from .afdloss import AFDLoss
from .deephashloss import DSHSDLoss, LCDSHLoss
from .deephashloss import DSHSDLoss
from .deephashloss import LCDSHLoss
from .deephashloss import DCHLoss
class CombinedLoss(nn.Layer):
......
......@@ -15,78 +15,132 @@
import paddle
import paddle.nn as nn
class DSHSDLoss(nn.Layer):
"""
# DSHSD(IEEE ACCESS 2019)
# paper [Deep Supervised Hashing Based on Stable Distribution](https://ieeexplore.ieee.org/document/8648432/)
# [DSHSD] epoch:70, bit:48, dataset:cifar10-1, MAP:0.809, Best MAP: 0.809
# [DSHSD] epoch:250, bit:48, dataset:nuswide_21, MAP:0.809, Best MAP: 0.815
# [DSHSD] epoch:135, bit:48, dataset:imagenet, MAP:0.647, Best MAP: 0.647
"""
def __init__(self, alpha, multi_label=False):
super(DSHSDLoss, self).__init__()
self.alpha = alpha
self.multi_label = multi_label
def forward(self, input, label):
feature = input["features"]
features = input["features"]
logits = input["logits"]
dist = paddle.sum(paddle.square(
(paddle.unsqueeze(feature, 1) - paddle.unsqueeze(feature, 0))),
axis=2)
features_temp1 = paddle.unsqueeze(features, 1)
features_temp2 = paddle.unsqueeze(features, 0)
dist = features_temp1 - features_temp2
dist = paddle.square(dist)
dist = paddle.sum(dist, axis=2)
# label to ont-hot
label = paddle.flatten(label)
n_class = logits.shape[1]
label = paddle.nn.functional.one_hot(label, n_class).astype("float32")
labels = paddle.nn.functional.one_hot(label, n_class)
labels = labels.squeeze().astype("float32")
s = (paddle.matmul(
label, label, transpose_y=True) == 0).astype("float32")
margin = 2 * feature.shape[1]
s = paddle.matmul(labels, labels, transpose_y=True)
s = (s == 0).astype("float32")
margin = 2 * features.shape[1]
Ld = (1 - s) / 2 * dist + s / 2 * (margin - dist).clip(min=0)
Ld = Ld.mean()
if self.multi_label:
# multiple labels classification loss
Lc = (logits - label * logits + (
(1 + (-logits).exp()).log())).sum(axis=1).mean()
Lc_temp = (1 + (-logits).exp()).log()
Lc = (logits - labels * logits + Lc_temp).sum(axis=1)
else:
# single labels classification loss
Lc = (-paddle.nn.functional.softmax(logits).log() * label).sum(
axis=1).mean()
probs = paddle.nn.functional.softmax(logits)
Lc = (-probs.log() * labels).sum(axis=1)
Lc = Lc.mean()
return {"dshsdloss": Lc + Ld * self.alpha}
loss = Lc + Ld * self.alpha
return {"dshsdloss": loss}
class LCDSHLoss(nn.Layer):
"""
# paper [Locality-Constrained Deep Supervised Hashing for Image Retrieval](https://www.ijcai.org/Proceedings/2017/0499.pdf)
# [LCDSH] epoch:145, bit:48, dataset:cifar10-1, MAP:0.798, Best MAP: 0.798
# [LCDSH] epoch:183, bit:48, dataset:nuswide_21, MAP:0.833, Best MAP: 0.834
"""
def __init__(self, n_class, _lambda):
super(LCDSHLoss, self).__init__()
self._lambda = _lambda
self.n_class = n_class
def forward(self, input, label):
feature = input["features"]
features = input["features"]
labels = paddle.nn.functional.one_hot(label, self.n_class)
labels = labels.squeeze().astype("float32")
# label to ont-hot
label = paddle.flatten(label)
label = paddle.nn.functional.one_hot(label, self.n_class).astype("float32")
s = 2 * (paddle.matmul(label, label, transpose_y=True) > 0).astype("float32") - 1
inner_product = paddle.matmul(feature, feature, transpose_y=True) * 0.5
s = paddle.matmul(labels, labels, transpose_y=True)
s = 2 * (s > 0).astype("float32") - 1
inner_product = paddle.matmul(features, features, transpose_y=True)
inner_product = inner_product * 0.5
inner_product = inner_product.clip(min=-50, max=50)
L1 = paddle.log(1 + paddle.exp(-s * inner_product)).mean()
L1 = paddle.log(1 + paddle.exp(-s * inner_product))
L1 = L1.mean()
b = feature.sign()
inner_product_ = paddle.matmul(b, b, transpose_y=True) * 0.5
binary_features = features.sign()
inner_product_ = paddle.matmul(
binary_features, binary_features, transpose_y=True)
inner_product_ = inner_product_ * 0.5
sigmoid = paddle.nn.Sigmoid()
L2 = (sigmoid(inner_product) - sigmoid(inner_product_)).pow(2).mean()
L2 = (sigmoid(inner_product) - sigmoid(inner_product_)).pow(2)
L2 = L2.mean()
loss = L1 + self._lambda * L2
return {"lcdshloss": loss}
class DCHLoss(paddle.nn.Layer):
"""
# paper [Deep Cauchy Hashing for Hamming Space Retrieval]
URL:(http://ise.thss.tsinghua.edu.cn/~mlong/doc/deep-cauchy-hashing-cvpr18.pdf)
"""
def __init__(self, gamma, _lambda, n_class):
super(DCHLoss, self).__init__()
self.gamma = gamma
self._lambda = _lambda
self.n_class = n_class
def distance(self, feature_i, feature_j):
assert feature_i.shape[1] == feature_j.shape[
1], "feature len of feature_i and feature_j is different, please check whether the featurs are right"
K = feature_i.shape[1]
inner_product = paddle.matmul(feature_i, feature_j, transpose_y=True)
len_i = feature_i.pow(2).sum(axis=1, keepdim=True).pow(0.5)
len_j = feature_j.pow(2).sum(axis=1, keepdim=True).pow(0.5)
norm = paddle.matmul(len_i, len_j, transpose_y=True)
cos = inner_product / norm.clip(min=0.0001)
dist = (1 - cos.clip(max=0.99)) * K / 2
return dist
def forward(self, input, label):
features = input["features"]
labels = paddle.nn.functional.one_hot(label, self.n_class)
labels = labels.squeeze().astype("float32")
s = paddle.matmul(labels, labels, transpose_y=True).astype("float32")
if (1 - s).sum() != 0 and s.sum() != 0:
positive_w = s * s.numel() / s.sum()
negative_w = (1 - s) * s.numel() / (1 - s).sum()
w = positive_w + negative_w
else:
w = 1
dist_matric = self.distance(features, features)
cauchy_loss = w * (s * paddle.log(dist_matric / self.gamma) +
paddle.log(1 + self.gamma / dist_matric))
return {"lcdshloss": L1 + self._lambda * L2}
all_one = paddle.ones_like(features, dtype="float32")
dist_to_one = self.distance(features.abs(), all_one)
quantization_loss = paddle.log(1 + dist_to_one / self.gamma)
loss = cauchy_loss.mean() + self._lambda * quantization_loss.mean()
return {"dchloss": loss}
......@@ -22,6 +22,51 @@ import paddle
from ppcls.utils import logger
class SGD(object):
"""
Args:
learning_rate (float|Tensor|LearningRateDecay, optional): The learning rate used to update ``Parameter``.
It can be a float value, a ``Tensor`` with a float type or a LearningRateDecay. The default value is 0.001.
parameters (list|tuple, optional): List/Tuple of ``Tensor`` to update to minimize ``loss``. \
This parameter is required in dygraph mode. \
The default value is None in static mode, at this time all parameters will be updated.
weight_decay (float|WeightDecayRegularizer, optional): The strategy of regularization. \
It canbe a float value as coeff of L2 regularization or \
:ref:`api_fluid_regularizer_L1Decay`, :ref:`api_fluid_regularizer_L2Decay`.
If a parameter has set regularizer using :ref:`api_fluid_ParamAttr` already, \
the regularization setting here in optimizer will be ignored for this parameter. \
Otherwise, the regularization setting here in optimizer will take effect. \
Default None, meaning there is no regularization.
grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of
some derived class of ``GradientClipBase`` . There are three cliping strategies
( :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` ,
:ref:`api_fluid_clip_GradientClipByValue` ). Default None, meaning there is no gradient clipping.
name (str, optional): The default value is None. Normally there is no need for user
to set this property.
"""
def __init__(self,
learning_rate=0.001,
weight_decay=None,
grad_clip=None,
name=None):
self.learning_rate = learning_rate
self.weight_decay = weight_decay
self.grad_clip = grad_clip
self.name = name
def __call__(self, model_list):
# model_list is None in static graph
parameters = sum([m.parameters() for m in model_list],
[]) if model_list else None
opt = optim.SGD(learning_rate=self.learning_rate,
parameters=parameters,
weight_decay=self.weight_decay,
grad_clip=self.grad_clip,
name=self.name)
return opt
class Momentum(object):
"""
Simple Momentum optimizer with velocity state.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册