未验证 提交 71075698 编写于 作者: X XiaoguangHu 提交者: GitHub

Merge pull request #1 from willthefrog/master

Add model API and demo
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import print_function
import argparse
import contextlib
import os
import numpy as np
import paddle
from paddle import fluid
from paddle.fluid.optimizer import Momentum
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear
from model import Model, CrossEntropy
class SimpleImgConvPool(fluid.dygraph.Layer):
def __init__(self,
num_channels,
num_filters,
filter_size,
pool_size,
pool_stride,
pool_padding=0,
pool_type='max',
global_pooling=False,
conv_stride=1,
conv_padding=0,
conv_dilation=1,
conv_groups=None,
act=None,
use_cudnn=False,
param_attr=None,
bias_attr=None):
super(SimpleImgConvPool, self).__init__('SimpleConv')
self._conv2d = Conv2D(
num_channels=num_channels,
num_filters=num_filters,
filter_size=filter_size,
stride=conv_stride,
padding=conv_padding,
dilation=conv_dilation,
groups=conv_groups,
param_attr=None,
bias_attr=None,
use_cudnn=use_cudnn)
self._pool2d = Pool2D(
pool_size=pool_size,
pool_type=pool_type,
pool_stride=pool_stride,
pool_padding=pool_padding,
global_pooling=global_pooling,
use_cudnn=use_cudnn)
def forward(self, inputs):
x = self._conv2d(inputs)
x = self._pool2d(x)
return x
class MNIST(Model):
def __init__(self):
super(MNIST, self).__init__()
self._simple_img_conv_pool_1 = SimpleImgConvPool(
1, 20, 5, 2, 2, act="relu")
self._simple_img_conv_pool_2 = SimpleImgConvPool(
20, 50, 5, 2, 2, act="relu")
pool_2_shape = 50 * 4 * 4
SIZE = 10
scale = (2.0 / (pool_2_shape**2 * SIZE))**0.5
self._fc = Linear(800,
10,
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.NormalInitializer(
loc=0.0, scale=scale)),
act="softmax")
def forward(self, inputs):
x = self._simple_img_conv_pool_1(inputs)
x = self._simple_img_conv_pool_2(x)
x = fluid.layers.flatten(x, axis=1)
x = self._fc(x)
return x
def accuracy(pred, label, topk=(1, )):
maxk = max(topk)
pred = np.argsort(pred)[:, ::-1][:, :maxk]
correct = (pred == np.repeat(label, maxk, 1))
batch_size = label.shape[0]
res = []
for k in topk:
correct_k = correct[:, :k].sum()
res.append(100.0 * correct_k / batch_size)
return res
def main():
@contextlib.contextmanager
def null_guard():
yield
guard = fluid.dygraph.guard() if FLAGS.dynamic else null_guard()
if not os.path.exists('mnist_checkpoints'):
os.mkdir('mnist_checkpoints')
train_loader = fluid.io.xmap_readers(
lambda b: [np.array([x[0] for x in b]).reshape(-1, 1, 28, 28),
np.array([x[1] for x in b]).reshape(-1, 1)],
paddle.batch(fluid.io.shuffle(paddle.dataset.mnist.train(), 6e4),
batch_size=FLAGS.batch_size, drop_last=True), 1, 1)
val_loader = fluid.io.xmap_readers(
lambda b: [np.array([x[0] for x in b]).reshape(-1, 1, 28, 28),
np.array([x[1] for x in b]).reshape(-1, 1)],
paddle.batch(paddle.dataset.mnist.test(),
batch_size=FLAGS.batch_size, drop_last=True), 1, 1)
device_ids = list(range(FLAGS.num_devices))
with guard:
model = MNIST()
optim = Momentum(learning_rate=FLAGS.lr, momentum=.9,
parameter_list=model.parameters())
model.prepare(optim, CrossEntropy())
if FLAGS.resume is not None:
model.load(FLAGS.resume)
for e in range(FLAGS.epoch):
train_loss = 0.0
train_acc = 0.0
val_loss = 0.0
val_acc = 0.0
print("======== train epoch {} ========".format(e))
for idx, batch in enumerate(train_loader()):
outputs, losses = model.train(batch[0], batch[1], device='gpu',
device_ids=device_ids)
acc = accuracy(outputs[0], batch[1])[0]
train_loss += np.sum(losses)
train_acc += acc
if idx % 10 == 0:
print("{:04d}: loss {:0.3f} top1: {:0.3f}%".format(
idx, train_loss / (idx + 1), train_acc / (idx + 1)))
print("======== eval epoch {} ========".format(e))
for idx, batch in enumerate(val_loader()):
outputs, losses = model.eval(batch[0], batch[1], device='gpu',
device_ids=device_ids)
acc = accuracy(outputs[0], batch[1])[0]
val_loss += np.sum(losses)
val_acc += acc
if idx % 10 == 0:
print("{:04d}: loss {:0.3f} top1: {:0.3f}%".format(
idx, val_loss / (idx + 1), val_acc / (idx + 1)))
model.save('mnist_checkpoints/{:02d}'.format(e))
if __name__ == '__main__':
parser = argparse.ArgumentParser("CNN training on MNIST")
parser.add_argument(
"-d", "--dynamic", action='store_true', help="enable dygraph mode")
parser.add_argument(
"-e", "--epoch", default=100, type=int, help="number of epoch")
parser.add_argument(
'--lr', '--learning-rate', default=1e-3, type=float, metavar='LR',
help='initial learning rate')
parser.add_argument(
"-b", "--batch_size", default=128, type=int, help="batch size")
parser.add_argument(
"-n", "--num_devices", default=4, type=int, help="number of devices")
parser.add_argument(
"-r", "--resume", default=None, type=str,
help="checkpoint path to resume")
FLAGS = parser.parse_args()
main()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
import inspect
import os
import pickle
from collections import OrderedDict
import numpy as np
from paddle import fluid
from paddle.fluid.framework import in_dygraph_mode, Variable
from paddle.fluid.executor import global_scope
from paddle.fluid.io import is_belong_to_optimizer
from paddle.fluid.dygraph.base import to_variable
__all__ = ['shape_hints', 'Model', 'Loss', 'CrossEntropy']
def to_list(value):
if isinstance(value, (list, tuple)):
return value
return [value]
def to_numpy(var):
assert isinstance(var, (Variable, fluid.core.VarBase)), "not a variable"
if isinstance(var, fluid.core.VarBase):
return var.numpy()
t = global_scope().find_var(var.name).get_tensor()
return np.array(t)
def extract_args(func):
if hasattr(inspect, 'getfullargspec'):
return inspect.getfullargspec(func)[0]
else:
return inspect.getargspec(func)[0]
def shape_hints(**hints):
assert hints, "hints can not be empty"
assert all(isinstance(h, (list, tuple)) for h in hints.values()), \
"shape hint must be a list or tuple"
def wrapper(func):
args = extract_args(func)
invalid = set(hints.keys()) - set(args)
assert not invalid, \
"shape hint for arguments that are not present in forward method" \
+ ": ({})".format(", ".join(invalid))
func.shape_hints = hints
return func
return wrapper
class Loss(object):
def __init__(self, average=True):
super(Loss, self).__init__()
self.average = average
def infer_shape(self, outputs):
return [o.shape for o in outputs]
def infer_dtype(self, outputs):
return [o.dtype for o in outputs]
def forward(self, outputs, labels):
raise NotImplementedError()
def __call__(self, outputs, labels):
labels = to_list(labels)
if in_dygraph_mode():
labels = [to_variable(l) for l in labels]
losses = to_list(self.forward(to_list(outputs), labels))
if not self.average:
return losses
return [fluid.layers.reduce_mean(l) for l in losses]
class CrossEntropy(Loss):
def __init__(self):
super(CrossEntropy, self).__init__()
def infer_shape(self, outputs):
return [o.shape[:-1] + (1, ) for o in outputs]
def infer_dtype(self, outputs):
return ['int64' for _ in outputs]
def forward(self, outputs, labels):
return [fluid.layers.cross_entropy(o, l) for o, l in zip(
outputs, labels)]
class StaticGraphAdapter(object):
def __init__(self, model):
super(StaticGraphAdapter, self).__init__()
self.model = model
# with `_build_once` gone, parameters are now created in `__init__`
# so we need to keep track of the parameters already created
self._startup_prog = fluid.default_startup_program()
self._orig_prog = fluid.default_main_program()
self._label_vars = {} # label variables
self._endpoints = {}
self._loss_endpoint = None
self._executor = None
self._progs = {}
self._compiled_progs = {}
self._lazy_load_optimizer = None
# parse shape hints
self._input_desc = OrderedDict([
(n, None) for n in extract_args(self.model.forward) if n != 'self'
])
if hasattr(self.model.forward, 'shape_hints'):
self._input_desc.update(self.model.forward.shape_hints)
@property
def mode(self):
return self.model.mode
@mode.setter
def mode(self, value):
self.model.mode = value
def train(self, inputs, labels, device='CPU', device_ids=None):
assert self.model._optimizer and self.model._loss_function, \
"model not ready, please call `model.prepare()` first"
self.mode = 'train'
return self._run(inputs, labels, device, device_ids)
def eval(self, inputs, labels, device='CPU', device_ids=None):
assert self.model._loss_function, \
"model not ready, please call `model.prepare()` first"
self.mode = 'eval'
return self._run(inputs, labels, device, device_ids)
def test(self, inputs, device='CPU', device_ids=None):
self.mode = 'test'
return self._run(inputs, None, device, device_ids)
def parameters(self, *args, **kwargs):
return None
def save(self, path):
def _save(state, path):
if not state:
return
state = {k: to_numpy(v) if isinstance(v, Variable) else v
for k, v in state.items()}
with open(path, 'wb') as f:
pickle.dump(state, f)
base = os.path.basename(path)
assert base != "", "path should be of 'dirname/filename' format"
param_path = path + ".pdparams"
_save(self.model.state_dict(), param_path)
prog = self._progs.get('train', None)
if prog is None or self.model._optimizer is None:
return
# XXX `optimizer.state_dict()` only work in dygraph mode
optim_path = path + ".pdopt"
optim = {p.name: p for p in filter(
is_belong_to_optimizer, prog.list_vars())}
if not optim:
return
# HACK this is contrived, optimizer state is not the same for
# static/dynamic graph mode
optim['__static_graph_only__'] = True
_save(optim, optim_path)
def load(self, path):
def _load(path):
if not os.path.exists(path):
return
with open(path, 'rb') as f:
return pickle.load(f)
param_path = path + ".pdparams"
param_state = _load(param_path)
assert param_state, "failed to load parameters, please check path"
if self._executor is None:
executor = fluid.Executor(fluid.CPUPlace())._default_executor
else:
executor = self._executor._default_executor
fluid.core._create_loaded_parameter(
list(self.model.state_dict().values()), global_scope(), executor)
for key, var in self.model.state_dict().items():
assert key in param_state, \
"parameter [{}] is not found in model file [{}]".format(
key, param_path)
self._set_var(var, param_state[key])
# FIXME what if a different optimizer is used?
if not self.model._optimizer:
return
optim_path = path + ".pdopt"
optim_state = _load(optim_path)
if optim_state is None:
return
assert '__static_graph_only__' in optim_state, \
"optimizer saved in dygraph mode is not usable in static graph"
if self._executor is not None:
self._load_optimizer(optim_state)
else:
self._lazy_load_optimizer = optim_state
def _load_optimizer(self, state):
prog = self._progs.get('train', None)
optim = list(filter(is_belong_to_optimizer, prog.list_vars()))
if not optim:
return
fluid.core._create_loaded_parameter(
optim, global_scope(), self._executor._default_executor)
for var in optim:
assert var.name in state, \
"variable [{}] is not in optimizer state file".format(var.name)
self._set_var(var, state[var.name])
def _set_var(self, var, ndarray):
t = global_scope().find_var(var.name).get_tensor()
p = t._place()
if p.is_cpu_place():
place = fluid.CPUPlace()
elif p.is_cuda_pinned_place():
place = fluid.CUDAPinnedPlace()
else:
p = fluid.core.Place()
p.set_place(t._place())
place = fluid.CUDAPlace(p.gpu_device_id())
t.set(ndarray, place)
def _run(self, inputs, labels=None, device='CPU', device_ids=None):
inputs = to_list(inputs)
if labels is not None:
labels = to_list(labels)
assert len(inputs) == len(self._input_desc), "number of inputs" \
+ " does not match number of arguments of `forward` method"
if self._progs.get(self.mode, None) is None:
self._make_program(self._infer_input_vars(inputs))
compiled_prog = self._compile_and_initialize(
self._progs[self.mode], device, device_ids)
feed = {}
input_names = [name for name in self._input_desc.keys()]
for idx, n in enumerate(input_names):
# train and test may take different arguments
if inputs[idx] is not None:
feed[n] = inputs[idx]
if labels is not None:
for idx, v in enumerate(self._label_vars[self.mode]):
feed[v.name] = labels[idx]
endpoints = self._endpoints[self.mode]
fetch_list = endpoints['output'] + endpoints['loss']
num_output = len(endpoints['output'])
out = self._executor.run(
compiled_prog, feed=feed,
fetch_list=fetch_list)
if self.mode == 'test':
return out[:num_output]
else:
return out[:num_output], out[num_output:]
def _make_program(self, inputs):
prog = self._orig_prog.clone()
if self.mode == 'train' and self.model._optimizer._learning_rate_map:
# HACK workaround learning rate map issue
lr_var = self.model._optimizer._learning_rate_map[self._orig_prog]
self.model._optimizer._learning_rate_map[prog] = lr_var
losses = []
with fluid.program_guard(prog, self._startup_prog):
outputs = to_list(self.model.forward(*inputs))
if self.mode != 'test':
label_vars = self._infer_label_vars(outputs)
self._label_vars[self.mode] = label_vars
losses = self.model._loss_function(outputs, label_vars)
if self.mode == 'train':
self._loss_endpoint = fluid.layers.sum(losses)
self.model._optimizer.minimize(self._loss_endpoint)
if self.mode != 'train': # clone again to put it in test mode
prog = prog.clone(for_test=True)
self._progs[self.mode] = prog
self._endpoints[self.mode] = {
"output": outputs,
"loss": losses
}
def _infer_input_vars(self, inputs):
input_vars = []
for idx, i in enumerate(inputs):
if i is None: # train and test may take different arguments
input_vars.append(None)
continue
ndarray = np.array(i)
name = list(self._input_desc.keys())[idx]
shape = list(self._input_desc.values())[idx]
if shape is None:
shape = (None, ) + ndarray.shape[1:]
input_vars.append(fluid.data(name, shape, ndarray.dtype))
return input_vars
def _infer_label_vars(self, outputs):
shapes = self.model._loss_function.infer_shape(outputs)
dtypes = self.model._loss_function.infer_dtype(outputs)
label_vars = []
for idx, (shape, dtype) in enumerate(zip(shapes, dtypes)):
name = '__label{}'.format(idx)
label_vars.append(fluid.data(name, shape, dtype))
return label_vars
def _compile_and_initialize(self, prog, device='CPU', device_ids=None):
compiled_prog = self._compiled_progs.get(self.mode, None)
if compiled_prog is not None:
return compiled_prog
places = [device.lower() == 'gpu' and fluid.CUDAPlace(i)
or fluid.CPUPlace() for i in device_ids]
# XXX *ALL WEIGHTS* should be initialized upon model construction
# even if `forward()` may run different code path for different mode
# therefore startup program only needs to run once
if self._executor is None:
self._executor = fluid.Executor(places[0])
# XXX incremental initialization
uninitialized = []
for var_py in self._startup_prog.list_vars():
var = fluid.global_scope().find_var(var_py.name)
if var and var.get_tensor()._is_initialized():
continue
uninitialized.append(var_py)
if uninitialized:
startup_prog = self._startup_prog._prune(uninitialized)
self._executor.run(startup_prog)
if self.mode == 'train' and self._lazy_load_optimizer:
self._load_optimizer(self._lazy_load_optimizer)
self._lazy_load_optimizer = None
compiled_prog = fluid.CompiledProgram(prog)
if len(device_ids) > 1:
loss_name = None
if self.mode == 'train' and self._loss_endpoint is not None:
loss_name = self._loss_endpoint.name
share_vars_from = None
if self.mode == 'eval' and 'train' in self._compiled_progs:
share_vars_from = self._compiled_progs['train']
# HACK invalidate eval program if is compiled before train program
# quite hackish, OTOH, it is generally uncommon that the eval
# program will be run before the train program
if self.mode == 'train' and 'eval' in self._compiled_progs:
del self._compiled_progs['eval']
compiled_prog = compiled_prog.with_data_parallel(
loss_name=loss_name, places=places,
share_vars_from=share_vars_from)
self._compiled_progs[self.mode] = compiled_prog
return compiled_prog
class DynamicGraphAdapter(object):
def __init__(self, model):
super(DynamicGraphAdapter, self).__init__()
self.model = model
@property
def mode(self):
return self.model.mode
@mode.setter
def mode(self, value):
self.model.mode = value
# TODO multi device in dygraph mode not implemented at present time
def train(self, inputs, labels, device='CPU', device_ids=None):
assert self.model._optimizer and self.model._loss_function, \
"model not ready, please call `model.prepare()` first"
super(Model, self.model).train()
self.mode = 'train'
inputs = to_list(inputs)
labels = to_list(labels)
outputs = self.model.forward(*[to_variable(x) for x in inputs])
losses = self.model._loss_function(outputs, labels)
final_loss = fluid.layers.sum(losses)
final_loss.backward()
self.model._optimizer.minimize(final_loss)
self.model.clear_gradients()
return [to_numpy(o) for o in to_list(outputs)], \
[to_numpy(l) for l in losses]
def eval(self, inputs, labels, device='CPU', device_ids=None):
assert self.model._loss_function, \
"model not ready, please call `model.prepare()` first"
super(Model, self.model).eval()
self.mode = 'eval'
inputs = to_list(inputs)
labels = to_list(labels)
outputs = self.model.forward(*[to_variable(x) for x in inputs])
losses = self.model._loss_function(outputs, labels)
return [to_numpy(o) for o in to_list(outputs)], \
[to_numpy(l) for l in losses]
def test(self, inputs, device='CPU', device_ids=None):
super(Model, self.model).eval()
self.mode = 'test'
inputs = [to_variable(x) for x in to_list(inputs)]
outputs = self.model.forward(*inputs)
return [to_numpy(o) for o in to_list(outputs)]
def parameters(self, *args, **kwargs):
return super(Model, self.model).parameters(*args, **kwargs)
def save(self, path):
params = self.model.state_dict()
fluid.save_dygraph(params, path)
if self.model._optimizer is None:
return
if self.model._optimizer.state_dict():
optim = self.model._optimizer.state_dict()
fluid.save_dygraph(optim, path)
def load(self, path):
params, optim = fluid.load_dygraph(path)
self.model.set_dict(params)
if self.model._optimizer is None or optim is None:
return
self.model._optimizer.set_dict(optim)
class Model(fluid.dygraph.Layer):
def __init__(self):
super(Model, self).__init__(self.__class__.__name__)
self.mode = 'train'
self._loss_function = None
self._loss_weights = None
self._optimizer = None
if in_dygraph_mode():
self._adapter = DynamicGraphAdapter(self)
else:
self._adapter = StaticGraphAdapter(self)
def train(self, *args, **kwargs):
return self._adapter.train(*args, **kwargs)
def eval(self, *args, **kwargs):
return self._adapter.eval(*args, **kwargs)
def test(self, *args, **kwargs):
return self._adapter.test(*args, **kwargs)
def save(self, *args, **kwargs):
return self._adapter.save(*args, **kwargs)
def load(self, *args, **kwargs):
return self._adapter.load(*args, **kwargs)
def prepare(self, optimizer, loss_function):
self._optimizer = optimizer
assert isinstance(loss_function, Loss), \
"'loss_function' must be sub classes of 'Loss'"
self._loss_function = loss_function
def parameters(self, *args, **kwargs):
return self._adapter.parameters(*args, **kwargs)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import print_function
import argparse
import contextlib
import math
import os
import random
import time
import cv2
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear
from paddle.fluid.dygraph.container import Sequential
from model import Model, CrossEntropy
class ConvBNLayer(fluid.dygraph.Layer):
def __init__(self,
num_channels,
num_filters,
filter_size,
stride=1,
groups=1,
act=None):
super(ConvBNLayer, self).__init__()
self._conv = Conv2D(
num_channels=num_channels,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
act=None,
bias_attr=False)
self._batch_norm = BatchNorm(num_filters, act=act)
def forward(self, inputs):
x = self._conv(inputs)
x = self._batch_norm(x)
return x
class BottleneckBlock(fluid.dygraph.Layer):
def __init__(self,
num_channels,
num_filters,
stride,
shortcut=True):
super(BottleneckBlock, self).__init__()
self.conv0 = ConvBNLayer(
num_channels=num_channels,
num_filters=num_filters,
filter_size=1,
act='relu')
self.conv1 = ConvBNLayer(
num_channels=num_filters,
num_filters=num_filters,
filter_size=3,
stride=stride,
act='relu')
self.conv2 = ConvBNLayer(
num_channels=num_filters,
num_filters=num_filters * 4,
filter_size=1,
act=None)
if not shortcut:
self.short = ConvBNLayer(
num_channels=num_channels,
num_filters=num_filters * 4,
filter_size=1,
stride=stride)
self.shortcut = shortcut
self._num_channels_out = num_filters * 4
def forward(self, inputs):
x = self.conv0(inputs)
conv1 = self.conv1(x)
conv2 = self.conv2(conv1)
if self.shortcut:
short = inputs
else:
short = self.short(inputs)
x = fluid.layers.elementwise_add(x=short, y=conv2)
layer_helper = LayerHelper(self.full_name(), act='relu')
return layer_helper.append_activation(x)
class ResNet(Model):
def __init__(self, depth=50, num_classes=1000):
super(ResNet, self).__init__()
layer_config = {
50: [3, 4, 6, 3],
101: [3, 4, 23, 3],
152: [3, 8, 36, 3],
}
assert depth in layer_config.keys(), \
"supported depth are {} but input layer is {}".format(
layer_config.keys(), depth)
layers = layer_config[depth]
num_in = [64, 256, 512, 1024]
num_out = [64, 128, 256, 512]
self.conv = ConvBNLayer(
num_channels=3,
num_filters=64,
filter_size=7,
stride=2,
act='relu')
self.pool = Pool2D(
pool_size=3,
pool_stride=2,
pool_padding=1,
pool_type='max')
self.layers = []
for idx, num_blocks in enumerate(layers):
blocks = []
shortcut = False
for b in range(num_blocks):
block = BottleneckBlock(
num_channels=num_in[idx] if b == 0 else num_out[idx] * 4,
num_filters=num_out[idx],
stride=2 if b == 0 and idx != 0 else 1,
shortcut=shortcut)
blocks.append(block)
shortcut = True
layer = self.add_sublayer(
"layer_{}".format(idx),
Sequential(*blocks))
self.layers.append(layer)
self.global_pool = Pool2D(
pool_size=7, pool_type='avg', global_pooling=True)
stdv = 1.0 / math.sqrt(2048 * 1.0)
self.fc_input_dim = num_out[-1] * 4 * 1 * 1
self.fc = Linear(self.fc_input_dim,
num_classes,
act='softmax',
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(
-stdv, stdv)))
def forward(self, inputs):
x = self.conv(inputs)
x = self.pool(x)
for layer in self.layers:
x = layer(x)
x = self.global_pool(x)
x = fluid.layers.reshape(x, shape=[-1, self.fc_input_dim])
x = self.fc(x)
return x
def make_optimizer(parameter_list=None):
total_images = 1281167
base_lr = FLAGS.lr
momentum = 0.9
weight_decay = 1e-4
step_per_epoch = int(math.floor(float(total_images) / FLAGS.batch_size))
boundaries = [step_per_epoch * e for e in [30, 60, 80]]
values = [base_lr * (0.1**i) for i in range(len(boundaries) + 1)]
learning_rate = fluid.layers.piecewise_decay(
boundaries=boundaries, values=values)
learning_rate = fluid.layers.linear_lr_warmup(
learning_rate=learning_rate,
warmup_steps=5 * step_per_epoch,
start_lr=0.,
end_lr=base_lr)
optimizer = fluid.optimizer.Momentum(
learning_rate=learning_rate,
momentum=momentum,
regularization=fluid.regularizer.L2Decay(weight_decay),
parameter_list=parameter_list)
return optimizer
def accuracy(pred, label, topk=(1, )):
maxk = max(topk)
pred = np.argsort(pred)[:, ::-1][:, :maxk]
correct = (pred == np.repeat(label, maxk, 1))
batch_size = label.shape[0]
res = []
for k in topk:
correct_k = correct[:, :k].sum()
res.append(100.0 * correct_k / batch_size)
return res
def center_crop_resize(img):
h, w = img.shape[:2]
c = int(224 / 256 * min((h, w)))
i = (h + 1 - c) // 2
j = (w + 1 - c) // 2
img = img[i: i + c, j: j + c, :]
return cv2.resize(img, (224, 224), 0, 0, cv2.INTER_LINEAR)
def random_crop_resize(img):
height, width = img.shape[:2]
area = height * width
for attempt in range(10):
target_area = random.uniform(0.08, 1.) * area
log_ratio = (math.log(3 / 4), math.log(4 / 3))
aspect_ratio = math.exp(random.uniform(*log_ratio))
w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))
if w <= width and h <= height:
i = random.randint(0, height - h)
j = random.randint(0, width - w)
img = img[i: i + h, j: j + w, :]
return cv2.resize(img, (224, 224), 0, 0, cv2.INTER_LINEAR)
return center_crop_resize(img)
def random_flip(img):
return img[:, ::-1, :]
def normalize_permute(img):
# transpose and convert to RGB from BGR
img = img.astype(np.float32).transpose((2, 0, 1))[::-1, ...]
mean = np.array([123.675, 116.28, 103.53], dtype=np.float32)
std = np.array([58.395, 57.120, 57.375], dtype=np.float32)
invstd = 1. / std
for v, m, s in zip(img, mean, invstd):
v.__isub__(m).__imul__(s)
return img
def compose(functions):
def process(sample):
img, label = sample
for fn in functions:
img = fn(img)
return img, label
return process
def image_folder(path, shuffle=False):
valid_ext = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.webp')
classes = [d for d in os.listdir(path) if
os.path.isdir(os.path.join(path, d))]
classes.sort()
class_map = {cls: idx for idx, cls in enumerate(classes)}
samples = []
for dir in sorted(class_map.keys()):
d = os.path.join(path, dir)
for root, _, fnames in sorted(os.walk(d)):
for fname in sorted(fnames):
p = os.path.join(root, fname)
if os.path.splitext(p)[1].lower() in valid_ext:
samples.append((p, class_map[dir]))
def iterator():
if shuffle:
random.shuffle(samples)
for s in samples:
yield s
return iterator
def run(model, loader, mode='train'):
total_loss = 0.
total_acc1 = 0.
total_acc5 = 0.
total_time = 0.
start = time.time()
device_ids = list(range(FLAGS.num_devices))
start = time.time()
for idx, batch in enumerate(loader()):
outputs, losses = getattr(model, mode)(
batch[0], batch[1], device='gpu', device_ids=device_ids)
top1, top5 = accuracy(outputs[0], batch[1], topk=(1, 5))
total_loss += np.sum(losses)
total_acc1 += top1
total_acc5 += top5
if idx > 1: # skip first two steps
total_time += time.time() - start
if idx % 10 == 0:
print(("{:04d} loss: {:0.3f} top1: {:0.3f}% top5: {:0.3f}% "
"time: {:0.3f}").format(
idx, total_loss / (idx + 1), total_acc1 / (idx + 1),
total_acc5 / (idx + 1), total_time / max(1, (idx - 1))))
start = time.time()
def main():
@contextlib.contextmanager
def null_guard():
yield
epoch = FLAGS.epoch
batch_size = FLAGS.batch_size
guard = fluid.dygraph.guard() if FLAGS.dynamic else null_guard()
train_dir = os.path.join(FLAGS.data, 'train')
val_dir = os.path.join(FLAGS.data, 'val')
train_loader = fluid.io.xmap_readers(
lambda batch: (np.array([b[0] for b in batch]),
np.array([b[1] for b in batch]).reshape(-1, 1)),
paddle.batch(
fluid.io.xmap_readers(
compose([cv2.imread, random_crop_resize, random_flip,
normalize_permute]),
image_folder(train_dir, shuffle=True),
process_num=8,
buffer_size=4 * batch_size),
batch_size=batch_size,
drop_last=True),
process_num=2, buffer_size=4)
val_loader = fluid.io.xmap_readers(
lambda batch: (np.array([b[0] for b in batch]),
np.array([b[1] for b in batch]).reshape(-1, 1)),
paddle.batch(
fluid.io.xmap_readers(
compose([cv2.imread, center_crop_resize, normalize_permute]),
image_folder(val_dir),
process_num=8,
buffer_size=4 * batch_size),
batch_size=batch_size),
process_num=2, buffer_size=4)
if not os.path.exists('resnet_checkpoints'):
os.mkdir('resnet_checkpoints')
with guard:
model = ResNet()
optim = make_optimizer(parameter_list=model.parameters())
model.prepare(optim, CrossEntropy())
if FLAGS.resume is not None:
model.load(FLAGS.resume)
for e in range(epoch):
print("======== train epoch {} ========".format(e))
run(model, train_loader)
model.save('resnet_checkpoints/{:02d}'.format(e))
print("======== eval epoch {} ========".format(e))
run(model, val_loader, mode='eval')
if __name__ == '__main__':
parser = argparse.ArgumentParser("Resnet Training on ImageNet")
parser.add_argument('data', metavar='DIR', help='path to dataset '
'(should have subdirectories named "train" and "val"')
parser.add_argument(
"-d", "--dynamic", action='store_true', help="enable dygraph mode")
parser.add_argument(
"-e", "--epoch", default=90, type=int, help="number of epoch")
parser.add_argument(
'--lr', '--learning-rate', default=0.1, type=float, metavar='LR',
help='initial learning rate')
parser.add_argument(
"-b", "--batch_size", default=256, type=int, help="batch size")
parser.add_argument(
"-n", "--num_devices", default=4, type=int, help="number of devices")
parser.add_argument(
"-r", "--resume", default=None, type=str,
help="checkpoint path to resume")
FLAGS = parser.parse_args()
assert FLAGS.data, "error: must provide data path"
main()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import print_function
import argparse
import contextlib
import os
import random
import time
from functools import partial
import cv2
import numpy as np
from pycocotools.coco import COCO
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph.nn import Conv2D
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.regularizer import L2Decay
from model import Model, Loss, shape_hints
from resnet import ResNet, ConvBNLayer
# XXX transfer learning
class ResNetBackBone(ResNet):
def __init__(self, depth=50):
super(ResNetBackBone, self).__init__(depth=depth)
delattr(self, 'fc')
def forward(self, inputs):
x = self.conv(inputs)
x = self.pool(x)
outputs = []
for layer in self.layers:
x = layer(x)
outputs.append(x)
return outputs
class YoloDetectionBlock(fluid.dygraph.Layer):
def __init__(self, num_channels, num_filters):
super(YoloDetectionBlock, self).__init__()
assert num_filters % 2 == 0, \
"num_filters {} cannot be divided by 2".format(num_filters)
self.conv0 = ConvBNLayer(
num_channels=num_channels,
num_filters=num_filters,
filter_size=1,
act='leaky_relu')
self.conv1 = ConvBNLayer(
num_channels=num_filters,
num_filters=num_filters * 2,
filter_size=3,
act='leaky_relu')
self.conv2 = ConvBNLayer(
num_channels=num_filters * 2,
num_filters=num_filters,
filter_size=1,
act='leaky_relu')
self.conv3 = ConvBNLayer(
num_channels=num_filters,
num_filters=num_filters * 2,
filter_size=3,
act='leaky_relu')
self.route = ConvBNLayer(
num_channels=num_filters * 2,
num_filters=num_filters,
filter_size=1,
act='leaky_relu')
self.tip = ConvBNLayer(
num_channels=num_filters,
num_filters=num_filters * 2,
filter_size=3,
act='leaky_relu')
def forward(self, inputs):
out = self.conv0(inputs)
out = self.conv1(out)
out = self.conv2(out)
out = self.conv3(out)
route = self.route(out)
tip = self.tip(route)
return route, tip
class YOLOv3(Model):
def __init__(self):
super(YOLOv3, self).__init__()
self.num_classes = 80
self.anchors = [10, 13, 16, 30, 33, 23, 30, 61, 62, 45,
59, 119, 116, 90, 156, 198, 373, 326]
self.anchor_masks = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
self.valid_thresh = 0.005
self.nms_topk = 400
self.nms_posk = 100
self.draw_thresh = 0.5
self.backbone = ResNetBackBone()
self.block_outputs = []
self.yolo_blocks = []
self.route_blocks = []
for idx, num_chan in enumerate([2048, 1280, 640]):
yolo_block = self.add_sublayer(
"detecton_block_{}".format(idx),
YoloDetectionBlock(num_chan, num_filters=512 // (2**idx)))
self.yolo_blocks.append(yolo_block)
num_filters = len(self.anchor_masks[idx]) * (self.num_classes + 5)
block_out = self.add_sublayer(
"block_out_{}".format(idx),
Conv2D(num_channels=1024 // (2**idx),
num_filters=num_filters,
filter_size=1,
param_attr=ParamAttr(
initializer=fluid.initializer.Normal(0., 0.02)),
bias_attr=ParamAttr(
initializer=fluid.initializer.Constant(0.0),
regularizer=L2Decay(0.))))
self.block_outputs.append(block_out)
if idx < 2:
route = self.add_sublayer(
"route_{}".format(idx),
ConvBNLayer(num_channels=512 // (2**idx),
num_filters=256 // (2**idx),
filter_size=1,
act='leaky_relu'))
self.route_blocks.append(route)
@shape_hints(inputs=[None, 3, None, None], im_shape=[None, 2])
def forward(self, inputs, im_shape):
outputs = []
boxes = []
scores = []
downsample = 32
feats = self.backbone(inputs)
feats = feats[::-1][:len(self.anchor_masks)]
route = None
for idx, feat in enumerate(feats):
if idx > 0:
feat = fluid.layers.concat(input=[route, feat], axis=1)
route, tip = self.yolo_blocks[idx](feat)
block_out = self.block_outputs[idx](tip)
if idx < 2:
route = self.route_blocks[idx](route)
route = fluid.layers.resize_nearest(route, scale=2)
anchor_mask = self.anchor_masks[idx]
mask_anchors = []
for m in anchor_mask:
mask_anchors.append(self.anchors[2 * m])
mask_anchors.append(self.anchors[2 * m + 1])
b, s = fluid.layers.yolo_box(
x=block_out,
img_size=im_shape,
anchors=mask_anchors,
class_num=self.num_classes,
conf_thresh=self.valid_thresh,
downsample_ratio=downsample)
outputs.append(block_out)
boxes.append(b)
scores.append(fluid.layers.transpose(s, perm=[0, 2, 1]))
downsample //= 2
if self.mode != 'test':
return outputs
return fluid.layers.multiclass_nms(
bboxes=fluid.layers.concat(boxes, axis=1),
scores=fluid.layers.concat(scores, axis=2),
score_threshold=self.valid_thresh,
nms_top_k=self.nms_topk,
keep_top_k=self.nms_posk,
nms_threshold=self.nms_thresh,
background_label=-1)
class YoloLoss(Loss):
def __init__(self, num_classes=80, num_max_boxes=50):
super(YoloLoss, self).__init__()
self.num_classes = num_classes
self.num_max_boxes = num_max_boxes
self.ignore_thresh = 0.7
self.anchors = [10, 13, 16, 30, 33, 23, 30, 61, 62, 45,
59, 119, 116, 90, 156, 198, 373, 326]
self.anchor_masks = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
def forward(self, outputs, labels):
downsample = 32
gt_box, gt_label, gt_score = labels
losses = []
for idx, out in enumerate(outputs):
anchor_mask = self.anchor_masks[idx]
loss = fluid.layers.yolov3_loss(
x=out,
gt_box=gt_box,
gt_label=gt_label,
gt_score=gt_score,
anchor_mask=anchor_mask,
downsample_ratio=downsample,
anchors=self.anchors,
class_num=self.num_classes,
ignore_thresh=self.ignore_thresh,
use_label_smooth=True)
losses.append(loss)
downsample //= 2
return losses
def infer_shape(self, _):
return [
[None, self.num_max_boxes, 4],
[None, self.num_max_boxes],
[None, self.num_max_boxes]
]
def infer_dtype(self, _):
return ['float32', 'int32', 'float32']
def make_optimizer(parameter_list=None):
base_lr = FLAGS.lr
warm_up_iter = 4000
momentum = 0.9
weight_decay = 5e-4
boundaries = [400000, 450000]
values = [base_lr * (0.1 ** i) for i in range(len(boundaries) + 1)]
learning_rate = fluid.layers.piecewise_decay(
boundaries=boundaries,
values=values)
learning_rate = fluid.layers.linear_lr_warmup(
learning_rate=learning_rate,
warmup_steps=warm_up_iter,
start_lr=0.0,
end_lr=base_lr)
optimizer = fluid.optimizer.Momentum(
learning_rate=learning_rate,
regularization=fluid.regularizer.L2Decay(weight_decay),
momentum=momentum,
parameter_list=parameter_list)
return optimizer
def _iou_matrix(a, b):
tl_i = np.maximum(a[:, np.newaxis, :2], b[:, :2])
br_i = np.minimum(a[:, np.newaxis, 2:], b[:, 2:])
area_i = np.prod(br_i - tl_i, axis=2) * (tl_i < br_i).all(axis=2)
area_a = np.prod(a[:, 2:] - a[:, :2], axis=1)
area_b = np.prod(b[:, 2:] - b[:, :2], axis=1)
area_o = (area_a[:, np.newaxis] + area_b - area_i)
return area_i / (area_o + 1e-10)
def _crop_box_with_center_constraint(box, crop):
cropped_box = box.copy()
cropped_box[:, :2] = np.maximum(box[:, :2], crop[:2])
cropped_box[:, 2:] = np.minimum(box[:, 2:], crop[2:])
cropped_box[:, :2] -= crop[:2]
cropped_box[:, 2:] -= crop[:2]
centers = (box[:, :2] + box[:, 2:]) / 2
valid = np.logical_and(
crop[:2] <= centers, centers < crop[2:]).all(axis=1)
valid = np.logical_and(
valid, (cropped_box[:, :2] < cropped_box[:, 2:]).all(axis=1))
return cropped_box, np.where(valid)[0]
def random_crop(inputs):
aspect_ratios = [.5, 2.]
thresholds = [.0, .1, .3, .5, .7, .9]
scaling = [.3, 1.]
img, gt_box, gt_label = inputs
h, w = img.shape[:2]
if len(gt_box) == 0:
return inputs
np.random.shuffle(thresholds)
for thresh in thresholds:
found = False
for i in range(50):
scale = np.random.uniform(*scaling)
min_ar, max_ar = aspect_ratios
ar = np.random.uniform(max(min_ar, scale**2),
min(max_ar, scale**-2))
crop_h = int(h * scale / np.sqrt(ar))
crop_w = int(w * scale * np.sqrt(ar))
crop_y = np.random.randint(0, h - crop_h)
crop_x = np.random.randint(0, w - crop_w)
crop_box = [crop_x, crop_y, crop_x + crop_w, crop_y + crop_h]
iou = _iou_matrix(gt_box, np.array([crop_box], dtype=np.float32))
if iou.max() < thresh:
continue
cropped_box, valid_ids = _crop_box_with_center_constraint(
gt_box, np.array(crop_box, dtype=np.float32))
if valid_ids.size > 0:
found = True
break
if found:
x1, y1, x2, y2 = crop_box
img = img[y1:y2, x1:x2, :]
gt_box = np.take(cropped_box, valid_ids, axis=0)
gt_label = np.take(gt_label, valid_ids, axis=0)
return img, gt_box, gt_label
return inputs
# XXX mix up, color distort and random expand are skipped for simplicity
def sample_transform(inputs, mode='train', num_max_boxes=50):
if mode == 'train':
img, gt_box, gt_label = random_crop(inputs)
else:
img, gt_box, gt_label = inputs
h, w = img.shape[:2]
# random flip
if mode == 'train' and np.random.uniform(0., 1.) > .5:
img = img[:, ::-1, :]
if len(gt_box) > 0:
swap = gt_box.copy()
gt_box[:, 0] = w - swap[:, 2] - 1
gt_box[:, 2] = w - swap[:, 0] - 1
if len(gt_label) == 0:
gt_box = np.zeros([num_max_boxes, 4], dtype=np.float32)
gt_label = np.zeros([num_max_boxes, 1], dtype=np.int32)
return img, gt_box, gt_label
gt_box = gt_box[:num_max_boxes, :]
gt_label = gt_label[:num_max_boxes, 0]
# normalize boxes
gt_box /= np.array([w, h] * 2, dtype=np.float32)
gt_box[:, 2:] = gt_box[:, 2:] - gt_box[:, :2]
gt_box[:, :2] = gt_box[:, :2] + gt_box[:, 2:] / 2.
pad = num_max_boxes - gt_label.size
gt_box = np.pad(gt_box, ((0, pad), (0, 0)), mode='constant')
gt_label = np.pad(gt_label, [(0, pad)], mode='constant')
return img, gt_box, gt_label
def batch_transform(batch, mode='train'):
if mode == 'train':
d = np.random.choice(
[320, 352, 384, 416, 448, 480, 512, 544, 576, 608])
interp = np.random.choice(range(5))
else:
d = 608
interp = cv2.INTER_CUBIC
# transpose batch
imgs, gt_boxes, gt_labels = list(zip(*batch))
imgs = np.array([cv2.resize(
img, (d, d), interpolation=interp) for img in imgs])
# transpose, permute and normalize
imgs = imgs.astype(np.float32)[..., ::-1]
mean = np.array([123.675, 116.28, 103.53], dtype=np.float32)
std = np.array([58.395, 57.120, 57.375], dtype=np.float32)
invstd = 1. / std
imgs -= mean
imgs *= invstd
imgs = imgs.transpose((0, 3, 1, 2))
im_shapes = np.full([len(imgs), 2], d, dtype=np.int32)
gt_boxes = np.array(gt_boxes)
gt_labels = np.array(gt_labels)
# XXX since mix up is not used, scores are all ones
gt_scores = np.ones_like(gt_labels, dtype=np.float32)
return [imgs, im_shapes], [gt_boxes, gt_labels, gt_scores]
def coco2017(root_dir, mode='train'):
json_path = os.path.join(
root_dir, 'annotations/instances_{}2017.json'.format(mode))
coco = COCO(json_path)
img_ids = coco.getImgIds()
imgs = coco.loadImgs(img_ids)
class_map = {v: i + 1 for i, v in enumerate(coco.getCatIds())}
samples = []
for img in imgs:
img_path = os.path.join(
root_dir, '{}2017'.format(mode), img['file_name'])
file_path = img_path
width = img['width']
height = img['height']
ann_ids = coco.getAnnIds(imgIds=img['id'], iscrowd=False)
anns = coco.loadAnns(ann_ids)
gt_box = []
gt_label = []
for ann in anns:
x1, y1, w, h = ann['bbox']
x2 = x1 + w - 1
y2 = y1 + h - 1
x1 = np.clip(x1, 0, width - 1)
x2 = np.clip(x2, 0, width - 1)
y1 = np.clip(y1, 0, height - 1)
y2 = np.clip(y2, 0, height - 1)
if ann['area'] <= 0 or x2 < x1 or y2 < y1:
continue
gt_label.append(ann['category_id'])
gt_box.append([x1, y1, x2, y2])
gt_box = np.array(gt_box, dtype=np.float32)
gt_label = np.array([class_map[cls] for cls in gt_label],
dtype=np.int32)[:, np.newaxis]
if gt_label.size == 0 and not mode == 'train':
continue
samples.append((file_path, gt_box.copy(), gt_label.copy()))
def iterator():
if mode == 'train':
random.shuffle(samples)
for file_path, gt_box, gt_label in samples:
img = cv2.imread(file_path)
yield img, gt_box, gt_label
return iterator
# XXX coco metrics not included for simplicity
def run(model, loader, mode='train'):
total_loss = 0.
total_time = 0.
device_ids = list(range(FLAGS.num_devices))
start = time.time()
for idx, batch in enumerate(loader()):
outputs, losses = getattr(model, mode)(
batch[0], batch[1], device='gpu', device_ids=device_ids)
total_loss += np.sum(losses)
if idx > 1: # skip first two steps
total_time += time.time() - start
if idx % 10 == 0:
print("{:04d}: loss {:0.3f} time: {:0.3f}".format(
idx, total_loss / (idx + 1), total_time / max(1, (idx - 1))))
start = time.time()
def main():
@contextlib.contextmanager
def null_guard():
yield
epoch = FLAGS.epoch
batch_size = FLAGS.batch_size
guard = fluid.dygraph.guard() if FLAGS.dynamic else null_guard()
train_loader = fluid.io.xmap_readers(
batch_transform,
paddle.batch(
fluid.io.xmap_readers(
sample_transform,
coco2017(FLAGS.data, 'train'),
process_num=8,
buffer_size=4 * batch_size),
batch_size=batch_size,
drop_last=True),
process_num=2, buffer_size=4)
val_sample_transform = partial(sample_transform, mode='val')
val_batch_transform = partial(batch_transform, mode='val')
val_loader = fluid.io.xmap_readers(
val_batch_transform,
paddle.batch(
fluid.io.xmap_readers(
val_sample_transform,
coco2017(FLAGS.data, 'val'),
process_num=8,
buffer_size=4 * batch_size),
batch_size=batch_size),
process_num=2, buffer_size=4)
if not os.path.exists('yolo_checkpoints'):
os.mkdir('yolo_checkpoints')
with guard:
model = YOLOv3()
# XXX transfer learning
if FLAGS.weights is not None:
model.backbone.load(FLAGS.weights)
optim = make_optimizer(parameter_list=model.parameters())
model.prepare(optim, YoloLoss())
for e in range(epoch):
print("======== train epoch {} ========".format(e))
run(model, train_loader)
model.save('yolo_checkpoints/{:02d}'.format(e))
print("======== eval epoch {} ========".format(e))
run(model, val_loader, mode='eval')
if __name__ == '__main__':
parser = argparse.ArgumentParser("Yolov3 Training on COCO")
parser.add_argument('data', metavar='DIR', help='path to COCO dataset')
parser.add_argument(
"-d", "--dynamic", action='store_true', help="enable dygraph mode")
parser.add_argument(
"-e", "--epoch", default=300, type=int, help="number of epoch")
parser.add_argument(
'--lr', '--learning-rate', default=0.001, type=float, metavar='LR',
help='initial learning rate')
parser.add_argument(
"-b", "--batch_size", default=64, type=int, help="batch size")
parser.add_argument(
"-n", "--num_devices", default=8, type=int, help="number of devices")
parser.add_argument(
"-w", "--weights", default=None, type=str,
help="path to pretrained weights")
FLAGS = parser.parse_args()
assert FLAGS.data, "error: must provide data path"
main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册