diff --git a/python_module/megengine/data/dataset/vision/cifar.py b/python_module/megengine/data/dataset/vision/cifar.py index e160891deaf07bc1d14ed55d318f67ccc35df0b5..9ce73688969d707c48245a83dce30759c33bc561 100644 --- a/python_module/megengine/data/dataset/vision/cifar.py +++ b/python_module/megengine/data/dataset/vision/cifar.py @@ -57,7 +57,14 @@ class CIFAR10(VisionDataset): else: self.root = root if not os.path.exists(self.root): - raise ValueError("dir %s does not exist" % self.root) + if download: + logger.debug( + "dir %s does not exist, will be automatically created", + self.root, + ) + os.makedirs(self.root) + else: + raise ValueError("dir %s does not exist" % self.root) self.target_file = os.path.join(self.root, self.raw_file_dir) @@ -77,8 +84,7 @@ class CIFAR10(VisionDataset): self.arrays = self.bytes2array(self.test_batch) else: raise ValueError( - "dir does not contain target file\ - %s, please set download=True" + "dir does not contain target file %s, please set download=True" % (self.target_file) ) @@ -160,6 +166,6 @@ class CIFAR100(CIFAR10): data.extend(list(batch_data[..., [2, 1, 0]])) fine_label.extend(dic[b"fine_labels"]) coarse_label.extend(dic[b"coarse_labels"]) - fine_label = np.array(fine_label) - coarse_label = np.array(coarse_label) + fine_label = np.array(fine_label, dtype=np.int32) + coarse_label = np.array(coarse_label, dtype=np.int32) return data, fine_label, coarse_label diff --git a/python_module/megengine/data/dataset/vision/mnist.py b/python_module/megengine/data/dataset/vision/mnist.py index 8ad82b1c49c67f5ff0234608e440b95d6369c92b..5e89a3140556bf9449f4fdadf1bb6e6b73b1f6ad 100644 --- a/python_module/megengine/data/dataset/vision/mnist.py +++ b/python_module/megengine/data/dataset/vision/mnist.py @@ -75,7 +75,14 @@ class MNIST(VisionDataset): else: self.root = root if not os.path.exists(self.root): - raise ValueError("dir %s does not exist" % self.root) + if download: + logger.debug( + "dir %s does not exist, will be automatically created", + self.root, + ) + os.makedirs(self.root) + else: + raise ValueError("dir %s does not exist" % self.root) if self._check_raw_files(): self.process(train)