From 770e7c8ed0b81858c2a3381af9113bc865dd208f Mon Sep 17 00:00:00 2001 From: "Eric.Lee2021" <305141918@qq.com> Date: Wed, 24 Feb 2021 18:42:42 +0800 Subject: [PATCH] update --- train.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/train.py b/train.py index 7a59cf5..6e81b88 100644 --- a/train.py +++ b/train.py @@ -1,4 +1,5 @@ #coding:utf-8 +import os from yolov3 import Yolov3, Yolov3Tiny from utils.parse_config import parse_data_cfg from utils.torch_utils import select_device @@ -6,16 +7,13 @@ import torch from torch.utils.data import DataLoader from utils.datasets import LoadImagesAndLabels from utils.utils import * -import os import numpy as np def set_learning_rate(optimizer, lr): for param_group in optimizer.param_groups: param_group['lr'] = lr -def train(data_cfg ='cfg/voc.data', - accumulate = 1): - device = select_device() +def train(data_cfg ='cfg/voc.data',accumulate = 1): # Configure run get_data_cfg = parse_data_cfg(data_cfg)#返回训练配置参数,类型:字典 @@ -33,6 +31,9 @@ def train(data_cfg ='cfg/voc.data', lr_step = str(get_data_cfg['lr_step']) lr0 = float(get_data_cfg['lr0']) + os.environ['CUDA_VISIBLE_DEVICES'] = gpus + device = select_device() + if multi_scale == 'True': multi_scale = True else: @@ -201,11 +202,11 @@ def train(data_cfg ='cfg/voc.data', #------------------------------------------------------------------------------- if __name__ == '__main__': - # train(data_cfg="cfg/hand.data") # train(data_cfg = "cfg/face.data") # train(data_cfg = "cfg/person.data") - train(data_cfg = "cfg/helmet.data") + # train(data_cfg = "cfg/helmet.data") + train(data_cfg = "cfg/transport.data") print('well done ~ ') -- GitLab