未验证 提交 2ead22d0 编写于 作者: C ceci3 提交者: GitHub

fix ptq demo (#1067)

上级 6b8b683e
...@@ -88,7 +88,7 @@ python demo_imagenet.py \ ...@@ -88,7 +88,7 @@ python demo_imagenet.py \
--model_dir='infermodel_mobilenetv2' \ --model_dir='infermodel_mobilenetv2' \
--model_filename='inference.pdmodel' \ --model_filename='inference.pdmodel' \
--params_filename='./inference.pdiparams' \ --params_filename='./inference.pdiparams' \
--save_dir='./save_qat_mbv2/' \ --save_dir='./save_ptq_mbv2/' \
--devices='gpu' \ --devices='gpu' \
--batch_size=64 \ --batch_size=64 \
--data_dir='../data/ILSVRC2012/' \ --data_dir='../data/ILSVRC2012/' \
...@@ -118,7 +118,7 @@ python demo_imagenet.py \ ...@@ -118,7 +118,7 @@ python demo_imagenet.py \
--model_dir='infermodel_mobilenetv2' \ --model_dir='infermodel_mobilenetv2' \
--model_filename='inference.pdmodel' \ --model_filename='inference.pdmodel' \
--params_filename='./inference.pdiparams' \ --params_filename='./inference.pdiparams' \
--save_dir='./save_qat_mbv2/' \ --save_dir='./save_asp_mbv2/' \
--devices='gpu' \ --devices='gpu' \
--batch_size=64 \ --batch_size=64 \
--data_dir='../data/ILSVRC2012/' \ --data_dir='../data/ILSVRC2012/' \
......
...@@ -189,7 +189,7 @@ if __name__ == '__main__': ...@@ -189,7 +189,7 @@ if __name__ == '__main__':
strategy_config=compress_config, strategy_config=compress_config,
train_config=train_config, train_config=train_config,
train_dataloader=train_dataloader, train_dataloader=train_dataloader,
eval_callback=eval_function, eval_callback=eval_function if 'HyperParameterOptimization' not in compress_config else eval_dataloader,
devices=args.devices) devices=args.devices)
ac.compress() ac.compress()
...@@ -37,10 +37,12 @@ def reader_wrapper(reader): ...@@ -37,10 +37,12 @@ def reader_wrapper(reader):
return gen return gen
def eval_reader(data_dir, batch_size):
val_reader = paddle.batch(reader.val(data_dir=data_dir), batch_size=batch_size)
return val_reader
def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list): def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
val_reader = eval_reader(data_dir, batch_size=1)
val_reader = paddle.batch(reader.val(data_dir=data_dir), batch_size=1)
image = paddle.static.data( image = paddle.static.data(
name='x', shape=[None, 3, 224, 224], dtype='float32') name='x', shape=[None, 3, 224, 224], dtype='float32')
label = paddle.static.data(name='label', shape=[None, 1], dtype='int64') label = paddle.static.data(name='label', shape=[None, 1], dtype='int64')
...@@ -102,7 +104,7 @@ if __name__ == '__main__': ...@@ -102,7 +104,7 @@ if __name__ == '__main__':
strategy_config=compress_config, strategy_config=compress_config,
train_config=train_config, train_config=train_config,
train_dataloader=train_dataloader, train_dataloader=train_dataloader,
eval_callback=eval_function if 'HyperParameterOptimization' not in compress_config else None, eval_callback=eval_function if 'HyperParameterOptimization' not in compress_config else reader_wrapper(eval_reader(data_dir, 64)),
devices=args.devices) devices=args.devices)
ac.compress() ac.compress()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册