未验证 提交 97646db9 编写于 作者: W wangna11BD 提交者: GitHub

fix train bug (#742)

* fix train bug

* fix aotgan

* fix aotgan

* fix use_shared_memory

* fix tipc log

* fix iters_per_epoch less 1

* add TODO

* fix benchmark performance decline of more than 5%
上级 9cdb6039
......@@ -44,13 +44,13 @@ optimizer:
optimG:
name: Adam
net_names:
- net_gen
- netG
beta1: 0.5
beta2: 0.999
optimD:
name: Adam
net_names:
- net_des
- netD
beta1: 0.5
beta2: 0.999
......
......@@ -165,7 +165,8 @@ dataset:
# data loader
use_shuffle: true
num_workers: 4
# TODO fix out of memory for val while training
num_workers: 0
batch_size: 1
prefetch_mode: ~
......@@ -180,7 +181,8 @@ dataset:
mean: [0.5, 0.5, 0.5]
std: [0.5, 0.5, 0.5]
scale: 1
num_workers: 4
# TODO fix out of memory for val while training
num_workers: 0
batch_size: 8
phase: val
......
......@@ -17,7 +17,8 @@ model:
dataset:
train:
name: InvDNDataset
num_workers: 10
# TODO fix out of memory for val while training
num_workers: 0
batch_size: 14 # 4 GPUs
opt:
phase: train
......@@ -26,7 +27,8 @@ dataset:
train_dir: data/SIDD_Medium_Srgb_Patches_512/train/
test:
name: InvDNDataset
num_workers: 1
# TODO fix out of memory for val while training
num_workers: 0
batch_size: 1
opt:
phase: test
......
total_iters: 3200000
total_iters: 400000
output_dir: output_dir
model:
......@@ -17,14 +17,16 @@ dataset:
train:
name: NAFNetTrain
rgb_dir: data/SIDD/train
num_workers: 16
batch_size: 8 # 1GPU
# TODO fix out of memory for val while training
num_workers: 0
batch_size: 8 # 8GPU
img_options:
patch_size: 256
test:
name: NAFNetVal
rgb_dir: data/SIDD/val
num_workers: 1
# TODO fix out of memory for val while training
num_workers: 0
batch_size: 1
img_options:
patch_size: 256
......@@ -34,10 +36,10 @@ export_model:
lr_scheduler:
name: CosineAnnealingRestartLR
learning_rate: !!float 125e-6 # num_gpu * 0.000125
periods: [3200000]
learning_rate: 0.001
periods: [400000]
restart_weights: [1]
eta_min: !!float 1e-7
eta_min: !!float 8e-7
validate:
interval: 5000
......
total_iters: 6400000
total_iters: 420000
output_dir: output_dir
model:
......@@ -20,8 +20,9 @@ model:
dataset:
train:
name: SwinIRDataset
num_workers: 8
batch_size: 2 # 1GPU
# TODO fix out of memory for val while training
num_workers: 0
batch_size: 2 # 4GPU
opt:
phase: train
n_channels: 3
......@@ -31,7 +32,8 @@ dataset:
dataroot_H: data/trainsets/trainH
test:
name: SwinIRDataset
num_workers: 1
# TODO fix out of memory for val while training
num_workers: 0
batch_size: 1
opt:
phase: test
......@@ -46,8 +48,8 @@ export_model:
lr_scheduler:
name: MultiStepDecay
learning_rate: 5e-5 # num_gpu * 5e-5
milestones: [3200000, 4800000, 5600000, 6000000, 6400000]
learning_rate: 2e-4
milestones: [210000, 305000, 345000, 385000, 420000]
gamma: 0.5
validate:
......
......@@ -216,7 +216,7 @@ class Trainer:
self.model.setup_train_mode(is_train=True)
while self.current_iter < (self.total_iters + 1):
self.current_epoch = iter_loader.epoch
self.inner_iter = self.current_iter % self.iters_per_epoch
self.inner_iter = self.current_iter % max(self.iters_per_epoch, 1)
add_profiler_step(self.profiler_options)
......
......@@ -91,8 +91,8 @@ class AOTGANModel(BaseModel):
super(AOTGANModel, self).__init__()
# define nets
self.nets['net_gen'] = build_generator(generator)
self.nets['net_des'] = build_discriminator(discriminator)
self.nets['netG'] = build_generator(generator)
self.nets['netD'] = build_discriminator(discriminator)
self.net_vgg = build_criterion(criterion)
self.adv_loss = Adversal()
......@@ -111,9 +111,9 @@ class AOTGANModel(BaseModel):
def forward(self):
input_x = paddle.concat([self.img_masked, self.mask], 1)
self.pred_img = self.nets['net_gen'](input_x)
self.pred_img = self.nets['netG'](input_x)
self.comp_img = (1 - self.mask) * self.img + self.mask * self.pred_img
self.visual_items['pred_img'] = self.pred_img
self.visual_items['pred_img'] = self.pred_img.detach()
def train_iter(self, optimizers=None):
self.forward()
......@@ -121,7 +121,7 @@ class AOTGANModel(BaseModel):
self.losses['l1'] = l1_loss * self.l1_weight
self.losses['perceptual'] = perceptual_loss * self.perceptual_weight
self.losses['style'] = style_loss * self.style_weight
dis_loss, gen_loss = self.adv_loss(self.nets['net_des'], self.comp_img, self.img, self.mask)
dis_loss, gen_loss = self.adv_loss(self.nets['netD'], self.comp_img, self.img, self.mask)
self.losses['adv_g'] = gen_loss * self.adversal_weight
loss_d_fake = dis_loss[0]
loss_d_real = dis_loss[1]
......
......@@ -2,7 +2,7 @@ tqdm
PyYAML>=5.1
scikit-image>=0.14.0
scipy>=1.1.0
opencv-python==4.6.0.66
opencv-python<=4.6.0.66
imageio==2.9.0
imageio-ffmpeg
librosa==0.8.1
......
......@@ -67,19 +67,6 @@ FILENAME=$new_filename
# MODE must be one of ['benchmark_train']
MODE=$2
PARAMS=$3
REST_ARGS=$4
# for log name
to_static=""
# parse "to_static" options and modify trainer into "to_static_trainer"
if [ $REST_ARGS = "to_static" ] || [ $PARAMS = "to_static" ] ;then
to_static="d2sT_"
sed -i 's/trainer:norm_train/trainer:to_static_train/g' $FILENAME
# clear PARAM contents
if [ $PARAMS = "to_static" ] ;then
PARAMS=""
fi
fi
IFS=$'\n'
# parser params from train_benchmark.txt
......@@ -162,6 +149,14 @@ else
device_num_list=($device_num)
fi
# for log name
to_static=""
# parse "to_static" options and modify trainer into "to_static_trainer"
if [[ ${model_type} = "dynamicTostatic" ]];then
to_static="d2sT_"
sed -i 's/trainer:norm_train/trainer:to_static_train/g' $FILENAME
fi
IFS="|"
for batch_size in ${batch_size_list[*]}; do
for precision in ${fp_items_list[*]}; do
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册