From 81f8a49d373bff4cc9b6f7853b4b9b3fcbe8f3bc Mon Sep 17 00:00:00 2001 From: "Eric.Lee2021" <305141918@qq.com> Date: Wed, 24 Feb 2021 00:11:01 +0800 Subject: [PATCH] update --- train.py | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/train.py b/train.py index 244e26c..3cc4503 100644 --- a/train.py +++ b/train.py @@ -25,14 +25,16 @@ def set_learning_rate(optimizer, lr): def train(fintune_model,image_size,lr0,path_data,model_exp): - # dataset + # config 训练配置 + max_epoch = 1000 n_classes = 19 n_img_per_gpu = 16 n_workers = 8 cropsize = [int(image_size*0.85),int(image_size*0.85)] + # DataLoader 数据迭代器 ds = FaceMask(path_data,img_size = image_size, cropsize=cropsize, mode='train') - # sampler = torch.utils.data.distributed.DistributedSampler(ds) + dl = DataLoader(ds, batch_size = n_img_per_gpu, shuffle = True, @@ -42,19 +44,19 @@ def train(fintune_model,image_size,lr0,path_data,model_exp): # model ignore_idx = -100 - + # 构建模型 use_cuda = torch.cuda.is_available() device = torch.device("cuda:0" if use_cuda else "cpu") net = BiSeNet(n_classes=n_classes) net = net.to(device) - + # 加载预训练模型 if os.access(fintune_model,os.F_OK) and (fintune_model is not None):# checkpoint chkpt = torch.load(fintune_model, map_location=device) net.load_state_dict(chkpt) print('load fintune model : {}'.format(fintune_model)) else: print('no fintune model') - + # 构建损失函数 score_thres = 0.7 n_min = n_img_per_gpu * cropsize[0] * cropsize[1]//16 LossP = OhemCELoss(thresh=score_thres, n_min=n_min, ignore_lb=ignore_idx) @@ -65,15 +67,14 @@ def train(fintune_model,image_size,lr0,path_data,model_exp): momentum = 0.9 weight_decay = 5e-4 lr_start = lr0 - max_epoch = 1000 - + # 构建优化器 optim = Optimizer.SGD( net.parameters(), lr = lr_start, momentum = momentum, weight_decay = weight_decay) - ## train loop + # train loop msg_iter = 50 loss_avg = [] st = glob_st = time.time() @@ -85,7 +86,7 @@ def train(fintune_model,image_size,lr0,path_data,model_exp): best_loss = np.inf loss_mean = 0. # 损失均值 loss_idx = 0. # 损失计算计数器 - + # 训练 print('start training ~') it = 0 for epoch in range(max_epoch): @@ -126,12 +127,10 @@ def train(fintune_model,image_size,lr0,path_data,model_exp): loss.backward() optim.step() - if it % msg_iter == 0: print('epoch <{}/{}> -->> <{}/{}> -> iter {} : loss {:.5f}, loss_mean :{:.5f}, best_loss :{:.5f},lr :{:.6f},batch_size : {}'.\ format(epoch,max_epoch,i,int(ds.__len__()/n_img_per_gpu),it,loss.item(),loss_mean/loss_idx,best_loss,init_lr,n_img_per_gpu)) - # print(msg) if (it) % 500 == 0: state = net.module.state_dict() if hasattr(net, 'module') else net.state_dict() @@ -140,7 +139,7 @@ def train(fintune_model,image_size,lr0,path_data,model_exp): torch.save(state, model_exp+'fp_{}_epoch-{}.pth'.format(image_size,epoch)) if __name__ == "__main__": - image_size = 512 + image_size = 256 lr0 = 1e-4 model_exp = './model_exp/' path_data = './CelebAMask-HQ/' -- GitLab