未验证 提交 3f162619 编写于 作者: T Teng Xi 提交者: GitHub

Update the PATH of CASIA and lfw in FaceNet (#288) (#296)

上级 f881efb1
...@@ -15,5 +15,5 @@ ...@@ -15,5 +15,5 @@
#!/bin/bash #!/bin/bash
export CUDA_VISIBLE_DEVICES=0 export CUDA_VISIBLE_DEVICES=0
python train_eval.py --action test \ python train_eval.py --action test \
--train_data_dir=/PATH_TO_CASIA_Dataset \ --train_data_dir=./CASIA/ \
--test_data_dir=/PATH_TO_lfw \ --test_data_dir=./lfw/ \
...@@ -15,5 +15,6 @@ ...@@ -15,5 +15,6 @@
#!/bin/bash #!/bin/bash
export CUDA_VISIBLE_DEVICES=0 export CUDA_VISIBLE_DEVICES=0
python train_eval.py --action quant \ python train_eval.py --action quant \
--train_data_dir=/PATH_TO_CASIA_Dataset \ --train_data_dir=./CASIA/ \
--test_data_dir=/PATH_TO_lfw \ --test_data_dir=./lfw/ \
--seed=1
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
#!/bin/bash #!/bin/bash
export CUDA_VISIBLE_DEVICES=0 export CUDA_VISIBLE_DEVICES=0
python -u train_eval.py \ python -u train_eval.py \
--train_data_dir=/PATH_TO_CASIA_Dataset \ --train_data_dir=./CASIA/ \
--test_data_dir=/PATH_TO_LFW \ --test_data_dir=./lfw/ \
--action train \ --action train \
--model=SlimFaceNet_B_x0_75 --model=SlimFaceNet_B_x0_75
...@@ -120,6 +120,7 @@ def train(exe, train_program, train_out, test_program, test_out, args): ...@@ -120,6 +120,7 @@ def train(exe, train_program, train_out, test_program, test_out, args):
compiled_prog = compiler.CompiledProgram( compiled_prog = compiler.CompiledProgram(
train_program, build_strategy=build_strategy).with_data_parallel( train_program, build_strategy=build_strategy).with_data_parallel(
loss_name=loss.name, build_strategy=build_strategy) loss_name=loss.name, build_strategy=build_strategy)
best_ave = 0
for epoch_id in range(args.start_epoch, args.total_epoch): for epoch_id in range(args.start_epoch, args.total_epoch):
for batch_id, data in enumerate(train_reader()): for batch_id, data in enumerate(train_reader()):
loss, acc, global_lr = exe.run(compiled_prog, loss, acc, global_lr = exe.run(compiled_prog,
...@@ -135,14 +136,17 @@ def train(exe, train_program, train_out, test_program, test_out, args): ...@@ -135,14 +136,17 @@ def train(exe, train_program, train_out, test_program, test_out, args):
model_path = os.path.join(args.save_ckpt, str(epoch_id)) model_path = os.path.join(args.save_ckpt, str(epoch_id))
fluid.io.save_persistables( fluid.io.save_persistables(
executor=exe, dirname=model_path, main_program=train_program) executor=exe, dirname=model_path, main_program=train_program)
test(exe, test_program, test_out, args) temp_ave = test(exe, test_program, test_out, args)
out_feature, test_reader, flods, flags = test_out if temp_ave > best_ave:
fluid.io.save_inference_model( best_ave = temp_ave
executor=exe, print('Best AVE: {}'.format(best_ave))
dirname='./out_inference', out_feature, test_reader, flods, flags = test_out
feeded_var_names=['image_test'], fluid.io.save_inference_model(
target_vars=[out_feature], executor=exe,
main_program=test_program) dirname='./out_inference',
feeded_var_names=['image_test'],
target_vars=[out_feature],
main_program=test_program)
def build_program(program, startup, args, is_train=True): def build_program(program, startup, args, is_train=True):
...@@ -229,7 +233,9 @@ def quant_val_reader_batch(): ...@@ -229,7 +233,9 @@ def quant_val_reader_batch():
test_dataset = LFW(nl, nr) test_dataset = LFW(nl, nr)
test_reader = paddle.batch( test_reader = paddle.batch(
test_dataset.reader, batch_size=1, drop_last=False) test_dataset.reader, batch_size=1, drop_last=False)
shuffle_reader = fluid.io.shuffle(test_reader, 1) shuffle_index = args.seed if args.seed else np.random.randint(1000)
print('shuffle_index: {}'.format(shuffle_index))
shuffle_reader = fluid.io.shuffle(test_reader, shuffle_index)
def _reader(): def _reader():
while True: while True:
...@@ -283,6 +289,7 @@ def main(): ...@@ -283,6 +289,7 @@ def main():
'--start_epoch', default=0, type=int, help='start_epoch') '--start_epoch', default=0, type=int, help='start_epoch')
parser.add_argument( parser.add_argument(
'--total_epoch', default=80, type=int, help='total_epoch') '--total_epoch', default=80, type=int, help='total_epoch')
parser.add_argument('--seed', default=None, type=int, help='shuffle seed')
parser.add_argument( parser.add_argument(
'--save_frequency', default=1, type=int, help='save_frequency') '--save_frequency', default=1, type=int, help='save_frequency')
parser.add_argument( parser.add_argument(
......
...@@ -362,7 +362,7 @@ def SlimFaceNet_B_x0_75(class_dim=None, scale=0.6, arch=None): ...@@ -362,7 +362,7 @@ def SlimFaceNet_B_x0_75(class_dim=None, scale=0.6, arch=None):
def SlimFaceNet_C_x0_75(class_dim=None, scale=0.6, arch=None): def SlimFaceNet_C_x0_75(class_dim=None, scale=0.6, arch=None):
scale = 0.75 scale = 0.75
arch = [1, 1, 2, 1, 0, 2, 1, 0, 1, 0, 1, 1, 2, 2, 3] arch = [1, 3, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 5, 5, 5]
return SlimFaceNet(class_dim=class_dim, scale=scale, arch=arch) return SlimFaceNet(class_dim=class_dim, scale=scale, arch=arch)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册