From 9ca38e720f53668f4544c026e386850fd392d70c Mon Sep 17 00:00:00 2001 From: Bai Yifan Date: Thu, 17 Dec 2020 20:37:24 +0800 Subject: [PATCH] fix dygraph quant demo dataset issue (#555) * fix dygraph quant demo dataset issue * fix dygraph quant demo dataset issue --- demo/dygraph/quant/README.md | 6 +++--- demo/dygraph/quant/train.py | 21 +++++++++++++++------ 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/demo/dygraph/quant/README.md b/demo/dygraph/quant/README.md index c10068b4..17693014 100755 --- a/demo/dygraph/quant/README.md +++ b/demo/dygraph/quant/README.md @@ -22,7 +22,7 @@ ### 配置量化参数 -``` +```python quant_config = { 'weight_preprocess_type': None, 'activation_preprocess_type': None, @@ -70,9 +70,9 @@ quanter.save_quantized_model(net, 'save_dir', input_spec=[paddle.static.InputSpe ```bash # 单卡训练 - python train.py --model='mobilenet_v1' + python train.py --model=mobilenet_v1 # 多卡训练,以0到3号卡为例 - python -m paddle.distributed.launch --gpus="0,1,2,3" train.py --model='mobilenet_v1' + python -m paddle.distributed.launch --gpus="0,1,2,3" train.py --model=mobilenet_v1 ``` - MobileNetV3 diff --git a/demo/dygraph/quant/train.py b/demo/dygraph/quant/train.py index 43be2c38..25c0a3d1 100644 --- a/demo/dygraph/quant/train.py +++ b/demo/dygraph/quant/train.py @@ -28,6 +28,7 @@ import numpy as np from paddle.distributed import ParallelEnv from paddle.static import load_program_state from paddle.vision.models import mobilenet_v1 +import paddle.vision.transforms as T from paddleslim.common import get_logger from paddleslim.dygraph.quant import QAT @@ -55,7 +56,7 @@ add_arg('use_pact', bool, False, add_arg('momentum_rate', float, 0.9, "The value of momentum_rate.") add_arg('num_epochs', int, 1, "The number of total epochs.") add_arg('total_images', int, 1281167, "The number of total training images.") -add_arg('data', str, "imagenet", "Which data to use. 'mnist' or 'imagenet'") +add_arg('data', str, "imagenet", "Which data to use. 'cifar10' or 'imagenet'") add_arg('log_period', int, 10, "Log period in batches.") add_arg('model_save_dir', str, "./output_models", "model save directory.") parser.add_argument('--step_epochs', nargs='+', type=int, default=[10, 20, 30], help="piecewise decay step") @@ -87,12 +88,16 @@ def load_dygraph_pretrain(model, path=None, load_static_weights=False): def compress(args): - if args.data == "mnist": - train_dataset = paddle.vision.datasets.MNIST(mode='train') - val_dataset = paddle.vision.datasets.MNIST(mode='test') + if args.data == "cifar10": + transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])]) + train_dataset = paddle.vision.datasets.Cifar10( + mode="train", backend="cv2", transform=transform) + val_dataset = paddle.vision.datasets.Cifar10( + mode="test", backend="cv2", transform=transform) class_dim = 10 - image_shape = "1,28,28" - args.total_images = 60000 + image_shape = [3, 32, 32] + pretrain = False + args.total_images = 50000 elif args.data == "imagenet": import imagenet_reader as reader train_dataset = reader.ImageNetDataset(mode='train') @@ -199,6 +204,8 @@ def compress(args): eval_reader_cost += time.time() - reader_start image = data[0] label = data[1] + if args.data == "cifar10": + label = paddle.reshape(label, [-1, 1]) eval_start = time.time() @@ -262,6 +269,8 @@ def compress(args): image = data[0] label = data[1] + if args.data == "cifar10": + label = paddle.reshape(label, [-1, 1]) train_start = time.time() out = net(image) -- GitLab