未验证 提交 765e80c0 编写于 作者: K Kaipeng Deng 提交者: GitHub

fix voc difficult not found & map_res empty (#2714)

上级 db13eddf
...@@ -117,7 +117,12 @@ class VOCDataSet(DetDataset): ...@@ -117,7 +117,12 @@ class VOCDataSet(DetDataset):
difficult = [] difficult = []
for i, obj in enumerate(objs): for i, obj in enumerate(objs):
cname = obj.find('name').text cname = obj.find('name').text
_difficult = int(obj.find('difficult').text)
# user dataset may not contain difficult field
_difficult = obj.find('difficult')
_difficult = int(
_difficult.text) if _difficult is not None else 0
x1 = float(obj.find('bndbox').find('xmin').text) x1 = float(obj.find('bndbox').find('xmin').text)
y1 = float(obj.find('bndbox').find('ymin').text) y1 = float(obj.find('bndbox').find('ymin').text)
x2 = float(obj.find('bndbox').find('xmax').text) x2 = float(obj.find('bndbox').find('xmax').text)
......
...@@ -179,6 +179,11 @@ class Checkpointer(Callback): ...@@ -179,6 +179,11 @@ class Checkpointer(Callback):
for metric in self.model._metrics: for metric in self.model._metrics:
map_res = metric.get_results() map_res = metric.get_results()
key = 'bbox' if 'bbox' in map_res else 'mask' key = 'bbox' if 'bbox' in map_res else 'mask'
if key not in map_res:
logger.warn("Evaluation results empty, this may be due to " \
"training iterations being too few or not " \
"loading the correct weights.")
return
if map_res[key][0] > self.best_ap: if map_res[key][0] > self.best_ap:
self.best_ap = map_res[key][0] self.best_ap = map_res[key][0]
save_name = 'best_model' save_name = 'best_model'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册