提交 5ec5f453 编写于 作者: D dengkaipeng

fit for dygraph

上级 074a08e5
......@@ -16,6 +16,8 @@ from __future__ import absolute_import
import sys
import json
from pycocotools.cocoeval import COCOeval
from pycocotools.coco import COCO
from .metric import Metric
......@@ -32,22 +34,21 @@ OUTFILE = './bbox.json'
class COCOMetric(Metric):
"""
Base class for metric, encapsulates metric logic and APIs
Metrci for MS-COCO dataset, only support update with batch
size as 1.
Usage:
m = SomeMetric()
for prediction, label in ...:
m.update(prediction, label)
m.accumulate()
Args:
anno_path(str): path to COCO annotation json file
with_background(bool): whether load category id with
background as 0, default True
"""
def __init__(self, anno_path, with_background=True, **kwargs):
super(COCOMetric, self).__init__(**kwargs)
self.states['bbox'] = []
self.anno_path = anno_path
self.with_background = with_background
self.bbox_results = []
from pycocotools.coco import COCO
self.coco_gt = COCO(anno_path)
cat_ids = self.coco_gt.getCatIds()
self.clsid2catid = dict(
......@@ -56,39 +57,40 @@ class COCOMetric(Metric):
def update(self, preds, *args, **kwargs):
im_ids, bboxes = preds
if bboxes[0].shape[1] != 6:
assert im_ids.shape[0] == 1, \
"COCOMetric can only update with batch size = 1"
if bboxes.shape[1] != 6:
# no bbox detected in this batch
return
idx = 0
bboxes, lods = bboxes
for i, (im_id, lod) in enumerate(zip(im_ids, lods[0])):
im_id = int(im_id)
for i in range(lod):
dt = bboxes[idx]
clsid, score, xmin, ymin, xmax, ymax = dt.tolist()
catid = (self.clsid2catid[int(clsid)])
w = xmax - xmin + 1
h = ymax - ymin + 1
bbox = [xmin, ymin, w, h]
coco_res = {
'image_id': im_id,
'category_id': catid,
'bbox': bbox,
'score': score
}
self.states['bbox'].append(coco_res)
idx += 1
im_id = int(im_ids)
for i in range(bboxes.shape[0]):
dt = bboxes[i, :]
clsid, score, xmin, ymin, xmax, ymax = dt.tolist()
catid = (self.clsid2catid[int(clsid)])
w = xmax - xmin + 1
h = ymax - ymin + 1
bbox = [xmin, ymin, w, h]
coco_res = {
'image_id': im_id,
'category_id': catid,
'bbox': bbox,
'score': score
}
self.bbox_results.append(coco_res)
def reset(self):
self.bbox_results = []
def accumulate(self):
if len(self.states['bbox']) == 0:
if len(self.bbox_results) == 0:
logger.warning("The number of valid bbox detected is zero.\n \
Please use reasonable model and check input data.\n \
stop COCOMetric accumulate!")
return [0.0]
with open(OUTFILE, 'w') as f:
json.dump(self.states['bbox'], f)
json.dump(self.bbox_results, f)
map_stats = self.cocoapi_eval(OUTFILE, 'bbox', coco_gt=self.coco_gt)
# flush coco evaluation result
......@@ -98,10 +100,8 @@ class COCOMetric(Metric):
def cocoapi_eval(self, jsonfile, style, coco_gt=None, anno_file=None):
assert coco_gt != None or anno_file != None
from pycocotools.cocoeval import COCOeval
if coco_gt == None:
from pycocotools.coco import COCO
coco_gt = COCO(anno_file)
logger.info("Start evaluate...")
coco_dt = coco_gt.loadRes(jsonfile)
......
......@@ -17,6 +17,11 @@ from __future__ import absolute_import
import six
import abc
import logging
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__)
__all__ = ['Metric']
......@@ -32,15 +37,12 @@ class Metric(object):
m.accumulate()
"""
def __init__(self, **kwargs):
self.reset()
@abc.abstractmethod
def reset(self):
"""
Reset states and result
"""
self.states = {}
self.result = None
raise NotImplementedError("function 'reset' not implemented in {}.".format(self.__class__.__name__))
@abc.abstractmethod
def update(self, *args, **kwargs):
......
......@@ -285,17 +285,11 @@ class StaticGraphAdapter(object):
compiled_prog, feed=feed,
fetch_list=fetch_list,
return_numpy=False)
# rets = [(np.array(v), v.recursive_sequence_lengths()) if v.lod() for v in rets]
np_rets = []
for ret in rets:
seq_len = ret.recursive_sequence_lengths()
if len(seq_len) == 0:
np_rets.append(np.array(ret))
else:
np_rets.append((np.array(ret), seq_len))
outputs = np_rets[:num_output]
labels = np_rets[num_output:num_output+num_label]
losses = np_rets[num_output+num_label:]
# LoDTensor cannot be fetch as numpy directly
rets = [np.array(v) for v in rets]
outputs = rets[:num_output]
labels = rets[num_output:num_output+num_label]
losses = rets[num_output+num_label:]
if self.mode == 'test':
return outputs
elif self.mode == 'eval':
......@@ -443,6 +437,8 @@ class DynamicGraphAdapter(object):
labels = to_list(labels)
outputs = self.model.forward(*[to_variable(x) for x in inputs])
losses = self.model._loss_function(outputs[0], labels)
for metric in self.model._metrics:
metric.update([to_numpy(o) for o in outputs[1:]], labels)
return [to_numpy(o) for o in to_list(outputs[0])], \
[to_numpy(l) for l in losses]
......
......@@ -515,7 +515,7 @@ def main():
coco2017(FLAGS.data, 'val'),
process_num=8,
buffer_size=4 * batch_size),
batch_size=batch_size),
batch_size=1),
process_num=2, buffer_size=4)
if not os.path.exists('yolo_checkpoints'):
......@@ -536,14 +536,15 @@ def main():
metrics=COCOMetric(anno_path, with_background=False))
for e in range(epoch):
logger.info("======== train epoch {} ========".format(e))
run(model, train_loader)
model.save('yolo_checkpoints/{:02d}'.format(e))
# logger.info("======== train epoch {} ========".format(e))
# run(model, train_loader)
# model.save('yolo_checkpoints/{:02d}'.format(e))
logger.info("======== eval epoch {} ========".format(e))
run(model, val_loader, mode='eval')
# should be called in fit()
for metric in model._metrics:
metric.accumulate()
metric.reset()
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册