From dcea1e9f839d1374f8aa942edd39fb9914c4bff8 Mon Sep 17 00:00:00 2001 From: kbChen Date: Tue, 27 Nov 2018 11:11:09 +0800 Subject: [PATCH] Metric (#1467) * Add arcmargin loss * Update README * Update dataset --- fluid/PaddleCV/metric_learning/README.md | 149 +++----- fluid/PaddleCV/metric_learning/__init__.py | 0 fluid/PaddleCV/metric_learning/_ce.py | 58 +++ .../metric_learning/data/download_cub200.sh | 5 - .../metric_learning/data/download_sop.sh | 2 + fluid/PaddleCV/metric_learning/data/split.py | 14 - fluid/PaddleCV/metric_learning/eval.py | 75 ++-- fluid/PaddleCV/metric_learning/imgtool.py | 115 ++++++ fluid/PaddleCV/metric_learning/infer.py | 40 +- .../metric_learning/losses/__init__.py | 11 +- .../metric_learning/losses/arcmarginloss.py | 59 +++ .../losses/{metrics.py => commonfunc.py} | 42 +-- .../metric_learning/losses/datareader.py | 354 ------------------ .../metric_learning/losses/emlloss.py | 22 +- .../metric_learning/losses/quadrupletloss.py | 30 +- .../metric_learning/losses/softmaxloss.py | 25 ++ .../metric_learning/losses/tripletloss.py | 21 +- .../metric_learning/models/__init__.py | 12 +- .../models/{resnet.py => resnet_embedding.py} | 66 ++-- .../metric_learning/models/se_resnext.py | 166 -------- fluid/PaddleCV/metric_learning/reader.py | 175 +++++++++ fluid/PaddleCV/metric_learning/train.py | 211 ----------- fluid/PaddleCV/metric_learning/train_elem.py | 290 ++++++++++++++ fluid/PaddleCV/metric_learning/train_pair.py | 274 ++++++++++++++ fluid/PaddleCV/metric_learning/utility.py | 42 ++- 25 files changed, 1237 insertions(+), 1021 deletions(-) delete mode 100644 fluid/PaddleCV/metric_learning/__init__.py create mode 100644 fluid/PaddleCV/metric_learning/_ce.py delete mode 100644 fluid/PaddleCV/metric_learning/data/download_cub200.sh create mode 100644 fluid/PaddleCV/metric_learning/data/download_sop.sh delete mode 100644 fluid/PaddleCV/metric_learning/data/split.py create mode 100644 fluid/PaddleCV/metric_learning/imgtool.py create mode 100644 fluid/PaddleCV/metric_learning/losses/arcmarginloss.py rename fluid/PaddleCV/metric_learning/losses/{metrics.py => commonfunc.py} (62%) delete mode 100644 fluid/PaddleCV/metric_learning/losses/datareader.py create mode 100644 fluid/PaddleCV/metric_learning/losses/softmaxloss.py rename fluid/PaddleCV/metric_learning/models/{resnet.py => resnet_embedding.py} (58%) delete mode 100644 fluid/PaddleCV/metric_learning/models/se_resnext.py create mode 100644 fluid/PaddleCV/metric_learning/reader.py delete mode 100644 fluid/PaddleCV/metric_learning/train.py create mode 100644 fluid/PaddleCV/metric_learning/train_elem.py create mode 100644 fluid/PaddleCV/metric_learning/train_pair.py diff --git a/fluid/PaddleCV/metric_learning/README.md b/fluid/PaddleCV/metric_learning/README.md index f1eb080a..c961bf48 100644 --- a/fluid/PaddleCV/metric_learning/README.md +++ b/fluid/PaddleCV/metric_learning/README.md @@ -17,87 +17,62 @@ Running sample code in this directory requires PaddelPaddle Fluid v0.14.0 and la ## Data preparation -Caltech-UCSD Birds 200 (CUB-200) is an image dataset including 200 bird species. We use it to conduct the metric learning experiments. More details of this dataset can be found from its [official website](http://www.vision.caltech.edu/visipedia/CUB-200.html). First of all, preparation of CUB-200 data can be done as: +Stanford Online Product(SOP) dataset contains 120,053 images of 22,634 products downloaded from eBay.com. We use it to conduct the metric learning experiments. For training, 59,5511 out of 11,318 classes are used, and 11,316 classes(60,502 images) are held out for testing. First of all, preparation of SOP data can be done as: ``` cd data/ -sh download_cub200.sh -``` -The script ```data/split.py``` is used to split train/valid set. In our settings, we use images from first 100 classes(001-100) as training data while the other 100 classes are validation data. After the splitting, there are two label files which contain train and validation image labels respectively: - -* *CUB200_train.txt*: label file of CUB-200 training set, with each line seperated by ```SPACE```, like: -``` -current_path/images/097.Orchard_Oriole/Orchard_Oriole_0021_2432168643.jpg 97 -current_path/images/097.Orchard_Oriole/Orchard_Oriole_0022_549995638.jpg 97 -current_path/images/097.Orchard_Oriole/Orchard_Oriole_0034_2244771004.jpg 97 -current_path/images/097.Orchard_Oriole/Orchard_Oriole_0010_2501839798.jpg 97 -current_path/images/097.Orchard_Oriole/Orchard_Oriole_0008_491860362.jpg 97 -current_path/images/097.Orchard_Oriole/Orchard_Oriole_0015_2545116359.jpg 97 -... -``` -* *CUB200_val.txt*: label file of CUB-200 validation set, with each line seperated by ```SPACE```, like. -``` -current_path/images/154.Red_eyed_Vireo/Red_eyed_Vireo_0029_59210443.jpg 154 -current_path/images/154.Red_eyed_Vireo/Red_eyed_Vireo_0021_2693953672.jpg 154 -current_path/images/154.Red_eyed_Vireo/Red_eyed_Vireo_0016_2917350638.jpg 154 -current_path/images/154.Red_eyed_Vireo/Red_eyed_Vireo_0027_2503540454.jpg 154 -current_path/images/154.Red_eyed_Vireo/Red_eyed_Vireo_0026_2502710393.jpg 154 -current_path/images/154.Red_eyed_Vireo/Red_eyed_Vireo_0022_2693134681.jpg 154 -... +sh download_sop.sh ``` ## Training metric learning models -To train a metric learning model, one need to set the neural network as backbone and the metric loss function to optimize. One example of training triplet loss using ResNet-50 is shown below: +To train a metric learning model, one need to set the neural network as backbone and the metric loss function to optimize. We train meiric learning model using softmax or [arcmargin](https://arxiv.org/abs/1801.07698) loss firstly, and then fine-turned the model using other metric learning loss, such as triplet, [quadruplet](https://arxiv.org/abs/1710.00478) and [eml](https://arxiv.org/abs/1212.6094) loss. One example of training using arcmargin loss is shown below: + ``` -python train.py \ +python train_elem.py \ --model=ResNet50 \ - --lr=0.001 \ - --num_epochs=120 \ + --train_batch_size=256 \ + --test_batch_size=50 \ + --lr=0.01 \ + --total_iter_num=30000 \ --use_gpu=True \ - --train_batch_size=20 \ - --test_batch_size=20 \ - --loss_name=tripletloss \ - --model_save_dir="output_tripletloss" + --pretrained_model=${path_to_pretrain_imagenet_model} \ + --model_save_dir=${output_model_path} \ + --loss_name=arcmargin \ + --arc_scale=80.0 \ + --arc_margin=0.15 \ + --arc_easy_margin=False ``` **parameter introduction:** -* **model**: name model to use. Default: "SE_ResNeXt50_32x4d". -* **num_epochs**: the number of epochs. Default: 120. -* **batch_size**: the size of each mini-batch. Default: 256. +* **model**: name model to use. Default: "ResNet50". +* **train_batch_size**: the size of each training mini-batch. Default: 256. +* **test_batch_size**: the size of each testing mini-batch. Default: 50. +* **lr**: initialized learning rate. Default: 0.01. +* **total_iter_num**: total number of training iterations. Default: 30000. * **use_gpu**: whether to use GPU or not. Default: True. -* **model_save_dir**: the directory to save trained model. Default: "output". -* **lr**: initialized learning rate. Default: 0.1. * **pretrained_model**: model path for pretraining. Default: None. - -**training log:** the log from training ResNet-50 based triplet loss is like: -``` -Pass 0, trainbatch 0, lr 9.99999974738e-05, loss_metric 0.0700866878033, loss_cls 5.23635625839, acc1 0.0, acc5 0.100000008941, time 0.16 sec -Pass 0, trainbatch 10, lr 9.99999974738e-05, loss_metric 0.0752244070172, loss_cls 5.30303478241, acc1 0.0, acc5 0.100000008941, time 0.14 sec -Pass 0, trainbatch 20, lr 9.99999974738e-05, loss_metric 0.0840565115213, loss_cls 5.41880941391, acc1 0.0, acc5 0.0333333350718, time 0.14 sec -Pass 0, trainbatch 30, lr 9.99999974738e-05, loss_metric 0.0698839947581, loss_cls 5.35385560989, acc1 0.0, acc5 0.0333333350718, time 0.14 sec -Pass 0, trainbatch 40, lr 9.99999974738e-05, loss_metric 0.0596057735384, loss_cls 5.34744024277, acc1 0.0, acc5 0.0, time 0.14 sec -Pass 0, trainbatch 50, lr 9.99999974738e-05, loss_metric 0.067836754024, loss_cls 5.37124729156, acc1 0.0, acc5 0.0333333350718, time 0.14 sec -Pass 0, trainbatch 60, lr 9.99999974738e-05, loss_metric 0.0637686774135, loss_cls 5.47412204742, acc1 0.0, acc5 0.0333333350718, time 0.14 sec -Pass 0, trainbatch 70, lr 9.99999974738e-05, loss_metric 0.0772982165217, loss_cls 5.38295936584, acc1 0.0, acc5 0.0, time 0.14 sec -Pass 0, trainbatch 80, lr 9.99999974738e-05, loss_metric 0.0861896127462, loss_cls 5.41250753403, acc1 0.0, acc5 0.0, time 0.14 sec -Pass 0, trainbatch 90, lr 9.99999974738e-05, loss_metric 0.0653102770448, loss_cls 5.53133153915, acc1 0.0, acc5 0.0, time 0.14 sec -... -``` +* **model_save_dir**: the directory to save trained model. Default: "output". +* **loss_name**: loss fortraining model. Default: "softmax". +* **arc_scale**: parameter of arcmargin loss. Default: 80.0. +* **arc_margin**: parameter of arcmargin loss. Default: 0.15. +* **arc_easy_margin**: parameter of arcmargin loss. Default: False. ## Finetuning -Finetuning is to finetune model weights in a specific task by loading pretrained weights. After initializing ```path_to_pretrain_model```, one can finetune a model as: +Finetuning is to finetune model weights in a specific task by loading pretrained weights. After training model using softmax or arcmargin loss, one can finetune the model using triplet, quadruplet or eml loss. One example of fine-turned using eml loss is shown below: + ``` -python train.py \ +python train_pair.py \ --model=ResNet50 \ - --pretrained_model=${path_to_pretrain_model} \ - --lr=0.001 \ - --num_epochs=120 \ + --train_batch_size=160 \ + --test_batch_size=50 \ + --lr=0.0001 \ + --total_iter_num=100000 \ --use_gpu=True \ - --train_batch_size=20 \ - --test_batch_size=20 \ - --loss_name=tripletloss \ - --model_save_dir="output_tripletloss" + --pretrained_model=${path_to_pretrain_arcmargin_model} \ + --model_save_dir=${output_model_path} \ + --loss_name=eml \ + --samples_each_class=2 ``` ## Evaluation @@ -105,58 +80,26 @@ Evaluation is to evaluate the performance of a trained model. One can download [ ``` python eval.py \ --model=ResNet50 \ + --batch_size=50 \ --pretrained_model=${path_to_pretrain_model} \ - --batch_size=30 \ - --loss_name=tripletloss -``` - -According to the congfiguration of evaluation, the output log is like: -``` -testbatch 0, loss 17.0384693146, recall 0.133333333333, time 0.08 sec -testbatch 10, loss 15.4248628616, recall 0.2, time 0.07 sec -testbatch 20, loss 19.3986873627, recall 0.0666666666667, time 0.07 sec -testbatch 30, loss 19.8149013519, recall 0.166666666667, time 0.07 sec -testbatch 40, loss 18.7500724792, recall 0.0333333333333, time 0.07 sec -testbatch 50, loss 15.1477527618, recall 0.166666666667, time 0.07 sec -testbatch 60, loss 21.6039619446, recall 0.0666666666667, time 0.07 sec -testbatch 70, loss 16.3203811646, recall 0.1, time 0.08 sec -testbatch 80, loss 17.3300457001, recall 0.133333333333, time 0.14 sec -testbatch 90, loss 17.9943237305, recall 0.0333333333333, time 0.07 sec -testbatch 100, loss 20.4538421631, recall 0.1, time 0.07 sec -End test, test_loss 18.2126255035, test recall 0.573597359736 -... ``` ## Inference Inference is used to get prediction score or image features based on trained models. ``` -python infer.py --model=ResNet50 \ - --pretrained_model=${path_to_pretrain_model} -``` -The output contains learned feature for each test sample: -``` -Test-0-feature: [0.1551965 0.48882252 0.3528545 ... 0.35809007 0.6210782 0.34474897] -Test-1-feature: [0.26215672 0.71406883 0.36118034 ... 0.4711366 0.6783772 0.26591945] -Test-2-feature: [0.26164916 0.46013424 0.38381338 ... 0.47984493 0.5830286 0.22124235] -Test-3-feature: [0.22502825 0.44153655 0.29287377 ... 0.45510024 0.81386226 0.21451607] -Test-4-feature: [0.27748746 0.49068335 0.28269237 ... 0.47356504 0.73254013 0.22317657] -Test-5-feature: [0.17743547 0.5232162 0.35012805 ... 0.38921246 0.80238944 0.26693743] -Test-6-feature: [0.18314484 0.4294481 0.37652573 ... 0.4795592 0.7446839 0.24178651] -Test-7-feature: [0.25836483 0.49866533 0.3469289 ... 0.38316026 0.56015515 0.22388287] -Test-8-feature: [0.30613047 0.5200348 0.2847372 ... 0.5700768 0.76645917 0.26504722] -Test-9-feature: [0.3305695 0.46257797 0.27108437 ... 0.42891273 0.5112956 0.26442713] -Test-10-feature: [0.16024818 0.46871603 0.32608703 ... 0.3341719 0.6876993 0.26097256] -Test-11-feature: [0.37611157 0.6006333 0.3023942 ... 0.4729057 0.53841203 0.19621202] -Test-12-feature: [0.17515017 0.41597834 0.45567667 ... 0.45650777 0.5987687 0.25734115] -... +python infer.py \ + --model=ResNet50 \ + --batch_size=1 \ + --pretrained_model=${path_to_pretrain_model} ``` ## Performances For comparation, many metric learning models with different neural networks and loss functions are trained using corresponding experiential parameters. Recall@Rank-1 is used as evaluation metric and the performance is listed in the table. Pretrained models can be downloaded by clicking related model names. -|model | ResNet50 | SE-ResNeXt-50 +|pretrain model | softmax | arcmargin |- | - | -: -|[triplet loss]() | 57.36% | 51.62% -|[eml loss]() | 58.84% | 52.94% -|[quadruplet loss]() | 62.67% | 56.40% +|without fine-tuned | 77.42% | 78.11% +|fine-tuned with triplet | 78.37% | 79.21% +|fine-tuned with quadruplet | 78.10% | 79.59% +|fine-tuned with eml | 79.32% | 80.11% diff --git a/fluid/PaddleCV/metric_learning/__init__.py b/fluid/PaddleCV/metric_learning/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/fluid/PaddleCV/metric_learning/_ce.py b/fluid/PaddleCV/metric_learning/_ce.py new file mode 100644 index 00000000..ad1d8e4b --- /dev/null +++ b/fluid/PaddleCV/metric_learning/_ce.py @@ -0,0 +1,58 @@ +# this file is only used for continuous evaluation test! + +import os +import sys +sys.path.append(os.environ['ceroot']) +from kpi import CostKpi, DurationKpi, AccKpi + +# NOTE kpi.py should shared in models in some way!!!! + +train_cost_kpi = CostKpi('train_cost', 0.02 0, actived=True) +test_recall_kpi = AccKpi('test_recall', 0.02, 0, actived=True) + +tracking_kpis = [ + train_cost_kpi, + test_recall_kpi, +] + + +def parse_log(log): + ''' + This method should be implemented by model developers. + + The suggestion: + + each line in the log should be key, value, for example: + + " + train_cost\t1.0 + test_cost\t1.0 + train_cost\t1.0 + train_cost\t1.0 + train_acc\t1.2 + " + ''' + for line in log.split('\n'): + fs = line.strip().split('\t') + print(fs) + if len(fs) == 3 and fs[0] == 'kpis': + kpi_name = fs[1] + kpi_value = float(fs[2]) + yield kpi_name, kpi_value + + +def log_to_ce(log): + kpi_tracker = {} + for kpi in tracking_kpis: + kpi_tracker[kpi.name] = kpi + + for (kpi_name, kpi_value) in parse_log(log): + print(kpi_name, kpi_value) + kpi_tracker[kpi_name].add_record(kpi_value) + kpi_tracker[kpi_name].persist() + + +if __name__ == '__main__': + log = sys.stdin.read() + log_to_ce(log) + diff --git a/fluid/PaddleCV/metric_learning/data/download_cub200.sh b/fluid/PaddleCV/metric_learning/data/download_cub200.sh deleted file mode 100644 index 07cf863d..00000000 --- a/fluid/PaddleCV/metric_learning/data/download_cub200.sh +++ /dev/null @@ -1,5 +0,0 @@ -wget http://www.vision.caltech.edu/visipedia-data/CUB-200/images.tgz -tar zxf images.tgz -find images|grep jpg|grep -v "\._" > list.txt -python split.py -rm -rf images.tgz list.txt diff --git a/fluid/PaddleCV/metric_learning/data/download_sop.sh b/fluid/PaddleCV/metric_learning/data/download_sop.sh new file mode 100644 index 00000000..8de8fdd0 --- /dev/null +++ b/fluid/PaddleCV/metric_learning/data/download_sop.sh @@ -0,0 +1,2 @@ +wget ftp://cs.stanford.edu/cs/cvgl/Stanford_Online_Products.zip +unzip Stanford_Online_Products.zip diff --git a/fluid/PaddleCV/metric_learning/data/split.py b/fluid/PaddleCV/metric_learning/data/split.py deleted file mode 100644 index 5eb91012..00000000 --- a/fluid/PaddleCV/metric_learning/data/split.py +++ /dev/null @@ -1,14 +0,0 @@ -input = open("list.txt", "r").readlines() -fout_train = open("CUB200_train.txt", "w") -fout_valid = open("CUB200_val.txt", "w") -for i, item in enumerate(input): - label = item.strip().split("/")[-2].split(".")[0] - label = int(label) - if label <= 100: - fout = fout_train - else: - fout = fout_valid - fout.write(item.strip() + " " + str(label) + "\n") - -fout_train.close() -fout_valid.close() diff --git a/fluid/PaddleCV/metric_learning/eval.py b/fluid/PaddleCV/metric_learning/eval.py index 3d802ea1..9922038b 100644 --- a/fluid/PaddleCV/metric_learning/eval.py +++ b/fluid/PaddleCV/metric_learning/eval.py @@ -1,29 +1,31 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + import os -import numpy as np -import time import sys +import math +import time +import argparse +import functools +import numpy as np import paddle import paddle.fluid as fluid import models -import argparse -import functools -from losses import tripletloss -from losses import quadrupletloss -from losses import emlloss -from losses.metrics import recall_topk +import reader from utility import add_arguments, print_arguments -import math +from utility import fmt_time, recall_topk # yapf: disable parser = argparse.ArgumentParser(description=__doc__) add_arg = functools.partial(add_arguments, argparser=parser) -add_arg('batch_size', int, 120, "Minibatch size.") -add_arg('use_gpu', bool, True, "Whether to use GPU or not.") -add_arg('image_shape', str, "3,224,224", "Input image size.") -add_arg('with_mem_opt', bool, False, "Whether to use memory optimization or not.") -add_arg('pretrained_model', str, None, "Whether to use pretrained model.") -add_arg('model', str, "SE_ResNeXt50_32x4d", "Set the network to use.") -add_arg('loss_name', str, "emlloss", "Loss name.") +add_arg('model', str, "ResNet50", "Set the network to use.") +add_arg('embedding_size', int, 0, "Embedding size.") +add_arg('batch_size', int, 10, "Minibatch size.") +add_arg('image_shape', str, "3,224,224", "Input image size.") +add_arg('use_gpu', bool, True, "Whether to use GPU or not.") +add_arg('with_mem_opt', bool, False, "Whether to use memory optimization or not.") +add_arg('pretrained_model', str, None, "Whether to use pretrained model.") # yapf: enable model_list = [m for m in dir(models) if "__" not in m] @@ -34,8 +36,6 @@ def eval(args): model_name = args.model pretrained_model = args.pretrained_model with_memory_optimization = args.with_mem_opt - loss_name = args.loss_name - image_shape = [int(m) for m in args.image_shape.split(",")] assert model_name in model_list, "{} is not in lists: {}".format(args.model, @@ -46,19 +46,8 @@ def eval(args): # model definition model = models.__dict__[model_name]() - out = model.net(input=image, class_dim=200) - - if loss_name == "tripletloss": - metricloss = tripletloss() - cost = metricloss.loss(out[0]) - elif loss_name == "quadrupletloss": - metricloss = quadrupletloss() - cost = metricloss.loss(out[0]) - elif loss_name == "emlloss": - metricloss = emlloss() - cost = metricloss.loss(out[0]) - - avg_cost = fluid.layers.mean(x=cost) + out = model.net(input=image, embedding_size=args.embedding_size) + test_program = fluid.default_main_program().clone(for_test=True) if with_memory_optimization: @@ -75,39 +64,29 @@ def eval(args): fluid.io.load_vars(exe, pretrained_model, predicate=if_exist) - test_reader = paddle.batch(metricloss.test_reader, batch_size=args.batch_size) + test_reader = paddle.batch(reader.test(args), batch_size=args.batch_size, drop_last=False) feeder = fluid.DataFeeder(place=place, feed_list=[image, label]) - fetch_list = [avg_cost.name, out[0].name] + fetch_list = [out.name] - test_info = [[]] - f = [] - l = [] + f, l = [], [] for batch_id, data in enumerate(test_reader()): - if len(data) < args.batch_size: - continue t1 = time.time() - loss, feas = exe.run(test_program, - fetch_list=fetch_list, - feed=feeder.feed(data)) + [feas] = exe.run(test_program, fetch_list=fetch_list, feed=feeder.feed(data)) label = np.asarray([x[1] for x in data]) f.append(feas) l.append(label) t2 = time.time() period = t2 - t1 - loss = np.mean(np.array(loss)) - test_info[0].append(loss) if batch_id % 20 == 0: - print("testbatch {0}, loss {1}, time {2}".format( \ - batch_id, loss, "%2.2f sec" % period)) + print("[%s] testbatch %d, time %2.2f sec" % \ + (fmt_time(), batch_id, period)) - test_loss = np.array(test_info[0]).mean() f = np.vstack(f) l = np.hstack(l) recall = recall_topk(f, l, k=1) - print("End test, test_loss {0}, test recall {1}".format( \ - test_loss, recall)) + print("[%s] End test %d, test_recall %.5f" % (fmt_time(), len(f), recall)) sys.stdout.flush() diff --git a/fluid/PaddleCV/metric_learning/imgtool.py b/fluid/PaddleCV/metric_learning/imgtool.py new file mode 100644 index 00000000..43be0c21 --- /dev/null +++ b/fluid/PaddleCV/metric_learning/imgtool.py @@ -0,0 +1,115 @@ +""" tools for processing images +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import cv2 +import math +import random +import functools +import numpy as np + +#random.seed(0) + +def rotate_image(img): + """ rotate_image """ + (h, w) = img.shape[:2] + center = (w // 2, h // 2) + angle = random.randint(-10, 10) + M = cv2.getRotationMatrix2D(center, angle, 1.0) + rotated = cv2.warpAffine(img, M, (w, h)) + return rotated + +def random_crop(img, size, scale=None, ratio=None): + """ random_crop """ + scale = [0.08, 1.0] if scale is None else scale + ratio = [3. / 4., 4. / 3.] if ratio is None else ratio + + aspect_ratio = math.sqrt(random.uniform(*ratio)) + w = 1. * aspect_ratio + h = 1. / aspect_ratio + + bound = min((float(img.shape[1]) / img.shape[0]) / (w ** 2), + (float(img.shape[0]) / img.shape[1]) / (h ** 2)) + scale_max = min(scale[1], bound) + scale_min = min(scale[0], bound) + + target_area = img.shape[0] * img.shape[1] * random.uniform(scale_min, + scale_max) + target_size = math.sqrt(target_area) + w = int(target_size * w) + h = int(target_size * h) + + i = random.randint(0, img.shape[0] - h) + j = random.randint(0, img.shape[1] - w) + + img = img[i:i+h, j:j+w, :] + resized = cv2.resize(img, (size, size), interpolation=cv2.INTER_LANCZOS4) + return resized + +def distort_color(img): + return img + +def resize_short(img, target_size): + """ resize_short """ + percent = float(target_size) / min(img.shape[0], img.shape[1]) + resized_width = int(round(img.shape[1] * percent)) + resized_height = int(round(img.shape[0] * percent)) + resized = cv2.resize(img, (resized_width, resized_height), interpolation=cv2.INTER_LANCZOS4) + return resized + +def crop_image(img, target_size, center): + """ crop_image """ + height, width = img.shape[:2] + size = target_size + if center == True: + w_start = (width - size) // 2 + h_start = (height - size) // 2 + else: + w_start = random.randint(0, width - size) + h_start = random.randint(0, height - size) + w_end = w_start + size + h_end = h_start + size + img = img[h_start:h_end, w_start:w_end, :] + return img + +def process_image(sample, mode, color_jitter, rotate, + crop_size=224, mean=None, std=None): + """ process_image """ + + mean = [0.485, 0.456, 0.406] if mean is None else mean + std = [0.229, 0.224, 0.225] if std is None else std + + image_name = sample[0] + img = cv2.imread(image_name) # BGR mode, but need RGB mode + + if mode == 'train': + if rotate: + img = rotate_image(img) + if crop_size > 0: + img = random_crop(img, crop_size) + if color_jitter: + img = distort_color(img) + if random.randint(0, 1) == 1: + img = img[:, ::-1, :] + else: + if crop_size > 0: + img = resize_short(img, crop_size) + img = crop_image(img, target_size=crop_size, center=True) + + img = img[:, :, ::-1].astype('float32').transpose((2, 0, 1)) / 255 + + img_mean = np.array(mean).reshape((3, 1, 1)) + img_std = np.array(std).reshape((3, 1, 1)) + img -= img_mean + img /= img_std + + if mode == 'train' or mode == 'val': + return (img, sample[1]) + elif mode == 'test': + return (img, ) + +def image_mapper(**kwargs): + """ image_mapper """ + return functools.partial(process_image, **kwargs) diff --git a/fluid/PaddleCV/metric_learning/infer.py b/fluid/PaddleCV/metric_learning/infer.py index 3b16889d..a189ccc3 100644 --- a/fluid/PaddleCV/metric_learning/infer.py +++ b/fluid/PaddleCV/metric_learning/infer.py @@ -1,25 +1,30 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + import os -import numpy as np -import time import sys +import math +import time +import argparse +import functools +import numpy as np import paddle import paddle.fluid as fluid import models -import argparse -import functools -from losses import tripletloss +import reader from utility import add_arguments, print_arguments -import math parser = argparse.ArgumentParser(description=__doc__) add_arg = functools.partial(add_arguments, argparser=parser) # yapf: disable -add_arg('batch_size', int, 1, "Minibatch size.") -add_arg('use_gpu', bool, True, "Whether to use GPU or not.") -add_arg('image_shape', str, "3,224,224", "Input image size.") -add_arg('with_mem_opt', bool, False, "Whether to use memory optimization or not.") -add_arg('pretrained_model', str, None, "Whether to use pretrained model.") -add_arg('model', str, "SE_ResNeXt50_32x4d", "Set the network to use.") +add_arg('model', str, "ResNet50", "Set the network to use.") +add_arg('embedding_size', int, 0, "Embedding size.") +add_arg('batch_size', int, 1, "Minibatch size.") +add_arg('image_shape', str, "3,224,224", "Input image size.") +add_arg('use_gpu', bool, True, "Whether to use GPU or not.") +add_arg('with_mem_opt', bool, False, "Whether to use memory optimization or not.") +add_arg('pretrained_model', str, None, "Whether to use pretrained model.") # yapf: enable model_list = [m for m in dir(models) if "__" not in m] @@ -39,7 +44,8 @@ def infer(args): # model definition model = models.__dict__[model_name]() - out = model.net(input=image, class_dim=200) + out = model.net(input=image, embedding_size=args.embedding_size) + test_program = fluid.default_main_program().clone(for_test=True) if with_memory_optimization: @@ -56,15 +62,13 @@ def infer(args): fluid.io.load_vars(exe, pretrained_model, predicate=if_exist) - infer_reader = paddle.batch(tripletloss().infer_reader, batch_size=args.batch_size) + infer_reader = paddle.batch(reader.infer(args), batch_size=args.batch_size, drop_last=False) feeder = fluid.DataFeeder(place=place, feed_list=[image]) - fetch_list = [out[0].name] + fetch_list = [out.name] for batch_id, data in enumerate(infer_reader()): - result = exe.run(test_program, - fetch_list=fetch_list, - feed=feeder.feed(data)) + result = exe.run(test_program, fetch_list=fetch_list, feed=feeder.feed(data)) result = result[0][0].reshape(-1) print("Test-{0}-feature: {1}".format(batch_id, result)) sys.stdout.flush() diff --git a/fluid/PaddleCV/metric_learning/losses/__init__.py b/fluid/PaddleCV/metric_learning/losses/__init__.py index 7dce48c6..5cc82214 100644 --- a/fluid/PaddleCV/metric_learning/losses/__init__.py +++ b/fluid/PaddleCV/metric_learning/losses/__init__.py @@ -1,3 +1,8 @@ -from .tripletloss import tripletloss -from .quadrupletloss import quadrupletloss -from .emlloss import emlloss +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from .softmaxloss import SoftmaxLoss +from .arcmarginloss import ArcMarginLoss +from .tripletloss import TripletLoss +from .quadrupletloss import QuadrupletLoss +from .emlloss import EmlLoss diff --git a/fluid/PaddleCV/metric_learning/losses/arcmarginloss.py b/fluid/PaddleCV/metric_learning/losses/arcmarginloss.py new file mode 100644 index 00000000..b166f7bb --- /dev/null +++ b/fluid/PaddleCV/metric_learning/losses/arcmarginloss.py @@ -0,0 +1,59 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import paddle.fluid as fluid + +class ArcMarginLoss(): + def __init__(self, class_dim, margin=0.15, scale=80.0, easy_margin=False): + self.class_dim = class_dim + self.margin = margin + self.scale = scale + self.easy_margin = easy_margin + + def loss(self, input, label): + out = self.arc_margin_product(input, label, self.class_dim, self.margin, self.scale, self.easy_margin) + #loss = fluid.layers.softmax_with_cross_entropy(logits=out, label=label) + out = fluid.layers.softmax(input=out) + loss = fluid.layers.cross_entropy(input=out, label=label) + return loss, out + + def arc_margin_product(self, input, label, out_dim, m, s, easy_margin=False): + #input = fluid.layers.l2_normalize(input, axis=1) + input_norm = fluid.layers.sqrt(fluid.layers.reduce_sum(fluid.layers.square(input), dim=1)) + input = fluid.layers.elementwise_div(input, input_norm, axis=0) + + weight = fluid.layers.create_parameter( + shape=[out_dim, input.shape[1]], + dtype='float32', + name='weight_norm', + attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.Xavier())) + #weight = fluid.layers.l2_normalize(weight, axis=1) + weight_norm = fluid.layers.sqrt(fluid.layers.reduce_sum(fluid.layers.square(weight), dim=1)) + weight = fluid.layers.elementwise_div(weight, weight_norm, axis=0) + weight = fluid.layers.transpose(weight, perm = [1, 0]) + cosine = fluid.layers.mul(input, weight) + sine = fluid.layers.sqrt(1.0 - fluid.layers.square(cosine) + 1e-6) + + cos_m = math.cos(m) + sin_m = math.sin(m) + phi = cosine * cos_m - sine * sin_m + + th = math.cos(math.pi - m) + mm = math.sin(math.pi - m) * m + if easy_margin: + phi = self.paddle_where_more_than(cosine, 0, phi, cosine) + else: + phi = self.paddle_where_more_than(cosine, th, phi, cosine-mm) + + one_hot = fluid.layers.one_hot(input=label, depth=out_dim) + output = fluid.layers.elementwise_mul(one_hot, phi) + fluid.layers.elementwise_mul((1.0 - one_hot), cosine) + output = output * s + return output + + def paddle_where_more_than(self, target, limit, x, y): + mask = fluid.layers.cast(x=(target>limit), dtype='float32') + output = fluid.layers.elementwise_mul(mask, x) + fluid.layers.elementwise_mul((1.0 - mask), y) + return output diff --git a/fluid/PaddleCV/metric_learning/losses/metrics.py b/fluid/PaddleCV/metric_learning/losses/commonfunc.py similarity index 62% rename from fluid/PaddleCV/metric_learning/losses/metrics.py rename to fluid/PaddleCV/metric_learning/losses/commonfunc.py index 51805bde..98af1502 100644 --- a/fluid/PaddleCV/metric_learning/losses/metrics.py +++ b/fluid/PaddleCV/metric_learning/losses/commonfunc.py @@ -1,41 +1,15 @@ -import numpy as np -def recall_topk(fea, lab, k = 1): - fea = np.array(fea) - fea = fea.reshape(fea.shape[0], -1) - n = np.sqrt(np.sum(fea**2, 1)).reshape(-1, 1) - fea = fea/n - a = np.sum(fea ** 2, 1).reshape(-1, 1) - b = a.T - ab = np.dot(fea, fea.T) - d = a + b - 2*ab - d = d + np.eye(len(fea)) * 1e8 - sorted_index = np.argsort(d, 1) - res = 0 - for i in range(len(fea)): - pred = lab[sorted_index[i][0]] - if lab[i] == pred: - res += 1.0 - res = res/len(fea) - return res +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function -import subprocess import os - -def get_gpu_num(): - visibledevice = os.getenv('CUDA_VISIBLE_DEVICES') - if visibledevice: - devicenum = len(visibledevice.split(',')) - else: - devicenum = subprocess.check_output( - [str.encode('nvidia-smi'), str.encode('-L')]).decode('utf-8').count('\n') - return devicenum - +import numpy as np import paddle as paddle import paddle.fluid as fluid def generate_index(batch_size, samples_each_class): - a = np.arange(0, batch_size * batch_size) - a = a.reshape(-1, batch_size) + a = np.arange(0, batch_size * batch_size) # N*N x 1 + a = a.reshape(-1, batch_size) # N x N steps = batch_size // samples_each_class res = [] for i in range(batch_size): @@ -72,7 +46,3 @@ def calculate_order_dist_matrix(feature, batch_size, samples_each_class): d = fluid.layers.gather(d, index=index_var) d = fluid.layers.reshape(d, shape=[-1, batch_size]) return d - - - - diff --git a/fluid/PaddleCV/metric_learning/losses/datareader.py b/fluid/PaddleCV/metric_learning/losses/datareader.py deleted file mode 100644 index 286bf950..00000000 --- a/fluid/PaddleCV/metric_learning/losses/datareader.py +++ /dev/null @@ -1,354 +0,0 @@ -import os -import math -import random -import functools -import numpy as np -import paddle -from PIL import Image, ImageEnhance - -random.seed(0) - -DATA_DIM = 224 - -THREAD = 8 -BUF_SIZE = 1024000 - -DATA_DIR = "./data/" -TRAIN_LIST = './data/CUB200_train.txt' -TEST_LIST = './data/CUB200_val.txt' -#DATA_DIR = "./data/CUB200/" -#TRAIN_LIST = './data/CUB200/CUB200_train.txt' -#TEST_LIST = './data/CUB200/CUB200_val.txt' -train_data = {} -test_data = {} -train_list = open(TRAIN_LIST, "r").readlines() -train_image_list = [] -for i, item in enumerate(train_list): - path, label = item.strip().split() - label = int(label) - 1 - train_image_list.append((path, label)) - if label not in train_data: - train_data[label] = [] - train_data[label].append(path) - -test_list = open(TEST_LIST, "r").readlines() -test_image_list = [] -infer_image_list = [] -for i, item in enumerate(test_list): - path, label = item.strip().split() - label = int(label) - 1 - test_image_list.append((path, label)) - infer_image_list.append(path) - if label not in test_data: - test_data[label] = [] - test_data[label].append(path) - -print("train_data size:", len(train_data)) -print("test_data size:", len(test_data)) -print("test_data image number:", len(test_image_list)) -random.shuffle(test_image_list) - - -img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1)) -img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1)) - - -def resize_short(img, target_size): - percent = float(target_size) / min(img.size[0], img.size[1]) - resized_width = int(round(img.size[0] * percent)) - resized_height = int(round(img.size[1] * percent)) - img = img.resize((resized_width, resized_height), Image.BILINEAR) - return img - -def Scale(img, size): - w, h = img.size - if (w <= h and w == size) or (h <= w and h == size): - return img - if w < h: - ow = size - oh = int(size * h / w) - return img.resize((ow, oh), Image.BILINEAR) - else: - oh = size - ow = int(size * w / h) - return img.resize((ow, oh), Image.BILINEAR) - -def CenterCrop(img, size): - w, h = img.size - th, tw = int(size), int(size) - x1 = int(round((w - tw) / 2.)) - y1 = int(round((h - th) / 2.)) - return img.crop((x1, y1, x1 + tw, y1 + th)) - -def crop_image(img, target_size, center): - width, height = img.size - size = target_size - if center == True: - w_start = (width - size) / 2 - h_start = (height - size) / 2 - else: - w_start = random.randint(0, width - size) - h_start = random.randint(0, height - size) - w_end = w_start + size - h_end = h_start + size - img = img.crop((w_start, h_start, w_end, h_end)) - return img - -def RandomResizedCrop(img, size): - for attempt in range(10): - area = img.size[0] * img.size[1] - target_area = random.uniform(0.08, 1.0) * area - aspect_ratio = random.uniform(3. / 4, 4. / 3) - - w = int(round(math.sqrt(target_area * aspect_ratio))) - h = int(round(math.sqrt(target_area / aspect_ratio))) - - if random.random() < 0.5: - w, h = h, w - - if w <= img.size[0] and h <= img.size[1]: - x1 = random.randint(0, img.size[0] - w) - y1 = random.randint(0, img.size[1] - h) - - img = img.crop((x1, y1, x1 + w, y1 + h)) - assert(img.size == (w, h)) - - return img.resize((size, size), Image.BILINEAR) - - w = min(img.size[0], img.size[1]) - i = (img.size[1] - w) // 2 - j = (img.size[0] - w) // 2 - img = img.crop((i, j, i+w, j+w)) - img = img.resize((size, size), Image.BILINEAR) - return img - - -def random_crop(img, size, scale=[0.08, 1.0], ratio=[3. / 4., 4. / 3.]): - aspect_ratio = math.sqrt(random.uniform(*ratio)) - w = 1. * aspect_ratio - h = 1. / aspect_ratio - - bound = min((float(img.size[0]) / img.size[1]) / (w**2), - (float(img.size[1]) / img.size[0]) / (h**2)) - scale_max = min(scale[1], bound) - scale_min = min(scale[0], bound) - - target_area = img.size[0] * img.size[1] * random.uniform(scale_min, - scale_max) - target_size = math.sqrt(target_area) - w = int(target_size * w) - h = int(target_size * h) - - i = random.randint(0, img.size[0] - w) - j = random.randint(0, img.size[1] - h) - - img = img.crop((i, j, i + w, j + h)) - img = img.resize((size, size), Image.BILINEAR) - return img - - -def rotate_image(img): - angle = random.randint(-10, 10) - img = img.rotate(angle) - return img - - -def distort_color(img): - def random_brightness(img, lower=0.8, upper=1.2): - e = random.uniform(lower, upper) - return ImageEnhance.Brightness(img).enhance(e) - - def random_contrast(img, lower=0.8, upper=1.2): - e = random.uniform(lower, upper) - return ImageEnhance.Contrast(img).enhance(e) - - def random_color(img, lower=0.8, upper=1.2): - e = random.uniform(lower, upper) - return ImageEnhance.Color(img).enhance(e) - - ops = [random_brightness, random_contrast, random_color] - random.shuffle(ops) - - img = ops[0](img) - img = ops[1](img) - img = ops[2](img) - - return img - -def process_image_imagepath(sample, mode, color_jitter, rotate): - imgpath = sample[0] - img = Image.open(imgpath) - if mode == 'train': - if rotate: img = rotate_image(img) - img = RandomResizedCrop(img, DATA_DIM) - else: - img = Scale(img, 256) - img = CenterCrop(img, DATA_DIM) - if mode == 'train': - if color_jitter: - img = distort_color(img) - if random.randint(0, 1) == 1: - img = img.transpose(Image.FLIP_LEFT_RIGHT) - - if img.mode != 'RGB': - img = img.convert('RGB') - - img = np.array(img).astype('float32').transpose((2, 0, 1)) / 255 - img -= img_mean - img /= img_std - - if mode in ['train', 'test']: - return img, sample[1] - elif mode == 'infer': - return [img] - - -def eml_iterator(data, - mode, - batch_size, - samples_each_class, - iter_size, - shuffle=False, - color_jitter=False, - rotate=False): - def reader(): - labs = list(data.keys()) - lab_num = len(labs) - ind = list(range(0, lab_num)) - assert batch_size % samples_each_class == 0, "batch_size % samples_each_class != 0" - num_class = batch_size // samples_each_class - for i in range(iter_size): - random.shuffle(ind) - for n in range(num_class): - lab_ind = ind[n] - label = labs[lab_ind] - data_list = data[label] - random.shuffle(data_list) - for s in range(samples_each_class): - path = DATA_DIR + data_list[s] - yield path, label - - mapper = functools.partial( - process_image_imagepath, mode=mode, color_jitter=color_jitter, rotate=rotate) - - return paddle.reader.xmap_readers(mapper, reader, THREAD, BUF_SIZE, order=True) - - -def quadruplet_iterator(data, - mode, - class_num, - samples_each_class, - iter_size, - shuffle=False, - color_jitter=False, - rotate=False): - def reader(): - labs = list(data.keys()) - lab_num = len(labs) - ind = list(range(0, lab_num)) - for i in range(iter_size): - random.shuffle(ind) - ind_sample = ind[:class_num] - - for ind_i in ind_sample: - lab = labs[ind_i] - data_list = data[lab] - data_ind = list(range(0, len(data_list))) - random.shuffle(data_ind) - anchor_ind = data_ind[:samples_each_class] - - for anchor_ind_i in anchor_ind: - anchor_path = DATA_DIR + data_list[anchor_ind_i] - yield anchor_path, lab - - mapper = functools.partial( - process_image_imagepath, mode=mode, color_jitter=color_jitter, rotate=rotate) - - return paddle.reader.xmap_readers(mapper, reader, THREAD, BUF_SIZE, order=True) - - -def triplet_iterator(data, - mode, - batch_size, - iter_size, - shuffle=False, - color_jitter=False, - rotate=False): - def reader(): - labs = list(data.keys()) - lab_num = len(labs) - ind = list(range(0, lab_num)) - for i in range(iter_size): - random.shuffle(ind) - ind_pos, ind_neg = ind[:2] - lab_pos = labs[ind_pos] - pos_data_list = data[lab_pos] - data_ind = list(range(0, len(pos_data_list))) - random.shuffle(data_ind) - anchor_ind, pos_ind = data_ind[:2] - - lab_neg = labs[ind_neg] - neg_data_list = data[lab_neg] - neg_ind = random.randint(0, len(neg_data_list) - 1) - - anchor_path = DATA_DIR + pos_data_list[anchor_ind] - yield anchor_path, lab_pos - - pos_path = DATA_DIR + pos_data_list[pos_ind] - yield pos_path, lab_pos - - neg_path = DATA_DIR + neg_data_list[neg_ind] - yield neg_path, lab_neg - - - mapper = functools.partial( - process_image_imagepath, mode=mode, color_jitter=color_jitter, rotate=rotate) - - return paddle.reader.xmap_readers(mapper, reader, THREAD, BUF_SIZE, order=True) - - -def image_iterator(data, - mode, - shuffle=False, - color_jitter=False, - rotate=False): - def test_reader(): - for i in range(len(data)): - path, label = data[i] - path = DATA_DIR + path - yield path, label - - def infer_reader(): - for i in range(len(data)): - path = data[i] - path = DATA_DIR + path - yield [path] - - if mode == "test": - mapper = functools.partial( - process_image_imagepath, mode=mode, color_jitter=color_jitter, rotate=rotate) - return paddle.reader.xmap_readers(mapper, test_reader, THREAD, BUF_SIZE) - elif mode == "infer": - mapper = functools.partial( - process_image_imagepath, mode=mode, color_jitter=color_jitter, rotate=rotate) - return paddle.reader.xmap_readers(mapper, infer_reader, THREAD, BUF_SIZE) - - -def eml_train(batch_size, samples_each_class): - return eml_iterator(train_data, 'train', batch_size, samples_each_class, iter_size = 100, \ - shuffle=True, color_jitter=False, rotate=False) - -def quadruplet_train(class_num, samples_each_class): - return quadruplet_iterator(train_data, 'train', class_num, samples_each_class, iter_size=100, \ - shuffle=True, color_jitter=False, rotate=False) - -def triplet_train(batch_size): - assert(batch_size % 3 == 0) - return triplet_iterator(train_data, 'train', batch_size, iter_size = batch_size//3 * 100, \ - shuffle=True, color_jitter=False, rotate=False) - -def test(): - return image_iterator(test_image_list, "test", shuffle=False) - -def infer(): - return image_iterator(infer_image_list, "infer", shuffle=False) diff --git a/fluid/PaddleCV/metric_learning/losses/emlloss.py b/fluid/PaddleCV/metric_learning/losses/emlloss.py index 362b0ae0..459ca449 100644 --- a/fluid/PaddleCV/metric_learning/losses/emlloss.py +++ b/fluid/PaddleCV/metric_learning/losses/emlloss.py @@ -1,21 +1,20 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + import math -import numpy as np import paddle.fluid as fluid -from . import datareader as reader -from .metrics import calculate_order_dist_matrix -from .metrics import get_gpu_num +from utility import get_gpu_num +from .commonfunc import calculate_order_dist_matrix -class emlloss(): +class EmlLoss(): def __init__(self, train_batch_size = 40, samples_each_class=2): - num_gpus = get_gpu_num() self.samples_each_class = samples_each_class self.train_batch_size = train_batch_size + num_gpus = get_gpu_num() assert(train_batch_size % num_gpus == 0) self.cal_loss_batch_size = train_batch_size // num_gpus assert(self.cal_loss_batch_size % samples_each_class == 0) - class_num = train_batch_size // samples_each_class - self.train_reader = reader.eml_train(train_batch_size, samples_each_class) - self.test_reader = reader.test() def surrogate_function(self, beta, theta, bias): x = theta * fluid.layers.exp(bias) @@ -41,7 +40,10 @@ class emlloss(): def loss(self, input): samples_each_class = self.samples_each_class - batch_size = self.cal_loss_batch_size + batch_size = self.cal_loss_batch_size + #input = fluid.layers.l2_normalize(input, axis=1) + #input_norm = fluid.layers.sqrt(fluid.layers.reduce_sum(fluid.layers.square(input), dim=1)) + #input = fluid.layers.elementwise_div(input, input_norm, axis=0) d = calculate_order_dist_matrix(input, self.cal_loss_batch_size, self.samples_each_class) ignore, pos, neg = fluid.layers.split(d, num_or_sections= [1, samples_each_class-1, batch_size-samples_each_class], dim=1) diff --git a/fluid/PaddleCV/metric_learning/losses/quadrupletloss.py b/fluid/PaddleCV/metric_learning/losses/quadrupletloss.py index b80af96e..a14fffcd 100644 --- a/fluid/PaddleCV/metric_learning/losses/quadrupletloss.py +++ b/fluid/PaddleCV/metric_learning/losses/quadrupletloss.py @@ -1,38 +1,40 @@ -import numpy as np +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + import paddle.fluid as fluid -from . import datareader as reader -from .metrics import calculate_order_dist_matrix -from .metrics import get_gpu_num +from utility import get_gpu_num +from .commonfunc import calculate_order_dist_matrix -class quadrupletloss(): +class QuadrupletLoss(): def __init__(self, train_batch_size = 80, samples_each_class = 2, - margin=0.1): + margin = 0.1): self.margin = margin - num_gpus = get_gpu_num() self.samples_each_class = samples_each_class self.train_batch_size = train_batch_size + num_gpus = get_gpu_num() assert(train_batch_size % num_gpus == 0) self.cal_loss_batch_size = train_batch_size // num_gpus assert(self.cal_loss_batch_size % samples_each_class == 0) - class_num = train_batch_size // samples_each_class - self.train_reader = reader.quadruplet_train(class_num, samples_each_class) - self.test_reader = reader.test() def loss(self, input): - feature = fluid.layers.l2_normalize(input, axis=1) + #input = fluid.layers.l2_normalize(input, axis=1) + input_norm = fluid.layers.sqrt(fluid.layers.reduce_sum(fluid.layers.square(input), dim=1)) + input = fluid.layers.elementwise_div(input, input_norm, axis=0) + samples_each_class = self.samples_each_class batch_size = self.cal_loss_batch_size margin = self.margin - d = calculate_order_dist_matrix(feature, self.cal_loss_batch_size, self.samples_each_class) + d = calculate_order_dist_matrix(input, self.cal_loss_batch_size, self.samples_each_class) ignore, pos, neg = fluid.layers.split(d, num_or_sections= [1, samples_each_class-1, batch_size-samples_each_class], dim=1) ignore.stop_gradient = True pos_max = fluid.layers.reduce_max(pos) neg_min = fluid.layers.reduce_min(neg) - pos_max = fluid.layers.sqrt(pos_max) - neg_min = fluid.layers.sqrt(neg_min) + #pos_max = fluid.layers.sqrt(pos_max + 1e-6) + #neg_min = fluid.layers.sqrt(neg_min + 1e-6) loss = fluid.layers.relu(pos_max - neg_min + margin) return loss diff --git a/fluid/PaddleCV/metric_learning/losses/softmaxloss.py b/fluid/PaddleCV/metric_learning/losses/softmaxloss.py new file mode 100644 index 00000000..dcf92adf --- /dev/null +++ b/fluid/PaddleCV/metric_learning/losses/softmaxloss.py @@ -0,0 +1,25 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import paddle.fluid as fluid + +class SoftmaxLoss(): + def __init__(self, class_dim): + self.class_dim = class_dim + + def loss(self, input, label): + out = self.fc_product(input, self.class_dim) + loss = fluid.layers.cross_entropy(input=out, label=label) + return loss, out + + def fc_product(self, input, out_dim): + stdv = 1.0 / math.sqrt(input.shape[1] * 1.0) + out = fluid.layers.fc(input=input, + size=out_dim, + act='softmax', + param_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.Uniform(-stdv, + stdv))) + return out diff --git a/fluid/PaddleCV/metric_learning/losses/tripletloss.py b/fluid/PaddleCV/metric_learning/losses/tripletloss.py index 02ad989c..7ef3bdb4 100644 --- a/fluid/PaddleCV/metric_learning/losses/tripletloss.py +++ b/fluid/PaddleCV/metric_learning/losses/tripletloss.py @@ -1,18 +1,21 @@ -from . import datareader as reader +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + import paddle.fluid as fluid -class tripletloss(): - def __init__(self, train_batch_size = 120, margin=0.1): - self.train_reader = reader.triplet_train(train_batch_size) - self.test_reader = reader.test() - self.infer_reader = reader.infer() +class TripletLoss(): + def __init__(self, margin=0.1): self.margin = margin def loss(self, input): margin = self.margin fea_dim = input.shape[1] # number of channels + #input = fluid.layers.l2_normalize(input, axis=1) + input_norm = fluid.layers.sqrt(fluid.layers.reduce_sum(fluid.layers.square(input), dim=1)) + input = fluid.layers.elementwise_div(input, input_norm, axis=0) output = fluid.layers.reshape(input, shape = [-1, 3, fea_dim]) - output = fluid.layers.l2_normalize(output, axis=2) + anchor, positive, negative = fluid.layers.split(output, num_or_sections = 3, dim = 1) anchor = fluid.layers.reshape(anchor, shape = [-1, fea_dim]) @@ -23,7 +26,7 @@ class tripletloss(): a_n = fluid.layers.square(anchor - negative) a_p = fluid.layers.reduce_sum(a_p, dim = 1) a_n = fluid.layers.reduce_sum(a_n, dim = 1) - a_p = fluid.layers.sqrt(a_p) - a_n = fluid.layers.sqrt(a_n) + #a_p = fluid.layers.sqrt(a_p + 1e-6) + #a_n = fluid.layers.sqrt(a_n + 1e-6) loss = fluid.layers.relu(a_p + margin - a_n) return loss diff --git a/fluid/PaddleCV/metric_learning/models/__init__.py b/fluid/PaddleCV/metric_learning/models/__init__.py index b997f3ce..bc92b497 100644 --- a/fluid/PaddleCV/metric_learning/models/__init__.py +++ b/fluid/PaddleCV/metric_learning/models/__init__.py @@ -1,6 +1,6 @@ -from .resnet import ResNet50 -from .resnet import ResNet101 -from .resnet import ResNet152 -from .se_resnext import SE_ResNeXt50_32x4d -from .se_resnext import SE_ResNeXt101_32x4d -from .se_resnext import SE_ResNeXt152_32x4d +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from .resnet_embedding import ResNet50 +from .resnet_embedding import ResNet101 +from .resnet_embedding import ResNet152 diff --git a/fluid/PaddleCV/metric_learning/models/resnet.py b/fluid/PaddleCV/metric_learning/models/resnet_embedding.py similarity index 58% rename from fluid/PaddleCV/metric_learning/models/resnet.py rename to fluid/PaddleCV/metric_learning/models/resnet_embedding.py index 7fbe65b6..9da16b1e 100644 --- a/fluid/PaddleCV/metric_learning/models/resnet.py +++ b/fluid/PaddleCV/metric_learning/models/resnet_embedding.py @@ -1,6 +1,7 @@ import paddle import paddle.fluid as fluid import math +from paddle.fluid.param_attr import ParamAttr __all__ = ["ResNet", "ResNet50", "ResNet101", "ResNet152"] @@ -22,7 +23,7 @@ class ResNet(): self.params = train_parameters self.layers = layers - def net(self, input, class_dim=1000): + def net(self, input, embedding_size=256): layers = self.layers supported_layers = [50, 101, 152] assert layers in supported_layers, \ @@ -37,7 +38,7 @@ class ResNet(): num_filters = [64, 128, 256, 512] conv = self.conv_bn_layer( - input=input, num_filters=64, filter_size=7, stride=2, act='relu') + input=input, num_filters=64, filter_size=7, stride=2, act='relu',name="conv1") conv = fluid.layers.pool2d( input=conv, pool_size=3, @@ -47,21 +48,26 @@ class ResNet(): for block in range(len(depth)): for i in range(depth[block]): + if layers in [101, 152] and block == 2: + if i == 0: + conv_name="res"+str(block+2)+"a" + else: + conv_name="res"+str(block+2)+"b"+str(i) + else: + conv_name="res"+str(block+2)+chr(97+i) conv = self.bottleneck_block( input=conv, num_filters=num_filters[block], - stride=2 if i == 0 and block != 0 else 1) + stride=2 if i == 0 and block != 0 else 1,name=conv_name) pool = fluid.layers.pool2d( input=conv, pool_size=7, pool_type='avg', global_pooling=True) - stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0) - out = fluid.layers.fc(input=pool, - size=class_dim, - act='softmax', - param_attr=fluid.param_attr.ParamAttr( - initializer=fluid.initializer.Uniform(-stdv, - stdv))) - return pool, out + + if embedding_size > 0: + embedding = fluid.layers.fc(input=pool, size=embedding_size) + return embedding + else: + return pool def conv_bn_layer(self, input, @@ -69,7 +75,8 @@ class ResNet(): filter_size, stride=1, groups=1, - act=None): + act=None, + name=None): conv = fluid.layers.conv2d( input=input, num_filters=num_filters, @@ -78,31 +85,44 @@ class ResNet(): padding=(filter_size - 1) // 2, groups=groups, act=None, - bias_attr=False) - return fluid.layers.batch_norm(input=conv, act=act) - - def shortcut(self, input, ch_out, stride): + param_attr=ParamAttr(name=name + "_weights"), + bias_attr=False, + name=name + '.conv2d.output.1') + if name == "conv1": + bn_name = "bn_" + name + else: + bn_name = "bn" + name[3:] + return fluid.layers.batch_norm(input=conv, + act=act, + name=bn_name+'.output.1', + param_attr=ParamAttr(name=bn_name + '_scale'), + bias_attr=ParamAttr(bn_name + '_offset'), + moving_mean_name=bn_name + '_mean', + moving_variance_name=bn_name + '_variance',) + + def shortcut(self, input, ch_out, stride, name): ch_in = input.shape[1] if ch_in != ch_out or stride != 1: - return self.conv_bn_layer(input, ch_out, 1, stride) + return self.conv_bn_layer(input, ch_out, 1, stride, name=name) else: return input - def bottleneck_block(self, input, num_filters, stride): + def bottleneck_block(self, input, num_filters, stride, name): conv0 = self.conv_bn_layer( - input=input, num_filters=num_filters, filter_size=1, act='relu') + input=input, num_filters=num_filters, filter_size=1, act='relu',name=name+"_branch2a") conv1 = self.conv_bn_layer( input=conv0, num_filters=num_filters, filter_size=3, stride=stride, - act='relu') + act='relu', + name=name+"_branch2b") conv2 = self.conv_bn_layer( - input=conv1, num_filters=num_filters * 4, filter_size=1, act=None) + input=conv1, num_filters=num_filters * 4, filter_size=1, act=None, name=name+"_branch2c") - short = self.shortcut(input, num_filters * 4, stride) + short = self.shortcut(input, num_filters * 4, stride, name=name + "_branch1") - return fluid.layers.elementwise_add(x=short, y=conv2, act='relu') + return fluid.layers.elementwise_add(x=short, y=conv2, act='relu',name=name+".add.output.5") def ResNet50(): diff --git a/fluid/PaddleCV/metric_learning/models/se_resnext.py b/fluid/PaddleCV/metric_learning/models/se_resnext.py deleted file mode 100644 index 16cc31bc..00000000 --- a/fluid/PaddleCV/metric_learning/models/se_resnext.py +++ /dev/null @@ -1,166 +0,0 @@ -import paddle -import paddle.fluid as fluid -import math - -__all__ = ["SE_ResNeXt", "SE_ResNeXt50_32x4d", "SE_ResNeXt101_32x4d", "SE_ResNeXt152_32x4d"] - -train_parameters = { - "input_size": [3, 224, 224], - "input_mean": [0.485, 0.456, 0.406], - "input_std": [0.229, 0.224, 0.225], - "learning_strategy": { - "name": "piecewise_decay", - "batch_size": 256, - "epochs": [30, 60, 90], - "steps": [0.1, 0.01, 0.001, 0.0001] - } -} - -class SE_ResNeXt(): - def __init__(self, layers = 50): - self.params = train_parameters - self.layers = layers - - def net(self, input, class_dim = 1000): - layers = self.layers - supported_layers = [50, 101, 152] - assert layers in supported_layers, \ - "supported layers are {} but input layer is {}".format(supported_layers, layers) - if layers == 50: - cardinality = 32 - reduction_ratio = 16 - depth = [3, 4, 6, 3] - num_filters = [128, 256, 512, 1024] - - conv = self.conv_bn_layer( - input=input, num_filters=64, filter_size=7, stride=2, act='relu') - conv = fluid.layers.pool2d( - input=conv, - pool_size=3, - pool_stride=2, - pool_padding=1, - pool_type='max') - elif layers == 101: - cardinality = 32 - reduction_ratio = 16 - depth = [3, 4, 23, 3] - num_filters = [128, 256, 512, 1024] - - conv = self.conv_bn_layer( - input=input, num_filters=64, filter_size=7, stride=2, act='relu') - conv = fluid.layers.pool2d( - input=conv, - pool_size=3, - pool_stride=2, - pool_padding=1, - pool_type='max') - elif layers == 152: - cardinality = 64 - reduction_ratio = 16 - depth = [3, 8, 36, 3] - num_filters = [128, 256, 512, 1024] - - conv = self.conv_bn_layer( - input=input, num_filters=64, filter_size=3, stride=2, act='relu') - conv = self.conv_bn_layer( - input=conv, num_filters=64, filter_size=3, stride=1, act='relu') - conv = self.conv_bn_layer( - input=conv, num_filters=128, filter_size=3, stride=1, act='relu') - conv = fluid.layers.pool2d( - input=conv, pool_size=3, pool_stride=2, pool_padding=1, \ - pool_type='max') - - for block in range(len(depth)): - for i in range(depth[block]): - conv = self.bottleneck_block( - input=conv, - num_filters=num_filters[block], - stride=2 if i == 0 and block != 0 else 1, - cardinality=cardinality, - reduction_ratio=reduction_ratio) - - pool = fluid.layers.pool2d( - input=conv, pool_size=7, pool_type='avg', global_pooling=True) - drop = fluid.layers.dropout(x=pool, dropout_prob=0.5) - stdv = 1.0 / math.sqrt(drop.shape[1] * 1.0) - out = fluid.layers.fc(input=drop, - size=class_dim, - act='softmax', - param_attr=fluid.param_attr.ParamAttr( - initializer=fluid.initializer.Uniform(-stdv, - stdv))) - return pool, out - - def shortcut(self, input, ch_out, stride): - ch_in = input.shape[1] - if ch_in != ch_out or stride != 1: - filter_size = 1 - return self.conv_bn_layer(input, ch_out, filter_size, stride) - else: - return input - - def bottleneck_block(self, input, num_filters, stride, cardinality, reduction_ratio): - conv0 = self.conv_bn_layer( - input=input, num_filters=num_filters, filter_size=1, act='relu') - conv1 = self.conv_bn_layer( - input=conv0, - num_filters=num_filters, - filter_size=3, - stride=stride, - groups=cardinality, - act='relu') - conv2 = self.conv_bn_layer( - input=conv1, num_filters=num_filters * 2, filter_size=1, act=None) - scale = self.squeeze_excitation( - input=conv2, - num_channels=num_filters * 2, - reduction_ratio=reduction_ratio) - - short = self.shortcut(input, num_filters * 2, stride) - - return fluid.layers.elementwise_add(x=short, y=scale, act='relu') - - def conv_bn_layer(self, input, num_filters, filter_size, stride=1, groups=1, - act=None): - conv = fluid.layers.conv2d( - input=input, - num_filters=num_filters, - filter_size=filter_size, - stride=stride, - padding=(filter_size - 1) // 2, - groups=groups, - act=None, - bias_attr=False) - return fluid.layers.batch_norm(input=conv, act=act) - - def squeeze_excitation(self, input, num_channels, reduction_ratio): - pool = fluid.layers.pool2d( - input=input, pool_size=0, pool_type='avg', global_pooling=True) - stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0) - squeeze = fluid.layers.fc(input=pool, - size=num_channels / reduction_ratio, - act='relu', - param_attr=fluid.param_attr.ParamAttr( - initializer=fluid.initializer.Uniform(-stdv, - stdv))) - stdv = 1.0 / math.sqrt(squeeze.shape[1] * 1.0) - excitation = fluid.layers.fc(input=squeeze, - size=num_channels, - act='sigmoid', - param_attr=fluid.param_attr.ParamAttr( - initializer=fluid.initializer.Uniform( - -stdv, stdv))) - scale = fluid.layers.elementwise_mul(x=input, y=excitation, axis=0) - return scale - -def SE_ResNeXt50_32x4d(): - model = SE_ResNeXt(layers = 50) - return model - -def SE_ResNeXt101_32x4d(): - model = SE_ResNeXt(layers = 101) - return model - -def SE_ResNeXt152_32x4d(): - model = SE_ResNeXt(layers = 152) - return model diff --git a/fluid/PaddleCV/metric_learning/reader.py b/fluid/PaddleCV/metric_learning/reader.py new file mode 100644 index 00000000..9c5aaf39 --- /dev/null +++ b/fluid/PaddleCV/metric_learning/reader.py @@ -0,0 +1,175 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import math +import random +import functools +import numpy as np +import paddle +from imgtool import process_image + +random.seed(0) + +DATA_DIR = "./data/Stanford_Online_Products/" +TRAIN_LIST = './data/Stanford_Online_Products/Ebay_train.txt' +VAL_LIST = './data/Stanford_Online_Products/Ebay_test.txt' + + +def init_sop(mode): + if mode == 'train': + train_data = {} + train_image_list = [] + train_list = open(TRAIN_LIST, "r").readlines() + for i, item in enumerate(train_list): + items = item.strip().split() + if items[0] == 'image_id': + continue + path = items[3] + label = int(items[1]) - 1 + train_image_list.append((path, label)) + if label not in train_data: + train_data[label] = [] + train_data[label].append(path) + random.shuffle(train_image_list) + print("{} dataset size: {}".format(mode, len(train_data))) + return train_data, train_image_list + else: + val_data = {} + val_image_list = [] + test_image_list = [] + val_list = open(VAL_LIST, "r").readlines() + for i, item in enumerate(val_list): + items = item.strip().split() + if items[0] == 'image_id': + continue + path = items[3] + label = int(items[1]) + val_image_list.append((path, label)) + test_image_list.append(path) + if label not in val_data: + val_data[label] = [] + val_data[label].append(path) + print("{} dataset size: {}".format(mode, len(val_data))) + if mode == 'val': + return val_data, val_image_list + else: + return test_image_list + +def common_iterator(data, settings): + batch_size = settings.train_batch_size + samples_each_class = settings.samples_each_class + assert (batch_size % samples_each_class == 0) + class_num = batch_size // samples_each_class + def train_iterator(): + labs = list(data.keys()) + lab_num = len(labs) + ind = list(range(0, lab_num)) + while True: + random.shuffle(ind) + ind_sample = ind[:class_num] + for ind_i in ind_sample: + lab = labs[ind_i] + data_list = data[lab] + data_ind = list(range(0, len(data_list))) + random.shuffle(data_ind) + anchor_ind = data_ind[:samples_each_class] + + for anchor_ind_i in anchor_ind: + anchor_path = DATA_DIR + data_list[anchor_ind_i] + yield anchor_path, lab + + return train_iterator + +def triplet_iterator(data, settings): + batch_size = settings.train_batch_size + assert (batch_size % 3 == 0) + def train_iterator(): + labs = list(data.keys()) + lab_num = len(labs) + ind = list(range(0, lab_num)) + while True: + random.shuffle(ind) + ind_pos, ind_neg = ind[:2] + lab_pos = labs[ind_pos] + pos_data_list = data[lab_pos] + data_ind = list(range(0, len(pos_data_list))) + random.shuffle(data_ind) + anchor_ind, pos_ind = data_ind[:2] + + lab_neg = labs[ind_neg] + neg_data_list = data[lab_neg] + neg_ind = random.randint(0, len(neg_data_list) - 1) + + anchor_path = DATA_DIR + pos_data_list[anchor_ind] + yield anchor_path, lab_pos + pos_path = DATA_DIR + pos_data_list[pos_ind] + yield pos_path, lab_pos + neg_path = DATA_DIR + neg_data_list[neg_ind] + yield neg_path, lab_neg + + return train_iterator + +def arcmargin_iterator(data, settings): + def train_iterator(): + while True: + for items in data: + path, label = items + path = DATA_DIR + path + yield path, label + return train_iterator + +def image_iterator(data, mode): + def val_iterator(): + for items in data: + path, label = items + path = DATA_DIR + path + yield path, label + def test_iterator(): + for item in data: + path = item + path = DATA_DIR + path + yield [path] + if mode == 'val': + return val_iterator + else: + return test_iterator + +def createreader(settings, mode): + def metric_reader(): + if mode == 'train': + train_data, train_image_list = init_sop('train') + loss_name = settings.loss_name + if loss_name in ["softmax", "arcmargin"]: + return arcmargin_iterator(train_image_list, settings)() + elif loss_name == 'triplet': + return triplet_iterator(train_data, settings)() + else: + return common_iterator(train_data, settings)() + elif mode == 'val': + val_data, val_image_list = init_sop('val') + return image_iterator(val_image_list, 'val')() + else: + test_image_list = init_sop('test') + return image_iterator(test_image_list, 'test')() + + image_shape = settings.image_shape.split(',') + assert(image_shape[1] == image_shape[2]) + image_size = int(image_shape[2]) + keep_order = False if mode != 'train' or settings.loss_name in ['softmax', 'arcmargin'] else True + image_mapper = functools.partial(process_image, + mode=mode, color_jitter=False, rotate=False, crop_size=image_size) + reader = paddle.reader.xmap_readers( + image_mapper, metric_reader, 8, 1000, order=keep_order) + return reader + + +def train(settings): + return createreader(settings, "train") + +def test(settings): + return createreader(settings, "val") + +def infer(settings): + return createreader(settings, "test") diff --git a/fluid/PaddleCV/metric_learning/train.py b/fluid/PaddleCV/metric_learning/train.py deleted file mode 100644 index d3705bf5..00000000 --- a/fluid/PaddleCV/metric_learning/train.py +++ /dev/null @@ -1,211 +0,0 @@ -import os -import sys -import math -import time -import argparse -import functools -import numpy as np -import paddle -import paddle.fluid as fluid -import models -from losses import tripletloss -from losses import quadrupletloss -from losses import emlloss -from losses.metrics import recall_topk -from utility import add_arguments, print_arguments - -parser = argparse.ArgumentParser(description=__doc__) -add_arg = functools.partial(add_arguments, argparser=parser) -# yapf: disable -add_arg('train_batch_size', int, 80, "Minibatch size.") -add_arg('test_batch_size', int, 10, "Minibatch size.") -add_arg('num_epochs', int, 120, "number of epochs.") -add_arg('image_shape', str, "3,224,224", "input image size") -add_arg('model_save_dir', str, "output", "model save directory") -add_arg('with_mem_opt', bool, True, - "Whether to use memory optimization or not.") -add_arg('pretrained_model', str, None, "Whether to use pretrained model.") -add_arg('checkpoint', str, None, "Whether to resume checkpoint.") -add_arg('lr', float, 0.1, "set learning rate.") -add_arg('lr_strategy', str, "piecewise_decay", - "Set the learning rate decay strategy.") -add_arg('model', str, "SE_ResNeXt50_32x4d", "Set the network to use.") -add_arg('loss_name', str, "tripletloss", "Set the loss type to use.") -add_arg('samples_each_class', int, 2, "Samples each class.") -add_arg('margin', float, 0.1, "margin.") -add_arg('alpha', float, 0.0, "alpha.") -# yapf: enable - -model_list = [m for m in dir(models) if "__" not in m] - -def optimizer_setting(params): - ls = params["learning_strategy"] - assert ls["name"] == "piecewise_decay", \ - "learning rate strategy must be {}, \ - but got {}".format("piecewise_decay", lr["name"]) - - step = 10000 - bd = [step * e for e in ls["epochs"]] - base_lr = params["lr"] - lr = [] - lr = [base_lr * (0.1 ** i) for i in range(len(bd) + 1)] - optimizer = fluid.optimizer.Momentum( - learning_rate=fluid.layers.piecewise_decay( - boundaries=bd, values=lr), - momentum=0.9, - regularization=fluid.regularizer.L2Decay(1e-4)) - - return optimizer - -def train(args): - # parameters from arguments - model_name = args.model - checkpoint = args.checkpoint - pretrained_model = args.pretrained_model - with_memory_optimization = args.with_mem_opt - model_save_dir = args.model_save_dir - loss_name = args.loss_name - - image_shape = [int(m) for m in args.image_shape.split(",")] - - assert model_name in model_list, "{} is not in lists: {}".format(args.model, model_list) - - image = fluid.layers.data(name='image', shape=image_shape, dtype='float32') - label = fluid.layers.data(name='label', shape=[1], dtype='int64') - - # model definition - model = models.__dict__[model_name]() - out = model.net(input=image, class_dim=200) - - if loss_name == "tripletloss": - metricloss = tripletloss( - train_batch_size = args.train_batch_size, - margin=args.margin) - cost_metric = metricloss.loss(out[0]) - avg_cost_metric = fluid.layers.mean(x=cost_metric) - elif loss_name == "quadrupletloss": - metricloss = quadrupletloss( - train_batch_size = args.train_batch_size, - samples_each_class = args.samples_each_class, - margin=args.margin) - cost_metric = metricloss.loss(out[0]) - avg_cost_metric = fluid.layers.mean(x=cost_metric) - elif loss_name == "emlloss": - metricloss = emlloss( - train_batch_size = args.train_batch_size, - samples_each_class = args.samples_each_class - ) - cost_metric = metricloss.loss(out[0]) - avg_cost_metric = fluid.layers.mean(x=cost_metric) - - cost_cls = fluid.layers.cross_entropy(input=out[1], label=label) - avg_cost_cls = fluid.layers.mean(x=cost_cls) - acc_top1 = fluid.layers.accuracy(input=out[1], label=label, k=1) - acc_top5 = fluid.layers.accuracy(input=out[1], label=label, k=5) - avg_cost = avg_cost_metric + args.alpha*avg_cost_cls - - - test_program = fluid.default_main_program().clone(for_test=True) - - # parameters from model and arguments - params = model.params - params["lr"] = args.lr - params["num_epochs"] = args.num_epochs - params["learning_strategy"]["batch_size"] = args.train_batch_size - params["learning_strategy"]["name"] = args.lr_strategy - - # initialize optimizer - optimizer = optimizer_setting(params) - opts = optimizer.minimize(avg_cost) - - global_lr = optimizer._global_learning_rate() - - place = fluid.CUDAPlace(0) - exe = fluid.Executor(place) - exe.run(fluid.default_startup_program()) - - if checkpoint is not None: - fluid.io.load_persistables(exe, checkpoint) - - if pretrained_model: - assert(checkpoint is None) - def if_exist(var): - has_var = os.path.exists(os.path.join(pretrained_model, var.name)) - if has_var: - print('var: %s found' % (var.name)) - return has_var - fluid.io.load_vars(exe, pretrained_model, predicate=if_exist) - - train_reader = paddle.batch(metricloss.train_reader, batch_size=args.train_batch_size) - test_reader = paddle.batch(metricloss.test_reader, batch_size=args.test_batch_size) - feeder = fluid.DataFeeder(place=place, feed_list=[image, label]) - - train_exe = fluid.ParallelExecutor(use_cuda=True, loss_name=avg_cost.name) - - fetch_list_train = [avg_cost_metric.name, avg_cost_cls.name, acc_top1.name, acc_top5.name, global_lr.name] - fetch_list_test = [out[0].name] - - if with_memory_optimization: - fluid.memory_optimize(fluid.default_main_program(), skip_opt_set=set(fetch_list_train)) - - for pass_id in range(params["num_epochs"]): - train_info = [[], [], [], []] - for batch_id, data in enumerate(train_reader()): - t1 = time.time() - loss_metric, loss_cls, acc1, acc5, lr = train_exe.run(fetch_list_train, feed=feeder.feed(data)) - t2 = time.time() - period = t2 - t1 - loss_metric = np.mean(np.array(loss_metric)) - loss_cls = np.mean(np.array(loss_cls)) - acc1 = np.mean(np.array(acc1)) - acc5 = np.mean(np.array(acc5)) - lr = np.mean(np.array(lr)) - train_info[0].append(loss_metric) - train_info[1].append(loss_cls) - train_info[2].append(acc1) - train_info[3].append(acc5) - if batch_id % 10 == 0: - print("Pass {0}, trainbatch {1}, lr {2}, loss_metric {3}, loss_cls {4}, acc1 {5}, acc5 {6}, time {7}".format(pass_id, \ - batch_id, lr, loss_metric, loss_cls, acc1, acc5, "%2.2f sec" % period)) - - train_loss_metric = np.array(train_info[0]).mean() - train_loss_cls = np.array(train_info[1]).mean() - train_acc1 = np.array(train_info[2]).mean() - train_acc5 = np.array(train_info[3]).mean() - f = [] - l = [] - for batch_id, data in enumerate(test_reader()): - if len(data) < args.test_batch_size: - continue - t1 = time.time() - [feas] = exe.run(test_program, fetch_list = fetch_list_test, feed=feeder.feed(data)) - label = np.asarray([x[1] for x in data]) - f.append(feas) - l.append(label) - - t2 = time.time() - period = t2 - t1 - if batch_id % 20 == 0: - print("Pass {0}, testbatch {1}, time {2}".format(pass_id, \ - batch_id, "%2.2f sec" % period)) - - f = np.vstack(f) - l = np.hstack(l) - recall = recall_topk(f, l, k = 1) - print("End pass {0}, train_loss_metric {1}, train_loss_cls {2}, train_acc1 {3}, train_acc5 {4}, test_recall {5}".format(pass_id, \ - train_loss_metric, train_loss_cls, train_acc1, train_acc5, recall)) - sys.stdout.flush() - - model_path = os.path.join(model_save_dir + '/' + model_name, - str(pass_id)) - if not os.path.isdir(model_path): - os.makedirs(model_path) - fluid.io.save_persistables(exe, model_path) - -def main(): - args = parser.parse_args() - print_arguments(args) - train(args) - -if __name__ == '__main__': - main() diff --git a/fluid/PaddleCV/metric_learning/train_elem.py b/fluid/PaddleCV/metric_learning/train_elem.py new file mode 100644 index 00000000..126a1ff1 --- /dev/null +++ b/fluid/PaddleCV/metric_learning/train_elem.py @@ -0,0 +1,290 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import sys +import math +import time +import logging +import argparse +import functools +import threading +import subprocess +import numpy as np +import paddle +import paddle.fluid as fluid +import models +import reader +from losses import SoftmaxLoss +from losses import ArcMarginLoss +from utility import add_arguments, print_arguments +from utility import fmt_time, recall_topk, get_gpu_num + +parser = argparse.ArgumentParser(description=__doc__) +add_arg = functools.partial(add_arguments, argparser=parser) +# yapf: disable +add_arg('model', str, "ResNet50", "Set the network to use.") +add_arg('embedding_size', int, 0, "Embedding size.") +add_arg('train_batch_size', int, 256, "Minibatch size.") +add_arg('test_batch_size', int, 50, "Minibatch size.") +add_arg('image_shape', str, "3,224,224", "input image size") +add_arg('class_dim', int, 11318 , "Class number.") +add_arg('lr', float, 0.01, "set learning rate.") +add_arg('lr_strategy', str, "piecewise_decay", "Set the learning rate decay strategy.") +add_arg('lr_steps', str, "30000", "step of lr") +add_arg('total_iter_num', int, 30000, "total_iter_num") +add_arg('display_iter_step', int, 10, "display_iter_step.") +add_arg('test_iter_step', int, 1000, "test_iter_step.") +add_arg('save_iter_step', int, 1000, "save_iter_step.") +add_arg('use_gpu', bool, True, "Whether to use GPU or not.") +add_arg('with_mem_opt', bool, True, "Whether to use memory optimization or not.") +add_arg('pretrained_model', str, None, "Whether to use pretrained model.") +add_arg('checkpoint', str, None, "Whether to resume checkpoint.") +add_arg('model_save_dir', str, "output", "model save directory") +add_arg('loss_name', str, "softmax", "Set the loss type to use.") +add_arg('arc_scale', float, 80.0, "arc scale.") +add_arg('arc_margin', float, 0.15, "arc margin.") +add_arg('arc_easy_margin', bool, False, "arc easy margin.") +add_arg('enable_ce', bool, False, "If set True, enable continuous evaluation job.") +# yapf: enable + +model_list = [m for m in dir(models) if "__" not in m] + +def optimizer_setting(params): + ls = params["learning_strategy"] + assert ls["name"] == "piecewise_decay", \ + "learning rate strategy must be {}, \ + but got {}".format("piecewise_decay", lr["name"]) + + bd = [int(e) for e in ls["lr_steps"].split(',')] + base_lr = params["lr"] + lr = [base_lr * (0.1 ** i) for i in range(len(bd) + 1)] + optimizer = fluid.optimizer.Momentum( + learning_rate=fluid.layers.piecewise_decay( + boundaries=bd, values=lr), + momentum=0.9, + regularization=fluid.regularizer.L2Decay(1e-4)) + return optimizer + + +def net_config(image, label, model, args, is_train): + assert args.model in model_list, "{} is not in lists: {}".format( + args.model, model_list) + + out = model.net(input=image, embedding_size=args.embedding_size) + if not is_train: + return None, None, None, out + + if args.loss_name == "softmax": + metricloss = SoftmaxLoss( + class_dim=args.class_dim, + ) + elif args.loss_name == "arcmargin": + metricloss = ArcMarginLoss( + class_dim = args.class_dim, + margin = args.arc_margin, + scale = args.arc_scale, + easy_margin = args.arc_easy_margin, + ) + cost, logit = metricloss.loss(out, label) + avg_cost = fluid.layers.mean(x=cost) + acc_top1 = fluid.layers.accuracy(input=logit, label=label, k=1) + acc_top5 = fluid.layers.accuracy(input=logit, label=label, k=5) + return avg_cost, acc_top1, acc_top5, out + +def build_program(is_train, main_prog, startup_prog, args): + image_shape = [int(m) for m in args.image_shape.split(",")] + model = models.__dict__[args.model]() + with fluid.program_guard(main_prog, startup_prog): + if is_train: + queue_capacity = 64 + py_reader = fluid.layers.py_reader( + capacity=queue_capacity, + shapes=[[-1] + image_shape, [-1, 1]], + lod_levels=[0, 0], + dtypes=["float32", "int64"], + use_double_buffer=True) + image, label = fluid.layers.read_file(py_reader) + else: + image = fluid.layers.data(name='image', shape=image_shape, dtype='float32') + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + + with fluid.unique_name.guard(): + avg_cost, acc_top1, acc_top5, out = net_config(image, label, model, args, is_train) + if is_train: + params = model.params + params["lr"] = args.lr + params["learning_strategy"]["lr_steps"] = args.lr_steps + params["learning_strategy"]["name"] = args.lr_strategy + optimizer = optimizer_setting(params) + optimizer.minimize(avg_cost) + global_lr = optimizer._global_learning_rate() + """ + if not is_train: + main_prog = main_prog.clone(for_test=True) + """ + if is_train: + return py_reader, avg_cost, acc_top1, acc_top5, global_lr + else: + return out, image, label + + +def train_async(args): + # parameters from arguments + + logging.debug('enter train') + model_name = args.model + checkpoint = args.checkpoint + pretrained_model = args.pretrained_model + model_save_dir = args.model_save_dir + + startup_prog = fluid.Program() + train_prog = fluid.Program() + tmp_prog = fluid.Program() + + if args.enable_ce: + assert args.model == "ResNet50" + assert args.loss_name == "arcmargin" + np.random.seed(0) + startup_prog.random_seed = 1000 + train_prog.random_seed = 1000 + tmp_prog.random_seed = 1000 + + train_py_reader, train_cost, train_acc1, train_acc5, global_lr = build_program( + is_train=True, + main_prog=train_prog, + startup_prog=startup_prog, + args=args) + test_feas, image, label = build_program( + is_train=False, + main_prog=tmp_prog, + startup_prog=startup_prog, + args=args) + test_prog = tmp_prog.clone(for_test=True) + + train_fetch_list = [global_lr.name, train_cost.name, train_acc1.name, train_acc5.name] + test_fetch_list = [test_feas.name] + + if args.with_mem_opt: + fluid.memory_optimize(train_prog, skip_opt_set=set(train_fetch_list)) + + place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace() + exe = fluid.Executor(place) + + exe.run(startup_prog) + + logging.debug('after run startup program') + + if checkpoint is not None: + fluid.io.load_persistables(exe, checkpoint, main_program=train_prog) + + if pretrained_model: + + def if_exist(var): + return os.path.exists(os.path.join(pretrained_model, var.name)) + + fluid.io.load_vars( + exe, pretrained_model, main_program=train_prog, predicate=if_exist) + + devicenum = get_gpu_num() + assert (args.train_batch_size % devicenum) == 0 + train_batch_size = args.train_batch_size // devicenum + test_batch_size = args.test_batch_size + + train_reader = paddle.batch(reader.train(args), batch_size=train_batch_size, drop_last=True) + test_reader = paddle.batch(reader.test(args), batch_size=test_batch_size, drop_last=False) + test_feeder = fluid.DataFeeder(place=place, feed_list=[image, label]) + train_py_reader.decorate_paddle_reader(train_reader) + + train_exe = fluid.ParallelExecutor( + main_program=train_prog, + use_cuda=args.use_gpu, + loss_name=train_cost.name) + + totalruntime = 0 + train_py_reader.start() + iter_no = 0 + train_info = [0, 0, 0, 0] + while iter_no <= args.total_iter_num: + t1 = time.time() + lr, loss, acc1, acc5 = train_exe.run(fetch_list=train_fetch_list) + t2 = time.time() + period = t2 - t1 + lr = np.mean(np.array(lr)) + train_info[0] += np.mean(np.array(loss)) + train_info[1] += np.mean(np.array(acc1)) + train_info[2] += np.mean(np.array(acc5)) + train_info[3] += 1 + if iter_no % args.display_iter_step == 0: + avgruntime = totalruntime / args.display_iter_step + avg_loss = train_info[0] / train_info[3] + avg_acc1 = train_info[1] / train_info[3] + avg_acc5 = train_info[2] / train_info[3] + print("[%s] trainbatch %d, lr %.6f, loss %.6f, "\ + "acc1 %.4f, acc5 %.4f, time %2.2f sec" % \ + (fmt_time(), iter_no, lr, avg_loss, avg_acc1, avg_acc5, avgruntime)) + sys.stdout.flush() + totalruntime = 0 + if iter_no % 1000 == 0: + train_info = [0, 0, 0, 0] + + totalruntime += period + + if iter_no % args.test_iter_step == 0 and iter_no != 0: + f, l = [], [] + for batch_id, data in enumerate(test_reader()): + t1 = time.time() + [feas] = exe.run(test_prog, fetch_list = test_fetch_list, feed=test_feeder.feed(data)) + label = np.asarray([x[1] for x in data]) + f.append(feas) + l.append(label) + + t2 = time.time() + period = t2 - t1 + if batch_id % 20 == 0: + print("[%s] testbatch %d, time %2.2f sec" % \ + (fmt_time(), batch_id, period)) + + f = np.vstack(f) + l = np.hstack(l) + recall = recall_topk(f, l, k=1) + print("[%s] test_img_num %d, trainbatch %d, test_recall %.5f" % \ + (fmt_time(), len(f), iter_no, recall)) + sys.stdout.flush() + + if iter_no % args.save_iter_step == 0 and iter_no != 0: + model_path = os.path.join(model_save_dir + '/' + model_name, + str(iter_no)) + if not os.path.isdir(model_path): + os.makedirs(model_path) + fluid.io.save_persistables(exe, model_path, main_program=train_prog) + + iter_no += 1 + + # This is for continuous evaluation only + if args.enable_ce: + # Use the mean cost/acc for training + print("kpis train_cost %s" % (avg_loss)) + print("kpis test_recall %s" % (recall)) + + +def initlogging(): + for handler in logging.root.handlers[:]: + logging.root.removeHandler(handler) + loglevel = logging.DEBUG + logging.basicConfig( + level=loglevel, + # logger.BASIC_FORMAT, + format= + "%(levelname)s:%(filename)s[%(lineno)s] %(name)s:%(funcName)s->%(message)s", + datefmt='%a, %d %b %Y %H:%M:%S') + +def main(): + args = parser.parse_args() + print_arguments(args) + train_async(args) + + +if __name__ == '__main__': + main() diff --git a/fluid/PaddleCV/metric_learning/train_pair.py b/fluid/PaddleCV/metric_learning/train_pair.py new file mode 100644 index 00000000..da94ec5c --- /dev/null +++ b/fluid/PaddleCV/metric_learning/train_pair.py @@ -0,0 +1,274 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import sys +import math +import time +import logging +import argparse +import functools +import threading +import subprocess +import numpy as np +import paddle +import paddle.fluid as fluid +import models +import reader +from losses import TripletLoss +from losses import QuadrupletLoss +from losses import EmlLoss +from utility import add_arguments, print_arguments +from utility import fmt_time, recall_topk, get_gpu_num + +parser = argparse.ArgumentParser(description=__doc__) +add_arg = functools.partial(add_arguments, argparser=parser) +# yapf: disable +add_arg('model', str, "ResNet50", "Set the network to use.") +add_arg('embedding_size', int, 0, "Embedding size.") +add_arg('train_batch_size', int, 120, "Minibatch size.") +add_arg('test_batch_size', int, 50, "Minibatch size.") +add_arg('image_shape', str, "3,224,224", "input image size") +add_arg('class_dim', int, 11318, "Class number.") +add_arg('lr', float, 0.0001, "set learning rate.") +add_arg('lr_strategy', str, "piecewise_decay", "Set the learning rate decay strategy.") +add_arg('lr_steps', str, "100000", "step of lr") +add_arg('total_iter_num', int, 100000, "total_iter_num") +add_arg('display_iter_step', int, 10, "display_iter_step.") +add_arg('test_iter_step', int, 5000, "test_iter_step.") +add_arg('save_iter_step', int, 5000, "save_iter_step.") +add_arg('use_gpu', bool, True, "Whether to use GPU or not.") +add_arg('with_mem_opt', bool, True, "Whether to use memory optimization or not.") +add_arg('pretrained_model', str, None, "Whether to use pretrained model.") +add_arg('checkpoint', str, None, "Whether to resume checkpoint.") +add_arg('model_save_dir', str, "output", "model save directory") +add_arg('loss_name', str, "triplet", "Set the loss type to use.") +add_arg('samples_each_class', int, 2, "samples_each_class.") +add_arg('margin', float, 0.1, "margin.") +# yapf: enable + +model_list = [m for m in dir(models) if "__" not in m] + +def optimizer_setting(params): + ls = params["learning_strategy"] + assert ls["name"] == "piecewise_decay", \ + "learning rate strategy must be {}, \ + but got {}".format("piecewise_decay", lr["name"]) + + bd = [int(e) for e in ls["lr_steps"].split(',')] + base_lr = params["lr"] + lr = [base_lr * (0.1 ** i) for i in range(len(bd) + 1)] + optimizer = fluid.optimizer.Momentum( + learning_rate=fluid.layers.piecewise_decay( + boundaries=bd, values=lr), + momentum=0.9, + regularization=fluid.regularizer.L2Decay(1e-4)) + return optimizer + + +def net_config(image, label, model, args, is_train): + assert args.model in model_list, "{} is not in lists: {}".format( + args.model, model_list) + + out = model.net(input=image, embedding_size=args.embedding_size) + if not is_train: + return None, out + + if args.loss_name == "triplet": + metricloss = TripletLoss( + margin=args.margin, + ) + elif args.loss_name == "quadruplet": + metricloss = QuadrupletLoss( + train_batch_size = args.train_batch_size, + samples_each_class = args.samples_each_class, + margin=args.margin, + ) + elif args.loss_name == "eml": + metricloss = EmlLoss( + train_batch_size = args.train_batch_size, + samples_each_class = args.samples_each_class, + ) + cost = metricloss.loss(out) + avg_cost = fluid.layers.mean(x=cost) + return avg_cost, out + +def build_program(is_train, main_prog, startup_prog, args): + image_shape = [int(m) for m in args.image_shape.split(",")] + model = models.__dict__[args.model]() + with fluid.program_guard(main_prog, startup_prog): + if is_train: + queue_capacity = 64 + py_reader = fluid.layers.py_reader( + capacity=queue_capacity, + shapes=[[-1] + image_shape, [-1, 1]], + lod_levels=[0, 0], + dtypes=["float32", "int64"], + use_double_buffer=True) + image, label = fluid.layers.read_file(py_reader) + else: + image = fluid.layers.data(name='image', shape=image_shape, dtype='float32') + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + + with fluid.unique_name.guard(): + avg_cost, out = net_config(image, label, model, args, is_train) + if is_train: + params = model.params + params["lr"] = args.lr + params["learning_strategy"]["lr_steps"] = args.lr_steps + params["learning_strategy"]["name"] = args.lr_strategy + optimizer = optimizer_setting(params) + optimizer.minimize(avg_cost) + global_lr = optimizer._global_learning_rate() + """ + if not is_train: + main_prog = main_prog.clone(for_test=True) + """ + if is_train: + return py_reader, avg_cost, global_lr, out, label + else: + return out, image, label + + +def train_async(args): + # parameters from arguments + + logging.debug('enter train') + model_name = args.model + checkpoint = args.checkpoint + pretrained_model = args.pretrained_model + model_save_dir = args.model_save_dir + + startup_prog = fluid.Program() + train_prog = fluid.Program() + tmp_prog = fluid.Program() + + train_py_reader, train_cost, global_lr, train_feas, train_label = build_program( + is_train=True, + main_prog=train_prog, + startup_prog=startup_prog, + args=args) + test_feas, image, label = build_program( + is_train=False, + main_prog=tmp_prog, + startup_prog=startup_prog, + args=args) + test_prog = tmp_prog.clone(for_test=True) + + train_fetch_list = [global_lr.name, train_cost.name, train_feas.name, train_label.name] + test_fetch_list = [test_feas.name] + + if args.with_mem_opt: + fluid.memory_optimize(train_prog, skip_opt_set=set(train_fetch_list)) + + place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace() + exe = fluid.Executor(place) + + exe.run(startup_prog) + + logging.debug('after run startup program') + + if checkpoint is not None: + fluid.io.load_persistables(exe, checkpoint, main_program=train_prog) + + if pretrained_model: + + def if_exist(var): + return os.path.exists(os.path.join(pretrained_model, var.name)) + + fluid.io.load_vars( + exe, pretrained_model, main_program=train_prog, predicate=if_exist) + + devicenum = get_gpu_num() + assert (args.train_batch_size % devicenum) == 0 + train_batch_size = args.train_batch_size / devicenum + test_batch_size = args.test_batch_size + + train_reader = paddle.batch(reader.train(args), batch_size=train_batch_size, drop_last=True) + test_reader = paddle.batch(reader.test(args), batch_size=test_batch_size, drop_last=False) + test_feeder = fluid.DataFeeder(place=place, feed_list=[image, label]) + train_py_reader.decorate_paddle_reader(train_reader) + + train_exe = fluid.ParallelExecutor( + main_program=train_prog, + use_cuda=args.use_gpu, + loss_name=train_cost.name) + + totalruntime = 0 + train_py_reader.start() + iter_no = 0 + train_info = [0, 0, 0] + while iter_no <= args.total_iter_num: + t1 = time.time() + lr, loss, feas, label = train_exe.run(fetch_list=train_fetch_list) + t2 = time.time() + period = t2 - t1 + lr = np.mean(np.array(lr)) + train_info[0] += np.mean(np.array(loss)) + train_info[1] += recall_topk(feas, label, k=1) + train_info[2] += 1 + if iter_no % args.display_iter_step == 0: + avgruntime = totalruntime / args.display_iter_step + avg_loss = train_info[0] / train_info[2] + avg_recall = train_info[1] / train_info[2] + print("[%s] trainbatch %d, lr %.6f, loss %.6f, "\ + "recall %.4f, time %2.2f sec" % \ + (fmt_time(), iter_no, lr, avg_loss, avg_recall, avgruntime)) + sys.stdout.flush() + totalruntime = 0 + if iter_no % 1000 == 0: + train_info = [0, 0, 0] + + totalruntime += period + + if iter_no % args.test_iter_step == 0 and iter_no != 0: + f, l = [], [] + for batch_id, data in enumerate(test_reader()): + t1 = time.time() + [feas] = exe.run(test_prog, fetch_list = test_fetch_list, feed=test_feeder.feed(data)) + label = np.asarray([x[1] for x in data]) + f.append(feas) + l.append(label) + + t2 = time.time() + period = t2 - t1 + if batch_id % 20 == 0: + print("[%s] testbatch %d, time %2.2f sec" % \ + (fmt_time(), batch_id, period)) + + f = np.vstack(f) + l = np.hstack(l) + recall = recall_topk(f, l, k=1) + print("[%s] test_img_num %d, trainbatch %d, test_recall %.5f" % \ + (fmt_time(), len(f), iter_no, recall)) + sys.stdout.flush() + + if iter_no % args.save_iter_step == 0 and iter_no != 0: + model_path = os.path.join(model_save_dir + '/' + model_name, + str(iter_no)) + if not os.path.isdir(model_path): + os.makedirs(model_path) + fluid.io.save_persistables(exe, model_path, main_program=train_prog) + + iter_no += 1 + +def initlogging(): + for handler in logging.root.handlers[:]: + logging.root.removeHandler(handler) + loglevel = logging.DEBUG + logging.basicConfig( + level=loglevel, + # logger.BASIC_FORMAT, + format= + "%(levelname)s:%(filename)s[%(lineno)s] %(name)s:%(funcName)s->%(message)s", + datefmt='%a, %d %b %Y %H:%M:%S') + +def main(): + args = parser.parse_args() + print_arguments(args) + train_async(args) + + +if __name__ == '__main__': + main() diff --git a/fluid/PaddleCV/metric_learning/utility.py b/fluid/PaddleCV/metric_learning/utility.py index d8fe4164..e3a109ca 100644 --- a/fluid/PaddleCV/metric_learning/utility.py +++ b/fluid/PaddleCV/metric_learning/utility.py @@ -16,9 +16,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import distutils.util + +import os import six +import time +import subprocess +import distutils.util import numpy as np + from paddle.fluid import core @@ -61,3 +66,38 @@ def add_arguments(argname, type, default, help, argparser, **kwargs): type=type, help=help + ' Default: %(default)s.', **kwargs) + +def fmt_time(): + """ get formatted time for now + """ + now_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())) + return now_str + +def recall_topk(fea, lab, k = 1): + fea = np.array(fea) + fea = fea.reshape(fea.shape[0], -1) + n = np.sqrt(np.sum(fea**2, 1)).reshape(-1, 1) + fea = fea / n + a = np.sum(fea ** 2, 1).reshape(-1, 1) + b = a.T + ab = np.dot(fea, fea.T) + d = a + b - 2*ab + d = d + np.eye(len(fea)) * 1e8 + sorted_index = np.argsort(d, 1) + res = 0 + for i in range(len(fea)): + pred = lab[sorted_index[i][0]] + if lab[i] == pred: + res += 1.0 + res = res / len(fea) + return res + +def get_gpu_num(): + visibledevice = os.getenv('CUDA_VISIBLE_DEVICES') + if visibledevice: + devicenum = len(visibledevice.split(',')) + else: + devicenum = subprocess.check_output( + [str.encode('nvidia-smi'), str.encode('-L')]).decode('utf-8').count('\n') + return devicenum + -- GitLab