diff --git a/mnist.py b/mnist.py new file mode 100644 index 0000000000000000000000000000000000000000..2f6d95c3fc54c40f859c9ad49fc7477b2ffa0b26 --- /dev/null +++ b/mnist.py @@ -0,0 +1,198 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division +from __future__ import print_function + +import argparse +import contextlib +import os + +import numpy as np + +import paddle +from paddle import fluid +from paddle.fluid.optimizer import Momentum +from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear + +from model import Model, CrossEntropy + + +class SimpleImgConvPool(fluid.dygraph.Layer): + def __init__(self, + num_channels, + num_filters, + filter_size, + pool_size, + pool_stride, + pool_padding=0, + pool_type='max', + global_pooling=False, + conv_stride=1, + conv_padding=0, + conv_dilation=1, + conv_groups=None, + act=None, + use_cudnn=False, + param_attr=None, + bias_attr=None): + super(SimpleImgConvPool, self).__init__('SimpleConv') + + self._conv2d = Conv2D( + num_channels=num_channels, + num_filters=num_filters, + filter_size=filter_size, + stride=conv_stride, + padding=conv_padding, + dilation=conv_dilation, + groups=conv_groups, + param_attr=None, + bias_attr=None, + use_cudnn=use_cudnn) + + self._pool2d = Pool2D( + pool_size=pool_size, + pool_type=pool_type, + pool_stride=pool_stride, + pool_padding=pool_padding, + global_pooling=global_pooling, + use_cudnn=use_cudnn) + + def forward(self, inputs): + x = self._conv2d(inputs) + x = self._pool2d(x) + return x + + +class MNIST(Model): + def __init__(self): + super(MNIST, self).__init__() + + self._simple_img_conv_pool_1 = SimpleImgConvPool( + 1, 20, 5, 2, 2, act="relu") + + self._simple_img_conv_pool_2 = SimpleImgConvPool( + 20, 50, 5, 2, 2, act="relu") + + pool_2_shape = 50 * 4 * 4 + SIZE = 10 + scale = (2.0 / (pool_2_shape**2 * SIZE))**0.5 + self._fc = Linear(800, + 10, + param_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.NormalInitializer( + loc=0.0, scale=scale)), + act="softmax") + + def forward(self, inputs): + x = self._simple_img_conv_pool_1(inputs) + x = self._simple_img_conv_pool_2(x) + x = fluid.layers.flatten(x, axis=1) + x = self._fc(x) + return x + + +def accuracy(pred, label, topk=(1, )): + maxk = max(topk) + pred = np.argsort(pred)[:, ::-1][:, :maxk] + correct = (pred == np.repeat(label, maxk, 1)) + + batch_size = label.shape[0] + res = [] + for k in topk: + correct_k = correct[:, :k].sum() + res.append(100.0 * correct_k / batch_size) + return res + + +def main(): + @contextlib.contextmanager + def null_guard(): + yield + + guard = fluid.dygraph.guard() if FLAGS.dynamic else null_guard() + + if not os.path.exists('mnist_checkpoints'): + os.mkdir('mnist_checkpoints') + + train_loader = fluid.io.xmap_readers( + lambda b: [np.array([x[0] for x in b]).reshape(-1, 1, 28, 28), + np.array([x[1] for x in b]).reshape(-1, 1)], + paddle.batch(fluid.io.shuffle(paddle.dataset.mnist.train(), 6e4), + batch_size=FLAGS.batch_size, drop_last=True), 1, 1) + val_loader = fluid.io.xmap_readers( + lambda b: [np.array([x[0] for x in b]).reshape(-1, 1, 28, 28), + np.array([x[1] for x in b]).reshape(-1, 1)], + paddle.batch(paddle.dataset.mnist.test(), + batch_size=FLAGS.batch_size, drop_last=True), 1, 1) + + device_ids = list(range(FLAGS.num_devices)) + + with guard: + model = MNIST() + optim = Momentum(learning_rate=FLAGS.lr, momentum=.9, + parameter_list=model.parameters()) + model.prepare(optim, CrossEntropy()) + if FLAGS.resume is not None: + model.load(FLAGS.resume) + + for e in range(FLAGS.epoch): + train_loss = 0.0 + train_acc = 0.0 + val_loss = 0.0 + val_acc = 0.0 + print("======== train epoch {} ========".format(e)) + for idx, batch in enumerate(train_loader()): + outputs, losses = model.train(batch[0], batch[1], device='gpu', + device_ids=device_ids) + + acc = accuracy(outputs[0], batch[1])[0] + train_loss += np.sum(losses) + train_acc += acc + if idx % 10 == 0: + print("{:04d}: loss {:0.3f} top1: {:0.3f}%".format( + idx, train_loss / (idx + 1), train_acc / (idx + 1))) + + print("======== eval epoch {} ========".format(e)) + for idx, batch in enumerate(val_loader()): + outputs, losses = model.eval(batch[0], batch[1], device='gpu', + device_ids=device_ids) + + acc = accuracy(outputs[0], batch[1])[0] + val_loss += np.sum(losses) + val_acc += acc + if idx % 10 == 0: + print("{:04d}: loss {:0.3f} top1: {:0.3f}%".format( + idx, val_loss / (idx + 1), val_acc / (idx + 1))) + model.save('mnist_checkpoints/{:02d}'.format(e)) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser("CNN training on MNIST") + parser.add_argument( + "-d", "--dynamic", action='store_true', help="enable dygraph mode") + parser.add_argument( + "-e", "--epoch", default=100, type=int, help="number of epoch") + parser.add_argument( + '--lr', '--learning-rate', default=1e-3, type=float, metavar='LR', + help='initial learning rate') + parser.add_argument( + "-b", "--batch_size", default=128, type=int, help="batch size") + parser.add_argument( + "-n", "--num_devices", default=4, type=int, help="number of devices") + parser.add_argument( + "-r", "--resume", default=None, type=str, + help="checkpoint path to resume") + FLAGS = parser.parse_args() + main() diff --git a/model.py b/model.py new file mode 100644 index 0000000000000000000000000000000000000000..a2ae5d05acde85efcd62b9e2dc0662ea9c318995 --- /dev/null +++ b/model.py @@ -0,0 +1,491 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import + +import inspect +import os +import pickle +from collections import OrderedDict + +import numpy as np + +from paddle import fluid +from paddle.fluid.framework import in_dygraph_mode, Variable +from paddle.fluid.executor import global_scope +from paddle.fluid.io import is_belong_to_optimizer +from paddle.fluid.dygraph.base import to_variable + +__all__ = ['shape_hints', 'Model', 'Loss', 'CrossEntropy'] + + +def to_list(value): + if isinstance(value, (list, tuple)): + return value + return [value] + + +def to_numpy(var): + assert isinstance(var, (Variable, fluid.core.VarBase)), "not a variable" + if isinstance(var, fluid.core.VarBase): + return var.numpy() + t = global_scope().find_var(var.name).get_tensor() + return np.array(t) + + +def extract_args(func): + if hasattr(inspect, 'getfullargspec'): + return inspect.getfullargspec(func)[0] + else: + return inspect.getargspec(func)[0] + + +def shape_hints(**hints): + assert hints, "hints can not be empty" + assert all(isinstance(h, (list, tuple)) for h in hints.values()), \ + "shape hint must be a list or tuple" + + def wrapper(func): + args = extract_args(func) + invalid = set(hints.keys()) - set(args) + assert not invalid, \ + "shape hint for arguments that are not present in forward method" \ + + ": ({})".format(", ".join(invalid)) + func.shape_hints = hints + return func + return wrapper + + +class Loss(object): + def __init__(self, average=True): + super(Loss, self).__init__() + self.average = average + + def infer_shape(self, outputs): + return [o.shape for o in outputs] + + def infer_dtype(self, outputs): + return [o.dtype for o in outputs] + + def forward(self, outputs, labels): + raise NotImplementedError() + + def __call__(self, outputs, labels): + labels = to_list(labels) + if in_dygraph_mode(): + labels = [to_variable(l) for l in labels] + losses = to_list(self.forward(to_list(outputs), labels)) + if not self.average: + return losses + return [fluid.layers.reduce_mean(l) for l in losses] + + +class CrossEntropy(Loss): + def __init__(self): + super(CrossEntropy, self).__init__() + + def infer_shape(self, outputs): + return [o.shape[:-1] + (1, ) for o in outputs] + + def infer_dtype(self, outputs): + return ['int64' for _ in outputs] + + def forward(self, outputs, labels): + return [fluid.layers.cross_entropy(o, l) for o, l in zip( + outputs, labels)] + + +class StaticGraphAdapter(object): + def __init__(self, model): + super(StaticGraphAdapter, self).__init__() + self.model = model + # with `_build_once` gone, parameters are now created in `__init__` + # so we need to keep track of the parameters already created + self._startup_prog = fluid.default_startup_program() + self._orig_prog = fluid.default_main_program() + + self._label_vars = {} # label variables + self._endpoints = {} + self._loss_endpoint = None + self._executor = None + self._progs = {} + self._compiled_progs = {} + + self._lazy_load_optimizer = None + + # parse shape hints + self._input_desc = OrderedDict([ + (n, None) for n in extract_args(self.model.forward) if n != 'self' + ]) + if hasattr(self.model.forward, 'shape_hints'): + self._input_desc.update(self.model.forward.shape_hints) + + @property + def mode(self): + return self.model.mode + + @mode.setter + def mode(self, value): + self.model.mode = value + + def train(self, inputs, labels, device='CPU', device_ids=None): + assert self.model._optimizer and self.model._loss_function, \ + "model not ready, please call `model.prepare()` first" + self.mode = 'train' + return self._run(inputs, labels, device, device_ids) + + def eval(self, inputs, labels, device='CPU', device_ids=None): + assert self.model._loss_function, \ + "model not ready, please call `model.prepare()` first" + self.mode = 'eval' + return self._run(inputs, labels, device, device_ids) + + def test(self, inputs, device='CPU', device_ids=None): + self.mode = 'test' + return self._run(inputs, None, device, device_ids) + + def parameters(self, *args, **kwargs): + return None + + def save(self, path): + def _save(state, path): + if not state: + return + state = {k: to_numpy(v) if isinstance(v, Variable) else v + for k, v in state.items()} + with open(path, 'wb') as f: + pickle.dump(state, f) + + base = os.path.basename(path) + assert base != "", "path should be of 'dirname/filename' format" + param_path = path + ".pdparams" + _save(self.model.state_dict(), param_path) + prog = self._progs.get('train', None) + if prog is None or self.model._optimizer is None: + return + # XXX `optimizer.state_dict()` only work in dygraph mode + optim_path = path + ".pdopt" + optim = {p.name: p for p in filter( + is_belong_to_optimizer, prog.list_vars())} + if not optim: + return + # HACK this is contrived, optimizer state is not the same for + # static/dynamic graph mode + optim['__static_graph_only__'] = True + _save(optim, optim_path) + + def load(self, path): + def _load(path): + if not os.path.exists(path): + return + with open(path, 'rb') as f: + return pickle.load(f) + + param_path = path + ".pdparams" + param_state = _load(param_path) + assert param_state, "failed to load parameters, please check path" + + if self._executor is None: + executor = fluid.Executor(fluid.CPUPlace())._default_executor + else: + executor = self._executor._default_executor + + fluid.core._create_loaded_parameter( + list(self.model.state_dict().values()), global_scope(), executor) + + for key, var in self.model.state_dict().items(): + assert key in param_state, \ + "parameter [{}] is not found in model file [{}]".format( + key, param_path) + self._set_var(var, param_state[key]) + + # FIXME what if a different optimizer is used? + if not self.model._optimizer: + return + optim_path = path + ".pdopt" + optim_state = _load(optim_path) + if optim_state is None: + return + assert '__static_graph_only__' in optim_state, \ + "optimizer saved in dygraph mode is not usable in static graph" + + if self._executor is not None: + self._load_optimizer(optim_state) + else: + self._lazy_load_optimizer = optim_state + + def _load_optimizer(self, state): + prog = self._progs.get('train', None) + optim = list(filter(is_belong_to_optimizer, prog.list_vars())) + if not optim: + return + + fluid.core._create_loaded_parameter( + optim, global_scope(), self._executor._default_executor) + + for var in optim: + assert var.name in state, \ + "variable [{}] is not in optimizer state file".format(var.name) + self._set_var(var, state[var.name]) + + def _set_var(self, var, ndarray): + t = global_scope().find_var(var.name).get_tensor() + p = t._place() + if p.is_cpu_place(): + place = fluid.CPUPlace() + elif p.is_cuda_pinned_place(): + place = fluid.CUDAPinnedPlace() + else: + p = fluid.core.Place() + p.set_place(t._place()) + place = fluid.CUDAPlace(p.gpu_device_id()) + + t.set(ndarray, place) + + def _run(self, inputs, labels=None, device='CPU', device_ids=None): + inputs = to_list(inputs) + if labels is not None: + labels = to_list(labels) + assert len(inputs) == len(self._input_desc), "number of inputs" \ + + " does not match number of arguments of `forward` method" + + if self._progs.get(self.mode, None) is None: + self._make_program(self._infer_input_vars(inputs)) + + compiled_prog = self._compile_and_initialize( + self._progs[self.mode], device, device_ids) + + feed = {} + input_names = [name for name in self._input_desc.keys()] + for idx, n in enumerate(input_names): + # train and test may take different arguments + if inputs[idx] is not None: + feed[n] = inputs[idx] + if labels is not None: + for idx, v in enumerate(self._label_vars[self.mode]): + feed[v.name] = labels[idx] + + endpoints = self._endpoints[self.mode] + fetch_list = endpoints['output'] + endpoints['loss'] + num_output = len(endpoints['output']) + out = self._executor.run( + compiled_prog, feed=feed, + fetch_list=fetch_list) + if self.mode == 'test': + return out[:num_output] + else: + return out[:num_output], out[num_output:] + + def _make_program(self, inputs): + prog = self._orig_prog.clone() + if self.mode == 'train' and self.model._optimizer._learning_rate_map: + # HACK workaround learning rate map issue + lr_var = self.model._optimizer._learning_rate_map[self._orig_prog] + self.model._optimizer._learning_rate_map[prog] = lr_var + losses = [] + with fluid.program_guard(prog, self._startup_prog): + outputs = to_list(self.model.forward(*inputs)) + if self.mode != 'test': + label_vars = self._infer_label_vars(outputs) + self._label_vars[self.mode] = label_vars + losses = self.model._loss_function(outputs, label_vars) + if self.mode == 'train': + self._loss_endpoint = fluid.layers.sum(losses) + self.model._optimizer.minimize(self._loss_endpoint) + if self.mode != 'train': # clone again to put it in test mode + prog = prog.clone(for_test=True) + self._progs[self.mode] = prog + self._endpoints[self.mode] = { + "output": outputs, + "loss": losses + } + + def _infer_input_vars(self, inputs): + input_vars = [] + for idx, i in enumerate(inputs): + if i is None: # train and test may take different arguments + input_vars.append(None) + continue + ndarray = np.array(i) + name = list(self._input_desc.keys())[idx] + shape = list(self._input_desc.values())[idx] + if shape is None: + shape = (None, ) + ndarray.shape[1:] + input_vars.append(fluid.data(name, shape, ndarray.dtype)) + return input_vars + + def _infer_label_vars(self, outputs): + shapes = self.model._loss_function.infer_shape(outputs) + dtypes = self.model._loss_function.infer_dtype(outputs) + label_vars = [] + for idx, (shape, dtype) in enumerate(zip(shapes, dtypes)): + name = '__label{}'.format(idx) + label_vars.append(fluid.data(name, shape, dtype)) + return label_vars + + def _compile_and_initialize(self, prog, device='CPU', device_ids=None): + compiled_prog = self._compiled_progs.get(self.mode, None) + if compiled_prog is not None: + return compiled_prog + + places = [device.lower() == 'gpu' and fluid.CUDAPlace(i) + or fluid.CPUPlace() for i in device_ids] + + # XXX *ALL WEIGHTS* should be initialized upon model construction + # even if `forward()` may run different code path for different mode + # therefore startup program only needs to run once + if self._executor is None: + self._executor = fluid.Executor(places[0]) + # XXX incremental initialization + uninitialized = [] + for var_py in self._startup_prog.list_vars(): + var = fluid.global_scope().find_var(var_py.name) + if var and var.get_tensor()._is_initialized(): + continue + uninitialized.append(var_py) + if uninitialized: + startup_prog = self._startup_prog._prune(uninitialized) + self._executor.run(startup_prog) + + if self.mode == 'train' and self._lazy_load_optimizer: + self._load_optimizer(self._lazy_load_optimizer) + self._lazy_load_optimizer = None + + compiled_prog = fluid.CompiledProgram(prog) + if len(device_ids) > 1: + loss_name = None + if self.mode == 'train' and self._loss_endpoint is not None: + loss_name = self._loss_endpoint.name + + share_vars_from = None + if self.mode == 'eval' and 'train' in self._compiled_progs: + share_vars_from = self._compiled_progs['train'] + # HACK invalidate eval program if is compiled before train program + # quite hackish, OTOH, it is generally uncommon that the eval + # program will be run before the train program + if self.mode == 'train' and 'eval' in self._compiled_progs: + del self._compiled_progs['eval'] + + compiled_prog = compiled_prog.with_data_parallel( + loss_name=loss_name, places=places, + share_vars_from=share_vars_from) + + self._compiled_progs[self.mode] = compiled_prog + return compiled_prog + + +class DynamicGraphAdapter(object): + def __init__(self, model): + super(DynamicGraphAdapter, self).__init__() + self.model = model + + @property + def mode(self): + return self.model.mode + + @mode.setter + def mode(self, value): + self.model.mode = value + + # TODO multi device in dygraph mode not implemented at present time + def train(self, inputs, labels, device='CPU', device_ids=None): + assert self.model._optimizer and self.model._loss_function, \ + "model not ready, please call `model.prepare()` first" + super(Model, self.model).train() + self.mode = 'train' + inputs = to_list(inputs) + labels = to_list(labels) + outputs = self.model.forward(*[to_variable(x) for x in inputs]) + losses = self.model._loss_function(outputs, labels) + final_loss = fluid.layers.sum(losses) + final_loss.backward() + self.model._optimizer.minimize(final_loss) + self.model.clear_gradients() + return [to_numpy(o) for o in to_list(outputs)], \ + [to_numpy(l) for l in losses] + + def eval(self, inputs, labels, device='CPU', device_ids=None): + assert self.model._loss_function, \ + "model not ready, please call `model.prepare()` first" + super(Model, self.model).eval() + self.mode = 'eval' + inputs = to_list(inputs) + labels = to_list(labels) + outputs = self.model.forward(*[to_variable(x) for x in inputs]) + losses = self.model._loss_function(outputs, labels) + return [to_numpy(o) for o in to_list(outputs)], \ + [to_numpy(l) for l in losses] + + def test(self, inputs, device='CPU', device_ids=None): + super(Model, self.model).eval() + self.mode = 'test' + inputs = [to_variable(x) for x in to_list(inputs)] + outputs = self.model.forward(*inputs) + return [to_numpy(o) for o in to_list(outputs)] + + def parameters(self, *args, **kwargs): + return super(Model, self.model).parameters(*args, **kwargs) + + def save(self, path): + params = self.model.state_dict() + fluid.save_dygraph(params, path) + if self.model._optimizer is None: + return + if self.model._optimizer.state_dict(): + optim = self.model._optimizer.state_dict() + fluid.save_dygraph(optim, path) + + def load(self, path): + params, optim = fluid.load_dygraph(path) + self.model.set_dict(params) + if self.model._optimizer is None or optim is None: + return + self.model._optimizer.set_dict(optim) + + +class Model(fluid.dygraph.Layer): + def __init__(self): + super(Model, self).__init__(self.__class__.__name__) + self.mode = 'train' + self._loss_function = None + self._loss_weights = None + self._optimizer = None + if in_dygraph_mode(): + self._adapter = DynamicGraphAdapter(self) + else: + self._adapter = StaticGraphAdapter(self) + + def train(self, *args, **kwargs): + return self._adapter.train(*args, **kwargs) + + def eval(self, *args, **kwargs): + return self._adapter.eval(*args, **kwargs) + + def test(self, *args, **kwargs): + return self._adapter.test(*args, **kwargs) + + def save(self, *args, **kwargs): + return self._adapter.save(*args, **kwargs) + + def load(self, *args, **kwargs): + return self._adapter.load(*args, **kwargs) + + def prepare(self, optimizer, loss_function): + self._optimizer = optimizer + assert isinstance(loss_function, Loss), \ + "'loss_function' must be sub classes of 'Loss'" + self._loss_function = loss_function + + def parameters(self, *args, **kwargs): + return self._adapter.parameters(*args, **kwargs) diff --git a/resnet.py b/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..4752d76cca47fd3c1ed71c3516d3b262a62a9152 --- /dev/null +++ b/resnet.py @@ -0,0 +1,404 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division +from __future__ import print_function + +import argparse +import contextlib +import math +import os +import random +import time + +import cv2 +import numpy as np + +import paddle +import paddle.fluid as fluid +from paddle.fluid.layer_helper import LayerHelper +from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear +from paddle.fluid.dygraph.container import Sequential + +from model import Model, CrossEntropy + + +class ConvBNLayer(fluid.dygraph.Layer): + def __init__(self, + num_channels, + num_filters, + filter_size, + stride=1, + groups=1, + act=None): + super(ConvBNLayer, self).__init__() + + self._conv = Conv2D( + num_channels=num_channels, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=(filter_size - 1) // 2, + groups=groups, + act=None, + bias_attr=False) + + self._batch_norm = BatchNorm(num_filters, act=act) + + def forward(self, inputs): + x = self._conv(inputs) + x = self._batch_norm(x) + + return x + + +class BottleneckBlock(fluid.dygraph.Layer): + def __init__(self, + num_channels, + num_filters, + stride, + shortcut=True): + super(BottleneckBlock, self).__init__() + + self.conv0 = ConvBNLayer( + num_channels=num_channels, + num_filters=num_filters, + filter_size=1, + act='relu') + self.conv1 = ConvBNLayer( + num_channels=num_filters, + num_filters=num_filters, + filter_size=3, + stride=stride, + act='relu') + self.conv2 = ConvBNLayer( + num_channels=num_filters, + num_filters=num_filters * 4, + filter_size=1, + act=None) + + if not shortcut: + self.short = ConvBNLayer( + num_channels=num_channels, + num_filters=num_filters * 4, + filter_size=1, + stride=stride) + + self.shortcut = shortcut + + self._num_channels_out = num_filters * 4 + + def forward(self, inputs): + x = self.conv0(inputs) + conv1 = self.conv1(x) + conv2 = self.conv2(conv1) + + if self.shortcut: + short = inputs + else: + short = self.short(inputs) + + x = fluid.layers.elementwise_add(x=short, y=conv2) + + layer_helper = LayerHelper(self.full_name(), act='relu') + return layer_helper.append_activation(x) + + +class ResNet(Model): + def __init__(self, depth=50, num_classes=1000): + super(ResNet, self).__init__() + + layer_config = { + 50: [3, 4, 6, 3], + 101: [3, 4, 23, 3], + 152: [3, 8, 36, 3], + } + assert depth in layer_config.keys(), \ + "supported depth are {} but input layer is {}".format( + layer_config.keys(), depth) + + layers = layer_config[depth] + num_in = [64, 256, 512, 1024] + num_out = [64, 128, 256, 512] + + self.conv = ConvBNLayer( + num_channels=3, + num_filters=64, + filter_size=7, + stride=2, + act='relu') + self.pool = Pool2D( + pool_size=3, + pool_stride=2, + pool_padding=1, + pool_type='max') + + self.layers = [] + for idx, num_blocks in enumerate(layers): + blocks = [] + shortcut = False + for b in range(num_blocks): + block = BottleneckBlock( + num_channels=num_in[idx] if b == 0 else num_out[idx] * 4, + num_filters=num_out[idx], + stride=2 if b == 0 and idx != 0 else 1, + shortcut=shortcut) + blocks.append(block) + shortcut = True + layer = self.add_sublayer( + "layer_{}".format(idx), + Sequential(*blocks)) + self.layers.append(layer) + + self.global_pool = Pool2D( + pool_size=7, pool_type='avg', global_pooling=True) + + stdv = 1.0 / math.sqrt(2048 * 1.0) + self.fc_input_dim = num_out[-1] * 4 * 1 * 1 + self.fc = Linear(self.fc_input_dim, + num_classes, + act='softmax', + param_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.Uniform( + -stdv, stdv))) + + def forward(self, inputs): + x = self.conv(inputs) + x = self.pool(x) + for layer in self.layers: + x = layer(x) + x = self.global_pool(x) + x = fluid.layers.reshape(x, shape=[-1, self.fc_input_dim]) + x = self.fc(x) + return x + + +def make_optimizer(parameter_list=None): + total_images = 1281167 + base_lr = FLAGS.lr + momentum = 0.9 + weight_decay = 1e-4 + step_per_epoch = int(math.floor(float(total_images) / FLAGS.batch_size)) + boundaries = [step_per_epoch * e for e in [30, 60, 80]] + values = [base_lr * (0.1**i) for i in range(len(boundaries) + 1)] + learning_rate = fluid.layers.piecewise_decay( + boundaries=boundaries, values=values) + learning_rate = fluid.layers.linear_lr_warmup( + learning_rate=learning_rate, + warmup_steps=5 * step_per_epoch, + start_lr=0., + end_lr=base_lr) + optimizer = fluid.optimizer.Momentum( + learning_rate=learning_rate, + momentum=momentum, + regularization=fluid.regularizer.L2Decay(weight_decay), + parameter_list=parameter_list) + return optimizer + + +def accuracy(pred, label, topk=(1, )): + maxk = max(topk) + pred = np.argsort(pred)[:, ::-1][:, :maxk] + correct = (pred == np.repeat(label, maxk, 1)) + + batch_size = label.shape[0] + res = [] + for k in topk: + correct_k = correct[:, :k].sum() + res.append(100.0 * correct_k / batch_size) + return res + + +def center_crop_resize(img): + h, w = img.shape[:2] + c = int(224 / 256 * min((h, w))) + i = (h + 1 - c) // 2 + j = (w + 1 - c) // 2 + img = img[i: i + c, j: j + c, :] + return cv2.resize(img, (224, 224), 0, 0, cv2.INTER_LINEAR) + + +def random_crop_resize(img): + height, width = img.shape[:2] + area = height * width + + for attempt in range(10): + target_area = random.uniform(0.08, 1.) * area + log_ratio = (math.log(3 / 4), math.log(4 / 3)) + aspect_ratio = math.exp(random.uniform(*log_ratio)) + + w = int(round(math.sqrt(target_area * aspect_ratio))) + h = int(round(math.sqrt(target_area / aspect_ratio))) + + if w <= width and h <= height: + i = random.randint(0, height - h) + j = random.randint(0, width - w) + img = img[i: i + h, j: j + w, :] + return cv2.resize(img, (224, 224), 0, 0, cv2.INTER_LINEAR) + + return center_crop_resize(img) + + +def random_flip(img): + return img[:, ::-1, :] + + +def normalize_permute(img): + # transpose and convert to RGB from BGR + img = img.astype(np.float32).transpose((2, 0, 1))[::-1, ...] + mean = np.array([123.675, 116.28, 103.53], dtype=np.float32) + std = np.array([58.395, 57.120, 57.375], dtype=np.float32) + invstd = 1. / std + for v, m, s in zip(img, mean, invstd): + v.__isub__(m).__imul__(s) + return img + + +def compose(functions): + def process(sample): + img, label = sample + for fn in functions: + img = fn(img) + return img, label + return process + + +def image_folder(path, shuffle=False): + valid_ext = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.webp') + classes = [d for d in os.listdir(path) if + os.path.isdir(os.path.join(path, d))] + classes.sort() + class_map = {cls: idx for idx, cls in enumerate(classes)} + samples = [] + for dir in sorted(class_map.keys()): + d = os.path.join(path, dir) + for root, _, fnames in sorted(os.walk(d)): + for fname in sorted(fnames): + p = os.path.join(root, fname) + if os.path.splitext(p)[1].lower() in valid_ext: + samples.append((p, class_map[dir])) + + def iterator(): + if shuffle: + random.shuffle(samples) + for s in samples: + yield s + + return iterator + + +def run(model, loader, mode='train'): + total_loss = 0. + total_acc1 = 0. + total_acc5 = 0. + total_time = 0. + start = time.time() + device_ids = list(range(FLAGS.num_devices)) + start = time.time() + + for idx, batch in enumerate(loader()): + outputs, losses = getattr(model, mode)( + batch[0], batch[1], device='gpu', device_ids=device_ids) + top1, top5 = accuracy(outputs[0], batch[1], topk=(1, 5)) + + total_loss += np.sum(losses) + total_acc1 += top1 + total_acc5 += top5 + if idx > 1: # skip first two steps + total_time += time.time() - start + if idx % 10 == 0: + print(("{:04d} loss: {:0.3f} top1: {:0.3f}% top5: {:0.3f}% " + "time: {:0.3f}").format( + idx, total_loss / (idx + 1), total_acc1 / (idx + 1), + total_acc5 / (idx + 1), total_time / max(1, (idx - 1)))) + start = time.time() + + +def main(): + @contextlib.contextmanager + def null_guard(): + yield + + epoch = FLAGS.epoch + batch_size = FLAGS.batch_size + guard = fluid.dygraph.guard() if FLAGS.dynamic else null_guard() + + train_dir = os.path.join(FLAGS.data, 'train') + val_dir = os.path.join(FLAGS.data, 'val') + + train_loader = fluid.io.xmap_readers( + lambda batch: (np.array([b[0] for b in batch]), + np.array([b[1] for b in batch]).reshape(-1, 1)), + paddle.batch( + fluid.io.xmap_readers( + compose([cv2.imread, random_crop_resize, random_flip, + normalize_permute]), + image_folder(train_dir, shuffle=True), + process_num=8, + buffer_size=4 * batch_size), + batch_size=batch_size, + drop_last=True), + process_num=2, buffer_size=4) + + val_loader = fluid.io.xmap_readers( + lambda batch: (np.array([b[0] for b in batch]), + np.array([b[1] for b in batch]).reshape(-1, 1)), + paddle.batch( + fluid.io.xmap_readers( + compose([cv2.imread, center_crop_resize, normalize_permute]), + image_folder(val_dir), + process_num=8, + buffer_size=4 * batch_size), + batch_size=batch_size), + process_num=2, buffer_size=4) + + if not os.path.exists('resnet_checkpoints'): + os.mkdir('resnet_checkpoints') + + with guard: + model = ResNet() + optim = make_optimizer(parameter_list=model.parameters()) + model.prepare(optim, CrossEntropy()) + if FLAGS.resume is not None: + model.load(FLAGS.resume) + + for e in range(epoch): + print("======== train epoch {} ========".format(e)) + run(model, train_loader) + model.save('resnet_checkpoints/{:02d}'.format(e)) + print("======== eval epoch {} ========".format(e)) + run(model, val_loader, mode='eval') + + +if __name__ == '__main__': + parser = argparse.ArgumentParser("Resnet Training on ImageNet") + parser.add_argument('data', metavar='DIR', help='path to dataset ' + '(should have subdirectories named "train" and "val"') + parser.add_argument( + "-d", "--dynamic", action='store_true', help="enable dygraph mode") + parser.add_argument( + "-e", "--epoch", default=90, type=int, help="number of epoch") + parser.add_argument( + '--lr', '--learning-rate', default=0.1, type=float, metavar='LR', + help='initial learning rate') + parser.add_argument( + "-b", "--batch_size", default=256, type=int, help="batch size") + parser.add_argument( + "-n", "--num_devices", default=4, type=int, help="number of devices") + parser.add_argument( + "-r", "--resume", default=None, type=str, + help="checkpoint path to resume") + FLAGS = parser.parse_args() + assert FLAGS.data, "error: must provide data path" + main() diff --git a/yolov3.py b/yolov3.py new file mode 100644 index 0000000000000000000000000000000000000000..61cc91c6de43e389a74ab9cd485ab310cbe55d48 --- /dev/null +++ b/yolov3.py @@ -0,0 +1,545 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division +from __future__ import print_function + +import argparse +import contextlib +import os +import random +import time + +from functools import partial + +import cv2 +import numpy as np +from pycocotools.coco import COCO + +import paddle +import paddle.fluid as fluid +from paddle.fluid.dygraph.nn import Conv2D +from paddle.fluid.param_attr import ParamAttr +from paddle.fluid.regularizer import L2Decay + +from model import Model, Loss, shape_hints +from resnet import ResNet, ConvBNLayer + + +# XXX transfer learning +class ResNetBackBone(ResNet): + def __init__(self, depth=50): + super(ResNetBackBone, self).__init__(depth=depth) + delattr(self, 'fc') + + def forward(self, inputs): + x = self.conv(inputs) + x = self.pool(x) + outputs = [] + for layer in self.layers: + x = layer(x) + outputs.append(x) + return outputs + + +class YoloDetectionBlock(fluid.dygraph.Layer): + def __init__(self, num_channels, num_filters): + super(YoloDetectionBlock, self).__init__() + + assert num_filters % 2 == 0, \ + "num_filters {} cannot be divided by 2".format(num_filters) + + self.conv0 = ConvBNLayer( + num_channels=num_channels, + num_filters=num_filters, + filter_size=1, + act='leaky_relu') + self.conv1 = ConvBNLayer( + num_channels=num_filters, + num_filters=num_filters * 2, + filter_size=3, + act='leaky_relu') + self.conv2 = ConvBNLayer( + num_channels=num_filters * 2, + num_filters=num_filters, + filter_size=1, + act='leaky_relu') + self.conv3 = ConvBNLayer( + num_channels=num_filters, + num_filters=num_filters * 2, + filter_size=3, + act='leaky_relu') + self.route = ConvBNLayer( + num_channels=num_filters * 2, + num_filters=num_filters, + filter_size=1, + act='leaky_relu') + self.tip = ConvBNLayer( + num_channels=num_filters, + num_filters=num_filters * 2, + filter_size=3, + act='leaky_relu') + + def forward(self, inputs): + out = self.conv0(inputs) + out = self.conv1(out) + out = self.conv2(out) + out = self.conv3(out) + route = self.route(out) + tip = self.tip(route) + return route, tip + + +class YOLOv3(Model): + def __init__(self): + super(YOLOv3, self).__init__() + self.num_classes = 80 + self.anchors = [10, 13, 16, 30, 33, 23, 30, 61, 62, 45, + 59, 119, 116, 90, 156, 198, 373, 326] + self.anchor_masks = [[6, 7, 8], [3, 4, 5], [0, 1, 2]] + self.valid_thresh = 0.005 + self.nms_topk = 400 + self.nms_posk = 100 + self.draw_thresh = 0.5 + + self.backbone = ResNetBackBone() + self.block_outputs = [] + self.yolo_blocks = [] + self.route_blocks = [] + + for idx, num_chan in enumerate([2048, 1280, 640]): + yolo_block = self.add_sublayer( + "detecton_block_{}".format(idx), + YoloDetectionBlock(num_chan, num_filters=512 // (2**idx))) + self.yolo_blocks.append(yolo_block) + + num_filters = len(self.anchor_masks[idx]) * (self.num_classes + 5) + + block_out = self.add_sublayer( + "block_out_{}".format(idx), + Conv2D(num_channels=1024 // (2**idx), + num_filters=num_filters, + filter_size=1, + param_attr=ParamAttr( + initializer=fluid.initializer.Normal(0., 0.02)), + bias_attr=ParamAttr( + initializer=fluid.initializer.Constant(0.0), + regularizer=L2Decay(0.)))) + self.block_outputs.append(block_out) + if idx < 2: + route = self.add_sublayer( + "route_{}".format(idx), + ConvBNLayer(num_channels=512 // (2**idx), + num_filters=256 // (2**idx), + filter_size=1, + act='leaky_relu')) + self.route_blocks.append(route) + + @shape_hints(inputs=[None, 3, None, None], im_shape=[None, 2]) + def forward(self, inputs, im_shape): + outputs = [] + boxes = [] + scores = [] + downsample = 32 + + feats = self.backbone(inputs) + feats = feats[::-1][:len(self.anchor_masks)] + route = None + for idx, feat in enumerate(feats): + if idx > 0: + feat = fluid.layers.concat(input=[route, feat], axis=1) + route, tip = self.yolo_blocks[idx](feat) + block_out = self.block_outputs[idx](tip) + + if idx < 2: + route = self.route_blocks[idx](route) + route = fluid.layers.resize_nearest(route, scale=2) + + anchor_mask = self.anchor_masks[idx] + mask_anchors = [] + for m in anchor_mask: + mask_anchors.append(self.anchors[2 * m]) + mask_anchors.append(self.anchors[2 * m + 1]) + b, s = fluid.layers.yolo_box( + x=block_out, + img_size=im_shape, + anchors=mask_anchors, + class_num=self.num_classes, + conf_thresh=self.valid_thresh, + downsample_ratio=downsample) + + outputs.append(block_out) + boxes.append(b) + scores.append(fluid.layers.transpose(s, perm=[0, 2, 1])) + + downsample //= 2 + + if self.mode != 'test': + return outputs + + return fluid.layers.multiclass_nms( + bboxes=fluid.layers.concat(boxes, axis=1), + scores=fluid.layers.concat(scores, axis=2), + score_threshold=self.valid_thresh, + nms_top_k=self.nms_topk, + keep_top_k=self.nms_posk, + nms_threshold=self.nms_thresh, + background_label=-1) + + +class YoloLoss(Loss): + def __init__(self, num_classes=80, num_max_boxes=50): + super(YoloLoss, self).__init__() + self.num_classes = num_classes + self.num_max_boxes = num_max_boxes + self.ignore_thresh = 0.7 + self.anchors = [10, 13, 16, 30, 33, 23, 30, 61, 62, 45, + 59, 119, 116, 90, 156, 198, 373, 326] + self.anchor_masks = [[6, 7, 8], [3, 4, 5], [0, 1, 2]] + + def forward(self, outputs, labels): + downsample = 32 + gt_box, gt_label, gt_score = labels + losses = [] + + for idx, out in enumerate(outputs): + anchor_mask = self.anchor_masks[idx] + loss = fluid.layers.yolov3_loss( + x=out, + gt_box=gt_box, + gt_label=gt_label, + gt_score=gt_score, + anchor_mask=anchor_mask, + downsample_ratio=downsample, + anchors=self.anchors, + class_num=self.num_classes, + ignore_thresh=self.ignore_thresh, + use_label_smooth=True) + losses.append(loss) + downsample //= 2 + return losses + + def infer_shape(self, _): + return [ + [None, self.num_max_boxes, 4], + [None, self.num_max_boxes], + [None, self.num_max_boxes] + ] + + def infer_dtype(self, _): + return ['float32', 'int32', 'float32'] + + +def make_optimizer(parameter_list=None): + base_lr = FLAGS.lr + warm_up_iter = 4000 + momentum = 0.9 + weight_decay = 5e-4 + boundaries = [400000, 450000] + values = [base_lr * (0.1 ** i) for i in range(len(boundaries) + 1)] + learning_rate = fluid.layers.piecewise_decay( + boundaries=boundaries, + values=values) + learning_rate = fluid.layers.linear_lr_warmup( + learning_rate=learning_rate, + warmup_steps=warm_up_iter, + start_lr=0.0, + end_lr=base_lr) + optimizer = fluid.optimizer.Momentum( + learning_rate=learning_rate, + regularization=fluid.regularizer.L2Decay(weight_decay), + momentum=momentum, + parameter_list=parameter_list) + return optimizer + + +def _iou_matrix(a, b): + tl_i = np.maximum(a[:, np.newaxis, :2], b[:, :2]) + br_i = np.minimum(a[:, np.newaxis, 2:], b[:, 2:]) + area_i = np.prod(br_i - tl_i, axis=2) * (tl_i < br_i).all(axis=2) + area_a = np.prod(a[:, 2:] - a[:, :2], axis=1) + area_b = np.prod(b[:, 2:] - b[:, :2], axis=1) + area_o = (area_a[:, np.newaxis] + area_b - area_i) + return area_i / (area_o + 1e-10) + + +def _crop_box_with_center_constraint(box, crop): + cropped_box = box.copy() + cropped_box[:, :2] = np.maximum(box[:, :2], crop[:2]) + cropped_box[:, 2:] = np.minimum(box[:, 2:], crop[2:]) + cropped_box[:, :2] -= crop[:2] + cropped_box[:, 2:] -= crop[:2] + centers = (box[:, :2] + box[:, 2:]) / 2 + valid = np.logical_and( + crop[:2] <= centers, centers < crop[2:]).all(axis=1) + valid = np.logical_and( + valid, (cropped_box[:, :2] < cropped_box[:, 2:]).all(axis=1)) + return cropped_box, np.where(valid)[0] + + +def random_crop(inputs): + aspect_ratios = [.5, 2.] + thresholds = [.0, .1, .3, .5, .7, .9] + scaling = [.3, 1.] + + img, gt_box, gt_label = inputs + h, w = img.shape[:2] + + if len(gt_box) == 0: + return inputs + + np.random.shuffle(thresholds) + for thresh in thresholds: + found = False + for i in range(50): + scale = np.random.uniform(*scaling) + min_ar, max_ar = aspect_ratios + ar = np.random.uniform(max(min_ar, scale**2), + min(max_ar, scale**-2)) + crop_h = int(h * scale / np.sqrt(ar)) + crop_w = int(w * scale * np.sqrt(ar)) + crop_y = np.random.randint(0, h - crop_h) + crop_x = np.random.randint(0, w - crop_w) + crop_box = [crop_x, crop_y, crop_x + crop_w, crop_y + crop_h] + iou = _iou_matrix(gt_box, np.array([crop_box], dtype=np.float32)) + if iou.max() < thresh: + continue + + cropped_box, valid_ids = _crop_box_with_center_constraint( + gt_box, np.array(crop_box, dtype=np.float32)) + if valid_ids.size > 0: + found = True + break + + if found: + x1, y1, x2, y2 = crop_box + img = img[y1:y2, x1:x2, :] + gt_box = np.take(cropped_box, valid_ids, axis=0) + gt_label = np.take(gt_label, valid_ids, axis=0) + return img, gt_box, gt_label + + return inputs + + +# XXX mix up, color distort and random expand are skipped for simplicity +def sample_transform(inputs, mode='train', num_max_boxes=50): + if mode == 'train': + img, gt_box, gt_label = random_crop(inputs) + else: + img, gt_box, gt_label = inputs + + h, w = img.shape[:2] + # random flip + if mode == 'train' and np.random.uniform(0., 1.) > .5: + img = img[:, ::-1, :] + if len(gt_box) > 0: + swap = gt_box.copy() + gt_box[:, 0] = w - swap[:, 2] - 1 + gt_box[:, 2] = w - swap[:, 0] - 1 + + if len(gt_label) == 0: + gt_box = np.zeros([num_max_boxes, 4], dtype=np.float32) + gt_label = np.zeros([num_max_boxes, 1], dtype=np.int32) + return img, gt_box, gt_label + + gt_box = gt_box[:num_max_boxes, :] + gt_label = gt_label[:num_max_boxes, 0] + # normalize boxes + gt_box /= np.array([w, h] * 2, dtype=np.float32) + gt_box[:, 2:] = gt_box[:, 2:] - gt_box[:, :2] + gt_box[:, :2] = gt_box[:, :2] + gt_box[:, 2:] / 2. + + pad = num_max_boxes - gt_label.size + gt_box = np.pad(gt_box, ((0, pad), (0, 0)), mode='constant') + gt_label = np.pad(gt_label, [(0, pad)], mode='constant') + + return img, gt_box, gt_label + + +def batch_transform(batch, mode='train'): + if mode == 'train': + d = np.random.choice( + [320, 352, 384, 416, 448, 480, 512, 544, 576, 608]) + interp = np.random.choice(range(5)) + else: + d = 608 + interp = cv2.INTER_CUBIC + # transpose batch + imgs, gt_boxes, gt_labels = list(zip(*batch)) + imgs = np.array([cv2.resize( + img, (d, d), interpolation=interp) for img in imgs]) + + # transpose, permute and normalize + imgs = imgs.astype(np.float32)[..., ::-1] + mean = np.array([123.675, 116.28, 103.53], dtype=np.float32) + std = np.array([58.395, 57.120, 57.375], dtype=np.float32) + invstd = 1. / std + imgs -= mean + imgs *= invstd + imgs = imgs.transpose((0, 3, 1, 2)) + + im_shapes = np.full([len(imgs), 2], d, dtype=np.int32) + gt_boxes = np.array(gt_boxes) + gt_labels = np.array(gt_labels) + # XXX since mix up is not used, scores are all ones + gt_scores = np.ones_like(gt_labels, dtype=np.float32) + return [imgs, im_shapes], [gt_boxes, gt_labels, gt_scores] + + +def coco2017(root_dir, mode='train'): + json_path = os.path.join( + root_dir, 'annotations/instances_{}2017.json'.format(mode)) + coco = COCO(json_path) + img_ids = coco.getImgIds() + imgs = coco.loadImgs(img_ids) + class_map = {v: i + 1 for i, v in enumerate(coco.getCatIds())} + samples = [] + + for img in imgs: + img_path = os.path.join( + root_dir, '{}2017'.format(mode), img['file_name']) + file_path = img_path + width = img['width'] + height = img['height'] + ann_ids = coco.getAnnIds(imgIds=img['id'], iscrowd=False) + anns = coco.loadAnns(ann_ids) + + gt_box = [] + gt_label = [] + + for ann in anns: + x1, y1, w, h = ann['bbox'] + x2 = x1 + w - 1 + y2 = y1 + h - 1 + x1 = np.clip(x1, 0, width - 1) + x2 = np.clip(x2, 0, width - 1) + y1 = np.clip(y1, 0, height - 1) + y2 = np.clip(y2, 0, height - 1) + if ann['area'] <= 0 or x2 < x1 or y2 < y1: + continue + gt_label.append(ann['category_id']) + gt_box.append([x1, y1, x2, y2]) + + gt_box = np.array(gt_box, dtype=np.float32) + gt_label = np.array([class_map[cls] for cls in gt_label], + dtype=np.int32)[:, np.newaxis] + + if gt_label.size == 0 and not mode == 'train': + continue + samples.append((file_path, gt_box.copy(), gt_label.copy())) + + def iterator(): + if mode == 'train': + random.shuffle(samples) + for file_path, gt_box, gt_label in samples: + img = cv2.imread(file_path) + yield img, gt_box, gt_label + + return iterator + + +# XXX coco metrics not included for simplicity +def run(model, loader, mode='train'): + total_loss = 0. + total_time = 0. + device_ids = list(range(FLAGS.num_devices)) + start = time.time() + + for idx, batch in enumerate(loader()): + outputs, losses = getattr(model, mode)( + batch[0], batch[1], device='gpu', device_ids=device_ids) + + total_loss += np.sum(losses) + if idx > 1: # skip first two steps + total_time += time.time() - start + if idx % 10 == 0: + print("{:04d}: loss {:0.3f} time: {:0.3f}".format( + idx, total_loss / (idx + 1), total_time / max(1, (idx - 1)))) + start = time.time() + + +def main(): + @contextlib.contextmanager + def null_guard(): + yield + + epoch = FLAGS.epoch + batch_size = FLAGS.batch_size + guard = fluid.dygraph.guard() if FLAGS.dynamic else null_guard() + + train_loader = fluid.io.xmap_readers( + batch_transform, + paddle.batch( + fluid.io.xmap_readers( + sample_transform, + coco2017(FLAGS.data, 'train'), + process_num=8, + buffer_size=4 * batch_size), + batch_size=batch_size, + drop_last=True), + process_num=2, buffer_size=4) + + val_sample_transform = partial(sample_transform, mode='val') + val_batch_transform = partial(batch_transform, mode='val') + + val_loader = fluid.io.xmap_readers( + val_batch_transform, + paddle.batch( + fluid.io.xmap_readers( + val_sample_transform, + coco2017(FLAGS.data, 'val'), + process_num=8, + buffer_size=4 * batch_size), + batch_size=batch_size), + process_num=2, buffer_size=4) + + if not os.path.exists('yolo_checkpoints'): + os.mkdir('yolo_checkpoints') + + with guard: + model = YOLOv3() + # XXX transfer learning + if FLAGS.weights is not None: + model.backbone.load(FLAGS.weights) + optim = make_optimizer(parameter_list=model.parameters()) + model.prepare(optim, YoloLoss()) + + for e in range(epoch): + print("======== train epoch {} ========".format(e)) + run(model, train_loader) + model.save('yolo_checkpoints/{:02d}'.format(e)) + print("======== eval epoch {} ========".format(e)) + run(model, val_loader, mode='eval') + + +if __name__ == '__main__': + parser = argparse.ArgumentParser("Yolov3 Training on COCO") + parser.add_argument('data', metavar='DIR', help='path to COCO dataset') + parser.add_argument( + "-d", "--dynamic", action='store_true', help="enable dygraph mode") + parser.add_argument( + "-e", "--epoch", default=300, type=int, help="number of epoch") + parser.add_argument( + '--lr', '--learning-rate', default=0.001, type=float, metavar='LR', + help='initial learning rate') + parser.add_argument( + "-b", "--batch_size", default=64, type=int, help="batch size") + parser.add_argument( + "-n", "--num_devices", default=8, type=int, help="number of devices") + parser.add_argument( + "-w", "--weights", default=None, type=str, + help="path to pretrained weights") + FLAGS = parser.parse_args() + assert FLAGS.data, "error: must provide data path" + main()