From 539a457f21cd96810476c099dee4ba3177746aa4 Mon Sep 17 00:00:00 2001 From: Teng Xi Date: Mon, 25 May 2020 15:23:57 +0800 Subject: [PATCH] add Paddle 1.8 PaddleSlim 1.1.1 demo (#304) (#306) * add Paddle 1.8 PaddleSlim 1.1.1 demo --- demo/slimfacenet/README.md | 2 ++ demo/slimfacenet/train_eval.py | 11 +++++------ 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/demo/slimfacenet/README.md b/demo/slimfacenet/README.md index f91de459..d7e6d273 100755 --- a/demo/slimfacenet/README.md +++ b/demo/slimfacenet/README.md @@ -2,6 +2,8 @@ 本示例将演示如何训练`slimfacenet`及评测`slimfacenet`量化模型。 +本示例依赖:Paddle 1.8 PaddleSlim 1.1.1 + 当前示例支持以下人脸识别模型: - `SlimFaceNet_A_x0_60` diff --git a/demo/slimfacenet/train_eval.py b/demo/slimfacenet/train_eval.py index 421496f0..a779feb3 100644 --- a/demo/slimfacenet/train_eval.py +++ b/demo/slimfacenet/train_eval.py @@ -233,9 +233,7 @@ def quant_val_reader_batch(): test_dataset = LFW(nl, nr) test_reader = paddle.batch( test_dataset.reader, batch_size=1, drop_last=False) - shuffle_index = args.seed if args.seed else np.random.randint(1000) - print('shuffle_index: {}'.format(shuffle_index)) - shuffle_reader = fluid.io.shuffle(test_reader, shuffle_index) + shuffle_reader = fluid.io.shuffle(test_reader, 3) def _reader(): while True: @@ -289,7 +287,6 @@ def main(): '--start_epoch', default=0, type=int, help='start_epoch') parser.add_argument( '--total_epoch', default=80, type=int, help='total_epoch') - parser.add_argument('--seed', default=None, type=int, help='shuffle seed') parser.add_argument( '--save_frequency', default=1, type=int, help='save_frequency') parser.add_argument( @@ -336,8 +333,10 @@ def main(): sample_generator=quant_val_reader_batch(), model_filename=None, #'model', params_filename=None, #'params', - batch_size=100, - batch_nums=10) + save_model_filename=None, #'model', + save_params_filename=None, #'params', + batch_size=np.random.randint(80, 160), + batch_nums=np.random.randint(4, 10)) elif args.action == 'test': [inference_program, feed_target_names, fetch_targets] = fluid.io.load_inference_model( -- GitLab