未验证 提交 9ca38e72 编写于 作者: B Bai Yifan 提交者: GitHub

fix dygraph quant demo dataset issue (#555)

* fix dygraph quant demo dataset issue

* fix dygraph quant demo dataset issue
上级 dacc4218
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
### 配置量化参数 ### 配置量化参数
``` ```python
quant_config = { quant_config = {
'weight_preprocess_type': None, 'weight_preprocess_type': None,
'activation_preprocess_type': None, 'activation_preprocess_type': None,
...@@ -70,9 +70,9 @@ quanter.save_quantized_model(net, 'save_dir', input_spec=[paddle.static.InputSpe ...@@ -70,9 +70,9 @@ quanter.save_quantized_model(net, 'save_dir', input_spec=[paddle.static.InputSpe
```bash ```bash
# 单卡训练 # 单卡训练
python train.py --model='mobilenet_v1' python train.py --model=mobilenet_v1
# 多卡训练,以0到3号卡为例 # 多卡训练,以0到3号卡为例
python -m paddle.distributed.launch --gpus="0,1,2,3" train.py --model='mobilenet_v1' python -m paddle.distributed.launch --gpus="0,1,2,3" train.py --model=mobilenet_v1
``` ```
- MobileNetV3 - MobileNetV3
......
...@@ -28,6 +28,7 @@ import numpy as np ...@@ -28,6 +28,7 @@ import numpy as np
from paddle.distributed import ParallelEnv from paddle.distributed import ParallelEnv
from paddle.static import load_program_state from paddle.static import load_program_state
from paddle.vision.models import mobilenet_v1 from paddle.vision.models import mobilenet_v1
import paddle.vision.transforms as T
from paddleslim.common import get_logger from paddleslim.common import get_logger
from paddleslim.dygraph.quant import QAT from paddleslim.dygraph.quant import QAT
...@@ -55,7 +56,7 @@ add_arg('use_pact', bool, False, ...@@ -55,7 +56,7 @@ add_arg('use_pact', bool, False,
add_arg('momentum_rate', float, 0.9, "The value of momentum_rate.") add_arg('momentum_rate', float, 0.9, "The value of momentum_rate.")
add_arg('num_epochs', int, 1, "The number of total epochs.") add_arg('num_epochs', int, 1, "The number of total epochs.")
add_arg('total_images', int, 1281167, "The number of total training images.") add_arg('total_images', int, 1281167, "The number of total training images.")
add_arg('data', str, "imagenet", "Which data to use. 'mnist' or 'imagenet'") add_arg('data', str, "imagenet", "Which data to use. 'cifar10' or 'imagenet'")
add_arg('log_period', int, 10, "Log period in batches.") add_arg('log_period', int, 10, "Log period in batches.")
add_arg('model_save_dir', str, "./output_models", "model save directory.") add_arg('model_save_dir', str, "./output_models", "model save directory.")
parser.add_argument('--step_epochs', nargs='+', type=int, default=[10, 20, 30], help="piecewise decay step") parser.add_argument('--step_epochs', nargs='+', type=int, default=[10, 20, 30], help="piecewise decay step")
...@@ -87,12 +88,16 @@ def load_dygraph_pretrain(model, path=None, load_static_weights=False): ...@@ -87,12 +88,16 @@ def load_dygraph_pretrain(model, path=None, load_static_weights=False):
def compress(args): def compress(args):
if args.data == "mnist": if args.data == "cifar10":
train_dataset = paddle.vision.datasets.MNIST(mode='train') transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])])
val_dataset = paddle.vision.datasets.MNIST(mode='test') train_dataset = paddle.vision.datasets.Cifar10(
mode="train", backend="cv2", transform=transform)
val_dataset = paddle.vision.datasets.Cifar10(
mode="test", backend="cv2", transform=transform)
class_dim = 10 class_dim = 10
image_shape = "1,28,28" image_shape = [3, 32, 32]
args.total_images = 60000 pretrain = False
args.total_images = 50000
elif args.data == "imagenet": elif args.data == "imagenet":
import imagenet_reader as reader import imagenet_reader as reader
train_dataset = reader.ImageNetDataset(mode='train') train_dataset = reader.ImageNetDataset(mode='train')
...@@ -199,6 +204,8 @@ def compress(args): ...@@ -199,6 +204,8 @@ def compress(args):
eval_reader_cost += time.time() - reader_start eval_reader_cost += time.time() - reader_start
image = data[0] image = data[0]
label = data[1] label = data[1]
if args.data == "cifar10":
label = paddle.reshape(label, [-1, 1])
eval_start = time.time() eval_start = time.time()
...@@ -262,6 +269,8 @@ def compress(args): ...@@ -262,6 +269,8 @@ def compress(args):
image = data[0] image = data[0]
label = data[1] label = data[1]
if args.data == "cifar10":
label = paddle.reshape(label, [-1, 1])
train_start = time.time() train_start = time.time()
out = net(image) out = net(image)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册