未验证 提交 aecbf40b 编写于 作者: W Wei Shengyu 提交者: GitHub

Merge pull request #762 from Intsigstephon/develop_reg

add RecModel for retrieval
......@@ -18,11 +18,14 @@ import importlib
import paddle.nn as nn
from . import backbone
from . import head
from .backbone import *
from ppcls.arch.loss_metrics.loss import *
from .head import *
from .utils import *
__all__ = ["build_model", "RecModel"]
def build_model(config):
config = copy.deepcopy(config)
......@@ -35,31 +38,31 @@ def build_model(config):
class RecModel(nn.Layer):
def __init__(self, **config):
super().__init__()
backbone_config = config["Backbone"]
backbone_name = backbone_config.pop("name")
self.backbone = getattr(backbone_name)(**backbone_config)
if "backbone_stop_layer" in config:
backbone_stop_layer = config["backbone_stop_layer"]
self.backbone.stop_layer(backbone_stop_layer)
self.backbone = eval(backbone_name)(**backbone_config)
if "Neck" in config:
neck_config = config["Neck"]
neck_name = neck_config.pop("name")
self.neck = getattr(neck_name)(**neck_config)
assert "Stoplayer" in config, "Stoplayer should be specified in retrieval task \
please specified a Stoplayer config"
stop_layer_config = config["Stoplayer"]
self.backbone.stop_after(stop_layer_config["name"])
if stop_layer_config.get("embedding_size", 0) > 0:
self.neck = nn.Linear(stop_layer_config["output_dim"], stop_layer_config["embedding_size"])
embedding_size = stop_layer_config["embedding_size"]
else:
self.neck = None
embedding_size = stop_layer_config["output_dim"]
assert "Head" in config, "Head should be specified in retrieval task \
please specify a Head config"
config["Head"]["embedding_size"] = embedding_size
self.head = build_head(config["Head"])
if "Head" in config:
head_config = config["Head"]
head_name = head_config.pop("name")
self.head = getattr(head_name)(**head_config)
else:
self.head = None
def forward(self, x):
y = self.backbone(x)
def forward(self, x, label):
x = self.backbone(x)
if self.neck is not None:
y = self.neck(y)
if self.head is not None:
y = self.head(y)
return y
x = self.neck(x)
y = self.head(x, label)
return {"features":x, "logits":y}
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = ['build_neck"]
def build_neck(config):
support_dict = ['FPN', 'FC']
module_name = config.pop('name')
assert module_name in support_dict, Exception('head only support {}'.format(
support_dict))
module_class = eval(module_name)(**config)
return module_class
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.nn as nn
class FC(nn.Layer):
def __init__(self, input_dim,
embedding_size):
super(FC, self).__init__()
self.input_dim = input_dim
self.embedding_size = embedding_size
weight_attr = paddle.ParamAttr(initializer = paddle.nn.initializer.XavierNormal())
self.fc = paddle.nn.Linear(self.input_dim, self.embedding_size, weight_attr=weight_attr)
def forward(self, x):
x = self.fc(x)
return x
......@@ -18,7 +18,7 @@ Global:
# model architecture
Arch:
name: "ResNet50"
# loss function config for traing/eval process
Loss:
Train:
......
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: "./output/"
device: "gpu"
class_num: 1000
save_interval: 1
eval_during_train: True
eval_interval: 1
epochs: 120
print_batch_step: 10
use_visualdl: False
image_shape: [3, 224, 224]
infer_imgs:
# model architecture
Arch:
name: "RecModel"
Backbone:
name: "ResNet50"
Stoplayer:
name: "flatten_0"
output_dim: 2048
embedding_size: 512
Head:
name: "ArcMargin"
margin: 0.5
scale: 80
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
weight: 1.0
Eval:
- CELoss:
weight: 1.0
Optimizer:
name: Momentum
momentum: 0.9
lr:
name: Piecewise
learning_rate: 0.1
decay_epochs: [30, 60, 90]
values: [0.1, 0.01, 0.001, 0.0001]
regularizer:
name: 'L2'
coeff: 0.0001
# data loader for train and eval
DataLoader:
Train:
# Dataset:
# Sampler:
# Loader:
batch_size: 256
num_workers: 4
file_list: "./dataset/ILSVRC2012/train_list.txt"
data_dir: "./dataset/ILSVRC2012/"
shuffle_seed: 0
transforms:
- DecodeImage:
to_rgb: True
channel_first: False
- RandCropImage:
size: 224
- RandFlipImage:
flip_code: 1
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
Eval:
# TOTO: modify to the latest trainer
# Dataset:
# Sampler:
# Loader:
batch_size: 128
num_workers: 4
file_list: "./dataset/ILSVRC2012/val_list.txt"
data_dir: "./dataset/ILSVRC2012/"
shuffle_seed: 0
transforms:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
Metric:
Train:
- Topk:
k: [1, 5]
Eval:
- Topk:
k: [1, 5]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册