未验证 提交 5b717b1b 编写于 作者: Z Zhen Wang 提交者: GitHub

change for quant demo in lite tutorial. (#3250)

上级 8ed8df0b
...@@ -24,13 +24,15 @@ add_arg = functools.partial(add_arguments, argparser=parser) ...@@ -24,13 +24,15 @@ add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable # yapf: disable
add_arg('batch_size', int, 64*4, "Minibatch size.") add_arg('batch_size', int, 64*4, "Minibatch size.")
add_arg('use_gpu', bool, True, "Whether to use GPU or not.") add_arg('use_gpu', bool, True, "Whether to use GPU or not.")
add_arg('total_images', int, 1281167, "Training image number.")
add_arg('class_dim', int, 1000, "Class number.") add_arg('class_dim', int, 1000, "Class number.")
add_arg('image_shape', str, "3,224,224", "Input image size") add_arg('image_shape', str, "3,224,224", "Input image size")
add_arg('model', str, "MobileNet", "Set the network to use.") add_arg('model', str, "MobileNet", "Set the network to use.")
add_arg('pretrained_model', str, None, "Whether to use pretrained model.") add_arg('pretrained_model', str, None, "Whether to use pretrained model.")
add_arg('teacher_model', str, None, "Set the teacher network to use.") add_arg('teacher_model', str, None, "Set the teacher network to use.")
add_arg('teacher_pretrained_model', str, None, "Whether to use pretrained model.") add_arg('teacher_pretrained_model', str, None, "Whether to use pretrained model.")
add_arg('compress_config', str, None, "The config file for compression with yaml format.") add_arg('compress_config', str, None, "The config file for compression with yaml format.")
add_arg('quant_only', bool, False, "Only do quantization-aware training.")
# yapf: enable # yapf: enable
model_list = [m for m in dir(models) if "__" not in m] model_list = [m for m in dir(models) if "__" not in m]
...@@ -64,12 +66,20 @@ def compress(args): ...@@ -64,12 +66,20 @@ def compress(args):
acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1) acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1)
acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5) acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5)
val_program = fluid.default_main_program().clone() val_program = fluid.default_main_program().clone()
if args.quant_only:
boundaries=[args.total_images / args.batch_size * 10,
args.total_images / args.batch_size * 16]
values=[1e-4, 1e-5, 1e-6]
else:
boundaries=[args.total_images / args.batch_size * 30,
args.total_images / args.batch_size * 60,
args.total_images / args.batch_size * 90]
values=[0.1, 0.01, 0.001, 0.0001]
opt = fluid.optimizer.Momentum( opt = fluid.optimizer.Momentum(
momentum=0.9, momentum=0.9,
learning_rate=fluid.layers.piecewise_decay( learning_rate=fluid.layers.piecewise_decay(
boundaries=[5000 * 30, 5000 * 60, 5000 * 90], boundaries=boundaries,
values=[0.1, 0.01, 0.001, 0.0001]), values=values),
regularization=fluid.regularizer.L2Decay(4e-5)) regularization=fluid.regularizer.L2Decay(4e-5))
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace() place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
......
...@@ -3,7 +3,7 @@ strategies: ...@@ -3,7 +3,7 @@ strategies:
quantization_strategy: quantization_strategy:
class: 'QuantizationStrategy' class: 'QuantizationStrategy'
start_epoch: 0 start_epoch: 0
end_epoch: 20 end_epoch: 19
float_model_save_path: './output/float' float_model_save_path: './output/float'
mobile_model_save_path: './output/mobile' mobile_model_save_path: './output/mobile'
int8_model_save_path: './output/int8' int8_model_save_path: './output/int8'
...@@ -14,7 +14,7 @@ strategies: ...@@ -14,7 +14,7 @@ strategies:
save_in_nodes: ['image'] save_in_nodes: ['image']
save_out_nodes: ['fc_0.tmp_2'] save_out_nodes: ['fc_0.tmp_2']
compressor: compressor:
epoch: 21 epoch: 20
checkpoint_path: './checkpoints_quan/' checkpoint_path: './checkpoints_quan/'
strategies: strategies:
- quantization_strategy - quantization_strategy
...@@ -82,7 +82,8 @@ cd - ...@@ -82,7 +82,8 @@ cd -
#--batch_size 64 \ #--batch_size 64 \
#--model "MobileNet" \ #--model "MobileNet" \
#--pretrained_model ./pretrain/MobileNetV1_pretrained \ #--pretrained_model ./pretrain/MobileNetV1_pretrained \
#--compress_config ./configs/quantization.yaml #--compress_config ./configs/quantization.yaml \
#--quant_only True
# for distillation with quantization # for distillation with quantization
#----------------------------------- #-----------------------------------
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册