提交 770e7c8e 编写于 作者: Eric.Lee2021's avatar Eric.Lee2021 🚴🏻

update

上级 cdb27ee6
#coding:utf-8
import os
from yolov3 import Yolov3, Yolov3Tiny
from utils.parse_config import parse_data_cfg
from utils.torch_utils import select_device
......@@ -6,16 +7,13 @@ import torch
from torch.utils.data import DataLoader
from utils.datasets import LoadImagesAndLabels
from utils.utils import *
import os
import numpy as np
def set_learning_rate(optimizer, lr):
for param_group in optimizer.param_groups:
param_group['lr'] = lr
def train(data_cfg ='cfg/voc.data',
accumulate = 1):
device = select_device()
def train(data_cfg ='cfg/voc.data',accumulate = 1):
# Configure run
get_data_cfg = parse_data_cfg(data_cfg)#返回训练配置参数,类型:字典
......@@ -33,6 +31,9 @@ def train(data_cfg ='cfg/voc.data',
lr_step = str(get_data_cfg['lr_step'])
lr0 = float(get_data_cfg['lr0'])
os.environ['CUDA_VISIBLE_DEVICES'] = gpus
device = select_device()
if multi_scale == 'True':
multi_scale = True
else:
......@@ -201,11 +202,11 @@ def train(data_cfg ='cfg/voc.data',
#-------------------------------------------------------------------------------
if __name__ == '__main__':
# train(data_cfg="cfg/hand.data")
# train(data_cfg = "cfg/face.data")
# train(data_cfg = "cfg/person.data")
train(data_cfg = "cfg/helmet.data")
# train(data_cfg = "cfg/helmet.data")
train(data_cfg = "cfg/transport.data")
print('well done ~ ')
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册