From b6e495172d5ec5bf1fc1b694973dd0e8dc976a97 Mon Sep 17 00:00:00 2001 From: Bubbliiiing <47347516+bubbliiiing@users.noreply.github.com> Date: Wed, 29 Jul 2020 11:50:20 +0800 Subject: [PATCH] Update dataloader.py --- utils/dataloader.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/utils/dataloader.py b/utils/dataloader.py index 100ba37..7e082de 100644 --- a/utils/dataloader.py +++ b/utils/dataloader.py @@ -104,8 +104,9 @@ class SSDDataset(Dataset): n = self.train_batches temp_index = index % n while True: - img, y = self.get_random_data(lines[temp_index], self.image_size[0:2]) + img, y = self.get_random_data(lines[index], self.image_size[0:2]) if len(y)==0: + index = (index + 1) % n continue boxes = np.array(y[:,:4],dtype=np.float32) boxes[:,0] = boxes[:,0]/self.image_size[1] @@ -114,9 +115,9 @@ class SSDDataset(Dataset): boxes[:,3] = boxes[:,3]/self.image_size[0] boxes = np.maximum(np.minimum(boxes,1),0) if ((boxes[:,3]-boxes[:,1])<=0).any() and ((boxes[:,2]-boxes[:,0])<=0).any(): + index = (index + 1) % n continue y = np.concatenate([boxes,y[:,-1:]],axis=-1) - temp_index = (temp_index + 1) % n break img = np.array(img, dtype=np.float32) -- GitLab