提交 d495dfa9 编写于 作者: W weishengyu

update delimiter

上级 8765dadf
......@@ -21,6 +21,15 @@ from .common_dataset import CommonDataset
class ImageNetDataset(CommonDataset):
def __init__(
self,
image_root,
cls_label_path,
transform_ops=None,
delimiter=None):
super(ImageNetDataset, self).__init__(image_root, cls_label_path, transform_ops)
self.delimiter = delimiter if delimiter is not None else " "
def _load_anno(self, seed=None):
assert os.path.exists(self._cls_path)
assert os.path.exists(self._img_root)
......@@ -32,7 +41,7 @@ class ImageNetDataset(CommonDataset):
if seed is not None:
np.random.RandomState(seed).shuffle(lines)
for l in lines:
l = l.strip().split(" ")
l = l.strip().split(self.delimiter)
self.images.append(os.path.join(self._img_root, l[0]))
self.labels.append(np.int64(l[1]))
assert os.path.exists(self.images[-1])
......@@ -19,10 +19,11 @@ import paddle.nn.functional as F
class Topk(object):
def __init__(self, topk=1, class_id_map_file=None):
def __init__(self, topk=1, class_id_map_file=None, delimiter=None):
assert isinstance(topk, (int, ))
self.class_id_map = self.parse_class_id_map(class_id_map_file)
self.topk = topk
self.delimiter = delimiter if delimiter is not None else " "
def parse_class_id_map(self, class_id_map_file):
if class_id_map_file is None:
......@@ -38,7 +39,7 @@ class Topk(object):
with open(class_id_map_file, "r") as fin:
lines = fin.readlines()
for line in lines:
partition = line.split("\n")[0].partition(" ")
partition = line.split("\n")[0].partition(self.delimiter)
class_id_map[int(partition[0])] = str(partition[-1])
except Exception as ex:
print(ex)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册