未验证 提交 903d0f7e 编写于 作者: K Kaipeng Deng 提交者: GitHub

Merge pull request #2 from heavengate/add_metric

add metric for mnist.
*.pyc
*.json
output*
*checkpoint*
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
import six
import abc
import numpy as np
import paddle.fluid as fluid
import logging
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__)
__all__ = ['Metric', 'Accuracy']
@six.add_metaclass(abc.ABCMeta)
class Metric(object):
"""
Base class for metric, encapsulates metric logic and APIs
Usage:
m = SomeMetric()
for prediction, label in ...:
m.update(prediction, label)
m.accumulate()
"""
@abc.abstractmethod
def reset(self):
"""
Reset states and result
"""
raise NotImplementedError("function 'reset' not implemented in {}.".format(self.__class__.__name__))
@abc.abstractmethod
def update(self, *args, **kwargs):
"""
Update states for metric
"""
raise NotImplementedError("function 'update' not implemented in {}.".format(self.__class__.__name__))
@abc.abstractmethod
def accumulate(self):
"""
Accumulates statistics, computes and returns the metric value
"""
raise NotImplementedError("function 'accumulate' not implemented in {}.".format(self.__class__.__name__))
def add_metric_op(self, pred, label):
"""
Add process op for metric in program
"""
return pred, label
class Accuracy(Metric):
"""
Encapsulates accuracy metric logic
"""
def __init__(self, topk=(1, ), *args, **kwargs):
super(Accuracy, self).__init__(*args, **kwargs)
self.topk = topk
self.maxk = max(topk)
self.reset()
def add_metric_op(self, pred, label, *args, **kwargs):
pred = fluid.layers.argsort(pred[0], descending=True)[1][:, :self.maxk]
correct = pred == label[0]
return correct
def update(self, correct, *args, **kwargs):
accs = []
for i, k in enumerate(self.topk):
num_corrects = correct[:, :k].sum()
num_samples = len(correct)
accs.append(float(num_corrects) / num_samples)
self.total[i] += num_corrects
self.count[i] += num_samples
return accs
def reset(self):
self.total = [0.] * len(self.topk)
self.count = [0] * len(self.topk)
def accumulate(self):
res = []
for t, c in zip(self.total, self.count):
res.append(float(t) / c)
return res
...@@ -27,6 +27,7 @@ from paddle.fluid.optimizer import Momentum ...@@ -27,6 +27,7 @@ from paddle.fluid.optimizer import Momentum
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear
from model import Model, CrossEntropy, Input from model import Model, CrossEntropy, Input
from metrics import Accuracy
class SimpleImgConvPool(fluid.dygraph.Layer): class SimpleImgConvPool(fluid.dygraph.Layer):
...@@ -145,36 +146,38 @@ def main(): ...@@ -145,36 +146,38 @@ def main():
parameter_list=model.parameters()) parameter_list=model.parameters())
inputs = [Input([None, 1, 28, 28], 'float32', name='image')] inputs = [Input([None, 1, 28, 28], 'float32', name='image')]
labels = [Input([None, 1], 'int64', name='label')] labels = [Input([None, 1], 'int64', name='label')]
model.prepare(optim, CrossEntropy(), inputs, labels) model.prepare(optim, CrossEntropy(), Accuracy(topk=(1, 2)), inputs, labels)
if FLAGS.resume is not None: if FLAGS.resume is not None:
model.load(FLAGS.resume) model.load(FLAGS.resume)
for e in range(FLAGS.epoch): for e in range(FLAGS.epoch):
train_loss = 0.0 train_loss = 0.0
train_acc = 0.0
val_loss = 0.0 val_loss = 0.0
val_acc = 0.0
print("======== train epoch {} ========".format(e)) print("======== train epoch {} ========".format(e))
for idx, batch in enumerate(train_loader()): for idx, batch in enumerate(train_loader()):
outputs, losses = model.train(batch[0], batch[1]) losses, metrics = model.train(batch[0], batch[1])
acc = accuracy(outputs[0], batch[1])[0]
train_loss += np.sum(losses) train_loss += np.sum(losses)
train_acc += acc
if idx % 10 == 0: if idx % 10 == 0:
print("{:04d}: loss {:0.3f} top1: {:0.3f}%".format( print("{:04d}: loss {:0.3f} top1: {:0.3f}% top2: {:0.3f}%".format(
idx, train_loss / (idx + 1), train_acc / (idx + 1))) idx, train_loss / (idx + 1), metrics[0][0], metrics[0][1]))
for metric in model._metrics:
res = metric.accumulate()
print("train epoch {:03d}: top1: {:0.3f}%, top2: {:0.3f}".format(e, res[0], res[1]))
metric.reset()
print("======== eval epoch {} ========".format(e)) print("======== eval epoch {} ========".format(e))
for idx, batch in enumerate(val_loader()): for idx, batch in enumerate(val_loader()):
outputs, losses = model.eval(batch[0], batch[1]) losses, metrics = model.eval(batch[0], batch[1])
acc = accuracy(outputs[0], batch[1])[0]
val_loss += np.sum(losses) val_loss += np.sum(losses)
val_acc += acc
if idx % 10 == 0: if idx % 10 == 0:
print("{:04d}: loss {:0.3f} top1: {:0.3f}%".format( print("{:04d}: loss {:0.3f} top1: {:0.3f}% top2: {:0.3f}%".format(
idx, val_loss / (idx + 1), val_acc / (idx + 1))) idx, val_loss / (idx + 1), metrics[0][0], metrics[0][1]))
for metric in model._metrics:
res = metric.accumulate()
print("eval epoch {:03d}: top1: {:0.3f}%, top2: {:0.3f}".format(e, res[0], res[1]))
metric.reset()
model.save('mnist_checkpoints/{:02d}'.format(e)) model.save('mnist_checkpoints/{:02d}'.format(e))
......
...@@ -26,6 +26,7 @@ from paddle.fluid.framework import in_dygraph_mode, Variable ...@@ -26,6 +26,7 @@ from paddle.fluid.framework import in_dygraph_mode, Variable
from paddle.fluid.executor import global_scope from paddle.fluid.executor import global_scope
from paddle.fluid.io import is_belong_to_optimizer from paddle.fluid.io import is_belong_to_optimizer
from paddle.fluid.dygraph.base import to_variable from paddle.fluid.dygraph.base import to_variable
from metrics import Metric
__all__ = ['Model', 'Loss', 'CrossEntropy', 'Input'] __all__ = ['Model', 'Loss', 'CrossEntropy', 'Input']
...@@ -46,6 +47,26 @@ def to_numpy(var): ...@@ -46,6 +47,26 @@ def to_numpy(var):
return np.array(t) return np.array(t)
def flatten_list(l):
assert isinstance(l, list), "not a list"
outl = []
splits = []
for sl in l:
assert isinstance(sl, list), "sub content not a list"
splits.append(len(sl))
outl += sl
return outl, splits
def restore_flatten_list(l, splits):
outl = []
for split in splits:
assert len(l) >= split, "list length invalid"
sl, l = l[:split], l[split:]
outl.append(sl)
return outl
def extract_args(func): def extract_args(func):
if hasattr(inspect, 'getfullargspec'): if hasattr(inspect, 'getfullargspec'):
return inspect.getfullargspec(func)[0] return inspect.getfullargspec(func)[0]
...@@ -309,15 +330,26 @@ class StaticGraphAdapter(object): ...@@ -309,15 +330,26 @@ class StaticGraphAdapter(object):
feed[v.name] = labels[idx] feed[v.name] = labels[idx]
endpoints = self._endpoints[self.mode] endpoints = self._endpoints[self.mode]
fetch_list = endpoints['output'] + endpoints['loss']
num_output = len(endpoints['output'])
out = self._executor.run(compiled_prog,
feed=feed,
fetch_list=fetch_list)
if self.mode == 'test': if self.mode == 'test':
return out[:num_output] fetch_list = endpoints['output']
else: else:
return out[:num_output], out[num_output:] metric_list, metric_splits = flatten_list(endpoints['metric'])
fetch_list = endpoints['loss'] + metric_list
num_loss = len(endpoints['loss'])
rets = self._executor.run(
compiled_prog, feed=feed,
fetch_list=fetch_list,
return_numpy=False)
# LoDTensor cannot be fetch as numpy directly
rets = [np.array(v) for v in rets]
if self.mode == 'test':
return rets[:]
losses = rets[:num_loss]
metric_states = restore_flatten_list(rets[num_loss:], metric_splits)
metrics = []
for metric, state in zip(self.model._metrics, metric_states):
metrics.append(metric.update(*state))
return (losses, metrics) if len(metrics) > 0 else losses
def prepare(self): def prepare(self):
modes = ['train', 'eval', 'test'] modes = ['train', 'eval', 'test']
...@@ -345,6 +377,7 @@ class StaticGraphAdapter(object): ...@@ -345,6 +377,7 @@ class StaticGraphAdapter(object):
lr_var = self.model._optimizer._learning_rate_map[self._orig_prog] lr_var = self.model._optimizer._learning_rate_map[self._orig_prog]
self.model._optimizer._learning_rate_map[prog] = lr_var self.model._optimizer._learning_rate_map[prog] = lr_var
losses = [] losses = []
metrics = []
with fluid.program_guard(prog, self._startup_prog): with fluid.program_guard(prog, self._startup_prog):
if isinstance(self.model._inputs, dict): if isinstance(self.model._inputs, dict):
ins = [self.model._inputs[n] \ ins = [self.model._inputs[n] \
...@@ -358,6 +391,8 @@ class StaticGraphAdapter(object): ...@@ -358,6 +391,8 @@ class StaticGraphAdapter(object):
if mode != 'test': if mode != 'test':
if self.model._loss_function: if self.model._loss_function:
losses = self.model._loss_function(outputs, labels) losses = self.model._loss_function(outputs, labels)
for metric in self.model._metrics:
metrics.append(to_list(metric.add_metric_op(outputs, labels)))
if mode == 'train' and self.model._optimizer: if mode == 'train' and self.model._optimizer:
self._loss_endpoint = fluid.layers.sum(losses) self._loss_endpoint = fluid.layers.sum(losses)
self.model._optimizer.minimize(self._loss_endpoint) self.model._optimizer.minimize(self._loss_endpoint)
...@@ -367,7 +402,7 @@ class StaticGraphAdapter(object): ...@@ -367,7 +402,7 @@ class StaticGraphAdapter(object):
self._input_vars[mode] = inputs self._input_vars[mode] = inputs
self._label_vars[mode] = labels self._label_vars[mode] = labels
self._progs[mode] = prog self._progs[mode] = prog
self._endpoints[mode] = {"output": outputs, "loss": losses} self._endpoints[mode] = {"output": outputs, "loss": losses, "metric": metrics}
def _compile_and_initialize(self, prog, mode): def _compile_and_initialize(self, prog, mode):
compiled_prog = self._compiled_progs.get(mode, None) compiled_prog = self._compiled_progs.get(mode, None)
...@@ -429,33 +464,44 @@ class DynamicGraphAdapter(object): ...@@ -429,33 +464,44 @@ class DynamicGraphAdapter(object):
self.mode = 'train' self.mode = 'train'
inputs = to_list(inputs) inputs = to_list(inputs)
if labels is not None: if labels is not None:
labels = to_list(labels) labels = [to_variable(l) for l in to_list(labels)]
outputs = self.model.forward(* [to_variable(x) for x in inputs]) outputs = to_list(self.model.forward(*[to_variable(x) for x in inputs]))
losses = self.model._loss_function(outputs, labels) losses = self.model._loss_function(outputs, labels)
final_loss = fluid.layers.sum(losses) final_loss = fluid.layers.sum(losses)
final_loss.backward() final_loss.backward()
self.model._optimizer.minimize(final_loss) self.model._optimizer.minimize(final_loss)
self.model.clear_gradients() self.model.clear_gradients()
return [to_numpy(o) for o in to_list(outputs)], \ metrics = []
[to_numpy(l) for l in losses] for metric in self.model._metrics:
metric_outs = metric.add_metric_op(outputs, to_list(labels))
m = metric.update(*[to_numpy(m) for m in to_list(metric_outs)])
metrics.append(m)
return ([to_numpy(l) for l in losses], metrics) \
if len(metrics) > 0 else [to_numpy(l) for l in losses]
def eval(self, inputs, labels=None): def eval(self, inputs, labels=None):
super(Model, self.model).eval() super(Model, self.model).eval()
self.mode = 'eval' self.mode = 'eval'
inputs = to_list(inputs) inputs = to_list(inputs)
if labels is not None: if labels is not None:
labels = to_list(labels) labels = [to_variable(l) for l in to_list(labels)]
outputs = self.model.forward(* [to_variable(x) for x in inputs]) outputs = to_list(self.model.forward(*[to_variable(x) for x in inputs]))
if self.model._loss_function: if self.model._loss_function:
losses = self.model._loss_function(outputs, labels) losses = self.model._loss_function(outputs, labels)
else: else:
losses = [] losses = []
metrics = []
for metric in self.model._metrics:
metric_outs = metric.add_metric_op(outputs, labels)
m = metric.update(*[to_numpy(m) for m in to_list(metric_outs)])
metrics.append(m)
# To be consistent with static graph # To be consistent with static graph
# return empty loss if loss_function is None # return empty loss if loss_function is None
return [to_numpy(o) for o in to_list(outputs)], \ return ([to_numpy(l) for l in losses], metrics) \
[to_numpy(l) for l in losses] if len(metrics) > 0 else [to_numpy(l) for l in losses]
def test(self, inputs): def test(self, inputs):
super(Model, self.model).eval() super(Model, self.model).eval()
...@@ -567,6 +613,7 @@ class Model(fluid.dygraph.Layer): ...@@ -567,6 +613,7 @@ class Model(fluid.dygraph.Layer):
def prepare(self, def prepare(self,
optimizer=None, optimizer=None,
loss_function=None, loss_function=None,
metrics=None,
inputs=None, inputs=None,
labels=None, labels=None,
device=None, device=None,
...@@ -580,6 +627,8 @@ class Model(fluid.dygraph.Layer): ...@@ -580,6 +627,8 @@ class Model(fluid.dygraph.Layer):
loss_function (Loss|None): loss function must be set in training loss_function (Loss|None): loss function must be set in training
and should be a Loss instance. It can be None when there is and should be a Loss instance. It can be None when there is
no loss. no loss.
metrics (Metric|list of Metric|None): if metrics is set, all
metric will be calculate and output in train/eval mode.
inputs (Input|list|dict|None): inputs, entry points of network, inputs (Input|list|dict|None): inputs, entry points of network,
could be a Input layer, or lits of Input layers, could be a Input layer, or lits of Input layers,
or dict (name: Input), or None. For static graph, or dict (name: Input), or None. For static graph,
...@@ -615,6 +664,13 @@ class Model(fluid.dygraph.Layer): ...@@ -615,6 +664,13 @@ class Model(fluid.dygraph.Layer):
"'inputs' must be list or dict in static graph mode") "'inputs' must be list or dict in static graph mode")
if loss_function and not isinstance(labels, (list, Input)): if loss_function and not isinstance(labels, (list, Input)):
raise TypeError("'labels' must be list in static graph mode") raise TypeError("'labels' must be list in static graph mode")
metrics = metrics or []
for metric in to_list(metrics):
assert isinstance(metric, Metric), \
"{} is not sub class of Metric".format(metric.__class__.__name__)
self._metrics = to_list(metrics)
self._inputs = inputs self._inputs = inputs
self._labels = labels self._labels = labels
self._device = device self._device = device
......
...@@ -33,9 +33,14 @@ from paddle.fluid.dygraph.nn import Conv2D ...@@ -33,9 +33,14 @@ from paddle.fluid.dygraph.nn import Conv2D
from paddle.fluid.param_attr import ParamAttr from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.regularizer import L2Decay from paddle.fluid.regularizer import L2Decay
from model import Model, Loss, shape_hints from model import Model, Loss, Input
from resnet import ResNet, ConvBNLayer from resnet import ResNet, ConvBNLayer
import logging
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__)
# XXX transfer learning # XXX transfer learning
class ResNetBackBone(ResNet): class ResNetBackBone(ResNet):
...@@ -102,13 +107,14 @@ class YoloDetectionBlock(fluid.dygraph.Layer): ...@@ -102,13 +107,14 @@ class YoloDetectionBlock(fluid.dygraph.Layer):
class YOLOv3(Model): class YOLOv3(Model):
def __init__(self): def __init__(self, num_classes=80):
super(YOLOv3, self).__init__() super(YOLOv3, self).__init__()
self.num_classes = 80 self.num_classes = num_classes
self.anchors = [10, 13, 16, 30, 33, 23, 30, 61, 62, 45, self.anchors = [10, 13, 16, 30, 33, 23, 30, 61, 62, 45,
59, 119, 116, 90, 156, 198, 373, 326] 59, 119, 116, 90, 156, 198, 373, 326]
self.anchor_masks = [[6, 7, 8], [3, 4, 5], [0, 1, 2]] self.anchor_masks = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
self.valid_thresh = 0.005 self.valid_thresh = 0.005
self.nms_thresh = 0.45
self.nms_topk = 400 self.nms_topk = 400
self.nms_posk = 100 self.nms_posk = 100
self.draw_thresh = 0.5 self.draw_thresh = 0.5
...@@ -146,8 +152,7 @@ class YOLOv3(Model): ...@@ -146,8 +152,7 @@ class YOLOv3(Model):
act='leaky_relu')) act='leaky_relu'))
self.route_blocks.append(route) self.route_blocks.append(route)
@shape_hints(inputs=[None, 3, None, None], im_shape=[None, 2]) def forward(self, inputs, img_info):
def forward(self, inputs, im_shape):
outputs = [] outputs = []
boxes = [] boxes = []
scores = [] scores = []
...@@ -161,48 +166,50 @@ class YOLOv3(Model): ...@@ -161,48 +166,50 @@ class YOLOv3(Model):
feat = fluid.layers.concat(input=[route, feat], axis=1) feat = fluid.layers.concat(input=[route, feat], axis=1)
route, tip = self.yolo_blocks[idx](feat) route, tip = self.yolo_blocks[idx](feat)
block_out = self.block_outputs[idx](tip) block_out = self.block_outputs[idx](tip)
outputs.append(block_out)
if idx < 2: if idx < 2:
route = self.route_blocks[idx](route) route = self.route_blocks[idx](route)
route = fluid.layers.resize_nearest(route, scale=2) route = fluid.layers.resize_nearest(route, scale=2)
anchor_mask = self.anchor_masks[idx] if self.mode == 'test':
mask_anchors = [] anchor_mask = self.anchor_masks[idx]
for m in anchor_mask: mask_anchors = []
mask_anchors.append(self.anchors[2 * m]) for m in anchor_mask:
mask_anchors.append(self.anchors[2 * m + 1]) mask_anchors.append(self.anchors[2 * m])
b, s = fluid.layers.yolo_box( mask_anchors.append(self.anchors[2 * m + 1])
x=block_out, img_shape = fluid.layers.slice(img_info, axes=[1], starts=[1], ends=[3])
img_size=im_shape, img_id = fluid.layers.slice(img_info, axes=[1], starts=[0], ends=[1])
anchors=mask_anchors, b, s = fluid.layers.yolo_box(
class_num=self.num_classes, x=block_out,
conf_thresh=self.valid_thresh, img_size=img_shape,
downsample_ratio=downsample) anchors=mask_anchors,
class_num=self.num_classes,
outputs.append(block_out) conf_thresh=self.valid_thresh,
boxes.append(b) downsample_ratio=downsample)
scores.append(fluid.layers.transpose(s, perm=[0, 2, 1]))
boxes.append(b)
scores.append(fluid.layers.transpose(s, perm=[0, 2, 1]))
downsample //= 2 downsample //= 2
if self.mode != 'test': if self.mode != 'test':
return outputs return outputs
return fluid.layers.multiclass_nms( return [img_id, fluid.layers.multiclass_nms(
bboxes=fluid.layers.concat(boxes, axis=1), bboxes=fluid.layers.concat(boxes, axis=1),
scores=fluid.layers.concat(scores, axis=2), scores=fluid.layers.concat(scores, axis=2),
score_threshold=self.valid_thresh, score_threshold=self.valid_thresh,
nms_top_k=self.nms_topk, nms_top_k=self.nms_topk,
keep_top_k=self.nms_posk, keep_top_k=self.nms_posk,
nms_threshold=self.nms_thresh, nms_threshold=self.nms_thresh,
background_label=-1) background_label=-1)]
class YoloLoss(Loss): class YoloLoss(Loss):
def __init__(self, num_classes=80, num_max_boxes=50): def __init__(self, num_classes=80):
super(YoloLoss, self).__init__() super(YoloLoss, self).__init__()
self.num_classes = num_classes self.num_classes = num_classes
self.num_max_boxes = num_max_boxes
self.ignore_thresh = 0.7 self.ignore_thresh = 0.7
self.anchors = [10, 13, 16, 30, 33, 23, 30, 61, 62, 45, self.anchors = [10, 13, 16, 30, 33, 23, 30, 61, 62, 45,
59, 119, 116, 90, 156, 198, 373, 326] 59, 119, 116, 90, 156, 198, 373, 326]
...@@ -226,20 +233,11 @@ class YoloLoss(Loss): ...@@ -226,20 +233,11 @@ class YoloLoss(Loss):
class_num=self.num_classes, class_num=self.num_classes,
ignore_thresh=self.ignore_thresh, ignore_thresh=self.ignore_thresh,
use_label_smooth=True) use_label_smooth=True)
loss = fluid.layers.reduce_mean(loss)
losses.append(loss) losses.append(loss)
downsample //= 2 downsample //= 2
return losses return losses
def infer_shape(self, _):
return [
[None, self.num_max_boxes, 4],
[None, self.num_max_boxes],
[None, self.num_max_boxes]
]
def infer_dtype(self, _):
return ['float32', 'int32', 'float32']
def make_optimizer(parameter_list=None): def make_optimizer(parameter_list=None):
base_lr = FLAGS.lr base_lr = FLAGS.lr
...@@ -293,7 +291,7 @@ def random_crop(inputs): ...@@ -293,7 +291,7 @@ def random_crop(inputs):
thresholds = [.0, .1, .3, .5, .7, .9] thresholds = [.0, .1, .3, .5, .7, .9]
scaling = [.3, 1.] scaling = [.3, 1.]
img, gt_box, gt_label = inputs img, img_ids, gt_box, gt_label = inputs
h, w = img.shape[:2] h, w = img.shape[:2]
if len(gt_box) == 0: if len(gt_box) == 0:
...@@ -327,7 +325,7 @@ def random_crop(inputs): ...@@ -327,7 +325,7 @@ def random_crop(inputs):
img = img[y1:y2, x1:x2, :] img = img[y1:y2, x1:x2, :]
gt_box = np.take(cropped_box, valid_ids, axis=0) gt_box = np.take(cropped_box, valid_ids, axis=0)
gt_label = np.take(gt_label, valid_ids, axis=0) gt_label = np.take(gt_label, valid_ids, axis=0)
return img, gt_box, gt_label return img, img_ids, gt_box, gt_label
return inputs return inputs
...@@ -335,9 +333,9 @@ def random_crop(inputs): ...@@ -335,9 +333,9 @@ def random_crop(inputs):
# XXX mix up, color distort and random expand are skipped for simplicity # XXX mix up, color distort and random expand are skipped for simplicity
def sample_transform(inputs, mode='train', num_max_boxes=50): def sample_transform(inputs, mode='train', num_max_boxes=50):
if mode == 'train': if mode == 'train':
img, gt_box, gt_label = random_crop(inputs) img, img_id, gt_box, gt_label = random_crop(inputs)
else: else:
img, gt_box, gt_label = inputs img, img_id, gt_box, gt_label = inputs
h, w = img.shape[:2] h, w = img.shape[:2]
# random flip # random flip
...@@ -350,7 +348,7 @@ def sample_transform(inputs, mode='train', num_max_boxes=50): ...@@ -350,7 +348,7 @@ def sample_transform(inputs, mode='train', num_max_boxes=50):
if len(gt_label) == 0: if len(gt_label) == 0:
gt_box = np.zeros([num_max_boxes, 4], dtype=np.float32) gt_box = np.zeros([num_max_boxes, 4], dtype=np.float32)
gt_label = np.zeros([num_max_boxes, 1], dtype=np.int32) gt_label = np.zeros([num_max_boxes], dtype=np.int32)
return img, gt_box, gt_label return img, gt_box, gt_label
gt_box = gt_box[:num_max_boxes, :] gt_box = gt_box[:num_max_boxes, :]
...@@ -362,9 +360,9 @@ def sample_transform(inputs, mode='train', num_max_boxes=50): ...@@ -362,9 +360,9 @@ def sample_transform(inputs, mode='train', num_max_boxes=50):
pad = num_max_boxes - gt_label.size pad = num_max_boxes - gt_label.size
gt_box = np.pad(gt_box, ((0, pad), (0, 0)), mode='constant') gt_box = np.pad(gt_box, ((0, pad), (0, 0)), mode='constant')
gt_label = np.pad(gt_label, [(0, pad)], mode='constant') gt_label = np.pad(gt_label, ((0, pad)), mode='constant')
return img, gt_box, gt_label return img, img_id, gt_box, gt_label
def batch_transform(batch, mode='train'): def batch_transform(batch, mode='train'):
...@@ -376,7 +374,8 @@ def batch_transform(batch, mode='train'): ...@@ -376,7 +374,8 @@ def batch_transform(batch, mode='train'):
d = 608 d = 608
interp = cv2.INTER_CUBIC interp = cv2.INTER_CUBIC
# transpose batch # transpose batch
imgs, gt_boxes, gt_labels = list(zip(*batch)) imgs, img_ids, gt_boxes, gt_labels = list(zip(*batch))
img_shapes = np.array([[im.shape[0], im.shape[1]] for im in imgs]).astype('int32')
imgs = np.array([cv2.resize( imgs = np.array([cv2.resize(
img, (d, d), interpolation=interp) for img in imgs]) img, (d, d), interpolation=interp) for img in imgs])
...@@ -389,12 +388,13 @@ def batch_transform(batch, mode='train'): ...@@ -389,12 +388,13 @@ def batch_transform(batch, mode='train'):
imgs *= invstd imgs *= invstd
imgs = imgs.transpose((0, 3, 1, 2)) imgs = imgs.transpose((0, 3, 1, 2))
im_shapes = np.full([len(imgs), 2], d, dtype=np.int32) img_ids = np.array(img_ids)
img_info = np.concatenate([img_ids, img_shapes], axis=1)
gt_boxes = np.array(gt_boxes) gt_boxes = np.array(gt_boxes)
gt_labels = np.array(gt_labels) gt_labels = np.array(gt_labels)
# XXX since mix up is not used, scores are all ones # XXX since mix up is not used, scores are all ones
gt_scores = np.ones_like(gt_labels, dtype=np.float32) gt_scores = np.ones_like(gt_labels, dtype=np.float32)
return [imgs, im_shapes], [gt_boxes, gt_labels, gt_scores] return [imgs, img_info], [gt_boxes, gt_labels, gt_scores]
def coco2017(root_dir, mode='train'): def coco2017(root_dir, mode='train'):
...@@ -434,17 +434,18 @@ def coco2017(root_dir, mode='train'): ...@@ -434,17 +434,18 @@ def coco2017(root_dir, mode='train'):
gt_box = np.array(gt_box, dtype=np.float32) gt_box = np.array(gt_box, dtype=np.float32)
gt_label = np.array([class_map[cls] for cls in gt_label], gt_label = np.array([class_map[cls] for cls in gt_label],
dtype=np.int32)[:, np.newaxis] dtype=np.int32)[:, np.newaxis]
im_id = np.array([img['id']], dtype=np.int32)
if gt_label.size == 0 and not mode == 'train': if gt_label.size == 0 and not mode == 'train':
continue continue
samples.append((file_path, gt_box.copy(), gt_label.copy())) samples.append((file_path, im_id.copy(), gt_box.copy(), gt_label.copy()))
def iterator(): def iterator():
if mode == 'train': if mode == 'train':
random.shuffle(samples) np.random.shuffle(samples)
for file_path, gt_box, gt_label in samples: for file_path, im_id, gt_box, gt_label in samples:
img = cv2.imread(file_path) img = cv2.imread(file_path)
yield img, gt_box, gt_label yield img, im_id, gt_box, gt_label
return iterator return iterator
...@@ -457,14 +458,13 @@ def run(model, loader, mode='train'): ...@@ -457,14 +458,13 @@ def run(model, loader, mode='train'):
start = time.time() start = time.time()
for idx, batch in enumerate(loader()): for idx, batch in enumerate(loader()):
outputs, losses = getattr(model, mode)( losses = getattr(model, mode)(batch[0], batch[1])
batch[0], batch[1], device='gpu', device_ids=device_ids)
total_loss += np.sum(losses) total_loss += np.sum(losses)
if idx > 1: # skip first two steps if idx > 1: # skip first two steps
total_time += time.time() - start total_time += time.time() - start
if idx % 10 == 0: if idx % 10 == 0:
print("{:04d}: loss {:0.3f} time: {:0.3f}".format( logger.info("{:04d}: loss {:0.3f} time: {:0.3f}".format(
idx, total_loss / (idx + 1), total_time / max(1, (idx - 1)))) idx, total_loss / (idx + 1), total_time / max(1, (idx - 1))))
start = time.time() start = time.time()
...@@ -501,26 +501,46 @@ def main(): ...@@ -501,26 +501,46 @@ def main():
coco2017(FLAGS.data, 'val'), coco2017(FLAGS.data, 'val'),
process_num=8, process_num=8,
buffer_size=4 * batch_size), buffer_size=4 * batch_size),
batch_size=batch_size), batch_size=1),
process_num=2, buffer_size=4) process_num=2, buffer_size=4)
if not os.path.exists('yolo_checkpoints'): if not os.path.exists('yolo_checkpoints'):
os.mkdir('yolo_checkpoints') os.mkdir('yolo_checkpoints')
with guard: with guard:
model = YOLOv3() NUM_CLASSES = 7
NUM_MAX_BOXES = 50
model = YOLOv3(num_classes=NUM_CLASSES)
# XXX transfer learning # XXX transfer learning
if FLAGS.pretrain_weights is not None:
model.backbone.load(FLAGS.pretrain_weights)
if FLAGS.weights is not None: if FLAGS.weights is not None:
model.backbone.load(FLAGS.weights) model.load(FLAGS.weights)
optim = make_optimizer(parameter_list=model.parameters()) optim = make_optimizer(parameter_list=model.parameters())
model.prepare(optim, YoloLoss()) anno_path = os.path.join(FLAGS.data, 'annotations', 'instances_val2017.json')
inputs = [Input([None, 3, None, None], 'float32', name='image'),
Input([None, 3], 'int32', name='img_info')]
labels = [Input([None, NUM_MAX_BOXES, 4], 'float32', name='gt_bbox'),
Input([None, NUM_MAX_BOXES], 'int32', name='gt_label'),
Input([None, NUM_MAX_BOXES], 'float32', name='gt_score')]
model.prepare(optim,
YoloLoss(num_classes=NUM_CLASSES),
# For YOLOv3, output variable in train/eval is different,
# which is not supported by metric, add by callback later?
# metrics=COCOMetric(anno_path, with_background=False)
inputs=inputs,
labels = labels)
for e in range(epoch): for e in range(epoch):
print("======== train epoch {} ========".format(e)) logger.info("======== train epoch {} ========".format(e))
run(model, train_loader) run(model, train_loader)
model.save('yolo_checkpoints/{:02d}'.format(e)) model.save('yolo_checkpoints/{:02d}'.format(e))
print("======== eval epoch {} ========".format(e)) logger.info("======== eval epoch {} ========".format(e))
run(model, val_loader, mode='eval') run(model, val_loader, mode='eval')
# should be called in fit()
for metric in model._metrics:
metric.accumulate()
metric.reset()
if __name__ == '__main__': if __name__ == '__main__':
...@@ -538,8 +558,11 @@ if __name__ == '__main__': ...@@ -538,8 +558,11 @@ if __name__ == '__main__':
parser.add_argument( parser.add_argument(
"-n", "--num_devices", default=8, type=int, help="number of devices") "-n", "--num_devices", default=8, type=int, help="number of devices")
parser.add_argument( parser.add_argument(
"-w", "--weights", default=None, type=str, "-p", "--pretrain_weights", default=None, type=str,
help="path to pretrained weights") help="path to pretrained weights")
parser.add_argument(
"-w", "--weights", default=None, type=str,
help="path to model weights")
FLAGS = parser.parse_args() FLAGS = parser.parse_args()
assert FLAGS.data, "error: must provide data path" assert FLAGS.data, "error: must provide data path"
main() main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import json
from pycocotools.cocoeval import COCOeval
from pycocotools.coco import COCO
from metrics import Metric
import logging
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__)
__all__ = ['COCOMetric']
OUTFILE = './bbox.json'
# considered to change to a callback later
class COCOMetric(Metric):
"""
Metrci for MS-COCO dataset, only support update with batch
size as 1.
Args:
anno_path(str): path to COCO annotation json file
with_background(bool): whether load category id with
background as 0, default True
"""
def __init__(self, anno_path, with_background=True, **kwargs):
super(COCOMetric, self).__init__(**kwargs)
self.anno_path = anno_path
self.with_background = with_background
self.bbox_results = []
self.coco_gt = COCO(anno_path)
cat_ids = self.coco_gt.getCatIds()
self.clsid2catid = dict(
{i + int(with_background): catid
for i, catid in enumerate(cat_ids)})
def update(self, preds, *args, **kwargs):
im_ids, bboxes = preds
assert im_ids.shape[0] == 1, \
"COCOMetric can only update with batch size = 1"
if bboxes.shape[1] != 6:
# no bbox detected in this batch
return
im_id = int(im_ids)
for i in range(bboxes.shape[0]):
dt = bboxes[i, :]
clsid, score, xmin, ymin, xmax, ymax = dt.tolist()
catid = (self.clsid2catid[int(clsid)])
w = xmax - xmin + 1
h = ymax - ymin + 1
bbox = [xmin, ymin, w, h]
coco_res = {
'image_id': im_id,
'category_id': catid,
'bbox': bbox,
'score': score
}
self.bbox_results.append(coco_res)
def reset(self):
self.bbox_results = []
def accumulate(self):
if len(self.bbox_results) == 0:
logger.warning("The number of valid bbox detected is zero.\n \
Please use reasonable model and check input data.\n \
stop COCOMetric accumulate!")
return [0.0]
with open(OUTFILE, 'w') as f:
json.dump(self.bbox_results, f)
map_stats = self.cocoapi_eval(OUTFILE, 'bbox', coco_gt=self.coco_gt)
# flush coco evaluation result
sys.stdout.flush()
self.result = map_stats[0]
return self.result
def cocoapi_eval(self, jsonfile, style, coco_gt=None, anno_file=None):
assert coco_gt != None or anno_file != None
if coco_gt == None:
coco_gt = COCO(anno_file)
logger.info("Start evaluate...")
coco_dt = coco_gt.loadRes(jsonfile)
coco_eval = COCOeval(coco_gt, coco_dt, style)
coco_eval.evaluate()
coco_eval.accumulate()
coco_eval.summarize()
return coco_eval.stats
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册