提交 bba0cf8f 编写于 作者: D dongshuilong

add CompCars train

上级 d58fd3b7
......@@ -21,7 +21,7 @@ from . import backbone
from . import head
from .backbone import *
from .head import *
from .head import *
from .utils import *
__all__ = ["build_model", "RecModel"]
......@@ -43,20 +43,24 @@ class RecModel(nn.Layer):
backbone_name = backbone_config.pop("name")
self.backbone = eval(backbone_name)(**backbone_config)
assert "Stoplayer" in config, "Stoplayer should be specified in retrieval task \
assert "Stoplayer" in config, "Stoplayer should be specified in retrieval task \
please specified a Stoplayer config"
stop_layer_config = config["Stoplayer"]
self.backbone.stop_after(stop_layer_config["name"])
if stop_layer_config.get("embedding_size", 0) > 0:
self.neck = nn.Linear(stop_layer_config["output_dim"], stop_layer_config["embedding_size"])
# self.neck = nn.Linear(stop_layer_config["output_dim"], stop_layer_config["embedding_size"])
self.neck = nn.Conv2D(stop_layer_config["output_dim"],
stop_layer_config["embedding_size"])
embedding_size = stop_layer_config["embedding_size"]
else:
self.neck = None
embedding_size = stop_layer_config["output_dim"]
assert "Head" in config, "Head should be specified in retrieval task \
assert "Head" in config, "Head should be specified in retrieval task \
please specify a Head config"
config["Head"]["embedding_size"] = embedding_size
self.head = build_head(config["Head"])
......@@ -65,4 +69,4 @@ class RecModel(nn.Layer):
if self.neck is not None:
x = self.neck(x)
y = self.head(x, label)
return {"features":x, "logits":y}
return {"features": x, "logits": y}
......@@ -16,35 +16,44 @@ import paddle
import paddle.nn as nn
import math
class ArcMargin(nn.Layer):
def __init__(self, embedding_size,
class_num,
margin=0.5,
scale=80.0,
easy_margin=False):
def __init__(self,
embedding_size,
class_num,
margin=0.5,
scale=80.0,
easy_margin=False):
super(ArcMargin, self).__init__()
self.embedding_size = embedding_size
self.class_num = class_num
self.margin = margin
self.scale = scale
self.embedding_size = embedding_size
self.class_num = class_num
self.margin = margin
self.scale = scale
self.easy_margin = easy_margin
weight_attr = paddle.ParamAttr(initializer = paddle.nn.initializer.XavierNormal())
self.fc = nn.Linear(self.embedding_size, self.class_num, weight_attr=weight_attr, bias_attr=False)
weight_attr = paddle.ParamAttr(
initializer=paddle.nn.initializer.XavierNormal())
self.fc = nn.Linear(
self.embedding_size,
self.class_num,
weight_attr=weight_attr,
bias_attr=False)
def forward(self, input, label):
input_norm = paddle.sqrt(paddle.sum(paddle.square(input), axis=1, keepdim=True))
input_norm = paddle.sqrt(
paddle.sum(paddle.square(input), axis=1, keepdim=True))
input = paddle.divide(input, input_norm)
weight = self.fc.weight
weight_norm = paddle.sqrt(paddle.sum(paddle.square(weight), axis=0, keepdim=True))
weight_norm = paddle.sqrt(
paddle.sum(paddle.square(weight), axis=0, keepdim=True))
weight = paddle.divide(weight, weight_norm)
cos = paddle.matmul(input, weight)
sin = paddle.sqrt(1.0 - paddle.square(cos) + 1e-6)
cos = paddle.matmul(input, weight)
sin = paddle.sqrt(1.0 - paddle.square(cos) + 1e-6)
cos_m = math.cos(self.margin)
sin_m = math.sin(self.margin)
phi = cos * cos_m - sin * sin_m
phi = cos * cos_m - sin * sin_m
th = math.cos(self.margin) * (-1)
mm = math.sin(self.margin) * self.margin
......@@ -55,11 +64,12 @@ class ArcMargin(nn.Layer):
one_hot = paddle.nn.functional.one_hot(label, self.class_num)
one_hot = paddle.squeeze(one_hot, axis=[1])
output = paddle.multiply(one_hot, phi) + paddle.multiply((1.0 - one_hot), cos)
output = output * self.scale
output = paddle.multiply(one_hot, phi) + paddle.multiply(
(1.0 - one_hot), cos)
output = output * self.scale
return output
def _paddle_where_more_than(self, target, limit, x, y):
mask = paddle.cast( x = (target > limit), dtype='float32')
mask = paddle.cast(x=(target > limit), dtype='float32')
output = paddle.multiply(mask, x) + paddle.multiply((1.0 - mask), y)
return output
......@@ -46,8 +46,8 @@ class CELoss(nn.Layer):
if self.epsilon is not None:
class_num = logits.shape[-1]
label = self._labelsmoothing(label, class_num)
x = -F.log_softmax(x, axis=-1)
loss = paddle.sum(x * label, axis=-1)
x = -F.log_softmax(logits, axis=-1)
loss = paddle.sum(logits * label, axis=-1)
else:
if label.shape[-1] == logits.shape[-1]:
label = F.softmax(label, axis=-1)
......
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: "./output/"
device: "gpu"
class_num: 431
save_interval: 1
eval_during_train: True
eval_interval: 1
epochs: 160
print_batch_step: 10
use_visualdl: False
# used for static mode and model export
image_shape: [3, 224, 224]
save_inference_dir: "./inference"
# model architecture
RecModel:
Backbone: "ResNet50"
Stoplayer: "adaptive_avg_pool2d_0"
embedding_size: 512
Head:
name: "ArcMargin"
embedding_size: 512
class_num: 431
margin: 0.15
scale: 32
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
weight: 1.0
- TripletLossV2:
weight: 1.0
margin: 0.5
Optimizer:
name: Momentum
momentum: 0.9
lr:
name: MultiStepDecay
learning_rate: 0.01
decay_epochs: [30, 60, 70, 80, 90, 100, 120, 140]
gamma: 0.5
verbose: False
last_epoch: -1
regularizer:
name: 'L2'
coeff: 0.0005
# data loader for train and eval
DataLoader:
Train:
dataset:
name: "CompCars"
image_root: "/work/dataset/CompCars/image/"
label_root: "/work/dataset/CompCars/label/"
bbox_crop: True
cls_label_path: "/work/dataset/CompCars/train_test_split/classification/train_label.txt"
transform_ops:
- ResizeImage:
size: 224
- RandFlipImage:
flip_code: 1
- AugMix:
prob: 0.5
- NormalizeImage:
scale: 0.00392157
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- RandomErasing:
EPSILON: 0.5
sl: 0.02
sh: 0.4
r1: 0.3
mean: [0., 0., 0.]
sampler:
name: DistributedRandomIdentitySampler
batch_size: 128
num_instances: 2
drop_last: False
shuffle: True
loader:
num_workers: 6
use_shared_memory: False
Eval:
# TOTO: modify to the latest trainer
dataset:
name: "CompCars"
image_root: "/work/dataset/CompCars/image/"
label_root: "/work/dataset/CompCars/label/"
cls_label_path: "/work/dataset/CompCars/train_test_split/classification/test_label.txt"
bbox_crop: True
transform_ops:
- ResizeImage:
size: 224
- NormalizeImage:
scale: 0.00392157
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 64
drop_last: False
shuffle: False
loader:
num_workers: 6
use_shared_memory: False
Infer:
infer_imgs: "docs/images/whl/demo.jpg"
batch_size: 10
transforms:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
PostProcess:
name: Topk
topk: 5
class_id_map_file: "ppcls/utils/imagenet1k_label_list.txt"
Metric:
Train:
- Topk:
k: [1, 5]
Eval:
- Topk:
k: [1, 5]
......@@ -25,14 +25,17 @@ from . import samplers
from .dataset.imagenet_dataset import ImageNetDataset
from .dataset.multilabel_dataset import MultiLabelDataset
from .dataset.common_dataset import create_operators
from .dataset.vehicle_dataset import CompCars, VeriWild
# sampler
from .samplers import DistributedRandomIdentitySampler
from .preprocess import transform
def build_dataloader(config, mode, device, seed=None):
assert mode in ['Train', 'Eval', 'Test'], "Mode should be Train, Eval or Test."
assert mode in ['Train', 'Eval', 'Test'
], "Mode should be Train, Eval or Test."
# build dataset
config_dataset = config[mode]['dataset']
config_dataset = copy.deepcopy(config_dataset)
......@@ -76,7 +79,7 @@ def build_dataloader(config, mode, device, seed=None):
batch_ops = create_operators(batch_transform)
batch_collate_fn = mix_collate_fn
else:
batch_collate_fn = None
batch_collate_fn = None
# build dataloader
config_loader = config[mode]['loader']
......@@ -105,9 +108,10 @@ def build_dataloader(config, mode, device, seed=None):
collate_fn=batch_collate_fn)
logger.info("build data_loader({}) success...".format(data_loader))
return data_loader
'''
# TODO: fix the format
def build_dataloader(config, mode, device, seed=None):
......
......@@ -14,17 +14,10 @@
from __future__ import print_function
import io
import tarfile
import numpy as np
from PIL import Image #all use default backend
import paddle
from paddle.io import Dataset
import pickle
import os
import cv2
import random
from ppcls.data import preprocess
from ppcls.data.preprocess import transform
......@@ -65,7 +58,7 @@ class CommonDataset(Dataset):
self.labels = []
self._load_anno()
def _load_anno(self):
def _load_anno(self):
pass
def __getitem__(self, idx):
......@@ -89,4 +82,3 @@ class CommonDataset(Dataset):
@property
def class_num(self):
return len(set(self.labels))
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import paddle
from paddle.io import Dataset
import os
import cv2
from ppcls.data import preprocess
from ppcls.data.preprocess import transform
from ppcls.utils import logger
from .common_dataset import create_operators
class CompCars(Dataset):
def __init__(self,
image_root,
cls_label_path,
label_root=None,
transform_ops=None,
bbox_crop=False):
self._img_root = image_root
self._cls_path = cls_label_path
self._label_root = label_root
if transform_ops:
self._transform_ops = create_operators(transform_ops)
self._bbox_crop = bbox_crop
self._dtype = paddle.get_default_dtype()
self._load_anno()
def _load_anno(self):
assert os.path.exists(self._cls_path)
assert os.path.exists(self._img_root)
if self._bbox_crop:
assert os.path.exists(self._label_root)
self.images = []
self.labels = []
self.bboxes = []
with open(self._cls_path) as fd:
lines = fd.readlines()
for l in lines:
l = l.strip().split()
if not self._bbox_crop:
self.images.append(os.path.join(self._img_root, l[0]))
self.labels.append(int(l[1]))
else:
label_path = os.path.join(self._label_root,
l[0].split('.')[0] + '.txt')
assert os.path.exists(label_path)
bbox = open(label_path).readlines()[-1].strip().split()
bbox = [int(x) for x in bbox]
self.images.append(os.path.join(self._img_root, l[0]))
self.labels.append(int(l[1]))
self.bboxes.append(bbox)
assert os.path.exists(self.images[-1])
def __getitem__(self, idx):
img = cv2.imread(self.images[idx])
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
if self._bbox_crop:
bbox = self.bboxes[idx]
img = img[bbox[1]:bbox[3], bbox[0]:bbox[2], :]
if self._transform_ops:
img = transform(img, self._transform_ops)
img = img.transpose((2, 0, 1))
return (img, self.labels[idx])
def __len__(self):
return len(self.images)
@property
def class_num(self):
return len(set(self.labels))
class VeriWild(Dataset):
def __init__(
self,
image_root,
cls_label_path,
transform_ops=None, ):
self._img_root = image_root
self._cls_path = cls_label_path
if transform_ops:
self._transform_ops = create_operators(transform_ops)
self._dtype = paddle.get_default_dtype()
self._load_anno()
def _load_anno(self):
assert os.path.exists(self._cls_path)
assert os.path.exists(self._img_root)
self.images = []
self.labels = []
self.cameras = []
with open(self._cls_path) as fd:
lines = fd.readlines()
for l in lines:
l = l.strip().split()
self.images.append(os.path.join(self._img_root, l[0]))
self.labels.append(int(l[1]))
self.cameras.append(int(l[2]))
assert os.path.exists(self.images[-1])
def __getitem__(self, idx):
try:
img = cv2.imread(self.images[idx])
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
if self._transform_ops:
img = transform(img, self._transform_ops)
img = img.transpose((2, 0, 1))
return (img, self.labels[idx], self.cameras[idx])
except Exception as ex:
logger.error("Exception occured when parse line: {} with msg: {}".
format(self.images[idx], ex))
rnd_idx = np.random.randint(self.__len__())
return self.__getitem__(rnd_idx)
def __len__(self):
return len(self.images)
@property
def class_num(self):
return len(set(self.labels))
......@@ -29,11 +29,13 @@ from PIL import Image
from .autoaugment import ImageNetPolicy
from .functional import augmentations
class OperatorParamError(ValueError):
""" OperatorParamError
"""
pass
class DecodeImage(object):
""" decode image """
......@@ -235,7 +237,12 @@ class AugMix(object):
""" Perform AugMix augmentation and compute mixture.
"""
def __init__(self, prob=0.5, aug_prob_coeff=0.1, mixture_width=3, mixture_depth=1, aug_severity=1):
def __init__(self,
prob=0.5,
aug_prob_coeff=0.1,
mixture_width=3,
mixture_depth=1,
aug_severity=1):
"""
Args:
prob: Probability of taking augmix
......@@ -264,14 +271,16 @@ class AugMix(object):
ws = np.float32(
np.random.dirichlet([self.aug_prob_coeff] * self.mixture_width))
m = np.float32(np.random.beta(self.aug_prob_coeff, self.aug_prob_coeff))
m = np.float32(
np.random.beta(self.aug_prob_coeff, self.aug_prob_coeff))
# image = Image.fromarray(image)
mix = np.zeros([image.shape[1], image.shape[0], 3])
for i in range(self.mixture_width):
image_aug = image.copy()
image_aug = Image.fromarray(image_aug)
depth = self.mixture_depth if self.mixture_depth > 0 else np.random.randint(1, 4)
depth = self.mixture_depth if self.mixture_depth > 0 else np.random.randint(
1, 4)
for _ in range(depth):
op = np.random.choice(self.augmentations)
image_aug = op(image_aug, self.aug_severity)
......
......@@ -30,7 +30,7 @@ from ppcls.utils.misc import AverageMeter
from ppcls.utils import logger
from ppcls.data import build_dataloader
from ppcls.arch import build_model
from ppcls.arch.loss_metrics import build_loss
from ppcls.losses import build_loss
from ppcls.arch.loss_metrics import build_metrics
from ppcls.optimizer import build_optimizer
from ppcls.utils.save_load import load_dygraph_pretrain
......
......@@ -5,17 +5,20 @@ from __future__ import print_function
import paddle
import paddle.nn as nn
class TripletLossV2(nn.Layer):
"""Triplet loss with hard positive/negative mining.
Args:
margin (float): margin for triplet.
"""
def __init__(self, margin=0.5):
def __init__(self, margin=0.5, normalize_feature=True):
super(TripletLossV2, self).__init__()
self.margin = margin
self.ranking_loss = paddle.nn.loss.MarginRankingLoss(margin=margin)
self.normalize_feature = normalize_feature
def forward(self, input, target, normalize_feature=True):
def forward(self, input, target):
"""
Args:
inputs: feature matrix with shape (batch_size, feat_dim)
......@@ -23,28 +26,25 @@ class TripletLossV2(nn.Layer):
"""
inputs = input["features"]
if normalize_feature:
if self.normalize_feature:
inputs = 1. * inputs / (paddle.expand_as(
paddle.norm(inputs, p=2, axis=-1, keepdim=True), inputs) +
1e-12)
paddle.norm(
inputs, p=2, axis=-1, keepdim=True), inputs) + 1e-12)
bs = inputs.shape[0]
# compute distance
dist = paddle.pow(inputs, 2).sum(axis=1, keepdim=True).expand([bs, bs])
dist = dist + dist.t()
dist = paddle.addmm(input=dist,
x=inputs,
y=inputs.t(),
alpha=-2.0,
beta=1.0)
dist = paddle.addmm(
input=dist, x=inputs, y=inputs.t(), alpha=-2.0, beta=1.0)
dist = paddle.clip(dist, min=1e-12).sqrt()
# hard negative mining
is_pos = paddle.expand(target, (bs, bs)).equal(
paddle.expand(target, (bs, bs)).t())
is_neg = paddle.expand(target, (bs, bs)).not_equal(
paddle.expand(target, (bs, bs)).t())
is_pos = paddle.expand(target, (
bs, bs)).equal(paddle.expand(target, (bs, bs)).t())
is_neg = paddle.expand(target, (
bs, bs)).not_equal(paddle.expand(target, (bs, bs)).t())
# `dist_ap` means distance(anchor, positive)
## both `dist_ap` and `relative_p_inds` with shape [N, 1]
......@@ -56,14 +56,14 @@ class TripletLossV2(nn.Layer):
dist_an, relative_n_inds = paddle.min(
paddle.reshape(dist[is_neg], (bs, -1)), axis=1, keepdim=True)
'''
dist_ap = paddle.max(paddle.reshape(paddle.masked_select(dist, is_pos),
(bs, -1)),
dist_ap = paddle.max(paddle.reshape(
paddle.masked_select(dist, is_pos), (bs, -1)),
axis=1,
keepdim=True)
# `dist_an` means distance(anchor, negative)
# both `dist_an` and `relative_n_inds` with shape [N, 1]
dist_an = paddle.min(paddle.reshape(paddle.masked_select(dist, is_neg),
(bs, -1)),
dist_an = paddle.min(paddle.reshape(
paddle.masked_select(dist, is_neg), (bs, -1)),
axis=1,
keepdim=True)
# shape [N]
......@@ -84,6 +84,7 @@ class TripletLoss(nn.Layer):
Args:
margin (float): margin for triplet.
"""
def __init__(self, margin=1.0):
super(TripletLoss, self).__init__()
self.margin = margin
......@@ -101,15 +102,12 @@ class TripletLoss(nn.Layer):
# Compute pairwise distance, replace by the official when merged
dist = paddle.pow(inputs, 2).sum(axis=1, keepdim=True).expand([bs, bs])
dist = dist + dist.t()
dist = paddle.addmm(input=dist,
x=inputs,
y=inputs.t(),
alpha=-2.0,
beta=1.0)
dist = paddle.addmm(
input=dist, x=inputs, y=inputs.t(), alpha=-2.0, beta=1.0)
dist = paddle.clip(dist, min=1e-12).sqrt()
mask = paddle.equal(target.expand([bs, bs]),
target.expand([bs, bs]).t())
mask = paddle.equal(
target.expand([bs, bs]), target.expand([bs, bs]).t())
mask_numpy_idx = mask.numpy()
dist_ap, dist_an = [], []
for i in range(bs):
......@@ -118,18 +116,16 @@ class TripletLoss(nn.Layer):
# dist_ap.append(dist_ap_i)
dist_ap.append(
max([
dist[i][j]
if mask_numpy_idx[i][j] == True else float("-inf")
for j in range(bs)
dist[i][j] if mask_numpy_idx[i][j] == True else float(
"-inf") for j in range(bs)
]).unsqueeze(0))
# dist_an_i = paddle.to_tensor(dist[i].numpy()[mask_numpy_idx[i] == False].min(), dtype='float64').unsqueeze(0)
# dist_an_i.stop_gradient = False
# dist_an.append(dist_an_i)
dist_an.append(
min([
dist[i][k]
if mask_numpy_idx[i][k] == False else float("inf")
for k in range(bs)
dist[i][k] if mask_numpy_idx[i][k] == False else float(
"inf") for k in range(bs)
]).unsqueeze(0))
dist_ap = paddle.concat(dist_ap, axis=0)
......@@ -139,4 +135,3 @@ class TripletLoss(nn.Layer):
y = paddle.ones_like(dist_an)
loss = self.ranking_loss(dist_an, dist_ap, y)
return {"TripletLoss": loss}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册