提交 df03b4a0 编写于 作者: S stephon

add Binary general reocg configure

上级 cb96c82b
......@@ -61,6 +61,7 @@ from ppcls.arch.backbone.model_zoo.hardnet import HarDNet68, HarDNet85, HarDNet3
from ppcls.arch.backbone.model_zoo.cspnet import CSPDarkNet53
from ppcls.arch.backbone.variant_models.resnet_variant import ResNet50_last_stage_stride1
from ppcls.arch.backbone.variant_models.vgg_variant import VGG19Sigmoid
from ppcls.arch.backbone.variant_models.lcnet_variant import PPLCNet_x2_5Tanh
def get_apis():
......
from .resnet_variant import ResNet50_last_stage_stride1
from .vgg_variant import VGG19Sigmoid
from .lcnet_variant import PPLCNet_x2_5Tanh
import paddle
from paddle.nn import Sigmoid
from paddle.nn import Tanh
from ppcls.arch.backbone.legendary_models.pp_lcnet import PPLCNet_x2_5
__all__ = ["PPLCNet_x2_5Tanh"]
class TanhSuffix(paddle.nn.Layer):
def __init__(self, origin_layer):
super(SigmoidSuffix, self).__init__()
self.origin_layer = origin_layer
self.tanh = Tanh()
def forward(self, input, res_dict=None, **kwargs):
x = self.origin_layer(input)
x = self.tanh(x)
return x
def PPLCNet_x2_5Tanh(pretrained=False, use_ssld=False, **kwargs):
def replace_function(origin_layer):
new_layer = TanhSuffix(origin_layer)
return new_layer
match_re = "linear_0"
model = PPLCNet_x2_5(pretrained=pretrained, use_ssld=use_ssld, **kwargs)
model.replace_sub(match_re, replace_function, True)
return model
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: ./output/
device: gpu
save_interval: 1
eval_during_train: True
eval_interval: 1
epochs: 100
print_batch_step: 10
use_visualdl: False
# used for static mode and model export
image_shape: [3, 224, 224]
save_inference_dir: ./inference
eval_mode: retrieval
use_dali: False
to_static: False
#feature postprocess
feature_normalize: False
feature_binarize: "sign"
# model architecture
Arch:
name: RecModel
infer_output_key: features
infer_add_softmax: False
Backbone:
name: PPLCNet_x2_5Tanh
pretrained: True
use_ssld: True
class_num: 512
Head:
name: FC
embedding_size: &embedding_size 512
class_num: &n_class 185341
# loss function config for traing/eval process
Loss:
Train:
- DSHSDLoss:
weight: 1.0
n_class: *n_class
bit: *embedding_size
alpha: 0.1
Eval:
- DSHSDLoss:
weight: 1.0
n_class: *n_class
bit: *embedding_size
alpha: 0.1
Optimizer:
name: Momentum
momentum: 0.9
lr:
name: Cosine
learning_rate: 0.04
warmup_epoch: 5
regularizer:
name: 'L2'
coeff: 0.00001
# data loader for train and eval
DataLoader:
Train:
dataset:
name: ImageNetDataset
image_root: ./dataset/all_data
cls_label_path: ./dataset/all_data/train_reg_all_data.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- RandCropImage:
size: 224
- RandFlipImage:
flip_code: 1
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 256
drop_last: False
shuffle: True
loader:
num_workers: 4
use_shared_memory: True
Eval:
Query:
dataset:
name: VeriWild
image_root: ./dataset/Aliproduct/
cls_label_path: ./dataset/Aliproduct/val_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
size: 224
- NormalizeImage:
scale: 0.00392157
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 64
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Gallery:
dataset:
name: VeriWild
image_root: ./dataset/Aliproduct/
cls_label_path: ./dataset/Aliproduct/val_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
size: 224
- NormalizeImage:
scale: 0.00392157
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 64
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Metric:
Eval:
- Recallk:
topk: [1, 5]
......@@ -22,6 +22,8 @@ from .distillationloss import DistillationGTCELoss
from .distillationloss import DistillationDMLLoss
from .multilabelloss import MultiLabelLoss
from .deephashloss import DSHSDLoss, LCDSHLoss
class CombinedLoss(nn.Layer):
def __init__(self, config_list):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册