未验证 提交 f1adbbc5 编写于 作者: B Bubbliiiing 提交者: GitHub

Add files via upload

上级 2ca5a9e2
from nets.ssd import get_ssd
from nets.ssd_training import Generator,MultiBoxLoss
from torch.utils.data import DataLoader
from utils.dataloader import ssd_dataset_collate, SSDDataset
from utils.config import Config
from torchsummary import summary
from torch.autograd import Variable
......@@ -17,20 +19,26 @@ def adjust_learning_rate(optimizer, lr, gamma, step):
return lr
if __name__ == "__main__":
Batch_size = 4
# ------------------------------------#
# 先冻结一部分权重训练
# 后解冻全部权重训练
# 先大学习率
# 后小学习率
# ------------------------------------#
lr = 1e-4
lr = 1e-5
freeze_lr = 1e-5
Cuda = True
Start_iter = 0
Freeze_epoch = 25
Epoch = 50
Batch_size = 4
#-------------------------------#
# Dataloder的使用
#-------------------------------#
Use_Data_Loader = True
model = get_ssd("train",Config["num_classes"])
print('Loading weights into state dict...')
......@@ -56,8 +64,13 @@ if __name__ == "__main__":
np.random.seed(None)
num_train = len(lines)
gen = Generator(Batch_size, lines,
(Config["min_dim"], Config["min_dim"]), Config["num_classes"]).generate()
if Use_Data_Loader:
train_dataset = SSDDataset(lines[:num_train], (Config["min_dim"], Config["min_dim"]))
gen = DataLoader(train_dataset, batch_size=Batch_size, num_workers=8, pin_memory=True,
drop_last=True, collate_fn=ssd_dataset_collate)
else:
gen = Generator(Batch_size, lines,
(Config["min_dim"], Config["min_dim"]), Config["num_classes"]).generate()
criterion = MultiBoxLoss(Config['num_classes'], 0.5, True, 0, True, 3, 0.5,
False, Cuda)
......@@ -80,7 +93,6 @@ if __name__ == "__main__":
for iteration, batch in enumerate(gen):
if iteration >= epoch_size:
break
start_time = time.time()
images, targets = batch[0], batch[1]
with torch.no_grad():
if Cuda:
......@@ -105,7 +117,8 @@ if __name__ == "__main__":
print('\nEpoch:'+ str(epoch+1) + '/' + str(Freeze_epoch))
print('iter:' + str(iteration) + '/' + str(epoch_size) + ' || Loc_Loss: %.4f || Conf_Loss: %.4f ||' % (loc_loss/(iteration+1),conf_loss/(iteration+1)), end=' ')
print('Saving state, iter:', str(epoch+1))
torch.save(model.state_dict(), 'logs/Epoch%d-Loc%.4f-Conf%.4f.pth'%((epoch+1),loc_loss/(iteration+1),conf_loss/(iteration+1)))
......@@ -125,7 +138,6 @@ if __name__ == "__main__":
for iteration, batch in enumerate(gen):
if iteration >= epoch_size:
break
start_time = time.time()
images, targets = batch[0], batch[1]
with torch.no_grad():
if Cuda:
......@@ -152,4 +164,4 @@ if __name__ == "__main__":
print('iter:' + str(iteration) + '/' + str(epoch_size) + ' || Loc_Loss: %.4f || Conf_Loss: %.4f ||' % (loc_loss/(iteration+1),conf_loss/(iteration+1)), end=' ')
print('Saving state, iter:', str(epoch+1))
torch.save(model.state_dict(), 'logs/Epoch%d-Loc%.4f-Conf%.4f.pth'%((epoch+1),loc_loss/(iteration+1),conf_loss/(iteration+1)))
torch.save(model.state_dict(), 'logs/Epoch%d-Loc%.4f-Conf%.4f.pth'%((epoch+1),loc_loss/(iteration+1),conf_loss/(iteration+1)))
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册