diff --git a/train.py b/train.py index 3cc450305fb98838f4e5fd532534204348f64ad5..8c71b8ba6abadd5b31518dbd9869f89145f1a04d 100644 --- a/train.py +++ b/train.py @@ -101,7 +101,7 @@ def train(fintune_model,image_size,lr0,path_data,model_exp): if flag_change_lr_cnt > 30: init_lr = init_lr*0.1 - set_learning_rate(optimizer, init_lr) + set_learning_rate(optim, init_lr) flag_change_lr_cnt = 0 loss_mean = 0. # 损失均值