From 56c411abeac5fffb0c0ec8c378f3e6143a67d169 Mon Sep 17 00:00:00 2001 From: xiteng1988 Date: Wed, 6 May 2020 16:51:58 +0800 Subject: [PATCH] add slimfacenet --- demo/slimfacenet/dataloader/LFW.py | 41 ++++ demo/slimfacenet/dataloader/__init__.py | 0 demo/slimfacenet/eval_infer_model.py | 228 +++++++++++++++++++++ demo/slimfacenet/lfw_eval.py | 137 +++++++++++++ demo/slimfacenet/models/__init__.py | 1 + demo/slimfacenet/models/calc_flops.py | 214 +++++++++++++++++++ demo/slimfacenet/models/slimfacenet.py | 261 ++++++++++++++++++++++++ demo/slimfacenet/slim_eval.sh | 14 ++ 8 files changed, 896 insertions(+) create mode 100644 demo/slimfacenet/dataloader/LFW.py create mode 100644 demo/slimfacenet/dataloader/__init__.py create mode 100644 demo/slimfacenet/eval_infer_model.py create mode 100644 demo/slimfacenet/lfw_eval.py create mode 100644 demo/slimfacenet/models/__init__.py create mode 100644 demo/slimfacenet/models/calc_flops.py create mode 100644 demo/slimfacenet/models/slimfacenet.py create mode 100644 demo/slimfacenet/slim_eval.sh diff --git a/demo/slimfacenet/dataloader/LFW.py b/demo/slimfacenet/dataloader/LFW.py new file mode 100644 index 00000000..2a0287b7 --- /dev/null +++ b/demo/slimfacenet/dataloader/LFW.py @@ -0,0 +1,41 @@ +import numpy as np +import scipy.misc + +import paddle +from paddle import fluid + +class LFW(object): + def __init__(self, imgl, imgr): + + self.imgl_list = imgl + self.imgr_list = imgr + self.shuffle_idx = [i for i in range(len(self.imgl_list))] + + def reader(self): + while True: + if len(self.shuffle_idx) == 0: + self.shuffle_idx = [i for i in range(len(self.imgl_list))] + return + index = self.shuffle_idx.pop(0) + + imgl = scipy.misc.imread(self.imgl_list[index]) + if len(imgl.shape) == 2: + imgl = np.stack([imgl] * 3, 2) + imgr = scipy.misc.imread(self.imgr_list[index]) + if len(imgr.shape) == 2: + imgr = np.stack([imgr] * 3, 2) + + imglist = [imgl, imgl[:, ::-1, :], imgr, imgr[:, ::-1, :]] + for i in range(len(imglist)): + imglist[i] = (imglist[i] - 127.5) / 128.0 + imglist[i] = imglist[i].transpose(2, 0, 1) + + imgs = [img.astype('float32') for img in imglist] + yield imgs + + def __len__(self): + return len(self.imgl_list) + + +if __name__ == '__main__': + pass \ No newline at end of file diff --git a/demo/slimfacenet/dataloader/__init__.py b/demo/slimfacenet/dataloader/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/demo/slimfacenet/eval_infer_model.py b/demo/slimfacenet/eval_infer_model.py new file mode 100644 index 00000000..90c50b53 --- /dev/null +++ b/demo/slimfacenet/eval_infer_model.py @@ -0,0 +1,228 @@ +import os +import shutil +import subprocess +import argparse +import time +import scipy.io +import numpy as np + +import paddle +from paddle import fluid + +#from dataloader.CASIA import CASIA_Face +from dataloader.LFW import LFW +from lfw_eval import parseList, evaluation_10_fold +from models.slimfacenet import SlimFaceNet + +def now(): + return time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())) + + +def creat_optimizer(args, trainset_scale): + start_step = trainset_scale * args.start_epoch // args.train_batchsize + + if args.lr_strategy == 'piecewise_decay': + bd = [trainset_scale * int(e) // args.train_batchsize for e in args.lr_steps.strip().split(',')] + lr = [float(e) for e in args.lr_list.strip().split(',')] + assert len(bd) == len(lr) - 1 + optimizer = fluid.optimizer.Momentum( + learning_rate=fluid.layers.piecewise_decay(boundaries=bd, values=lr), + momentum=0.9, + regularization=fluid.regularizer.L2Decay(args.l2_decay)) + elif args.lr_strategy == 'cosine_decay': + lr = args.lr + step_each_epoch = trainset_scale // args.train_batchsize + optimizer = fluid.optimizer.Momentum( + learning_rate=fluid.layers.cosine_decay(lr, step_each_epoch, args.total_epoch), + momentum=0.9, + regularization=fluid.regularizer.L2Decay(args.l2_decay)) + else: + print('Wrong learning rate strategy') + exit() + return optimizer + + +def test(test_exe, test_program, test_out, args): + featureLs = None + featureRs = None + out_feature, test_reader, flods, flags = test_out + for idx, data in enumerate(test_reader()): + res = [] + res.append(test_exe.run(test_program, feed = {u'image_test': data[0][u'image_test1']}, fetch_list = out_feature)) + res.append(test_exe.run(test_program, feed = {u'image_test': data[0][u'image_test2']}, fetch_list = out_feature)) + res.append(test_exe.run(test_program, feed = {u'image_test': data[0][u'image_test3']}, fetch_list = out_feature)) + res.append(test_exe.run(test_program, feed = {u'image_test': data[0][u'image_test4']}, fetch_list = out_feature)) + featureL = np.concatenate((res[0][0], res[1][0]), 1) + featureR = np.concatenate((res[2][0], res[3][0]), 1) + if featureLs is None: + featureLs = featureL + else: + featureLs = np.concatenate((featureLs, featureL), 0) + if featureRs is None: + featureRs = featureR + else: + featureRs = np.concatenate((featureRs, featureR), 0) + result = {'fl': featureLs, 'fr': featureRs, 'fold': flods, 'flag': flags} + scipy.io.savemat(args.feature_save_dir, result) + ACCs = evaluation_10_fold(args.feature_save_dir) + print('eval arch {}'.format(args.arch)) + with open(os.path.join(args.save_ckpt, 'log.txt'), 'a+') as f: + f.writelines('eval arch {}\n'.format(args.arch)) + for i in range(len(ACCs)): + #print('{} {:.2f}'.format(i+1, ACCs[i] * 100)) + print('{} {}'.format(i+1, ACCs[i] * 100)) + with open(os.path.join(args.save_ckpt, 'log.txt'), 'a+') as f: + #f.writelines('{} {:.2f}\n'.format(i+1, ACCs[i] * 100)) + f.writelines('{} {}\n'.format(i+1, ACCs[i] * 100)) + print('--------') + #print('AVE {:.2f}'.format(np.mean(ACCs) * 100)) + print('AVE {}'.format(np.mean(ACCs) * 100)) + with open(os.path.join(args.save_ckpt, 'log.txt'), 'a+') as f: + f.writelines('--------\n') + #f.writelines('AVE {:.2f}\n'.format(np.mean(ACCs) * 100)) + f.writelines('AVE {}\n'.format(np.mean(ACCs) * 100)) + return np.mean(ACCs) * 100 + + +def train(exe, train_program, train_out, test_program, test_out, args): + loss, acc, global_lr, train_reader = train_out + fetch_list_train = [loss.name, acc.name, global_lr.name] + train_exe = fluid.ParallelExecutor( + use_cuda=True, + loss_name=loss.name, + main_program=train_program) + for epoch_id in range(args.start_epoch, args.total_epoch): + for batch_id, data in enumerate(train_reader()): + loss, acc, global_lr = train_exe.run(feed=data, fetch_list=fetch_list_train) + avg_loss = np.mean(np.array(loss)) + avg_acc = np.mean(np.array(acc)) + print('{} Epoch: {:^4d} step: {:^4d} loss: {:.6f}, acc: {:.6f}, lr: {}'.format( + now(), epoch_id, batch_id, avg_loss, avg_acc, float(np.mean(np.array(global_lr))))) + + #test(exe, test_program, test_out, args) + if batch_id % args.save_frequency == 0: + model_path = os.path.join(args.save_ckpt, str(epoch_id)) + fluid.io.save_persistables(executor=exe, dirname=model_path, main_program=train_program) + + test(exe, test_program, test_out, args) + + +def build_program(program, startup, args, is_train=True): + num_trainers = len(os.getenv('CUDA_VISIBLE_DEVICES').split(',')) + places = fluid.cuda_places() if args.use_gpu else fluid.CPUPlace() + + train_dataset = CASIA_Face(root = args.train_data_dir) + trainset_scale = len(train_dataset) + + with fluid.program_guard(main_program=program, startup_program=startup): + with fluid.unique_name.guard(): + # Model construction + arch = [int(a) for a in args.arch.strip().split(',')] + model = SlimFaceNet(class_dim = train_dataset.class_nums, arch = arch) + + if is_train: + image = fluid.layers.data(name='image', shape=[-1, 3, 112, 112], dtype='float32') + label = fluid.layers.data(name='label', shape=[-1, 1], dtype='int64') + train_reader = paddle.batch(train_dataset.reader, batch_size = args.train_batchsize // num_trainers, drop_last = False) + reader = fluid.io.PyReader(feed_list=[image, label], capacity=64, iterable=True, return_list=False) + reader.decorate_sample_list_generator(train_reader, places=places) + + model.extract_feature = False + loss, acc = model.net(image, label) + optimizer = creat_optimizer(args, trainset_scale) + optimizer.minimize(loss) + global_lr = optimizer._global_learning_rate() + out = (loss, acc, global_lr, reader) + + else: + nl, nr, flods, flags = parseList(args.test_data_dir) + test_dataset = LFW(nl, nr) + test_reader = paddle.batch(test_dataset.reader, batch_size = args.test_batchsize, drop_last = False) + image_test = fluid.layers.data(name='image_test', shape=[-1, 3, 112, 112], dtype='float32') + image_test1 = fluid.layers.data(name='image_test1', shape=[-1, 3, 112, 112], dtype='float32') + image_test2 = fluid.layers.data(name='image_test2', shape=[-1, 3, 112, 112], dtype='float32') + image_test3 = fluid.layers.data(name='image_test3', shape=[-1, 3, 112, 112], dtype='float32') + image_test4 = fluid.layers.data(name='image_test4', shape=[-1, 3, 112, 112], dtype='float32') + reader = fluid.io.PyReader(feed_list=[image_test1, image_test2, image_test3, image_test4], capacity=64, iterable=True, return_list=False) + reader.decorate_sample_list_generator(test_reader, fluid.core.CPUPlace()) + + model.extract_feature = True + feature = model.net(image_test) + out = (feature, reader, flods, flags) + + return out + + +def main(): + global args + parser = argparse.ArgumentParser(description='PaddlePaddle SlimFaceNet') + parser.add_argument('--action', default='final', type=str, help='test/final') + parser.add_argument('--model', default='slimfacenet', type=str, help='slimfacenet/slimfacenet_v1') + parser.add_argument('--arch', default='1,1,0,1,1,1,1,0,1,0,1,3,2,2,3', type=str, help='arch') + parser.add_argument('--use_gpu', default=1, type=int, help='Use GPU or not, 0 is not used') + parser.add_argument('--use_multiGPU', default=0, type=int, help='Use multi GPU or not, 0 is not used') + parser.add_argument('--lr_strategy', default='piecewise_decay', type=str, help='lr_strategy') + parser.add_argument('--lr', default=0.1, type=float, help='learning rate') + parser.add_argument('--lr_list', default='0.1,0.01,0.001,0.0001', type=str, help='learning rate list (piecewise_decay)') + parser.add_argument('--lr_steps', default='36,52,58', type=str, help='learning rate decay at which epochs') + parser.add_argument('--l2_decay', default=4e-5, type=float, help='base l2_decay') + parser.add_argument('--train_data_dir', default='./CASIA', type=str, help='train_data_dir') + parser.add_argument('--test_data_dir', default='./lfw', type=str, help='lfw_data_dir') + parser.add_argument('--train_batchsize', default=512, type=int, help='train_batchsize') + parser.add_argument('--test_batchsize', default=500, type=int, help='test_batchsize') + parser.add_argument('--img_shape', default='3,112,96', type=str, help='img_shape') + parser.add_argument('--start_epoch', default=0, type=int, help='start_epoch') + parser.add_argument('--total_epoch', default=80, type=int, help='total_epoch') + parser.add_argument('--save_frequency', default=1, type=int, help='save_frequency') + parser.add_argument('--save_ckpt', default='output', type=str, help='save_ckpt') + parser.add_argument('--resume', default='', type=str, help='resume') + parser.add_argument('--feature_save_dir', default='result.mat', type=str, help='The path of the extract features save, must be .mat file') + args = parser.parse_args() + + num_trainers = len(os.getenv('CUDA_VISIBLE_DEVICES').split(',')) + print(args) + print('num_trainers: {}'.format(num_trainers)) + if args.save_ckpt == None: + args.save_ckpt = 'output' + if not os.path.exists(args.save_ckpt): + subprocess.call(['mkdir', '-p', args.save_ckpt]) + shutil.copyfile(__file__, os.path.join(args.save_ckpt, 'train.py')) + shutil.copyfile('models/slimfacenet.py', os.path.join(args.save_ckpt, 'model.py')) + with open(os.path.join(args.save_ckpt, 'log.txt'), 'w+') as f: + f.writelines(str(args) + '\n') + f.writelines('num_trainers: {}'.format(num_trainers) + '\n') + + startup_program = fluid.Program() + + place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace() + exe = fluid.Executor(place) + exe.run(startup_program) + + [inference_program, feed_target_names, fetch_targets] = fluid.io.load_inference_model(dirname='./quant_model/', + model_filename=None, + params_filename=None, + executor=exe) + + #if args.action == 'final': + # train(exe, train_program, train_out, test_program, test_out, args) + if args.action == 'test': + nl, nr, flods, flags = parseList(args.test_data_dir) + test_dataset = LFW(nl, nr) + test_reader = paddle.batch(test_dataset.reader, batch_size = args.test_batchsize, drop_last = False) + image_test = fluid.layers.data(name='image_test', shape=[-1, 3, 112, 96], dtype='float32') + image_test1 = fluid.layers.data(name='image_test1', shape=[-1, 3, 112, 96], dtype='float32') + image_test2 = fluid.layers.data(name='image_test2', shape=[-1, 3, 112, 96], dtype='float32') + image_test3 = fluid.layers.data(name='image_test3', shape=[-1, 3, 112, 96], dtype='float32') + image_test4 = fluid.layers.data(name='image_test4', shape=[-1, 3, 112, 96], dtype='float32') + reader = fluid.io.PyReader(feed_list=[image_test1, image_test2, image_test3, image_test4], capacity=64, iterable=True, return_list=False) + reader.decorate_sample_list_generator(test_reader, fluid.core.CPUPlace()) + test_out = (fetch_targets, reader, flods, flags) + print('fetch_targets[0]: ', fetch_targets[0]) + print('feed_target_names: ', feed_target_names) + test(exe, inference_program, test_out, args) + else: + print('WRONG ACTION') + + +if __name__ == '__main__': + main() diff --git a/demo/slimfacenet/lfw_eval.py b/demo/slimfacenet/lfw_eval.py new file mode 100644 index 00000000..b96425d4 --- /dev/null +++ b/demo/slimfacenet/lfw_eval.py @@ -0,0 +1,137 @@ +import os +import argparse +import time +import scipy.io +import numpy as np + +import paddle +from paddle import fluid + +#from dataloader.CASIA import CASIA_Face +from dataloader.LFW import LFW +from models.slimfacenet import SlimFaceNet + + +def parseList(root): + with open(os.path.join(root, 'pairs.txt')) as f: + pairs = f.read().splitlines()[1:] + folder_name = 'lfw-112X96' + nameLs = [] + nameRs = [] + folds = [] + flags = [] + for i, p in enumerate(pairs): + p = p.split('\t') + if len(p) == 3: + nameL = os.path.join(root, folder_name, p[0], p[0] + '_' + '{:04}.jpg'.format(int(p[1]))) + nameR = os.path.join(root, folder_name, p[0], p[0] + '_' + '{:04}.jpg'.format(int(p[2]))) + fold = i // 600 + flag = 1 + elif len(p) == 4: + nameL = os.path.join(root, folder_name, p[0], p[0] + '_' + '{:04}.jpg'.format(int(p[1]))) + nameR = os.path.join(root, folder_name, p[2], p[2] + '_' + '{:04}.jpg'.format(int(p[3]))) + fold = i // 600 + flag = -1 + nameLs.append(nameL) + nameRs.append(nameR) + folds.append(fold) + flags.append(flag) + return [nameLs, nameRs, folds, flags] + + +def getAccuracy(scores, flags, threshold): + p = np.sum(scores[flags == 1] > threshold) + n = np.sum(scores[flags == -1] < threshold) + return 1.0 * (p + n) / len(scores) + + +def getThreshold(scores, flags, thrNum): + accuracys = np.zeros((2 * thrNum + 1, 1)) + thresholds = np.arange(-thrNum, thrNum + 1) * 1.0 / thrNum + for i in range(2 * thrNum + 1): + accuracys[i] = getAccuracy(scores, flags, thresholds[i]) + + max_index = np.squeeze(accuracys == np.max(accuracys)) + bestThreshold = np.mean(thresholds[max_index]) + return bestThreshold + + +def evaluation_10_fold(root='result.mat'): + ACCs = np.zeros(10) + result = scipy.io.loadmat(root) + for i in range(10): + fold = result['fold'] + flags = result['flag'] + featureLs = result['fl'] + featureRs = result['fr'] + + valFold = fold != i + testFold = fold == i + flags = np.squeeze(flags) + + mu = np.mean(np.concatenate((featureLs[valFold[0], :], featureRs[valFold[0], :]), 0), 0) + mu = np.expand_dims(mu, 0) + featureLs = featureLs - mu + featureRs = featureRs - mu + featureLs = featureLs / np.expand_dims(np.sqrt(np.sum(np.power(featureLs, 2), 1)), 1) + featureRs = featureRs / np.expand_dims(np.sqrt(np.sum(np.power(featureRs, 2), 1)), 1) + + scores = np.sum(np.multiply(featureLs, featureRs), 1) + threshold = getThreshold(scores[valFold[0]], flags[valFold[0]], 10000) + ACCs[i] = getAccuracy(scores[testFold[0]], flags[testFold[0]], threshold) + return ACCs + + +def test(test_reader, flods, flags, net, args): + net.eval() + featureLs = None + featureRs = None + for idx, data in enumerate(test_reader()): + data_list = [[] for _ in range(4)] + for _ in range(len(data)): + data_list[0].append(data[_][0]) + data_list[1].append(data[_][1]) + data_list[2].append(data[_][2]) + data_list[3].append(data[_][3]) + res = [net(fluid.dygraph.to_variable(np.array(d))).numpy() for d in data_list] + featureL = np.concatenate((res[0], res[1]), 1) + featureR = np.concatenate((res[2], res[3]), 1) + if featureLs is None: + featureLs = featureL + else: + featureLs = np.concatenate((featureLs, featureL), 0) + if featureRs is None: + featureRs = featureR + else: + featureRs = np.concatenate((featureRs, featureR), 0) + result = {'fl': featureLs, 'fr': featureRs, 'fold': flods, 'flag': flags} + scipy.io.savemat(args.feature_save_dir, result) + ACCs = evaluation_10_fold(args.feature_save_dir) + for i in range(len(ACCs)): + print('{} {:.2f}'.format(i+1, ACCs[i] * 100)) + print('--------') + print('AVE {:.2f}'.format(np.mean(ACCs) * 100)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='PaddlePaddle SlimFaceNet') + parser.add_argument('--use_gpu', default=0, type=int, help='Use GPU or not, 0 is not used') + parser.add_argument('--test_data_dir', default='./lfw', type=str, help='lfw_data_dir') + parser.add_argument('--resume', default='output/0', type=str, help='resume') + parser.add_argument('--feature_save_dir', default='result.mat', type=str, help='The path of the extract features save, must be .mat file') + args = parser.parse_args() + + place = fluid.CPUPlace() if args.use_gpu == 0 else fluid.CUDAPlace(0) + with fluid.dygraph.guard(place): + train_dataset = CASIA_Face(root = args.train_data_dir) + nl, nr, flods, flags = parseList(args.test_data_dir) + test_dataset = LFW(nl, nr) + test_reader = paddle.batch(test_dataset.reader, batch_size = args.test_batchsize, drop_last = False) + net = SlimFaceNet(train_dataset.class_nums, args.img_shape) + + if args.resume: + assert os.path.exists(args.resume + ".pdparams"), "Given dir {}.pdparams not exist.".format(args.resume) + para_dict, opti_dict = fluid.dygraph.load_dygraph(args.resume) + net.set_dict(para_dict) + + test(test_reader, flods, flags, net, args) diff --git a/demo/slimfacenet/models/__init__.py b/demo/slimfacenet/models/__init__.py new file mode 100644 index 00000000..34153bf9 --- /dev/null +++ b/demo/slimfacenet/models/__init__.py @@ -0,0 +1 @@ +from .slimfacenet import SlimFaceNet diff --git a/demo/slimfacenet/models/calc_flops.py b/demo/slimfacenet/models/calc_flops.py new file mode 100644 index 00000000..82f971c2 --- /dev/null +++ b/demo/slimfacenet/models/calc_flops.py @@ -0,0 +1,214 @@ +from collections import OrderedDict +from prettytable import PrettyTable +import distutils.util +import numpy as np +import six + +def summary(main_prog): + ''' + It can summary model's PARAMS, FLOPs until now. + It support common operator like conv, fc, pool, relu, sigmoid, bn etc. + Args: + main_prog: main program + Returns: + print summary on terminal + ''' + collected_ops_list = [] + is_quantize = False + for one_b in main_prog.blocks: + block_vars = one_b.vars + for one_op in one_b.ops: + # if str(one_op.type).find('quantize') > -1: + # is_quantize = True + op_info = OrderedDict() + spf_res = _summary_model(block_vars, one_op) + if spf_res is None: + continue + # TODO: get the operator name + op_info['type'] = one_op.type + op_info['input_shape'] = spf_res[0][1:] + op_info['out_shape'] = spf_res[1][1:] + op_info['PARAMs'] = spf_res[2] + op_info['FLOPs'] = spf_res[3] + collected_ops_list.append(op_info) + + + summary_table, total = _format_summary(collected_ops_list) + _print_summary(summary_table, total) + return total, is_quantize + + +def _summary_model(block_vars, one_op): + ''' + Compute operator's params and flops. + Args: + block_vars: all vars of one block + one_op: one operator to count + Returns: + in_data_shape: one operator's input data shape + out_data_shape: one operator's output data shape + params: one operator's PARAMs + flops: : one operator's FLOPs + ''' + if one_op.type in ['conv2d', 'depthwise_conv2d']: + k_arg_shape = block_vars[one_op.input("Filter")[0]].shape + in_data_shape = block_vars[one_op.input("Input")[0]].shape + out_data_shape = block_vars[one_op.output("Output")[0]].shape + c_out, c_in, k_h, k_w = k_arg_shape + _, c_out_, h_out, w_out = out_data_shape + #assert c_out == c_out_, 'shape error!' + k_groups = one_op.attr("groups") + kernel_ops = k_h * k_w * (in_data_shape[1] / k_groups) + try: + bias_ops = 0 if one_op.input("Bias") == [] else 1 + except: + bias_ops = 0 + params = c_out * (kernel_ops + bias_ops) + flops = h_out * w_out * c_out * (kernel_ops + bias_ops) + # base nvidia paper, include mul and add + flops = 2 * flops + if one_op.type == 'depthwise_conv2d': + pass + + # var_name = block_vars[one_op.input("Filter")[0]].name + # if var_name.endswith('.int8'): + # flops /= 2.0 + + elif one_op.type == 'pool2d': + in_data_shape = block_vars[one_op.input("X")[0]].shape + out_data_shape = block_vars[one_op.output("Out")[0]].shape + _, c_out, h_out, w_out = out_data_shape + k_size = one_op.attr("ksize") + params = 0 + flops = h_out * w_out * c_out * (k_size[0] * k_size[1]) + + elif one_op.type == 'mul': + k_arg_shape = block_vars[one_op.input("Y")[0]].shape + in_data_shape = block_vars[one_op.input("X")[0]].shape + out_data_shape = block_vars[one_op.output("Out")[0]].shape + # TODO: fc has mul ops + # add attr to mul op, tell us whether it belongs to 'fc' + # this's not the best way + if 'fc' not in one_op.output("Out")[0]: + return None + k_in, k_out = k_arg_shape + # bias in sum op + params = k_in * k_out + 1 + flops = k_in * k_out + + # var_name = block_vars[one_op.input("Y")[0]].name + # if var_name.endswith('.int8'): + # flops /= 2.0 + + elif one_op.type in ['sigmoid', 'tanh', 'relu', 'leaky_relu', 'prelu']: + in_data_shape = block_vars[one_op.input("X")[0]].shape + out_data_shape = block_vars[one_op.output("Out")[0]].shape + params = 0 + if one_op.type == 'prelu': + params = 1 + flops = 1 + for one_dim in in_data_shape[1:]: + flops *= one_dim + + elif one_op.type == 'batch_norm': + in_data_shape = block_vars[one_op.input("X")[0]].shape + out_data_shape = block_vars[one_op.output("Y")[0]].shape + _, c_in, h_out, w_out = in_data_shape + # gamma, beta + params = c_in * 2 + # compute mean and std + flops = h_out * w_out * c_in * 2 + + else: + return None + + return in_data_shape, out_data_shape, params, flops + + +def _format_summary(collected_ops_list): + ''' + Format summary report. + Args: + collected_ops_list: the collected operator with summary + Returns: + summary_table: summary report format + total: sum param and flops + ''' + summary_table = PrettyTable( + ["No.", "TYPE", "INPUT", "OUTPUT", "PARAMs", "FLOPs"]) + summary_table.align = 'r' + + total = {} + total_params = [] + total_flops = [] + for i, one_op in enumerate(collected_ops_list): + # notice the order + table_row = [ + i, + one_op['type'], + one_op['input_shape'], + one_op['out_shape'], + int(one_op['PARAMs']), + int(one_op['FLOPs']), + ] + summary_table.add_row(table_row) + total_params.append(int(one_op['PARAMs'])) + total_flops.append(int(one_op['FLOPs'])) + + total['params'] = total_params + total['flops'] = total_flops + + return summary_table, total + + +def _print_summary(summary_table, total): + ''' + Print all the summary on terminal. + Args: + summary_table: summary report format + total: sum param and flops + ''' + parmas = total['params'] + flops = total['flops'] + print(summary_table) + print('Total PARAMs: {}({:.4f}M)'.format( + sum(parmas), sum(parmas) / (10.0 ** 6))) + print('Total FLOPs: {}({:.4f}G)'.format(sum(flops), sum(flops) / 10.0 ** 6)) + print('Total MAdds: {}({:.4f}G)'.format(sum(flops)/2, sum(flops) / 10.0 ** 6 / 2)) + print( + "Notice: \n now supported ops include [Conv, DepthwiseConv, FC(mul), BatchNorm, Pool, Activation(sigmoid, tanh, relu, leaky_relu, prelu)]" + ) + + +def get_batch_dt_res(nmsed_out_v, data, contiguous_category_id_to_json_id, batch_size): + dts_res = [] + lod = nmsed_out_v[0].lod()[0] + nmsed_out_v = np.array(nmsed_out_v[0]) + real_batch_size = min(batch_size, len(data)) + assert (len(lod) == real_batch_size + 1), \ + "Error Lod Tensor offset dimension. Lod({}) vs. batch_size({})".format(len(lod), batch_size) + k = 0 + for i in range(real_batch_size): + dt_num_this_img = lod[i + 1] - lod[i] + image_id = int(data[i][4][0]) + image_width = int(data[i][4][1]) + image_height = int(data[i][4][2]) + for j in range(dt_num_this_img): + dt = nmsed_out_v[k] + k = k + 1 + category_id, score, xmin, ymin, xmax, ymax = dt.tolist() + xmin = max(min(xmin, 1.0), 0.0) * image_width + ymin = max(min(ymin, 1.0), 0.0) * image_height + xmax = max(min(xmax, 1.0), 0.0) * image_width + ymax = max(min(ymax, 1.0), 0.0) * image_height + w = xmax - xmin + h = ymax - ymin + bbox = [xmin, ymin, w, h] + dt_res = { + 'image_id': image_id, + 'category_id': contiguous_category_id_to_json_id[category_id], + 'bbox': bbox, + 'score': score + } + dts_res.append(dt_res) + return dts_res diff --git a/demo/slimfacenet/models/slimfacenet.py b/demo/slimfacenet/models/slimfacenet.py new file mode 100644 index 00000000..609458ed --- /dev/null +++ b/demo/slimfacenet/models/slimfacenet.py @@ -0,0 +1,261 @@ +import math +import datetime +import numpy as np + +import paddle +import paddle.fluid as fluid +from paddle.fluid.initializer import MSRA +from paddle.fluid.param_attr import ParamAttr + +class SlimFaceNet(): + def __init__(self, class_dim, scale=0.6, arch=None): + + assert arch is not None + self.arch = arch + self.class_dim = class_dim + kernels = [3] + expansions = [2, 4, 6] + SE = [0, 1] + self.table = [] + for k in kernels: + for e in expansions: + for se in SE: + self.table.append((k, e, se)) + + if scale == 1.0: + # 100% - channel + self.Slimfacenet_bottleneck_setting = [ + # t, c , n ,s + [2, 64, 5, 2], + [4, 128, 1, 2], + [2, 128, 6, 1], + [4, 128, 1, 2], + [2, 128, 2, 1] + ] + elif scale == 0.9: + # 90% - channel + self.Slimfacenet_bottleneck_setting = [ + # t, c , n ,s + [2, 56, 5, 2], + [4, 116, 1, 2], + [2, 116, 6, 1], + [4, 116, 1, 2], + [2, 116, 2, 1] + ] + elif scale == 0.75: + # 75% - channel + self.Slimfacenet_bottleneck_setting = [ + # t, c , n ,s + [2, 48, 5, 2], + [4, 96, 1, 2], + [2, 96, 6, 1], + [4, 96, 1, 2], + [2, 96, 2, 1] + ] + elif scale == 0.6: + # 60% - channel + self.Slimfacenet_bottleneck_setting = [ + # t, c , n ,s + [2, 40, 5, 2], + [4, 76, 1, 2], + [2, 76, 6, 1], + [4, 76, 1, 2], + [2, 76, 2, 1] + ] + else: + print('WRONG scale') + exit() + self.extract_feature = True + + def set_extract_feature_flag(self, flag): + self.extract_feature = flag + + def net(self, input, label=None): + x = self.conv_bn_layer(input, filter_size=3, num_filters=64, stride=2, padding=1, num_groups=1, if_act=True, name='conv3x3') + x = self.conv_bn_layer(x, filter_size=3, num_filters=64, stride=1, padding=1, num_groups=64, if_act=True, name='dw_conv3x3') + + in_c = 64 + cnt = 0 + for _exp, out_c , times, _stride in self.Slimfacenet_bottleneck_setting: + for i in range(times): + stride = _stride if i==0 else 1 + filter_size, exp, se = self.table[self.arch[cnt]] + se = False if se==0 else True + x = self.residual_unit(x, num_in_filter=in_c, num_out_filter=out_c, stride=stride, filter_size=filter_size, expansion_factor=exp, use_se=se, name='residual_unit'+str(cnt+1)) + cnt += 1 + in_c = out_c + + out_c = 512 + x = self.conv_bn_layer(x, filter_size=1, num_filters=out_c, stride=1, padding=0, num_groups=1, if_act=True, name='conv1x1') + # Replace dw_conv7x7 with dw_conv5x5 + dw_conv3x3 + x = self.conv_bn_layer(x, filter_size=(7,6), num_filters=out_c, stride=1, padding=0, num_groups=out_c, if_act=False, name='global_dw_conv7x7') + # x = self.conv_bn_layer(x, filter_size=5, num_filters=out_c, stride=1, padding=0, num_groups=out_c, if_act=False, name='global_dw_conv5x5') + # x = self.conv_bn_layer(x, filter_size=3, num_filters=out_c, stride=1, padding=0, num_groups=out_c, if_act=False, name='global_dw_conv3x3') + # 128dim, L2Decay = 4e-4 + x = fluid.layers.conv2d(x, num_filters=128, filter_size=1, stride=1, padding=0, groups=1, act=None, use_cudnn=True, param_attr=ParamAttr(name='linear_conv1x1_weights', initializer=MSRA(), regularizer=fluid.regularizer.L2Decay(4e-4)), bias_attr=False) + bn_name = 'linear_conv1x1_bn' + x = fluid.layers.batch_norm(x, param_attr=ParamAttr(name=bn_name + "_scale"), bias_attr=ParamAttr(name=bn_name + "_offset"), moving_mean_name=bn_name + '_mean', moving_variance_name=bn_name + '_variance') + + x = fluid.layers.reshape(x, shape=[x.shape[0], x.shape[1]]) + + if self.extract_feature: + return x + + out = self.arc_margin_product(x, label, self.class_dim, s = 32.0, m = 0.50, mode = 2) + softmax = fluid.layers.softmax(input=out) + cost = fluid.layers.cross_entropy(input=softmax, label=label) + loss = fluid.layers.mean(x=cost) + acc = fluid.layers.accuracy(input=out, label=label, k=1) + return loss, acc + + def residual_unit(self, + input, + num_in_filter, + num_out_filter, + stride, + filter_size, + expansion_factor, + use_se=False, + name=None): + + num_expfilter = int(round(num_in_filter * expansion_factor)) + input_data = input + + expand_conv = self.conv_bn_layer( + input=input, + filter_size=1, + num_filters=num_expfilter, + stride=1, + padding=0, + if_act=True, + name=name + '_expand') + + depthwise_conv = self.conv_bn_layer( + input=expand_conv, + filter_size=filter_size, + num_filters=num_expfilter, + stride=stride, + padding=int((filter_size - 1) // 2), + if_act=True, + num_groups=num_expfilter, + use_cudnn=True, + name=name + '_depthwise') + + if use_se: + depthwise_conv = self.se_block(input=depthwise_conv, num_out_filter=num_expfilter, name=name + '_se') + + linear_conv = self.conv_bn_layer( + input=depthwise_conv, + filter_size=1, + num_filters=num_out_filter, + stride=1, + padding=0, + if_act=False, + name=name + '_linear') + if num_in_filter != num_out_filter or stride != 1: + return linear_conv + else: + return fluid.layers.elementwise_add(x=input_data, y=linear_conv, act=None) + + def se_block(self, input, num_out_filter, ratio=4, name=None): + num_mid_filter = int(num_out_filter // ratio) + pool = fluid.layers.pool2d(input=input, pool_type='avg', global_pooling=True, use_cudnn=False) + conv1 = fluid.layers.conv2d( + input=pool, + filter_size=1, + num_filters=num_mid_filter, + act=None, + param_attr=ParamAttr(name=name + '_1_weights'), + bias_attr=ParamAttr(name=name + '_1_offset')) + conv1 = fluid.layers.prelu(conv1, mode='channel', param_attr = ParamAttr(name=name + '_prelu', regularizer=fluid.regularizer.L2Decay(0.0))) + conv2 = fluid.layers.conv2d( + input=conv1, + filter_size=1, + num_filters=num_out_filter, + act='hard_sigmoid', + param_attr=ParamAttr(name=name + '_2_weights'), + bias_attr=ParamAttr(name=name + '_2_offset')) + scale = fluid.layers.elementwise_mul(x=input, y=conv2, axis=0) + return scale + + def conv_bn_layer(self, + input, + filter_size, + num_filters, + stride, + padding, + num_groups=1, + if_act=True, + name=None, + use_cudnn=True): + conv = fluid.layers.conv2d( + input=input, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=padding, + groups=num_groups, + act=None, + use_cudnn=use_cudnn, + param_attr=ParamAttr(name=name + '_weights', initializer=MSRA()), + bias_attr=False) + bn_name = name + '_bn' + bn = fluid.layers.batch_norm( + input=conv, + param_attr=ParamAttr(name=bn_name + "_scale"), + bias_attr=ParamAttr(name=bn_name + "_offset"), + moving_mean_name=bn_name + '_mean', + moving_variance_name=bn_name + '_variance') + # print(bn.shape) + if if_act: + return fluid.layers.prelu(bn, mode='channel', param_attr = ParamAttr(name=name + '_prelu', regularizer=fluid.regularizer.L2Decay(0.0))) + else: + return bn + + def arc_margin_product(self, input, label, out_dim, s=32.0, m=0.50, mode=2): + input_norm = fluid.layers.sqrt(fluid.layers.reduce_sum(fluid.layers.square(input), dim=1)) + input = fluid.layers.elementwise_div(input, input_norm, axis=0) + + weight = fluid.layers.create_parameter( + shape=[out_dim, input.shape[1]], + dtype='float32', + name='weight_norm', + attr=fluid.param_attr.ParamAttr(initializer=fluid.initializer.Xavier(), regularizer=fluid.regularizer.L2Decay(4e-4))) + + weight_norm = fluid.layers.sqrt(fluid.layers.reduce_sum(fluid.layers.square(weight), dim=1)) + weight = fluid.layers.elementwise_div(weight, weight_norm, axis=0) + weight = fluid.layers.transpose(weight, perm=[1, 0]) + cosine = fluid.layers.mul(input, weight) + sine = fluid.layers.sqrt(1.0 - fluid.layers.square(cosine)) + + cos_m = math.cos(m) + sin_m = math.sin(m) + phi = cosine * cos_m - sine * sin_m + + th = math.cos(math.pi - m) + mm = math.sin(math.pi - m) * m + + if mode == 1: + phi = self.paddle_where_more_than(cosine, 0, phi, cosine) + elif mode == 2: + phi = self.paddle_where_more_than(cosine, th, phi, cosine - mm) + else: + pass + # print('***** IMPORTANT WARNING *****') + # print('Please determine if phi is correct.') + + one_hot = fluid.layers.one_hot(input=label, depth=out_dim) + output = fluid.layers.elementwise_mul(one_hot, phi) + fluid.layers.elementwise_mul((1.0 - one_hot), cosine) + output = output * s + return output + + def paddle_where_more_than(self, target, limit, x, y): + mask = fluid.layers.cast(x=(target > limit), dtype='float32') + output = fluid.layers.elementwise_mul(mask, x) + fluid.layers.elementwise_mul((1.0 - mask), y) + return output + +if __name__ == "__main__": + x = fluid.layers.data(name='x', shape=[3, 112, 112], dtype='float32') + print(x.shape) + model = SlimFaceNet(10000, [1,3,3,1,1,0,0,1,0,1,1,0,5,5,3]) + y = model.net(x) diff --git a/demo/slimfacenet/slim_eval.sh b/demo/slimfacenet/slim_eval.sh new file mode 100644 index 00000000..e1b4e260 --- /dev/null +++ b/demo/slimfacenet/slim_eval.sh @@ -0,0 +1,14 @@ +# ================================================================ +# Copyright (C) 2020 BAIDU CORPORATION. All rights reserved. +# +# Filename : slim_eval.sh +# Author : paddleslim@baidu.com +# Date : 2020-05-06 +# Describe : eval the performace of slimfacenet on lfw +# +# ================================================================ + +#!/bin/bash +export CUDA_VISIBLE_DEVICES=0 +#export LD_LIBRARY_PATH='PATH to CUDA and CUDNN' +python eval_infer_model.py --action test -- GitLab