提交 3dccfbe1 编写于 作者: C Chris Yann 提交者: whs

Deep metric learning based on Fluid (#1146)

* initialize

* add loss

* update

* Update README.md

* Update README.md

* add download

* add split

* Update README.md

* Update README.md

* Update README.md

* update

* update

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* update

* update
上级 96a1493f
# Deep Metric Learning
Metric learning is a kind of methods to learn discriminative features for each sample, with the purpose that intra-class samples have smaller distances while inter-class samples have larger distances in the learned space. With the develop of deep learning technique, metric learning methods are combined with deep neural networks to boost the performance of traditional tasks, such as face recognition/verification, human re-identification, image retrieval and so on. In this page, we introduce the way to implement deep metric learning using PaddlePaddle Fluid, including [data preparation](#data-preparation), [training](#training-a-model), [finetuning](#finetuning), [evaluation](#evaluation) and [inference](#inference).
---
## Table of Contents
- [Installation](#installation)
- [Data preparation](#data-preparation)
- [Training metric learning models](#training-a-model)
- [Finetuning](#finetuning)
- [Evaluation](#evaluation)
- [Inference](#inference)
- [Performances](#supported-models)
## Installation
Running sample code in this directory requires PaddelPaddle Fluid v0.14.0 and later. If the PaddlePaddle on your device is lower than this version, please follow the instructions in [installation document](http://www.paddlepaddle.org/docs/develop/documentation/zh/build_and_install/pip_install_cn.html) and make an update.
## Data preparation
Caltech-UCSD Birds 200 (CUB-200) is an image dataset including 200 bird species. We use it to conduct the metric learning experiments. More details of this dataset can be found from its [official website](http://www.vision.caltech.edu/visipedia/CUB-200.html). First of all, preparation of CUB-200 data can be done as:
```
cd data/
sh download_cub200.sh
```
The script ```data/split.py``` is used to split train/valid set. In our settings, we use images from first 100 classes(001-100) as training data while the other 100 classes are validation data. After the splitting, there are two label files which contain train and validation image labels respectively:
* *CUB200_train.txt*: label file of CUB-200 training set, with each line seperated by ```SPACE```, like:
```
current_path/images/097.Orchard_Oriole/Orchard_Oriole_0021_2432168643.jpg 97
current_path/images/097.Orchard_Oriole/Orchard_Oriole_0022_549995638.jpg 97
current_path/images/097.Orchard_Oriole/Orchard_Oriole_0034_2244771004.jpg 97
current_path/images/097.Orchard_Oriole/Orchard_Oriole_0010_2501839798.jpg 97
current_path/images/097.Orchard_Oriole/Orchard_Oriole_0008_491860362.jpg 97
current_path/images/097.Orchard_Oriole/Orchard_Oriole_0015_2545116359.jpg 97
...
```
* *CUB200_val.txt*: label file of CUB-200 validation set, with each line seperated by ```SPACE```, like.
```
current_path/images/154.Red_eyed_Vireo/Red_eyed_Vireo_0029_59210443.jpg 154
current_path/images/154.Red_eyed_Vireo/Red_eyed_Vireo_0021_2693953672.jpg 154
current_path/images/154.Red_eyed_Vireo/Red_eyed_Vireo_0016_2917350638.jpg 154
current_path/images/154.Red_eyed_Vireo/Red_eyed_Vireo_0027_2503540454.jpg 154
current_path/images/154.Red_eyed_Vireo/Red_eyed_Vireo_0026_2502710393.jpg 154
current_path/images/154.Red_eyed_Vireo/Red_eyed_Vireo_0022_2693134681.jpg 154
...
```
## Training metric learning models
To train a metric learning model, one need to set the neural network as backbone and the metric loss function to optimize. One example of training triplet loss using ResNet-50 is shown below:
```
python train.py \
--model=ResNet50 \
--lr=0.001 \
--num_epochs=120 \
--use_gpu=True \
--train_batch_size=20 \
--test_batch_size=20 \
--loss_name=tripletloss \
--model_save_dir="output_tripletloss"
```
**parameter introduction:**
* **model**: name model to use. Default: "SE_ResNeXt50_32x4d".
* **num_epochs**: the number of epochs. Default: 120.
* **batch_size**: the size of each mini-batch. Default: 256.
* **use_gpu**: whether to use GPU or not. Default: True.
* **model_save_dir**: the directory to save trained model. Default: "output".
* **lr**: initialized learning rate. Default: 0.1.
* **pretrained_model**: model path for pretraining. Default: None.
**training log:** the log from training ResNet-50 based triplet loss is like:
```
Pass 0, trainbatch 0, lr 9.99999974738e-05, loss_metric 0.0700866878033, loss_cls 5.23635625839, acc1 0.0, acc5 0.100000008941, time 0.16 sec
Pass 0, trainbatch 10, lr 9.99999974738e-05, loss_metric 0.0752244070172, loss_cls 5.30303478241, acc1 0.0, acc5 0.100000008941, time 0.14 sec
Pass 0, trainbatch 20, lr 9.99999974738e-05, loss_metric 0.0840565115213, loss_cls 5.41880941391, acc1 0.0, acc5 0.0333333350718, time 0.14 sec
Pass 0, trainbatch 30, lr 9.99999974738e-05, loss_metric 0.0698839947581, loss_cls 5.35385560989, acc1 0.0, acc5 0.0333333350718, time 0.14 sec
Pass 0, trainbatch 40, lr 9.99999974738e-05, loss_metric 0.0596057735384, loss_cls 5.34744024277, acc1 0.0, acc5 0.0, time 0.14 sec
Pass 0, trainbatch 50, lr 9.99999974738e-05, loss_metric 0.067836754024, loss_cls 5.37124729156, acc1 0.0, acc5 0.0333333350718, time 0.14 sec
Pass 0, trainbatch 60, lr 9.99999974738e-05, loss_metric 0.0637686774135, loss_cls 5.47412204742, acc1 0.0, acc5 0.0333333350718, time 0.14 sec
Pass 0, trainbatch 70, lr 9.99999974738e-05, loss_metric 0.0772982165217, loss_cls 5.38295936584, acc1 0.0, acc5 0.0, time 0.14 sec
Pass 0, trainbatch 80, lr 9.99999974738e-05, loss_metric 0.0861896127462, loss_cls 5.41250753403, acc1 0.0, acc5 0.0, time 0.14 sec
Pass 0, trainbatch 90, lr 9.99999974738e-05, loss_metric 0.0653102770448, loss_cls 5.53133153915, acc1 0.0, acc5 0.0, time 0.14 sec
...
```
## Finetuning
Finetuning is to finetune model weights in a specific task by loading pretrained weights. After initializing ```path_to_pretrain_model```, one can finetune a model as:
```
python train.py \
--model=ResNet50 \
--pretrained_model=${path_to_pretrain_model} \
--lr=0.001 \
--num_epochs=120 \
--use_gpu=True \
--train_batch_size=20 \
--test_batch_size=20 \
--loss_name=tripletloss \
--model_save_dir="output_tripletloss"
```
## Evaluation
Evaluation is to evaluate the performance of a trained model. One can download [pretrained models](#supported-models) and set its path to ```path_to_pretrain_model```. Then Recall@Rank-1 can be obtained by running the following command:
```
python eval.py \
--model=ResNet50 \
--pretrained_model=${path_to_pretrain_model} \
--batch_size=30 \
--loss_name=tripletloss
```
According to the congfiguration of evaluation, the output log is like:
```
testbatch 0, loss 17.0384693146, recall 0.133333333333, time 0.08 sec
testbatch 10, loss 15.4248628616, recall 0.2, time 0.07 sec
testbatch 20, loss 19.3986873627, recall 0.0666666666667, time 0.07 sec
testbatch 30, loss 19.8149013519, recall 0.166666666667, time 0.07 sec
testbatch 40, loss 18.7500724792, recall 0.0333333333333, time 0.07 sec
testbatch 50, loss 15.1477527618, recall 0.166666666667, time 0.07 sec
testbatch 60, loss 21.6039619446, recall 0.0666666666667, time 0.07 sec
testbatch 70, loss 16.3203811646, recall 0.1, time 0.08 sec
testbatch 80, loss 17.3300457001, recall 0.133333333333, time 0.14 sec
testbatch 90, loss 17.9943237305, recall 0.0333333333333, time 0.07 sec
testbatch 100, loss 20.4538421631, recall 0.1, time 0.07 sec
End test, test_loss 18.2126255035, test recall 0.573597359736
...
```
## Inference
Inference is used to get prediction score or image features based on trained models.
```
python infer.py --model=ResNet50 \
--pretrained_model=${path_to_pretrain_model}
```
The output contains learned feature for each test sample:
```
Test-0-feature: [0.1551965 0.48882252 0.3528545 ... 0.35809007 0.6210782 0.34474897]
Test-1-feature: [0.26215672 0.71406883 0.36118034 ... 0.4711366 0.6783772 0.26591945]
Test-2-feature: [0.26164916 0.46013424 0.38381338 ... 0.47984493 0.5830286 0.22124235]
Test-3-feature: [0.22502825 0.44153655 0.29287377 ... 0.45510024 0.81386226 0.21451607]
Test-4-feature: [0.27748746 0.49068335 0.28269237 ... 0.47356504 0.73254013 0.22317657]
Test-5-feature: [0.17743547 0.5232162 0.35012805 ... 0.38921246 0.80238944 0.26693743]
Test-6-feature: [0.18314484 0.4294481 0.37652573 ... 0.4795592 0.7446839 0.24178651]
Test-7-feature: [0.25836483 0.49866533 0.3469289 ... 0.38316026 0.56015515 0.22388287]
Test-8-feature: [0.30613047 0.5200348 0.2847372 ... 0.5700768 0.76645917 0.26504722]
Test-9-feature: [0.3305695 0.46257797 0.27108437 ... 0.42891273 0.5112956 0.26442713]
Test-10-feature: [0.16024818 0.46871603 0.32608703 ... 0.3341719 0.6876993 0.26097256]
Test-11-feature: [0.37611157 0.6006333 0.3023942 ... 0.4729057 0.53841203 0.19621202]
Test-12-feature: [0.17515017 0.41597834 0.45567667 ... 0.45650777 0.5987687 0.25734115]
...
```
## Performances
For comparation, many metric learning models with different neural networks and loss functions are trained using corresponding experiential parameters. Recall@Rank-1 is used as evaluation metric and the performance is listed in the table. Pretrained models can be downloaded by clicking related model names.
|model | ResNet50 | SE-ResNeXt-50
|- | - | -:
|[triplet loss]() | 57.36% | 51.62%
|[eml loss]() | 58.84% | 52.94%
|[quadruplet loss]() | 62.67% | 56.40%
wget http://www.vision.caltech.edu/visipedia-data/CUB-200/images.tgz
tar zxf images.tgz
find $PWD/images|grep jpg|grep -v "\._" > list.txt
python split.py
rm -rf images.tgz list.txt
input = open("list.txt", "r").readlines()
fout_train = open("CUB200_train.txt", "w")
fout_valid = open("CUB200_val.txt", "w")
for i, item in enumerate(input):
label = item.strip().split("/")[-2].split(".")[0]
label = int(label)
if label <= 100:
fout = fout_train
else:
fout = fout_valid
fout.write(item.strip() + " " + str(label) + "\n")
fout_train.close()
fout_valid.close()
import os
import numpy as np
import time
import sys
import paddle
import paddle.fluid as fluid
import models
import argparse
import functools
from losses import tripletloss
from losses import quadrupletloss
from losses import emlloss
from losses.metrics import recall_topk
from utility import add_arguments, print_arguments
import math
# yapf: disable
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
add_arg('batch_size', int, 120, "Minibatch size.")
add_arg('use_gpu', bool, True, "Whether to use GPU or not.")
add_arg('image_shape', str, "3,224,224", "Input image size.")
add_arg('with_mem_opt', bool, False, "Whether to use memory optimization or not.")
add_arg('pretrained_model', str, None, "Whether to use pretrained model.")
add_arg('model', str, "SE_ResNeXt50_32x4d", "Set the network to use.")
add_arg('loss_name', str, "emlloss", "Loss name.")
# yapf: enable
model_list = [m for m in dir(models) if "__" not in m]
def eval(args):
# parameters from arguments
model_name = args.model
pretrained_model = args.pretrained_model
with_memory_optimization = args.with_mem_opt
loss_name = args.loss_name
image_shape = [int(m) for m in args.image_shape.split(",")]
assert model_name in model_list, "{} is not in lists: {}".format(args.model,
model_list)
image = fluid.layers.data(name='image', shape=image_shape, dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
# model definition
model = models.__dict__[model_name]()
out = model.net(input=image, class_dim=200)
if loss_name == "tripletloss":
metricloss = tripletloss(test_batch_size=args.batch_size, margin=0.1)
cost = metricloss.loss(out[0])
elif loss_name == "quadrupletloss":
metricloss = quadrupletloss(test_batch_size=args.batch_size)
cost = metricloss.loss(out[0])
elif loss_name == "emlloss":
metricloss = emlloss(
test_batch_size=args.batch_size, samples_each_class=2, fea_dim=2048)
cost = metricloss.loss(out[0])
avg_cost = fluid.layers.mean(x=cost)
test_program = fluid.default_main_program().clone(for_test=True)
if with_memory_optimization:
fluid.memory_optimize(fluid.default_main_program())
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
if pretrained_model:
def if_exist(var):
return os.path.exists(os.path.join(pretrained_model, var.name))
fluid.io.load_vars(exe, pretrained_model, predicate=if_exist)
test_reader = metricloss.test_reader
feeder = fluid.DataFeeder(place=place, feed_list=[image, label])
fetch_list = [avg_cost.name, out[0].name]
test_info = [[]]
f = []
l = []
for batch_id, (data, label) in enumerate(test_reader()):
t1 = time.time()
loss, feas = exe.run(test_program,
fetch_list=fetch_list,
feed=feeder.feed(data))
f.append(feas)
l.append(label)
t2 = time.time()
period = t2 - t1
loss = np.mean(np.array(loss))
test_info[0].append(loss)
if batch_id % 10 == 0:
recall = recall_topk(feas, label, k=1)
print("testbatch {0}, loss {1}, recall {2}, time {3}".format( \
batch_id, loss, recall, "%2.2f sec" % period))
sys.stdout.flush()
test_loss = np.array(test_info[0]).mean()
f = np.vstack(f)
l = np.vstack(l)
recall = recall_topk(f, l, k=1)
print("End test, test_loss {0}, test recall {1}".format( \
test_loss, recall))
sys.stdout.flush()
def main():
args = parser.parse_args()
print_arguments(args)
eval(args)
if __name__ == '__main__':
main()
import os
import numpy as np
import time
import sys
import paddle
import paddle.fluid as fluid
import models
import argparse
import functools
from losses import tripletloss
from utility import add_arguments, print_arguments
import math
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size', int, 1, "Minibatch size.")
add_arg('use_gpu', bool, True, "Whether to use GPU or not.")
add_arg('image_shape', str, "3,224,224", "Input image size.")
add_arg('with_mem_opt', bool, False, "Whether to use memory optimization or not.")
add_arg('pretrained_model', str, None, "Whether to use pretrained model.")
add_arg('model', str, "SE_ResNeXt50_32x4d", "Set the network to use.")
# yapf: enable
model_list = [m for m in dir(models) if "__" not in m]
def infer(args):
# parameters from arguments
model_name = args.model
pretrained_model = args.pretrained_model
with_memory_optimization = args.with_mem_opt
image_shape = [int(m) for m in args.image_shape.split(",")]
assert model_name in model_list, "{} is not in lists: {}".format(args.model,
model_list)
image = fluid.layers.data(name='image', shape=image_shape, dtype='float32')
# model definition
model = models.__dict__[model_name]()
out = model.net(input=image, class_dim=200)
test_program = fluid.default_main_program().clone(for_test=True)
if with_memory_optimization:
fluid.memory_optimize(fluid.default_main_program())
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
if pretrained_model:
def if_exist(var):
return os.path.exists(os.path.join(pretrained_model, var.name))
fluid.io.load_vars(exe, pretrained_model, predicate=if_exist)
infer_reader = tripletloss(infer_batch_size=args.batch_size).infer_reader
feeder = fluid.DataFeeder(place=place, feed_list=[image])
fetch_list = [out[0].name]
for batch_id, data in enumerate(infer_reader()):
result = exe.run(test_program,
fetch_list=fetch_list,
feed=feeder.feed(data))
result = result[0][0].reshape(-1)
print("Test-{0}-feature: {1}".format(batch_id, result))
sys.stdout.flush()
def main():
args = parser.parse_args()
print_arguments(args)
infer(args)
if __name__ == '__main__':
main()
from .tripletloss import tripletloss
from .quadrupletloss import quadrupletloss
from .emlloss import emlloss
import os
import math
import random
import cPickle
import functools
import numpy as np
#import paddle.v2 as paddle
import paddle
from PIL import Image, ImageEnhance
random.seed(0)
DATA_DIM = 224
THREAD = 8
BUF_SIZE = 1024000
DATA_DIR = "./data/"
TRAIN_LIST = './data/CUB200_train.txt'
TEST_LIST = './data/CUB200_val.txt'
#DATA_DIR = "./thirdparty/paddlemodels/metric_learning/data/"
#TRAIN_LIST = './thirdparty/paddlemodels/metric_learning/data/CUB200_train.txt'
#TEST_LIST = './thirdparty/paddlemodels/metric_learning/data/CUB200_val.txt'
train_data = {}
test_data = {}
train_list = open(TRAIN_LIST, "r").readlines()
for i, item in enumerate(train_list):
path, label = item.strip().split()
label = int(label) - 1
if label not in train_data:
train_data[label] = []
train_data[label].append(path)
test_list = open(TEST_LIST, "r").readlines()
test_image_list = []
infer_image_list = []
for i, item in enumerate(test_list):
path, label = item.strip().split()
label = int(label) - 1
test_image_list.append((path, label))
infer_image_list.append(path)
if label not in test_data:
test_data[label] = []
test_data[label].append(path)
print "train_data size:", len(train_data)
print "test_data size:", len(test_data)
print "test_data image number:", len(test_image_list)
random.shuffle(test_image_list)
img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1))
img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1))
def resize_short(img, target_size):
percent = float(target_size) / min(img.size[0], img.size[1])
resized_width = int(round(img.size[0] * percent))
resized_height = int(round(img.size[1] * percent))
img = img.resize((resized_width, resized_height), Image.BILINEAR)
return img
def Scale(img, size):
w, h = img.size
if (w <= h and w == size) or (h <= w and h == size):
return img
if w < h:
ow = size
oh = int(size * h / w)
return img.resize((ow, oh), Image.BILINEAR)
else:
oh = size
ow = int(size * w / h)
return img.resize((ow, oh), Image.BILINEAR)
def CenterCrop(img, size):
w, h = img.size
th, tw = int(size), int(size)
x1 = int(round((w - tw) / 2.))
y1 = int(round((h - th) / 2.))
return img.crop((x1, y1, x1 + tw, y1 + th))
def crop_image(img, target_size, center):
width, height = img.size
size = target_size
if center == True:
w_start = (width - size) / 2
h_start = (height - size) / 2
else:
w_start = random.randint(0, width - size)
h_start = random.randint(0, height - size)
w_end = w_start + size
h_end = h_start + size
img = img.crop((w_start, h_start, w_end, h_end))
return img
def RandomResizedCrop(img, size):
for attempt in range(10):
area = img.size[0] * img.size[1]
target_area = random.uniform(0.08, 1.0) * area
aspect_ratio = random.uniform(3. / 4, 4. / 3)
w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))
if random.random() < 0.5:
w, h = h, w
if w <= img.size[0] and h <= img.size[1]:
x1 = random.randint(0, img.size[0] - w)
y1 = random.randint(0, img.size[1] - h)
img = img.crop((x1, y1, x1 + w, y1 + h))
assert(img.size == (w, h))
return img.resize((size, size), Image.BILINEAR)
w = min(img.size[0], img.size[1])
i = (img.size[1] - w) // 2
j = (img.size[0] - w) // 2
img = img.crop((i, j, i+w, j+w))
img = img.resize((size, size), Image.BILINEAR)
return img
def random_crop(img, size, scale=[0.08, 1.0], ratio=[3. / 4., 4. / 3.]):
aspect_ratio = math.sqrt(random.uniform(*ratio))
w = 1. * aspect_ratio
h = 1. / aspect_ratio
bound = min((float(img.size[0]) / img.size[1]) / (w**2),
(float(img.size[1]) / img.size[0]) / (h**2))
scale_max = min(scale[1], bound)
scale_min = min(scale[0], bound)
target_area = img.size[0] * img.size[1] * random.uniform(scale_min,
scale_max)
target_size = math.sqrt(target_area)
w = int(target_size * w)
h = int(target_size * h)
i = random.randint(0, img.size[0] - w)
j = random.randint(0, img.size[1] - h)
img = img.crop((i, j, i + w, j + h))
img = img.resize((size, size), Image.BILINEAR)
return img
def rotate_image(img):
angle = random.randint(-10, 10)
img = img.rotate(angle)
return img
def distort_color(img):
def random_brightness(img, lower=0.8, upper=1.2):
e = random.uniform(lower, upper)
return ImageEnhance.Brightness(img).enhance(e)
def random_contrast(img, lower=0.8, upper=1.2):
e = random.uniform(lower, upper)
return ImageEnhance.Contrast(img).enhance(e)
def random_color(img, lower=0.8, upper=1.2):
e = random.uniform(lower, upper)
return ImageEnhance.Color(img).enhance(e)
ops = [random_brightness, random_contrast, random_color]
random.shuffle(ops)
img = ops[0](img)
img = ops[1](img)
img = ops[2](img)
return img
def process_image_imagepath(sample, mode, color_jitter, rotate):
imgpath = sample[0]
img = Image.open(imgpath)
if mode == 'train':
if rotate: img = rotate_image(img)
img = RandomResizedCrop(img, DATA_DIM)
else:
img = Scale(img, 256)
img = CenterCrop(img, DATA_DIM)
if mode == 'train':
if color_jitter:
img = distort_color(img)
if random.randint(0, 1) == 1:
img = img.transpose(Image.FLIP_LEFT_RIGHT)
if img.mode != 'RGB':
img = img.convert('RGB')
img = np.array(img).astype('float32').transpose((2, 0, 1)) / 255
img -= img_mean
img /= img_std
if mode == 'train' or mode == 'test':
return img, sample[1]
elif mode == 'infer':
return img
def eml_iterator(data,
mode,
batch_size,
samples_each_class,
iter_size,
shuffle=False,
color_jitter=False,
rotate=False):
def reader():
labs = data.keys()
lab_num = len(labs)
counter = np.zeros(lab_num)
ind = range(0, lab_num)
batchdata = []
assert batch_size % samples_each_class == 0, "batch_size % samples_each_class != 0"
num_class = batch_size/samples_each_class
for i in range(iter_size):
random.shuffle(ind)
for n in range(num_class):
lab_ind = ind[n]
label = labs[lab_ind]
data_list = data[label]
random.shuffle(data_list)
for s in range(samples_each_class):
path = DATA_DIR + data_list[s]
# print("path:", path)
img, _ = process_image_imagepath(sample = [path, label], \
mode = mode, \
color_jitter = color_jitter, \
rotate = rotate)
batchdata.append([img, label])
#print("batch_size:", len(batchdata))
if len(batchdata) == batch_size:
yield batchdata
batchdata = []
return reader
def quadruplet_iterator(data,
mode,
class_num,
samples_each_class,
iter_size,
shuffle=False,
color_jitter=False,
rotate=False):
def reader():
labs = data.keys()
lab_num = len(labs)
ind = range(0, lab_num)
batchdata = []
for i in range(iter_size):
random.shuffle(ind)
ind_sample = ind[:class_num]
for ind_i in ind_sample:
lab = labs[ind_i]
data_list = data[lab]
data_ind = range(0, len(data_list))
random.shuffle(data_ind)
anchor_ind = data_ind[:samples_each_class]
for anchor_ind_i in anchor_ind:
anchor_path = DATA_DIR + data_list[anchor_ind_i]
anchor_img, _ = process_image_imagepath(sample = [anchor_path, lab], \
mode = mode, \
color_jitter = color_jitter, \
rotate = rotate)
batchdata.append([anchor_img, lab])
yield batchdata
batchdata = []
return reader
def triplet_iterator(data,
mode,
batch_size,
iter_size,
shuffle=False,
color_jitter=False,
rotate=False):
def reader():
labs = data.keys()
lab_num = len(labs)
counter = np.zeros(lab_num)
ind = range(0, lab_num)
batchdata = []
for i in range(iter_size):
random.shuffle(ind)
ind_pos, ind_neg = ind[:2]
lab_pos = labs[ind_pos]
pos_data_list = data[lab_pos]
data_ind = range(0, len(pos_data_list))
random.shuffle(data_ind)
anchor_ind, pos_ind = data_ind[:2]
lab_neg = labs[ind_neg]
neg_data_list = data[lab_neg]
neg_ind = random.randint(0, len(neg_data_list) - 1)
anchor_path = DATA_DIR + pos_data_list[anchor_ind]
anchor_img, _ = process_image_imagepath(sample = [anchor_path, lab_pos], \
mode = mode, \
color_jitter = color_jitter, \
rotate = rotate)
pos_path = DATA_DIR + pos_data_list[pos_ind]
pos_img, _ = process_image_imagepath(sample = [pos_path, lab_pos], \
mode = mode, \
color_jitter = color_jitter, \
rotate = rotate)
neg_path = DATA_DIR + neg_data_list[neg_ind]
neg_img, _ = process_image_imagepath(sample = [neg_path, lab_neg], \
mode = mode, \
color_jitter = color_jitter, \
rotate = rotate)
batchdata.append([anchor_img, lab_pos])
batchdata.append([pos_img, lab_pos])
batchdata.append([neg_img, lab_neg])
#print("batchdata:", len(batchdata))
if len(batchdata) == batch_size:
yield batchdata
batchdata = []
if batchdata:
yield batchdata
return reader
def image_iterator(data,
mode,
batch_size,
shuffle=False,
color_jitter=False,
rotate=False):
def infer_reader():
batchdata = []
for i in range(len(data)):
path = data[i]
path = DATA_DIR + path
img = process_image_imagepath(sample = [path], \
mode = "infer", \
color_jitter = color_jitter, \
rotate = rotate)
batchdata.append([img])
if len(batchdata) == batch_size:
yield batchdata
batchdata = []
def reader():
batchdata = []
batchlabel = []
for i in range(len(data)):
path, label = data[i]
path = DATA_DIR + path
img, label = process_image_imagepath(sample = [path, label], \
mode = mode, \
color_jitter = color_jitter, \
rotate = rotate)
batchdata.append([img, label])
batchlabel.append([label])
if len(batchdata) == batch_size:
yield batchdata, batchlabel
batchdata = []
batchlabel = []
if mode in ["train", "test"]:
return reader
else:
return infer_reader
def eml_train(batch_size, samples_each_class):
return eml_iterator(train_data, 'train', batch_size, samples_each_class, iter_size = 100, \
shuffle=True, color_jitter=False, rotate=False)
def quadruplet_train(batch_size, class_num, samples_each_class):
print('train batch size ', batch_size)
return quadruplet_iterator(train_data, 'train', class_num, samples_each_class, iter_size=100, \
shuffle=True, color_jitter=False, rotate=False)
def triplet_train(batch_size):
assert(batch_size % 3 == 0)
print('train batch size ', batch_size)
return triplet_iterator(train_data, 'train', batch_size, iter_size = batch_size/3 * 100, \
shuffle=True, color_jitter=False, rotate=False)
def test(batch_size):
print('test batch size ', batch_size)
return image_iterator(test_image_list, "test", batch_size, shuffle=False)
def infer(batch_size):
print('inference batch size ', batch_size)
return image_iterator(infer_image_list, "infer", batch_size, shuffle=False)
import datareader as reader
import math
import numpy as np
import paddle.fluid as fluid
class emlloss():
def __init__(self, train_batch_size = 50, test_batch_size = 50, infer_batch_size = 50, samples_each_class=2, fea_dim=2048):
self.train_reader = reader.eml_train(train_batch_size, samples_each_class)
self.test_reader = reader.test(test_batch_size)
self.infer_reader = reader.infer(infer_batch_size)
self.samples_each_class = samples_each_class
self.fea_dim = fea_dim
self.batch_size = train_batch_size
def surrogate_function(self, input):
beta = 100000
output = fluid.layers.log(1+beta*input)/math.log(1+beta)
return output
def surrogate_function_approximate(self, input, bias):
beta = 100000
output = (fluid.layers.log(beta*input)+bias)/math.log(1+beta)
return output
def generate_index(self, batch_size, samples_each_class):
a = np.arange(0, batch_size*batch_size) # N*N x 1
a = a.reshape(-1,batch_size) # N x N
steps = batch_size//samples_each_class
res = []
for i in range(batch_size):
step = i // samples_each_class
start = step * samples_each_class
end = (step + 1) * samples_each_class
p = []
n = []
for j, k in enumerate(a[i]):
if j >= start and j < end:
p.append(k)
else:
n.append(k)
comb = p + n
res += comb
res = np.array(res).astype(np.int32)
return res
def loss(self, input):
samples_each_class = self.samples_each_class
fea_dim = self.fea_dim
batch_size = self.batch_size
feature1 = fluid.layers.reshape(input, shape = [-1, fea_dim])
feature2 = fluid.layers.transpose(feature1, perm = [1,0])
ab = fluid.layers.mul(x=feature1, y=feature2)
a2 = fluid.layers.square(feature1)
a2 = fluid.layers.reduce_sum(a2, dim = 1)
a2 = fluid.layers.reshape(a2, shape = [-1])
b2 = fluid.layers.square(feature2)
b2 = fluid.layers.reduce_sum(b2, dim = 0)
b2 = fluid.layers.reshape(b2, shape = [-1])
d = fluid.layers.elementwise_add(-2*ab, a2, axis = 0)
d = fluid.layers.elementwise_add(d, b2, axis = 1)
d = fluid.layers.reshape(d, shape=[-1, 1])
index = self.generate_index(batch_size, samples_each_class)
index_var = fluid.layers.create_global_var(
shape=[batch_size*batch_size], value=0, dtype='int32', persistable=True)
index_var = fluid.layers.assign(index, index_var)
index_var.stop_gradient = True
d = fluid.layers.gather(d, index=index_var)
d = fluid.layers.reshape(d, shape=[-1, batch_size])
pos, neg = fluid.layers.split(d,
num_or_sections= [samples_each_class,batch_size-samples_each_class],
dim=1)
pos_max = fluid.layers.reduce_max(pos, dim=1)
pos_max = fluid.layers.reshape(pos_max, shape=[-1, 1])
pos = fluid.layers.exp(pos-pos_max)
pos_mean = fluid.layers.reduce_mean(pos, dim=1)
neg_min = fluid.layers.reduce_min(neg, dim=1)
neg_min = fluid.layers.reshape(neg_min, shape=[-1, 1])
neg = fluid.layers.exp(-1*(neg-neg_min))
neg_mean = fluid.layers.reduce_mean(neg, dim=1)
theta = fluid.layers.reshape(neg_mean * pos_mean, shape=[-1,1])
max_gap = fluid.layers.fill_constant([1], dtype='float32', value=20.0)
max_gap.stop_gradient = True
target = pos_max - neg_min
target_max = fluid.layers.elementwise_max(target, max_gap)
target_min = fluid.layers.elementwise_min(target, max_gap)
expadj_min = fluid.layers.exp(target_min)
loss1 = self.surrogate_function(theta*expadj_min)
loss2 = self.surrogate_function_approximate(theta, target_max)
bias = fluid.layers.exp(max_gap)
bias = self.surrogate_function(theta*bias)
loss = loss1 + loss2 - bias
return loss
import numpy as np
def recall_topk(fea, lab, k = 1):
fea = np.array(fea)
fea = fea.reshape(fea.shape[0], -1)
n = np.sqrt(np.sum(fea**2, 1)).reshape(-1, 1)
fea = fea/n
a = np.sum(fea ** 2, 1).reshape(-1, 1)
b = a.T
ab = np.dot(fea, fea.T)
d = a + b - 2*ab
d = d + np.eye(len(fea)) * 1e8
sorted_index = np.argsort(d, 1)
res = 0
for i in range(len(fea)):
pred = lab[sorted_index[i][0]]
if lab[i] == pred:
res += 1.0
res = res/len(fea)
return res
import numpy as np
import datareader as reader
import paddle.fluid as fluid
class quadrupletloss():
def __init__(self,
train_batch_size = 80,
test_batch_size = 10,
infer_batch_size = 10,
samples_each_class = 2,
num_gpus=8,
margin=0.1):
self.margin = margin
self.num_gpus = num_gpus
self.samples_each_class = samples_each_class
self.train_batch_size = train_batch_size
assert(train_batch_size % (samples_each_class*num_gpus) == 0)
self.class_num = train_batch_size / self.samples_each_class
self.train_reader = reader.quadruplet_train(train_batch_size, self.class_num, self.samples_each_class)
self.test_reader = reader.test(test_batch_size)
self.infer_reader = reader.infer(infer_batch_size)
def loss(self, input):
margin = self.margin
batch_size = self.train_batch_size / self.num_gpus
fea_dim = input.shape[1] # number of channels
output = fluid.layers.reshape(input, shape = [-1, fea_dim])
output = fluid.layers.l2_normalize(output, axis=1)
#scores = fluid.layers.matmul(output, output, transpose_x=False, transpose_y=True)
output_t = fluid.layers.transpose(output, perm = [1, 0])
scores = fluid.layers.mul(x=output, y=output_t)
mask_np = np.zeros((batch_size, batch_size), dtype=np.float32)
for i in xrange(batch_size):
for j in xrange(batch_size):
if i / self.samples_each_class == j / self.samples_each_class:
mask_np[i, j] = 100.
#mask = fluid.layers.create_tensor(dtype='float32')
mask = fluid.layers.create_global_var(
shape=[batch_size, batch_size], value=0, dtype='float32', persistable=True)
fluid.layers.assign(mask_np, mask)
scores = fluid.layers.scale(x=scores, scale=-1.0) + mask
scores_max = fluid.layers.reduce_max(scores, dim=0, keep_dim=True)
ind_max = fluid.layers.argmax(scores, axis=0)
ind_max = fluid.layers.cast(x=ind_max, dtype='float32')
ind2 = fluid.layers.argmax(scores_max, axis=1)
ind2 = fluid.layers.cast(x=ind2, dtype='int32')
ind1 = fluid.layers.gather(ind_max, ind2)
ind1 = fluid.layers.cast(x=ind1, dtype='int32')
scores_min = fluid.layers.reduce_min(scores, dim=0, keep_dim=True)
ind_min = fluid.layers.argmin(scores, axis=0)
ind_min = fluid.layers.cast(x=ind_min, dtype='float32')
ind4 = fluid.layers.argmin(scores_min, axis=1)
ind4 = fluid.layers.cast(x=ind4, dtype='int32')
ind3 = fluid.layers.gather(ind_min, ind4)
ind3 = fluid.layers.cast(x=ind3, dtype='int32')
f1 = fluid.layers.gather(output, ind1)
f2 = fluid.layers.gather(output, ind2)
f3 = fluid.layers.gather(output, ind3)
f4 = fluid.layers.gather(output, ind4)
ind1.stop_gradient = True
ind2.stop_gradient = True
ind3.stop_gradient = True
ind4.stop_gradient = True
ind_max.stop_gradient = True
ind_min.stop_gradient = True
scores_max.stop_gradient = True
scores_min.stop_gradient = True
scores.stop_gradient = True
mask.stop_gradient = True
output_t.stop_gradient = True
f1_2 = fluid.layers.square(f1 - f2)
f3_4 = fluid.layers.square(f3 - f4)
s1 = fluid.layers.reduce_sum(f1_2, dim = 1)
s2 = fluid.layers.reduce_sum(f3_4, dim = 1)
s1 = fluid.layers.sqrt(s1)
s2 = fluid.layers.sqrt(s2)
loss = fluid.layers.relu(s1 - s2 + margin)
return loss
import datareader as reader
import paddle.fluid as fluid
class tripletloss():
def __init__(self,
train_batch_size = 120,
test_batch_size = 120,
infer_batch_size = 120,
margin=0.1):
self.train_reader = reader.triplet_train(train_batch_size)
self.test_reader = reader.test(test_batch_size)
self.infer_reader = reader.infer(infer_batch_size)
self.margin = margin
def loss(self, input):
margin = self.margin
fea_dim = input.shape[1] # number of channels
output = fluid.layers.reshape(input, shape = [-1, 3, fea_dim])
output = fluid.layers.l2_normalize(output, axis=2)
anchor, positive, negative = fluid.layers.split(output, num_or_sections = 3, dim = 1)
anchor = fluid.layers.reshape(anchor, shape = [-1, fea_dim])
positive = fluid.layers.reshape(positive, shape = [-1, fea_dim])
negative = fluid.layers.reshape(negative, shape = [-1, fea_dim])
a_p = fluid.layers.square(anchor - positive)
a_n = fluid.layers.square(anchor - negative)
a_p = fluid.layers.reduce_sum(a_p, dim = 1)
a_n = fluid.layers.reduce_sum(a_n, dim = 1)
a_p = fluid.layers.sqrt(a_p)
a_n = fluid.layers.sqrt(a_n)
loss = fluid.layers.relu(a_p + margin - a_n)
return loss
from .resnet import ResNet50
from .resnet import ResNet101
from .resnet import ResNet152
from .se_resnext import SE_ResNeXt50_32x4d
from .se_resnext import SE_ResNeXt101_32x4d
from .se_resnext import SE_ResNeXt152_32x4d
import paddle
import paddle.fluid as fluid
import math
__all__ = ["ResNet", "ResNet50", "ResNet101", "ResNet152"]
train_parameters = {
"input_size": [3, 224, 224],
"input_mean": [0.485, 0.456, 0.406],
"input_std": [0.229, 0.224, 0.225],
"learning_strategy": {
"name": "piecewise_decay",
"batch_size": 256,
"epochs": [30, 60, 90],
"steps": [0.1, 0.01, 0.001, 0.0001]
}
}
class ResNet():
def __init__(self, layers=50):
self.params = train_parameters
self.layers = layers
def net(self, input, class_dim=1000):
layers = self.layers
supported_layers = [50, 101, 152]
assert layers in supported_layers, \
"supported layers are {} but input layer is {}".format(supported_layers, layers)
if layers == 50:
depth = [3, 4, 6, 3]
elif layers == 101:
depth = [3, 4, 23, 3]
elif layers == 152:
depth = [3, 8, 36, 3]
num_filters = [64, 128, 256, 512]
conv = self.conv_bn_layer(
input=input, num_filters=64, filter_size=7, stride=2, act='relu')
conv = fluid.layers.pool2d(
input=conv,
pool_size=3,
pool_stride=2,
pool_padding=1,
pool_type='max')
for block in range(len(depth)):
for i in range(depth[block]):
conv = self.bottleneck_block(
input=conv,
num_filters=num_filters[block],
stride=2 if i == 0 and block != 0 else 1)
pool = fluid.layers.pool2d(
input=conv, pool_size=7, pool_type='avg', global_pooling=True)
stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
out = fluid.layers.fc(input=pool,
size=class_dim,
act='softmax',
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv,
stdv)))
return pool, out
def conv_bn_layer(self,
input,
num_filters,
filter_size,
stride=1,
groups=1,
act=None):
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) / 2,
groups=groups,
act=None,
bias_attr=False)
return fluid.layers.batch_norm(input=conv, act=act)
def shortcut(self, input, ch_out, stride):
ch_in = input.shape[1]
if ch_in != ch_out or stride != 1:
return self.conv_bn_layer(input, ch_out, 1, stride)
else:
return input
def bottleneck_block(self, input, num_filters, stride):
conv0 = self.conv_bn_layer(
input=input, num_filters=num_filters, filter_size=1, act='relu')
conv1 = self.conv_bn_layer(
input=conv0,
num_filters=num_filters,
filter_size=3,
stride=stride,
act='relu')
conv2 = self.conv_bn_layer(
input=conv1, num_filters=num_filters * 4, filter_size=1, act=None)
short = self.shortcut(input, num_filters * 4, stride)
return fluid.layers.elementwise_add(x=short, y=conv2, act='relu')
def ResNet50():
model = ResNet(layers=50)
return model
def ResNet101():
model = ResNet(layers=101)
return model
def ResNet152():
model = ResNet(layers=152)
return model
import paddle
import paddle.fluid as fluid
import math
__all__ = ["SE_ResNeXt", "SE_ResNeXt50_32x4d", "SE_ResNeXt101_32x4d", "SE_ResNeXt152_32x4d"]
train_parameters = {
"input_size": [3, 224, 224],
"input_mean": [0.485, 0.456, 0.406],
"input_std": [0.229, 0.224, 0.225],
"learning_strategy": {
"name": "piecewise_decay",
"batch_size": 256,
"epochs": [30, 60, 90],
"steps": [0.1, 0.01, 0.001, 0.0001]
}
}
class SE_ResNeXt():
def __init__(self, layers = 50):
self.params = train_parameters
self.layers = layers
def net(self, input, class_dim = 1000):
layers = self.layers
supported_layers = [50, 101, 152]
assert layers in supported_layers, \
"supported layers are {} but input layer is {}".format(supported_layers, layers)
if layers == 50:
cardinality = 32
reduction_ratio = 16
depth = [3, 4, 6, 3]
num_filters = [128, 256, 512, 1024]
conv = self.conv_bn_layer(
input=input, num_filters=64, filter_size=7, stride=2, act='relu')
conv = fluid.layers.pool2d(
input=conv,
pool_size=3,
pool_stride=2,
pool_padding=1,
pool_type='max')
elif layers == 101:
cardinality = 32
reduction_ratio = 16
depth = [3, 4, 23, 3]
num_filters = [128, 256, 512, 1024]
conv = self.conv_bn_layer(
input=input, num_filters=64, filter_size=7, stride=2, act='relu')
conv = fluid.layers.pool2d(
input=conv,
pool_size=3,
pool_stride=2,
pool_padding=1,
pool_type='max')
elif layers == 152:
cardinality = 64
reduction_ratio = 16
depth = [3, 8, 36, 3]
num_filters = [128, 256, 512, 1024]
conv = self.conv_bn_layer(
input=input, num_filters=64, filter_size=3, stride=2, act='relu')
conv = self.conv_bn_layer(
input=conv, num_filters=64, filter_size=3, stride=1, act='relu')
conv = self.conv_bn_layer(
input=conv, num_filters=128, filter_size=3, stride=1, act='relu')
conv = fluid.layers.pool2d(
input=conv, pool_size=3, pool_stride=2, pool_padding=1, \
pool_type='max')
for block in range(len(depth)):
for i in range(depth[block]):
conv = self.bottleneck_block(
input=conv,
num_filters=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
cardinality=cardinality,
reduction_ratio=reduction_ratio)
pool = fluid.layers.pool2d(
input=conv, pool_size=7, pool_type='avg', global_pooling=True)
drop = fluid.layers.dropout(x=pool, dropout_prob=0.5)
stdv = 1.0 / math.sqrt(drop.shape[1] * 1.0)
out = fluid.layers.fc(input=drop,
size=class_dim,
act='softmax',
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv,
stdv)))
return pool, out
def shortcut(self, input, ch_out, stride):
ch_in = input.shape[1]
if ch_in != ch_out or stride != 1:
filter_size = 1
return self.conv_bn_layer(input, ch_out, filter_size, stride)
else:
return input
def bottleneck_block(self, input, num_filters, stride, cardinality, reduction_ratio):
conv0 = self.conv_bn_layer(
input=input, num_filters=num_filters, filter_size=1, act='relu')
conv1 = self.conv_bn_layer(
input=conv0,
num_filters=num_filters,
filter_size=3,
stride=stride,
groups=cardinality,
act='relu')
conv2 = self.conv_bn_layer(
input=conv1, num_filters=num_filters * 2, filter_size=1, act=None)
scale = self.squeeze_excitation(
input=conv2,
num_channels=num_filters * 2,
reduction_ratio=reduction_ratio)
short = self.shortcut(input, num_filters * 2, stride)
return fluid.layers.elementwise_add(x=short, y=scale, act='relu')
def conv_bn_layer(self, input, num_filters, filter_size, stride=1, groups=1,
act=None):
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) / 2,
groups=groups,
act=None,
bias_attr=False)
return fluid.layers.batch_norm(input=conv, act=act)
def squeeze_excitation(self, input, num_channels, reduction_ratio):
pool = fluid.layers.pool2d(
input=input, pool_size=0, pool_type='avg', global_pooling=True)
stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
squeeze = fluid.layers.fc(input=pool,
size=num_channels / reduction_ratio,
act='relu',
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv,
stdv)))
stdv = 1.0 / math.sqrt(squeeze.shape[1] * 1.0)
excitation = fluid.layers.fc(input=squeeze,
size=num_channels,
act='sigmoid',
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(
-stdv, stdv)))
scale = fluid.layers.elementwise_mul(x=input, y=excitation, axis=0)
return scale
def SE_ResNeXt50_32x4d():
model = SE_ResNeXt(layers = 50)
return model
def SE_ResNeXt101_32x4d():
model = SE_ResNeXt(layers = 101)
return model
def SE_ResNeXt152_32x4d():
model = SE_ResNeXt(layers = 152)
return model
import os
import sys
import math
import time
import argparse
import functools
import numpy as np
import paddle
import paddle.fluid as fluid
import models
from losses import tripletloss
from losses import quadrupletloss
from losses import emlloss
from losses.metrics import recall_topk
from utility import add_arguments, print_arguments
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('train_batch_size', int, 80, "Minibatch size.")
add_arg('test_batch_size', int, 10, "Minibatch size.")
add_arg('use_gpu', bool, True, "Whether to use GPU or not.")
add_arg('num_epochs', int, 120, "Number of epochs.")
add_arg('image_shape', str, "3,224,224", "Input image size.")
add_arg('model_save_dir', str, "output", "Model save directory")
add_arg('with_mem_opt', bool, True, "Whether to use memory optimization or not.")
add_arg('pretrained_model', str, None, "Whether to use pretrained model.")
add_arg('checkpoint', str, None, "Whether to resume checkpoint.")
add_arg('lr', float, 0.1, "Set learning rate.")
add_arg('lr_strategy', str, "piecewise_decay", "Set the learning rate decay strategy.")
add_arg('model', str, "SE_ResNeXt50_32x4d", "Set the network to use.")
add_arg('loss_name', str, "tripletloss", "Set the loss type to use.")
add_arg('samples_each_class', int, 2, "Samples each class.")
add_arg('num_gpus', int, 8, "Number of gpus.")
add_arg('margin', float, 0.1, "Parameter margin.")
add_arg('alpha', float, 0.0, "Parameter alpha.")
# yapf: enable
model_list = [m for m in dir(models) if "__" not in m]
def optimizer_setting(params):
ls = params["learning_strategy"]
assert ls[
"name"] == "piecewise_decay", "learning rate strategy must be {}, but got {}".format(
"piecewise_decay", lr["name"])
step = 10000
bd = [step * e for e in ls["epochs"]]
base_lr = params["lr"]
lr = []
lr = [base_lr * (0.1**i) for i in range(len(bd) + 1)]
optimizer = fluid.optimizer.Momentum(
learning_rate=fluid.layers.piecewise_decay(
boundaries=bd, values=lr),
momentum=0.9,
regularization=fluid.regularizer.L2Decay(1e-4))
return optimizer
def train(args):
# parameters from arguments
model_name = args.model
checkpoint = args.checkpoint
pretrained_model = args.pretrained_model
with_memory_optimization = args.with_mem_opt
model_save_dir = args.model_save_dir
loss_name = args.loss_name
image_shape = [int(m) for m in args.image_shape.split(",")]
assert model_name in model_list, "{} is not in lists: {}".format(args.model,
model_list)
image = fluid.layers.data(name='image', shape=image_shape, dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
# model definition
model = models.__dict__[model_name]()
out = model.net(input=image, class_dim=200)
if loss_name == "tripletloss":
metricloss = tripletloss(
train_batch_size=args.train_batch_size,
test_batch_size=args.test_batch_size,
margin=0.1)
cost_metric = metricloss.loss(out[0])
avg_cost_metric = fluid.layers.mean(x=cost_metric)
elif loss_name == "quadrupletloss":
metricloss = quadrupletloss(
train_batch_size=args.train_batch_size,
test_batch_size=args.test_batch_size,
samples_each_class=args.samples_each_class,
num_gpus=args.num_gpus,
margin=args.margin)
cost_metric = metricloss.loss(out[0])
avg_cost_metric = fluid.layers.mean(x=cost_metric)
elif loss_name == "emlloss":
metricloss = emlloss(
train_batch_size=args.train_batch_size,
test_batch_size=args.test_batch_size,
samples_each_class=2,
fea_dim=2048)
cost_metric = metricloss.loss(out[0])
avg_cost_metric = fluid.layers.mean(x=cost_metric)
else:
print("loss name is not supported!")
exit()
cost_cls = fluid.layers.cross_entropy(input=out[1], label=label)
avg_cost_cls = fluid.layers.mean(x=cost_cls)
acc_top1 = fluid.layers.accuracy(input=out[1], label=label, k=1)
acc_top5 = fluid.layers.accuracy(input=out[1], label=label, k=5)
avg_cost = avg_cost_metric + args.alpha * avg_cost_cls
test_program = fluid.default_main_program().clone(for_test=True)
# parameters from model and arguments
params = model.params
params["lr"] = args.lr
params["num_epochs"] = args.num_epochs
params["learning_strategy"]["batch_size"] = args.train_batch_size
params["learning_strategy"]["name"] = args.lr_strategy
# initialize optimizer
optimizer = optimizer_setting(params)
opts = optimizer.minimize(avg_cost)
global_lr = optimizer._global_learning_rate()
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
if checkpoint is not None:
fluid.io.load_persistables(exe, checkpoint)
if pretrained_model:
def if_exist(var):
return os.path.exists(os.path.join(pretrained_model, var.name))
fluid.io.load_vars(exe, pretrained_model, predicate=if_exist)
train_reader = metricloss.train_reader
test_reader = metricloss.test_reader
feeder = fluid.DataFeeder(place=place, feed_list=[image, label])
train_exe = fluid.ParallelExecutor(
use_cuda=True if args.use_gpu else False, loss_name=avg_cost.name)
fetch_list_train = [
avg_cost_metric.name, avg_cost_cls.name, acc_top1.name, acc_top5.name,
global_lr.name
]
fetch_list_test = [out[0].name]
if with_memory_optimization:
fluid.memory_optimize(
fluid.default_main_program(), skip_opt_set=set(fetch_list_train))
for pass_id in range(params["num_epochs"]):
train_info = [[], [], [], []]
for batch_id, data in enumerate(train_reader()):
t1 = time.time()
loss_metric, loss_cls, acc1, acc5, lr = train_exe.run(
fetch_list_train, feed=feeder.feed(data))
t2 = time.time()
period = t2 - t1
loss_metric = np.mean(np.array(loss_metric))
loss_cls = np.mean(np.array(loss_cls))
acc1 = np.mean(np.array(acc1))
acc5 = np.mean(np.array(acc5))
lr = np.mean(np.array(lr))
train_info[0].append(loss_metric)
train_info[1].append(loss_cls)
train_info[2].append(acc1)
train_info[3].append(acc5)
if batch_id % 10 == 0:
print("Pass {0}, trainbatch {1}, lr {2}, loss_metric {3}, loss_cls {4}, acc1 {5}, acc5 {6}, time {7}".format(pass_id, \
batch_id, lr, loss_metric, loss_cls, acc1, acc5, "%2.2f sec" % period))
train_loss_metric = np.array(train_info[0]).mean()
train_loss_cls = np.array(train_info[1]).mean()
train_acc1 = np.array(train_info[2]).mean()
train_acc5 = np.array(train_info[3]).mean()
f = []
l = []
for batch_id, (data, label) in enumerate(test_reader()):
t1 = time.time()
[feas] = exe.run(test_program,
fetch_list=fetch_list_test,
feed=feeder.feed(data))
f.append(feas)
l.append(label)
t2 = time.time()
period = t2 - t1
if batch_id % 10 == 0:
recall = recall_topk(feas, label, k=1)
print("Pass {0}, testbatch {1}, recall {2}, time {3}".format(pass_id, \
batch_id, recall, "%2.2f sec" % period))
f = np.vstack(f)
l = np.vstack(l)
recall = recall_topk(f, l, k=1)
print("End pass {0}, train_loss_metric {1}, train_loss_cls {2}, train_acc1 {3}, train_acc5 {4}, test_recall {5}".format(pass_id, \
train_loss_metric, train_loss_cls, train_acc1, train_acc5, recall))
sys.stdout.flush()
model_path = os.path.join(model_save_dir + '/' + model_name,
str(pass_id))
if not os.path.isdir(model_path):
os.makedirs(model_path)
fluid.io.save_persistables(exe, model_path)
def main():
args = parser.parse_args()
print_arguments(args)
train(args)
if __name__ == '__main__':
main()
"""Contains common utility functions."""
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import distutils.util
import numpy as np
from paddle.fluid import core
def print_arguments(args):
"""Print argparse's arguments.
Usage:
.. code-block:: python
parser = argparse.ArgumentParser()
parser.add_argument("name", default="Jonh", type=str, help="User name.")
args = parser.parse_args()
print_arguments(args)
:param args: Input argparse.Namespace for printing.
:type args: argparse.Namespace
"""
print("----------- Configuration Arguments -----------")
for arg, value in sorted(vars(args).iteritems()):
print("%s: %s" % (arg, value))
print("------------------------------------------------")
def add_arguments(argname, type, default, help, argparser, **kwargs):
"""Add argparse's argument.
Usage:
.. code-block:: python
parser = argparse.ArgumentParser()
add_argument("name", str, "Jonh", "User name.", parser)
args = parser.parse_args()
"""
type = distutils.util.strtobool if type == bool else type
argparser.add_argument(
"--" + argname,
default=default,
type=type,
help=help + ' Default: %(default)s.',
**kwargs)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册