提交 5ec5f453 编写于 作者: D dengkaipeng

fit for dygraph

上级 074a08e5
...@@ -16,6 +16,8 @@ from __future__ import absolute_import ...@@ -16,6 +16,8 @@ from __future__ import absolute_import
import sys import sys
import json import json
from pycocotools.cocoeval import COCOeval
from pycocotools.coco import COCO
from .metric import Metric from .metric import Metric
...@@ -32,22 +34,21 @@ OUTFILE = './bbox.json' ...@@ -32,22 +34,21 @@ OUTFILE = './bbox.json'
class COCOMetric(Metric): class COCOMetric(Metric):
""" """
Base class for metric, encapsulates metric logic and APIs Metrci for MS-COCO dataset, only support update with batch
size as 1.
Usage: Args:
m = SomeMetric() anno_path(str): path to COCO annotation json file
for prediction, label in ...: with_background(bool): whether load category id with
m.update(prediction, label) background as 0, default True
m.accumulate()
""" """
def __init__(self, anno_path, with_background=True, **kwargs): def __init__(self, anno_path, with_background=True, **kwargs):
super(COCOMetric, self).__init__(**kwargs) super(COCOMetric, self).__init__(**kwargs)
self.states['bbox'] = []
self.anno_path = anno_path self.anno_path = anno_path
self.with_background = with_background self.with_background = with_background
self.bbox_results = []
from pycocotools.coco import COCO
self.coco_gt = COCO(anno_path) self.coco_gt = COCO(anno_path)
cat_ids = self.coco_gt.getCatIds() cat_ids = self.coco_gt.getCatIds()
self.clsid2catid = dict( self.clsid2catid = dict(
...@@ -56,39 +57,40 @@ class COCOMetric(Metric): ...@@ -56,39 +57,40 @@ class COCOMetric(Metric):
def update(self, preds, *args, **kwargs): def update(self, preds, *args, **kwargs):
im_ids, bboxes = preds im_ids, bboxes = preds
if bboxes[0].shape[1] != 6: assert im_ids.shape[0] == 1, \
"COCOMetric can only update with batch size = 1"
if bboxes.shape[1] != 6:
# no bbox detected in this batch # no bbox detected in this batch
return return
idx = 0 im_id = int(im_ids)
bboxes, lods = bboxes for i in range(bboxes.shape[0]):
for i, (im_id, lod) in enumerate(zip(im_ids, lods[0])): dt = bboxes[i, :]
im_id = int(im_id) clsid, score, xmin, ymin, xmax, ymax = dt.tolist()
for i in range(lod): catid = (self.clsid2catid[int(clsid)])
dt = bboxes[idx]
clsid, score, xmin, ymin, xmax, ymax = dt.tolist() w = xmax - xmin + 1
catid = (self.clsid2catid[int(clsid)]) h = ymax - ymin + 1
bbox = [xmin, ymin, w, h]
w = xmax - xmin + 1 coco_res = {
h = ymax - ymin + 1 'image_id': im_id,
bbox = [xmin, ymin, w, h] 'category_id': catid,
coco_res = { 'bbox': bbox,
'image_id': im_id, 'score': score
'category_id': catid, }
'bbox': bbox, self.bbox_results.append(coco_res)
'score': score
} def reset(self):
self.states['bbox'].append(coco_res) self.bbox_results = []
idx += 1
def accumulate(self): def accumulate(self):
if len(self.states['bbox']) == 0: if len(self.bbox_results) == 0:
logger.warning("The number of valid bbox detected is zero.\n \ logger.warning("The number of valid bbox detected is zero.\n \
Please use reasonable model and check input data.\n \ Please use reasonable model and check input data.\n \
stop COCOMetric accumulate!") stop COCOMetric accumulate!")
return [0.0] return [0.0]
with open(OUTFILE, 'w') as f: with open(OUTFILE, 'w') as f:
json.dump(self.states['bbox'], f) json.dump(self.bbox_results, f)
map_stats = self.cocoapi_eval(OUTFILE, 'bbox', coco_gt=self.coco_gt) map_stats = self.cocoapi_eval(OUTFILE, 'bbox', coco_gt=self.coco_gt)
# flush coco evaluation result # flush coco evaluation result
...@@ -98,10 +100,8 @@ class COCOMetric(Metric): ...@@ -98,10 +100,8 @@ class COCOMetric(Metric):
def cocoapi_eval(self, jsonfile, style, coco_gt=None, anno_file=None): def cocoapi_eval(self, jsonfile, style, coco_gt=None, anno_file=None):
assert coco_gt != None or anno_file != None assert coco_gt != None or anno_file != None
from pycocotools.cocoeval import COCOeval
if coco_gt == None: if coco_gt == None:
from pycocotools.coco import COCO
coco_gt = COCO(anno_file) coco_gt = COCO(anno_file)
logger.info("Start evaluate...") logger.info("Start evaluate...")
coco_dt = coco_gt.loadRes(jsonfile) coco_dt = coco_gt.loadRes(jsonfile)
......
...@@ -17,6 +17,11 @@ from __future__ import absolute_import ...@@ -17,6 +17,11 @@ from __future__ import absolute_import
import six import six
import abc import abc
import logging
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__)
__all__ = ['Metric'] __all__ = ['Metric']
...@@ -32,15 +37,12 @@ class Metric(object): ...@@ -32,15 +37,12 @@ class Metric(object):
m.accumulate() m.accumulate()
""" """
def __init__(self, **kwargs): @abc.abstractmethod
self.reset()
def reset(self): def reset(self):
""" """
Reset states and result Reset states and result
""" """
self.states = {} raise NotImplementedError("function 'reset' not implemented in {}.".format(self.__class__.__name__))
self.result = None
@abc.abstractmethod @abc.abstractmethod
def update(self, *args, **kwargs): def update(self, *args, **kwargs):
......
...@@ -285,17 +285,11 @@ class StaticGraphAdapter(object): ...@@ -285,17 +285,11 @@ class StaticGraphAdapter(object):
compiled_prog, feed=feed, compiled_prog, feed=feed,
fetch_list=fetch_list, fetch_list=fetch_list,
return_numpy=False) return_numpy=False)
# rets = [(np.array(v), v.recursive_sequence_lengths()) if v.lod() for v in rets] # LoDTensor cannot be fetch as numpy directly
np_rets = [] rets = [np.array(v) for v in rets]
for ret in rets: outputs = rets[:num_output]
seq_len = ret.recursive_sequence_lengths() labels = rets[num_output:num_output+num_label]
if len(seq_len) == 0: losses = rets[num_output+num_label:]
np_rets.append(np.array(ret))
else:
np_rets.append((np.array(ret), seq_len))
outputs = np_rets[:num_output]
labels = np_rets[num_output:num_output+num_label]
losses = np_rets[num_output+num_label:]
if self.mode == 'test': if self.mode == 'test':
return outputs return outputs
elif self.mode == 'eval': elif self.mode == 'eval':
...@@ -443,6 +437,8 @@ class DynamicGraphAdapter(object): ...@@ -443,6 +437,8 @@ class DynamicGraphAdapter(object):
labels = to_list(labels) labels = to_list(labels)
outputs = self.model.forward(*[to_variable(x) for x in inputs]) outputs = self.model.forward(*[to_variable(x) for x in inputs])
losses = self.model._loss_function(outputs[0], labels) losses = self.model._loss_function(outputs[0], labels)
for metric in self.model._metrics:
metric.update([to_numpy(o) for o in outputs[1:]], labels)
return [to_numpy(o) for o in to_list(outputs[0])], \ return [to_numpy(o) for o in to_list(outputs[0])], \
[to_numpy(l) for l in losses] [to_numpy(l) for l in losses]
......
...@@ -515,7 +515,7 @@ def main(): ...@@ -515,7 +515,7 @@ def main():
coco2017(FLAGS.data, 'val'), coco2017(FLAGS.data, 'val'),
process_num=8, process_num=8,
buffer_size=4 * batch_size), buffer_size=4 * batch_size),
batch_size=batch_size), batch_size=1),
process_num=2, buffer_size=4) process_num=2, buffer_size=4)
if not os.path.exists('yolo_checkpoints'): if not os.path.exists('yolo_checkpoints'):
...@@ -536,14 +536,15 @@ def main(): ...@@ -536,14 +536,15 @@ def main():
metrics=COCOMetric(anno_path, with_background=False)) metrics=COCOMetric(anno_path, with_background=False))
for e in range(epoch): for e in range(epoch):
logger.info("======== train epoch {} ========".format(e)) # logger.info("======== train epoch {} ========".format(e))
run(model, train_loader) # run(model, train_loader)
model.save('yolo_checkpoints/{:02d}'.format(e)) # model.save('yolo_checkpoints/{:02d}'.format(e))
logger.info("======== eval epoch {} ========".format(e)) logger.info("======== eval epoch {} ========".format(e))
run(model, val_loader, mode='eval') run(model, val_loader, mode='eval')
# should be called in fit() # should be called in fit()
for metric in model._metrics: for metric in model._metrics:
metric.accumulate() metric.accumulate()
metric.reset()
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册