From a4a0a30f10bbf2d1fa8ce2cb5e0b2aaaddf74cc2 Mon Sep 17 00:00:00 2001 From: Bubbliiiing <47347516+bubbliiiing@users.noreply.github.com> Date: Wed, 11 Nov 2020 15:05:00 +0800 Subject: [PATCH] Update train_with_tensorboard.py --- train_with_tensorboard.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/train_with_tensorboard.py b/train_with_tensorboard.py index d85b4eb..dd1a732 100644 --- a/train_with_tensorboard.py +++ b/train_with_tensorboard.py @@ -207,9 +207,9 @@ if __name__ == "__main__": if Use_Data_Loader: train_dataset = YoloDataset(lines[:num_train], (input_shape[0], input_shape[1]), mosaic=mosaic) val_dataset = YoloDataset(lines[num_train:], (input_shape[0], input_shape[1]), mosaic=False) - gen = DataLoader(train_dataset, batch_size=Batch_size, num_workers=4, pin_memory=True, + gen = DataLoader(train_dataset, shuffle=True, batch_size=Batch_size, num_workers=4, pin_memory=True, drop_last=True, collate_fn=yolo_dataset_collate) - gen_val = DataLoader(val_dataset, batch_size=Batch_size, num_workers=4,pin_memory=True, + gen_val = DataLoader(val_dataset, shuffle=True, batch_size=Batch_size, num_workers=4,pin_memory=True, drop_last=True, collate_fn=yolo_dataset_collate) else: gen = Generator(Batch_size, lines[:num_train], @@ -244,9 +244,9 @@ if __name__ == "__main__": if Use_Data_Loader: train_dataset = YoloDataset(lines[:num_train], (input_shape[0], input_shape[1]), mosaic=mosaic) val_dataset = YoloDataset(lines[num_train:], (input_shape[0], input_shape[1]), mosaic=False) - gen = DataLoader(train_dataset, batch_size=Batch_size, num_workers=4, pin_memory=True, + gen = DataLoader(train_dataset, shuffle=True, batch_size=Batch_size, num_workers=4, pin_memory=True, drop_last=True, collate_fn=yolo_dataset_collate) - gen_val = DataLoader(val_dataset, batch_size=Batch_size, num_workers=4,pin_memory=True, + gen_val = DataLoader(val_dataset, shuffle=True, batch_size=Batch_size, num_workers=4,pin_memory=True, drop_last=True, collate_fn=yolo_dataset_collate) else: gen = Generator(Batch_size, lines[:num_train], -- GitLab