未验证 提交 9ca38e72 编写于 作者: B Bai Yifan 提交者: GitHub

fix dygraph quant demo dataset issue (#555)

* fix dygraph quant demo dataset issue

* fix dygraph quant demo dataset issue
上级 dacc4218
......@@ -22,7 +22,7 @@
### 配置量化参数
```
```python
quant_config = {
'weight_preprocess_type': None,
'activation_preprocess_type': None,
......@@ -70,9 +70,9 @@ quanter.save_quantized_model(net, 'save_dir', input_spec=[paddle.static.InputSpe
```bash
# 单卡训练
python train.py --model='mobilenet_v1'
python train.py --model=mobilenet_v1
# 多卡训练,以0到3号卡为例
python -m paddle.distributed.launch --gpus="0,1,2,3" train.py --model='mobilenet_v1'
python -m paddle.distributed.launch --gpus="0,1,2,3" train.py --model=mobilenet_v1
```
- MobileNetV3
......
......@@ -28,6 +28,7 @@ import numpy as np
from paddle.distributed import ParallelEnv
from paddle.static import load_program_state
from paddle.vision.models import mobilenet_v1
import paddle.vision.transforms as T
from paddleslim.common import get_logger
from paddleslim.dygraph.quant import QAT
......@@ -55,7 +56,7 @@ add_arg('use_pact', bool, False,
add_arg('momentum_rate', float, 0.9, "The value of momentum_rate.")
add_arg('num_epochs', int, 1, "The number of total epochs.")
add_arg('total_images', int, 1281167, "The number of total training images.")
add_arg('data', str, "imagenet", "Which data to use. 'mnist' or 'imagenet'")
add_arg('data', str, "imagenet", "Which data to use. 'cifar10' or 'imagenet'")
add_arg('log_period', int, 10, "Log period in batches.")
add_arg('model_save_dir', str, "./output_models", "model save directory.")
parser.add_argument('--step_epochs', nargs='+', type=int, default=[10, 20, 30], help="piecewise decay step")
......@@ -87,12 +88,16 @@ def load_dygraph_pretrain(model, path=None, load_static_weights=False):
def compress(args):
if args.data == "mnist":
train_dataset = paddle.vision.datasets.MNIST(mode='train')
val_dataset = paddle.vision.datasets.MNIST(mode='test')
if args.data == "cifar10":
transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])])
train_dataset = paddle.vision.datasets.Cifar10(
mode="train", backend="cv2", transform=transform)
val_dataset = paddle.vision.datasets.Cifar10(
mode="test", backend="cv2", transform=transform)
class_dim = 10
image_shape = "1,28,28"
args.total_images = 60000
image_shape = [3, 32, 32]
pretrain = False
args.total_images = 50000
elif args.data == "imagenet":
import imagenet_reader as reader
train_dataset = reader.ImageNetDataset(mode='train')
......@@ -199,6 +204,8 @@ def compress(args):
eval_reader_cost += time.time() - reader_start
image = data[0]
label = data[1]
if args.data == "cifar10":
label = paddle.reshape(label, [-1, 1])
eval_start = time.time()
......@@ -262,6 +269,8 @@ def compress(args):
image = data[0]
label = data[1]
if args.data == "cifar10":
label = paddle.reshape(label, [-1, 1])
train_start = time.time()
out = net(image)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册