提交 770e7c8e 编写于 作者: Eric.Lee2021's avatar Eric.Lee2021 🚴🏻

update

上级 cdb27ee6
#coding:utf-8 #coding:utf-8
import os
from yolov3 import Yolov3, Yolov3Tiny from yolov3 import Yolov3, Yolov3Tiny
from utils.parse_config import parse_data_cfg from utils.parse_config import parse_data_cfg
from utils.torch_utils import select_device from utils.torch_utils import select_device
...@@ -6,16 +7,13 @@ import torch ...@@ -6,16 +7,13 @@ import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from utils.datasets import LoadImagesAndLabels from utils.datasets import LoadImagesAndLabels
from utils.utils import * from utils.utils import *
import os
import numpy as np import numpy as np
def set_learning_rate(optimizer, lr): def set_learning_rate(optimizer, lr):
for param_group in optimizer.param_groups: for param_group in optimizer.param_groups:
param_group['lr'] = lr param_group['lr'] = lr
def train(data_cfg ='cfg/voc.data', def train(data_cfg ='cfg/voc.data',accumulate = 1):
accumulate = 1):
device = select_device()
# Configure run # Configure run
get_data_cfg = parse_data_cfg(data_cfg)#返回训练配置参数,类型:字典 get_data_cfg = parse_data_cfg(data_cfg)#返回训练配置参数,类型:字典
...@@ -33,6 +31,9 @@ def train(data_cfg ='cfg/voc.data', ...@@ -33,6 +31,9 @@ def train(data_cfg ='cfg/voc.data',
lr_step = str(get_data_cfg['lr_step']) lr_step = str(get_data_cfg['lr_step'])
lr0 = float(get_data_cfg['lr0']) lr0 = float(get_data_cfg['lr0'])
os.environ['CUDA_VISIBLE_DEVICES'] = gpus
device = select_device()
if multi_scale == 'True': if multi_scale == 'True':
multi_scale = True multi_scale = True
else: else:
...@@ -201,11 +202,11 @@ def train(data_cfg ='cfg/voc.data', ...@@ -201,11 +202,11 @@ def train(data_cfg ='cfg/voc.data',
#------------------------------------------------------------------------------- #-------------------------------------------------------------------------------
if __name__ == '__main__': if __name__ == '__main__':
# train(data_cfg="cfg/hand.data") # train(data_cfg="cfg/hand.data")
# train(data_cfg = "cfg/face.data") # train(data_cfg = "cfg/face.data")
# train(data_cfg = "cfg/person.data") # train(data_cfg = "cfg/person.data")
train(data_cfg = "cfg/helmet.data") # train(data_cfg = "cfg/helmet.data")
train(data_cfg = "cfg/transport.data")
print('well done ~ ') print('well done ~ ')
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册