提交 90134440 编写于 作者: L LielinJiang

Merge branch 'master' of https://github.com/PaddlePaddle/hapi into multiple-gpus-append-op

*.pyc
*.json
output*
*checkpoint*
- repo: https://github.com/PaddlePaddle/mirrors-yapf.git
sha: 0d79c0c469bab64f7229c9aca2b1186ef47f0e37
hooks:
- id: yapf
files: \.py$
- repo: https://github.com/pre-commit/pre-commit-hooks
sha: a11d9314b22d8f8c7556443875b731ef05965464
hooks:
- id: check-merge-conflict
- id: check-symlinks
- id: detect-private-key
files: (?!.*paddle)^.*$
- id: end-of-file-fixer
files: \.(md|yml)$
- id: trailing-whitespace
files: \.(md|yml)$
- repo: https://github.com/Lucas-C/pre-commit-hooks
sha: v1.0.1
hooks:
- id: forbid-crlf
files: \.(md|yml)$
- id: remove-crlf
files: \.(md|yml)$
- id: forbid-tabs
files: \.(md|yml)$
- id: remove-tabs
files: \.(md|yml)$
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
import six
import abc
import numpy as np
import paddle.fluid as fluid
import logging
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__)
__all__ = ['Metric', 'Accuracy']
@six.add_metaclass(abc.ABCMeta)
class Metric(object):
"""
Base class for metric, encapsulates metric logic and APIs
Usage:
m = SomeMetric()
for prediction, label in ...:
m.update(prediction, label)
m.accumulate()
"""
@abc.abstractmethod
def reset(self):
"""
Reset states and result
"""
raise NotImplementedError("function 'reset' not implemented in {}.".format(self.__class__.__name__))
@abc.abstractmethod
def update(self, *args, **kwargs):
"""
Update states for metric
"""
raise NotImplementedError("function 'update' not implemented in {}.".format(self.__class__.__name__))
@abc.abstractmethod
def accumulate(self):
"""
Accumulates statistics, computes and returns the metric value
"""
raise NotImplementedError("function 'accumulate' not implemented in {}.".format(self.__class__.__name__))
def add_metric_op(self, pred, label):
"""
Add process op for metric in program
"""
return pred, label
class Accuracy(Metric):
"""
Encapsulates accuracy metric logic
"""
def __init__(self, topk=(1, ), *args, **kwargs):
super(Accuracy, self).__init__(*args, **kwargs)
self.topk = topk
self.maxk = max(topk)
self.reset()
def add_metric_op(self, pred, label, *args, **kwargs):
pred = fluid.layers.argsort(pred[0], descending=True)[1][:, :self.maxk]
correct = pred == label[0]
return correct
def update(self, correct, *args, **kwargs):
accs = []
for i, k in enumerate(self.topk):
num_corrects = correct[:, :k].sum()
num_samples = len(correct)
accs.append(float(num_corrects) / num_samples)
self.total[i] += num_corrects
self.count[i] += num_samples
return accs
def reset(self):
self.total = [0.] * len(self.topk)
self.count = [0] * len(self.topk)
def accumulate(self):
res = []
for t, c in zip(self.total, self.count):
res.append(float(t) / c)
return res
...@@ -26,7 +26,8 @@ from paddle import fluid ...@@ -26,7 +26,8 @@ from paddle import fluid
from paddle.fluid.optimizer import Momentum from paddle.fluid.optimizer import Momentum
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear
from model import Model, CrossEntropy from model import Model, CrossEntropy, Input
from metrics import Accuracy
class SimpleImgConvPool(fluid.dygraph.Layer): class SimpleImgConvPool(fluid.dygraph.Layer):
...@@ -78,7 +79,6 @@ class SimpleImgConvPool(fluid.dygraph.Layer): ...@@ -78,7 +79,6 @@ class SimpleImgConvPool(fluid.dygraph.Layer):
class MNIST(Model): class MNIST(Model):
def __init__(self): def __init__(self):
super(MNIST, self).__init__() super(MNIST, self).__init__()
self._simple_img_conv_pool_1 = SimpleImgConvPool( self._simple_img_conv_pool_1 = SimpleImgConvPool(
1, 20, 5, 2, 2, act="relu") 1, 20, 5, 2, 2, act="relu")
...@@ -88,12 +88,13 @@ class MNIST(Model): ...@@ -88,12 +88,13 @@ class MNIST(Model):
pool_2_shape = 50 * 4 * 4 pool_2_shape = 50 * 4 * 4
SIZE = 10 SIZE = 10
scale = (2.0 / (pool_2_shape**2 * SIZE))**0.5 scale = (2.0 / (pool_2_shape**2 * SIZE))**0.5
self._fc = Linear(800, self._fc = Linear(
10, 800,
param_attr=fluid.param_attr.ParamAttr( 10,
initializer=fluid.initializer.NormalInitializer( param_attr=fluid.param_attr.ParamAttr(
loc=0.0, scale=scale)), initializer=fluid.initializer.NormalInitializer(
act="softmax") loc=0.0, scale=scale)),
act="softmax")
def forward(self, inputs): def forward(self, inputs):
x = self._simple_img_conv_pool_1(inputs) x = self._simple_img_conv_pool_1(inputs)
...@@ -137,44 +138,46 @@ def main(): ...@@ -137,44 +138,46 @@ def main():
paddle.batch(paddle.dataset.mnist.test(), paddle.batch(paddle.dataset.mnist.test(),
batch_size=FLAGS.batch_size, drop_last=True), 1, 1) batch_size=FLAGS.batch_size, drop_last=True), 1, 1)
device_ids = list(range(FLAGS.num_devices))
with guard: with guard:
model = MNIST() model = MNIST()
optim = Momentum(learning_rate=FLAGS.lr, momentum=.9, optim = Momentum(
parameter_list=model.parameters()) learning_rate=FLAGS.lr,
model.prepare(optim, CrossEntropy()) momentum=.9,
parameter_list=model.parameters())
inputs = [Input([None, 1, 28, 28], 'float32', name='image')]
labels = [Input([None, 1], 'int64', name='label')]
model.prepare(optim, CrossEntropy(), Accuracy(topk=(1, 2)), inputs, labels)
if FLAGS.resume is not None: if FLAGS.resume is not None:
model.load(FLAGS.resume) model.load(FLAGS.resume)
for e in range(FLAGS.epoch): for e in range(FLAGS.epoch):
train_loss = 0.0 train_loss = 0.0
train_acc = 0.0
val_loss = 0.0 val_loss = 0.0
val_acc = 0.0
print("======== train epoch {} ========".format(e)) print("======== train epoch {} ========".format(e))
for idx, batch in enumerate(train_loader()): for idx, batch in enumerate(train_loader()):
outputs, losses = model.train(batch[0], batch[1], device='gpu', losses, metrics = model.train(batch[0], batch[1])
device_ids=device_ids)
acc = accuracy(outputs[0], batch[1])[0]
train_loss += np.sum(losses) train_loss += np.sum(losses)
train_acc += acc
if idx % 10 == 0: if idx % 10 == 0:
print("{:04d}: loss {:0.3f} top1: {:0.3f}%".format( print("{:04d}: loss {:0.3f} top1: {:0.3f}% top2: {:0.3f}%".format(
idx, train_loss / (idx + 1), train_acc / (idx + 1))) idx, train_loss / (idx + 1), metrics[0][0], metrics[0][1]))
for metric in model._metrics:
res = metric.accumulate()
print("train epoch {:03d}: top1: {:0.3f}%, top2: {:0.3f}".format(e, res[0], res[1]))
metric.reset()
print("======== eval epoch {} ========".format(e)) print("======== eval epoch {} ========".format(e))
for idx, batch in enumerate(val_loader()): for idx, batch in enumerate(val_loader()):
outputs, losses = model.eval(batch[0], batch[1], device='gpu', losses, metrics = model.eval(batch[0], batch[1])
device_ids=device_ids)
acc = accuracy(outputs[0], batch[1])[0]
val_loss += np.sum(losses) val_loss += np.sum(losses)
val_acc += acc
if idx % 10 == 0: if idx % 10 == 0:
print("{:04d}: loss {:0.3f} top1: {:0.3f}%".format( print("{:04d}: loss {:0.3f} top1: {:0.3f}% top2: {:0.3f}%".format(
idx, val_loss / (idx + 1), val_acc / (idx + 1))) idx, val_loss / (idx + 1), metrics[0][0], metrics[0][1]))
for metric in model._metrics:
res = metric.accumulate()
print("eval epoch {:03d}: top1: {:0.3f}%, top2: {:0.3f}".format(e, res[0], res[1]))
metric.reset()
model.save('mnist_checkpoints/{:02d}'.format(e)) model.save('mnist_checkpoints/{:02d}'.format(e))
...@@ -185,14 +188,21 @@ if __name__ == '__main__': ...@@ -185,14 +188,21 @@ if __name__ == '__main__':
parser.add_argument( parser.add_argument(
"-e", "--epoch", default=100, type=int, help="number of epoch") "-e", "--epoch", default=100, type=int, help="number of epoch")
parser.add_argument( parser.add_argument(
'--lr', '--learning-rate', default=1e-3, type=float, metavar='LR', '--lr',
'--learning-rate',
default=1e-3,
type=float,
metavar='LR',
help='initial learning rate') help='initial learning rate')
parser.add_argument( parser.add_argument(
"-b", "--batch_size", default=128, type=int, help="batch size") "-b", "--batch_size", default=128, type=int, help="batch size")
parser.add_argument( parser.add_argument(
"-n", "--num_devices", default=4, type=int, help="number of devices") "-n", "--num_devices", default=1, type=int, help="number of devices")
parser.add_argument( parser.add_argument(
"-r", "--resume", default=None, type=str, "-r",
"--resume",
default=None,
type=str,
help="checkpoint path to resume") help="checkpoint path to resume")
FLAGS = parser.parse_args() FLAGS = parser.parse_args()
main() main()
...@@ -25,15 +25,20 @@ from paddle.fluid.framework import in_dygraph_mode, Variable ...@@ -25,15 +25,20 @@ from paddle.fluid.framework import in_dygraph_mode, Variable
from paddle.fluid.executor import global_scope from paddle.fluid.executor import global_scope
from paddle.fluid.io import is_belong_to_optimizer from paddle.fluid.io import is_belong_to_optimizer
from paddle.fluid.dygraph.base import to_variable from paddle.fluid.dygraph.base import to_variable
from paddle.fluid.incubate.fleet.collective import fleet, DistributedStrategy from paddle.fluid.incubate.fleet.collective import fleet, DistributedStrategy
import paddle.fluid.incubate.fleet.base.role_maker as role_maker import paddle.fluid.incubate.fleet.base.role_maker as role_maker
import distributed import distributed
__all__ = ['shape_hints', 'Model', 'Loss', 'CrossEntropy'] from metrics import Metric
__all__ = ['Model', 'Loss', 'CrossEntropy', 'Input']
def to_list(value): def to_list(value):
if value is None:
return value
if isinstance(value, (list, tuple)): if isinstance(value, (list, tuple)):
return value return value
return [value] return [value]
...@@ -47,6 +52,26 @@ def to_numpy(var): ...@@ -47,6 +52,26 @@ def to_numpy(var):
return np.array(t) return np.array(t)
def flatten_list(l):
assert isinstance(l, list), "not a list"
outl = []
splits = []
for sl in l:
assert isinstance(sl, list), "sub content not a list"
splits.append(len(sl))
outl += sl
return outl, splits
def restore_flatten_list(l, splits):
outl = []
for split in splits:
assert len(l) >= split, "list length invalid"
sl, l = l[:split], l[split:]
outl.append(sl)
return outl
def extract_args(func): def extract_args(func):
if hasattr(inspect, 'getfullargspec'): if hasattr(inspect, 'getfullargspec'):
return inspect.getfullargspec(func)[0] return inspect.getfullargspec(func)[0]
...@@ -54,20 +79,15 @@ def extract_args(func): ...@@ -54,20 +79,15 @@ def extract_args(func):
return inspect.getargspec(func)[0] return inspect.getargspec(func)[0]
def shape_hints(**hints): class Input(fluid.dygraph.Layer):
assert hints, "hints can not be empty" def __init__(self, shape=None, dtype=None, name=None):
assert all(isinstance(h, (list, tuple)) for h in hints.values()), \ super(Input, self).__init__()
"shape hint must be a list or tuple" self.shape = shape
self.dtype = dtype
self.name = name
def wrapper(func): def forward(self):
args = extract_args(func) return fluid.data(self.name, shape=self.shape, dtype=self.dtype)
invalid = set(hints.keys()) - set(args)
assert not invalid, \
"shape hint for arguments that are not present in forward method" \
+ ": ({})".format(", ".join(invalid))
func.shape_hints = hints
return func
return wrapper
class Loss(object): class Loss(object):
...@@ -75,12 +95,6 @@ class Loss(object): ...@@ -75,12 +95,6 @@ class Loss(object):
super(Loss, self).__init__() super(Loss, self).__init__()
self.average = average self.average = average
def infer_shape(self, outputs):
return [o.shape for o in outputs]
def infer_dtype(self, outputs):
return [o.dtype for o in outputs]
def forward(self, outputs, labels): def forward(self, outputs, labels):
raise NotImplementedError() raise NotImplementedError()
...@@ -89,24 +103,21 @@ class Loss(object): ...@@ -89,24 +103,21 @@ class Loss(object):
if in_dygraph_mode(): if in_dygraph_mode():
labels = [to_variable(l) for l in labels] labels = [to_variable(l) for l in labels]
losses = to_list(self.forward(to_list(outputs), labels)) losses = to_list(self.forward(to_list(outputs), labels))
if not self.average: if self.average:
return losses losses = [fluid.layers.reduce_mean(l) for l in losses]
return [fluid.layers.reduce_mean(l) for l in losses] else:
losses = [fluid.layers.reduce_sum(l) for l in losses]
return losses
class CrossEntropy(Loss): class CrossEntropy(Loss):
def __init__(self): def __init__(self, average=True):
super(CrossEntropy, self).__init__() super(CrossEntropy, self).__init__()
def infer_shape(self, outputs):
return [o.shape[:-1] + (1, ) for o in outputs]
def infer_dtype(self, outputs):
return ['int64' for _ in outputs]
def forward(self, outputs, labels): def forward(self, outputs, labels):
return [fluid.layers.cross_entropy(o, l) for o, l in zip( return [
outputs, labels)] fluid.layers.cross_entropy(o, l) for o, l in zip(outputs, labels)
]
class StaticGraphAdapter(object): class StaticGraphAdapter(object):
...@@ -119,24 +130,13 @@ class StaticGraphAdapter(object): ...@@ -119,24 +130,13 @@ class StaticGraphAdapter(object):
self._orig_prog = fluid.default_main_program() self._orig_prog = fluid.default_main_program()
self._label_vars = {} # label variables self._label_vars = {} # label variables
self._input_vars = {} # label variables
self._endpoints = {} self._endpoints = {}
self._loss_endpoint = None self._loss_endpoint = None
self._executor = None self._executor = None
self._progs = {} self._progs = {}
self._compiled_progs = {} self._compiled_progs = {}
self._lazy_load_optimizer = None
self._nranks = distributed.Env().nranks
self._local_rank = distributed.Env().local_rank
# parse shape hints
self._input_desc = OrderedDict([
(n, None) for n in extract_args(self.model.forward) if n != 'self'
])
if hasattr(self.model.forward, 'shape_hints'):
self._input_desc.update(self.model.forward.shape_hints)
@property @property
def mode(self): def mode(self):
return self.model.mode return self.model.mode
...@@ -145,21 +145,19 @@ class StaticGraphAdapter(object): ...@@ -145,21 +145,19 @@ class StaticGraphAdapter(object):
def mode(self, value): def mode(self, value):
self.model.mode = value self.model.mode = value
def train(self, inputs, labels, device='CPU', device_ids=None): def train(self, inputs, labels=None):
assert self.model._optimizer and self.model._loss_function, \ assert self.model._optimizer, \
"model not ready, please call `model.prepare()` first" "model not ready, please call `model.prepare()` first"
self.mode = 'train' self.mode = 'train'
return self._run(inputs, labels, device, device_ids) return self._run(inputs, labels)
def eval(self, inputs, labels, device='CPU', device_ids=None): def eval(self, inputs, labels=None):
assert self.model._loss_function, \
"model not ready, please call `model.prepare()` first"
self.mode = 'eval' self.mode = 'eval'
return self._run(inputs, labels, device, device_ids) return self._run(inputs, labels)
def test(self, inputs, device='CPU', device_ids=None): def test(self, inputs):
self.mode = 'test' self.mode = 'test'
return self._run(inputs, None, device, device_ids) return self._run(inputs, None)
def parameters(self, *args, **kwargs): def parameters(self, *args, **kwargs):
return None return None
...@@ -168,13 +166,18 @@ class StaticGraphAdapter(object): ...@@ -168,13 +166,18 @@ class StaticGraphAdapter(object):
def _save(state, path): def _save(state, path):
if not state: if not state:
return return
state = {k: to_numpy(v) if isinstance(v, Variable) else v state = {
for k, v in state.items()} k: to_numpy(v) if isinstance(v, Variable) else v
for k, v in state.items()
}
with open(path, 'wb') as f: with open(path, 'wb') as f:
pickle.dump(state, f) pickle.dump(state, f)
base = os.path.basename(path) base = os.path.basename(path)
assert base != "", "path should be of 'dirname/filename' format" assert base != "", "path should be of 'dirname/filename' format"
dir_name = os.path.dirname(path)
if dir_name and not os.path.exists(dir_name):
os.makedirs(dir_name)
param_path = path + ".pdparams" param_path = path + ".pdparams"
_save(self.model.state_dict(), param_path) _save(self.model.state_dict(), param_path)
prog = self._progs.get('train', None) prog = self._progs.get('train', None)
...@@ -182,13 +185,13 @@ class StaticGraphAdapter(object): ...@@ -182,13 +185,13 @@ class StaticGraphAdapter(object):
return return
# XXX `optimizer.state_dict()` only work in dygraph mode # XXX `optimizer.state_dict()` only work in dygraph mode
optim_path = path + ".pdopt" optim_path = path + ".pdopt"
optim = {p.name: p for p in filter( optim = {
is_belong_to_optimizer, prog.list_vars())} p.name: p
for p in filter(is_belong_to_optimizer, prog.list_vars())
}
if not optim: if not optim:
return return
# HACK this is contrived, optimizer state is not the same for
# static/dynamic graph mode
optim['__static_graph_only__'] = True
_save(optim, optim_path) _save(optim, optim_path)
def load(self, path): def load(self, path):
...@@ -223,27 +226,77 @@ class StaticGraphAdapter(object): ...@@ -223,27 +226,77 @@ class StaticGraphAdapter(object):
optim_state = _load(optim_path) optim_state = _load(optim_path)
if optim_state is None: if optim_state is None:
return return
assert '__static_graph_only__' in optim_state, \
"optimizer saved in dygraph mode is not usable in static graph"
if self._executor is not None: self._load_optimizer(optim_state, executor)
self._load_optimizer(optim_state)
else:
self._lazy_load_optimizer = optim_state
def _load_optimizer(self, state): def _load_optimizer(self, state, executor):
prog = self._progs.get('train', None) prog = self._progs.get('train', None)
optim = list(filter(is_belong_to_optimizer, prog.list_vars())) optim = list(filter(is_belong_to_optimizer, prog.list_vars()))
if not optim: if not optim:
return return
fluid.core._create_loaded_parameter( fluid.core._create_loaded_parameter(optim, global_scope(), executor)
optim, global_scope(), self._executor._default_executor)
converted_state = dict(state)
for var in optim: for var in optim:
assert var.name in state, \ if var.name in ["@LR_DECAY_COUNTER@", "global_step"]:
# When using learning rate scheduler, dygraph would name the
# global step var as "global_step" to save, while static-graph
# would has a state var named as "@LR_DECAY_COUNTER@".
# NOTE: dygraph saved global_step is 1 larger than that in
# static-graph, since the time of global_step to increase is
# different.
state_val = (
np.array(converted_state.pop("global_step")) - 1
) if "global_step" in converted_state else converted_state.pop(
"@LR_DECAY_COUNTER@", None)
if state_val is not None:
converted_state[var.name] = state_val
elif var.name.startswith("learning_rate_"):
# When using static learning rate, static-graph would make it
# a persistable var named 'unique_name.generate("learning_rate")',
# However, dygraph wouldn't save it.
if var.name not in state: continue
else:
# moment and other accumulators
if var.name not in converted_state:
# try to convert from dygraph name
opt_name = self.model._optimizer._name
opt_cls_name = self.model._optimizer.__class__.__name__
opt_unq_name = None
for name in self.model._optimizer._accumulators.keys():
accum_name = name if opt_name is None else name[len(
opt_name) + 1:]
for param_name, state_var in self.model._optimizer._accumulators[
name].items():
if opt_unq_name is None:
# can not infer out the exact unique(opt_name),
# thus try to extract rather than generate
for state_key in sorted(
state.keys(),
key=lambda x: len(x),
reverse=True):
prefix = param_name + "_" + (
opt_cls_name if opt_name is None else
opt_name) + "_"
if state_key.startswith(prefix):
prefix_offset = state_key[len(
prefix):].find("_") + len(prefix)
opt_unq_name = state_key[len(
param_name + "_"):prefix_offset]
# TODO: assert
# assert opt_unq_name is None
# gen(param.name + "_" + gen(opt_name) + "_" + accum_name)
# always end with "_0" since the unique optimizer._name
dy_state_name = (param_name + "_" + opt_unq_name +
"_" + accum_name + "_0")
converted_state[
state_var.name] = converted_state.pop(
dy_state_name)
assert var.name in converted_state, \
"variable [{}] is not in optimizer state file".format(var.name) "variable [{}] is not in optimizer state file".format(var.name)
self._set_var(var, state[var.name]) self._set_var(var, converted_state[var.name])
def _set_var(self, var, ndarray): def _set_var(self, var, ndarray):
t = global_scope().find_var(var.name).get_tensor() t = global_scope().find_var(var.name).get_tensor()
...@@ -259,21 +312,20 @@ class StaticGraphAdapter(object): ...@@ -259,21 +312,20 @@ class StaticGraphAdapter(object):
t.set(ndarray, place) t.set(ndarray, place)
def _run(self, inputs, labels=None, device='CPU', device_ids=None): def _run(self, inputs, labels=None):
compiled_prog = self._compiled_progs.get(self.mode, None)
assert compiled_prog, \
"Model is not ready, please call `model.prepare()` first"
inputs = to_list(inputs) inputs = to_list(inputs)
if labels is not None: if labels is not None:
labels = to_list(labels) labels = to_list(labels)
assert len(inputs) == len(self._input_desc), "number of inputs" \ assert len(inputs) == len(self._input_vars[self.mode]), \
"number of inputs" \
+ " does not match number of arguments of `forward` method" + " does not match number of arguments of `forward` method"
if self._progs.get(self.mode, None) is None:
self._make_program(self._infer_input_vars(inputs))
compiled_prog = self._compile_and_initialize(
self._progs[self.mode], device, device_ids)
feed = {} feed = {}
input_names = [name for name in self._input_desc.keys()] input_names = [v.name for v in self._input_vars[self.mode]]
for idx, n in enumerate(input_names): for idx, n in enumerate(input_names):
# train and test may take different arguments # train and test may take different arguments
if inputs[idx] is not None: if inputs[idx] is not None:
...@@ -283,33 +335,71 @@ class StaticGraphAdapter(object): ...@@ -283,33 +335,71 @@ class StaticGraphAdapter(object):
feed[v.name] = labels[idx] feed[v.name] = labels[idx]
endpoints = self._endpoints[self.mode] endpoints = self._endpoints[self.mode]
fetch_list = endpoints['output'] + endpoints['loss']
num_output = len(endpoints['output'])
if self.mode != 'test':
fetch_list += endpoints['label']
out = self._executor.run(
compiled_prog, feed=feed,
fetch_list=fetch_list)
if self.mode == 'test': if self.mode == 'test':
return out[:num_output] fetch_list = endpoints['output']
else: else:
return out[:num_output], out[num_output:-1], out[-1:] metric_list, metric_splits = flatten_list(endpoints['metric'])
fetch_list = endpoints['loss'] + metric_list
num_loss = len(endpoints['loss'])
rets = self._executor.run(
compiled_prog, feed=feed,
fetch_list=fetch_list,
return_numpy=False)
# LoDTensor cannot be fetch as numpy directly
rets = [np.array(v) for v in rets]
if self.mode == 'test':
return rets[:]
losses = rets[:num_loss]
metric_states = restore_flatten_list(rets[num_loss:], metric_splits)
metrics = []
for metric, state in zip(self.model._metrics, metric_states):
metrics.append(metric.update(*state))
return (losses, metrics) if len(metrics) > 0 else losses
def prepare(self):
modes = ['train', 'eval', 'test']
for mode in modes:
self._make_program(mode)
self._compile_and_initialize(self._progs[mode], mode)
def _make_program(self, mode):
prog = self._progs.get(mode, None)
if prog is not None:
return
def _make_program(self, inputs):
prog = self._orig_prog.clone() prog = self._orig_prog.clone()
if self.mode == 'train' and self.model._optimizer._learning_rate_map: # NOTE: When defining learning rate scheduling in static-graph, ops to
# increase the global step var and calculate learning rate would be
# prepended into _orig_prog. test program maked by `_orig_prog.clone`
# also would include these ops. Thus must prune these ops in test
# program, otherwise the global step would be changed in test.
if mode != 'train':
for op in list(prog.global_block().ops):
prog.global_block()._remove_op(0)
if mode == 'train' and self.model._optimizer \
and self.model._optimizer._learning_rate_map:
# HACK workaround learning rate map issue # HACK workaround learning rate map issue
lr_var = self.model._optimizer._learning_rate_map[self._orig_prog] lr_var = self.model._optimizer._learning_rate_map[self._orig_prog]
self.model._optimizer._learning_rate_map[prog] = lr_var self.model._optimizer._learning_rate_map[prog] = lr_var
losses = [] losses = []
metrics = []
with fluid.program_guard(prog, self._startup_prog): with fluid.program_guard(prog, self._startup_prog):
if isinstance(self.model._inputs, dict):
ins = [self.model._inputs[n] \
for n in extract_args(self.model.forward) if n != 'self']
else:
ins = self.model._inputs
lbls = self.model._labels if self.model._labels else []
inputs = [k.forward() for k in to_list(ins)]
labels = [k.forward() for k in to_list(lbls)]
outputs = to_list(self.model.forward(*inputs)) outputs = to_list(self.model.forward(*inputs))
if self.mode != 'test': if mode != 'test':
label_vars = self._infer_label_vars(outputs) if self.model._loss_function:
self._label_vars[self.mode] = label_vars losses = self.model._loss_function(outputs, labels)
losses = self.model._loss_function(outputs, label_vars) for metric in self.model._metrics:
if self.mode == 'train': metrics.append(to_list(metric.add_metric_op(outputs, labels)))
if mode == 'train' and self.model._optimizer:
self._loss_endpoint = fluid.layers.sum(losses) self._loss_endpoint = fluid.layers.sum(losses)
if self._nranks > 1: if self._nranks > 1:
role = role_maker.PaddleCloudRoleMaker(is_collective=True) role = role_maker.PaddleCloudRoleMaker(is_collective=True)
...@@ -325,45 +415,26 @@ class StaticGraphAdapter(object): ...@@ -325,45 +415,26 @@ class StaticGraphAdapter(object):
if self.mode != 'test': if self.mode != 'test':
label_vars = [distributed._all_gather(l, self._nranks) for l in label_vars] label_vars = [distributed._all_gather(l, self._nranks) for l in label_vars]
if self.mode != 'train': # clone again to put it in test mode if mode != 'train': # clone again to put it in test mode
prog = prog.clone(for_test=True) prog = prog.clone(for_test=True)
self._progs[self.mode] = prog
self._endpoints[self.mode] = {
"output": outputs,
"loss": losses,
"label": label_vars
}
def _infer_input_vars(self, inputs): self._input_vars[mode] = inputs
input_vars = [] self._label_vars[mode] = labels
for idx, i in enumerate(inputs): self._progs[mode] = prog
if i is None: # train and test may take different arguments self._endpoints[mode] = {"output": outputs, "loss": losses, "metric": metrics}
input_vars.append(None)
continue def _compile_and_initialize(self, prog, mode):
ndarray = np.array(i) compiled_prog = self._compiled_progs.get(mode, None)
name = list(self._input_desc.keys())[idx]
shape = list(self._input_desc.values())[idx]
if shape is None:
shape = (None, ) + ndarray.shape[1:]
input_vars.append(fluid.data(name, shape, ndarray.dtype))
return input_vars
def _infer_label_vars(self, outputs):
shapes = self.model._loss_function.infer_shape(outputs)
dtypes = self.model._loss_function.infer_dtype(outputs)
label_vars = []
for idx, (shape, dtype) in enumerate(zip(shapes, dtypes)):
name = '__label{}'.format(idx)
label_vars.append(fluid.data(name, shape, dtype))
return label_vars
def _compile_and_initialize(self, prog, device='CPU', device_ids=None):
compiled_prog = self._compiled_progs.get(self.mode, None)
if compiled_prog is not None: if compiled_prog is not None:
return compiled_prog return compiled_prog
places = [device.lower() == 'gpu' and fluid.CUDAPlace(i) device = self.model._device
or fluid.CPUPlace() for i in device_ids] device_ids = self.model._device_ids
if device.lower() == 'gpu':
places = fluid.cuda_places(device_ids)
else:
places = fluid.cpu_places(len(device_ids) if device_ids else None)
# XXX *ALL WEIGHTS* should be initialized upon model construction # XXX *ALL WEIGHTS* should be initialized upon model construction
# even if `forward()` may run different code path for different mode # even if `forward()` may run different code path for different mode
...@@ -394,26 +465,14 @@ class StaticGraphAdapter(object): ...@@ -394,26 +465,14 @@ class StaticGraphAdapter(object):
compiled_prog = fluid.CompiledProgram(prog) compiled_prog = fluid.CompiledProgram(prog)
else: else:
compiled_prog = prog#fleet.main_program compiled_prog = prog#fleet.main_program
if len(device_ids) > 1:
if len(places) > 1:
loss_name = None loss_name = None
if self.mode == 'train' and self._loss_endpoint is not None: if mode == 'train' and self._loss_endpoint is not None:
loss_name = self._loss_endpoint.name loss_name = self._loss_endpoint.name
share_vars_from = None
if self.mode == 'eval' and 'train' in self._compiled_progs:
share_vars_from = self._compiled_progs['train']
# HACK invalidate eval program if is compiled before train program
# quite hackish, OTOH, it is generally uncommon that the eval
# program will be run before the train program
if self.mode == 'train' and 'eval' in self._compiled_progs:
del self._compiled_progs['eval']
compiled_prog = compiled_prog.with_data_parallel( compiled_prog = compiled_prog.with_data_parallel(
loss_name=loss_name, places=places, loss_name=loss_name, places=places)
share_vars_from=share_vars_from) self._compiled_progs[mode] = compiled_prog
self._compiled_progs[self.mode] = compiled_prog
return compiled_prog
class DynamicGraphAdapter(object): class DynamicGraphAdapter(object):
...@@ -435,13 +494,14 @@ class DynamicGraphAdapter(object): ...@@ -435,13 +494,14 @@ class DynamicGraphAdapter(object):
self.model.mode = value self.model.mode = value
# TODO multi device in dygraph mode not implemented at present time # TODO multi device in dygraph mode not implemented at present time
def train(self, inputs, labels, device='CPU', device_ids=None): def train(self, inputs, labels=None):
assert self.model._optimizer and self.model._loss_function, \ assert self.model._optimizer, \
"model not ready, please call `model.prepare()` first" "model not ready, please call `model.prepare()` first"
super(Model, self.model).train() super(Model, self.model).train()
self.mode = 'train' self.mode = 'train'
inputs = to_list(inputs) inputs = to_list(inputs)
labels = to_list(labels) if labels is not None:
labels = [to_variable(l) for l in to_list(labels)]
if self._nranks > 1: if self._nranks > 1:
outputs = self.ddp_model.forward(*[to_variable(x) for x in inputs]) outputs = self.ddp_model.forward(*[to_variable(x) for x in inputs])
losses = self.model._loss_function(outputs, labels) losses = self.model._loss_function(outputs, labels)
...@@ -456,8 +516,13 @@ class DynamicGraphAdapter(object): ...@@ -456,8 +516,13 @@ class DynamicGraphAdapter(object):
final_loss.backward() final_loss.backward()
self.model._optimizer.minimize(final_loss) self.model._optimizer.minimize(final_loss)
self.model.clear_gradients() self.model.clear_gradients()
return [to_numpy(o) for o in to_list(outputs)], \ metrics = []
[to_numpy(l) for l in losses], [l for l in labels] for metric in self.model._metrics:
metric_outs = metric.add_metric_op(outputs, to_list(labels))
m = metric.update(*[to_numpy(m) for m in to_list(metric_outs)])
metrics.append(m)
return ([to_numpy(l) for l in losses], metrics) \
if len(metrics) > 0 else [to_numpy(l) for l in losses]
def eval(self, inputs, labels, device='CPU', device_ids=None): def eval(self, inputs, labels, device='CPU', device_ids=None):
assert self.model._loss_function, \ assert self.model._loss_function, \
...@@ -465,17 +530,25 @@ class DynamicGraphAdapter(object): ...@@ -465,17 +530,25 @@ class DynamicGraphAdapter(object):
super(Model, self.model).eval() super(Model, self.model).eval()
self.mode = 'eval' self.mode = 'eval'
inputs = to_list(inputs) inputs = to_list(inputs)
labels = to_list(labels) if labels is not None:
labels = [to_variable(l) for l in labels] labels = [to_variable(l) for l in to_list(labels)]
outputs = self.model.forward(*[to_variable(x) for x in inputs]) outputs = self.model.forward(*[to_variable(x) for x in inputs])
losses = self.model._loss_function(outputs, labels) losses = self.model._loss_function(outputs, labels)
if self._nranks > 1: if self._nranks > 1:
outputs = [distributed._all_gather(o, self._nranks) for o in to_list(outputs)] outputs = [distributed._all_gather(o, self._nranks) for o in to_list(outputs)]
labels = [distributed._all_gather(l, self._nranks) for l in labels] labels = [distributed._all_gather(l, self._nranks) for l in labels]
return [to_numpy(o) for o in to_list(outputs)], \ metrics = []
[to_numpy(l) for l in losses], [to_numpy(l) for l in labels] for metric in self.model._metrics:
metric_outs = metric.add_metric_op(outputs, labels)
def test(self, inputs, device='CPU', device_ids=None): m = metric.update(*[to_numpy(m) for m in to_list(metric_outs)])
metrics.append(m)
# To be consistent with static graph
# return empty loss if loss_function is None
return ([to_numpy(l) for l in losses], metrics) \
if len(metrics) > 0 else [to_numpy(l) for l in losses]
def test(self, inputs):
super(Model, self.model).eval() super(Model, self.model).eval()
self.mode = 'test' self.mode = 'test'
inputs = [to_variable(x) for x in to_list(inputs)] inputs = [to_variable(x) for x in to_list(inputs)]
...@@ -501,15 +574,68 @@ class DynamicGraphAdapter(object): ...@@ -501,15 +574,68 @@ class DynamicGraphAdapter(object):
self.model.set_dict(params) self.model.set_dict(params)
if self.model._optimizer is None or optim is None: if self.model._optimizer is None or optim is None:
return return
self.model._optimizer.set_dict(optim)
# If optimizer performs set_dict when state vars haven't been created,
# which would happen when set_dict before minimize, the state would be
# stored in optimizer._accumulators_holder and loaded lazily.
# To contrive this when loading from static-graph saved states, extend
# state dict to include keys named accoring to dygraph naming rules.
# TODO: if len(self.model._optimizer._accumulators) > 0
converted_state = dict(optim)
opt_unq_name = self.model._optimizer._name
opt_cls_name = self.model._optimizer.__class__.__name__
opt_name = opt_unq_name[:opt_unq_name.rfind("_")] # remove suffix idx
param_names = [param.name for param in self.model.parameters()]
for var_name, state_var in sorted(
optim.items(), key=lambda x: len(x[0]), reverse=True):
if var_name in ["@LR_DECAY_COUNTER@", "global_step"]:
# NOTE: dygraph saved global_step is 1 larger than that in
# static-graph, since the time of global_step to increase is
# different.
if var_name == "@LR_DECAY_COUNTER@":
converted_state["global_step"] = np.array(
converted_state.pop("@LR_DECAY_COUNTER@")) + 1
else:
# moment and other accumulators
# extend state dict to include promising dygraph names
for param_name in param_names:
if var_name.startswith(param_name + "_" + opt_name):
# when init optimizer with name
accum_name = var_name[len(param_name + "_" + opt_name +
"_"):]
elif var_name.startswith(param_name +
"_") and opt_name == opt_cls_name:
# when init optimizer without name
accum_name = var_name[len(param_name + "_"):]
else:
continue
# remove suffix idx
accum_name = accum_name[:accum_name.rfind("_")]
# state names always end with "_0" in dygraph because of the
# unique optimizer._name
dy_state_name = (param_name + "_" + opt_unq_name + "_" +
accum_name + "_0")
converted_state[dy_state_name] = state_var
self.model._optimizer.set_dict(converted_state)
class Model(fluid.dygraph.Layer): class Model(fluid.dygraph.Layer):
"""
FIXME: add more comments and usage
"""
def __init__(self): def __init__(self):
super(Model, self).__init__(self.__class__.__name__) super(Model, self).__init__(self.__class__.__name__)
self.mode = 'train' self.mode = 'train'
self._inputs = None
self._labels = None
self._loss_function = None self._loss_function = None
self._loss_weights = None self._loss_weights = None
self._loss = None
self._optimizer = None
self._device = None
self._device_ids = None
self._optimizer = None self._optimizer = None
if in_dygraph_mode(): if in_dygraph_mode():
self._adapter = DynamicGraphAdapter(self) self._adapter = DynamicGraphAdapter(self)
...@@ -532,11 +658,75 @@ class Model(fluid.dygraph.Layer): ...@@ -532,11 +658,75 @@ class Model(fluid.dygraph.Layer):
def load(self, *args, **kwargs): def load(self, *args, **kwargs):
return self._adapter.load(*args, **kwargs) return self._adapter.load(*args, **kwargs)
def prepare(self, optimizer, loss_function): def prepare(self,
optimizer=None,
loss_function=None,
metrics=None,
inputs=None,
labels=None,
device=None,
device_ids=None):
"""
FIXME: add comments
Args:
optimizer (Optimizer|None): optimizer must be set in training
and should be a Optimizer instance. It can be None in eval
and test mode.
loss_function (Loss|None): loss function must be set in training
and should be a Loss instance. It can be None when there is
no loss.
metrics (Metric|list of Metric|None): if metrics is set, all
metric will be calculate and output in train/eval mode.
inputs (Input|list|dict|None): inputs, entry points of network,
could be a Input layer, or lits of Input layers,
or dict (name: Input), or None. For static graph,
inputs must be set. For dynamic graph, it could be None.
labels (Input|list|None): labels, entry points of network,
could be a Input layer or lits of Input layers, or None.
For static graph, if set loss_function in Model.prepare(), it
must be set. Otherwise, it could be None.
device (str|None): specify device type, 'CPU' or 'GPU'.
If None, automatically select device according to
installation package version.
device_ids (list[int]|None): specify device index. If None,
the available device will be obtained from the environment
variable when the model is executed: If the GPU is used, the
currently available device ID is obtained from the environment
variable FLAGS_selected_gpus or CUDA_VISIBLE_DEVICES when the
model is executed; CPU, when the model is executed,
the currently available CPU number is obtained from the
environment variable CPU_NUM. For example, export CPU_NUM=4,
if the environment variable is not set, the executor will add
the variable to the environment variable and set its value to 1.
The default is None.
"""
self._optimizer = optimizer self._optimizer = optimizer
assert isinstance(loss_function, Loss), \ if loss_function:
"'loss_function' must be sub classes of 'Loss'" if not isinstance(loss_function, Loss):
raise TypeError(
"'loss_function' must be sub classes of 'Loss'")
self._loss_function = loss_function self._loss_function = loss_function
if not in_dygraph_mode():
if not isinstance(inputs, (list, dict, Input)):
raise TypeError(
"'inputs' must be list or dict in static graph mode")
if loss_function and not isinstance(labels, (list, Input)):
raise TypeError("'labels' must be list in static graph mode")
metrics = metrics or []
for metric in to_list(metrics):
assert isinstance(metric, Metric), \
"{} is not sub class of Metric".format(metric.__class__.__name__)
self._metrics = to_list(metrics)
self._inputs = inputs
self._labels = labels
self._device = device
if device is None:
self._device = 'GPU' if fluid.is_compiled_with_cuda() else 'CPU'
self._device_ids = device_ids
if not in_dygraph_mode():
self._adapter.prepare()
def parameters(self, *args, **kwargs): def parameters(self, *args, **kwargs):
return self._adapter.parameters(*args, **kwargs) return self._adapter.parameters(*args, **kwargs)
...@@ -33,9 +33,14 @@ from paddle.fluid.dygraph.nn import Conv2D ...@@ -33,9 +33,14 @@ from paddle.fluid.dygraph.nn import Conv2D
from paddle.fluid.param_attr import ParamAttr from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.regularizer import L2Decay from paddle.fluid.regularizer import L2Decay
from model import Model, Loss, shape_hints from model import Model, Loss, Input
from resnet import ResNet, ConvBNLayer from resnet import ResNet, ConvBNLayer
import logging
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__)
# XXX transfer learning # XXX transfer learning
class ResNetBackBone(ResNet): class ResNetBackBone(ResNet):
...@@ -102,13 +107,14 @@ class YoloDetectionBlock(fluid.dygraph.Layer): ...@@ -102,13 +107,14 @@ class YoloDetectionBlock(fluid.dygraph.Layer):
class YOLOv3(Model): class YOLOv3(Model):
def __init__(self): def __init__(self, num_classes=80):
super(YOLOv3, self).__init__() super(YOLOv3, self).__init__()
self.num_classes = 80 self.num_classes = num_classes
self.anchors = [10, 13, 16, 30, 33, 23, 30, 61, 62, 45, self.anchors = [10, 13, 16, 30, 33, 23, 30, 61, 62, 45,
59, 119, 116, 90, 156, 198, 373, 326] 59, 119, 116, 90, 156, 198, 373, 326]
self.anchor_masks = [[6, 7, 8], [3, 4, 5], [0, 1, 2]] self.anchor_masks = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
self.valid_thresh = 0.005 self.valid_thresh = 0.005
self.nms_thresh = 0.45
self.nms_topk = 400 self.nms_topk = 400
self.nms_posk = 100 self.nms_posk = 100
self.draw_thresh = 0.5 self.draw_thresh = 0.5
...@@ -146,8 +152,7 @@ class YOLOv3(Model): ...@@ -146,8 +152,7 @@ class YOLOv3(Model):
act='leaky_relu')) act='leaky_relu'))
self.route_blocks.append(route) self.route_blocks.append(route)
@shape_hints(inputs=[None, 3, None, None], im_shape=[None, 2]) def forward(self, inputs, img_info):
def forward(self, inputs, im_shape):
outputs = [] outputs = []
boxes = [] boxes = []
scores = [] scores = []
...@@ -161,48 +166,50 @@ class YOLOv3(Model): ...@@ -161,48 +166,50 @@ class YOLOv3(Model):
feat = fluid.layers.concat(input=[route, feat], axis=1) feat = fluid.layers.concat(input=[route, feat], axis=1)
route, tip = self.yolo_blocks[idx](feat) route, tip = self.yolo_blocks[idx](feat)
block_out = self.block_outputs[idx](tip) block_out = self.block_outputs[idx](tip)
outputs.append(block_out)
if idx < 2: if idx < 2:
route = self.route_blocks[idx](route) route = self.route_blocks[idx](route)
route = fluid.layers.resize_nearest(route, scale=2) route = fluid.layers.resize_nearest(route, scale=2)
anchor_mask = self.anchor_masks[idx] if self.mode == 'test':
mask_anchors = [] anchor_mask = self.anchor_masks[idx]
for m in anchor_mask: mask_anchors = []
mask_anchors.append(self.anchors[2 * m]) for m in anchor_mask:
mask_anchors.append(self.anchors[2 * m + 1]) mask_anchors.append(self.anchors[2 * m])
b, s = fluid.layers.yolo_box( mask_anchors.append(self.anchors[2 * m + 1])
x=block_out, img_shape = fluid.layers.slice(img_info, axes=[1], starts=[1], ends=[3])
img_size=im_shape, img_id = fluid.layers.slice(img_info, axes=[1], starts=[0], ends=[1])
anchors=mask_anchors, b, s = fluid.layers.yolo_box(
class_num=self.num_classes, x=block_out,
conf_thresh=self.valid_thresh, img_size=img_shape,
downsample_ratio=downsample) anchors=mask_anchors,
class_num=self.num_classes,
outputs.append(block_out) conf_thresh=self.valid_thresh,
boxes.append(b) downsample_ratio=downsample)
scores.append(fluid.layers.transpose(s, perm=[0, 2, 1]))
boxes.append(b)
scores.append(fluid.layers.transpose(s, perm=[0, 2, 1]))
downsample //= 2 downsample //= 2
if self.mode != 'test': if self.mode != 'test':
return outputs return outputs
return fluid.layers.multiclass_nms( return [img_id, fluid.layers.multiclass_nms(
bboxes=fluid.layers.concat(boxes, axis=1), bboxes=fluid.layers.concat(boxes, axis=1),
scores=fluid.layers.concat(scores, axis=2), scores=fluid.layers.concat(scores, axis=2),
score_threshold=self.valid_thresh, score_threshold=self.valid_thresh,
nms_top_k=self.nms_topk, nms_top_k=self.nms_topk,
keep_top_k=self.nms_posk, keep_top_k=self.nms_posk,
nms_threshold=self.nms_thresh, nms_threshold=self.nms_thresh,
background_label=-1) background_label=-1)]
class YoloLoss(Loss): class YoloLoss(Loss):
def __init__(self, num_classes=80, num_max_boxes=50): def __init__(self, num_classes=80):
super(YoloLoss, self).__init__() super(YoloLoss, self).__init__()
self.num_classes = num_classes self.num_classes = num_classes
self.num_max_boxes = num_max_boxes
self.ignore_thresh = 0.7 self.ignore_thresh = 0.7
self.anchors = [10, 13, 16, 30, 33, 23, 30, 61, 62, 45, self.anchors = [10, 13, 16, 30, 33, 23, 30, 61, 62, 45,
59, 119, 116, 90, 156, 198, 373, 326] 59, 119, 116, 90, 156, 198, 373, 326]
...@@ -226,20 +233,11 @@ class YoloLoss(Loss): ...@@ -226,20 +233,11 @@ class YoloLoss(Loss):
class_num=self.num_classes, class_num=self.num_classes,
ignore_thresh=self.ignore_thresh, ignore_thresh=self.ignore_thresh,
use_label_smooth=True) use_label_smooth=True)
loss = fluid.layers.reduce_mean(loss)
losses.append(loss) losses.append(loss)
downsample //= 2 downsample //= 2
return losses return losses
def infer_shape(self, _):
return [
[None, self.num_max_boxes, 4],
[None, self.num_max_boxes],
[None, self.num_max_boxes]
]
def infer_dtype(self, _):
return ['float32', 'int32', 'float32']
def make_optimizer(parameter_list=None): def make_optimizer(parameter_list=None):
base_lr = FLAGS.lr base_lr = FLAGS.lr
...@@ -293,7 +291,7 @@ def random_crop(inputs): ...@@ -293,7 +291,7 @@ def random_crop(inputs):
thresholds = [.0, .1, .3, .5, .7, .9] thresholds = [.0, .1, .3, .5, .7, .9]
scaling = [.3, 1.] scaling = [.3, 1.]
img, gt_box, gt_label = inputs img, img_ids, gt_box, gt_label = inputs
h, w = img.shape[:2] h, w = img.shape[:2]
if len(gt_box) == 0: if len(gt_box) == 0:
...@@ -327,7 +325,7 @@ def random_crop(inputs): ...@@ -327,7 +325,7 @@ def random_crop(inputs):
img = img[y1:y2, x1:x2, :] img = img[y1:y2, x1:x2, :]
gt_box = np.take(cropped_box, valid_ids, axis=0) gt_box = np.take(cropped_box, valid_ids, axis=0)
gt_label = np.take(gt_label, valid_ids, axis=0) gt_label = np.take(gt_label, valid_ids, axis=0)
return img, gt_box, gt_label return img, img_ids, gt_box, gt_label
return inputs return inputs
...@@ -335,9 +333,9 @@ def random_crop(inputs): ...@@ -335,9 +333,9 @@ def random_crop(inputs):
# XXX mix up, color distort and random expand are skipped for simplicity # XXX mix up, color distort and random expand are skipped for simplicity
def sample_transform(inputs, mode='train', num_max_boxes=50): def sample_transform(inputs, mode='train', num_max_boxes=50):
if mode == 'train': if mode == 'train':
img, gt_box, gt_label = random_crop(inputs) img, img_id, gt_box, gt_label = random_crop(inputs)
else: else:
img, gt_box, gt_label = inputs img, img_id, gt_box, gt_label = inputs
h, w = img.shape[:2] h, w = img.shape[:2]
# random flip # random flip
...@@ -350,7 +348,7 @@ def sample_transform(inputs, mode='train', num_max_boxes=50): ...@@ -350,7 +348,7 @@ def sample_transform(inputs, mode='train', num_max_boxes=50):
if len(gt_label) == 0: if len(gt_label) == 0:
gt_box = np.zeros([num_max_boxes, 4], dtype=np.float32) gt_box = np.zeros([num_max_boxes, 4], dtype=np.float32)
gt_label = np.zeros([num_max_boxes, 1], dtype=np.int32) gt_label = np.zeros([num_max_boxes], dtype=np.int32)
return img, gt_box, gt_label return img, gt_box, gt_label
gt_box = gt_box[:num_max_boxes, :] gt_box = gt_box[:num_max_boxes, :]
...@@ -362,9 +360,9 @@ def sample_transform(inputs, mode='train', num_max_boxes=50): ...@@ -362,9 +360,9 @@ def sample_transform(inputs, mode='train', num_max_boxes=50):
pad = num_max_boxes - gt_label.size pad = num_max_boxes - gt_label.size
gt_box = np.pad(gt_box, ((0, pad), (0, 0)), mode='constant') gt_box = np.pad(gt_box, ((0, pad), (0, 0)), mode='constant')
gt_label = np.pad(gt_label, [(0, pad)], mode='constant') gt_label = np.pad(gt_label, ((0, pad)), mode='constant')
return img, gt_box, gt_label return img, img_id, gt_box, gt_label
def batch_transform(batch, mode='train'): def batch_transform(batch, mode='train'):
...@@ -376,7 +374,8 @@ def batch_transform(batch, mode='train'): ...@@ -376,7 +374,8 @@ def batch_transform(batch, mode='train'):
d = 608 d = 608
interp = cv2.INTER_CUBIC interp = cv2.INTER_CUBIC
# transpose batch # transpose batch
imgs, gt_boxes, gt_labels = list(zip(*batch)) imgs, img_ids, gt_boxes, gt_labels = list(zip(*batch))
img_shapes = np.array([[im.shape[0], im.shape[1]] for im in imgs]).astype('int32')
imgs = np.array([cv2.resize( imgs = np.array([cv2.resize(
img, (d, d), interpolation=interp) for img in imgs]) img, (d, d), interpolation=interp) for img in imgs])
...@@ -389,12 +388,13 @@ def batch_transform(batch, mode='train'): ...@@ -389,12 +388,13 @@ def batch_transform(batch, mode='train'):
imgs *= invstd imgs *= invstd
imgs = imgs.transpose((0, 3, 1, 2)) imgs = imgs.transpose((0, 3, 1, 2))
im_shapes = np.full([len(imgs), 2], d, dtype=np.int32) img_ids = np.array(img_ids)
img_info = np.concatenate([img_ids, img_shapes], axis=1)
gt_boxes = np.array(gt_boxes) gt_boxes = np.array(gt_boxes)
gt_labels = np.array(gt_labels) gt_labels = np.array(gt_labels)
# XXX since mix up is not used, scores are all ones # XXX since mix up is not used, scores are all ones
gt_scores = np.ones_like(gt_labels, dtype=np.float32) gt_scores = np.ones_like(gt_labels, dtype=np.float32)
return [imgs, im_shapes], [gt_boxes, gt_labels, gt_scores] return [imgs, img_info], [gt_boxes, gt_labels, gt_scores]
def coco2017(root_dir, mode='train'): def coco2017(root_dir, mode='train'):
...@@ -434,17 +434,18 @@ def coco2017(root_dir, mode='train'): ...@@ -434,17 +434,18 @@ def coco2017(root_dir, mode='train'):
gt_box = np.array(gt_box, dtype=np.float32) gt_box = np.array(gt_box, dtype=np.float32)
gt_label = np.array([class_map[cls] for cls in gt_label], gt_label = np.array([class_map[cls] for cls in gt_label],
dtype=np.int32)[:, np.newaxis] dtype=np.int32)[:, np.newaxis]
im_id = np.array([img['id']], dtype=np.int32)
if gt_label.size == 0 and not mode == 'train': if gt_label.size == 0 and not mode == 'train':
continue continue
samples.append((file_path, gt_box.copy(), gt_label.copy())) samples.append((file_path, im_id.copy(), gt_box.copy(), gt_label.copy()))
def iterator(): def iterator():
if mode == 'train': if mode == 'train':
random.shuffle(samples) np.random.shuffle(samples)
for file_path, gt_box, gt_label in samples: for file_path, im_id, gt_box, gt_label in samples:
img = cv2.imread(file_path) img = cv2.imread(file_path)
yield img, gt_box, gt_label yield img, im_id, gt_box, gt_label
return iterator return iterator
...@@ -457,14 +458,13 @@ def run(model, loader, mode='train'): ...@@ -457,14 +458,13 @@ def run(model, loader, mode='train'):
start = time.time() start = time.time()
for idx, batch in enumerate(loader()): for idx, batch in enumerate(loader()):
outputs, losses = getattr(model, mode)( losses = getattr(model, mode)(batch[0], batch[1])
batch[0], batch[1], device='gpu', device_ids=device_ids)
total_loss += np.sum(losses) total_loss += np.sum(losses)
if idx > 1: # skip first two steps if idx > 1: # skip first two steps
total_time += time.time() - start total_time += time.time() - start
if idx % 10 == 0: if idx % 10 == 0:
print("{:04d}: loss {:0.3f} time: {:0.3f}".format( logger.info("{:04d}: loss {:0.3f} time: {:0.3f}".format(
idx, total_loss / (idx + 1), total_time / max(1, (idx - 1)))) idx, total_loss / (idx + 1), total_time / max(1, (idx - 1))))
start = time.time() start = time.time()
...@@ -501,26 +501,46 @@ def main(): ...@@ -501,26 +501,46 @@ def main():
coco2017(FLAGS.data, 'val'), coco2017(FLAGS.data, 'val'),
process_num=8, process_num=8,
buffer_size=4 * batch_size), buffer_size=4 * batch_size),
batch_size=batch_size), batch_size=1),
process_num=2, buffer_size=4) process_num=2, buffer_size=4)
if not os.path.exists('yolo_checkpoints'): if not os.path.exists('yolo_checkpoints'):
os.mkdir('yolo_checkpoints') os.mkdir('yolo_checkpoints')
with guard: with guard:
model = YOLOv3() NUM_CLASSES = 7
NUM_MAX_BOXES = 50
model = YOLOv3(num_classes=NUM_CLASSES)
# XXX transfer learning # XXX transfer learning
if FLAGS.pretrain_weights is not None:
model.backbone.load(FLAGS.pretrain_weights)
if FLAGS.weights is not None: if FLAGS.weights is not None:
model.backbone.load(FLAGS.weights) model.load(FLAGS.weights)
optim = make_optimizer(parameter_list=model.parameters()) optim = make_optimizer(parameter_list=model.parameters())
model.prepare(optim, YoloLoss()) anno_path = os.path.join(FLAGS.data, 'annotations', 'instances_val2017.json')
inputs = [Input([None, 3, None, None], 'float32', name='image'),
Input([None, 3], 'int32', name='img_info')]
labels = [Input([None, NUM_MAX_BOXES, 4], 'float32', name='gt_bbox'),
Input([None, NUM_MAX_BOXES], 'int32', name='gt_label'),
Input([None, NUM_MAX_BOXES], 'float32', name='gt_score')]
model.prepare(optim,
YoloLoss(num_classes=NUM_CLASSES),
# For YOLOv3, output variable in train/eval is different,
# which is not supported by metric, add by callback later?
# metrics=COCOMetric(anno_path, with_background=False)
inputs=inputs,
labels = labels)
for e in range(epoch): for e in range(epoch):
print("======== train epoch {} ========".format(e)) logger.info("======== train epoch {} ========".format(e))
run(model, train_loader) run(model, train_loader)
model.save('yolo_checkpoints/{:02d}'.format(e)) model.save('yolo_checkpoints/{:02d}'.format(e))
print("======== eval epoch {} ========".format(e)) logger.info("======== eval epoch {} ========".format(e))
run(model, val_loader, mode='eval') run(model, val_loader, mode='eval')
# should be called in fit()
for metric in model._metrics:
metric.accumulate()
metric.reset()
if __name__ == '__main__': if __name__ == '__main__':
...@@ -538,8 +558,11 @@ if __name__ == '__main__': ...@@ -538,8 +558,11 @@ if __name__ == '__main__':
parser.add_argument( parser.add_argument(
"-n", "--num_devices", default=8, type=int, help="number of devices") "-n", "--num_devices", default=8, type=int, help="number of devices")
parser.add_argument( parser.add_argument(
"-w", "--weights", default=None, type=str, "-p", "--pretrain_weights", default=None, type=str,
help="path to pretrained weights") help="path to pretrained weights")
parser.add_argument(
"-w", "--weights", default=None, type=str,
help="path to model weights")
FLAGS = parser.parse_args() FLAGS = parser.parse_args()
assert FLAGS.data, "error: must provide data path" assert FLAGS.data, "error: must provide data path"
main() main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import json
from pycocotools.cocoeval import COCOeval
from pycocotools.coco import COCO
from metrics import Metric
import logging
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__)
__all__ = ['COCOMetric']
OUTFILE = './bbox.json'
# considered to change to a callback later
class COCOMetric(Metric):
"""
Metrci for MS-COCO dataset, only support update with batch
size as 1.
Args:
anno_path(str): path to COCO annotation json file
with_background(bool): whether load category id with
background as 0, default True
"""
def __init__(self, anno_path, with_background=True, **kwargs):
super(COCOMetric, self).__init__(**kwargs)
self.anno_path = anno_path
self.with_background = with_background
self.bbox_results = []
self.coco_gt = COCO(anno_path)
cat_ids = self.coco_gt.getCatIds()
self.clsid2catid = dict(
{i + int(with_background): catid
for i, catid in enumerate(cat_ids)})
def update(self, preds, *args, **kwargs):
im_ids, bboxes = preds
assert im_ids.shape[0] == 1, \
"COCOMetric can only update with batch size = 1"
if bboxes.shape[1] != 6:
# no bbox detected in this batch
return
im_id = int(im_ids)
for i in range(bboxes.shape[0]):
dt = bboxes[i, :]
clsid, score, xmin, ymin, xmax, ymax = dt.tolist()
catid = (self.clsid2catid[int(clsid)])
w = xmax - xmin + 1
h = ymax - ymin + 1
bbox = [xmin, ymin, w, h]
coco_res = {
'image_id': im_id,
'category_id': catid,
'bbox': bbox,
'score': score
}
self.bbox_results.append(coco_res)
def reset(self):
self.bbox_results = []
def accumulate(self):
if len(self.bbox_results) == 0:
logger.warning("The number of valid bbox detected is zero.\n \
Please use reasonable model and check input data.\n \
stop COCOMetric accumulate!")
return [0.0]
with open(OUTFILE, 'w') as f:
json.dump(self.bbox_results, f)
map_stats = self.cocoapi_eval(OUTFILE, 'bbox', coco_gt=self.coco_gt)
# flush coco evaluation result
sys.stdout.flush()
self.result = map_stats[0]
return self.result
def cocoapi_eval(self, jsonfile, style, coco_gt=None, anno_file=None):
assert coco_gt != None or anno_file != None
if coco_gt == None:
coco_gt = COCO(anno_file)
logger.info("Start evaluate...")
coco_dt = coco_gt.loadRes(jsonfile)
coco_eval = COCOeval(coco_gt, coco_dt, style)
coco_eval.evaluate()
coco_eval.accumulate()
coco_eval.summarize()
return coco_eval.stats
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册