未验证 提交 cefbb796 编写于 作者: Q qingqing01 提交者: GitHub

Merge pull request #746 from qingqing01/ssd

Parallel training for MobileNet-SSD.
import paddle.v2 as paddle
import paddle.fluid as fluid
import os
import reader
import numpy as np
import load_model as load_model
from mobilenet_ssd import mobile_net
from utility import add_arguments, print_arguments
import os
import numpy as np
import argparse
import functools
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('parallel', bool, True, "Whether use parallel training.")
add_arg('use_gpu', bool, True, "Whether use GPU.")
# yapf: disable
def train(train_file_list,
def train(args,
train_file_list,
val_file_list,
data_args,
learning_rate,
......@@ -25,16 +36,37 @@ def train(train_file_list,
difficult = fluid.layers.data(
name='gt_difficult', shape=[1], dtype='int32', lod_level=1)
mbox_locs, mbox_confs, box, box_var = mobile_net(image, image_shape)
nmsed_out = fluid.layers.detection_output(
mbox_locs, mbox_confs, box, box_var, nms_threshold=0.45)
loss_vec = fluid.layers.ssd_loss(mbox_locs, mbox_confs, gt_box, gt_label,
if args.parallel:
places = fluid.layers.get_places()
pd = fluid.layers.ParallelDo(places)
with pd.do():
image_ = pd.read_input(image)
gt_box_ = pd.read_input(gt_box)
gt_label_ = pd.read_input(gt_label)
difficult_ = pd.read_input(difficult)
locs, confs, box, box_var = mobile_net(image_, image_shape)
loss = fluid.layers.ssd_loss(locs, confs, gt_box_, gt_label_,
box, box_var)
pd.write_output(loss)
pd.write_output(locs)
pd.write_output(confs)
pd.write_output(box)
pd.write_output(box_var)
loss, locs, confs, box, box_var = pd()
loss = fluid.layers.reduce_sum(loss)
else:
locs, confs, box, box_var = mobile_net(image, image_shape)
nmsed_out = fluid.layers.detection_output(
locs, mbox_confs, box, box_var, nms_threshold=0.45)
loss = fluid.layers.ssd_loss(locs, mbox_confs, gt_box, gt_label,
box, box_var)
loss = fluid.layers.nn.reduce_sum(loss_vec)
loss = fluid.layers.reduce_sum(loss)
map_eval = None
test_program = fluid.default_main_program().clone(for_test=True)
with fluid.program_guard(test_program):
nmsed_out = fluid.layers.detection_output(
locs, confs, box, box_var, nms_threshold=0.45)
map_eval = fluid.evaluator.DetectionMAP(
nmsed_out,
gt_label,
......@@ -56,7 +88,7 @@ def train(train_file_list,
optimizer.minimize(loss)
place = fluid.CUDAPlace(0)
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
......@@ -97,16 +129,18 @@ def train(train_file_list,
if __name__ == '__main__':
args = parser.parse_args()
print_arguments(args)
data_args = reader.Settings(
data_dir='./data',
label_file='label_list',
resize_h=300,
resize_w=300,
mean_value=[127.5, 127.5, 127.5])
train(
train_file_list='./data/trainval.txt',
val_file_list='./data/test.txt',
data_args=data_args,
learning_rate=0.001,
batch_size=32,
num_passes=300)
train(args,
train_file_list='./data/trainval.txt',
val_file_list='./data/test.txt',
data_args=data_args,
learning_rate=0.001,
batch_size=32,
num_passes=300)
"""Contains common utility functions."""
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import distutils.util
import numpy as np
from paddle.fluid import core
def print_arguments(args):
"""Print argparse's arguments.
Usage:
.. code-block:: python
parser = argparse.ArgumentParser()
parser.add_argument("name", default="Jonh", type=str, help="User name.")
args = parser.parse_args()
print_arguments(args)
:param args: Input argparse.Namespace for printing.
:type args: argparse.Namespace
"""
print("----------- Configuration Arguments -----------")
for arg, value in sorted(vars(args).iteritems()):
print("%s: %s" % (arg, value))
print("------------------------------------------------")
def add_arguments(argname, type, default, help, argparser, **kwargs):
"""Add argparse's argument.
Usage:
.. code-block:: python
parser = argparse.ArgumentParser()
add_argument("name", str, "Jonh", "User name.", parser)
args = parser.parse_args()
"""
type = distutils.util.strtobool if type == bool else type
argparser.add_argument(
"--" + argname,
default=default,
type=type,
help=help + ' Default: %(default)s.',
**kwargs)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册