From 2ead22d06fd62b20b9a46357dfacf9d5ca9ab9fd Mon Sep 17 00:00:00 2001 From: ceci3 Date: Sat, 23 Apr 2022 18:07:44 +0800 Subject: [PATCH] fix ptq demo (#1067) --- demo/auto-compression/README.md | 4 ++-- demo/auto-compression/demo_glue.py | 2 +- demo/auto-compression/demo_imagenet.py | 8 +++++--- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/demo/auto-compression/README.md b/demo/auto-compression/README.md index b61d67d6..e69329ea 100644 --- a/demo/auto-compression/README.md +++ b/demo/auto-compression/README.md @@ -88,7 +88,7 @@ python demo_imagenet.py \ --model_dir='infermodel_mobilenetv2' \ --model_filename='inference.pdmodel' \ --params_filename='./inference.pdiparams' \ - --save_dir='./save_qat_mbv2/' \ + --save_dir='./save_ptq_mbv2/' \ --devices='gpu' \ --batch_size=64 \ --data_dir='../data/ILSVRC2012/' \ @@ -118,7 +118,7 @@ python demo_imagenet.py \ --model_dir='infermodel_mobilenetv2' \ --model_filename='inference.pdmodel' \ --params_filename='./inference.pdiparams' \ - --save_dir='./save_qat_mbv2/' \ + --save_dir='./save_asp_mbv2/' \ --devices='gpu' \ --batch_size=64 \ --data_dir='../data/ILSVRC2012/' \ diff --git a/demo/auto-compression/demo_glue.py b/demo/auto-compression/demo_glue.py index 569bd7cc..a2b0cfde 100644 --- a/demo/auto-compression/demo_glue.py +++ b/demo/auto-compression/demo_glue.py @@ -189,7 +189,7 @@ if __name__ == '__main__': strategy_config=compress_config, train_config=train_config, train_dataloader=train_dataloader, - eval_callback=eval_function, + eval_callback=eval_function if 'HyperParameterOptimization' not in compress_config else eval_dataloader, devices=args.devices) ac.compress() diff --git a/demo/auto-compression/demo_imagenet.py b/demo/auto-compression/demo_imagenet.py index 62394810..04543228 100644 --- a/demo/auto-compression/demo_imagenet.py +++ b/demo/auto-compression/demo_imagenet.py @@ -37,10 +37,12 @@ def reader_wrapper(reader): return gen +def eval_reader(data_dir, batch_size): + val_reader = paddle.batch(reader.val(data_dir=data_dir), batch_size=batch_size) + return val_reader def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list): - - val_reader = paddle.batch(reader.val(data_dir=data_dir), batch_size=1) + val_reader = eval_reader(data_dir, batch_size=1) image = paddle.static.data( name='x', shape=[None, 3, 224, 224], dtype='float32') label = paddle.static.data(name='label', shape=[None, 1], dtype='int64') @@ -102,7 +104,7 @@ if __name__ == '__main__': strategy_config=compress_config, train_config=train_config, train_dataloader=train_dataloader, - eval_callback=eval_function if 'HyperParameterOptimization' not in compress_config else None, + eval_callback=eval_function if 'HyperParameterOptimization' not in compress_config else reader_wrapper(eval_reader(data_dir, 64)), devices=args.devices) ac.compress() -- GitLab