提交 56c411ab 编写于 作者: X xiteng1988

add slimfacenet

上级 2dc89a6b
import numpy as np
import scipy.misc
import paddle
from paddle import fluid
class LFW(object):
def __init__(self, imgl, imgr):
self.imgl_list = imgl
self.imgr_list = imgr
self.shuffle_idx = [i for i in range(len(self.imgl_list))]
def reader(self):
while True:
if len(self.shuffle_idx) == 0:
self.shuffle_idx = [i for i in range(len(self.imgl_list))]
return
index = self.shuffle_idx.pop(0)
imgl = scipy.misc.imread(self.imgl_list[index])
if len(imgl.shape) == 2:
imgl = np.stack([imgl] * 3, 2)
imgr = scipy.misc.imread(self.imgr_list[index])
if len(imgr.shape) == 2:
imgr = np.stack([imgr] * 3, 2)
imglist = [imgl, imgl[:, ::-1, :], imgr, imgr[:, ::-1, :]]
for i in range(len(imglist)):
imglist[i] = (imglist[i] - 127.5) / 128.0
imglist[i] = imglist[i].transpose(2, 0, 1)
imgs = [img.astype('float32') for img in imglist]
yield imgs
def __len__(self):
return len(self.imgl_list)
if __name__ == '__main__':
pass
\ No newline at end of file
import os
import shutil
import subprocess
import argparse
import time
import scipy.io
import numpy as np
import paddle
from paddle import fluid
#from dataloader.CASIA import CASIA_Face
from dataloader.LFW import LFW
from lfw_eval import parseList, evaluation_10_fold
from models.slimfacenet import SlimFaceNet
def now():
return time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time()))
def creat_optimizer(args, trainset_scale):
start_step = trainset_scale * args.start_epoch // args.train_batchsize
if args.lr_strategy == 'piecewise_decay':
bd = [trainset_scale * int(e) // args.train_batchsize for e in args.lr_steps.strip().split(',')]
lr = [float(e) for e in args.lr_list.strip().split(',')]
assert len(bd) == len(lr) - 1
optimizer = fluid.optimizer.Momentum(
learning_rate=fluid.layers.piecewise_decay(boundaries=bd, values=lr),
momentum=0.9,
regularization=fluid.regularizer.L2Decay(args.l2_decay))
elif args.lr_strategy == 'cosine_decay':
lr = args.lr
step_each_epoch = trainset_scale // args.train_batchsize
optimizer = fluid.optimizer.Momentum(
learning_rate=fluid.layers.cosine_decay(lr, step_each_epoch, args.total_epoch),
momentum=0.9,
regularization=fluid.regularizer.L2Decay(args.l2_decay))
else:
print('Wrong learning rate strategy')
exit()
return optimizer
def test(test_exe, test_program, test_out, args):
featureLs = None
featureRs = None
out_feature, test_reader, flods, flags = test_out
for idx, data in enumerate(test_reader()):
res = []
res.append(test_exe.run(test_program, feed = {u'image_test': data[0][u'image_test1']}, fetch_list = out_feature))
res.append(test_exe.run(test_program, feed = {u'image_test': data[0][u'image_test2']}, fetch_list = out_feature))
res.append(test_exe.run(test_program, feed = {u'image_test': data[0][u'image_test3']}, fetch_list = out_feature))
res.append(test_exe.run(test_program, feed = {u'image_test': data[0][u'image_test4']}, fetch_list = out_feature))
featureL = np.concatenate((res[0][0], res[1][0]), 1)
featureR = np.concatenate((res[2][0], res[3][0]), 1)
if featureLs is None:
featureLs = featureL
else:
featureLs = np.concatenate((featureLs, featureL), 0)
if featureRs is None:
featureRs = featureR
else:
featureRs = np.concatenate((featureRs, featureR), 0)
result = {'fl': featureLs, 'fr': featureRs, 'fold': flods, 'flag': flags}
scipy.io.savemat(args.feature_save_dir, result)
ACCs = evaluation_10_fold(args.feature_save_dir)
print('eval arch {}'.format(args.arch))
with open(os.path.join(args.save_ckpt, 'log.txt'), 'a+') as f:
f.writelines('eval arch {}\n'.format(args.arch))
for i in range(len(ACCs)):
#print('{} {:.2f}'.format(i+1, ACCs[i] * 100))
print('{} {}'.format(i+1, ACCs[i] * 100))
with open(os.path.join(args.save_ckpt, 'log.txt'), 'a+') as f:
#f.writelines('{} {:.2f}\n'.format(i+1, ACCs[i] * 100))
f.writelines('{} {}\n'.format(i+1, ACCs[i] * 100))
print('--------')
#print('AVE {:.2f}'.format(np.mean(ACCs) * 100))
print('AVE {}'.format(np.mean(ACCs) * 100))
with open(os.path.join(args.save_ckpt, 'log.txt'), 'a+') as f:
f.writelines('--------\n')
#f.writelines('AVE {:.2f}\n'.format(np.mean(ACCs) * 100))
f.writelines('AVE {}\n'.format(np.mean(ACCs) * 100))
return np.mean(ACCs) * 100
def train(exe, train_program, train_out, test_program, test_out, args):
loss, acc, global_lr, train_reader = train_out
fetch_list_train = [loss.name, acc.name, global_lr.name]
train_exe = fluid.ParallelExecutor(
use_cuda=True,
loss_name=loss.name,
main_program=train_program)
for epoch_id in range(args.start_epoch, args.total_epoch):
for batch_id, data in enumerate(train_reader()):
loss, acc, global_lr = train_exe.run(feed=data, fetch_list=fetch_list_train)
avg_loss = np.mean(np.array(loss))
avg_acc = np.mean(np.array(acc))
print('{} Epoch: {:^4d} step: {:^4d} loss: {:.6f}, acc: {:.6f}, lr: {}'.format(
now(), epoch_id, batch_id, avg_loss, avg_acc, float(np.mean(np.array(global_lr)))))
#test(exe, test_program, test_out, args)
if batch_id % args.save_frequency == 0:
model_path = os.path.join(args.save_ckpt, str(epoch_id))
fluid.io.save_persistables(executor=exe, dirname=model_path, main_program=train_program)
test(exe, test_program, test_out, args)
def build_program(program, startup, args, is_train=True):
num_trainers = len(os.getenv('CUDA_VISIBLE_DEVICES').split(','))
places = fluid.cuda_places() if args.use_gpu else fluid.CPUPlace()
train_dataset = CASIA_Face(root = args.train_data_dir)
trainset_scale = len(train_dataset)
with fluid.program_guard(main_program=program, startup_program=startup):
with fluid.unique_name.guard():
# Model construction
arch = [int(a) for a in args.arch.strip().split(',')]
model = SlimFaceNet(class_dim = train_dataset.class_nums, arch = arch)
if is_train:
image = fluid.layers.data(name='image', shape=[-1, 3, 112, 112], dtype='float32')
label = fluid.layers.data(name='label', shape=[-1, 1], dtype='int64')
train_reader = paddle.batch(train_dataset.reader, batch_size = args.train_batchsize // num_trainers, drop_last = False)
reader = fluid.io.PyReader(feed_list=[image, label], capacity=64, iterable=True, return_list=False)
reader.decorate_sample_list_generator(train_reader, places=places)
model.extract_feature = False
loss, acc = model.net(image, label)
optimizer = creat_optimizer(args, trainset_scale)
optimizer.minimize(loss)
global_lr = optimizer._global_learning_rate()
out = (loss, acc, global_lr, reader)
else:
nl, nr, flods, flags = parseList(args.test_data_dir)
test_dataset = LFW(nl, nr)
test_reader = paddle.batch(test_dataset.reader, batch_size = args.test_batchsize, drop_last = False)
image_test = fluid.layers.data(name='image_test', shape=[-1, 3, 112, 112], dtype='float32')
image_test1 = fluid.layers.data(name='image_test1', shape=[-1, 3, 112, 112], dtype='float32')
image_test2 = fluid.layers.data(name='image_test2', shape=[-1, 3, 112, 112], dtype='float32')
image_test3 = fluid.layers.data(name='image_test3', shape=[-1, 3, 112, 112], dtype='float32')
image_test4 = fluid.layers.data(name='image_test4', shape=[-1, 3, 112, 112], dtype='float32')
reader = fluid.io.PyReader(feed_list=[image_test1, image_test2, image_test3, image_test4], capacity=64, iterable=True, return_list=False)
reader.decorate_sample_list_generator(test_reader, fluid.core.CPUPlace())
model.extract_feature = True
feature = model.net(image_test)
out = (feature, reader, flods, flags)
return out
def main():
global args
parser = argparse.ArgumentParser(description='PaddlePaddle SlimFaceNet')
parser.add_argument('--action', default='final', type=str, help='test/final')
parser.add_argument('--model', default='slimfacenet', type=str, help='slimfacenet/slimfacenet_v1')
parser.add_argument('--arch', default='1,1,0,1,1,1,1,0,1,0,1,3,2,2,3', type=str, help='arch')
parser.add_argument('--use_gpu', default=1, type=int, help='Use GPU or not, 0 is not used')
parser.add_argument('--use_multiGPU', default=0, type=int, help='Use multi GPU or not, 0 is not used')
parser.add_argument('--lr_strategy', default='piecewise_decay', type=str, help='lr_strategy')
parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
parser.add_argument('--lr_list', default='0.1,0.01,0.001,0.0001', type=str, help='learning rate list (piecewise_decay)')
parser.add_argument('--lr_steps', default='36,52,58', type=str, help='learning rate decay at which epochs')
parser.add_argument('--l2_decay', default=4e-5, type=float, help='base l2_decay')
parser.add_argument('--train_data_dir', default='./CASIA', type=str, help='train_data_dir')
parser.add_argument('--test_data_dir', default='./lfw', type=str, help='lfw_data_dir')
parser.add_argument('--train_batchsize', default=512, type=int, help='train_batchsize')
parser.add_argument('--test_batchsize', default=500, type=int, help='test_batchsize')
parser.add_argument('--img_shape', default='3,112,96', type=str, help='img_shape')
parser.add_argument('--start_epoch', default=0, type=int, help='start_epoch')
parser.add_argument('--total_epoch', default=80, type=int, help='total_epoch')
parser.add_argument('--save_frequency', default=1, type=int, help='save_frequency')
parser.add_argument('--save_ckpt', default='output', type=str, help='save_ckpt')
parser.add_argument('--resume', default='', type=str, help='resume')
parser.add_argument('--feature_save_dir', default='result.mat', type=str, help='The path of the extract features save, must be .mat file')
args = parser.parse_args()
num_trainers = len(os.getenv('CUDA_VISIBLE_DEVICES').split(','))
print(args)
print('num_trainers: {}'.format(num_trainers))
if args.save_ckpt == None:
args.save_ckpt = 'output'
if not os.path.exists(args.save_ckpt):
subprocess.call(['mkdir', '-p', args.save_ckpt])
shutil.copyfile(__file__, os.path.join(args.save_ckpt, 'train.py'))
shutil.copyfile('models/slimfacenet.py', os.path.join(args.save_ckpt, 'model.py'))
with open(os.path.join(args.save_ckpt, 'log.txt'), 'w+') as f:
f.writelines(str(args) + '\n')
f.writelines('num_trainers: {}'.format(num_trainers) + '\n')
startup_program = fluid.Program()
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(startup_program)
[inference_program, feed_target_names, fetch_targets] = fluid.io.load_inference_model(dirname='./quant_model/',
model_filename=None,
params_filename=None,
executor=exe)
#if args.action == 'final':
# train(exe, train_program, train_out, test_program, test_out, args)
if args.action == 'test':
nl, nr, flods, flags = parseList(args.test_data_dir)
test_dataset = LFW(nl, nr)
test_reader = paddle.batch(test_dataset.reader, batch_size = args.test_batchsize, drop_last = False)
image_test = fluid.layers.data(name='image_test', shape=[-1, 3, 112, 96], dtype='float32')
image_test1 = fluid.layers.data(name='image_test1', shape=[-1, 3, 112, 96], dtype='float32')
image_test2 = fluid.layers.data(name='image_test2', shape=[-1, 3, 112, 96], dtype='float32')
image_test3 = fluid.layers.data(name='image_test3', shape=[-1, 3, 112, 96], dtype='float32')
image_test4 = fluid.layers.data(name='image_test4', shape=[-1, 3, 112, 96], dtype='float32')
reader = fluid.io.PyReader(feed_list=[image_test1, image_test2, image_test3, image_test4], capacity=64, iterable=True, return_list=False)
reader.decorate_sample_list_generator(test_reader, fluid.core.CPUPlace())
test_out = (fetch_targets, reader, flods, flags)
print('fetch_targets[0]: ', fetch_targets[0])
print('feed_target_names: ', feed_target_names)
test(exe, inference_program, test_out, args)
else:
print('WRONG ACTION')
if __name__ == '__main__':
main()
import os
import argparse
import time
import scipy.io
import numpy as np
import paddle
from paddle import fluid
#from dataloader.CASIA import CASIA_Face
from dataloader.LFW import LFW
from models.slimfacenet import SlimFaceNet
def parseList(root):
with open(os.path.join(root, 'pairs.txt')) as f:
pairs = f.read().splitlines()[1:]
folder_name = 'lfw-112X96'
nameLs = []
nameRs = []
folds = []
flags = []
for i, p in enumerate(pairs):
p = p.split('\t')
if len(p) == 3:
nameL = os.path.join(root, folder_name, p[0], p[0] + '_' + '{:04}.jpg'.format(int(p[1])))
nameR = os.path.join(root, folder_name, p[0], p[0] + '_' + '{:04}.jpg'.format(int(p[2])))
fold = i // 600
flag = 1
elif len(p) == 4:
nameL = os.path.join(root, folder_name, p[0], p[0] + '_' + '{:04}.jpg'.format(int(p[1])))
nameR = os.path.join(root, folder_name, p[2], p[2] + '_' + '{:04}.jpg'.format(int(p[3])))
fold = i // 600
flag = -1
nameLs.append(nameL)
nameRs.append(nameR)
folds.append(fold)
flags.append(flag)
return [nameLs, nameRs, folds, flags]
def getAccuracy(scores, flags, threshold):
p = np.sum(scores[flags == 1] > threshold)
n = np.sum(scores[flags == -1] < threshold)
return 1.0 * (p + n) / len(scores)
def getThreshold(scores, flags, thrNum):
accuracys = np.zeros((2 * thrNum + 1, 1))
thresholds = np.arange(-thrNum, thrNum + 1) * 1.0 / thrNum
for i in range(2 * thrNum + 1):
accuracys[i] = getAccuracy(scores, flags, thresholds[i])
max_index = np.squeeze(accuracys == np.max(accuracys))
bestThreshold = np.mean(thresholds[max_index])
return bestThreshold
def evaluation_10_fold(root='result.mat'):
ACCs = np.zeros(10)
result = scipy.io.loadmat(root)
for i in range(10):
fold = result['fold']
flags = result['flag']
featureLs = result['fl']
featureRs = result['fr']
valFold = fold != i
testFold = fold == i
flags = np.squeeze(flags)
mu = np.mean(np.concatenate((featureLs[valFold[0], :], featureRs[valFold[0], :]), 0), 0)
mu = np.expand_dims(mu, 0)
featureLs = featureLs - mu
featureRs = featureRs - mu
featureLs = featureLs / np.expand_dims(np.sqrt(np.sum(np.power(featureLs, 2), 1)), 1)
featureRs = featureRs / np.expand_dims(np.sqrt(np.sum(np.power(featureRs, 2), 1)), 1)
scores = np.sum(np.multiply(featureLs, featureRs), 1)
threshold = getThreshold(scores[valFold[0]], flags[valFold[0]], 10000)
ACCs[i] = getAccuracy(scores[testFold[0]], flags[testFold[0]], threshold)
return ACCs
def test(test_reader, flods, flags, net, args):
net.eval()
featureLs = None
featureRs = None
for idx, data in enumerate(test_reader()):
data_list = [[] for _ in range(4)]
for _ in range(len(data)):
data_list[0].append(data[_][0])
data_list[1].append(data[_][1])
data_list[2].append(data[_][2])
data_list[3].append(data[_][3])
res = [net(fluid.dygraph.to_variable(np.array(d))).numpy() for d in data_list]
featureL = np.concatenate((res[0], res[1]), 1)
featureR = np.concatenate((res[2], res[3]), 1)
if featureLs is None:
featureLs = featureL
else:
featureLs = np.concatenate((featureLs, featureL), 0)
if featureRs is None:
featureRs = featureR
else:
featureRs = np.concatenate((featureRs, featureR), 0)
result = {'fl': featureLs, 'fr': featureRs, 'fold': flods, 'flag': flags}
scipy.io.savemat(args.feature_save_dir, result)
ACCs = evaluation_10_fold(args.feature_save_dir)
for i in range(len(ACCs)):
print('{} {:.2f}'.format(i+1, ACCs[i] * 100))
print('--------')
print('AVE {:.2f}'.format(np.mean(ACCs) * 100))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='PaddlePaddle SlimFaceNet')
parser.add_argument('--use_gpu', default=0, type=int, help='Use GPU or not, 0 is not used')
parser.add_argument('--test_data_dir', default='./lfw', type=str, help='lfw_data_dir')
parser.add_argument('--resume', default='output/0', type=str, help='resume')
parser.add_argument('--feature_save_dir', default='result.mat', type=str, help='The path of the extract features save, must be .mat file')
args = parser.parse_args()
place = fluid.CPUPlace() if args.use_gpu == 0 else fluid.CUDAPlace(0)
with fluid.dygraph.guard(place):
train_dataset = CASIA_Face(root = args.train_data_dir)
nl, nr, flods, flags = parseList(args.test_data_dir)
test_dataset = LFW(nl, nr)
test_reader = paddle.batch(test_dataset.reader, batch_size = args.test_batchsize, drop_last = False)
net = SlimFaceNet(train_dataset.class_nums, args.img_shape)
if args.resume:
assert os.path.exists(args.resume + ".pdparams"), "Given dir {}.pdparams not exist.".format(args.resume)
para_dict, opti_dict = fluid.dygraph.load_dygraph(args.resume)
net.set_dict(para_dict)
test(test_reader, flods, flags, net, args)
from .slimfacenet import SlimFaceNet
from collections import OrderedDict
from prettytable import PrettyTable
import distutils.util
import numpy as np
import six
def summary(main_prog):
'''
It can summary model's PARAMS, FLOPs until now.
It support common operator like conv, fc, pool, relu, sigmoid, bn etc.
Args:
main_prog: main program
Returns:
print summary on terminal
'''
collected_ops_list = []
is_quantize = False
for one_b in main_prog.blocks:
block_vars = one_b.vars
for one_op in one_b.ops:
# if str(one_op.type).find('quantize') > -1:
# is_quantize = True
op_info = OrderedDict()
spf_res = _summary_model(block_vars, one_op)
if spf_res is None:
continue
# TODO: get the operator name
op_info['type'] = one_op.type
op_info['input_shape'] = spf_res[0][1:]
op_info['out_shape'] = spf_res[1][1:]
op_info['PARAMs'] = spf_res[2]
op_info['FLOPs'] = spf_res[3]
collected_ops_list.append(op_info)
summary_table, total = _format_summary(collected_ops_list)
_print_summary(summary_table, total)
return total, is_quantize
def _summary_model(block_vars, one_op):
'''
Compute operator's params and flops.
Args:
block_vars: all vars of one block
one_op: one operator to count
Returns:
in_data_shape: one operator's input data shape
out_data_shape: one operator's output data shape
params: one operator's PARAMs
flops: : one operator's FLOPs
'''
if one_op.type in ['conv2d', 'depthwise_conv2d']:
k_arg_shape = block_vars[one_op.input("Filter")[0]].shape
in_data_shape = block_vars[one_op.input("Input")[0]].shape
out_data_shape = block_vars[one_op.output("Output")[0]].shape
c_out, c_in, k_h, k_w = k_arg_shape
_, c_out_, h_out, w_out = out_data_shape
#assert c_out == c_out_, 'shape error!'
k_groups = one_op.attr("groups")
kernel_ops = k_h * k_w * (in_data_shape[1] / k_groups)
try:
bias_ops = 0 if one_op.input("Bias") == [] else 1
except:
bias_ops = 0
params = c_out * (kernel_ops + bias_ops)
flops = h_out * w_out * c_out * (kernel_ops + bias_ops)
# base nvidia paper, include mul and add
flops = 2 * flops
if one_op.type == 'depthwise_conv2d':
pass
# var_name = block_vars[one_op.input("Filter")[0]].name
# if var_name.endswith('.int8'):
# flops /= 2.0
elif one_op.type == 'pool2d':
in_data_shape = block_vars[one_op.input("X")[0]].shape
out_data_shape = block_vars[one_op.output("Out")[0]].shape
_, c_out, h_out, w_out = out_data_shape
k_size = one_op.attr("ksize")
params = 0
flops = h_out * w_out * c_out * (k_size[0] * k_size[1])
elif one_op.type == 'mul':
k_arg_shape = block_vars[one_op.input("Y")[0]].shape
in_data_shape = block_vars[one_op.input("X")[0]].shape
out_data_shape = block_vars[one_op.output("Out")[0]].shape
# TODO: fc has mul ops
# add attr to mul op, tell us whether it belongs to 'fc'
# this's not the best way
if 'fc' not in one_op.output("Out")[0]:
return None
k_in, k_out = k_arg_shape
# bias in sum op
params = k_in * k_out + 1
flops = k_in * k_out
# var_name = block_vars[one_op.input("Y")[0]].name
# if var_name.endswith('.int8'):
# flops /= 2.0
elif one_op.type in ['sigmoid', 'tanh', 'relu', 'leaky_relu', 'prelu']:
in_data_shape = block_vars[one_op.input("X")[0]].shape
out_data_shape = block_vars[one_op.output("Out")[0]].shape
params = 0
if one_op.type == 'prelu':
params = 1
flops = 1
for one_dim in in_data_shape[1:]:
flops *= one_dim
elif one_op.type == 'batch_norm':
in_data_shape = block_vars[one_op.input("X")[0]].shape
out_data_shape = block_vars[one_op.output("Y")[0]].shape
_, c_in, h_out, w_out = in_data_shape
# gamma, beta
params = c_in * 2
# compute mean and std
flops = h_out * w_out * c_in * 2
else:
return None
return in_data_shape, out_data_shape, params, flops
def _format_summary(collected_ops_list):
'''
Format summary report.
Args:
collected_ops_list: the collected operator with summary
Returns:
summary_table: summary report format
total: sum param and flops
'''
summary_table = PrettyTable(
["No.", "TYPE", "INPUT", "OUTPUT", "PARAMs", "FLOPs"])
summary_table.align = 'r'
total = {}
total_params = []
total_flops = []
for i, one_op in enumerate(collected_ops_list):
# notice the order
table_row = [
i,
one_op['type'],
one_op['input_shape'],
one_op['out_shape'],
int(one_op['PARAMs']),
int(one_op['FLOPs']),
]
summary_table.add_row(table_row)
total_params.append(int(one_op['PARAMs']))
total_flops.append(int(one_op['FLOPs']))
total['params'] = total_params
total['flops'] = total_flops
return summary_table, total
def _print_summary(summary_table, total):
'''
Print all the summary on terminal.
Args:
summary_table: summary report format
total: sum param and flops
'''
parmas = total['params']
flops = total['flops']
print(summary_table)
print('Total PARAMs: {}({:.4f}M)'.format(
sum(parmas), sum(parmas) / (10.0 ** 6)))
print('Total FLOPs: {}({:.4f}G)'.format(sum(flops), sum(flops) / 10.0 ** 6))
print('Total MAdds: {}({:.4f}G)'.format(sum(flops)/2, sum(flops) / 10.0 ** 6 / 2))
print(
"Notice: \n now supported ops include [Conv, DepthwiseConv, FC(mul), BatchNorm, Pool, Activation(sigmoid, tanh, relu, leaky_relu, prelu)]"
)
def get_batch_dt_res(nmsed_out_v, data, contiguous_category_id_to_json_id, batch_size):
dts_res = []
lod = nmsed_out_v[0].lod()[0]
nmsed_out_v = np.array(nmsed_out_v[0])
real_batch_size = min(batch_size, len(data))
assert (len(lod) == real_batch_size + 1), \
"Error Lod Tensor offset dimension. Lod({}) vs. batch_size({})".format(len(lod), batch_size)
k = 0
for i in range(real_batch_size):
dt_num_this_img = lod[i + 1] - lod[i]
image_id = int(data[i][4][0])
image_width = int(data[i][4][1])
image_height = int(data[i][4][2])
for j in range(dt_num_this_img):
dt = nmsed_out_v[k]
k = k + 1
category_id, score, xmin, ymin, xmax, ymax = dt.tolist()
xmin = max(min(xmin, 1.0), 0.0) * image_width
ymin = max(min(ymin, 1.0), 0.0) * image_height
xmax = max(min(xmax, 1.0), 0.0) * image_width
ymax = max(min(ymax, 1.0), 0.0) * image_height
w = xmax - xmin
h = ymax - ymin
bbox = [xmin, ymin, w, h]
dt_res = {
'image_id': image_id,
'category_id': contiguous_category_id_to_json_id[category_id],
'bbox': bbox,
'score': score
}
dts_res.append(dt_res)
return dts_res
import math
import datetime
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.initializer import MSRA
from paddle.fluid.param_attr import ParamAttr
class SlimFaceNet():
def __init__(self, class_dim, scale=0.6, arch=None):
assert arch is not None
self.arch = arch
self.class_dim = class_dim
kernels = [3]
expansions = [2, 4, 6]
SE = [0, 1]
self.table = []
for k in kernels:
for e in expansions:
for se in SE:
self.table.append((k, e, se))
if scale == 1.0:
# 100% - channel
self.Slimfacenet_bottleneck_setting = [
# t, c , n ,s
[2, 64, 5, 2],
[4, 128, 1, 2],
[2, 128, 6, 1],
[4, 128, 1, 2],
[2, 128, 2, 1]
]
elif scale == 0.9:
# 90% - channel
self.Slimfacenet_bottleneck_setting = [
# t, c , n ,s
[2, 56, 5, 2],
[4, 116, 1, 2],
[2, 116, 6, 1],
[4, 116, 1, 2],
[2, 116, 2, 1]
]
elif scale == 0.75:
# 75% - channel
self.Slimfacenet_bottleneck_setting = [
# t, c , n ,s
[2, 48, 5, 2],
[4, 96, 1, 2],
[2, 96, 6, 1],
[4, 96, 1, 2],
[2, 96, 2, 1]
]
elif scale == 0.6:
# 60% - channel
self.Slimfacenet_bottleneck_setting = [
# t, c , n ,s
[2, 40, 5, 2],
[4, 76, 1, 2],
[2, 76, 6, 1],
[4, 76, 1, 2],
[2, 76, 2, 1]
]
else:
print('WRONG scale')
exit()
self.extract_feature = True
def set_extract_feature_flag(self, flag):
self.extract_feature = flag
def net(self, input, label=None):
x = self.conv_bn_layer(input, filter_size=3, num_filters=64, stride=2, padding=1, num_groups=1, if_act=True, name='conv3x3')
x = self.conv_bn_layer(x, filter_size=3, num_filters=64, stride=1, padding=1, num_groups=64, if_act=True, name='dw_conv3x3')
in_c = 64
cnt = 0
for _exp, out_c , times, _stride in self.Slimfacenet_bottleneck_setting:
for i in range(times):
stride = _stride if i==0 else 1
filter_size, exp, se = self.table[self.arch[cnt]]
se = False if se==0 else True
x = self.residual_unit(x, num_in_filter=in_c, num_out_filter=out_c, stride=stride, filter_size=filter_size, expansion_factor=exp, use_se=se, name='residual_unit'+str(cnt+1))
cnt += 1
in_c = out_c
out_c = 512
x = self.conv_bn_layer(x, filter_size=1, num_filters=out_c, stride=1, padding=0, num_groups=1, if_act=True, name='conv1x1')
# Replace dw_conv7x7 with dw_conv5x5 + dw_conv3x3
x = self.conv_bn_layer(x, filter_size=(7,6), num_filters=out_c, stride=1, padding=0, num_groups=out_c, if_act=False, name='global_dw_conv7x7')
# x = self.conv_bn_layer(x, filter_size=5, num_filters=out_c, stride=1, padding=0, num_groups=out_c, if_act=False, name='global_dw_conv5x5')
# x = self.conv_bn_layer(x, filter_size=3, num_filters=out_c, stride=1, padding=0, num_groups=out_c, if_act=False, name='global_dw_conv3x3')
# 128dim, L2Decay = 4e-4
x = fluid.layers.conv2d(x, num_filters=128, filter_size=1, stride=1, padding=0, groups=1, act=None, use_cudnn=True, param_attr=ParamAttr(name='linear_conv1x1_weights', initializer=MSRA(), regularizer=fluid.regularizer.L2Decay(4e-4)), bias_attr=False)
bn_name = 'linear_conv1x1_bn'
x = fluid.layers.batch_norm(x, param_attr=ParamAttr(name=bn_name + "_scale"), bias_attr=ParamAttr(name=bn_name + "_offset"), moving_mean_name=bn_name + '_mean', moving_variance_name=bn_name + '_variance')
x = fluid.layers.reshape(x, shape=[x.shape[0], x.shape[1]])
if self.extract_feature:
return x
out = self.arc_margin_product(x, label, self.class_dim, s = 32.0, m = 0.50, mode = 2)
softmax = fluid.layers.softmax(input=out)
cost = fluid.layers.cross_entropy(input=softmax, label=label)
loss = fluid.layers.mean(x=cost)
acc = fluid.layers.accuracy(input=out, label=label, k=1)
return loss, acc
def residual_unit(self,
input,
num_in_filter,
num_out_filter,
stride,
filter_size,
expansion_factor,
use_se=False,
name=None):
num_expfilter = int(round(num_in_filter * expansion_factor))
input_data = input
expand_conv = self.conv_bn_layer(
input=input,
filter_size=1,
num_filters=num_expfilter,
stride=1,
padding=0,
if_act=True,
name=name + '_expand')
depthwise_conv = self.conv_bn_layer(
input=expand_conv,
filter_size=filter_size,
num_filters=num_expfilter,
stride=stride,
padding=int((filter_size - 1) // 2),
if_act=True,
num_groups=num_expfilter,
use_cudnn=True,
name=name + '_depthwise')
if use_se:
depthwise_conv = self.se_block(input=depthwise_conv, num_out_filter=num_expfilter, name=name + '_se')
linear_conv = self.conv_bn_layer(
input=depthwise_conv,
filter_size=1,
num_filters=num_out_filter,
stride=1,
padding=0,
if_act=False,
name=name + '_linear')
if num_in_filter != num_out_filter or stride != 1:
return linear_conv
else:
return fluid.layers.elementwise_add(x=input_data, y=linear_conv, act=None)
def se_block(self, input, num_out_filter, ratio=4, name=None):
num_mid_filter = int(num_out_filter // ratio)
pool = fluid.layers.pool2d(input=input, pool_type='avg', global_pooling=True, use_cudnn=False)
conv1 = fluid.layers.conv2d(
input=pool,
filter_size=1,
num_filters=num_mid_filter,
act=None,
param_attr=ParamAttr(name=name + '_1_weights'),
bias_attr=ParamAttr(name=name + '_1_offset'))
conv1 = fluid.layers.prelu(conv1, mode='channel', param_attr = ParamAttr(name=name + '_prelu', regularizer=fluid.regularizer.L2Decay(0.0)))
conv2 = fluid.layers.conv2d(
input=conv1,
filter_size=1,
num_filters=num_out_filter,
act='hard_sigmoid',
param_attr=ParamAttr(name=name + '_2_weights'),
bias_attr=ParamAttr(name=name + '_2_offset'))
scale = fluid.layers.elementwise_mul(x=input, y=conv2, axis=0)
return scale
def conv_bn_layer(self,
input,
filter_size,
num_filters,
stride,
padding,
num_groups=1,
if_act=True,
name=None,
use_cudnn=True):
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=padding,
groups=num_groups,
act=None,
use_cudnn=use_cudnn,
param_attr=ParamAttr(name=name + '_weights', initializer=MSRA()),
bias_attr=False)
bn_name = name + '_bn'
bn = fluid.layers.batch_norm(
input=conv,
param_attr=ParamAttr(name=bn_name + "_scale"),
bias_attr=ParamAttr(name=bn_name + "_offset"),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
# print(bn.shape)
if if_act:
return fluid.layers.prelu(bn, mode='channel', param_attr = ParamAttr(name=name + '_prelu', regularizer=fluid.regularizer.L2Decay(0.0)))
else:
return bn
def arc_margin_product(self, input, label, out_dim, s=32.0, m=0.50, mode=2):
input_norm = fluid.layers.sqrt(fluid.layers.reduce_sum(fluid.layers.square(input), dim=1))
input = fluid.layers.elementwise_div(input, input_norm, axis=0)
weight = fluid.layers.create_parameter(
shape=[out_dim, input.shape[1]],
dtype='float32',
name='weight_norm',
attr=fluid.param_attr.ParamAttr(initializer=fluid.initializer.Xavier(), regularizer=fluid.regularizer.L2Decay(4e-4)))
weight_norm = fluid.layers.sqrt(fluid.layers.reduce_sum(fluid.layers.square(weight), dim=1))
weight = fluid.layers.elementwise_div(weight, weight_norm, axis=0)
weight = fluid.layers.transpose(weight, perm=[1, 0])
cosine = fluid.layers.mul(input, weight)
sine = fluid.layers.sqrt(1.0 - fluid.layers.square(cosine))
cos_m = math.cos(m)
sin_m = math.sin(m)
phi = cosine * cos_m - sine * sin_m
th = math.cos(math.pi - m)
mm = math.sin(math.pi - m) * m
if mode == 1:
phi = self.paddle_where_more_than(cosine, 0, phi, cosine)
elif mode == 2:
phi = self.paddle_where_more_than(cosine, th, phi, cosine - mm)
else:
pass
# print('***** IMPORTANT WARNING *****')
# print('Please determine if phi is correct.')
one_hot = fluid.layers.one_hot(input=label, depth=out_dim)
output = fluid.layers.elementwise_mul(one_hot, phi) + fluid.layers.elementwise_mul((1.0 - one_hot), cosine)
output = output * s
return output
def paddle_where_more_than(self, target, limit, x, y):
mask = fluid.layers.cast(x=(target > limit), dtype='float32')
output = fluid.layers.elementwise_mul(mask, x) + fluid.layers.elementwise_mul((1.0 - mask), y)
return output
if __name__ == "__main__":
x = fluid.layers.data(name='x', shape=[3, 112, 112], dtype='float32')
print(x.shape)
model = SlimFaceNet(10000, [1,3,3,1,1,0,0,1,0,1,1,0,5,5,3])
y = model.net(x)
# ================================================================
# Copyright (C) 2020 BAIDU CORPORATION. All rights reserved.
#
# Filename : slim_eval.sh
# Author : paddleslim@baidu.com
# Date : 2020-05-06
# Describe : eval the performace of slimfacenet on lfw
#
# ================================================================
#!/bin/bash
export CUDA_VISIBLE_DEVICES=0
#export LD_LIBRARY_PATH='PATH to CUDA and CUDNN'
python eval_infer_model.py --action test
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册