提交 aa2bfd2d 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

fix(mge/data): create dir automatically when download=True

GitOrigin-RevId: f3cb1b7d50af84c8c03b048ad3a0ad02e3c463f2
上级 3c616386
......@@ -57,7 +57,14 @@ class CIFAR10(VisionDataset):
else:
self.root = root
if not os.path.exists(self.root):
raise ValueError("dir %s does not exist" % self.root)
if download:
logger.debug(
"dir %s does not exist, will be automatically created",
self.root,
)
os.makedirs(self.root)
else:
raise ValueError("dir %s does not exist" % self.root)
self.target_file = os.path.join(self.root, self.raw_file_dir)
......@@ -77,8 +84,7 @@ class CIFAR10(VisionDataset):
self.arrays = self.bytes2array(self.test_batch)
else:
raise ValueError(
"dir does not contain target file\
%s, please set download=True"
"dir does not contain target file %s, please set download=True"
% (self.target_file)
)
......@@ -160,6 +166,6 @@ class CIFAR100(CIFAR10):
data.extend(list(batch_data[..., [2, 1, 0]]))
fine_label.extend(dic[b"fine_labels"])
coarse_label.extend(dic[b"coarse_labels"])
fine_label = np.array(fine_label)
coarse_label = np.array(coarse_label)
fine_label = np.array(fine_label, dtype=np.int32)
coarse_label = np.array(coarse_label, dtype=np.int32)
return data, fine_label, coarse_label
......@@ -75,7 +75,14 @@ class MNIST(VisionDataset):
else:
self.root = root
if not os.path.exists(self.root):
raise ValueError("dir %s does not exist" % self.root)
if download:
logger.debug(
"dir %s does not exist, will be automatically created",
self.root,
)
os.makedirs(self.root)
else:
raise ValueError("dir %s does not exist" % self.root)
if self._check_raw_files():
self.process(train)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册