未验证 提交 eedf9486 编写于 作者: J JiaQi Xu 提交者: GitHub

Update ssd_training.py

上级 1e83aab6
......@@ -103,7 +103,7 @@ class MultiBoxLoss(nn.Module):
class Generator(object):
def __init__(self,batch_size,
train_lines, val_lines, image_size,num_classes,
train_lines, image_size,num_classes,
saturation_var=0.5,
brightness_var=0.5,
contrast_var=0.5,
......@@ -115,9 +115,7 @@ class Generator(object):
aspect_ratio_range=[3./4., 4./3.]):
self.batch_size = batch_size
self.train_lines = train_lines
self.val_lines = val_lines
self.train_batches = len(train_lines)
self.val_batches = len(val_lines)
self.image_size = image_size
self.color_jitter = []
self.num_classes = num_classes
......@@ -232,14 +230,10 @@ class Generator(object):
new_targets = np.asarray(new_targets).reshape(-1, targets.shape[1])
return img, new_targets
def generate(self, train=True):
def generate(self):
while True:
if train:
shuffle(self.train_lines)
lines = self.train_lines
else:
shuffle(self.val_lines)
lines = self.val_lines
shuffle(self.train_lines)
lines = self.train_lines
inputs = []
targets = []
for annotation_line in lines:
......@@ -259,19 +253,18 @@ class Generator(object):
boxes = np.maximum(np.minimum(boxes,1),0)
y = np.concatenate([boxes,y[:,-1:]],axis=-1)
if train and self.do_crop:
if self.do_crop:
img, y = self.random_sized_crop(img, y)
img = imresize(img, self.image_size).astype('float32')
if train:
shuffle(self.color_jitter)
for jitter in self.color_jitter:
img = jitter(img)
if self.lighting_std:
img = self.lighting(img)
if self.hflip_prob > 0:
img, y = self.horizontal_flip(img, y)
if self.vflip_prob > 0:
img, y = self.vertical_flip(img, y)
shuffle(self.color_jitter)
for jitter in self.color_jitter:
img = jitter(img)
if self.lighting_std:
img = self.lighting(img)
if self.hflip_prob > 0:
img, y = self.horizontal_flip(img, y)
if self.vflip_prob > 0:
img, y = self.vertical_flip(img, y)
if len(y)==0:
continue
......@@ -282,4 +275,4 @@ class Generator(object):
tmp_targets = np.array(targets)
inputs = []
targets = []
yield tmp_inp, tmp_targets
\ No newline at end of file
yield tmp_inp, tmp_targets
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册