未验证 提交 635ca16d 编写于 作者: W whs 提交者: GitHub

Merge pull request #852 from qingqing01/ssd_pl_exe

Add eval.py and fix bug for MobileNet-SSD.
......@@ -18,17 +18,19 @@ add_arg('batch_size', int, 256, "Minibatch size.")
add_arg('num_layers', int, 50, "How many layers for SE-ResNeXt model.")
add_arg('with_mem_opt', bool, True, "Whether to use memory optimization or not.")
add_arg('parallel_exe', bool, True, "Whether to use ParallelExecutor to train or not.")
def train_paralle_do(args,
learning_rate,
batch_size,
num_passes,
init_model=None,
model_save_dir='model',
parallel=True,
use_nccl=True,
lr_strategy=None,
layers=50):
# yapf: enable
def train_parallel_do(args,
learning_rate,
batch_size,
num_passes,
init_model=None,
model_save_dir='model',
parallel=True,
use_nccl=True,
lr_strategy=None,
layers=50):
class_dim = 1000
image_shape = [3, 224, 224]
......@@ -62,6 +64,8 @@ def train_paralle_do(args,
acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1)
acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5)
inference_program = fluid.default_main_program().clone(for_test=True)
if lr_strategy is None:
optimizer = fluid.optimizer.Momentum(
learning_rate=learning_rate,
......@@ -76,12 +80,9 @@ def train_paralle_do(args,
momentum=0.9,
regularization=fluid.regularizer.L2Decay(1e-4))
inference_program = fluid.default_main_program().clone(for_test=True)
opts = optimizer.minimize(avg_cost)
if args.with_mem_opt:
fluid.memory_optimize(fluid.default_main_program())
fluid.memory_optimize(inference_program)
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
......@@ -154,6 +155,7 @@ def train_paralle_do(args,
os.makedirs(model_path)
fluid.io.save_persistables(exe, model_path)
def train_parallel_exe(args,
learning_rate,
batch_size,
......@@ -195,7 +197,6 @@ def train_parallel_exe(args,
if args.with_mem_opt:
fluid.memory_optimize(fluid.default_main_program())
fluid.memory_optimize(test_program)
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
......@@ -210,9 +211,7 @@ def train_parallel_exe(args,
train_exe = fluid.ParallelExecutor(use_cuda=True, loss_name=avg_cost.name)
test_exe = fluid.ParallelExecutor(
use_cuda=True,
main_program=test_program,
share_vars_from=train_exe)
use_cuda=True, main_program=test_program, share_vars_from=train_exe)
fetch_list = [avg_cost.name, acc_top1.name, acc_top5.name]
......@@ -221,9 +220,8 @@ def train_parallel_exe(args,
test_info = [[], [], []]
for batch_id, data in enumerate(train_reader()):
t1 = time.time()
loss, acc1, acc5 = train_exe.run(
fetch_list,
feed_dict=feeder.feed(data))
loss, acc1, acc5 = train_exe.run(fetch_list,
feed_dict=feeder.feed(data))
t2 = time.time()
period = t2 - t1
loss = np.mean(np.array(loss))
......@@ -245,9 +243,8 @@ def train_parallel_exe(args,
train_acc5 = np.array(train_info[2]).mean()
for data in test_reader():
t1 = time.time()
loss, acc1, acc5 = test_exe.run(
fetch_list,
feed_dict=feeder.feed(data))
loss, acc1, acc5 = test_exe.run(fetch_list,
feed_dict=feeder.feed(data))
t2 = time.time()
period = t2 - t1
loss = np.mean(np.array(loss))
......@@ -281,8 +278,6 @@ def train_parallel_exe(args,
fluid.io.save_persistables(exe, model_path)
if __name__ == '__main__':
args = parser.parse_args()
print_arguments(args)
......@@ -300,12 +295,13 @@ if __name__ == '__main__':
# layers: 50, 152
layers = args.num_layers
method = train_parallel_exe if args.parallel_exe else train_parallel_do
method(args,
learning_rate=0.1,
batch_size=batch_size,
num_passes=120,
init_model=None,
parallel=True,
use_nccl=True,
lr_strategy=lr_strategy,
layers=layers)
method(
args,
learning_rate=0.1,
batch_size=batch_size,
num_passes=120,
init_model=None,
parallel=True,
use_nccl=True,
lr_strategy=lr_strategy,
layers=layers)
......@@ -6,3 +6,4 @@ pretrained/ssd_mobilenet_v1_coco
pretrained/mobilenet_v1_imagenet.tar.gz
pretrained/mobilenet_v1_imagenet
log*
*.log
import os
import time
import numpy as np
import argparse
import functools
import paddle
import paddle.fluid as fluid
import reader
from mobilenet_ssd import mobile_net
from utility import add_arguments, print_arguments
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('dataset', str, 'pascalvoc', "coco or pascalvoc.")
add_arg('batch_size', int, 32, "Minibatch size.")
add_arg('use_gpu', bool, True, "Whether to use GPU or not.")
add_arg('data_dir', str, '', "The data root path.")
add_arg('test_list', str, '', "The testing data lists.")
add_arg('label_file', str, '', "The label file, which save the real name and is only used for Pascal VOC.")
add_arg('model_dir', str, '', "The model path.")
add_arg('ap_version', str, '11point', "11point or integral")
add_arg('resize_h', int, 300, "The resized image height.")
add_arg('resize_w', int, 300, "The resized image width.")
add_arg('mean_value_B', float, 127.5, "mean value for B channel which will be subtracted") #123.68
add_arg('mean_value_G', float, 127.5, "mean value for G channel which will be subtracted") #116.78
add_arg('mean_value_R', float, 127.5, "mean value for R channel which will be subtracted") #103.94
# yapf: enable
def eval(args, data_args, test_list, batch_size, model_dir=None):
image_shape = [3, data_args.resize_h, data_args.resize_w]
if data_args.dataset == 'coco':
num_classes = 81
elif data_args.dataset == 'pascalvoc':
num_classes = 21
image = fluid.layers.data(name='image', shape=image_shape, dtype='float32')
gt_box = fluid.layers.data(
name='gt_box', shape=[4], dtype='float32', lod_level=1)
gt_label = fluid.layers.data(
name='gt_label', shape=[1], dtype='int32', lod_level=1)
difficult = fluid.layers.data(
name='gt_difficult', shape=[1], dtype='int32', lod_level=1)
locs, confs, box, box_var = mobile_net(num_classes, image, image_shape)
nmsed_out = fluid.layers.detection_output(
locs, confs, box, box_var, nms_threshold=0.45)
loss = fluid.layers.ssd_loss(locs, confs, gt_box, gt_label, box, box_var)
loss = fluid.layers.reduce_sum(loss)
test_program = fluid.default_main_program().clone(for_test=True)
with fluid.program_guard(test_program):
map_eval = fluid.evaluator.DetectionMAP(
nmsed_out,
gt_label,
gt_box,
difficult,
num_classes,
overlap_threshold=0.5,
evaluate_difficult=False,
ap_version=args.ap_version)
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
if model_dir:
def if_exist(var):
return os.path.exists(os.path.join(model_dir, var.name))
fluid.io.load_vars(exe, model_dir, predicate=if_exist)
test_reader = paddle.batch(
reader.test(data_args, test_list), batch_size=batch_size)
feeder = fluid.DataFeeder(
place=place, feed_list=[image, gt_box, gt_label, difficult])
_, accum_map = map_eval.get_map_var()
map_eval.reset(exe)
for idx, data in enumerate(test_reader()):
test_map = exe.run(test_program,
feed=feeder.feed(data),
fetch_list=[accum_map])
if idx % 50 == 0:
print("Batch {0}, map {1}".format(idx, test_map[0]))
print("Test model {0}, map {1}".format(model_dir, test_map[0]))
if __name__ == '__main__':
args = parser.parse_args()
print_arguments(args)
data_args = reader.Settings(
dataset=args.dataset,
data_dir=args.data_dir,
label_file=args.label_file,
resize_h=args.resize_h,
resize_w=args.resize_w,
mean_value=[args.mean_value_B, args.mean_value_G, args.mean_value_R])
eval(
args,
test_list=args.test_list,
data_args=data_args,
batch_size=args.batch_size,
model_dir=args.model_dir)
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import image_util
from paddle.utils.image_util import *
import random
......@@ -23,12 +22,19 @@ import xml.etree.ElementTree
import os
import time
import copy
import functools
class Settings(object):
def __init__(self, dataset, toy, data_dir, label_file, resize_h, resize_w,
mean_value, apply_distort, apply_expand):
def __init__(self,
dataset=None,
data_dir=None,
label_file=None,
resize_h=300,
resize_w=300,
mean_value=[127.5, 127.5, 127.5],
apply_distort=True,
apply_expand=True,
toy=0):
self._dataset = dataset
self._toy = toy
self._data_dir = data_dir
......@@ -38,8 +44,6 @@ class Settings(object):
for line in open(label_fpath):
self._label_list.append(line.strip())
self._thread = 2
self._buf_size = 2048
self._apply_distort = apply_distort
self._apply_expand = apply_expand
self._resize_height = resize_h
......@@ -98,15 +102,94 @@ class Settings(object):
return self._img_mean
def process_image(sample, settings, mode):
img = Image.open(sample[0])
if img.mode == 'L':
img = img.convert('RGB')
def preprocess(img, bbox_labels, mode, settings):
img_width, img_height = img.size
sampled_labels = bbox_labels
if mode == 'train':
if settings._apply_distort:
img = image_util.distort_image(img, settings)
if settings._apply_expand:
img, bbox_labels, img_width, img_height = image_util.expand_image(
img, bbox_labels, img_width, img_height, settings)
# sampling
batch_sampler = []
# hard-code here
batch_sampler.append(
image_util.sampler(1, 1, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0))
batch_sampler.append(
image_util.sampler(1, 50, 0.3, 1.0, 0.5, 2.0, 0.1, 0.0))
batch_sampler.append(
image_util.sampler(1, 50, 0.3, 1.0, 0.5, 2.0, 0.3, 0.0))
batch_sampler.append(
image_util.sampler(1, 50, 0.3, 1.0, 0.5, 2.0, 0.5, 0.0))
batch_sampler.append(
image_util.sampler(1, 50, 0.3, 1.0, 0.5, 2.0, 0.7, 0.0))
batch_sampler.append(
image_util.sampler(1, 50, 0.3, 1.0, 0.5, 2.0, 0.9, 0.0))
batch_sampler.append(
image_util.sampler(1, 50, 0.3, 1.0, 0.5, 2.0, 0.0, 1.0))
sampled_bbox = image_util.generate_batch_samples(batch_sampler,
bbox_labels)
img = np.array(img)
if len(sampled_bbox) > 0:
idx = int(random.uniform(0, len(sampled_bbox)))
img, sampled_labels = image_util.crop_image(
img, bbox_labels, sampled_bbox[idx], img_width, img_height)
img = Image.fromarray(img)
img = img.resize((settings.resize_w, settings.resize_h), Image.ANTIALIAS)
img = np.array(img)
if mode == 'train':
mirror = int(random.uniform(0, 2))
if mirror == 1:
img = img[:, ::-1, :]
for i in xrange(len(sampled_labels)):
tmp = sampled_labels[i][1]
sampled_labels[i][1] = 1 - sampled_labels[i][3]
sampled_labels[i][3] = 1 - tmp
# HWC to CHW
if len(img.shape) == 3:
img = np.swapaxes(img, 1, 2)
img = np.swapaxes(img, 1, 0)
# RBG to BGR
img = img[[2, 1, 0], :, :]
img = img.astype('float32')
img -= settings.img_mean
img = img * 0.007843
return img, sampled_labels
def coco(settings, file_list, mode, shuffle):
# cocoapi
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
if mode == 'train' or mode == 'test':
if settings.dataset == 'coco':
# layout: category_id | xmin | ymin | xmax | ymax | iscrowd | origin_coco_bbox | segmentation | area | image_id | annotation_id
coco = COCO(file_list)
image_ids = coco.getImgIds()
images = coco.loadImgs(image_ids)
category_ids = coco.getCatIds()
category_names = [item['name'] for item in coco.loadCats(category_ids)]
if not settings.toy == 0:
images = images[:settings.toy] if len(images) > settings.toy else images
print("{} on {} with {} images".format(mode, settings.dataset, len(images)))
def reader():
if mode == 'train' and shuffle:
random.shuffle(images)
for image in images:
image_name = image['file_name']
image_path = os.path.join(settings.data_dir, image_name)
im = Image.open(image_path)
if im.mode == 'L':
im = im.convert('RGB')
im_width, im_height = im.size
# layout: category_id | xmin | ymin | xmax | ymax | iscrowd |
# origin_coco_bbox | segmentation | area | image_id | annotation_id
bbox_labels = []
annIds = coco.getAnnIds(imgIds=image['id'])
anns = coco.loadAnns(annIds)
......@@ -119,21 +202,47 @@ def process_image(sample, settings, mode):
xmin, ymin, w, h = bbox
xmax = xmin + w
ymax = ymin + h
bbox_sample.append(float(xmin) / img_width)
bbox_sample.append(float(ymin) / img_height)
bbox_sample.append(float(xmax) / img_width)
bbox_sample.append(float(ymax) / img_height)
bbox_sample.append(float(xmin) / im_width)
bbox_sample.append(float(ymin) / im_height)
bbox_sample.append(float(xmax) / im_width)
bbox_sample.append(float(ymax) / im_height)
bbox_sample.append(float(ann['iscrowd']))
#bbox_sample.append(ann['bbox'])
#bbox_sample.append(ann['segmentation'])
#bbox_sample.append(ann['area'])
#bbox_sample.append(ann['image_id'])
#bbox_sample.append(ann['id'])
bbox_labels.append(bbox_sample)
elif settings.dataset == 'pascalvoc':
im, sample_labels = preprocess(im, bbox_labels, mode, settings)
sample_labels = np.array(sample_labels)
if len(sample_labels) == 0: continue
im = im.astype('float32')
boxes = sample_labels[:, 1:5]
lbls = sample_labels[:, 0].astype('int32')
difficults = sample_labels[:, -1].astype('int32')
yield im, boxes, lbls, difficults
return reader
def pascalvoc(settings, file_list, mode, shuffle):
flist = open(file_list)
images = [line.strip() for line in flist]
if not settings.toy == 0:
images = images[:settings.toy] if len(images) > settings.toy else images
print("{} on {} with {} images".format(mode, settings.dataset, len(images)))
def reader():
if mode == 'train' and shuffle:
random.shuffle(images)
for image in images:
image_path, label_path = image.split()
image_path = os.path.join(settings.data_dir, image_path)
label_path = os.path.join(settings.data_dir, label_path)
im = Image.open(image_path)
if im.mode == 'L':
im = im.convert('RGB')
im_width, im_height = im.size
# layout: label | xmin | ymin | xmax | ymax | difficult
bbox_labels = []
root = xml.etree.ElementTree.parse(sample[1]).getroot()
root = xml.etree.ElementTree.parse(label_path).getroot()
for object in root.findall('object'):
bbox_sample = []
# start from 1
......@@ -141,124 +250,22 @@ def process_image(sample, settings, mode):
float(settings.label_list.index(object.find('name').text)))
bbox = object.find('bndbox')
difficult = float(object.find('difficult').text)
bbox_sample.append(float(bbox.find('xmin').text) / img_width)
bbox_sample.append(float(bbox.find('ymin').text) / img_height)
bbox_sample.append(float(bbox.find('xmax').text) / img_width)
bbox_sample.append(float(bbox.find('ymax').text) / img_height)
bbox_sample.append(float(bbox.find('xmin').text) / im_width)
bbox_sample.append(float(bbox.find('ymin').text) / im_height)
bbox_sample.append(float(bbox.find('xmax').text) / im_width)
bbox_sample.append(float(bbox.find('ymax').text) / im_height)
bbox_sample.append(difficult)
bbox_labels.append(bbox_sample)
im, sample_labels = preprocess(im, bbox_labels, mode, settings)
sample_labels = np.array(sample_labels)
if len(sample_labels) == 0: continue
im = im.astype('float32')
boxes = sample_labels[:, 1:5]
lbls = sample_labels[:, 0].astype('int32')
difficults = sample_labels[:, -1].astype('int32')
yield im, boxes, lbls, difficults
sample_labels = bbox_labels
if mode == 'train':
if settings._apply_distort:
img = image_util.distort_image(img, settings)
if settings._apply_expand:
img, bbox_labels, img_width, img_height = image_util.expand_image(
img, bbox_labels, img_width, img_height, settings)
batch_sampler = []
# hard-code here
batch_sampler.append(
image_util.sampler(1, 1, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0))
batch_sampler.append(
image_util.sampler(1, 50, 0.3, 1.0, 0.5, 2.0, 0.1, 0.0))
batch_sampler.append(
image_util.sampler(1, 50, 0.3, 1.0, 0.5, 2.0, 0.3, 0.0))
batch_sampler.append(
image_util.sampler(1, 50, 0.3, 1.0, 0.5, 2.0, 0.5, 0.0))
batch_sampler.append(
image_util.sampler(1, 50, 0.3, 1.0, 0.5, 2.0, 0.7, 0.0))
batch_sampler.append(
image_util.sampler(1, 50, 0.3, 1.0, 0.5, 2.0, 0.9, 0.0))
batch_sampler.append(
image_util.sampler(1, 50, 0.3, 1.0, 0.5, 2.0, 0.0, 1.0))
""" random crop """
sampled_bbox = image_util.generate_batch_samples(batch_sampler,
bbox_labels)
img = np.array(img)
if len(sampled_bbox) > 0:
idx = int(random.uniform(0, len(sampled_bbox)))
img, sample_labels = image_util.crop_image(
img, bbox_labels, sampled_bbox[idx], img_width, img_height)
img = Image.fromarray(img)
img = img.resize((settings.resize_w, settings.resize_h), Image.ANTIALIAS)
img = np.array(img)
if mode == 'train':
mirror = int(random.uniform(0, 2))
if mirror == 1:
img = img[:, ::-1, :]
for i in xrange(len(sample_labels)):
tmp = sample_labels[i][1]
sample_labels[i][1] = 1 - sample_labels[i][3]
sample_labels[i][3] = 1 - tmp
# HWC to CHW
if len(img.shape) == 3:
img = np.swapaxes(img, 1, 2)
img = np.swapaxes(img, 1, 0)
# RBG to BGR
img = img[[2, 1, 0], :, :]
img = img.astype('float32')
img -= settings.img_mean
img = img.flatten()
img = img * 0.007843
sample_labels = np.array(sample_labels)
if mode == 'train' or mode == 'test':
if len(sample_labels) != 0:
return img.astype(
'float32'), sample_labels[:, 1:5], sample_labels[:, 0].astype(
'int32'), sample_labels[:, -1].astype('int32')
elif mode == 'infer':
return img.astype('float32')
def _reader_creator(settings, file_list, mode, shuffle):
def reader():
if settings.dataset == 'coco':
# cocoapi
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
coco = COCO(file_list)
image_ids = coco.getImgIds()
images = coco.loadImgs(image_ids)
category_ids = coco.getCatIds()
category_names = [
item['name'] for item in coco.loadCats(category_ids)
]
elif settings.dataset == 'pascalvoc':
flist = open(file_list)
images = [line.strip() for line in flist]
if not settings.toy == 0:
images = images[:settings.toy] if len(
images) > settings.toy else images
print("{} on {} with {} images".format(mode, settings.dataset,
len(images)))
if shuffle:
random.shuffle(images)
for image in images:
if settings.dataset == 'coco':
image_name = image['file_name']
image_path = os.path.join(settings.data_dir, image_name)
yield [image_path]
elif settings.dataset == 'pascalvoc':
if mode == 'train' or mode == 'test':
image_path, label_path = image.split()
image_path = os.path.join(settings.data_dir, image_path)
label_path = os.path.join(settings.data_dir, label_path)
yield image_path, label_path
elif mode == 'infer':
image_path = os.path.join(settings.data_dir, image)
yield [image_path]
mapper = functools.partial(process_image, mode=mode, settings=settings)
return paddle.reader.xmap_readers(mapper, reader, settings._thread,
settings._buf_size)
return reader
def draw_bounding_box_on_image(image,
......@@ -301,9 +308,9 @@ def train(settings, file_list, shuffle=True):
elif '2017' in file_list:
sub_dir = "train2017"
train_settings.data_dir = os.path.join(settings.data_dir, sub_dir)
return _reader_creator(train_settings, file_list, 'train', shuffle)
elif settings.dataset == 'pascalvoc':
return _reader_creator(settings, file_list, 'train', shuffle)
return coco(train_settings, file_list, 'train', shuffle)
else:
return pascalvoc(settings, file_list, 'train', shuffle)
def test(settings, file_list):
......@@ -315,10 +322,29 @@ def test(settings, file_list):
elif '2017' in file_list:
sub_dir = "val2017"
test_settings.data_dir = os.path.join(settings.data_dir, sub_dir)
return _reader_creator(test_settings, file_list, 'test', False)
elif settings.dataset == 'pascalvoc':
return _reader_creator(settings, file_list, 'test', False)
return coco(test_settings, file_list, 'test', False)
else:
return pascalvoc(settings, file_list, 'test', False)
def infer(settings, file_list):
return _reader_creator(settings, file_list, 'infer', False)
def infer(settings, image_path):
def reader():
im = Image.open(image_path)
if im.mode == 'L':
im = im.convert('RGB')
im_width, im_height = im.size
img = img.resize((settings.resize_w, settings.resize_h),
Image.ANTIALIAS)
img = np.array(img)
# HWC to CHW
if len(img.shape) == 3:
img = np.swapaxes(img, 1, 2)
img = np.swapaxes(img, 1, 0)
# RBG to BGR
img = img[[2, 1, 0], :, :]
img = img.astype('float32')
img -= settings.img_mean
img = img * 0.007843
yield img
return reader
import paddle
import paddle.fluid as fluid
import reader
from mobilenet_ssd import mobile_net
from utility import add_arguments, print_arguments
import os
import time
import numpy as np
import argparse
import functools
import shutil
import paddle
import paddle.fluid as fluid
import reader
from mobilenet_ssd import mobile_net
from utility import add_arguments, print_arguments
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('learning_rate', float, 0.001, "Learning rate.")
add_arg('batch_size', int, 32, "Minibatch size.")
add_arg('num_passes', int, 25, "Epoch number.")
add_arg('num_passes', int, 120, "Epoch number.")
add_arg('parallel', bool, True, "Whether use parallel training.")
add_arg('use_gpu', bool, True, "Whether use GPU.")
add_arg('use_nccl', bool, False, "Whether use NCCL.")
add_arg('use_gpu', bool, True, "Whether to use GPU or not.")
add_arg('use_nccl', bool, False, "Whether to use NCCL or not.")
add_arg('dataset', str, 'pascalvoc', "coco or pascalvoc.")
add_arg('model_save_dir', str, 'model', "The path to save model.")
add_arg('pretrained_model', str, 'pretrained/ssd_mobilenet_v1_coco/', "The init model path.")
add_arg('apply_distort', bool, True, "Whether apply distort")
add_arg('apply_expand', bool, False, "Whether appley expand")
add_arg('resize_h', int, 300, "resize image size")
add_arg('resize_w', int, 300, "resize image size")
add_arg('mean_value_B', float, 127.5, "mean value which will be subtracted") #123.68
add_arg('mean_value_G', float, 127.5, "mean value which will be subtracted") #116.78
add_arg('mean_value_R', float, 127.5, "mean value which will be subtracted") #103.94
add_arg('apply_distort', bool, True, "Whether apply distort")
add_arg('apply_expand', bool, True, "Whether appley expand")
add_arg('ap_version', str, '11point', "11point or integral")
add_arg('resize_h', int, 300, "The resized image height.")
add_arg('resize_w', int, 300, "The resized image width.")
add_arg('mean_value_B', float, 127.5, "mean value for B channel which will be subtracted") #123.68
add_arg('mean_value_G', float, 127.5, "mean value for G channel which will be subtracted") #116.78
add_arg('mean_value_R', float, 127.5, "mean value for R channel which will be subtracted") #103.94
add_arg('is_toy', int, 0, "Toy for quick debug, 0 means using all data, while n means using only n sample")
# yapf: disable
# yapf: enable
def parallel_do(args,
......@@ -93,7 +96,7 @@ def parallel_do(args,
num_classes,
overlap_threshold=0.5,
evaluate_difficult=False,
ap_version='integral')
ap_version=args.ap_version)
if data_args.dataset == 'coco':
# learning rate decay in 12, 19 pass, respectively
......@@ -115,8 +118,10 @@ def parallel_do(args,
exe.run(fluid.default_startup_program())
if pretrained_model:
def if_exist(var):
return os.path.exists(os.path.join(pretrained_model, var.name))
fluid.io.load_vars(exe, pretrained_model, predicate=if_exist)
train_reader = paddle.batch(
......@@ -130,7 +135,7 @@ def parallel_do(args,
_, accum_map = map_eval.get_map_var()
map_eval.reset(exe)
test_map = None
for _, data in enumerate(test_reader()):
for data in test_reader():
test_map = exe.run(test_program,
feed=feeder.feed(data),
fetch_list=[accum_map])
......@@ -173,6 +178,9 @@ def parallel_exe(args,
elif data_args.dataset == 'pascalvoc':
num_classes = 21
devices = os.getenv("CUDA_VISIBLE_DEVICES") or ""
devices_num = len(devices.split(","))
image = fluid.layers.data(name='image', shape=image_shape, dtype='float32')
gt_box = fluid.layers.data(
name='gt_box', shape=[4], dtype='float32', lod_level=1)
......@@ -184,8 +192,7 @@ def parallel_exe(args,
locs, confs, box, box_var = mobile_net(num_classes, image, image_shape)
nmsed_out = fluid.layers.detection_output(
locs, confs, box, box_var, nms_threshold=0.45)
loss = fluid.layers.ssd_loss(locs, confs, gt_box, gt_label, box,
box_var)
loss = fluid.layers.ssd_loss(locs, confs, gt_box, gt_label, box, box_var)
loss = fluid.layers.reduce_sum(loss)
test_program = fluid.default_main_program().clone(for_test=True)
......@@ -198,17 +205,23 @@ def parallel_exe(args,
num_classes,
overlap_threshold=0.5,
evaluate_difficult=False,
ap_version='integral')
ap_version=args.ap_version)
if data_args.dataset == 'coco':
# learning rate decay in 12, 19 pass, respectively
if '2014' in train_file_list:
boundaries = [82783 / batch_size * 12, 82783 / batch_size * 19]
epocs = 82783 / batch_size
boundaries = [epocs * 12, epocs * 19]
elif '2017' in train_file_list:
boundaries = [118287 / batch_size * 12, 118287 / batch_size * 19]
epocs = 118287 / batch_size
boundaries = [epcos * 12, epocs * 19]
elif data_args.dataset == 'pascalvoc':
boundaries = [40000, 60000]
values = [learning_rate, learning_rate * 0.5, learning_rate * 0.25]
epocs = 19200 / batch_size
boundaries = [epocs * 40, epocs * 60, epocs * 80, epocs * 100]
values = [
learning_rate, learning_rate * 0.5, learning_rate * 0.25,
learning_rate * 0.1, learning_rate * 0.01
]
optimizer = fluid.optimizer.RMSProp(
learning_rate=fluid.layers.piecewise_decay(boundaries, values),
regularization=fluid.regularizer.L2Decay(0.00005), )
......@@ -220,12 +233,14 @@ def parallel_exe(args,
exe.run(fluid.default_startup_program())
if pretrained_model:
def if_exist(var):
return os.path.exists(os.path.join(pretrained_model, var.name))
fluid.io.load_vars(exe, pretrained_model, predicate=if_exist)
train_exe = fluid.ParallelExecutor(use_cuda=args.use_gpu,
loss_name=loss.name)
train_exe = fluid.ParallelExecutor(
use_cuda=args.use_gpu, loss_name=loss.name)
train_reader = paddle.batch(
reader.train(data_args, train_file_list), batch_size=batch_size)
......@@ -234,36 +249,48 @@ def parallel_exe(args,
feeder = fluid.DataFeeder(
place=place, feed_list=[image, gt_box, gt_label, difficult])
def test(pass_id):
def save_model(postfix):
model_path = os.path.join(model_save_dir, postfix)
if os.path.isdir(model_path):
shutil.rmtree(model_path)
print 'save models to %s' % (model_path)
fluid.io.save_persistables(exe, model_path)
best_map = 0.
def test(pass_id, best_map):
_, accum_map = map_eval.get_map_var()
map_eval.reset(exe)
test_map = None
for _, data in enumerate(test_reader()):
for data in test_reader():
test_map = exe.run(test_program,
feed=feeder.feed(data),
fetch_list=[accum_map])
if test_map[0] > best_map:
best_map = test_map[0]
save_model('best_model')
print("Test {0}, map {1}".format(pass_id, test_map[0]))
for pass_id in range(num_passes):
start_time = time.time()
prev_start_time = start_time
end_time = 0
test(pass_id)
for batch_id, data in enumerate(train_reader()):
prev_start_time = start_time
start_time = time.time()
if len(data) < devices_num: continue
loss_v, = train_exe.run(fetch_list=[loss.name],
feed_dict=feeder.feed(data))
feed_dict=feeder.feed(data))
end_time = time.time()
loss_v = np.mean(np.array(loss_v))
if batch_id % 20 == 0:
print("Pass {0}, batch {1}, loss {2}, time {3}".format(
pass_id, batch_id, loss_v, start_time - prev_start_time))
test(pass_id, best_map)
if pass_id % 10 == 0 or pass_id == num_passes - 1:
model_path = os.path.join(model_save_dir, str(pass_id))
print 'save models to %s' % (model_path)
fluid.io.save_persistables(exe, model_path)
save_model(str(pass_id))
print("Best test map {0}".format(best_map))
if __name__ == '__main__':
args = parser.parse_args()
......@@ -282,22 +309,23 @@ if __name__ == '__main__':
data_args = reader.Settings(
dataset=args.dataset,
toy=args.is_toy,
data_dir=data_dir,
label_file=label_file,
apply_distort=args.apply_distort,
apply_expand=args.apply_expand,
resize_h=args.resize_h,
resize_w=args.resize_w,
mean_value=[args.mean_value_B, args.mean_value_G, args.mean_value_R])
mean_value=[args.mean_value_B, args.mean_value_G, args.mean_value_R],
toy=args.is_toy)
#method = parallel_do
method = parallel_exe
method(args,
train_file_list=train_file_list,
val_file_list=val_file_list,
data_args=data_args,
learning_rate=args.learning_rate,
batch_size=args.batch_size,
num_passes=args.num_passes,
model_save_dir=model_save_dir,
pretrained_model=args.pretrained_model)
method(
args,
train_file_list=train_file_list,
val_file_list=val_file_list,
data_args=data_args,
learning_rate=args.learning_rate,
batch_size=args.batch_size,
num_passes=args.num_passes,
model_save_dir=model_save_dir,
pretrained_model=args.pretrained_model)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册