未验证 提交 f07e2124 编写于 作者: C ceci3 提交者: GitHub

fix dataloader (#592)

上级 197abda0
......@@ -8,6 +8,7 @@ import time
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import paddle.vision.transforms as T
import paddle.static as static
from paddle import ParamAttr
from paddleslim.analysis import flops
......@@ -51,9 +52,12 @@ def conv_bn_layer(input,
def search_mobilenetv2_block(config, args, image_size):
image_shape = [3, image_size, image_size]
transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])])
if args.data == 'cifar10':
train_dataset = paddle.vision.datasets.Cifar10(mode='train')
val_dataset = paddle.vision.datasets.Cifar10(mode='test')
train_dataset = paddle.vision.datasets.Cifar10(
mode='train', transform=transform, backend='cv2')
val_dataset = paddle.vision.datasets.Cifar10(
mode='test', transform=transform, backend='cv2')
elif args.data == 'imagenet':
train_dataset = imagenet_reader.ImageNetDataset(mode='train')
......
......@@ -11,6 +11,7 @@ import paddle
import paddle.nn as nn
import paddle.static as static
import paddle.nn.functional as F
import paddle.vision.transforms as T
from paddleslim.nas import RLNAS
from paddleslim.common import get_logger
from optimizer import create_optimizer
......@@ -94,8 +95,11 @@ def search_mobilenetv2(config, args, image_size, is_server=True):
image_shape = [3, image_size, image_size]
if args.data == 'cifar10':
train_dataset = paddle.vision.datasets.Cifar10(mode='train')
val_dataset = paddle.vision.datasets.Cifar10(mode='test')
transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])])
train_dataset = paddle.vision.datasets.Cifar10(
mode='train', transform=transform, backend='cv2')
val_dataset = paddle.vision.datasets.Cifar10(
mode='test', transform=transform, backend='cv2')
elif args.data == 'imagenet':
train_dataset = imagenet_reader.ImageNetDataset(mode='train')
......
......@@ -11,6 +11,7 @@ import paddle
import paddle.nn as nn
import paddle.static as static
import paddle.nn.functional as F
import paddle.vision.transforms as T
from paddleslim.nas import RLNAS
from paddleslim.common import get_logger
from optimizer import create_optimizer
......@@ -104,8 +105,11 @@ def search_mobilenetv2(config, args, image_size, is_server=True):
image_shape = [3, image_size, image_size]
if args.data == 'cifar10':
train_dataset = paddle.vision.datasets.Cifar10(mode='train')
val_dataset = paddle.vision.datasets.Cifar10(mode='test')
transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])])
train_dataset = paddle.vision.datasets.Cifar10(
mode='train', transform=transform, backend='cv2')
val_dataset = paddle.vision.datasets.Cifar10(
mode='test', transform=transform, backend='cv2')
elif args.data == 'imagenet':
train_dataset = imagenet_reader.ImageNetDataset(mode='train')
......
......@@ -11,6 +11,7 @@ import paddle
import paddle.nn as nn
import paddle.static as static
import paddle.nn.functional as F
import paddle.vision.transforms as T
from paddle import ParamAttr
from paddleslim.analysis import flops
from paddleslim.nas import SANAS
......@@ -76,8 +77,11 @@ def build_program(main_program,
def search_mobilenetv2(config, args, image_size, is_server=True):
image_shape = [3, image_size, image_size]
if args.data == 'cifar10':
train_dataset = paddle.vision.datasets.Cifar10(mode='train')
val_dataset = paddle.vision.datasets.Cifar10(mode='test')
transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])])
train_dataset = paddle.vision.datasets.Cifar10(
mode='train', transform=transform, backend='cv2')
val_dataset = paddle.vision.datasets.Cifar10(
mode='test', transform=transform, backend='cv2')
elif args.data == 'imagenet':
train_dataset = imagenet_reader.ImageNetDataset(mode='train')
......@@ -182,8 +186,11 @@ def test_search_result(tokens, image_size, args, config):
image_shape = [3, image_size, image_size]
if args.data == 'cifar10':
train_dataset = paddle.vision.datasets.Cifar10(mode='train')
val_dataset = paddle.vision.datasets.Cifar10(mode='test')
transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])])
train_dataset = paddle.vision.datasets.Cifar10(
mode='train', transform=transform, backend='cv2')
val_dataset = paddle.vision.datasets.Cifar10(
mode='test', transform=transform, backend='cv2')
elif args.data == 'imagenet':
train_dataset = imagenet_reader.ImageNetDataset(mode='train')
......
......@@ -74,8 +74,8 @@ def build_program(archs):
import paddle.vision.transforms as T
def input_data(image, label):
transform = T.Compose([T.Normalize([127.5], [127.5])])
train_dataset = paddle.vision.datasets.Cifar10(mode="train", transform=transform)
transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])])
train_dataset = paddle.vision.datasets.Cifar10(mode="train", transform=transform, backend='cv2')
train_loader = paddle.io.DataLoader(train_dataset,
places=paddle.CPUPlace(),
feed_list=[image, label],
......@@ -83,7 +83,7 @@ def input_data(image, label):
batch_size=64,
return_list=False,
shuffle=True)
eval_dataset = paddle.vision.datasets.Cifar10(mode="test", transform=transform)
eval_dataset = paddle.vision.datasets.Cifar10(mode="test", transform=transform, backend='cv2')
eval_loader = paddle.io.DataLoader(eval_dataset,
places=paddle.CPUPlace(),
feed_list=[image, label],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册