diff --git a/train.py b/train.py index e43aa75ec9b9914c11f9ed3b0857b12ea2a6bea4..f45ed30dd4f6490af7e6419b540b54ccf7109b3e 100644 --- a/train.py +++ b/train.py @@ -166,7 +166,7 @@ if __name__ == "__main__": gen_val = Generator(Batch_size, lines[num_train:], (input_shape[0], input_shape[1])).generate(mosaic = False) - epoch_size = int(max(1, num_train//Batch_size//2.5)) if mosaic else max(1, num_train//Batch_size) + epoch_size = max(1, num_train//Batch_size) epoch_size_val = num_val//Batch_size #------------------------------------# # 冻结一定部分训练 @@ -195,7 +195,7 @@ if __name__ == "__main__": gen_val = Generator(Batch_size, lines[num_train:], (input_shape[0], input_shape[1])).generate(mosaic = False) - epoch_size = int(max(1, num_train//Batch_size//2.5)) if mosaic else max(1, num_train//Batch_size) + epoch_size = max(1, num_train//Batch_size) epoch_size_val = num_val//Batch_size #------------------------------------# # 解冻后训练