提交 1b2641aa 编写于 作者: B bigboss 提交者: qingqing01

Add reader and train logic for PyramidBox. (#927)

* training code of pyramidbox

* Fix code format.
上级 d3ff11c0
from PIL import Image, ImageEnhance, ImageDraw
from PIL import ImageFile
import numpy as np
import random
import math
ImageFile.LOAD_TRUNCATED_IMAGES = True #otherwise IOError raised image file is truncated
class sampler():
def __init__(self, max_sample, max_trial, min_scale, max_scale,
min_aspect_ratio, max_aspect_ratio, min_jaccard_overlap,
max_jaccard_overlap):
self.max_sample = max_sample
self.max_trial = max_trial
self.min_scale = min_scale
self.max_scale = max_scale
self.min_aspect_ratio = min_aspect_ratio
self.max_aspect_ratio = max_aspect_ratio
self.min_jaccard_overlap = min_jaccard_overlap
self.max_jaccard_overlap = max_jaccard_overlap
class bbox():
def __init__(self, xmin, ymin, xmax, ymax):
self.xmin = xmin
self.ymin = ymin
self.xmax = xmax
self.ymax = ymax
def bbox_area(src_bbox):
width = src_bbox.xmax - src_bbox.xmin
height = src_bbox.ymax - src_bbox.ymin
return width * height
def generate_sample(sampler):
scale = random.uniform(sampler.min_scale, sampler.max_scale)
min_aspect_ratio = max(sampler.min_aspect_ratio, (scale**2.0))
max_aspect_ratio = min(sampler.max_aspect_ratio, 1 / (scale**2.0))
aspect_ratio = random.uniform(min_aspect_ratio, max_aspect_ratio)
bbox_width = scale * (aspect_ratio**0.5)
bbox_height = scale / (aspect_ratio**0.5)
xmin_bound = 1 - bbox_width
ymin_bound = 1 - bbox_height
xmin = random.uniform(0, xmin_bound)
ymin = random.uniform(0, ymin_bound)
xmax = xmin + bbox_width
ymax = ymin + bbox_height
sampled_bbox = bbox(xmin, ymin, xmax, ymax)
return sampled_bbox
def jaccard_overlap(sample_bbox, object_bbox):
if sample_bbox.xmin >= object_bbox.xmax or \
sample_bbox.xmax <= object_bbox.xmin or \
sample_bbox.ymin >= object_bbox.ymax or \
sample_bbox.ymax <= object_bbox.ymin:
return 0
intersect_xmin = max(sample_bbox.xmin, object_bbox.xmin)
intersect_ymin = max(sample_bbox.ymin, object_bbox.ymin)
intersect_xmax = min(sample_bbox.xmax, object_bbox.xmax)
intersect_ymax = min(sample_bbox.ymax, object_bbox.ymax)
intersect_size = (intersect_xmax - intersect_xmin) * (
intersect_ymax - intersect_ymin)
sample_bbox_size = bbox_area(sample_bbox)
object_bbox_size = bbox_area(object_bbox)
overlap = intersect_size / (
sample_bbox_size + object_bbox_size - intersect_size)
return overlap
def satisfy_sample_constraint(sampler, sample_bbox, bbox_labels):
if sampler.min_jaccard_overlap == 0 and sampler.max_jaccard_overlap == 0:
return True
for i in range(len(bbox_labels)):
object_bbox = bbox(
bbox_labels[i][0],
bbox_labels[i][1], # tangxu @ 2018-05-17
bbox_labels[i][2],
bbox_labels[i][3])
overlap = jaccard_overlap(sample_bbox, object_bbox)
if sampler.min_jaccard_overlap != 0 and \
overlap < sampler.min_jaccard_overlap:
continue
if sampler.max_jaccard_overlap != 0 and \
overlap > sampler.max_jaccard_overlap:
continue
return True
return False
def generate_batch_samples(batch_sampler, bbox_labels):
sampled_bbox = []
index = []
c = 0
for sampler in batch_sampler:
found = 0
for i in range(sampler.max_trial):
if found >= sampler.max_sample:
break
sample_bbox = generate_sample(sampler)
if satisfy_sample_constraint(sampler, sample_bbox, bbox_labels):
sampled_bbox.append(sample_bbox)
found = found + 1
index.append(c)
c = c + 1
return sampled_bbox
def clip_bbox(src_bbox):
src_bbox.xmin = max(min(src_bbox.xmin, 1.0), 0.0)
src_bbox.ymin = max(min(src_bbox.ymin, 1.0), 0.0)
src_bbox.xmax = max(min(src_bbox.xmax, 1.0), 0.0)
src_bbox.ymax = max(min(src_bbox.ymax, 1.0), 0.0)
return src_bbox
def meet_emit_constraint(src_bbox, sample_bbox):
center_x = (src_bbox.xmax + src_bbox.xmin) / 2
center_y = (src_bbox.ymax + src_bbox.ymin) / 2
if center_x >= sample_bbox.xmin and \
center_x <= sample_bbox.xmax and \
center_y >= sample_bbox.ymin and \
center_y <= sample_bbox.ymax:
return True
return False
def transform_labels(bbox_labels, sample_bbox):
proj_bbox = bbox(0, 0, 0, 0)
sample_labels = []
for i in range(len(bbox_labels)):
sample_label = []
object_bbox = bbox(bbox_labels[i][0], bbox_labels[i][1],
bbox_labels[i][2], bbox_labels[i][3])
if not meet_emit_constraint(object_bbox, sample_bbox):
continue
sample_width = sample_bbox.xmax - sample_bbox.xmin
sample_height = sample_bbox.ymax - sample_bbox.ymin
proj_bbox.xmin = (object_bbox.xmin - sample_bbox.xmin) / sample_width
proj_bbox.ymin = (object_bbox.ymin - sample_bbox.ymin) / sample_height
proj_bbox.xmax = (object_bbox.xmax - sample_bbox.xmin) / sample_width
proj_bbox.ymax = (object_bbox.ymax - sample_bbox.ymin) / sample_height
proj_bbox = clip_bbox(proj_bbox)
if bbox_area(proj_bbox) > 0:
sample_label.append(bbox_labels[i][0])
sample_label.append(float(proj_bbox.xmin))
sample_label.append(float(proj_bbox.ymin))
sample_label.append(float(proj_bbox.xmax))
sample_label.append(float(proj_bbox.ymax))
#sample_label.append(bbox_labels[i][5])
sample_label = sample_label + bbox_labels[i][5:]
sample_labels.append(sample_label)
return sample_labels
def crop_image(img, bbox_labels, sample_bbox, image_width, image_height):
sample_bbox = clip_bbox(sample_bbox)
xmin = int(sample_bbox.xmin * image_width)
xmax = int(sample_bbox.xmax * image_width)
ymin = int(sample_bbox.ymin * image_height)
ymax = int(sample_bbox.ymax * image_height)
sample_img = img[ymin:ymax, xmin:xmax]
sample_labels = transform_labels(bbox_labels, sample_bbox)
return sample_img, sample_labels
def random_brightness(img, settings):
prob = random.uniform(0, 1)
if prob < settings._brightness_prob:
delta = random.uniform(-settings._brightness_delta,
settings._brightness_delta) + 1
img = ImageEnhance.Brightness(img).enhance(delta)
return img
def random_contrast(img, settings):
prob = random.uniform(0, 1)
if prob < settings._contrast_prob:
delta = random.uniform(-settings._contrast_delta,
settings._contrast_delta) + 1
img = ImageEnhance.Contrast(img).enhance(delta)
return img
def random_saturation(img, settings):
prob = random.uniform(0, 1)
if prob < settings._saturation_prob:
delta = random.uniform(-settings._saturation_delta,
settings._saturation_delta) + 1
img = ImageEnhance.Color(img).enhance(delta)
return img
def random_hue(img, settings):
prob = random.uniform(0, 1)
if prob < settings._hue_prob:
delta = random.uniform(-settings._hue_delta, settings._hue_delta)
img_hsv = np.array(img.convert('HSV'))
img_hsv[:, :, 0] = img_hsv[:, :, 0] + delta
img = Image.fromarray(img_hsv, mode='HSV').convert('RGB')
return img
def distort_image(img, settings):
prob = random.uniform(0, 1)
# Apply different distort order
if prob > 0.5:
img = random_brightness(img, settings)
img = random_contrast(img, settings)
img = random_saturation(img, settings)
img = random_hue(img, settings)
else:
img = random_brightness(img, settings)
img = random_saturation(img, settings)
img = random_hue(img, settings)
img = random_contrast(img, settings)
return img
def expand_image(img, bbox_labels, img_width, img_height, settings):
prob = random.uniform(0, 1)
if prob < settings._expand_prob:
if settings._expand_max_ratio - 1 >= 0.01:
expand_ratio = random.uniform(1, settings._expand_max_ratio)
height = int(img_height * expand_ratio)
width = int(img_width * expand_ratio)
h_off = math.floor(random.uniform(0, height - img_height))
w_off = math.floor(random.uniform(0, width - img_width))
expand_bbox = bbox(-w_off / img_width, -h_off / img_height,
(width - w_off) / img_width,
(height - h_off) / img_height)
expand_img = np.ones((height, width, 3))
expand_img = np.uint8(expand_img * np.squeeze(settings._img_mean))
expand_img = Image.fromarray(expand_img)
expand_img.paste(img, (int(w_off), int(h_off)))
bbox_labels = transform_labels(bbox_labels, expand_bbox)
return expand_img, bbox_labels, width, height
return img, bbox_labels, img_width, img_height
...@@ -45,7 +45,7 @@ def conv_block(input, groups, filters, ksizes, strides=None, with_pool=True): ...@@ -45,7 +45,7 @@ def conv_block(input, groups, filters, ksizes, strides=None, with_pool=True):
class PyramidBox(object): class PyramidBox(object):
def __init__(self, data_shape, is_infer=False): def __init__(self, data_shape, is_infer=False, sub_network=False):
self.data_shape = data_shape self.data_shape = data_shape
self.min_sizes = [16., 32., 64., 128., 256., 512.] self.min_sizes = [16., 32., 64., 128., 256., 512.]
self.steps = [4., 8., 16., 32., 64., 128.] self.steps = [4., 8., 16., 32., 64., 128.]
...@@ -54,9 +54,10 @@ class PyramidBox(object): ...@@ -54,9 +54,10 @@ class PyramidBox(object):
# the base network is VGG with atrus layers # the base network is VGG with atrus layers
self._input() self._input()
self._vgg() self._vgg()
self._low_level_fpn() if sub_network:
self._cpm_module() self._low_level_fpn()
self._pyramidbox() self._cpm_module()
self._pyramidbox()
def _input(self): def _input(self):
self.image = fluid.layers.data( self.image = fluid.layers.data(
...@@ -66,6 +67,8 @@ class PyramidBox(object): ...@@ -66,6 +67,8 @@ class PyramidBox(object):
name='gt_box', shape=[4], dtype='float32', lod_level=1) name='gt_box', shape=[4], dtype='float32', lod_level=1)
self.gt_label = fluid.layers.data( self.gt_label = fluid.layers.data(
name='gt_label', shape=[1], dtype='int32', lod_level=1) name='gt_label', shape=[1], dtype='int32', lod_level=1)
self.difficult = fluid.layers.data(
name='gt_difficult', shape=[1], dtype='int32', lod_level=1)
def _vgg(self): def _vgg(self):
self.conv1 = conv_block(self.image, 2, [64] * 2, [3] * 2) self.conv1 = conv_block(self.image, 2, [64] * 2, [3] * 2)
...@@ -232,6 +235,39 @@ class PyramidBox(object): ...@@ -232,6 +235,39 @@ class PyramidBox(object):
self.prior_boxes = fluid.layers.concat(boxes) self.prior_boxes = fluid.layers.concat(boxes)
self.box_vars = fluid.layers.concat(vars) self.box_vars = fluid.layers.concat(vars)
def vgg_ssd(self, num_classes, image_shape): # tangxu
self.conv3_norm = self._l2_norm_scale(self.conv3)
self.conv4_norm = self._l2_norm_scale(self.conv4)
self.conv5_norm = self._l2_norm_scale(self.conv5)
mbox_locs, mbox_confs, box, box_var = fluid.layers.multi_box_head(
inputs=[
self.conv3_norm, self.conv4_norm, self.conv5_norm, self.conv6,
self.conv7, self.conv8
],
image=self.image,
num_classes=num_classes,
# min_ratio=20,
# max_ratio=90,
min_sizes=[16.0, 32.0, 64.0, 128.0, 256.0, 512.0],
max_sizes=[[], [], [], [], [], []],
# max_sizes=[[], 150.0, 195.0, 240.0, 285.0, 300.0],
aspect_ratios=[[1.], [1.], [1.], [1.], [1.], [1.]],
steps=[4.0, 8.0, 16.0, 32.0, 64.0, 128.0],
base_size=image_shape[2],
offset=0.5,
flip=False)
# locs, confs, box, box_var = vgg_extra_net(num_classes, image, image_shape)
# nmsed_out = fluid.layers.detection_output(
# locs, confs, box, box_var, nms_threshold=args.nms_threshold)
loss = fluid.layers.ssd_loss(mbox_locs, mbox_confs, self.gt_box,
self.gt_label, box, box_var)
loss = fluid.layers.reduce_sum(loss)
return loss
def train(self): def train(self):
face_loss = fluid.layers.ssd_loss( face_loss = fluid.layers.ssd_loss(
self.face_mbox_loc, self.face_mbox_conf, self.gt_box, self.gt_label, self.face_mbox_loc, self.face_mbox_conf, self.gt_box, self.gt_label,
......
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import image_util
from paddle.utils.image_util import *
import random
from PIL import Image
from PIL import ImageDraw
import numpy as np
import xml.etree.ElementTree
import os
import time
import copy
class Settings(object):
def __init__(self,
dataset=None,
data_dir=None,
label_file=None,
resize_h=300,
resize_w=300,
mean_value=[127.5, 127.5, 127.5],
apply_distort=True,
apply_expand=True,
ap_version='11point',
toy=0):
self._dataset = dataset
self._ap_version = ap_version
self._toy = toy
self._data_dir = data_dir
self._apply_distort = apply_distort
self._apply_expand = apply_expand
self._resize_height = resize_h
self._resize_width = resize_w
self._img_mean = np.array(mean_value)[:, np.newaxis, np.newaxis].astype(
'float32')
self._expand_prob = 0.5
self._expand_max_ratio = 4
self._hue_prob = 0.5
self._hue_delta = 18
self._contrast_prob = 0.5
self._contrast_delta = 0.5
self._saturation_prob = 0.5
self._saturation_delta = 0.5
self._brightness_prob = 0.5
self._brightness_delta = 0.125
@property
def dataset(self):
return self._dataset
@property
def ap_version(self):
return self._ap_version
@property
def toy(self):
return self._toy
@property
def apply_distort(self):
return self._apply_expand
@property
def apply_distort(self):
return self._apply_distort
@property
def data_dir(self):
return self._data_dir
@data_dir.setter
def data_dir(self, data_dir):
self._data_dir = data_dir
@property
def label_list(self):
return self._label_list
@property
def resize_h(self):
return self._resize_height
@property
def resize_w(self):
return self._resize_width
@property
def img_mean(self):
return self._img_mean
def preprocess(img, bbox_labels, mode, settings):
img_width, img_height = img.size
sampled_labels = bbox_labels
if mode == 'train':
if settings._apply_distort:
img = image_util.distort_image(img, settings)
if settings._apply_expand:
img, bbox_labels, img_width, img_height = image_util.expand_image(
img, bbox_labels, img_width, img_height, settings)
# sampling
batch_sampler = []
# hard-code here
batch_sampler.append(
image_util.sampler(1, 50, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0))
batch_sampler.append(
image_util.sampler(1, 50, 0.3, 1.0, 1.0, 1.0, 1.0, 1.0))
batch_sampler.append(
image_util.sampler(1, 50, 0.3, 1.0, 1.0, 1.0, 1.0, 1.0))
batch_sampler.append(
image_util.sampler(1, 50, 0.3, 1.0, 1.0, 1.0, 1.0, 1.0))
batch_sampler.append(
image_util.sampler(1, 50, 0.3, 1.0, 1.0, 1.0, 1.0, 1.0))
sampled_bbox = image_util.generate_batch_samples(batch_sampler,
bbox_labels)
img = np.array(img)
if len(sampled_bbox) > 0:
idx = int(random.uniform(0, len(sampled_bbox)))
img, sampled_labels = image_util.crop_image(
img, bbox_labels, sampled_bbox[idx], img_width, img_height)
img = Image.fromarray(img)
img = img.resize((settings.resize_w, settings.resize_h), Image.ANTIALIAS)
img = np.array(img)
if mode == 'train':
mirror = int(random.uniform(0, 2))
if mirror == 1:
img = img[:, ::-1, :]
for i in xrange(len(sampled_labels)):
tmp = sampled_labels[i][1]
sampled_labels[i][1] = 1 - sampled_labels[i][3]
sampled_labels[i][3] = 1 - tmp
# HWC to CHW
if len(img.shape) == 3:
img = np.swapaxes(img, 1, 2)
img = np.swapaxes(img, 1, 0)
# RBG to BGR
img = img[[2, 1, 0], :, :]
img = img.astype('float32')
img -= settings.img_mean
img = img * 0.007843
return img, sampled_labels
def put_txt_in_dict(input_txt):
with open(input_txt, 'r') as f_dir:
lines_input_txt = f_dir.readlines()
dict_input_txt = {}
num_class = 0
for i in range(len(lines_input_txt)):
tmp_line_txt = lines_input_txt[i].strip('\n\t\r')
if '--' in tmp_line_txt:
if i != 0:
num_class += 1
dict_input_txt[num_class] = []
dict_name = tmp_line_txt
dict_input_txt[num_class].append(tmp_line_txt)
if '--' not in tmp_line_txt:
if len(tmp_line_txt) > 6:
# tmp_line_txt = tmp_line_txt[:-2]
split_str = tmp_line_txt.split(' ')
x1_min = float(split_str[0])
y1_min = float(split_str[1])
x2_max = float(split_str[2])
y2_max = float(split_str[3])
tmp_line_txt = str(x1_min) + ' ' + str(y1_min) + ' ' + str(
x2_max) + ' ' + str(y2_max)
dict_input_txt[num_class].append(tmp_line_txt)
else:
dict_input_txt[num_class].append(tmp_line_txt)
return dict_input_txt
def pyramidbox(settings, file_list, mode, shuffle):
dict_input_txt = {}
dict_input_txt = put_txt_in_dict(file_list)
def reader():
if mode == 'train' and shuffle:
random.shuffle(dict_input_txt)
for index_image in range(len(dict_input_txt)):
image_name = dict_input_txt[index_image][0] + '.jpg'
image_path = os.path.join(settings.data_dir, image_name)
im = Image.open(image_path)
if im.mode == 'L':
im = im.convert('RGB')
im_width, im_height = im.size
# layout: category_id | xmin | ymin | xmax | ymax | iscrowd
bbox_labels = []
for index_box in range(len(dict_input_txt[index_image])):
if index_box >= 2:
bbox_sample = []
temp_info_box = dict_input_txt[index_image][
index_box].split(' ')
xmin = float(temp_info_box[0])
ymin = float(temp_info_box[1])
w = float(temp_info_box[2])
h = float(temp_info_box[3])
xmax = xmin + w
ymax = ymin + h
bbox_sample.append(float(xmin) / im_width)
bbox_sample.append(float(ymin) / im_height)
bbox_sample.append(float(xmax) / im_width)
bbox_sample.append(float(ymax) / im_height)
bbox_labels.append(bbox_sample)
im, sample_labels = preprocess(im, bbox_labels, mode, settings)
sample_labels = np.array(sample_labels)
if len(sample_labels) == 0: continue
im = im.astype('float32')
boxes = sample_labels[:, 0:4]
lbls = [1] * len(boxes)
difficults = [1] * len(boxes)
yield im, boxes, lbls, difficults
return reader
def train(settings, file_list, shuffle=True):
return pyramidbox(settings, file_list, 'train', shuffle)
def test(settings, file_list):
return pyramidbox(settings, file_list, 'test', False)
import os import os
import numpy as np import numpy as np
import time
import argparse import argparse
import functools import functools
import reader
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from pyramidbox import PyramidBox from pyramidbox import PyramidBox
...@@ -10,53 +12,127 @@ from utility import add_arguments, print_arguments ...@@ -10,53 +12,127 @@ from utility import add_arguments, print_arguments
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser) add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable # yapf: disable
add_arg('batch_size', int, 32, "Minibatch size.") add_arg('parallel', bool, True, "parallel")
add_arg('use_gpu', bool, True, "Whether use GPU.") add_arg('learning_rate', float, 0.0001, "Learning rate.")
add_arg('parallel', bool, True, "Parallel.") add_arg('batch_size', int, 16, "Minibatch size.")
add_arg('pretrained_model', str, "./vgg_model/", "The init model path.") add_arg('num_passes', int, 120, "Epoch number.")
add_arg('use_gpu', bool, True, "Whether use GPU.")
add_arg('dataset', str, 'WIDERFACE', "coco2014, coco2017, and pascalvoc.")
add_arg('model_save_dir', str, 'model', "The path to save model.")
add_arg('pretrained_model', str, './vgg_model/', "The init model path.")
add_arg('resize_h', int, 640, "The resized image height.")
add_arg('resize_w', int, 640, "The resized image height.")
#yapf: enable #yapf: enable
def train(args,
learning_rate,
batch_size,
pretrained_model):
network = PyramidBox([3, 640, 640]) def train(args, data_args, learning_rate, batch_size, pretrained_model,
face_loss, head_loss = network.train() num_passes):
loss = face_loss + head_loss
num_classes = 2
devices = os.getenv("CUDA_VISIBLE_DEVICES") or ""
devices_num = len(devices.split(","))
image_shape = [3, data_args.resize_h, data_args.resize_w]
test_program, face_map_eval, head_map_eval = network.test() network = PyramidBox(image_shape)
loss = network.vgg_ssd(num_classes, image_shape)
epocs = 12880 / batch_size
boundaries = [epocs * 100, epocs * 125, epocs * 150]
values = [
learning_rate, learning_rate * 0.1, learning_rate * 0.01,
learning_rate * 0.001
]
epocs = 19200 / batch_size
boundaries = [epocs * 40, epocs * 60, epocs * 80, epocs * 100]
lr = learning_rate
values = [lr, lr * 0.5, lr * 0.25, lr * 0.1, lr * 0.01]
optimizer = fluid.optimizer.RMSProp( optimizer = fluid.optimizer.RMSProp(
learning_rate=fluid.layers.piecewise_decay(boundaries, values), learning_rate=fluid.layers.piecewise_decay(boundaries, values),
regularization=fluid.regularizer.L2Decay(0.00005), ) regularization=fluid.regularizer.L2Decay(0.0005),
)
optimizer.minimize(loss) optimizer.minimize(loss)
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace() place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
# fluid.io.save_inference_model('./vgg_model/', ['image'], [loss], exe) # fluid.io.save_inference_model('./vgg_model/', ['image'], [loss], exe)
if pretrained_model: if pretrained_model:
def if_exist(var): def if_exist(var):
return os.path.exists(os.path.join(pretrained_model, var.name)) return os.path.exists(os.path.join(pretrained_model, var.name))
print('Load pre-trained model.')
fluid.io.load_vars(exe, pretrained_model, predicate=if_exist) fluid.io.load_vars(exe, pretrained_model, predicate=if_exist)
# print(fluid.default_main_program()) if args.parallel:
# print(test_program) train_exe = fluid.ParallelExecutor(
# fluid.io.save_persistables(exe, "model") use_cuda=args.use_gpu, loss_name=loss.name)
train_reader = paddle.batch(
reader.train(data_args, train_file_list), batch_size=batch_size)
feeder = fluid.DataFeeder(
place=place,
feed_list=[
network.image, network.gt_box, network.gt_label, network.difficult
])
def save_model(postfix):
model_path = os.path.join(model_save_dir, postfix)
if os.path.isdir(model_path):
shutil.rmtree(model_path)
print 'save models to %s' % (model_path)
fluid.io.save_persistables(exe, model_path)
best_map = 0.
for pass_id in range(num_passes):
start_time = time.time()
prev_start_time = start_time
end_time = 0
for batch_id, data in enumerate(train_reader()):
prev_start_time = start_time
start_time = time.time()
if len(data) < devices_num: continue
if args.parallel:
loss_v, = train_exe.run(fetch_list=[loss.name],
feed=feeder.feed(data))
else:
loss_v, = exe.run(fluid.default_main_program(),
feed=feeder.feed(data),
fetch_list=[loss])
end_time = time.time()
loss_v = np.mean(np.array(loss_v))
if batch_id % 1 == 0:
print("Pass {0}, batch {1}, loss {2}, time {3}".format(
pass_id, batch_id, loss_v, start_time - prev_start_time))
test(pass_id, best_map)
if pass_id % 10 == 0 or pass_id == num_passes - 1:
save_model(str(pass_id))
print("Best test map {0}".format(best_map))
if __name__ == '__main__': if __name__ == '__main__':
args = parser.parse_args() args = parser.parse_args()
print_arguments(args) print_arguments(args)
train(args, data_dir = 'data/WIDERFACE/WIDER_train/images/'
learning_rate=0.01, train_file_list = 'label/train_gt_widerface.res'
batch_size=args.batch_size, val_file_list = 'label/val_gt_widerface.res'
pretrained_model=args.pretrained_model) model_save_dir = args.model_save_dir
data_args = reader.Settings(
dataset=args.dataset,
data_dir=data_dir,
resize_h=args.resize_h,
resize_w=args.resize_w,
ap_version='11point')
train(
args,
data_args=data_args,
learning_rate=0.01,
batch_size=args.batch_size,
pretrained_model=args.pretrained_model,
num_passes=args.num_passes)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册