未验证 提交 3f162619 编写于 作者: T Teng Xi 提交者: GitHub

Update the PATH of CASIA and lfw in FaceNet (#288) (#296)

上级 f881efb1
......@@ -15,5 +15,5 @@
#!/bin/bash
export CUDA_VISIBLE_DEVICES=0
python train_eval.py --action test \
--train_data_dir=/PATH_TO_CASIA_Dataset \
--test_data_dir=/PATH_TO_lfw \
--train_data_dir=./CASIA/ \
--test_data_dir=./lfw/ \
......@@ -15,5 +15,6 @@
#!/bin/bash
export CUDA_VISIBLE_DEVICES=0
python train_eval.py --action quant \
--train_data_dir=/PATH_TO_CASIA_Dataset \
--test_data_dir=/PATH_TO_lfw \
--train_data_dir=./CASIA/ \
--test_data_dir=./lfw/ \
--seed=1
......@@ -15,7 +15,7 @@
#!/bin/bash
export CUDA_VISIBLE_DEVICES=0
python -u train_eval.py \
--train_data_dir=/PATH_TO_CASIA_Dataset \
--test_data_dir=/PATH_TO_LFW \
--train_data_dir=./CASIA/ \
--test_data_dir=./lfw/ \
--action train \
--model=SlimFaceNet_B_x0_75
......@@ -120,6 +120,7 @@ def train(exe, train_program, train_out, test_program, test_out, args):
compiled_prog = compiler.CompiledProgram(
train_program, build_strategy=build_strategy).with_data_parallel(
loss_name=loss.name, build_strategy=build_strategy)
best_ave = 0
for epoch_id in range(args.start_epoch, args.total_epoch):
for batch_id, data in enumerate(train_reader()):
loss, acc, global_lr = exe.run(compiled_prog,
......@@ -135,14 +136,17 @@ def train(exe, train_program, train_out, test_program, test_out, args):
model_path = os.path.join(args.save_ckpt, str(epoch_id))
fluid.io.save_persistables(
executor=exe, dirname=model_path, main_program=train_program)
test(exe, test_program, test_out, args)
out_feature, test_reader, flods, flags = test_out
fluid.io.save_inference_model(
executor=exe,
dirname='./out_inference',
feeded_var_names=['image_test'],
target_vars=[out_feature],
main_program=test_program)
temp_ave = test(exe, test_program, test_out, args)
if temp_ave > best_ave:
best_ave = temp_ave
print('Best AVE: {}'.format(best_ave))
out_feature, test_reader, flods, flags = test_out
fluid.io.save_inference_model(
executor=exe,
dirname='./out_inference',
feeded_var_names=['image_test'],
target_vars=[out_feature],
main_program=test_program)
def build_program(program, startup, args, is_train=True):
......@@ -229,7 +233,9 @@ def quant_val_reader_batch():
test_dataset = LFW(nl, nr)
test_reader = paddle.batch(
test_dataset.reader, batch_size=1, drop_last=False)
shuffle_reader = fluid.io.shuffle(test_reader, 1)
shuffle_index = args.seed if args.seed else np.random.randint(1000)
print('shuffle_index: {}'.format(shuffle_index))
shuffle_reader = fluid.io.shuffle(test_reader, shuffle_index)
def _reader():
while True:
......@@ -283,6 +289,7 @@ def main():
'--start_epoch', default=0, type=int, help='start_epoch')
parser.add_argument(
'--total_epoch', default=80, type=int, help='total_epoch')
parser.add_argument('--seed', default=None, type=int, help='shuffle seed')
parser.add_argument(
'--save_frequency', default=1, type=int, help='save_frequency')
parser.add_argument(
......
......@@ -362,7 +362,7 @@ def SlimFaceNet_B_x0_75(class_dim=None, scale=0.6, arch=None):
def SlimFaceNet_C_x0_75(class_dim=None, scale=0.6, arch=None):
scale = 0.75
arch = [1, 1, 2, 1, 0, 2, 1, 0, 1, 0, 1, 1, 2, 2, 3]
arch = [1, 3, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 5, 5, 5]
return SlimFaceNet(class_dim=class_dim, scale=scale, arch=arch)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册