未验证 提交 4f732271 编写于 作者: littletomatodonkey's avatar littletomatodonkey 提交者: GitHub

fix preset (#5454)

上级 fce73d50
......@@ -17,8 +17,13 @@ from . import functional as F
from .functional import InterpolationMode, _interpolation_modes_from_int
__all__ = [
"Compose", "ToTensor", "Normalize", "Resize", "CenterCrop",
"RandomResizedCrop", "RandomHorizontalFlip"
"Compose",
"ToTensor",
"Normalize",
"Resize",
"CenterCrop",
"RandomResizedCrop",
"RandomHorizontalFlip",
]
......
import paddle
from paddlevision.transforms import autoaugment, transforms
......@@ -12,16 +14,11 @@ class ClassificationPresetTrain:
trans = [transforms.RandomResizedCrop(crop_size)]
if hflip_prob > 0:
trans.append(transforms.RandomHorizontalFlip(hflip_prob))
#if auto_augment_policy is not None:
# aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy)
# trans.append(autoaugment.AutoAugment(policy=aa_policy))
trans.extend([
transforms.ToTensor(),
transforms.Normalize(
mean=mean, std=std),
])
#if random_erase_prob > 0:
# trans.append(transforms.RandomErasing(p=random_erase_prob))
self.transforms = transforms.Compose(trans)
......@@ -35,12 +32,14 @@ class ClassificationPresetEval:
resize_size=256,
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225)):
mean = tuple([m * 255 for m in mean])
std = tuple([s * 255 for s in std])
self.transforms = transforms.Compose([
transforms.Resize(resize_size),
transforms.CenterCrop(crop_size),
transforms.ToTensor(),
transforms.Normalize(
# fix to support pt-quant
paddle.vision.transforms.Transpose((2, 0, 1)),
paddle.vision.transforms.Normalize(
mean=mean, std=std),
])
......
......@@ -49,6 +49,7 @@ def main(args):
img = Image.open(f).convert('RGB')
img = eval_transforms(img)
img = paddle.to_tensor(img)
img = img.expand([1] + img.shape)
output = model(img).numpy()[0]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册