未验证 提交 f07e2124 编写于 作者: C ceci3 提交者: GitHub

fix dataloader (#592)

上级 197abda0
...@@ -8,6 +8,7 @@ import time ...@@ -8,6 +8,7 @@ import time
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
import paddle.nn.functional as F import paddle.nn.functional as F
import paddle.vision.transforms as T
import paddle.static as static import paddle.static as static
from paddle import ParamAttr from paddle import ParamAttr
from paddleslim.analysis import flops from paddleslim.analysis import flops
...@@ -51,9 +52,12 @@ def conv_bn_layer(input, ...@@ -51,9 +52,12 @@ def conv_bn_layer(input,
def search_mobilenetv2_block(config, args, image_size): def search_mobilenetv2_block(config, args, image_size):
image_shape = [3, image_size, image_size] image_shape = [3, image_size, image_size]
transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])])
if args.data == 'cifar10': if args.data == 'cifar10':
train_dataset = paddle.vision.datasets.Cifar10(mode='train') train_dataset = paddle.vision.datasets.Cifar10(
val_dataset = paddle.vision.datasets.Cifar10(mode='test') mode='train', transform=transform, backend='cv2')
val_dataset = paddle.vision.datasets.Cifar10(
mode='test', transform=transform, backend='cv2')
elif args.data == 'imagenet': elif args.data == 'imagenet':
train_dataset = imagenet_reader.ImageNetDataset(mode='train') train_dataset = imagenet_reader.ImageNetDataset(mode='train')
......
...@@ -11,6 +11,7 @@ import paddle ...@@ -11,6 +11,7 @@ import paddle
import paddle.nn as nn import paddle.nn as nn
import paddle.static as static import paddle.static as static
import paddle.nn.functional as F import paddle.nn.functional as F
import paddle.vision.transforms as T
from paddleslim.nas import RLNAS from paddleslim.nas import RLNAS
from paddleslim.common import get_logger from paddleslim.common import get_logger
from optimizer import create_optimizer from optimizer import create_optimizer
...@@ -94,8 +95,11 @@ def search_mobilenetv2(config, args, image_size, is_server=True): ...@@ -94,8 +95,11 @@ def search_mobilenetv2(config, args, image_size, is_server=True):
image_shape = [3, image_size, image_size] image_shape = [3, image_size, image_size]
if args.data == 'cifar10': if args.data == 'cifar10':
train_dataset = paddle.vision.datasets.Cifar10(mode='train') transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])])
val_dataset = paddle.vision.datasets.Cifar10(mode='test') train_dataset = paddle.vision.datasets.Cifar10(
mode='train', transform=transform, backend='cv2')
val_dataset = paddle.vision.datasets.Cifar10(
mode='test', transform=transform, backend='cv2')
elif args.data == 'imagenet': elif args.data == 'imagenet':
train_dataset = imagenet_reader.ImageNetDataset(mode='train') train_dataset = imagenet_reader.ImageNetDataset(mode='train')
......
...@@ -11,6 +11,7 @@ import paddle ...@@ -11,6 +11,7 @@ import paddle
import paddle.nn as nn import paddle.nn as nn
import paddle.static as static import paddle.static as static
import paddle.nn.functional as F import paddle.nn.functional as F
import paddle.vision.transforms as T
from paddleslim.nas import RLNAS from paddleslim.nas import RLNAS
from paddleslim.common import get_logger from paddleslim.common import get_logger
from optimizer import create_optimizer from optimizer import create_optimizer
...@@ -104,8 +105,11 @@ def search_mobilenetv2(config, args, image_size, is_server=True): ...@@ -104,8 +105,11 @@ def search_mobilenetv2(config, args, image_size, is_server=True):
image_shape = [3, image_size, image_size] image_shape = [3, image_size, image_size]
if args.data == 'cifar10': if args.data == 'cifar10':
train_dataset = paddle.vision.datasets.Cifar10(mode='train') transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])])
val_dataset = paddle.vision.datasets.Cifar10(mode='test') train_dataset = paddle.vision.datasets.Cifar10(
mode='train', transform=transform, backend='cv2')
val_dataset = paddle.vision.datasets.Cifar10(
mode='test', transform=transform, backend='cv2')
elif args.data == 'imagenet': elif args.data == 'imagenet':
train_dataset = imagenet_reader.ImageNetDataset(mode='train') train_dataset = imagenet_reader.ImageNetDataset(mode='train')
......
...@@ -11,6 +11,7 @@ import paddle ...@@ -11,6 +11,7 @@ import paddle
import paddle.nn as nn import paddle.nn as nn
import paddle.static as static import paddle.static as static
import paddle.nn.functional as F import paddle.nn.functional as F
import paddle.vision.transforms as T
from paddle import ParamAttr from paddle import ParamAttr
from paddleslim.analysis import flops from paddleslim.analysis import flops
from paddleslim.nas import SANAS from paddleslim.nas import SANAS
...@@ -76,8 +77,11 @@ def build_program(main_program, ...@@ -76,8 +77,11 @@ def build_program(main_program,
def search_mobilenetv2(config, args, image_size, is_server=True): def search_mobilenetv2(config, args, image_size, is_server=True):
image_shape = [3, image_size, image_size] image_shape = [3, image_size, image_size]
if args.data == 'cifar10': if args.data == 'cifar10':
train_dataset = paddle.vision.datasets.Cifar10(mode='train') transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])])
val_dataset = paddle.vision.datasets.Cifar10(mode='test') train_dataset = paddle.vision.datasets.Cifar10(
mode='train', transform=transform, backend='cv2')
val_dataset = paddle.vision.datasets.Cifar10(
mode='test', transform=transform, backend='cv2')
elif args.data == 'imagenet': elif args.data == 'imagenet':
train_dataset = imagenet_reader.ImageNetDataset(mode='train') train_dataset = imagenet_reader.ImageNetDataset(mode='train')
...@@ -182,8 +186,11 @@ def test_search_result(tokens, image_size, args, config): ...@@ -182,8 +186,11 @@ def test_search_result(tokens, image_size, args, config):
image_shape = [3, image_size, image_size] image_shape = [3, image_size, image_size]
if args.data == 'cifar10': if args.data == 'cifar10':
train_dataset = paddle.vision.datasets.Cifar10(mode='train') transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])])
val_dataset = paddle.vision.datasets.Cifar10(mode='test') train_dataset = paddle.vision.datasets.Cifar10(
mode='train', transform=transform, backend='cv2')
val_dataset = paddle.vision.datasets.Cifar10(
mode='test', transform=transform, backend='cv2')
elif args.data == 'imagenet': elif args.data == 'imagenet':
train_dataset = imagenet_reader.ImageNetDataset(mode='train') train_dataset = imagenet_reader.ImageNetDataset(mode='train')
......
...@@ -74,8 +74,8 @@ def build_program(archs): ...@@ -74,8 +74,8 @@ def build_program(archs):
import paddle.vision.transforms as T import paddle.vision.transforms as T
def input_data(image, label): def input_data(image, label):
transform = T.Compose([T.Normalize([127.5], [127.5])]) transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])])
train_dataset = paddle.vision.datasets.Cifar10(mode="train", transform=transform) train_dataset = paddle.vision.datasets.Cifar10(mode="train", transform=transform, backend='cv2')
train_loader = paddle.io.DataLoader(train_dataset, train_loader = paddle.io.DataLoader(train_dataset,
places=paddle.CPUPlace(), places=paddle.CPUPlace(),
feed_list=[image, label], feed_list=[image, label],
...@@ -83,7 +83,7 @@ def input_data(image, label): ...@@ -83,7 +83,7 @@ def input_data(image, label):
batch_size=64, batch_size=64,
return_list=False, return_list=False,
shuffle=True) shuffle=True)
eval_dataset = paddle.vision.datasets.Cifar10(mode="test", transform=transform) eval_dataset = paddle.vision.datasets.Cifar10(mode="test", transform=transform, backend='cv2')
eval_loader = paddle.io.DataLoader(eval_dataset, eval_loader = paddle.io.DataLoader(eval_dataset,
places=paddle.CPUPlace(), places=paddle.CPUPlace(),
feed_list=[image, label], feed_list=[image, label],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册