提交 6478fcce 编写于 作者: G guosheng

Merge branch 'master' of https://github.com/PaddlePaddle/hapi into add-load-finetune

*.pyc
*.json
output*
*checkpoint*
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
import six
import abc
import numpy as np
import paddle.fluid as fluid
import logging
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__)
__all__ = ['Metric', 'Accuracy']
@six.add_metaclass(abc.ABCMeta)
class Metric(object):
"""
Base class for metric, encapsulates metric logic and APIs
Usage:
m = SomeMetric()
for prediction, label in ...:
m.update(prediction, label)
m.accumulate()
"""
@abc.abstractmethod
def reset(self):
"""
Reset states and result
"""
raise NotImplementedError("function 'reset' not implemented in {}.".format(self.__class__.__name__))
@abc.abstractmethod
def update(self, *args, **kwargs):
"""
Update states for metric
"""
raise NotImplementedError("function 'update' not implemented in {}.".format(self.__class__.__name__))
@abc.abstractmethod
def accumulate(self):
"""
Accumulates statistics, computes and returns the metric value
"""
raise NotImplementedError("function 'accumulate' not implemented in {}.".format(self.__class__.__name__))
def add_metric_op(self, pred, label):
"""
Add process op for metric in program
"""
return pred, label
class Accuracy(Metric):
"""
Encapsulates accuracy metric logic
"""
def __init__(self, topk=(1, ), *args, **kwargs):
super(Accuracy, self).__init__(*args, **kwargs)
self.topk = topk
self.maxk = max(topk)
self.reset()
def add_metric_op(self, pred, label, *args, **kwargs):
pred = fluid.layers.argsort(pred[0], descending=True)[1][:, :self.maxk]
correct = pred == label[0]
return correct
def update(self, correct, *args, **kwargs):
accs = []
for i, k in enumerate(self.topk):
num_corrects = correct[:, :k].sum()
num_samples = len(correct)
accs.append(float(num_corrects) / num_samples)
self.total[i] += num_corrects
self.count[i] += num_samples
return accs
def reset(self):
self.total = [0.] * len(self.topk)
self.count = [0] * len(self.topk)
def accumulate(self):
res = []
for t, c in zip(self.total, self.count):
res.append(float(t) / c)
return res
......@@ -27,6 +27,7 @@ from paddle.fluid.optimizer import Momentum
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear
from model import Model, CrossEntropy, Input
from metrics import Accuracy
class SimpleImgConvPool(fluid.dygraph.Layer):
......@@ -145,36 +146,38 @@ def main():
parameter_list=model.parameters())
inputs = [Input([None, 1, 28, 28], 'float32', name='image')]
labels = [Input([None, 1], 'int64', name='label')]
model.prepare(optim, CrossEntropy(), inputs, labels)
model.prepare(optim, CrossEntropy(), Accuracy(topk=(1, 2)), inputs, labels)
if FLAGS.resume is not None:
model.load(FLAGS.resume)
for e in range(FLAGS.epoch):
train_loss = 0.0
train_acc = 0.0
val_loss = 0.0
val_acc = 0.0
print("======== train epoch {} ========".format(e))
for idx, batch in enumerate(train_loader()):
outputs, losses = model.train(batch[0], batch[1])
losses, metrics = model.train(batch[0], batch[1])
acc = accuracy(outputs[0], batch[1])[0]
train_loss += np.sum(losses)
train_acc += acc
if idx % 10 == 0:
print("{:04d}: loss {:0.3f} top1: {:0.3f}%".format(
idx, train_loss / (idx + 1), train_acc / (idx + 1)))
print("{:04d}: loss {:0.3f} top1: {:0.3f}% top2: {:0.3f}%".format(
idx, train_loss / (idx + 1), metrics[0][0], metrics[0][1]))
for metric in model._metrics:
res = metric.accumulate()
print("train epoch {:03d}: top1: {:0.3f}%, top2: {:0.3f}".format(e, res[0], res[1]))
metric.reset()
print("======== eval epoch {} ========".format(e))
for idx, batch in enumerate(val_loader()):
outputs, losses = model.eval(batch[0], batch[1])
losses, metrics = model.eval(batch[0], batch[1])
acc = accuracy(outputs[0], batch[1])[0]
val_loss += np.sum(losses)
val_acc += acc
if idx % 10 == 0:
print("{:04d}: loss {:0.3f} top1: {:0.3f}%".format(
idx, val_loss / (idx + 1), val_acc / (idx + 1)))
print("{:04d}: loss {:0.3f} top1: {:0.3f}% top2: {:0.3f}%".format(
idx, val_loss / (idx + 1), metrics[0][0], metrics[0][1]))
for metric in model._metrics:
res = metric.accumulate()
print("eval epoch {:03d}: top1: {:0.3f}%, top2: {:0.3f}".format(e, res[0], res[1]))
metric.reset()
model.save('mnist_checkpoints/{:02d}'.format(e))
......
......@@ -27,6 +27,7 @@ from paddle.fluid.framework import in_dygraph_mode, Variable
from paddle.fluid.executor import global_scope
from paddle.fluid.io import is_belong_to_optimizer
from paddle.fluid.dygraph.base import to_variable
from metrics import Metric
__all__ = ['Model', 'Loss', 'CrossEntropy', 'Input']
......@@ -47,6 +48,26 @@ def to_numpy(var):
return np.array(t)
def flatten_list(l):
assert isinstance(l, list), "not a list"
outl = []
splits = []
for sl in l:
assert isinstance(sl, list), "sub content not a list"
splits.append(len(sl))
outl += sl
return outl, splits
def restore_flatten_list(l, splits):
outl = []
for split in splits:
assert len(l) >= split, "list length invalid"
sl, l = l[:split], l[split:]
outl.append(sl)
return outl
def extract_args(func):
if hasattr(inspect, 'getfullargspec'):
return inspect.getfullargspec(func)[0]
......@@ -56,6 +77,7 @@ def extract_args(func):
class Input(fluid.dygraph.Layer):
def __init__(self, shape=None, dtype=None, name=None):
super(Input, self).__init__()
self.shape = shape
self.dtype = dtype
self.name = name
......@@ -316,15 +338,26 @@ class StaticGraphAdapter(object):
feed[v.name] = labels[idx]
endpoints = self._endpoints[self.mode]
fetch_list = endpoints['output'] + endpoints['loss']
num_output = len(endpoints['output'])
out = self._executor.run(compiled_prog,
feed=feed,
fetch_list=fetch_list)
if self.mode == 'test':
return out[:num_output]
fetch_list = endpoints['output']
else:
return out[:num_output], out[num_output:]
metric_list, metric_splits = flatten_list(endpoints['metric'])
fetch_list = endpoints['loss'] + metric_list
num_loss = len(endpoints['loss'])
rets = self._executor.run(compiled_prog,
feed=feed,
fetch_list=fetch_list,
return_numpy=False)
# LoDTensor cannot be fetch as numpy directly
rets = [np.array(v) for v in rets]
if self.mode == 'test':
return rets[:]
losses = rets[:num_loss]
metric_states = restore_flatten_list(rets[num_loss:], metric_splits)
metrics = []
for metric, state in zip(self.model._metrics, metric_states):
metrics.append(metric.update(*state))
return (losses, metrics) if len(metrics) > 0 else losses
def prepare(self):
modes = ['train', 'eval', 'test']
......@@ -352,6 +385,7 @@ class StaticGraphAdapter(object):
lr_var = self.model._optimizer._learning_rate_map[self._orig_prog]
self.model._optimizer._learning_rate_map[prog] = lr_var
losses = []
metrics = []
with fluid.program_guard(prog, self._startup_prog):
if isinstance(self.model._inputs, dict):
ins = [self.model._inputs[n] \
......@@ -365,6 +399,9 @@ class StaticGraphAdapter(object):
if mode != 'test':
if self.model._loss_function:
losses = self.model._loss_function(outputs, labels)
for metric in self.model._metrics:
metrics.append(
to_list(metric.add_metric_op(outputs, labels)))
if mode == 'train' and self.model._optimizer:
self._loss_endpoint = fluid.layers.sum(losses)
self.model._optimizer.minimize(self._loss_endpoint)
......@@ -374,7 +411,11 @@ class StaticGraphAdapter(object):
self._input_vars[mode] = inputs
self._label_vars[mode] = labels
self._progs[mode] = prog
self._endpoints[mode] = {"output": outputs, "loss": losses}
self._endpoints[mode] = {
"output": outputs,
"loss": losses,
"metric": metrics
}
def _compile_and_initialize(self, prog, mode):
compiled_prog = self._compiled_progs.get(mode, None)
......@@ -436,33 +477,46 @@ class DynamicGraphAdapter(object):
self.mode = 'train'
inputs = to_list(inputs)
if labels is not None:
labels = to_list(labels)
outputs = self.model.forward(* [to_variable(x) for x in inputs])
labels = [to_variable(l) for l in to_list(labels)]
outputs = to_list(
self.model.forward(* [to_variable(x) for x in inputs]))
losses = self.model._loss_function(outputs, labels)
final_loss = fluid.layers.sum(losses)
final_loss.backward()
self.model._optimizer.minimize(final_loss)
self.model.clear_gradients()
return [to_numpy(o) for o in to_list(outputs)], \
[to_numpy(l) for l in losses]
metrics = []
for metric in self.model._metrics:
metric_outs = metric.add_metric_op(outputs, to_list(labels))
m = metric.update(* [to_numpy(m) for m in to_list(metric_outs)])
metrics.append(m)
return ([to_numpy(l) for l in losses], metrics) \
if len(metrics) > 0 else [to_numpy(l) for l in losses]
def eval(self, inputs, labels=None):
super(Model, self.model).eval()
self.mode = 'eval'
inputs = to_list(inputs)
if labels is not None:
labels = to_list(labels)
outputs = self.model.forward(* [to_variable(x) for x in inputs])
labels = [to_variable(l) for l in to_list(labels)]
outputs = to_list(
self.model.forward(* [to_variable(x) for x in inputs]))
if self.model._loss_function:
losses = self.model._loss_function(outputs, labels)
else:
losses = []
metrics = []
for metric in self.model._metrics:
metric_outs = metric.add_metric_op(outputs, labels)
m = metric.update(* [to_numpy(m) for m in to_list(metric_outs)])
metrics.append(m)
# To be consistent with static graph
# return empty loss if loss_function is None
return [to_numpy(o) for o in to_list(outputs)], \
[to_numpy(l) for l in losses]
return ([to_numpy(l) for l in losses], metrics) \
if len(metrics) > 0 else [to_numpy(l) for l in losses]
def test(self, inputs):
super(Model, self.model).eval()
......@@ -657,6 +711,7 @@ class Model(fluid.dygraph.Layer):
def prepare(self,
optimizer=None,
loss_function=None,
metrics=None,
inputs=None,
labels=None,
device=None,
......@@ -670,6 +725,8 @@ class Model(fluid.dygraph.Layer):
loss_function (Loss|None): loss function must be set in training
and should be a Loss instance. It can be None when there is
no loss.
metrics (Metric|list of Metric|None): if metrics is set, all
metric will be calculate and output in train/eval mode.
inputs (Input|list|dict|None): inputs, entry points of network,
could be a Input layer, or lits of Input layers,
or dict (name: Input), or None. For static graph,
......@@ -705,6 +762,13 @@ class Model(fluid.dygraph.Layer):
"'inputs' must be list or dict in static graph mode")
if loss_function and not isinstance(labels, (list, Input)):
raise TypeError("'labels' must be list in static graph mode")
metrics = metrics or []
for metric in to_list(metrics):
assert isinstance(metric, Metric), \
"{} is not sub class of Metric".format(metric.__class__.__name__)
self._metrics = to_list(metrics)
self._inputs = inputs
self._labels = labels
self._device = device
......
......@@ -33,9 +33,14 @@ from paddle.fluid.dygraph.nn import Conv2D
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.regularizer import L2Decay
from model import Model, Loss, shape_hints
from model import Model, Loss, Input
from resnet import ResNet, ConvBNLayer
import logging
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__)
# XXX transfer learning
class ResNetBackBone(ResNet):
......@@ -102,13 +107,14 @@ class YoloDetectionBlock(fluid.dygraph.Layer):
class YOLOv3(Model):
def __init__(self):
def __init__(self, num_classes=80):
super(YOLOv3, self).__init__()
self.num_classes = 80
self.num_classes = num_classes
self.anchors = [10, 13, 16, 30, 33, 23, 30, 61, 62, 45,
59, 119, 116, 90, 156, 198, 373, 326]
self.anchor_masks = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
self.valid_thresh = 0.005
self.nms_thresh = 0.45
self.nms_topk = 400
self.nms_posk = 100
self.draw_thresh = 0.5
......@@ -146,8 +152,7 @@ class YOLOv3(Model):
act='leaky_relu'))
self.route_blocks.append(route)
@shape_hints(inputs=[None, 3, None, None], im_shape=[None, 2])
def forward(self, inputs, im_shape):
def forward(self, inputs, img_info):
outputs = []
boxes = []
scores = []
......@@ -161,48 +166,50 @@ class YOLOv3(Model):
feat = fluid.layers.concat(input=[route, feat], axis=1)
route, tip = self.yolo_blocks[idx](feat)
block_out = self.block_outputs[idx](tip)
outputs.append(block_out)
if idx < 2:
route = self.route_blocks[idx](route)
route = fluid.layers.resize_nearest(route, scale=2)
anchor_mask = self.anchor_masks[idx]
mask_anchors = []
for m in anchor_mask:
mask_anchors.append(self.anchors[2 * m])
mask_anchors.append(self.anchors[2 * m + 1])
b, s = fluid.layers.yolo_box(
x=block_out,
img_size=im_shape,
anchors=mask_anchors,
class_num=self.num_classes,
conf_thresh=self.valid_thresh,
downsample_ratio=downsample)
outputs.append(block_out)
boxes.append(b)
scores.append(fluid.layers.transpose(s, perm=[0, 2, 1]))
if self.mode == 'test':
anchor_mask = self.anchor_masks[idx]
mask_anchors = []
for m in anchor_mask:
mask_anchors.append(self.anchors[2 * m])
mask_anchors.append(self.anchors[2 * m + 1])
img_shape = fluid.layers.slice(img_info, axes=[1], starts=[1], ends=[3])
img_id = fluid.layers.slice(img_info, axes=[1], starts=[0], ends=[1])
b, s = fluid.layers.yolo_box(
x=block_out,
img_size=img_shape,
anchors=mask_anchors,
class_num=self.num_classes,
conf_thresh=self.valid_thresh,
downsample_ratio=downsample)
boxes.append(b)
scores.append(fluid.layers.transpose(s, perm=[0, 2, 1]))
downsample //= 2
if self.mode != 'test':
return outputs
return fluid.layers.multiclass_nms(
return [img_id, fluid.layers.multiclass_nms(
bboxes=fluid.layers.concat(boxes, axis=1),
scores=fluid.layers.concat(scores, axis=2),
score_threshold=self.valid_thresh,
nms_top_k=self.nms_topk,
keep_top_k=self.nms_posk,
nms_threshold=self.nms_thresh,
background_label=-1)
background_label=-1)]
class YoloLoss(Loss):
def __init__(self, num_classes=80, num_max_boxes=50):
def __init__(self, num_classes=80):
super(YoloLoss, self).__init__()
self.num_classes = num_classes
self.num_max_boxes = num_max_boxes
self.ignore_thresh = 0.7
self.anchors = [10, 13, 16, 30, 33, 23, 30, 61, 62, 45,
59, 119, 116, 90, 156, 198, 373, 326]
......@@ -226,20 +233,11 @@ class YoloLoss(Loss):
class_num=self.num_classes,
ignore_thresh=self.ignore_thresh,
use_label_smooth=True)
loss = fluid.layers.reduce_mean(loss)
losses.append(loss)
downsample //= 2
return losses
def infer_shape(self, _):
return [
[None, self.num_max_boxes, 4],
[None, self.num_max_boxes],
[None, self.num_max_boxes]
]
def infer_dtype(self, _):
return ['float32', 'int32', 'float32']
def make_optimizer(parameter_list=None):
base_lr = FLAGS.lr
......@@ -293,7 +291,7 @@ def random_crop(inputs):
thresholds = [.0, .1, .3, .5, .7, .9]
scaling = [.3, 1.]
img, gt_box, gt_label = inputs
img, img_ids, gt_box, gt_label = inputs
h, w = img.shape[:2]
if len(gt_box) == 0:
......@@ -327,7 +325,7 @@ def random_crop(inputs):
img = img[y1:y2, x1:x2, :]
gt_box = np.take(cropped_box, valid_ids, axis=0)
gt_label = np.take(gt_label, valid_ids, axis=0)
return img, gt_box, gt_label
return img, img_ids, gt_box, gt_label
return inputs
......@@ -335,9 +333,9 @@ def random_crop(inputs):
# XXX mix up, color distort and random expand are skipped for simplicity
def sample_transform(inputs, mode='train', num_max_boxes=50):
if mode == 'train':
img, gt_box, gt_label = random_crop(inputs)
img, img_id, gt_box, gt_label = random_crop(inputs)
else:
img, gt_box, gt_label = inputs
img, img_id, gt_box, gt_label = inputs
h, w = img.shape[:2]
# random flip
......@@ -350,7 +348,7 @@ def sample_transform(inputs, mode='train', num_max_boxes=50):
if len(gt_label) == 0:
gt_box = np.zeros([num_max_boxes, 4], dtype=np.float32)
gt_label = np.zeros([num_max_boxes, 1], dtype=np.int32)
gt_label = np.zeros([num_max_boxes], dtype=np.int32)
return img, gt_box, gt_label
gt_box = gt_box[:num_max_boxes, :]
......@@ -362,9 +360,9 @@ def sample_transform(inputs, mode='train', num_max_boxes=50):
pad = num_max_boxes - gt_label.size
gt_box = np.pad(gt_box, ((0, pad), (0, 0)), mode='constant')
gt_label = np.pad(gt_label, [(0, pad)], mode='constant')
gt_label = np.pad(gt_label, ((0, pad)), mode='constant')
return img, gt_box, gt_label
return img, img_id, gt_box, gt_label
def batch_transform(batch, mode='train'):
......@@ -376,7 +374,8 @@ def batch_transform(batch, mode='train'):
d = 608
interp = cv2.INTER_CUBIC
# transpose batch
imgs, gt_boxes, gt_labels = list(zip(*batch))
imgs, img_ids, gt_boxes, gt_labels = list(zip(*batch))
img_shapes = np.array([[im.shape[0], im.shape[1]] for im in imgs]).astype('int32')
imgs = np.array([cv2.resize(
img, (d, d), interpolation=interp) for img in imgs])
......@@ -389,12 +388,13 @@ def batch_transform(batch, mode='train'):
imgs *= invstd
imgs = imgs.transpose((0, 3, 1, 2))
im_shapes = np.full([len(imgs), 2], d, dtype=np.int32)
img_ids = np.array(img_ids)
img_info = np.concatenate([img_ids, img_shapes], axis=1)
gt_boxes = np.array(gt_boxes)
gt_labels = np.array(gt_labels)
# XXX since mix up is not used, scores are all ones
gt_scores = np.ones_like(gt_labels, dtype=np.float32)
return [imgs, im_shapes], [gt_boxes, gt_labels, gt_scores]
return [imgs, img_info], [gt_boxes, gt_labels, gt_scores]
def coco2017(root_dir, mode='train'):
......@@ -434,17 +434,18 @@ def coco2017(root_dir, mode='train'):
gt_box = np.array(gt_box, dtype=np.float32)
gt_label = np.array([class_map[cls] for cls in gt_label],
dtype=np.int32)[:, np.newaxis]
im_id = np.array([img['id']], dtype=np.int32)
if gt_label.size == 0 and not mode == 'train':
continue
samples.append((file_path, gt_box.copy(), gt_label.copy()))
samples.append((file_path, im_id.copy(), gt_box.copy(), gt_label.copy()))
def iterator():
if mode == 'train':
random.shuffle(samples)
for file_path, gt_box, gt_label in samples:
np.random.shuffle(samples)
for file_path, im_id, gt_box, gt_label in samples:
img = cv2.imread(file_path)
yield img, gt_box, gt_label
yield img, im_id, gt_box, gt_label
return iterator
......@@ -457,14 +458,13 @@ def run(model, loader, mode='train'):
start = time.time()
for idx, batch in enumerate(loader()):
outputs, losses = getattr(model, mode)(
batch[0], batch[1], device='gpu', device_ids=device_ids)
losses = getattr(model, mode)(batch[0], batch[1])
total_loss += np.sum(losses)
if idx > 1: # skip first two steps
total_time += time.time() - start
if idx % 10 == 0:
print("{:04d}: loss {:0.3f} time: {:0.3f}".format(
logger.info("{:04d}: loss {:0.3f} time: {:0.3f}".format(
idx, total_loss / (idx + 1), total_time / max(1, (idx - 1))))
start = time.time()
......@@ -501,26 +501,46 @@ def main():
coco2017(FLAGS.data, 'val'),
process_num=8,
buffer_size=4 * batch_size),
batch_size=batch_size),
batch_size=1),
process_num=2, buffer_size=4)
if not os.path.exists('yolo_checkpoints'):
os.mkdir('yolo_checkpoints')
with guard:
model = YOLOv3()
NUM_CLASSES = 7
NUM_MAX_BOXES = 50
model = YOLOv3(num_classes=NUM_CLASSES)
# XXX transfer learning
if FLAGS.pretrain_weights is not None:
model.backbone.load(FLAGS.pretrain_weights)
if FLAGS.weights is not None:
model.backbone.load(FLAGS.weights)
model.load(FLAGS.weights)
optim = make_optimizer(parameter_list=model.parameters())
model.prepare(optim, YoloLoss())
anno_path = os.path.join(FLAGS.data, 'annotations', 'instances_val2017.json')
inputs = [Input([None, 3, None, None], 'float32', name='image'),
Input([None, 3], 'int32', name='img_info')]
labels = [Input([None, NUM_MAX_BOXES, 4], 'float32', name='gt_bbox'),
Input([None, NUM_MAX_BOXES], 'int32', name='gt_label'),
Input([None, NUM_MAX_BOXES], 'float32', name='gt_score')]
model.prepare(optim,
YoloLoss(num_classes=NUM_CLASSES),
# For YOLOv3, output variable in train/eval is different,
# which is not supported by metric, add by callback later?
# metrics=COCOMetric(anno_path, with_background=False)
inputs=inputs,
labels = labels)
for e in range(epoch):
print("======== train epoch {} ========".format(e))
logger.info("======== train epoch {} ========".format(e))
run(model, train_loader)
model.save('yolo_checkpoints/{:02d}'.format(e))
print("======== eval epoch {} ========".format(e))
logger.info("======== eval epoch {} ========".format(e))
run(model, val_loader, mode='eval')
# should be called in fit()
for metric in model._metrics:
metric.accumulate()
metric.reset()
if __name__ == '__main__':
......@@ -538,8 +558,11 @@ if __name__ == '__main__':
parser.add_argument(
"-n", "--num_devices", default=8, type=int, help="number of devices")
parser.add_argument(
"-w", "--weights", default=None, type=str,
"-p", "--pretrain_weights", default=None, type=str,
help="path to pretrained weights")
parser.add_argument(
"-w", "--weights", default=None, type=str,
help="path to model weights")
FLAGS = parser.parse_args()
assert FLAGS.data, "error: must provide data path"
main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import json
from pycocotools.cocoeval import COCOeval
from pycocotools.coco import COCO
from metrics import Metric
import logging
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__)
__all__ = ['COCOMetric']
OUTFILE = './bbox.json'
# considered to change to a callback later
class COCOMetric(Metric):
"""
Metrci for MS-COCO dataset, only support update with batch
size as 1.
Args:
anno_path(str): path to COCO annotation json file
with_background(bool): whether load category id with
background as 0, default True
"""
def __init__(self, anno_path, with_background=True, **kwargs):
super(COCOMetric, self).__init__(**kwargs)
self.anno_path = anno_path
self.with_background = with_background
self.bbox_results = []
self.coco_gt = COCO(anno_path)
cat_ids = self.coco_gt.getCatIds()
self.clsid2catid = dict(
{i + int(with_background): catid
for i, catid in enumerate(cat_ids)})
def update(self, preds, *args, **kwargs):
im_ids, bboxes = preds
assert im_ids.shape[0] == 1, \
"COCOMetric can only update with batch size = 1"
if bboxes.shape[1] != 6:
# no bbox detected in this batch
return
im_id = int(im_ids)
for i in range(bboxes.shape[0]):
dt = bboxes[i, :]
clsid, score, xmin, ymin, xmax, ymax = dt.tolist()
catid = (self.clsid2catid[int(clsid)])
w = xmax - xmin + 1
h = ymax - ymin + 1
bbox = [xmin, ymin, w, h]
coco_res = {
'image_id': im_id,
'category_id': catid,
'bbox': bbox,
'score': score
}
self.bbox_results.append(coco_res)
def reset(self):
self.bbox_results = []
def accumulate(self):
if len(self.bbox_results) == 0:
logger.warning("The number of valid bbox detected is zero.\n \
Please use reasonable model and check input data.\n \
stop COCOMetric accumulate!")
return [0.0]
with open(OUTFILE, 'w') as f:
json.dump(self.bbox_results, f)
map_stats = self.cocoapi_eval(OUTFILE, 'bbox', coco_gt=self.coco_gt)
# flush coco evaluation result
sys.stdout.flush()
self.result = map_stats[0]
return self.result
def cocoapi_eval(self, jsonfile, style, coco_gt=None, anno_file=None):
assert coco_gt != None or anno_file != None
if coco_gt == None:
coco_gt = COCO(anno_file)
logger.info("Start evaluate...")
coco_dt = coco_gt.loadRes(jsonfile)
coco_eval = COCOeval(coco_gt, coco_dt, style)
coco_eval.evaluate()
coco_eval.accumulate()
coco_eval.summarize()
return coco_eval.stats
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册