未验证 提交 c3ae8396 编写于 作者: Q qingqing01 提交者: GitHub

Fix bug by using coco dataset. (#1582)

* Fix bug for coco dataset.
上级 13af25f0
......@@ -68,6 +68,7 @@ class GeneratorEnqueuer(object):
try:
task()
except Exception:
traceback.print_exc()
self._stop_event.set()
break
else:
......@@ -75,6 +76,7 @@ class GeneratorEnqueuer(object):
try:
task()
except Exception:
traceback.print_exc()
self._stop_event.set()
break
......
......@@ -176,10 +176,17 @@ def coco(settings, file_list, mode, batch_size, shuffle):
if mode == 'train' and shuffle:
np.random.shuffle(images)
batch_out = []
if '2014' in file_list:
sub_dir = "train2014" if model == "train" else "val2014"
elif '2017' in file_list:
sub_dir = "train2017" if mode == "train" else "val2017"
data_dir = os.path.join(settings.data_dir, sub_dir)
for image in images:
image_name = image['file_name']
image_path = os.path.join(settings.data_dir, image_name)
image_path = os.path.join(data_dir, image_name)
if not os.path.exists(image_path):
raise ValueError("%s is not exist, you should specify "
"data path correctly." % image_path)
im = Image.open(image_path)
if im.mode == 'L':
im = im.convert('RGB')
......@@ -242,7 +249,9 @@ def pascalvoc(settings, file_list, mode, batch_size, shuffle):
image_path, label_path = image.split()
image_path = os.path.join(settings.data_dir, image_path)
label_path = os.path.join(settings.data_dir, label_path)
if not os.path.exists(image_path):
raise ValueError("%s is not exist, you should specify "
"data path correctly." % image_path)
im = Image.open(image_path)
if im.mode == 'L':
im = im.convert('RGB')
......@@ -295,7 +304,6 @@ def train(settings,
max_queue=24,
enable_ce=False):
file_list = os.path.join(settings.data_dir, file_list)
if 'coco' in settings.dataset:
generator = coco(settings, file_list, "train", batch_size, shuffle)
else:
......@@ -341,6 +349,9 @@ def test(settings, file_list, batch_size):
def infer(settings, image_path):
def reader():
if not os.path.exists(image_path):
raise ValueError("%s is not exist, you should specify "
"data path correctly." % image_path)
img = Image.open(image_path)
if img.mode == 'L':
img = im.convert('RGB')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册