未验证 提交 09452a90 编写于 作者: Q qingqing01 提交者: GitHub

Data anchor sampling. (#998)

* Add data anchor sampling in PyramidBox paper.
上级 e15da197
...@@ -100,6 +100,76 @@ def generate_sample(sampler, image_width, image_height): ...@@ -100,6 +100,76 @@ def generate_sample(sampler, image_width, image_height):
return sampled_bbox return sampled_bbox
def data_anchor_sampling(sampler, bbox_labels, image_width, image_height,
scale_array, resize_width, resize_height):
num_gt = len(bbox_labels)
# np.random.randint range: [low, high)
rand_idx = np.random.randint(0, num_gt) if num_gt != 0 else 0
if num_gt != 0:
norm_xmin = bbox_labels[rand_idx][0]
norm_ymin = bbox_labels[rand_idx][1]
norm_xmax = bbox_labels[rand_idx][2]
norm_ymax = bbox_labels[rand_idx][3]
xmin = norm_xmin * image_width
ymin = norm_ymin * image_height
wid = image_width * (norm_xmax - norm_xmin)
hei = image_height * (norm_ymax - norm_ymin)
range_size = 0
for scale_ind in range(0, len(scale_array) - 1):
area = wid * hei
if area > scale_array[scale_ind] ** 2 and area < \
scale_array[scale_ind + 1] ** 2:
range_size = scale_ind + 1
break
scale_choose = 0.0
if range_size == 0:
rand_idx_size = range_size + 1
else:
# np.random.randint range: [low, high)
rng_rand_size = np.random.randint(0, range_size)
rand_idx_size = rng_rand_size % range_size
scale_choose = random.uniform(scale_array[rand_idx_size] / 2.0,
2.0 * scale_array[rand_idx_size])
sample_bbox_size = wid * resize_width / scale_choose
w_off_orig = 0.0
h_off_orig = 0.0
if sample_bbox_size < max(image_height, image_width):
if wid <= sample_bbox_size:
w_off_orig = random.uniform(xmin + wid - sample_bbox_size, xmin)
else:
w_off_orig = random.uniform(xmin, xmin + wid - sample_bbox_size)
if hei <= sample_bbox_size:
h_off_orig = random.uniform(ymin + hei - sample_bbox_size, ymin)
else:
h_off_orig = random.uniform(ymin, ymin + hei - sample_bbox_size)
else:
w_off_orig = random.uniform(image_width - sample_bbox_size, 0.0)
h_off_orig = random.uniform(image_height - sample_bbox_size, 0.0)
w_off_orig = math.floor(w_off_orig)
h_off_orig = math.floor(h_off_orig)
# Figure out top left coordinates.
w_off = 0.0
h_off = 0.0
w_off = float(w_off_orig / image_width)
h_off = float(h_off_orig / image_height)
sampled_bbox = bbox(w_off, h_off,
w_off + float(sample_bbox_size / image_width),
h_off + float(sample_bbox_size / image_height))
return sampled_bbox
def jaccard_overlap(sample_bbox, object_bbox): def jaccard_overlap(sample_bbox, object_bbox):
if sample_bbox.xmin >= object_bbox.xmax or \ if sample_bbox.xmin >= object_bbox.xmax or \
sample_bbox.xmax <= object_bbox.xmin or \ sample_bbox.xmax <= object_bbox.xmin or \
...@@ -161,8 +231,6 @@ def satisfy_sample_constraint(sampler, sample_bbox, bbox_labels): ...@@ -161,8 +231,6 @@ def satisfy_sample_constraint(sampler, sample_bbox, bbox_labels):
def generate_batch_samples(batch_sampler, bbox_labels, image_width, def generate_batch_samples(batch_sampler, bbox_labels, image_width,
image_height): image_height):
sampled_bbox = [] sampled_bbox = []
index = []
c = 0
for sampler in batch_sampler: for sampler in batch_sampler:
found = 0 found = 0
for i in range(sampler.max_trial): for i in range(sampler.max_trial):
...@@ -172,8 +240,24 @@ def generate_batch_samples(batch_sampler, bbox_labels, image_width, ...@@ -172,8 +240,24 @@ def generate_batch_samples(batch_sampler, bbox_labels, image_width,
if satisfy_sample_constraint(sampler, sample_bbox, bbox_labels): if satisfy_sample_constraint(sampler, sample_bbox, bbox_labels):
sampled_bbox.append(sample_bbox) sampled_bbox.append(sample_bbox)
found = found + 1 found = found + 1
index.append(c) return sampled_bbox
c = c + 1
def generate_batch_random_samples(batch_sampler, bbox_labels, image_width,
image_height, scale_array, resize_width,
resize_height):
sampled_bbox = []
for sampler in batch_sampler:
found = 0
for i in range(sampler.max_trial):
if found >= sampler.max_sample:
break
sample_bbox = data_anchor_sampling(
sampler, bbox_labels, image_width, image_height, scale_array,
resize_width, resize_height)
if satisfy_sample_constraint(sampler, sample_bbox, bbox_labels):
sampled_bbox.append(sample_bbox)
found = found + 1
return sampled_bbox return sampled_bbox
...@@ -243,42 +327,79 @@ def crop_image(img, bbox_labels, sample_bbox, image_width, image_height): ...@@ -243,42 +327,79 @@ def crop_image(img, bbox_labels, sample_bbox, image_width, image_height):
xmax = int(sample_bbox.xmax * image_width) xmax = int(sample_bbox.xmax * image_width)
ymin = int(sample_bbox.ymin * image_height) ymin = int(sample_bbox.ymin * image_height)
ymax = int(sample_bbox.ymax * image_height) ymax = int(sample_bbox.ymax * image_height)
sample_img = img[ymin:ymax, xmin:xmax] sample_img = img[ymin:ymax, xmin:xmax]
sample_labels = transform_labels(bbox_labels, sample_bbox) sample_labels = transform_labels(bbox_labels, sample_bbox)
return sample_img, sample_labels return sample_img, sample_labels
def crop_image_sampling(img, bbox_labels, sample_bbox, image_width,
image_height, resize_width, resize_height):
# no clipping here
xmin = int(sample_bbox.xmin * image_width)
xmax = int(sample_bbox.xmax * image_width)
ymin = int(sample_bbox.ymin * image_height)
ymax = int(sample_bbox.ymax * image_height)
w_off = xmin
h_off = ymin
width = xmax - xmin
height = ymax - ymin
cross_xmin = max(0.0, float(w_off))
cross_ymin = max(0.0, float(h_off))
cross_xmax = min(float(w_off + width - 1.0), float(image_width))
cross_ymax = min(float(h_off + height - 1.0), float(image_height))
cross_width = cross_xmax - cross_xmin
cross_height = cross_ymax - cross_ymin
roi_xmin = 0 if w_off >= 0 else abs(w_off)
roi_ymin = 0 if h_off >= 0 else abs(h_off)
roi_width = cross_width
roi_height = cross_height
sample_img = np.zeros((width, height, 3))
sample_img[roi_xmin : roi_xmin + roi_width, roi_ymin : roi_ymin + roi_height] = \
img[cross_xmin : cross_xmin + cross_width, cross_ymin : cross_ymin + cross_height]
sample_img = cv2.resize(
sample_img, (resize_width, resize_height), interpolation=cv2.INTER_AREA)
sample_labels = transform_labels(bbox_labels, sample_bbox)
return sample_img, sample_labels
def random_brightness(img, settings): def random_brightness(img, settings):
prob = random.uniform(0, 1) prob = random.uniform(0, 1)
if prob < settings._brightness_prob: if prob < settings.brightness_prob:
delta = random.uniform(-settings._brightness_delta, delta = random.uniform(-settings.brightness_delta,
settings._brightness_delta) + 1 settings.brightness_delta) + 1
img = ImageEnhance.Brightness(img).enhance(delta) img = ImageEnhance.Brightness(img).enhance(delta)
return img return img
def random_contrast(img, settings): def random_contrast(img, settings):
prob = random.uniform(0, 1) prob = random.uniform(0, 1)
if prob < settings._contrast_prob: if prob < settings.contrast_prob:
delta = random.uniform(-settings._contrast_delta, delta = random.uniform(-settings.contrast_delta,
settings._contrast_delta) + 1 settings.contrast_delta) + 1
img = ImageEnhance.Contrast(img).enhance(delta) img = ImageEnhance.Contrast(img).enhance(delta)
return img return img
def random_saturation(img, settings): def random_saturation(img, settings):
prob = random.uniform(0, 1) prob = random.uniform(0, 1)
if prob < settings._saturation_prob: if prob < settings.saturation_prob:
delta = random.uniform(-settings._saturation_delta, delta = random.uniform(-settings.saturation_delta,
settings._saturation_delta) + 1 settings.saturation_delta) + 1
img = ImageEnhance.Color(img).enhance(delta) img = ImageEnhance.Color(img).enhance(delta)
return img return img
def random_hue(img, settings): def random_hue(img, settings):
prob = random.uniform(0, 1) prob = random.uniform(0, 1)
if prob < settings._hue_prob: if prob < settings.hue_prob:
delta = random.uniform(-settings._hue_delta, settings._hue_delta) delta = random.uniform(-settings.hue_delta, settings.hue_delta)
img_hsv = np.array(img.convert('HSV')) img_hsv = np.array(img.convert('HSV'))
img_hsv[:, :, 0] = img_hsv[:, :, 0] + delta img_hsv[:, :, 0] = img_hsv[:, :, 0] + delta
img = Image.fromarray(img_hsv, mode='HSV').convert('RGB') img = Image.fromarray(img_hsv, mode='HSV').convert('RGB')
...@@ -303,9 +424,9 @@ def distort_image(img, settings): ...@@ -303,9 +424,9 @@ def distort_image(img, settings):
def expand_image(img, bbox_labels, img_width, img_height, settings): def expand_image(img, bbox_labels, img_width, img_height, settings):
prob = random.uniform(0, 1) prob = random.uniform(0, 1)
if prob < settings._expand_prob: if prob < settings.expand_prob:
if settings._expand_max_ratio - 1 >= 0.01: if settings.expand_max_ratio - 1 >= 0.01:
expand_ratio = random.uniform(1, settings._expand_max_ratio) expand_ratio = random.uniform(1, settings.expand_max_ratio)
height = int(img_height * expand_ratio) height = int(img_height * expand_ratio)
width = int(img_width * expand_ratio) width = int(img_width * expand_ratio)
h_off = math.floor(random.uniform(0, height - img_height)) h_off = math.floor(random.uniform(0, height - img_height))
...@@ -314,7 +435,7 @@ def expand_image(img, bbox_labels, img_width, img_height, settings): ...@@ -314,7 +435,7 @@ def expand_image(img, bbox_labels, img_width, img_height, settings):
(width - w_off) / img_width, (width - w_off) / img_width,
(height - h_off) / img_height) (height - h_off) / img_height)
expand_img = np.ones((height, width, 3)) expand_img = np.ones((height, width, 3))
expand_img = np.uint8(expand_img * np.squeeze(settings._img_mean)) expand_img = np.uint8(expand_img * np.squeeze(settings.img_mean))
expand_img = Image.fromarray(expand_img) expand_img = Image.fromarray(expand_img)
expand_img.paste(img, (int(w_off), int(h_off))) expand_img.paste(img, (int(w_off), int(h_off)))
bbox_labels = transform_labels(bbox_labels, expand_bbox) bbox_labels = transform_labels(bbox_labels, expand_bbox)
......
...@@ -22,6 +22,7 @@ import xml.etree.ElementTree ...@@ -22,6 +22,7 @@ import xml.etree.ElementTree
import os import os
import time import time
import copy import copy
import random
class Settings(object): class Settings(object):
...@@ -36,101 +37,80 @@ class Settings(object): ...@@ -36,101 +37,80 @@ class Settings(object):
apply_expand=True, apply_expand=True,
ap_version='11point', ap_version='11point',
toy=0): toy=0):
self._dataset = dataset self.dataset = dataset
self._ap_version = ap_version self.ap_version = ap_version
self._toy = toy self.toy = toy
self._data_dir = data_dir self.data_dir = data_dir
self._apply_distort = apply_distort self.apply_distort = apply_distort
self._apply_expand = apply_expand self.apply_expand = apply_expand
self._resize_height = resize_h self.resize_height = resize_h
self._resize_width = resize_w self.resize_width = resize_w
self._img_mean = np.array(mean_value)[:, np.newaxis, np.newaxis].astype( self.img_mean = np.array(mean_value)[:, np.newaxis, np.newaxis].astype(
'float32') 'float32')
self._expand_prob = 0.5 self.expand_prob = 0.5
self._expand_max_ratio = 4 self.expand_max_ratio = 4
self._hue_prob = 0.5 self.hue_prob = 0.5
self._hue_delta = 18 self.hue_delta = 18
self._contrast_prob = 0.5 self.contrast_prob = 0.5
self._contrast_delta = 0.5 self.contrast_delta = 0.5
self._saturation_prob = 0.5 self.saturation_prob = 0.5
self._saturation_delta = 0.5 self.saturation_delta = 0.5
self._brightness_prob = 0.5 self.brightness_prob = 0.5
# _brightness_delta is the normalized value by 256 # _brightness_delta is the normalized value by 256
# self._brightness_delta = 32 # self._brightness_delta = 32
self._brightness_delta = 0.125 self.brightness_delta = 0.125
self.scale = 0.007843 # 1 / 127.5
@property self.data_anchor_sampling_prob = 0.5
def dataset(self):
return self._dataset
@property
def ap_version(self):
return self._ap_version
@property
def toy(self):
return self._toy
@property
def apply_expand(self):
return self._apply_expand
@property
def apply_distort(self):
return self._apply_distort
@property
def data_dir(self):
return self._data_dir
@data_dir.setter
def data_dir(self, data_dir):
self._data_dir = data_dir
@property
def label_list(self):
return self._label_list
@property
def resize_h(self):
return self._resize_height
@property
def resize_w(self):
return self._resize_width
@property
def img_mean(self):
return self._img_mean
def preprocess(img, bbox_labels, mode, settings): def preprocess(img, bbox_labels, mode, settings):
img_width, img_height = img.size img_width, img_height = img.size
sampled_labels = bbox_labels sampled_labels = bbox_labels
if mode == 'train': if mode == 'train':
if settings._apply_distort: if settings.apply_distort:
img = image_util.distort_image(img, settings) img = image_util.distort_image(img, settings)
if settings._apply_expand: if settings.apply_expand:
img, bbox_labels, img_width, img_height = image_util.expand_image( img, bbox_labels, img_width, img_height = image_util.expand_image(
img, bbox_labels, img_width, img_height, settings) img, bbox_labels, img_width, img_height, settings)
# sampling # sampling
batch_sampler = [] batch_sampler = []
prob = random.uniform(0., 1.)
if prob > settings.data_anchor_sampling_prob:
scale_array = np.array([16, 32, 64, 128, 256, 512])
batch_sampler.append(
image_util.sampler(1, 10, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.2,
0.0, True))
sampled_bbox = image_util.generate_batch_random_samples(
batch_sampler, bbox_labels, img_width, img_height, scale_array,
settings.resize_width, settings.resize_height)
img = np.array(img)
if len(sampled_bbox) > 0:
idx = int(random.uniform(0, len(sampled_bbox)))
img, sampled_labels = image_util.crop_image_sampling(
img, bbox_labels, sampled_bbox[idx], img_width, img_height,
resize_width, resize_heigh)
img = Image.fromarray(img)
else:
# hard-code here # hard-code here
batch_sampler.append( batch_sampler.append(
image_util.sampler(1, 50, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, image_util.sampler(1, 50, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0,
True)) 0.0, True))
batch_sampler.append( batch_sampler.append(
image_util.sampler(1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, image_util.sampler(1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0,
True)) 0.0, True))
batch_sampler.append( batch_sampler.append(
image_util.sampler(1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, image_util.sampler(1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0,
True)) 0.0, True))
batch_sampler.append( batch_sampler.append(
image_util.sampler(1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, image_util.sampler(1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0,
True)) 0.0, True))
batch_sampler.append( batch_sampler.append(
image_util.sampler(1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, image_util.sampler(1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0,
True)) 0.0, True))
sampled_bbox = image_util.generate_batch_samples( sampled_bbox = image_util.generate_batch_samples(
batch_sampler, bbox_labels, img_width, img_height) batch_sampler, bbox_labels, img_width, img_height)
...@@ -141,7 +121,9 @@ def preprocess(img, bbox_labels, mode, settings): ...@@ -141,7 +121,9 @@ def preprocess(img, bbox_labels, mode, settings):
img, bbox_labels, sampled_bbox[idx], img_width, img_height) img, bbox_labels, sampled_bbox[idx], img_width, img_height)
img = Image.fromarray(img) img = Image.fromarray(img)
img = img.resize((settings.resize_w, settings.resize_h), Image.ANTIALIAS)
img = img.resize((settings.resize_width, settings.resize_height),
Image.ANTIALIAS)
img = np.array(img) img = np.array(img)
if mode == 'train': if mode == 'train':
...@@ -160,7 +142,7 @@ def preprocess(img, bbox_labels, mode, settings): ...@@ -160,7 +142,7 @@ def preprocess(img, bbox_labels, mode, settings):
img = img[[2, 1, 0], :, :] img = img[[2, 1, 0], :, :]
img = img.astype('float32') img = img.astype('float32')
img -= settings.img_mean img -= settings.img_mean
img = img * 0.007843 img = img * settings.scale
return img, sampled_labels return img, sampled_labels
...@@ -180,7 +162,6 @@ def put_txt_in_dict(input_txt): ...@@ -180,7 +162,6 @@ def put_txt_in_dict(input_txt):
dict_input_txt[num_class].append(tmp_line_txt) dict_input_txt[num_class].append(tmp_line_txt)
if '--' not in tmp_line_txt: if '--' not in tmp_line_txt:
if len(tmp_line_txt) > 6: if len(tmp_line_txt) > 6:
# tmp_line_txt = tmp_line_txt[:-2]
split_str = tmp_line_txt.split(' ') split_str = tmp_line_txt.split(' ')
x1_min = float(split_str[0]) x1_min = float(split_str[0])
y1_min = float(split_str[1]) y1_min = float(split_str[1])
...@@ -288,8 +269,8 @@ def infer(settings, image_path): ...@@ -288,8 +269,8 @@ def infer(settings, image_path):
if img.mode == 'L': if img.mode == 'L':
img = im.convert('RGB') img = im.convert('RGB')
im_width, im_height = img.size im_width, im_height = img.size
if settings.resize_w and settings.resize_h: if settings.resize_width and settings.resize_height:
img = img.resize((settings.resize_w, settings.resize_h), img = img.resize((settings.resize_width, settings.resize_height),
Image.ANTIALIAS) Image.ANTIALIAS)
img = np.array(img) img = np.array(img)
# HWC to CHW # HWC to CHW
...@@ -300,9 +281,7 @@ def infer(settings, image_path): ...@@ -300,9 +281,7 @@ def infer(settings, image_path):
img = img[[2, 1, 0], :, :] img = img[[2, 1, 0], :, :]
img = img.astype('float32') img = img.astype('float32')
img -= settings.img_mean img -= settings.img_mean
img = img * 0.007843 img = img * settings.scale
img = [img] return np.array([img])
img = np.array(img)
return img
return batch_reader return batch_reader
...@@ -21,28 +21,35 @@ add_arg('batch_size', int, 12, "Minibatch size.") ...@@ -21,28 +21,35 @@ add_arg('batch_size', int, 12, "Minibatch size.")
add_arg('num_passes', int, 120, "Epoch number.") add_arg('num_passes', int, 120, "Epoch number.")
add_arg('use_gpu', bool, True, "Whether use GPU.") add_arg('use_gpu', bool, True, "Whether use GPU.")
add_arg('use_pyramidbox', bool, True, "Whether use PyramidBox model.") add_arg('use_pyramidbox', bool, True, "Whether use PyramidBox model.")
add_arg('dataset', str, 'WIDERFACE', "coco2014, coco2017, and pascalvoc.") add_arg('model_save_dir', str, 'output', "The path to save model.")
add_arg('model_save_dir', str, 'model', "The path to save model.")
add_arg('pretrained_model', str, './pretrained/', "The init model path.") add_arg('pretrained_model', str, './pretrained/', "The init model path.")
add_arg('resize_h', int, 640, "The resized image height.") add_arg('resize_h', int, 640, "The resized image height.")
add_arg('resize_w', int, 640, "The resized image height.") add_arg('resize_w', int, 640, "The resized image height.")
#yapf: enable #yapf: enable
def train(args, data_args, learning_rate, batch_size, pretrained_model, def train(args, config, train_file_list, optimizer_method):
num_passes, optimizer_method): learning_rate = args.learning_rate
batch_size = args.batch_size
num_passes = args.num_passes
height = args.resize_h
width = args.resize_w
use_gpu = args.use_gpu
use_pyramidbox = args.use_pyramidbox
model_save_dir = args.model_save_dir
pretrained_model = args.pretrained_model
num_classes = 2 num_classes = 2
image_shape = [3, height, width]
devices = os.getenv("CUDA_VISIBLE_DEVICES") or "" devices = os.getenv("CUDA_VISIBLE_DEVICES") or ""
devices_num = len(devices.split(",")) devices_num = len(devices.split(","))
image_shape = [3, data_args.resize_h, data_args.resize_w]
fetches = [] fetches = []
network = PyramidBox(image_shape, num_classes, network = PyramidBox(image_shape, num_classes,
sub_network=args.use_pyramidbox) sub_network=use_pyramidbox)
if args.use_pyramidbox: if use_pyramidbox:
face_loss, head_loss, loss = network.train() face_loss, head_loss, loss = network.train()
fetches = [face_loss, head_loss] fetches = [face_loss, head_loss]
else: else:
...@@ -70,9 +77,9 @@ def train(args, data_args, learning_rate, batch_size, pretrained_model, ...@@ -70,9 +77,9 @@ def train(args, data_args, learning_rate, batch_size, pretrained_model,
) )
optimizer.minimize(loss) optimizer.minimize(loss)
# fluid.memory_optimize(fluid.default_main_program()) #fluid.memory_optimize(fluid.default_main_program())
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace() place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
...@@ -80,7 +87,7 @@ def train(args, data_args, learning_rate, batch_size, pretrained_model, ...@@ -80,7 +87,7 @@ def train(args, data_args, learning_rate, batch_size, pretrained_model,
if pretrained_model: if pretrained_model:
if pretrained_model.isdigit(): if pretrained_model.isdigit():
start_pass = int(pretrained_model) + 1 start_pass = int(pretrained_model) + 1
pretrained_model = os.path.join(args.model_save_dir, pretrained_model) pretrained_model = os.path.join(model_save_dir, pretrained_model)
print("Resume from %s " %(pretrained_model)) print("Resume from %s " %(pretrained_model))
if not os.path.exists(pretrained_model): if not os.path.exists(pretrained_model):
...@@ -92,10 +99,10 @@ def train(args, data_args, learning_rate, batch_size, pretrained_model, ...@@ -92,10 +99,10 @@ def train(args, data_args, learning_rate, batch_size, pretrained_model,
if args.parallel: if args.parallel:
train_exe = fluid.ParallelExecutor( train_exe = fluid.ParallelExecutor(
use_cuda=args.use_gpu, loss_name=loss.name) use_cuda=use_gpu, loss_name=loss.name)
train_reader = paddle.batch( train_reader = paddle.batch(
reader.train(data_args, train_file_list), batch_size=batch_size) reader.train(config, train_file_list), batch_size=batch_size)
feeder = fluid.DataFeeder(place=place, feed_list=network.feeds()) feeder = fluid.DataFeeder(place=place, feed_list=network.feeds())
def save_model(postfix): def save_model(postfix):
...@@ -143,22 +150,12 @@ if __name__ == '__main__': ...@@ -143,22 +150,12 @@ if __name__ == '__main__':
data_dir = 'data/WIDERFACE/WIDER_train/images/' data_dir = 'data/WIDERFACE/WIDER_train/images/'
train_file_list = 'label/train_gt_widerface.res' train_file_list = 'label/train_gt_widerface.res'
val_file_list = 'label/val_gt_widerface.res'
model_save_dir = args.model_save_dir
data_args = reader.Settings( config = reader.Settings(
dataset=args.dataset,
data_dir=data_dir, data_dir=data_dir,
resize_h=args.resize_h, resize_h=args.resize_h,
resize_w=args.resize_w, resize_w=args.resize_w,
apply_expand=False, apply_expand=False,
mean_value=[104., 117., 123], mean_value=[104., 117., 123],
ap_version='11point') ap_version='11point')
train( train(args, config, train_file_list, optimizer_method="momentum")
args,
data_args=data_args,
learning_rate=args.learning_rate,
batch_size=args.batch_size,
pretrained_model=args.pretrained_model,
num_passes=args.num_passes,
optimizer_method="momentum")
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册