提交 2507be1a 编写于 作者: L lubin

add deephash configs and dch algorithm

上级 a25d37ea
......@@ -17,13 +17,14 @@ from .cosmargin import CosMargin
from .circlemargin import CircleMargin
from .fc import FC
from .vehicle_neck import VehicleNeck
from paddle.nn import Tanh
__all__ = ['build_gear']
def build_gear(config):
support_dict = [
'ArcMargin', 'CosMargin', 'CircleMargin', 'FC', 'VehicleNeck'
'ArcMargin', 'CosMargin', 'CircleMargin', 'FC', 'VehicleNeck', 'Tanh'
]
module_name = config.pop('name')
assert module_name in support_dict, Exception(
......
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: ./output
device: gpu
save_interval: 15
eval_during_train: True
eval_interval: 15
epochs: 150
print_batch_step: 10
use_visualdl: False
# used for static mode and model export
image_shape: [3, 224, 224]
save_inference_dir: ./inference
eval_mode: retrieval
use_dali: False
to_static: False
#feature postprocess
feature_normalize: False
feature_binarize: "sign"
# model architecture
Arch:
name: RecModel
infer_output_key: features
infer_add_softmax: False
Backbone:
name: AlexNet
pretrained: True
class_num: 48
# loss function config for traing/eval process
Loss:
Train:
- DCHLoss:
weight: 1.0
gamma: 20.0
_lambda: 0.1
n_class: 10
Eval:
- DCHLoss:
weight: 1.0
gamma: 20.0
_lambda: 0.1
n_class: 10
Optimizer:
name: SGD
lr:
name: Piecewise
learning_rate: 0.005
decay_epochs: [200]
values: [0.005, 0.0005]
regularizer:
name: 'L2'
coeff: 0.00001
# data loader for train and eval
DataLoader:
Train:
dataset:
name: CustomizedCifar10
mode: 'train'
sampler:
batch_size: 128
drop_last: False
shuffle: True
loader:
num_workers: 4
use_shared_memory: True
Eval:
Query:
dataset:
name: CustomizedCifar10
mode: 'test'
sampler:
batch_size: 128
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Gallery:
dataset:
name: CustomizedCifar10
mode: 'train'
sampler:
batch_size: 128
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Metric:
Eval:
- mAP: {}
- Recallk:
topk: [1, 5]
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: ./output
device: gpu
save_interval: 15
eval_during_train: True
eval_interval: 15
epochs: 150
print_batch_step: 10
use_visualdl: False
# used for static mode and model export
image_shape: [3, 224, 224]
save_inference_dir: ./inference
eval_mode: retrieval
use_dali: False
to_static: False
#feature postprocess
feature_normalize: False
feature_binarize: "sign"
# model architecture
Arch:
name: RecModel
infer_output_key: features
infer_add_softmax: False
Backbone:
name: AlexNet
pretrained: True
class_num: 48
Neck:
name: Tanh
Head:
name: FC
class_num: 10
embedding_size: 48
# loss function config for traing/eval process
Loss:
Train:
- DSHSDLoss:
weight: 1.0
alpha: 0.05
Eval:
- DSHSDLoss:
weight: 1.0
alpha: 0.05
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
lr:
name: Piecewise
learning_rate: 0.00001
decay_epochs: [200]
values: [0.00001, 0.000001]
# data loader for train and eval
DataLoader:
Train:
dataset:
name: CustomizedCifar10
mode: 'train'
sampler:
batch_size: 128
drop_last: False
shuffle: True
loader:
num_workers: 4
use_shared_memory: True
Eval:
Query:
dataset:
name: CustomizedCifar10
mode: 'test'
sampler:
batch_size: 128
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Gallery:
dataset:
name: CustomizedCifar10
mode: 'train'
sampler:
batch_size: 128
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Metric:
Eval:
- mAP: {}
- Recallk:
topk: [1, 5]
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: ./output
device: gpu
save_interval: 15
eval_during_train: True
eval_interval: 15
epochs: 150
print_batch_step: 10
use_visualdl: False
# used for static mode and model export
image_shape: [3, 224, 224]
save_inference_dir: ./inference
eval_mode: retrieval
use_dali: False
to_static: False
#feature postprocess
feature_normalize: False
feature_binarize: "sign"
# model architecture
Arch:
name: RecModel
infer_output_key: features
infer_add_softmax: False
Backbone:
name: AlexNet
pretrained: True
class_num: 48
# loss function config for traing/eval process
Loss:
Train:
- LCDSHLoss:
weight: 1.0
_lambda: 3
n_class: 10
Eval:
- LCDSHLoss:
weight: 1.0
_lambda: 3
n_class: 10
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
lr:
name: Piecewise
learning_rate: 0.00001
decay_epochs: [200]
values: [0.00001, 0.000001]
# data loader for train and eval
DataLoader:
Train:
dataset:
name: CustomizedCifar10
mode: 'train'
sampler:
batch_size: 128
drop_last: False
shuffle: True
loader:
num_workers: 4
use_shared_memory: True
Eval:
Query:
dataset:
name: CustomizedCifar10
mode: 'test'
sampler:
batch_size: 128
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Gallery:
dataset:
name: CustomizedCifar10
mode: 'train'
sampler:
batch_size: 128
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Metric:
Eval:
- mAP: {}
- Recallk:
topk: [1, 5]
......@@ -28,6 +28,7 @@ from ppcls.data.dataloader.vehicle_dataset import CompCars, VeriWild
from ppcls.data.dataloader.logo_dataset import LogoDataset
from ppcls.data.dataloader.icartoon_dataset import ICartoonDataset
from ppcls.data.dataloader.mix_dataset import MixDataset
from ppcls.data.dataloader.customized_cifar10 import CustomizedCifar10
# sampler
from ppcls.data.dataloader.DistributedRandomIdentitySampler import DistributedRandomIdentitySampler
......
......@@ -4,6 +4,7 @@ from ppcls.data.dataloader.common_dataset import create_operators
from ppcls.data.dataloader.vehicle_dataset import CompCars, VeriWild
from ppcls.data.dataloader.logo_dataset import LogoDataset
from ppcls.data.dataloader.icartoon_dataset import ICartoonDataset
from ppcls.data.dataloader.customized_cifar10 import CustomizedCifar10
from ppcls.data.dataloader.mix_dataset import MixDataset
from ppcls.data.dataloader.mix_sampler import MixSampler
from ppcls.data.dataloader.pk_sampler import PKSampler
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle.vision.datasets import Cifar10
from paddle.vision import transforms
from paddle.dataset.common import _check_exists_and_download
import numpy as np
import os
from PIL import Image
class CustomizedCifar10(Cifar10):
def __init__(self,
data_file=None,
mode='train',
download=True,
backend=None):
assert mode.lower() in ['train', 'test', 'train', 'test'], \
"mode should be 'train10', 'test10', 'train100' or 'test100', but got {}".format(mode)
self.mode = mode.lower()
if backend is None:
backend = paddle.vision.get_image_backend()
if backend not in ['pil', 'cv2']:
raise ValueError(
"Expected backend are one of ['pil', 'cv2'], but got {}"
.format(backend))
self.backend = backend
self._init_url_md5_flag()
self.data_file = data_file
if self.data_file is None:
assert download, "data_file is not set and downloading automatically is disabled"
self.data_file = _check_exists_and_download(
data_file, self.data_url, self.data_md5, 'cifar', download)
self.transform = transforms.Compose([
transforms.Resize(224), transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
self._load_data()
self.dtype = paddle.get_default_dtype()
def __getitem__(self, index):
img, target = self.data[index]
img = np.reshape(img, [3, 32, 32])
img = img.transpose([1, 2, 0]).astype("uint8")
img = Image.fromarray(img)
img = self.transform(img)
return (img, target)
......@@ -24,7 +24,9 @@ from .distillationloss import DistillationDistanceLoss
from .distillationloss import DistillationRKDLoss
from .multilabelloss import MultiLabelLoss
from .deephashloss import DSHSDLoss, LCDSHLoss
from .deephashloss import DSHSDLoss
from .deephashloss import LCDSHLoss
from .deephashloss import DCHLoss
class CombinedLoss(nn.Layer):
......
......@@ -15,6 +15,7 @@
import paddle
import paddle.nn as nn
class DSHSDLoss(nn.Layer):
"""
# DSHSD(IEEE ACCESS 2019)
......@@ -23,6 +24,7 @@ class DSHSDLoss(nn.Layer):
# [DSHSD] epoch:250, bit:48, dataset:nuswide_21, MAP:0.809, Best MAP: 0.815
# [DSHSD] epoch:135, bit:48, dataset:imagenet, MAP:0.647, Best MAP: 0.647
"""
def __init__(self, alpha, multi_label=False):
super(DSHSDLoss, self).__init__()
self.alpha = alpha
......@@ -37,9 +39,9 @@ class DSHSDLoss(nn.Layer):
axis=2)
# label to ont-hot
label = paddle.flatten(label)
n_class = logits.shape[1]
label = paddle.nn.functional.one_hot(label, n_class).astype("float32")
label = paddle.nn.functional.one_hot(
label, n_class).astype("float32").squeeze()
s = (paddle.matmul(
label, label, transpose_y=True) == 0).astype("float32")
......@@ -65,6 +67,7 @@ class LCDSHLoss(nn.Layer):
# [LCDSH] epoch:145, bit:48, dataset:cifar10-1, MAP:0.798, Best MAP: 0.798
# [LCDSH] epoch:183, bit:48, dataset:nuswide_21, MAP:0.833, Best MAP: 0.834
"""
def __init__(self, n_class, _lambda):
super(LCDSHLoss, self).__init__()
self._lambda = _lambda
......@@ -73,11 +76,11 @@ class LCDSHLoss(nn.Layer):
def forward(self, input, label):
feature = input["features"]
# label to ont-hot
label = paddle.flatten(label)
label = paddle.nn.functional.one_hot(label, self.n_class).astype("float32")
s = 2 * (paddle.matmul(label, label, transpose_y=True) > 0).astype("float32") - 1
label = paddle.nn.functional.one_hot(
label, self.n_class).astype("float32").squeeze()
s = 2 * (paddle.matmul(
label, label, transpose_y=True) > 0).astype("float32") - 1
inner_product = paddle.matmul(feature, feature, transpose_y=True) * 0.5
inner_product = inner_product.clip(min=-50, max=50)
......@@ -90,3 +93,58 @@ class LCDSHLoss(nn.Layer):
return {"lcdshloss": L1 + self._lambda * L2}
class DCHLoss(paddle.nn.Layer):
"""
# paper [Deep Cauchy Hashing for Hamming Space Retrieval]
URL:(http://ise.thss.tsinghua.edu.cn/~mlong/doc/deep-cauchy-hashing-cvpr18.pdf)
# [DCH] epoch:150, bit:48, dataset:cifar10-1, MAP:0.768, Best MAP: 0.810
# [DCH] epoch:150, bit:48, dataset:coco, MAP:0.665, Best MAP: 0.670
# [DCH] epoch:150, bit:48, dataset:imagenet, MAP:0.586, Best MAP: 0.586
# [DCH] epoch:150, bit:48, dataset:nuswide_21, MAP:0.778, Best MAP: 0.794
"""
def __init__(self, gamma, _lambda, n_class):
super(DCHLoss, self).__init__()
self.gamma = gamma
self._lambda = _lambda
self.n_class = n_class
def d(self, hi, hj):
assert hi.shape[1] == hj.shape[
1], "feature len of hi and hj is different, please check whether the featurs are right"
K = hi.shape[1]
inner_product = paddle.matmul(hi, hj, transpose_y=True)
len_i = hi.pow(2).sum(axis=1, keepdim=True).pow(0.5)
len_j = hj.pow(2).sum(axis=1, keepdim=True).pow(0.5)
norm = paddle.matmul(len_i, len_j, transpose_y=True)
cos = inner_product / norm.clip(min=0.0001)
return (1 - cos.clip(max=0.99)) * K / 2
def forward(self, input, label):
u = input["features"]
y = paddle.nn.functional.one_hot(
label, self.n_class).astype("float32").squeeze()
s = paddle.matmul(y, y, transpose_y=True).astype("float32")
if (1 - s).sum() != 0 and s.sum() != 0:
positive_w = s * s.numel() / s.sum()
negative_w = (1 - s) * s.numel() / (1 - s).sum()
w = positive_w + negative_w
else:
w = 1
d_hi_hj = self.d(u, u)
cauchy_loss = w * (s * paddle.log(d_hi_hj / self.gamma) +
paddle.log(1 + self.gamma / d_hi_hj))
all_one = paddle.ones_like(u, dtype="float32")
quantization_loss = paddle.log(1 + self.d(u.abs(), all_one) /
self.gamma)
loss = cauchy_loss.mean() + self._lambda * quantization_loss.mean()
return {"dchloss": loss}
......@@ -22,6 +22,53 @@ import paddle
from ppcls.utils import logger
class SGD(object):
"""
Args:
learning_rate (float|Tensor|LearningRateDecay, optional): The learning rate used to update ``Parameter``.
It can be a float value, a ``Tensor`` with a float type or a LearningRateDecay. The default value is 0.001.
parameters (list|tuple, optional): List/Tuple of ``Tensor`` to update to minimize ``loss``. \
This parameter is required in dygraph mode. \
The default value is None in static mode, at this time all parameters will be updated.
weight_decay (float|WeightDecayRegularizer, optional): The strategy of regularization. \
It canbe a float value as coeff of L2 regularization or \
:ref:`api_fluid_regularizer_L1Decay`, :ref:`api_fluid_regularizer_L2Decay`.
If a parameter has set regularizer using :ref:`api_fluid_ParamAttr` already, \
the regularization setting here in optimizer will be ignored for this parameter. \
Otherwise, the regularization setting here in optimizer will take effect. \
Default None, meaning there is no regularization.
grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of
some derived class of ``GradientClipBase`` . There are three cliping strategies
( :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` ,
:ref:`api_fluid_clip_GradientClipByValue` ). Default None, meaning there is no gradient clipping.
name (str, optional): The default value is None. Normally there is no need for user
to set this property.
"""
def __init__(self,
learning_rate=0.001,
parameters=None,
weight_decay=None,
grad_clip=None,
name=None):
self.learning_rate = learning_rate
self.parameters = parameters
self.weight_decay = weight_decay
self.grad_clip = grad_clip
self.name = name
def __call__(self, model_list):
# model_list is None in static graph
parameters = sum([m.parameters() for m in model_list],
[]) if model_list else None
opt = optim.SGD(learning_rate=self.learning_rate,
parameters=parameters,
weight_decay=self.weight_decay,
grad_clip=self.grad_clip,
name=self.name)
return opt
class Momentum(object):
"""
Simple Momentum optimizer with velocity state.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册