未验证 提交 86b5f49b 编写于 作者: X xiteng1988 提交者: GitHub

cherry pick 254 284 (#286)

* add slimfacenet to paddleslim/models (#284)

* add slimfacenets (#254)

* add slimfacenet
上级 60544b52
......@@ -4,6 +4,7 @@ from .resnet import ResNet34, ResNet50
from .resnet_vd import ResNet50_vd, ResNet101_vd
from .mobilenet_v2 import MobileNetV2_x0_25, MobileNetV2
from .pvanet import PVANet
from .slimfacenet import SlimFaceNet_A_x0_60, SlimFaceNet_B_x0_75, SlimFaceNet_C_x0_75
__all__ = [
"model_list", "MobileNet", "ResNet34", "ResNet50", "MobileNetV2", "PVANet",
"ResNet50_vd", "ResNet101_vd", "MobileNetV2_x0_25"
......
# ================================================================
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import datetime
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.initializer import MSRA
from paddle.fluid.param_attr import ParamAttr
class SlimFaceNet():
def __init__(self, class_dim, scale=0.6, arch=None):
assert arch is not None
self.arch = arch
self.class_dim = class_dim
kernels = [3]
expansions = [2, 4, 6]
SE = [0, 1]
self.table = []
for k in kernels:
for e in expansions:
for se in SE:
self.table.append((k, e, se))
if scale == 1.0:
# 100% - channel
self.Slimfacenet_bottleneck_setting = [
# t, c , n ,s
[2, 64, 5, 2],
[4, 128, 1, 2],
[2, 128, 6, 1],
[4, 128, 1, 2],
[2, 128, 2, 1]
]
elif scale == 0.9:
# 90% - channel
self.Slimfacenet_bottleneck_setting = [
# t, c , n ,s
[2, 56, 5, 2],
[4, 116, 1, 2],
[2, 116, 6, 1],
[4, 116, 1, 2],
[2, 116, 2, 1]
]
elif scale == 0.75:
# 75% - channel
self.Slimfacenet_bottleneck_setting = [
# t, c , n ,s
[2, 48, 5, 2],
[4, 96, 1, 2],
[2, 96, 6, 1],
[4, 96, 1, 2],
[2, 96, 2, 1]
]
elif scale == 0.6:
# 60% - channel
self.Slimfacenet_bottleneck_setting = [
# t, c , n ,s
[2, 40, 5, 2],
[4, 76, 1, 2],
[2, 76, 6, 1],
[4, 76, 1, 2],
[2, 76, 2, 1]
]
else:
print('WRONG scale')
exit()
self.extract_feature = True
def set_extract_feature_flag(self, flag):
self.extract_feature = flag
def net(self, input, label=None):
x = self.conv_bn_layer(
input,
filter_size=3,
num_filters=64,
stride=2,
padding=1,
num_groups=1,
if_act=True,
name='conv3x3')
x = self.conv_bn_layer(
x,
filter_size=3,
num_filters=64,
stride=1,
padding=1,
num_groups=64,
if_act=True,
name='dw_conv3x3')
in_c = 64
cnt = 0
for _exp, out_c, times, _stride in self.Slimfacenet_bottleneck_setting:
for i in range(times):
stride = _stride if i == 0 else 1
filter_size, exp, se = self.table[self.arch[cnt]]
se = False if se == 0 else True
x = self.residual_unit(
x,
num_in_filter=in_c,
num_out_filter=out_c,
stride=stride,
filter_size=filter_size,
expansion_factor=exp,
use_se=se,
name='residual_unit' + str(cnt + 1))
cnt += 1
in_c = out_c
out_c = 512
x = self.conv_bn_layer(
x,
filter_size=1,
num_filters=out_c,
stride=1,
padding=0,
num_groups=1,
if_act=True,
name='conv1x1')
x = self.conv_bn_layer(
x,
filter_size=(7, 6),
num_filters=out_c,
stride=1,
padding=0,
num_groups=out_c,
if_act=False,
name='global_dw_conv7x7')
x = fluid.layers.conv2d(
x,
num_filters=128,
filter_size=1,
stride=1,
padding=0,
groups=1,
act=None,
use_cudnn=True,
param_attr=ParamAttr(
name='linear_conv1x1_weights',
initializer=MSRA(),
regularizer=fluid.regularizer.L2Decay(4e-4)),
bias_attr=False)
bn_name = 'linear_conv1x1_bn'
x = fluid.layers.batch_norm(
x,
param_attr=ParamAttr(name=bn_name + "_scale"),
bias_attr=ParamAttr(name=bn_name + "_offset"),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
x = fluid.layers.reshape(x, shape=[x.shape[0], x.shape[1]])
if self.extract_feature:
return x
out = self.arc_margin_product(
x, label, self.class_dim, s=32.0, m=0.50, mode=2)
softmax = fluid.layers.softmax(input=out)
cost = fluid.layers.cross_entropy(input=softmax, label=label)
loss = fluid.layers.mean(x=cost)
acc = fluid.layers.accuracy(input=out, label=label, k=1)
return loss, acc
def residual_unit(self,
input,
num_in_filter,
num_out_filter,
stride,
filter_size,
expansion_factor,
use_se=False,
name=None):
num_expfilter = int(round(num_in_filter * expansion_factor))
input_data = input
expand_conv = self.conv_bn_layer(
input=input,
filter_size=1,
num_filters=num_expfilter,
stride=1,
padding=0,
if_act=True,
name=name + '_expand')
depthwise_conv = self.conv_bn_layer(
input=expand_conv,
filter_size=filter_size,
num_filters=num_expfilter,
stride=stride,
padding=int((filter_size - 1) // 2),
if_act=True,
num_groups=num_expfilter,
use_cudnn=True,
name=name + '_depthwise')
if use_se:
depthwise_conv = self.se_block(
input=depthwise_conv,
num_out_filter=num_expfilter,
name=name + '_se')
linear_conv = self.conv_bn_layer(
input=depthwise_conv,
filter_size=1,
num_filters=num_out_filter,
stride=1,
padding=0,
if_act=False,
name=name + '_linear')
if num_in_filter != num_out_filter or stride != 1:
return linear_conv
else:
return fluid.layers.elementwise_add(
x=input_data, y=linear_conv, act=None)
def se_block(self, input, num_out_filter, ratio=4, name=None):
num_mid_filter = int(num_out_filter // ratio)
pool = fluid.layers.pool2d(
input=input, pool_type='avg', global_pooling=True, use_cudnn=False)
conv1 = fluid.layers.conv2d(
input=pool,
filter_size=1,
num_filters=num_mid_filter,
act=None,
param_attr=ParamAttr(name=name + '_1_weights'),
bias_attr=ParamAttr(name=name + '_1_offset'))
conv1 = fluid.layers.prelu(
conv1,
mode='channel',
param_attr=ParamAttr(
name=name + '_prelu',
regularizer=fluid.regularizer.L2Decay(0.0)))
conv2 = fluid.layers.conv2d(
input=conv1,
filter_size=1,
num_filters=num_out_filter,
act='hard_sigmoid',
param_attr=ParamAttr(name=name + '_2_weights'),
bias_attr=ParamAttr(name=name + '_2_offset'))
scale = fluid.layers.elementwise_mul(x=input, y=conv2, axis=0)
return scale
def conv_bn_layer(self,
input,
filter_size,
num_filters,
stride,
padding,
num_groups=1,
if_act=True,
name=None,
use_cudnn=True):
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=padding,
groups=num_groups,
act=None,
use_cudnn=use_cudnn,
param_attr=ParamAttr(
name=name + '_weights', initializer=MSRA()),
bias_attr=False)
bn_name = name + '_bn'
bn = fluid.layers.batch_norm(
input=conv,
param_attr=ParamAttr(name=bn_name + "_scale"),
bias_attr=ParamAttr(name=bn_name + "_offset"),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
if if_act:
return fluid.layers.prelu(
bn,
mode='channel',
param_attr=ParamAttr(
name=name + '_prelu',
regularizer=fluid.regularizer.L2Decay(0.0)))
else:
return bn
def arc_margin_product(self, input, label, out_dim, s=32.0, m=0.50,
mode=2):
input_norm = fluid.layers.sqrt(
fluid.layers.reduce_sum(
fluid.layers.square(input), dim=1))
input = fluid.layers.elementwise_div(input, input_norm, axis=0)
weight = fluid.layers.create_parameter(
shape=[out_dim, input.shape[1]],
dtype='float32',
name='weight_norm',
attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Xavier(),
regularizer=fluid.regularizer.L2Decay(4e-4)))
weight_norm = fluid.layers.sqrt(
fluid.layers.reduce_sum(
fluid.layers.square(weight), dim=1))
weight = fluid.layers.elementwise_div(weight, weight_norm, axis=0)
weight = fluid.layers.transpose(weight, perm=[1, 0])
cosine = fluid.layers.mul(input, weight)
sine = fluid.layers.sqrt(1.0 - fluid.layers.square(cosine))
cos_m = math.cos(m)
sin_m = math.sin(m)
phi = cosine * cos_m - sine * sin_m
th = math.cos(math.pi - m)
mm = math.sin(math.pi - m) * m
if mode == 1:
phi = self.paddle_where_more_than(cosine, 0, phi, cosine)
elif mode == 2:
phi = self.paddle_where_more_than(cosine, th, phi, cosine - mm)
else:
pass
one_hot = fluid.layers.one_hot(input=label, depth=out_dim)
output = fluid.layers.elementwise_mul(
one_hot, phi) + fluid.layers.elementwise_mul(
(1.0 - one_hot), cosine)
output = output * s
return output
def paddle_where_more_than(self, target, limit, x, y):
mask = fluid.layers.cast(x=(target > limit), dtype='float32')
output = fluid.layers.elementwise_mul(
mask, x) + fluid.layers.elementwise_mul((1.0 - mask), y)
return output
def SlimFaceNet_A_x0_60(class_dim=None, scale=0.6, arch=None):
scale = 0.6
arch = [0, 1, 5, 1, 0, 2, 1, 2, 0, 1, 2, 1, 1, 0, 1]
return SlimFaceNet(class_dim=class_dim, scale=scale, arch=arch)
def SlimFaceNet_B_x0_75(class_dim=None, scale=0.6, arch=None):
scale = 0.75
arch = [1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 1, 3, 2, 2, 3]
return SlimFaceNet(class_dim=class_dim, scale=scale, arch=arch)
def SlimFaceNet_C_x0_75(class_dim=None, scale=0.6, arch=None):
scale = 0.75
arch = [1, 1, 2, 1, 0, 2, 1, 0, 1, 0, 1, 1, 2, 2, 3]
return SlimFaceNet(class_dim=class_dim, scale=scale, arch=arch)
if __name__ == "__main__":
x = fluid.data(name='x', shape=[-1, 3, 112, 112], dtype='float32')
print(x.shape)
model = SlimFaceNet(10000, [1, 3, 3, 1, 1, 0, 0, 1, 0, 1, 1, 0, 5, 5, 3])
y = model.net(x)
# slimfacenet使用示例
本示例将演示如何训练`slimfacenet`及评测`slimfacenet`量化模型。
当前示例支持以下人脸识别模型:
- `SlimFaceNet_A_x0_60`
- `SlimFaceNet_B_x0_75`
## 1. 数据准备
本示例支持`CASIA``lfw`两种公开数据集默认情况:
1). 训练数据集位置`./CASIA`
2). 测试数据集位置`./lfw`
## 2. 下载预训练模型
如果使用预先训练并量化好的`slimfacenet`模型,可以从以下地址下载
## 3. 启动`slimfacenet`训练任务
通过以下命令启动训练任务:
```
sh slim_train.sh
或者
export CUDA_VISIBLE_DEVICES=0
python -u train_eval.py \
--action train \
--model=SlimFaceNet_B_x0_75
```
其中,SlimFaceNet_A_x0_60是`slimfacenet`搜索空间中的一个模型结构,通道数的缩放系数为0.6,
在每个缩放系数下搜索空间中都共有6**15(约4700亿)种不同的模型结构。模型训练好之后会保存在`./out_inference/`
## 4. 将float32模型量化为int8模型
通过以下命令启动训练任务:
```
sh slim_quant.sh
或者
export CUDA_VISIBLE_DEVICES=0
python -u train_eval.py --action quant
```
执行完之后量化模型会保存在`./quant_model/`, 注当前阶段量化模型还是是按float32保存的,转paddlelite后会变为int8
## 4. 加载和评估量化模型
本节介绍如何加载并评测预先训练好并量化后的模型。
执行以下代码加载模型并评估模型在测试集上的指标。
```
将量化模型默认地址在`./quant_model/`
sh slim_eval.sh
或者
export CUDA_VISIBLE_DEVICES=0
python train_eval.py --action test
```
# ================================================================
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import scipy.misc
import os
import paddle
from paddle import fluid
class CASIA_Face(object):
def __init__(self, root):
self.root = root
img_txt_dir = os.path.join(root, 'CASIA-WebFace-112X96.txt')
image_list = []
label_list = []
with open(img_txt_dir) as f:
img_label_list = f.read().splitlines()
for info in img_label_list:
image_dir, label_name = info.split(' ')
image_list.append(
os.path.join(root, 'CASIA-WebFace-112X96', image_dir))
label_list.append(int(label_name))
self.image_list = image_list
self.label_list = label_list
self.class_nums = len(np.unique(self.label_list))
self.shuffle_idx = list(
np.random.choice(
len(self.image_list), len(self.image_list), False))
def reader(self):
while True:
if len(self.shuffle_idx) == 0:
self.shuffle_idx = list(
np.random.choice(
len(self.image_list), len(self.image_list), False))
return
index = self.shuffle_idx.pop()
img_path = self.image_list[index]
target = self.label_list[index]
try:
img = scipy.misc.imread(img_path)
except:
continue
if len(img.shape) == 2:
img = np.stack([img] * 3, 2)
flip = np.random.choice(2) * 2 - 1
img = img[:, ::flip, :]
img = (img - 127.5) / 128.0
img = img.transpose(2, 0, 1)
yield img, target
def __len__(self):
return len(self.image_list)
if __name__ == '__main__':
data_dir = 'PATH to CASIA dataset'
place = fluid.CPUPlace()
with fluid.dygraph.guard(place):
dataset = CASIA_Face(root=data_dir)
print(len(dataset))
print(dataset.class_nums)
trainloader = paddle.batch(
dataset.reader, batch_size=1, drop_last=False)
for i in range(10):
for data in trainloader():
img = np.array([x[0] for x in data]).astype('float32')
img = fluid.dygraph.to_variable(img)
print(img.shape)
label = np.array([x[1] for x in data]).astype('int64').reshape(
-1, 1)
label = fluid.dygraph.to_variable(label)
print(label.shape)
print(len(dataset))
# ================================================================
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import scipy.misc
import paddle
from paddle import fluid
class LFW(object):
def __init__(self, imgl, imgr):
self.imgl_list = imgl
self.imgr_list = imgr
self.shuffle_idx = [i for i in range(len(self.imgl_list))]
def reader(self):
while True:
if len(self.shuffle_idx) == 0:
self.shuffle_idx = [i for i in range(len(self.imgl_list))]
return
index = self.shuffle_idx.pop(0)
imgl = scipy.misc.imread(self.imgl_list[index])
if len(imgl.shape) == 2:
imgl = np.stack([imgl] * 3, 2)
imgr = scipy.misc.imread(self.imgr_list[index])
if len(imgr.shape) == 2:
imgr = np.stack([imgr] * 3, 2)
imglist = [imgl, imgl[:, ::-1, :], imgr, imgr[:, ::-1, :]]
for i in range(len(imglist)):
imglist[i] = (imglist[i] - 127.5) / 128.0
imglist[i] = imglist[i].transpose(2, 0, 1)
imgs = [img.astype('float32') for img in imglist]
yield imgs
def __len__(self):
return len(self.imgl_list)
if __name__ == '__main__':
pass
# ================================================================
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import argparse
import time
import scipy.io
import numpy as np
import paddle
from paddle import fluid
from dataloader.casia import CASIA_Face
from dataloader.lfw import LFW
from paddleslim import models
def parse_filelist(root):
with open(os.path.join(root, 'pairs.txt')) as f:
pairs = f.read().splitlines()[1:]
folder_name = 'lfw-112X96'
nameLs = []
nameRs = []
folds = []
flags = []
for i, p in enumerate(pairs):
p = p.split('\t')
if len(p) == 3:
nameL = os.path.join(root, folder_name, p[0],
p[0] + '_' + '{:04}.jpg'.format(int(p[1])))
nameR = os.path.join(root, folder_name, p[0],
p[0] + '_' + '{:04}.jpg'.format(int(p[2])))
fold = i // 600
flag = 1
elif len(p) == 4:
nameL = os.path.join(root, folder_name, p[0],
p[0] + '_' + '{:04}.jpg'.format(int(p[1])))
nameR = os.path.join(root, folder_name, p[2],
p[2] + '_' + '{:04}.jpg'.format(int(p[3])))
fold = i // 600
flag = -1
nameLs.append(nameL)
nameRs.append(nameR)
folds.append(fold)
flags.append(flag)
return [nameLs, nameRs, folds, flags]
def get_accuracy(scores, flags, threshold):
p = np.sum(scores[flags == 1] > threshold)
n = np.sum(scores[flags == -1] < threshold)
return 1.0 * (p + n) / len(scores)
def get_threshold(scores, flags, thrNum):
accuracys = np.zeros((2 * thrNum + 1, 1))
thresholds = np.arange(-thrNum, thrNum + 1) * 1.0 / thrNum
for i in range(2 * thrNum + 1):
accuracys[i] = get_accuracy(scores, flags, thresholds[i])
max_index = np.squeeze(accuracys == np.max(accuracys))
bestThreshold = np.mean(thresholds[max_index])
return bestThreshold
def evaluation_10_fold(root='result.mat'):
ACCs = np.zeros(10)
result = scipy.io.loadmat(root)
for i in range(10):
fold = result['fold']
flags = result['flag']
featureLs = result['fl']
featureRs = result['fr']
valFold = fold != i
testFold = fold == i
flags = np.squeeze(flags)
mu = np.mean(
np.concatenate(
(featureLs[valFold[0], :], featureRs[valFold[0], :]), 0), 0)
mu = np.expand_dims(mu, 0)
featureLs = featureLs - mu
featureRs = featureRs - mu
featureLs = featureLs / np.expand_dims(
np.sqrt(np.sum(np.power(featureLs, 2), 1)), 1)
featureRs = featureRs / np.expand_dims(
np.sqrt(np.sum(np.power(featureRs, 2), 1)), 1)
scores = np.sum(np.multiply(featureLs, featureRs), 1)
threshold = get_threshold(scores[valFold[0]], flags[valFold[0]], 10000)
ACCs[i] = get_accuracy(scores[testFold[0]], flags[testFold[0]],
threshold)
return ACCs
def test(test_reader, flods, flags, net, args):
net.eval()
featureLs = None
featureRs = None
for idx, data in enumerate(test_reader()):
data_list = [[] for _ in range(4)]
for _ in range(len(data)):
data_list[0].append(data[_][0])
data_list[1].append(data[_][1])
data_list[2].append(data[_][2])
data_list[3].append(data[_][3])
res = [
net(fluid.dygraph.to_variable(np.array(d))).numpy()
for d in data_list
]
featureL = np.concatenate((res[0], res[1]), 1)
featureR = np.concatenate((res[2], res[3]), 1)
if featureLs is None:
featureLs = featureL
else:
featureLs = np.concatenate((featureLs, featureL), 0)
if featureRs is None:
featureRs = featureR
else:
featureRs = np.concatenate((featureRs, featureR), 0)
result = {'fl': featureLs, 'fr': featureRs, 'fold': flods, 'flag': flags}
scipy.io.savemat(args.feature_save_dir, result)
ACCs = evaluation_10_fold(args.feature_save_dir)
for i in range(len(ACCs)):
print('{} {:.2f}'.format(i + 1, ACCs[i] * 100))
print('--------')
print('AVE {:.2f}'.format(np.mean(ACCs) * 100))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='PaddlePaddle SlimFaceNet')
parser.add_argument(
'--use_gpu', default=0, type=int, help='Use GPU or not, 0 is not used')
parser.add_argument(
'--test_data_dir', default='./lfw', type=str, help='lfw_data_dir')
parser.add_argument(
'--resume', default='output/0', type=str, help='resume')
parser.add_argument(
'--feature_save_dir',
default='result.mat',
type=str,
help='The path of the extract features save, must be .mat file')
args = parser.parse_args()
place = fluid.CPUPlace() if args.use_gpu == 0 else fluid.CUDAPlace(0)
with fluid.dygraph.guard(place):
train_dataset = CASIA_Face(root=args.train_data_dir)
nl, nr, flods, flags = parse_filelist(args.test_data_dir)
test_dataset = LFW(nl, nr)
test_reader = paddle.batch(
test_dataset.reader,
batch_size=args.test_batchsize,
drop_last=False)
net = models.__dict__[args.model](class_dim=train_dataset.class_nums)
if args.resume:
assert os.path.exists(args.resume + ".pdparams"
), "Given dir {}.pdparams not exist.".format(
args.resume)
para_dict, opti_dict = fluid.dygraph.load_dygraph(args.resume)
net.set_dict(para_dict)
test(test_reader, flods, flags, net, args)
# ================================================================
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#!/bin/bash
export CUDA_VISIBLE_DEVICES=0
python train_eval.py --action test \
--train_data_dir=/PATH_TO_CASIA_Dataset \
--test_data_dir=/PATH_TO_lfw \
# ================================================================
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#!/bin/bash
export CUDA_VISIBLE_DEVICES=0
python train_eval.py --action quant \
--train_data_dir=/PATH_TO_CASIA_Dataset \
--test_data_dir=/PATH_TO_lfw \
# ================================================================
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#!/bin/bash
export CUDA_VISIBLE_DEVICES=0
python -u train_eval.py \
--train_data_dir=/PATH_TO_CASIA_Dataset \
--test_data_dir=/PATH_TO_LFW \
--action train \
--model=SlimFaceNet_B_x0_75
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import shutil
import subprocess
import argparse
import time
import scipy.io
import numpy as np
import paddle
import paddle.fluid as fluid
import paddle.fluid.compiler as compiler
from dataloader.casia import CASIA_Face
from dataloader.lfw import LFW
from lfw_eval import parse_filelist, evaluation_10_fold
from paddleslim import models
from paddleslim.quant import quant_post
def now():
return time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))
def creat_optimizer(args, trainset_scale):
start_step = trainset_scale * args.start_epoch // args.train_batchsize
if args.lr_strategy == 'piecewise_decay':
bd = [
trainset_scale * int(e) // args.train_batchsize
for e in args.lr_steps.strip().split(',')
]
lr = [float(e) for e in args.lr_list.strip().split(',')]
assert len(bd) == len(lr) - 1
optimizer = fluid.optimizer.Momentum(
learning_rate=fluid.layers.piecewise_decay(
boundaries=bd, values=lr),
momentum=0.9,
regularization=fluid.regularizer.L2Decay(args.l2_decay))
elif args.lr_strategy == 'cosine_decay':
lr = args.lr
step_each_epoch = trainset_scale // args.train_batchsize
optimizer = fluid.optimizer.Momentum(
learning_rate=fluid.layers.cosine_decay(lr, step_each_epoch,
args.total_epoch),
momentum=0.9,
regularization=fluid.regularizer.L2Decay(args.l2_decay))
else:
print('Wrong learning rate strategy')
exit()
return optimizer
def test(test_exe, test_program, test_out, args):
featureLs = None
featureRs = None
out_feature, test_reader, flods, flags = test_out
for idx, data in enumerate(test_reader()):
res = []
res.append(
test_exe.run(test_program,
feed={u'image_test': data[0][u'image_test1']},
fetch_list=out_feature))
res.append(
test_exe.run(test_program,
feed={u'image_test': data[0][u'image_test2']},
fetch_list=out_feature))
res.append(
test_exe.run(test_program,
feed={u'image_test': data[0][u'image_test3']},
fetch_list=out_feature))
res.append(
test_exe.run(test_program,
feed={u'image_test': data[0][u'image_test4']},
fetch_list=out_feature))
featureL = np.concatenate((res[0][0], res[1][0]), 1)
featureR = np.concatenate((res[2][0], res[3][0]), 1)
if featureLs is None:
featureLs = featureL
else:
featureLs = np.concatenate((featureLs, featureL), 0)
if featureRs is None:
featureRs = featureR
else:
featureRs = np.concatenate((featureRs, featureR), 0)
result = {'fl': featureLs, 'fr': featureRs, 'fold': flods, 'flag': flags}
scipy.io.savemat(args.feature_save_dir, result)
ACCs = evaluation_10_fold(args.feature_save_dir)
with open(os.path.join(args.save_ckpt, 'log.txt'), 'a+') as f:
f.writelines('eval model {}\n'.format(args.model))
for i in range(len(ACCs)):
print('{} {}'.format(i + 1, ACCs[i] * 100))
with open(os.path.join(args.save_ckpt, 'log.txt'), 'a+') as f:
f.writelines('{} {}\n'.format(i + 1, ACCs[i] * 100))
print('--------')
print('AVE {}'.format(np.mean(ACCs) * 100))
with open(os.path.join(args.save_ckpt, 'log.txt'), 'a+') as f:
f.writelines('--------\n')
f.writelines('AVE {}\n'.format(np.mean(ACCs) * 100))
return np.mean(ACCs) * 100
def train(exe, train_program, train_out, test_program, test_out, args):
loss, acc, global_lr, train_reader = train_out
fetch_list_train = [loss.name, acc.name, global_lr.name]
build_strategy = fluid.BuildStrategy()
build_strategy.fuse_all_optimizer_ops = True
compiled_prog = compiler.CompiledProgram(
train_program, build_strategy=build_strategy).with_data_parallel(
loss_name=loss.name, build_strategy=build_strategy)
for epoch_id in range(args.start_epoch, args.total_epoch):
for batch_id, data in enumerate(train_reader()):
loss, acc, global_lr = exe.run(compiled_prog,
feed=data,
fetch_list=fetch_list_train)
avg_loss = np.mean(np.array(loss))
avg_acc = np.mean(np.array(acc))
print(
'{} Epoch: {:^4d} step: {:^4d} loss: {:.6f}, acc: {:.6f}, lr: {}'.
format(now(), epoch_id, batch_id, avg_loss, avg_acc,
float(np.mean(np.array(global_lr)))))
if batch_id % args.save_frequency == 0:
model_path = os.path.join(args.save_ckpt, str(epoch_id))
fluid.io.save_persistables(
executor=exe, dirname=model_path, main_program=train_program)
test(exe, test_program, test_out, args)
out_feature, test_reader, flods, flags = test_out
fluid.io.save_inference_model(
executor=exe,
dirname='./out_inference',
feeded_var_names=['image_test'],
target_vars=[out_feature],
main_program=test_program)
def build_program(program, startup, args, is_train=True):
num_trainers = len(os.getenv('CUDA_VISIBLE_DEVICES').split(','))
places = fluid.cuda_places() if args.use_gpu else fluid.CPUPlace()
train_dataset = CASIA_Face(root=args.train_data_dir)
trainset_scale = len(train_dataset)
with fluid.program_guard(main_program=program, startup_program=startup):
with fluid.unique_name.guard():
# Model construction
model = models.__dict__[args.model](
class_dim=train_dataset.class_nums)
if is_train:
image = fluid.data(
name='image', shape=[-1, 3, 112, 96], dtype='float32')
label = fluid.data(name='label', shape=[-1, 1], dtype='int64')
train_reader = paddle.batch(
train_dataset.reader,
batch_size=args.train_batchsize // num_trainers,
drop_last=False)
reader = fluid.io.DataLoader.from_generator(
feed_list=[image, label],
capacity=64,
iterable=True,
return_list=False)
reader.set_sample_list_generator(train_reader, places=places)
model.extract_feature = False
loss, acc = model.net(image, label)
optimizer = creat_optimizer(args, trainset_scale)
optimizer.minimize(loss)
global_lr = optimizer._global_learning_rate()
out = (loss, acc, global_lr, reader)
else:
nl, nr, flods, flags = parse_filelist(args.test_data_dir)
test_dataset = LFW(nl, nr)
test_reader = paddle.batch(
test_dataset.reader,
batch_size=args.test_batchsize,
drop_last=False)
image_test = fluid.data(
name='image_test', shape=[-1, 3, 112, 96], dtype='float32')
image_test1 = fluid.data(
name='image_test1',
shape=[-1, 3, 112, 96],
dtype='float32')
image_test2 = fluid.data(
name='image_test2',
shape=[-1, 3, 112, 96],
dtype='float32')
image_test3 = fluid.data(
name='image_test3',
shape=[-1, 3, 112, 96],
dtype='float32')
image_test4 = fluid.data(
name='image_test4',
shape=[-1, 3, 112, 96],
dtype='float32')
reader = fluid.io.DataLoader.from_generator(
feed_list=[
image_test1, image_test2, image_test3, image_test4
],
capacity=64,
iterable=True,
return_list=False)
reader.set_sample_list_generator(
test_reader,
places=fluid.cuda_places()
if args.use_gpu else fluid.CPUPlace())
model.extract_feature = True
feature = model.net(image_test)
out = (feature, reader, flods, flags)
return out
def quant_val_reader_batch():
nl, nr, flods, flags = parse_filelist(args.test_data_dir)
test_dataset = LFW(nl, nr)
test_reader = paddle.batch(
test_dataset.reader, batch_size=1, drop_last=False)
shuffle_reader = fluid.io.shuffle(test_reader, 1)
def _reader():
while True:
for idx, data in enumerate(shuffle_reader()):
yield np.expand_dims(data[0][0], axis=0)
return _reader
def main():
global args
parser = argparse.ArgumentParser(description='PaddlePaddle SlimFaceNet')
parser.add_argument(
'--action', default='train', type=str, help='train/test/quant')
parser.add_argument(
'--model',
default='SlimFaceNet_B_x0_75',
type=str,
help='SlimFaceNet_B_x0_75/SlimFaceNet_C_x0_75/SlimFaceNet_A_x0_60')
parser.add_argument(
'--use_gpu', default=1, type=int, help='Use GPU or not, 0 is not used')
parser.add_argument(
'--lr_strategy',
default='piecewise_decay',
type=str,
help='lr_strategy')
parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
parser.add_argument(
'--lr_list',
default='0.1,0.01,0.001,0.0001',
type=str,
help='learning rate list (piecewise_decay)')
parser.add_argument(
'--lr_steps',
default='36,52,58',
type=str,
help='learning rate decay at which epochs')
parser.add_argument(
'--l2_decay', default=4e-5, type=float, help='base l2_decay')
parser.add_argument(
'--train_data_dir', default='./CASIA', type=str, help='train_data_dir')
parser.add_argument(
'--test_data_dir', default='./lfw', type=str, help='lfw_data_dir')
parser.add_argument(
'--train_batchsize', default=512, type=int, help='train_batchsize')
parser.add_argument(
'--test_batchsize', default=500, type=int, help='test_batchsize')
parser.add_argument(
'--img_shape', default='3,112,96', type=str, help='img_shape')
parser.add_argument(
'--start_epoch', default=0, type=int, help='start_epoch')
parser.add_argument(
'--total_epoch', default=80, type=int, help='total_epoch')
parser.add_argument(
'--save_frequency', default=1, type=int, help='save_frequency')
parser.add_argument(
'--save_ckpt', default='output', type=str, help='save_ckpt')
parser.add_argument(
'--feature_save_dir',
default='result.mat',
type=str,
help='The path of the extract features save, must be .mat file')
args = parser.parse_args()
num_trainers = len(os.getenv('CUDA_VISIBLE_DEVICES').split(','))
print(args)
print('num_trainers: {}'.format(num_trainers))
if args.save_ckpt == None:
args.save_ckpt = 'output'
if not os.path.exists(args.save_ckpt):
subprocess.call(['mkdir', '-p', args.save_ckpt])
with open(os.path.join(args.save_ckpt, 'log.txt'), 'w+') as f:
f.writelines(str(args) + '\n')
f.writelines('num_trainers: {}'.format(num_trainers) + '\n')
if args.action == 'train':
train_program = fluid.Program()
test_program = fluid.Program()
startup_program = fluid.Program()
if args.action == 'train':
train_out = build_program(train_program, startup_program, args, True)
test_out = build_program(test_program, startup_program, args, False)
test_program = test_program.clone(for_test=True)
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(startup_program)
if args.action == 'train':
train(exe, train_program, train_out, test_program, test_out, args)
elif args.action == 'quant':
quant_post(
executor=exe,
model_dir='./out_inference/',
quantize_model_path='./quant_model/',
sample_generator=quant_val_reader_batch(),
model_filename=None, #'model',
params_filename=None, #'params',
batch_size=100,
batch_nums=10)
elif args.action == 'test':
[inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model(
dirname='./quant_model/',
model_filename=None,
params_filename=None,
executor=exe)
nl, nr, flods, flags = parse_filelist(args.test_data_dir)
test_dataset = LFW(nl, nr)
test_reader = paddle.batch(
test_dataset.reader,
batch_size=args.test_batchsize,
drop_last=False)
image_test = fluid.data(
name='image_test', shape=[-1, 3, 112, 96], dtype='float32')
image_test1 = fluid.data(
name='image_test1', shape=[-1, 3, 112, 96], dtype='float32')
image_test2 = fluid.data(
name='image_test2', shape=[-1, 3, 112, 96], dtype='float32')
image_test3 = fluid.data(
name='image_test3', shape=[-1, 3, 112, 96], dtype='float32')
image_test4 = fluid.data(
name='image_test4', shape=[-1, 3, 112, 96], dtype='float32')
reader = fluid.io.DataLoader.from_generator(
feed_list=[image_test1, image_test2, image_test3, image_test4],
capacity=64,
iterable=True,
return_list=False)
reader.set_sample_list_generator(
test_reader,
places=fluid.cuda_places() if args.use_gpu else fluid.CPUPlace())
test_out = (fetch_targets, reader, flods, flags)
print('fetch_targets[0]: ', fetch_targets[0])
print('feed_target_names: ', feed_target_names)
test(exe, inference_program, test_out, args)
else:
print('WRONG ACTION')
if __name__ == '__main__':
main()
......@@ -14,5 +14,6 @@
from __future__ import absolute_import
from .util import image_classification
from .slimfacenet import SlimFaceNet_A_x0_60, SlimFaceNet_B_x0_75, SlimFaceNet_C_x0_75
__all__ = ["image_classification"]
# ================================================================
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import datetime
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.initializer import MSRA
from paddle.fluid.param_attr import ParamAttr
class SlimFaceNet():
def __init__(self, class_dim, scale=0.6, arch=None):
assert arch is not None
self.arch = arch
self.class_dim = class_dim
kernels = [3]
expansions = [2, 4, 6]
SE = [0, 1]
self.table = []
for k in kernels:
for e in expansions:
for se in SE:
self.table.append((k, e, se))
if scale == 1.0:
# 100% - channel
self.Slimfacenet_bottleneck_setting = [
# t, c , n ,s
[2, 64, 5, 2],
[4, 128, 1, 2],
[2, 128, 6, 1],
[4, 128, 1, 2],
[2, 128, 2, 1]
]
elif scale == 0.9:
# 90% - channel
self.Slimfacenet_bottleneck_setting = [
# t, c , n ,s
[2, 56, 5, 2],
[4, 116, 1, 2],
[2, 116, 6, 1],
[4, 116, 1, 2],
[2, 116, 2, 1]
]
elif scale == 0.75:
# 75% - channel
self.Slimfacenet_bottleneck_setting = [
# t, c , n ,s
[2, 48, 5, 2],
[4, 96, 1, 2],
[2, 96, 6, 1],
[4, 96, 1, 2],
[2, 96, 2, 1]
]
elif scale == 0.6:
# 60% - channel
self.Slimfacenet_bottleneck_setting = [
# t, c , n ,s
[2, 40, 5, 2],
[4, 76, 1, 2],
[2, 76, 6, 1],
[4, 76, 1, 2],
[2, 76, 2, 1]
]
else:
print('WRONG scale')
exit()
self.extract_feature = True
def set_extract_feature_flag(self, flag):
self.extract_feature = flag
def net(self, input, label=None):
x = self.conv_bn_layer(
input,
filter_size=3,
num_filters=64,
stride=2,
padding=1,
num_groups=1,
if_act=True,
name='conv3x3')
x = self.conv_bn_layer(
x,
filter_size=3,
num_filters=64,
stride=1,
padding=1,
num_groups=64,
if_act=True,
name='dw_conv3x3')
in_c = 64
cnt = 0
for _exp, out_c, times, _stride in self.Slimfacenet_bottleneck_setting:
for i in range(times):
stride = _stride if i == 0 else 1
filter_size, exp, se = self.table[self.arch[cnt]]
se = False if se == 0 else True
x = self.residual_unit(
x,
num_in_filter=in_c,
num_out_filter=out_c,
stride=stride,
filter_size=filter_size,
expansion_factor=exp,
use_se=se,
name='residual_unit' + str(cnt + 1))
cnt += 1
in_c = out_c
out_c = 512
x = self.conv_bn_layer(
x,
filter_size=1,
num_filters=out_c,
stride=1,
padding=0,
num_groups=1,
if_act=True,
name='conv1x1')
x = self.conv_bn_layer(
x,
filter_size=(7, 6),
num_filters=out_c,
stride=1,
padding=0,
num_groups=out_c,
if_act=False,
name='global_dw_conv7x7')
x = fluid.layers.conv2d(
x,
num_filters=128,
filter_size=1,
stride=1,
padding=0,
groups=1,
act=None,
use_cudnn=True,
param_attr=ParamAttr(
name='linear_conv1x1_weights',
initializer=MSRA(),
regularizer=fluid.regularizer.L2Decay(4e-4)),
bias_attr=False)
bn_name = 'linear_conv1x1_bn'
x = fluid.layers.batch_norm(
x,
param_attr=ParamAttr(name=bn_name + "_scale"),
bias_attr=ParamAttr(name=bn_name + "_offset"),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
x = fluid.layers.reshape(x, shape=[x.shape[0], x.shape[1]])
if self.extract_feature:
return x
out = self.arc_margin_product(
x, label, self.class_dim, s=32.0, m=0.50, mode=2)
softmax = fluid.layers.softmax(input=out)
cost = fluid.layers.cross_entropy(input=softmax, label=label)
loss = fluid.layers.mean(x=cost)
acc = fluid.layers.accuracy(input=out, label=label, k=1)
return loss, acc
def residual_unit(self,
input,
num_in_filter,
num_out_filter,
stride,
filter_size,
expansion_factor,
use_se=False,
name=None):
num_expfilter = int(round(num_in_filter * expansion_factor))
input_data = input
expand_conv = self.conv_bn_layer(
input=input,
filter_size=1,
num_filters=num_expfilter,
stride=1,
padding=0,
if_act=True,
name=name + '_expand')
depthwise_conv = self.conv_bn_layer(
input=expand_conv,
filter_size=filter_size,
num_filters=num_expfilter,
stride=stride,
padding=int((filter_size - 1) // 2),
if_act=True,
num_groups=num_expfilter,
use_cudnn=True,
name=name + '_depthwise')
if use_se:
depthwise_conv = self.se_block(
input=depthwise_conv,
num_out_filter=num_expfilter,
name=name + '_se')
linear_conv = self.conv_bn_layer(
input=depthwise_conv,
filter_size=1,
num_filters=num_out_filter,
stride=1,
padding=0,
if_act=False,
name=name + '_linear')
if num_in_filter != num_out_filter or stride != 1:
return linear_conv
else:
return fluid.layers.elementwise_add(
x=input_data, y=linear_conv, act=None)
def se_block(self, input, num_out_filter, ratio=4, name=None):
num_mid_filter = int(num_out_filter // ratio)
pool = fluid.layers.pool2d(
input=input, pool_type='avg', global_pooling=True, use_cudnn=False)
conv1 = fluid.layers.conv2d(
input=pool,
filter_size=1,
num_filters=num_mid_filter,
act=None,
param_attr=ParamAttr(name=name + '_1_weights'),
bias_attr=ParamAttr(name=name + '_1_offset'))
conv1 = fluid.layers.prelu(
conv1,
mode='channel',
param_attr=ParamAttr(
name=name + '_prelu',
regularizer=fluid.regularizer.L2Decay(0.0)))
conv2 = fluid.layers.conv2d(
input=conv1,
filter_size=1,
num_filters=num_out_filter,
act='hard_sigmoid',
param_attr=ParamAttr(name=name + '_2_weights'),
bias_attr=ParamAttr(name=name + '_2_offset'))
scale = fluid.layers.elementwise_mul(x=input, y=conv2, axis=0)
return scale
def conv_bn_layer(self,
input,
filter_size,
num_filters,
stride,
padding,
num_groups=1,
if_act=True,
name=None,
use_cudnn=True):
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=padding,
groups=num_groups,
act=None,
use_cudnn=use_cudnn,
param_attr=ParamAttr(
name=name + '_weights', initializer=MSRA()),
bias_attr=False)
bn_name = name + '_bn'
bn = fluid.layers.batch_norm(
input=conv,
param_attr=ParamAttr(name=bn_name + "_scale"),
bias_attr=ParamAttr(name=bn_name + "_offset"),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
if if_act:
return fluid.layers.prelu(
bn,
mode='channel',
param_attr=ParamAttr(
name=name + '_prelu',
regularizer=fluid.regularizer.L2Decay(0.0)))
else:
return bn
def arc_margin_product(self, input, label, out_dim, s=32.0, m=0.50,
mode=2):
input_norm = fluid.layers.sqrt(
fluid.layers.reduce_sum(
fluid.layers.square(input), dim=1))
input = fluid.layers.elementwise_div(input, input_norm, axis=0)
weight = fluid.layers.create_parameter(
shape=[out_dim, input.shape[1]],
dtype='float32',
name='weight_norm',
attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Xavier(),
regularizer=fluid.regularizer.L2Decay(4e-4)))
weight_norm = fluid.layers.sqrt(
fluid.layers.reduce_sum(
fluid.layers.square(weight), dim=1))
weight = fluid.layers.elementwise_div(weight, weight_norm, axis=0)
weight = fluid.layers.transpose(weight, perm=[1, 0])
cosine = fluid.layers.mul(input, weight)
sine = fluid.layers.sqrt(1.0 - fluid.layers.square(cosine))
cos_m = math.cos(m)
sin_m = math.sin(m)
phi = cosine * cos_m - sine * sin_m
th = math.cos(math.pi - m)
mm = math.sin(math.pi - m) * m
if mode == 1:
phi = self.paddle_where_more_than(cosine, 0, phi, cosine)
elif mode == 2:
phi = self.paddle_where_more_than(cosine, th, phi, cosine - mm)
else:
pass
one_hot = fluid.layers.one_hot(input=label, depth=out_dim)
output = fluid.layers.elementwise_mul(
one_hot, phi) + fluid.layers.elementwise_mul(
(1.0 - one_hot), cosine)
output = output * s
return output
def paddle_where_more_than(self, target, limit, x, y):
mask = fluid.layers.cast(x=(target > limit), dtype='float32')
output = fluid.layers.elementwise_mul(
mask, x) + fluid.layers.elementwise_mul((1.0 - mask), y)
return output
def SlimFaceNet_A_x0_60(class_dim=None, scale=0.6, arch=None):
scale = 0.6
arch = [0, 1, 5, 1, 0, 2, 1, 2, 0, 1, 2, 1, 1, 0, 1]
return SlimFaceNet(class_dim=class_dim, scale=scale, arch=arch)
def SlimFaceNet_B_x0_75(class_dim=None, scale=0.6, arch=None):
scale = 0.75
arch = [1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 1, 3, 2, 2, 3]
return SlimFaceNet(class_dim=class_dim, scale=scale, arch=arch)
def SlimFaceNet_C_x0_75(class_dim=None, scale=0.6, arch=None):
scale = 0.75
arch = [1, 1, 2, 1, 0, 2, 1, 0, 1, 0, 1, 1, 2, 2, 3]
return SlimFaceNet(class_dim=class_dim, scale=scale, arch=arch)
if __name__ == "__main__":
x = fluid.data(name='x', shape=[-1, 3, 112, 112], dtype='float32')
print(x.shape)
model = SlimFaceNet(10000, [1, 3, 3, 1, 1, 0, 0, 1, 0, 1, 1, 0, 5, 5, 3])
y = model.net(x)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册