提交 686616ff 编写于 作者: D dengkaipeng

use Executor and CompiledProgram.

上级 fd16377f
...@@ -32,7 +32,7 @@ YOLOv3 structure ...@@ -32,7 +32,7 @@ YOLOv3 structure
YOLOv3 networks are composed of base feature extraction network, multi-scale feature fusion layers, and output layers. YOLOv3 networks are composed of base feature extraction network, multi-scale feature fusion layers, and output layers.
1. Feature extraction network: YOLOv3 uses [DarkNet53](https://arxiv.org/abs/1612.08242) for feature extracion,Darknet53 uses a full convolution structure, replacing the pooling layer with a convolution operation of step size 2, and adding Residual-block to avoid gradient dispersion when the number of network layers is too deep. 1. Feature extraction network: YOLOv3 uses [DarkNet53](https://arxiv.org/abs/1612.08242) for feature extracion. Darknet53 uses a full convolution structure, replacing the pooling layer with a convolution operation of step size 2, and adding Residual-block to avoid gradient dispersion when the number of network layers is too deep.
2. Feature fusion layer. In order to solve the problem that the previous YOLO version is not sensitive to small objects, YOLOv3 uses three different scale feature maps for target detection, which are 13\*13, 26\*26, 52\*52, respectively, for detecting large, medium and small objects. The feature fusion layer selects the three scale feature maps produced by DarkNet as input, and draws on the idea of FPN (feature pyramid networks) to fuse the feature maps of each scale through a series of convolutional layers and upsampling. 2. Feature fusion layer. In order to solve the problem that the previous YOLO version is not sensitive to small objects, YOLOv3 uses three different scale feature maps for target detection, which are 13\*13, 26\*26, 52\*52, respectively, for detecting large, medium and small objects. The feature fusion layer selects the three scale feature maps produced by DarkNet as input, and draws on the idea of FPN (feature pyramid networks) to fuse the feature maps of each scale through a series of convolutional layers and upsampling.
...@@ -71,9 +71,8 @@ Please make sure that pre-trained model is downloaded and loaded correctly, othe ...@@ -71,9 +71,8 @@ Please make sure that pre-trained model is downloaded and loaded correctly, othe
To train the model, [cocoapi](https://github.com/cocodataset/cocoapi) is needed. Install the cocoapi: To train the model, [cocoapi](https://github.com/cocodataset/cocoapi) is needed. Install the cocoapi:
# COCOAPI=/path/to/clone/cocoapi git clone https://github.com/cocodataset/cocoapi.git
git clone https://github.com/cocodataset/cocoapi.git $COCOAPI cd PythonAPI
cd $COCOAPI/PythonAPI
# if cython is not installed # if cython is not installed
pip install Cython pip install Cython
# Install into global site-packages # Install into global site-packages
......
...@@ -33,7 +33,7 @@ YOLOv3网络结构 ...@@ -33,7 +33,7 @@ YOLOv3网络结构
YOLOv3 的网络结构由基础特征提取网络、multi-scale特征融合层和输出层组成。 YOLOv3 的网络结构由基础特征提取网络、multi-scale特征融合层和输出层组成。
1. 特征提取网络。YOLOv3使用 [DarkNet53](https://arxiv.org/abs/1612.08242)作为特征提取网络DarkNet53 基本采用了全卷积网络,用步长为2的卷积操作替代了池化层,同时添加了 Residual 单元,避免在网络层数过深时发生梯度弥散。 1. 特征提取网络。YOLOv3使用 [DarkNet53](https://arxiv.org/abs/1612.08242)作为特征提取网络DarkNet53 基本采用了全卷积网络,用步长为2的卷积操作替代了池化层,同时添加了 Residual 单元,避免在网络层数过深时发生梯度弥散。
2. 特征融合层。为了解决之前YOLO版本对小目标不敏感的问题,YOLOv3采用了3个不同尺度的特征图来进行目标检测,分别为13\*13,26\*26,52\*52,用来检测大、中、小三种目标。特征融合层选取 DarkNet 产出的三种尺度特征图作为输入,借鉴了FPN(feature pyramid networks)的思想,通过一系列的卷积层和上采样对各尺度的特征图进行融合。 2. 特征融合层。为了解决之前YOLO版本对小目标不敏感的问题,YOLOv3采用了3个不同尺度的特征图来进行目标检测,分别为13\*13,26\*26,52\*52,用来检测大、中、小三种目标。特征融合层选取 DarkNet 产出的三种尺度特征图作为输入,借鉴了FPN(feature pyramid networks)的思想,通过一系列的卷积层和上采样对各尺度的特征图进行融合。
...@@ -73,9 +73,8 @@ YOLOv3 的网络结构由基础特征提取网络、multi-scale特征融合层 ...@@ -73,9 +73,8 @@ YOLOv3 的网络结构由基础特征提取网络、multi-scale特征融合层
训练前需要首先下载[cocoapi](https://github.com/cocodataset/cocoapi) 训练前需要首先下载[cocoapi](https://github.com/cocodataset/cocoapi)
# COCOAPI=/path/to/clone/cocoapi git clone https://github.com/cocodataset/cocoapi.git
git clone https://github.com/cocodataset/cocoapi.git $COCOAPI cd PythonAPI
cd $COCOAPI/PythonAPI
# if cython is not installed # if cython is not installed
pip install Cython pip install Cython
# Install into global site-packages # Install into global site-packages
......
...@@ -39,7 +39,7 @@ _C.max_box_num = 50 ...@@ -39,7 +39,7 @@ _C.max_box_num = 50
# #
# valid score threshold to include boxes # valid score threshold to include boxes
_C.valid_thresh = 0.01 _C.valid_thresh = 0.005
# threshold vale for box non-max suppression # threshold vale for box non-max suppression
_C.nms_thresh = 0.45 _C.nms_thresh = 0.45
...@@ -69,7 +69,7 @@ _C.pixel_stds = [0.229, 0.224, 0.225] ...@@ -69,7 +69,7 @@ _C.pixel_stds = [0.229, 0.224, 0.225]
# #
# batch size # batch size
_C.batch_size = 64 _C.batch_size = 8
# derived learning rate the to get the final learning rate. # derived learning rate the to get the final learning rate.
_C.learning_rate = 0.001 _C.learning_rate = 0.001
...@@ -98,15 +98,9 @@ _C.momentum = 0.9 ...@@ -98,15 +98,9 @@ _C.momentum = 0.9
# support both CPU and GPU # support both CPU and GPU
_C.use_gpu = True _C.use_gpu = True
# Whether use parallel
_C.parallel = True
# Class number # Class number
_C.class_num = 80 _C.class_num = 80
# support pyreader
_C.use_pyreader = True
# dataset path # dataset path
_C.train_file_list = 'annotations/instances_train2017.json' _C.train_file_list = 'annotations/instances_train2017.json'
_C.train_data_dir = 'train2017' _C.train_data_dir = 'train2017'
......
...@@ -56,10 +56,8 @@ def upsample(input, scale=2,name=None): ...@@ -56,10 +56,8 @@ def upsample(input, scale=2,name=None):
class YOLOv3(object): class YOLOv3(object):
def __init__(self, def __init__(self,
is_train=True, is_train=True,
use_pyreader=True,
use_random=True): use_random=True):
self.is_train = is_train self.is_train = is_train
self.use_pyreader = use_pyreader
self.use_random = use_random self.use_random = use_random
self.outputs = [] self.outputs = []
self.losses = [] self.losses = []
...@@ -69,7 +67,7 @@ class YOLOv3(object): ...@@ -69,7 +67,7 @@ class YOLOv3(object):
def build_input(self): def build_input(self):
self.image_shape = [3, cfg.input_size, cfg.input_size] self.image_shape = [3, cfg.input_size, cfg.input_size]
if self.use_pyreader and self.is_train: if self.is_train:
self.py_reader = fluid.layers.py_reader( self.py_reader = fluid.layers.py_reader(
capacity=64, capacity=64,
shapes = [[-1] + self.image_shape, [-1, cfg.max_box_num, 4], [-1, cfg.max_box_num], [-1, cfg.max_box_num]], shapes = [[-1] + self.image_shape, [-1, cfg.max_box_num, 4], [-1, cfg.max_box_num], [-1, cfg.max_box_num]],
...@@ -81,15 +79,6 @@ class YOLOv3(object): ...@@ -81,15 +79,6 @@ class YOLOv3(object):
self.image = fluid.layers.data( self.image = fluid.layers.data(
name='image', shape=self.image_shape, dtype='float32' name='image', shape=self.image_shape, dtype='float32'
) )
self.gtbox = fluid.layers.data(
name='gtbox', shape=[cfg.max_box_num, 4], dtype='float32'
)
self.gtlabel = fluid.layers.data(
name='gtlabel', shape=[cfg.max_box_num], dtype='int32'
)
self.gtscore = fluid.layers.data(
name='gtscore', shape=[cfg.max_box_num], dtype='float32'
)
self.im_shape = fluid.layers.data( self.im_shape = fluid.layers.data(
name="im_shape", shape=[2], dtype='int32') name="im_shape", shape=[2], dtype='int32')
self.im_id = fluid.layers.data( self.im_id = fluid.layers.data(
......
...@@ -42,7 +42,7 @@ def train(): ...@@ -42,7 +42,7 @@ def train():
if not os.path.exists(cfg.model_save_dir): if not os.path.exists(cfg.model_save_dir):
os.makedirs(cfg.model_save_dir) os.makedirs(cfg.model_save_dir)
model = YOLOv3(use_pyreader=cfg.use_pyreader) model = YOLOv3()
model.build_model() model.build_model()
input_size = cfg.input_size input_size = cfg.input_size
loss = model.loss() loss = model.loss()
...@@ -69,44 +69,37 @@ def train(): ...@@ -69,44 +69,37 @@ def train():
momentum=cfg.momentum) momentum=cfg.momentum)
optimizer.minimize(loss) optimizer.minimize(loss)
fluid.memory_optimize(fluid.default_main_program())
place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace() place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace()
base_exe = fluid.Executor(place) exe = fluid.Executor(place)
base_exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
if cfg.pretrain: if cfg.pretrain:
def if_exist(var): def if_exist(var):
return os.path.exists(os.path.join(cfg.pretrain, var.name)) return os.path.exists(os.path.join(cfg.pretrain, var.name))
fluid.io.load_vars(base_exe, cfg.pretrain, predicate=if_exist) fluid.io.load_vars(exe, cfg.pretrain, predicate=if_exist)
if cfg.parallel: compile_program = fluid.compiler.CompiledProgram(
exe = fluid.ParallelExecutor( use_cuda=bool(cfg.use_gpu), loss_name=loss.name) fluid.default_main_program()).with_data_parallel(
else: loss_name=loss.name)
exe = base_exe
random_sizes = [cfg.input_size] random_sizes = [cfg.input_size]
if cfg.random_shape: if cfg.random_shape:
random_sizes = [32 * i for i in range(10, 20)] random_sizes = [32 * i for i in range(10, 20)]
mixup_iter = cfg.max_iter - cfg.start_iter - cfg.no_mixup_iter mixup_iter = cfg.max_iter - cfg.start_iter - cfg.no_mixup_iter
if cfg.use_pyreader: train_reader = reader.train(input_size, batch_size=cfg.batch_size, shuffle=True, mixup_iter=mixup_iter*devices_num, random_sizes=random_sizes, use_multiprocessing=cfg.use_multiprocess)
train_reader = reader.train(input_size, batch_size=cfg.batch_size/devices_num, shuffle=True, mixup_iter=mixup_iter*devices_num, random_sizes=random_sizes, use_multiprocessing=cfg.use_multiprocess) py_reader = model.py_reader
py_reader = model.py_reader py_reader.decorate_paddle_reader(train_reader)
py_reader.decorate_paddle_reader(train_reader)
else:
train_reader = reader.train(input_size, batch_size=cfg.batch_size, shuffle=True, mixup_iter=mixup_iter, random_sizes=random_sizes, use_multiprocessing=cfg.use_multiprocess)
feeder = fluid.DataFeeder(place=place, feed_list=model.feeds())
def save_model(postfix): def save_model(postfix):
model_path = os.path.join(cfg.model_save_dir, postfix) model_path = os.path.join(cfg.model_save_dir, postfix)
if os.path.isdir(model_path): if os.path.isdir(model_path):
shutil.rmtree(model_path) shutil.rmtree(model_path)
fluid.io.save_persistables(base_exe, model_path) fluid.io.save_persistables(exe, model_path)
fetch_list = [loss] fetch_list = [loss]
def train_loop_pyreader(): def train_loop():
py_reader.start() py_reader.start()
smoothed_loss = SmoothedValue() smoothed_loss = SmoothedValue()
try: try:
...@@ -137,43 +130,7 @@ def train(): ...@@ -137,43 +130,7 @@ def train():
except fluid.core.EOFException: except fluid.core.EOFException:
py_reader.reset() py_reader.reset()
def train_loop(): train_loop()
start_time = time.time()
prev_start_time = start_time
start = start_time
smoothed_loss = SmoothedValue()
snapshot_loss = 0
snapshot_time = 0
for iter_id, data in enumerate(train_reader()):
iter_id += cfg.start_iter
prev_start_time = start_time
start_time = time.time()
losses = exe.run(fetch_list=[v.name for v in fetch_list],
feed=feeder.feed(data))
smoothed_loss.add_value(losses[0])
snapshot_loss += losses[0]
snapshot_time += start_time - prev_start_time
lr = np.array(fluid.global_scope().find_var('learning_rate')
.get_tensor())
print("Iter {:d}, lr: {:.6f}, loss: {:.4f}, time {:.5f}".format(
iter_id, lr[0], smoothed_loss.get_mean_value(), start_time - prev_start_time))
sys.stdout.flush()
if (iter_id + 1) % cfg.snapshot_iter == 0:
save_model("model_iter{}".format(iter_id))
print("Snapshot {} saved, average loss: {}, average time: {}".format(
iter_id + 1, snapshot_loss / float(cfg.snapshot_iter),
snapshot_time / float(cfg.snapshot_iter)))
snapshot_loss = 0
snapshot_time = 0
if (iter_id + 1) == cfg.max_iter:
print("Finish iter {}".format(iter_id))
break
if cfg.use_pyreader:
train_loop_pyreader()
else:
train_loop()
save_model('model_final') save_model('model_final')
......
...@@ -94,7 +94,6 @@ def parse_args(): ...@@ -94,7 +94,6 @@ def parse_args():
add_arg = functools.partial(add_arguments, argparser=parser) add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable # yapf: disable
# ENV # ENV
add_arg('parallel', bool, True, "Whether use parallel.")
add_arg('use_gpu', bool, True, "Whether use GPU.") add_arg('use_gpu', bool, True, "Whether use GPU.")
add_arg('model_save_dir', str, 'checkpoints', "The path to save model.") add_arg('model_save_dir', str, 'checkpoints', "The path to save model.")
add_arg('pretrain', str, 'weights/darknet53', "The pretrain model path.") add_arg('pretrain', str, 'weights/darknet53', "The pretrain model path.")
...@@ -102,12 +101,10 @@ def parse_args(): ...@@ -102,12 +101,10 @@ def parse_args():
add_arg('dataset', str, 'coco2017', "Dataset: coco2014, coco2017.") add_arg('dataset', str, 'coco2017', "Dataset: coco2014, coco2017.")
add_arg('class_num', int, 80, "Class number.") add_arg('class_num', int, 80, "Class number.")
add_arg('data_dir', str, 'dataset/coco', "The data root path.") add_arg('data_dir', str, 'dataset/coco', "The data root path.")
add_arg('use_pyreader', bool, True, "Use pyreader.")
add_arg('use_profile', bool, False, "Whether use profiler.")
add_arg('start_iter', int, 0, "Start iteration.") add_arg('start_iter', int, 0, "Start iteration.")
add_arg('use_multiprocess', bool, True, "add multiprocess.") add_arg('use_multiprocess', bool, True, "add multiprocess.")
#SOLVER #SOLVER
add_arg('batch_size', int, 64, "Learning rate.") add_arg('batch_size', int, 8, "Mini-batch size per device.")
add_arg('learning_rate', float, 0.001, "Learning rate.") add_arg('learning_rate', float, 0.001, "Learning rate.")
add_arg('max_iter', int, 500200, "Iter number.") add_arg('max_iter', int, 500200, "Iter number.")
add_arg('snapshot_iter', int, 2000, "Save model every snapshot stride.") add_arg('snapshot_iter', int, 2000, "Save model every snapshot stride.")
...@@ -122,8 +119,8 @@ def parse_args(): ...@@ -122,8 +119,8 @@ def parse_args():
add_arg('nms_posk', int, 100, "The number of boxes of NMS output.") add_arg('nms_posk', int, 100, "The number of boxes of NMS output.")
add_arg('debug', bool, False, "Debug mode") add_arg('debug', bool, False, "Debug mode")
# SINGLE EVAL AND DRAW # SINGLE EVAL AND DRAW
add_arg('image_path', str, 'image', "The image path used to inference and visualize.") add_arg('image_path', str, 'image', "The image path used to inference and visualize.")
add_arg('image_name', str, None, "The single image used to inference and visualize. None to inference all images in image_path") add_arg('image_name', str, None, "The single image used to inference and visualize. None to inference all images in image_path")
add_arg('draw_thresh', float, 0.5, "Confidence score threshold to draw prediction box in image in debug mode") add_arg('draw_thresh', float, 0.5, "Confidence score threshold to draw prediction box in image in debug mode")
# yapf: enable # yapf: enable
args = parser.parse_args() args = parser.parse_args()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册