未验证 提交 c9e173e5 编写于 作者: W Wei Shengyu 提交者: GitHub

Merge pull request #1894 from weisy11/cp_1854

Cherry-pick #1854
...@@ -21,6 +21,15 @@ from .common_dataset import CommonDataset ...@@ -21,6 +21,15 @@ from .common_dataset import CommonDataset
class ImageNetDataset(CommonDataset): class ImageNetDataset(CommonDataset):
def __init__(
self,
image_root,
cls_label_path,
transform_ops=None,
delimiter=None):
self.delimiter = delimiter if delimiter is not None else " "
super(ImageNetDataset, self).__init__(image_root, cls_label_path, transform_ops)
def _load_anno(self, seed=None): def _load_anno(self, seed=None):
assert os.path.exists(self._cls_path) assert os.path.exists(self._cls_path)
assert os.path.exists(self._img_root) assert os.path.exists(self._img_root)
...@@ -32,7 +41,7 @@ class ImageNetDataset(CommonDataset): ...@@ -32,7 +41,7 @@ class ImageNetDataset(CommonDataset):
if seed is not None: if seed is not None:
np.random.RandomState(seed).shuffle(lines) np.random.RandomState(seed).shuffle(lines)
for l in lines: for l in lines:
l = l.strip().split(" ") l = l.strip().split(self.delimiter)
self.images.append(os.path.join(self._img_root, l[0])) self.images.append(os.path.join(self._img_root, l[0]))
self.labels.append(np.int64(l[1])) self.labels.append(np.int64(l[1]))
assert os.path.exists(self.images[-1]) assert os.path.exists(self.images[-1])
...@@ -19,10 +19,11 @@ import paddle.nn.functional as F ...@@ -19,10 +19,11 @@ import paddle.nn.functional as F
class Topk(object): class Topk(object):
def __init__(self, topk=1, class_id_map_file=None): def __init__(self, topk=1, class_id_map_file=None, delimiter=None):
assert isinstance(topk, (int, )) assert isinstance(topk, (int, ))
self.class_id_map = self.parse_class_id_map(class_id_map_file) self.class_id_map = self.parse_class_id_map(class_id_map_file)
self.topk = topk self.topk = topk
self.delimiter = delimiter if delimiter is not None else " "
def parse_class_id_map(self, class_id_map_file): def parse_class_id_map(self, class_id_map_file):
if class_id_map_file is None: if class_id_map_file is None:
...@@ -38,7 +39,7 @@ class Topk(object): ...@@ -38,7 +39,7 @@ class Topk(object):
with open(class_id_map_file, "r") as fin: with open(class_id_map_file, "r") as fin:
lines = fin.readlines() lines = fin.readlines()
for line in lines: for line in lines:
partition = line.split("\n")[0].partition(" ") partition = line.split("\n")[0].partition(self.delimiter)
class_id_map[int(partition[0])] = str(partition[-1]) class_id_map[int(partition[0])] = str(partition[-1])
except Exception as ex: except Exception as ex:
print(ex) print(ex)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册