提交 d65c9edf 编写于 作者: K kbChen 提交者: qingqing01

fix metric learning (#1276)

* fix metric learning
* fix review
上级 58002048
wget http://www.vision.caltech.edu/visipedia-data/CUB-200/images.tgz
tar zxf images.tgz
find $PWD/images|grep jpg|grep -v "\._" > list.txt
find images|grep jpg|grep -v "\._" > list.txt
python split.py
rm -rf images.tgz list.txt
......@@ -49,14 +49,13 @@ def eval(args):
out = model.net(input=image, class_dim=200)
if loss_name == "tripletloss":
metricloss = tripletloss(test_batch_size=args.batch_size, margin=0.1)
metricloss = tripletloss()
cost = metricloss.loss(out[0])
elif loss_name == "quadrupletloss":
metricloss = quadrupletloss(test_batch_size=args.batch_size)
metricloss = quadrupletloss()
cost = metricloss.loss(out[0])
elif loss_name == "emlloss":
metricloss = emlloss(
test_batch_size=args.batch_size, samples_each_class=2, fea_dim=2048)
metricloss = emlloss()
cost = metricloss.loss(out[0])
avg_cost = fluid.layers.mean(x=cost)
......@@ -76,7 +75,7 @@ def eval(args):
fluid.io.load_vars(exe, pretrained_model, predicate=if_exist)
test_reader = metricloss.test_reader
test_reader = paddle.batch(metricloss.test_reader, batch_size=args.batch_size)
feeder = fluid.DataFeeder(place=place, feed_list=[image, label])
fetch_list = [avg_cost.name, out[0].name]
......@@ -84,11 +83,14 @@ def eval(args):
test_info = [[]]
f = []
l = []
for batch_id, (data, label) in enumerate(test_reader()):
for batch_id, data in enumerate(test_reader()):
if len(data) < args.batch_size:
continue
t1 = time.time()
loss, feas = exe.run(test_program,
fetch_list=fetch_list,
feed=feeder.feed(data))
label = np.asarray([x[1] for x in data])
f.append(feas)
l.append(label)
......@@ -96,15 +98,13 @@ def eval(args):
period = t2 - t1
loss = np.mean(np.array(loss))
test_info[0].append(loss)
if batch_id % 10 == 0:
recall = recall_topk(feas, label, k=1)
print("testbatch {0}, loss {1}, recall {2}, time {3}".format( \
batch_id, loss, recall, "%2.2f sec" % period))
sys.stdout.flush()
if batch_id % 20 == 0:
print("testbatch {0}, loss {1}, time {2}".format( \
batch_id, loss, "%2.2f sec" % period))
test_loss = np.array(test_info[0]).mean()
f = np.vstack(f)
l = np.vstack(l)
l = np.hstack(l)
recall = recall_topk(f, l, k=1)
print("End test, test_loss {0}, test recall {1}".format( \
test_loss, recall))
......
......@@ -56,7 +56,7 @@ def infer(args):
fluid.io.load_vars(exe, pretrained_model, predicate=if_exist)
infer_reader = tripletloss(infer_batch_size=args.batch_size).infer_reader
infer_reader = paddle.batch(tripletloss().infer_reader, batch_size=args.batch_size)
feeder = fluid.DataFeeder(place=place, feed_list=[image])
fetch_list = [out[0].name]
......
......@@ -18,15 +18,17 @@ BUF_SIZE = 1024000
DATA_DIR = "./data/"
TRAIN_LIST = './data/CUB200_train.txt'
TEST_LIST = './data/CUB200_val.txt'
#DATA_DIR = "./thirdparty/paddlemodels/metric_learning/data/"
#TRAIN_LIST = './thirdparty/paddlemodels/metric_learning/data/CUB200_train.txt'
#TEST_LIST = './thirdparty/paddlemodels/metric_learning/data/CUB200_val.txt'
#DATA_DIR = "./data/CUB200/"
#TRAIN_LIST = './data/CUB200/CUB200_train.txt'
#TEST_LIST = './data/CUB200/CUB200_val.txt'
train_data = {}
test_data = {}
train_list = open(TRAIN_LIST, "r").readlines()
train_image_list = []
for i, item in enumerate(train_list):
path, label = item.strip().split()
label = int(label) - 1
train_image_list.append((path, label))
if label not in train_data:
train_data[label] = []
train_data[label].append(path)
......@@ -53,7 +55,6 @@ img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1))
img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1))
def resize_short(img, target_size):
percent = float(target_size) / min(img.size[0], img.size[1])
resized_width = int(round(img.size[0] * percent))
......@@ -198,25 +199,24 @@ def process_image_imagepath(sample, mode, color_jitter, rotate):
img -= img_mean
img /= img_std
if mode == 'train' or mode == 'test':
if mode in ['train', 'test']:
return img, sample[1]
elif mode == 'infer':
return img
return [img]
def eml_iterator(data,
mode,
batch_size,
samples_each_class,
iter_size,
shuffle=False,
color_jitter=False,
rotate=False):
mode,
batch_size,
samples_each_class,
iter_size,
shuffle=False,
color_jitter=False,
rotate=False):
def reader():
labs = data.keys()
lab_num = len(labs)
counter = np.zeros(lab_num)
ind = range(0, lab_num)
batchdata = []
assert batch_size % samples_each_class == 0, "batch_size % samples_each_class != 0"
num_class = batch_size/samples_each_class
for i in range(iter_size):
......@@ -228,31 +228,26 @@ def eml_iterator(data,
random.shuffle(data_list)
for s in range(samples_each_class):
path = DATA_DIR + data_list[s]
# print("path:", path)
img, _ = process_image_imagepath(sample = [path, label], \
mode = mode, \
color_jitter = color_jitter, \
rotate = rotate)
batchdata.append([img, label])
#print("batch_size:", len(batchdata))
if len(batchdata) == batch_size:
yield batchdata
batchdata = []
return reader
yield path, label
mapper = functools.partial(
process_image_imagepath, mode=mode, color_jitter=color_jitter, rotate=rotate)
return paddle.reader.xmap_readers(mapper, reader, THREAD, BUF_SIZE, order=True)
def quadruplet_iterator(data,
mode,
class_num,
samples_each_class,
iter_size,
shuffle=False,
color_jitter=False,
rotate=False):
mode,
class_num,
samples_each_class,
iter_size,
shuffle=False,
color_jitter=False,
rotate=False):
def reader():
labs = data.keys()
lab_num = len(labs)
ind = range(0, lab_num)
batchdata = []
for i in range(iter_size):
random.shuffle(ind)
ind_sample = ind[:class_num]
......@@ -266,28 +261,25 @@ def quadruplet_iterator(data,
for anchor_ind_i in anchor_ind:
anchor_path = DATA_DIR + data_list[anchor_ind_i]
anchor_img, _ = process_image_imagepath(sample = [anchor_path, lab], \
mode = mode, \
color_jitter = color_jitter, \
rotate = rotate)
batchdata.append([anchor_img, lab])
yield batchdata
batchdata = []
return reader
yield anchor_path, lab
mapper = functools.partial(
process_image_imagepath, mode=mode, color_jitter=color_jitter, rotate=rotate)
return paddle.reader.xmap_readers(mapper, reader, THREAD, BUF_SIZE, order=True)
def triplet_iterator(data,
mode,
batch_size,
iter_size,
shuffle=False,
color_jitter=False,
rotate=False):
mode,
batch_size,
iter_size,
shuffle=False,
color_jitter=False,
rotate=False):
def reader():
labs = data.keys()
lab_num = len(labs)
counter = np.zeros(lab_num)
ind = range(0, lab_num)
batchdata = []
for i in range(iter_size):
random.shuffle(ind)
ind_pos, ind_neg = ind[:2]
......@@ -302,92 +294,63 @@ def triplet_iterator(data,
neg_ind = random.randint(0, len(neg_data_list) - 1)
anchor_path = DATA_DIR + pos_data_list[anchor_ind]
anchor_img, _ = process_image_imagepath(sample = [anchor_path, lab_pos], \
mode = mode, \
color_jitter = color_jitter, \
rotate = rotate)
yield anchor_path, lab_pos
pos_path = DATA_DIR + pos_data_list[pos_ind]
pos_img, _ = process_image_imagepath(sample = [pos_path, lab_pos], \
mode = mode, \
color_jitter = color_jitter, \
rotate = rotate)
yield pos_path, lab_pos
neg_path = DATA_DIR + neg_data_list[neg_ind]
neg_img, _ = process_image_imagepath(sample = [neg_path, lab_neg], \
mode = mode, \
color_jitter = color_jitter, \
rotate = rotate)
batchdata.append([anchor_img, lab_pos])
batchdata.append([pos_img, lab_pos])
batchdata.append([neg_img, lab_neg])
#print("batchdata:", len(batchdata))
if len(batchdata) == batch_size:
yield batchdata
batchdata = []
if batchdata:
yield batchdata
return reader
yield neg_path, lab_neg
mapper = functools.partial(
process_image_imagepath, mode=mode, color_jitter=color_jitter, rotate=rotate)
return paddle.reader.xmap_readers(mapper, reader, THREAD, BUF_SIZE, order=True)
def image_iterator(data,
mode,
batch_size,
shuffle=False,
color_jitter=False,
rotate=False):
def infer_reader():
batchdata = []
mode,
shuffle=False,
color_jitter=False,
rotate=False):
def test_reader():
for i in range(len(data)):
path = data[i]
path, label = data[i]
path = DATA_DIR + path
img = process_image_imagepath(sample = [path], \
mode = "infer", \
color_jitter = color_jitter, \
rotate = rotate)
batchdata.append([img])
if len(batchdata) == batch_size:
yield batchdata
batchdata = []
def reader():
batchdata = []
batchlabel = []
yield path, label
def infer_reader():
for i in range(len(data)):
path, label = data[i]
path = data[i]
path = DATA_DIR + path
img, label = process_image_imagepath(sample = [path, label], \
mode = mode, \
color_jitter = color_jitter, \
rotate = rotate)
batchdata.append([img, label])
batchlabel.append([label])
if len(batchdata) == batch_size:
yield batchdata, batchlabel
batchdata = []
batchlabel = []
if mode in ["train", "test"]:
return reader
else:
return infer_reader
yield [path]
if mode == "test":
mapper = functools.partial(
process_image_imagepath, mode=mode, color_jitter=color_jitter, rotate=rotate)
return paddle.reader.xmap_readers(mapper, test_reader, THREAD, BUF_SIZE)
elif mode == "infer":
mapper = functools.partial(
process_image_imagepath, mode=mode, color_jitter=color_jitter, rotate=rotate)
return paddle.reader.xmap_readers(mapper, infer_reader, THREAD, BUF_SIZE)
def eml_train(batch_size, samples_each_class):
return eml_iterator(train_data, 'train', batch_size, samples_each_class, iter_size = 100, \
shuffle=True, color_jitter=False, rotate=False)
def quadruplet_train(batch_size, class_num, samples_each_class):
print('train batch size ', batch_size)
def quadruplet_train(class_num, samples_each_class):
return quadruplet_iterator(train_data, 'train', class_num, samples_each_class, iter_size=100, \
shuffle=True, color_jitter=False, rotate=False)
def triplet_train(batch_size):
assert(batch_size % 3 == 0)
print('train batch size ', batch_size)
return triplet_iterator(train_data, 'train', batch_size, iter_size = batch_size/3 * 100, \
shuffle=True, color_jitter=False, rotate=False)
def test(batch_size):
print('test batch size ', batch_size)
return image_iterator(test_image_list, "test", batch_size, shuffle=False)
def test():
return image_iterator(test_image_list, "test", shuffle=False)
def infer(batch_size):
print('inference batch size ', batch_size)
return image_iterator(infer_image_list, "infer", batch_size, shuffle=False)
def infer():
return image_iterator(infer_image_list, "infer", shuffle=False)
......@@ -2,102 +2,63 @@ import datareader as reader
import math
import numpy as np
import paddle.fluid as fluid
from metrics import calculate_order_dist_matrix
from metrics import get_gpu_num
class emlloss():
def __init__(self, train_batch_size = 50, test_batch_size = 50, infer_batch_size = 50, samples_each_class=2, fea_dim=2048):
self.train_reader = reader.eml_train(train_batch_size, samples_each_class)
self.test_reader = reader.test(test_batch_size)
self.infer_reader = reader.infer(infer_batch_size)
def __init__(self, train_batch_size = 40, samples_each_class=2):
num_gpus = get_gpu_num()
self.samples_each_class = samples_each_class
self.fea_dim = fea_dim
self.batch_size = train_batch_size
def surrogate_function(self, input):
beta = 100000
output = fluid.layers.log(1+beta*input)/math.log(1+beta)
self.train_batch_size = train_batch_size
assert(train_batch_size % num_gpus == 0)
self.cal_loss_batch_size = train_batch_size / num_gpus
assert(self.cal_loss_batch_size % samples_each_class == 0)
class_num = train_batch_size / samples_each_class
self.train_reader = reader.eml_train(train_batch_size, samples_each_class)
self.test_reader = reader.test()
def surrogate_function(self, beta, theta, bias):
x = theta * fluid.layers.exp(bias)
output = fluid.layers.log(1+beta*x)/math.log(1+beta)
return output
def surrogate_function_approximate(self, input, bias):
beta = 100000
output = (fluid.layers.log(beta*input)+bias)/math.log(1+beta)
def surrogate_function_approximate(self, beta, theta, bias):
output = (fluid.layers.log(theta) + bias + math.log(beta))/math.log(1+beta)
return output
def generate_index(self, batch_size, samples_each_class):
a = np.arange(0, batch_size*batch_size) # N*N x 1
a = a.reshape(-1,batch_size) # N x N
steps = batch_size//samples_each_class
res = []
for i in range(batch_size):
step = i // samples_each_class
start = step * samples_each_class
end = (step + 1) * samples_each_class
p = []
n = []
for j, k in enumerate(a[i]):
if j >= start and j < end:
p.append(k)
else:
n.append(k)
comb = p + n
res += comb
res = np.array(res).astype(np.int32)
return res
def surrogate_function_stable(self, beta, theta, target, thresh):
max_gap = fluid.layers.fill_constant([1], dtype='float32', value=thresh)
max_gap.stop_gradient = True
target_max = fluid.layers.elementwise_max(target, max_gap)
target_min = fluid.layers.elementwise_min(target, max_gap)
loss1 = self.surrogate_function(beta, theta, target_min)
loss2 = self.surrogate_function_approximate(beta, theta, target_max)
bias = self.surrogate_function(beta, theta, max_gap)
loss = loss1 + loss2 - bias
return loss
def loss(self, input):
samples_each_class = self.samples_each_class
fea_dim = self.fea_dim
batch_size = self.batch_size
feature1 = fluid.layers.reshape(input, shape = [-1, fea_dim])
feature2 = fluid.layers.transpose(feature1, perm = [1,0])
ab = fluid.layers.mul(x=feature1, y=feature2)
a2 = fluid.layers.square(feature1)
a2 = fluid.layers.reduce_sum(a2, dim = 1)
a2 = fluid.layers.reshape(a2, shape = [-1])
b2 = fluid.layers.square(feature2)
b2 = fluid.layers.reduce_sum(b2, dim = 0)
b2 = fluid.layers.reshape(b2, shape = [-1])
d = fluid.layers.elementwise_add(-2*ab, a2, axis = 0)
d = fluid.layers.elementwise_add(d, b2, axis = 1)
d = fluid.layers.reshape(d, shape=[-1, 1])
index = self.generate_index(batch_size, samples_each_class)
index_var = fluid.layers.create_global_var(
shape=[batch_size*batch_size], value=0, dtype='int32', persistable=True)
index_var = fluid.layers.assign(index, index_var)
index_var.stop_gradient = True
d = fluid.layers.gather(d, index=index_var)
d = fluid.layers.reshape(d, shape=[-1, batch_size])
pos, neg = fluid.layers.split(d,
num_or_sections= [samples_each_class,batch_size-samples_each_class],
dim=1)
batch_size = self.cal_loss_batch_size
d = calculate_order_dist_matrix(input, self.cal_loss_batch_size, self.samples_each_class)
ignore, pos, neg = fluid.layers.split(d, num_or_sections= [1,
samples_each_class-1, batch_size-samples_each_class], dim=1)
ignore.stop_gradient = True
pos_max = fluid.layers.reduce_max(pos, dim=1)
pos_max = fluid.layers.reshape(pos_max, shape=[-1, 1])
pos = fluid.layers.exp(pos-pos_max)
pos = fluid.layers.exp(pos - pos_max)
pos_mean = fluid.layers.reduce_mean(pos, dim=1)
neg_min = fluid.layers.reduce_min(neg, dim=1)
neg_min = fluid.layers.reshape(neg_min, shape=[-1, 1])
neg = fluid.layers.exp(-1*(neg-neg_min))
neg_mean = fluid.layers.reduce_mean(neg, dim=1)
bias = pos_max - neg_min
theta = fluid.layers.reshape(neg_mean * pos_mean, shape=[-1,1])
max_gap = fluid.layers.fill_constant([1], dtype='float32', value=20.0)
max_gap.stop_gradient = True
target = pos_max - neg_min
target_max = fluid.layers.elementwise_max(target, max_gap)
target_min = fluid.layers.elementwise_min(target, max_gap)
expadj_min = fluid.layers.exp(target_min)
loss1 = self.surrogate_function(theta*expadj_min)
loss2 = self.surrogate_function_approximate(theta, target_max)
bias = fluid.layers.exp(max_gap)
bias = self.surrogate_function(theta*bias)
loss = loss1 + loss2 - bias
thresh = 20.0
beta = 100000
loss = self.surrogate_function_stable(beta, theta, bias, thresh)
return loss
......@@ -17,3 +17,60 @@ def recall_topk(fea, lab, k = 1):
res += 1.0
res = res/len(fea)
return res
import subprocess
import os
def get_gpu_num():
visibledevice = os.getenv('CUDA_VISIBLE_DEVICES')
if visibledevice:
devicenum = len(visibledevice.split(','))
else:
devicenum = subprocess.check_output(['nvidia-smi', '-L']).count('\n')
return devicenum
import paddle as paddle
import paddle.fluid as fluid
def generate_index(batch_size, samples_each_class):
a = np.arange(0, batch_size * batch_size)
a = a.reshape(-1, batch_size)
steps = batch_size // samples_each_class
res = []
for i in range(batch_size):
step = i // samples_each_class
start = step * samples_each_class
end = (step + 1) * samples_each_class
p = []
n = []
for j, k in enumerate(a[i]):
if j >= start and j < end:
if j == i:
p.insert(0, k)
else:
p.append(k)
else:
n.append(k)
comb = p + n
res += comb
res = np.array(res).astype(np.int32)
return res
def calculate_order_dist_matrix(feature, batch_size, samples_each_class):
assert(batch_size % samples_each_class == 0)
feature = fluid.layers.reshape(feature, shape=[batch_size, -1])
ab = fluid.layers.matmul(feature, feature, False, True)
a2 = fluid.layers.square(feature)
a2 = fluid.layers.reduce_sum(a2, dim = 1)
d = fluid.layers.elementwise_add(-2*ab, a2, axis = 0)
d = fluid.layers.elementwise_add(d, a2, axis = 1)
d = fluid.layers.reshape(d, shape = [-1, 1])
index = generate_index(batch_size, samples_each_class)
index_var = fluid.layers.create_global_var(shape=[batch_size*batch_size], value=0, dtype='int32', persistable=True)
index_var = fluid.layers.assign(index, index_var)
d = fluid.layers.gather(d, index=index_var)
d = fluid.layers.reshape(d, shape=[-1, batch_size])
return d
import numpy as np
import datareader as reader
import paddle.fluid as fluid
from metrics import calculate_order_dist_matrix
from metrics import get_gpu_num
class quadrupletloss():
def __init__(self,
train_batch_size = 80,
test_batch_size = 10,
infer_batch_size = 10,
samples_each_class = 2,
num_gpus=8,
margin=0.1):
self.margin = margin
self.num_gpus = num_gpus
num_gpus = get_gpu_num()
self.samples_each_class = samples_each_class
self.train_batch_size = train_batch_size
assert(train_batch_size % (samples_each_class*num_gpus) == 0)
self.class_num = train_batch_size / self.samples_each_class
self.train_reader = reader.quadruplet_train(train_batch_size, self.class_num, self.samples_each_class)
self.test_reader = reader.test(test_batch_size)
self.infer_reader = reader.infer(infer_batch_size)
assert(train_batch_size % num_gpus == 0)
self.cal_loss_batch_size = train_batch_size / num_gpus
assert(self.cal_loss_batch_size % samples_each_class == 0)
class_num = train_batch_size / samples_each_class
self.train_reader = reader.quadruplet_train(class_num, samples_each_class)
self.test_reader = reader.test()
def loss(self, input):
feature = fluid.layers.l2_normalize(input, axis=1)
samples_each_class = self.samples_each_class
batch_size = self.cal_loss_batch_size
margin = self.margin
batch_size = self.train_batch_size / self.num_gpus
fea_dim = input.shape[1] # number of channels
output = fluid.layers.reshape(input, shape = [-1, fea_dim])
output = fluid.layers.l2_normalize(output, axis=1)
#scores = fluid.layers.matmul(output, output, transpose_x=False, transpose_y=True)
output_t = fluid.layers.transpose(output, perm = [1, 0])
scores = fluid.layers.mul(x=output, y=output_t)
mask_np = np.zeros((batch_size, batch_size), dtype=np.float32)
for i in xrange(batch_size):
for j in xrange(batch_size):
if i / self.samples_each_class == j / self.samples_each_class:
mask_np[i, j] = 100.
#mask = fluid.layers.create_tensor(dtype='float32')
mask = fluid.layers.create_global_var(
shape=[batch_size, batch_size], value=0, dtype='float32', persistable=True)
fluid.layers.assign(mask_np, mask)
scores = fluid.layers.scale(x=scores, scale=-1.0) + mask
scores_max = fluid.layers.reduce_max(scores, dim=0, keep_dim=True)
ind_max = fluid.layers.argmax(scores, axis=0)
ind_max = fluid.layers.cast(x=ind_max, dtype='float32')
ind2 = fluid.layers.argmax(scores_max, axis=1)
ind2 = fluid.layers.cast(x=ind2, dtype='int32')
ind1 = fluid.layers.gather(ind_max, ind2)
ind1 = fluid.layers.cast(x=ind1, dtype='int32')
scores_min = fluid.layers.reduce_min(scores, dim=0, keep_dim=True)
ind_min = fluid.layers.argmin(scores, axis=0)
ind_min = fluid.layers.cast(x=ind_min, dtype='float32')
ind4 = fluid.layers.argmin(scores_min, axis=1)
ind4 = fluid.layers.cast(x=ind4, dtype='int32')
ind3 = fluid.layers.gather(ind_min, ind4)
ind3 = fluid.layers.cast(x=ind3, dtype='int32')
f1 = fluid.layers.gather(output, ind1)
f2 = fluid.layers.gather(output, ind2)
f3 = fluid.layers.gather(output, ind3)
f4 = fluid.layers.gather(output, ind4)
ind1.stop_gradient = True
ind2.stop_gradient = True
ind3.stop_gradient = True
ind4.stop_gradient = True
ind_max.stop_gradient = True
ind_min.stop_gradient = True
scores_max.stop_gradient = True
scores_min.stop_gradient = True
scores.stop_gradient = True
mask.stop_gradient = True
output_t.stop_gradient = True
f1_2 = fluid.layers.square(f1 - f2)
f3_4 = fluid.layers.square(f3 - f4)
s1 = fluid.layers.reduce_sum(f1_2, dim = 1)
s2 = fluid.layers.reduce_sum(f3_4, dim = 1)
s1 = fluid.layers.sqrt(s1)
s2 = fluid.layers.sqrt(s2)
loss = fluid.layers.relu(s1 - s2 + margin)
d = calculate_order_dist_matrix(feature, self.cal_loss_batch_size, self.samples_each_class)
ignore, pos, neg = fluid.layers.split(d, num_or_sections= [1,
samples_each_class-1, batch_size-samples_each_class], dim=1)
ignore.stop_gradient = True
pos_max = fluid.layers.reduce_max(pos)
neg_min = fluid.layers.reduce_min(neg)
pos_max = fluid.layers.sqrt(pos_max)
neg_min = fluid.layers.sqrt(neg_min)
loss = fluid.layers.relu(pos_max - neg_min + margin)
return loss
......@@ -2,14 +2,10 @@ import datareader as reader
import paddle.fluid as fluid
class tripletloss():
def __init__(self,
train_batch_size = 120,
test_batch_size = 120,
infer_batch_size = 120,
margin=0.1):
def __init__(self, train_batch_size = 120, margin=0.1):
self.train_reader = reader.triplet_train(train_batch_size)
self.test_reader = reader.test(test_batch_size)
self.infer_reader = reader.infer(infer_batch_size)
self.test_reader = reader.test()
self.infer_reader = reader.infer()
self.margin = margin
def loss(self, input):
......
......@@ -17,39 +17,38 @@ from utility import add_arguments, print_arguments
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('train_batch_size', int, 80, "Minibatch size.")
add_arg('test_batch_size', int, 10, "Minibatch size.")
add_arg('use_gpu', bool, True, "Whether to use GPU or not.")
add_arg('num_epochs', int, 120, "Number of epochs.")
add_arg('image_shape', str, "3,224,224", "Input image size.")
add_arg('model_save_dir', str, "output", "Model save directory")
add_arg('with_mem_opt', bool, True, "Whether to use memory optimization or not.")
add_arg('pretrained_model', str, None, "Whether to use pretrained model.")
add_arg('checkpoint', str, None, "Whether to resume checkpoint.")
add_arg('lr', float, 0.1, "Set learning rate.")
add_arg('lr_strategy', str, "piecewise_decay", "Set the learning rate decay strategy.")
add_arg('model', str, "SE_ResNeXt50_32x4d", "Set the network to use.")
add_arg('loss_name', str, "tripletloss", "Set the loss type to use.")
add_arg('samples_each_class', int, 2, "Samples each class.")
add_arg('num_gpus', int, 8, "Number of gpus.")
add_arg('margin', float, 0.1, "Parameter margin.")
add_arg('alpha', float, 0.0, "Parameter alpha.")
add_arg('train_batch_size', int, 80, "Minibatch size.")
add_arg('test_batch_size', int, 10, "Minibatch size.")
add_arg('num_epochs', int, 120, "number of epochs.")
add_arg('image_shape', str, "3,224,224", "input image size")
add_arg('model_save_dir', str, "output", "model save directory")
add_arg('with_mem_opt', bool, True,
"Whether to use memory optimization or not.")
add_arg('pretrained_model', str, None, "Whether to use pretrained model.")
add_arg('checkpoint', str, None, "Whether to resume checkpoint.")
add_arg('lr', float, 0.1, "set learning rate.")
add_arg('lr_strategy', str, "piecewise_decay",
"Set the learning rate decay strategy.")
add_arg('model', str, "SE_ResNeXt50_32x4d", "Set the network to use.")
add_arg('loss_name', str, "tripletloss", "Set the loss type to use.")
add_arg('samples_each_class', int, 2, "Samples each class.")
add_arg('margin', float, 0.1, "margin.")
add_arg('alpha', float, 0.0, "alpha.")
# yapf: enable
model_list = [m for m in dir(models) if "__" not in m]
def optimizer_setting(params):
ls = params["learning_strategy"]
assert ls[
"name"] == "piecewise_decay", "learning rate strategy must be {}, but got {}".format(
"piecewise_decay", lr["name"])
assert ls["name"] == "piecewise_decay", \
"learning rate strategy must be {}, \
but got {}".format("piecewise_decay", lr["name"])
step = 10000
bd = [step * e for e in ls["epochs"]]
base_lr = params["lr"]
lr = []
lr = [base_lr * (0.1**i) for i in range(len(bd) + 1)]
lr = [base_lr * (0.1 ** i) for i in range(len(bd) + 1)]
optimizer = fluid.optimizer.Momentum(
learning_rate=fluid.layers.piecewise_decay(
boundaries=bd, values=lr),
......@@ -58,7 +57,6 @@ def optimizer_setting(params):
return optimizer
def train(args):
# parameters from arguments
model_name = args.model
......@@ -70,54 +68,47 @@ def train(args):
image_shape = [int(m) for m in args.image_shape.split(",")]
assert model_name in model_list, "{} is not in lists: {}".format(args.model,
model_list)
assert model_name in model_list, "{} is not in lists: {}".format(args.model, model_list)
image = fluid.layers.data(name='image', shape=image_shape, dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
# model definition
model = models.__dict__[model_name]()
out = model.net(input=image, class_dim=200)
if loss_name == "tripletloss":
metricloss = tripletloss(
train_batch_size=args.train_batch_size,
test_batch_size=args.test_batch_size,
margin=0.1)
train_batch_size = args.train_batch_size,
margin=args.margin)
cost_metric = metricloss.loss(out[0])
avg_cost_metric = fluid.layers.mean(x=cost_metric)
elif loss_name == "quadrupletloss":
metricloss = quadrupletloss(
train_batch_size=args.train_batch_size,
test_batch_size=args.test_batch_size,
samples_each_class=args.samples_each_class,
num_gpus=args.num_gpus,
margin=args.margin)
train_batch_size = args.train_batch_size,
samples_each_class = args.samples_each_class,
margin=args.margin)
cost_metric = metricloss.loss(out[0])
avg_cost_metric = fluid.layers.mean(x=cost_metric)
elif loss_name == "emlloss":
metricloss = emlloss(
train_batch_size=args.train_batch_size,
test_batch_size=args.test_batch_size,
samples_each_class=2,
fea_dim=2048)
train_batch_size = args.train_batch_size,
samples_each_class=2
)
cost_metric = metricloss.loss(out[0])
avg_cost_metric = fluid.layers.mean(x=cost_metric)
else:
print("loss name is not supported!")
exit()
cost_cls = fluid.layers.cross_entropy(input=out[1], label=label)
avg_cost_cls = fluid.layers.mean(x=cost_cls)
acc_top1 = fluid.layers.accuracy(input=out[1], label=label, k=1)
acc_top5 = fluid.layers.accuracy(input=out[1], label=label, k=5)
avg_cost = avg_cost_metric + args.alpha * avg_cost_cls
avg_cost = avg_cost_metric + args.alpha*avg_cost_cls
test_program = fluid.default_main_program().clone(for_test=True)
# parameters from model and arguments
params = model.params
params = model.params
params["lr"] = args.lr
params["num_epochs"] = args.num_epochs
params["learning_strategy"]["batch_size"] = args.train_batch_size
......@@ -129,7 +120,7 @@ def train(args):
global_lr = optimizer._global_learning_rate()
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
......@@ -137,35 +128,31 @@ def train(args):
fluid.io.load_persistables(exe, checkpoint)
if pretrained_model:
assert(checkpoint is None)
def if_exist(var):
return os.path.exists(os.path.join(pretrained_model, var.name))
has_var = os.path.exists(os.path.join(pretrained_model, var.name))
if has_var:
print('var: %s found' % (var.name))
return has_var
fluid.io.load_vars(exe, pretrained_model, predicate=if_exist)
train_reader = metricloss.train_reader
test_reader = metricloss.test_reader
train_reader = paddle.batch(metricloss.train_reader, batch_size=args.train_batch_size)
test_reader = paddle.batch(metricloss.test_reader, batch_size=args.test_batch_size)
feeder = fluid.DataFeeder(place=place, feed_list=[image, label])
train_exe = fluid.ParallelExecutor(
use_cuda=True if args.use_gpu else False, loss_name=avg_cost.name)
train_exe = fluid.ParallelExecutor(use_cuda=True, loss_name=avg_cost.name)
fetch_list_train = [
avg_cost_metric.name, avg_cost_cls.name, acc_top1.name, acc_top5.name,
global_lr.name
]
fetch_list_train = [avg_cost_metric.name, avg_cost_cls.name, acc_top1.name, acc_top5.name, global_lr.name]
fetch_list_test = [out[0].name]
if with_memory_optimization:
fluid.memory_optimize(
fluid.default_main_program(), skip_opt_set=set(fetch_list_train))
fluid.memory_optimize(fluid.default_main_program(), skip_opt_set=set(fetch_list_train))
for pass_id in range(params["num_epochs"]):
train_info = [[], [], [], []]
for batch_id, data in enumerate(train_reader()):
t1 = time.time()
loss_metric, loss_cls, acc1, acc5, lr = train_exe.run(
fetch_list_train, feed=feeder.feed(data))
loss_metric, loss_cls, acc1, acc5, lr = train_exe.run(fetch_list_train, feed=feeder.feed(data))
t2 = time.time()
period = t2 - t1
loss_metric = np.mean(np.array(loss_metric))
......@@ -187,24 +174,24 @@ def train(args):
train_acc5 = np.array(train_info[3]).mean()
f = []
l = []
for batch_id, (data, label) in enumerate(test_reader()):
for batch_id, data in enumerate(test_reader()):
if len(data) < args.test_batch_size:
continue
t1 = time.time()
[feas] = exe.run(test_program,
fetch_list=fetch_list_test,
feed=feeder.feed(data))
[feas] = exe.run(test_program, fetch_list = fetch_list_test, feed=feeder.feed(data))
label = np.asarray([x[1] for x in data])
f.append(feas)
l.append(label)
t2 = time.time()
period = t2 - t1
if batch_id % 10 == 0:
recall = recall_topk(feas, label, k=1)
print("Pass {0}, testbatch {1}, recall {2}, time {3}".format(pass_id, \
batch_id, recall, "%2.2f sec" % period))
if batch_id % 20 == 0:
print("Pass {0}, testbatch {1}, time {2}".format(pass_id, \
batch_id, "%2.2f sec" % period))
f = np.vstack(f)
l = np.vstack(l)
recall = recall_topk(f, l, k=1)
l = np.hstack(l)
recall = recall_topk(f, l, k = 1)
print("End pass {0}, train_loss_metric {1}, train_loss_cls {2}, train_acc1 {3}, train_acc5 {4}, test_recall {5}".format(pass_id, \
train_loss_metric, train_loss_cls, train_acc1, train_acc5, recall))
sys.stdout.flush()
......@@ -215,12 +202,10 @@ def train(args):
os.makedirs(model_path)
fluid.io.save_persistables(exe, model_path)
def main():
args = parser.parse_args()
print_arguments(args)
train(args)
if __name__ == '__main__':
main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册