diff --git a/demo/nas/block_sa_nas_mobilenetv2.py b/demo/nas/block_sa_nas_mobilenetv2.py index 9fc9f9f50f221ef647bd4373fdd112c54b785846..9bec555ffec8534249e8d60f1fd4afa02eba06b8 100644 --- a/demo/nas/block_sa_nas_mobilenetv2.py +++ b/demo/nas/block_sa_nas_mobilenetv2.py @@ -8,6 +8,7 @@ import time import paddle import paddle.nn as nn import paddle.nn.functional as F +import paddle.vision.transforms as T import paddle.static as static from paddle import ParamAttr from paddleslim.analysis import flops @@ -51,9 +52,12 @@ def conv_bn_layer(input, def search_mobilenetv2_block(config, args, image_size): image_shape = [3, image_size, image_size] + transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])]) if args.data == 'cifar10': - train_dataset = paddle.vision.datasets.Cifar10(mode='train') - val_dataset = paddle.vision.datasets.Cifar10(mode='test') + train_dataset = paddle.vision.datasets.Cifar10( + mode='train', transform=transform, backend='cv2') + val_dataset = paddle.vision.datasets.Cifar10( + mode='test', transform=transform, backend='cv2') elif args.data == 'imagenet': train_dataset = imagenet_reader.ImageNetDataset(mode='train') diff --git a/demo/nas/parl_nas_mobilenetv2.py b/demo/nas/parl_nas_mobilenetv2.py index 732c8f28dc6aef4e8848eeb9c0d2e0e1b02d7674..b07dc83fbeb509749a120541d611fd29c51ec650 100644 --- a/demo/nas/parl_nas_mobilenetv2.py +++ b/demo/nas/parl_nas_mobilenetv2.py @@ -11,6 +11,7 @@ import paddle import paddle.nn as nn import paddle.static as static import paddle.nn.functional as F +import paddle.vision.transforms as T from paddleslim.nas import RLNAS from paddleslim.common import get_logger from optimizer import create_optimizer @@ -94,8 +95,11 @@ def search_mobilenetv2(config, args, image_size, is_server=True): image_shape = [3, image_size, image_size] if args.data == 'cifar10': - train_dataset = paddle.vision.datasets.Cifar10(mode='train') - val_dataset = paddle.vision.datasets.Cifar10(mode='test') + transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])]) + train_dataset = paddle.vision.datasets.Cifar10( + mode='train', transform=transform, backend='cv2') + val_dataset = paddle.vision.datasets.Cifar10( + mode='test', transform=transform, backend='cv2') elif args.data == 'imagenet': train_dataset = imagenet_reader.ImageNetDataset(mode='train') diff --git a/demo/nas/rl_nas_mobilenetv2.py b/demo/nas/rl_nas_mobilenetv2.py index 27445d4d742883e1ff01ac825e62d726bf104150..904a65d8c82a1b46119e8ebf5d4aa78fcd0d4f8b 100644 --- a/demo/nas/rl_nas_mobilenetv2.py +++ b/demo/nas/rl_nas_mobilenetv2.py @@ -11,6 +11,7 @@ import paddle import paddle.nn as nn import paddle.static as static import paddle.nn.functional as F +import paddle.vision.transforms as T from paddleslim.nas import RLNAS from paddleslim.common import get_logger from optimizer import create_optimizer @@ -104,8 +105,11 @@ def search_mobilenetv2(config, args, image_size, is_server=True): image_shape = [3, image_size, image_size] if args.data == 'cifar10': - train_dataset = paddle.vision.datasets.Cifar10(mode='train') - val_dataset = paddle.vision.datasets.Cifar10(mode='test') + transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])]) + train_dataset = paddle.vision.datasets.Cifar10( + mode='train', transform=transform, backend='cv2') + val_dataset = paddle.vision.datasets.Cifar10( + mode='test', transform=transform, backend='cv2') elif args.data == 'imagenet': train_dataset = imagenet_reader.ImageNetDataset(mode='train') diff --git a/demo/nas/sa_nas_mobilenetv2.py b/demo/nas/sa_nas_mobilenetv2.py index 64e4748f484d8449da33ecff5b2417c0c077eba5..6eb557d16ded2f09f21c5580ba6e9756311691be 100644 --- a/demo/nas/sa_nas_mobilenetv2.py +++ b/demo/nas/sa_nas_mobilenetv2.py @@ -11,6 +11,7 @@ import paddle import paddle.nn as nn import paddle.static as static import paddle.nn.functional as F +import paddle.vision.transforms as T from paddle import ParamAttr from paddleslim.analysis import flops from paddleslim.nas import SANAS @@ -76,8 +77,11 @@ def build_program(main_program, def search_mobilenetv2(config, args, image_size, is_server=True): image_shape = [3, image_size, image_size] if args.data == 'cifar10': - train_dataset = paddle.vision.datasets.Cifar10(mode='train') - val_dataset = paddle.vision.datasets.Cifar10(mode='test') + transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])]) + train_dataset = paddle.vision.datasets.Cifar10( + mode='train', transform=transform, backend='cv2') + val_dataset = paddle.vision.datasets.Cifar10( + mode='test', transform=transform, backend='cv2') elif args.data == 'imagenet': train_dataset = imagenet_reader.ImageNetDataset(mode='train') @@ -182,8 +186,11 @@ def test_search_result(tokens, image_size, args, config): image_shape = [3, image_size, image_size] if args.data == 'cifar10': - train_dataset = paddle.vision.datasets.Cifar10(mode='train') - val_dataset = paddle.vision.datasets.Cifar10(mode='test') + transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])]) + train_dataset = paddle.vision.datasets.Cifar10( + mode='train', transform=transform, backend='cv2') + val_dataset = paddle.vision.datasets.Cifar10( + mode='test', transform=transform, backend='cv2') elif args.data == 'imagenet': train_dataset = imagenet_reader.ImageNetDataset(mode='train') diff --git a/docs/zh_cn/quick_start/nas_tutorial.md b/docs/zh_cn/quick_start/nas_tutorial.md index 7037187e178a2f8014a093a93792e16a99aae196..a7225c75fbc988a7e47ac1b3fdf088da3c3ea2bb 100644 --- a/docs/zh_cn/quick_start/nas_tutorial.md +++ b/docs/zh_cn/quick_start/nas_tutorial.md @@ -74,8 +74,8 @@ def build_program(archs): import paddle.vision.transforms as T def input_data(image, label): - transform = T.Compose([T.Normalize([127.5], [127.5])]) - train_dataset = paddle.vision.datasets.Cifar10(mode="train", transform=transform) + transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])]) + train_dataset = paddle.vision.datasets.Cifar10(mode="train", transform=transform, backend='cv2') train_loader = paddle.io.DataLoader(train_dataset, places=paddle.CPUPlace(), feed_list=[image, label], @@ -83,7 +83,7 @@ def input_data(image, label): batch_size=64, return_list=False, shuffle=True) - eval_dataset = paddle.vision.datasets.Cifar10(mode="test", transform=transform) + eval_dataset = paddle.vision.datasets.Cifar10(mode="test", transform=transform, backend='cv2') eval_loader = paddle.io.DataLoader(eval_dataset, places=paddle.CPUPlace(), feed_list=[image, label],