From eedf94867406383bea97199d0572b630a0d71df2 Mon Sep 17 00:00:00 2001 From: JiaQi Xu <47347516+bubbliiiing@users.noreply.github.com> Date: Fri, 20 Mar 2020 19:20:43 +0800 Subject: [PATCH] Update ssd_training.py --- nets/ssd_training.py | 37 +++++++++++++++---------------------- 1 file changed, 15 insertions(+), 22 deletions(-) diff --git a/nets/ssd_training.py b/nets/ssd_training.py index 684822d..be609b3 100644 --- a/nets/ssd_training.py +++ b/nets/ssd_training.py @@ -103,7 +103,7 @@ class MultiBoxLoss(nn.Module): class Generator(object): def __init__(self,batch_size, - train_lines, val_lines, image_size,num_classes, + train_lines, image_size,num_classes, saturation_var=0.5, brightness_var=0.5, contrast_var=0.5, @@ -115,9 +115,7 @@ class Generator(object): aspect_ratio_range=[3./4., 4./3.]): self.batch_size = batch_size self.train_lines = train_lines - self.val_lines = val_lines self.train_batches = len(train_lines) - self.val_batches = len(val_lines) self.image_size = image_size self.color_jitter = [] self.num_classes = num_classes @@ -232,14 +230,10 @@ class Generator(object): new_targets = np.asarray(new_targets).reshape(-1, targets.shape[1]) return img, new_targets - def generate(self, train=True): + def generate(self): while True: - if train: - shuffle(self.train_lines) - lines = self.train_lines - else: - shuffle(self.val_lines) - lines = self.val_lines + shuffle(self.train_lines) + lines = self.train_lines inputs = [] targets = [] for annotation_line in lines: @@ -259,19 +253,18 @@ class Generator(object): boxes = np.maximum(np.minimum(boxes,1),0) y = np.concatenate([boxes,y[:,-1:]],axis=-1) - if train and self.do_crop: + if self.do_crop: img, y = self.random_sized_crop(img, y) img = imresize(img, self.image_size).astype('float32') - if train: - shuffle(self.color_jitter) - for jitter in self.color_jitter: - img = jitter(img) - if self.lighting_std: - img = self.lighting(img) - if self.hflip_prob > 0: - img, y = self.horizontal_flip(img, y) - if self.vflip_prob > 0: - img, y = self.vertical_flip(img, y) + shuffle(self.color_jitter) + for jitter in self.color_jitter: + img = jitter(img) + if self.lighting_std: + img = self.lighting(img) + if self.hflip_prob > 0: + img, y = self.horizontal_flip(img, y) + if self.vflip_prob > 0: + img, y = self.vertical_flip(img, y) if len(y)==0: continue @@ -282,4 +275,4 @@ class Generator(object): tmp_targets = np.array(targets) inputs = [] targets = [] - yield tmp_inp, tmp_targets \ No newline at end of file + yield tmp_inp, tmp_targets -- GitLab