diff --git a/dygraph/mobilenet/imagenet_dataset.py b/dygraph/mobilenet/imagenet_dataset.py index 1b39ac5690ae8cb78f7ca05c685c01d5b7a84d9c..4c00fcd4a1642975f676b07ba844eefc645b713f 100644 --- a/dygraph/mobilenet/imagenet_dataset.py +++ b/dygraph/mobilenet/imagenet_dataset.py @@ -38,13 +38,13 @@ class ImageNetDataset(DatasetFolder): self.transform = transforms.Compose([ transforms.RandomResizedCrop(image_size), transforms.RandomHorizontalFlip(), - transforms.Permute(mode='CHW'), normalize + transforms.Transpose(order=(2, 0, 1)), normalize ]) else: self.transform = transforms.Compose([ transforms.Resize(resize_short_size), transforms.CenterCrop(image_size), - transforms.Permute(mode='CHW'), normalize + transforms.Transpose(order=(2, 0, 1)), normalize ]) def __getitem__(self, idx):