未验证 提交 97646db9 编写于 作者: W wangna11BD 提交者: GitHub

fix train bug (#742)

* fix train bug

* fix aotgan

* fix aotgan

* fix use_shared_memory

* fix tipc log

* fix iters_per_epoch less 1

* add TODO

* fix benchmark performance decline of more than 5%
上级 9cdb6039
...@@ -44,13 +44,13 @@ optimizer: ...@@ -44,13 +44,13 @@ optimizer:
optimG: optimG:
name: Adam name: Adam
net_names: net_names:
- net_gen - netG
beta1: 0.5 beta1: 0.5
beta2: 0.999 beta2: 0.999
optimD: optimD:
name: Adam name: Adam
net_names: net_names:
- net_des - netD
beta1: 0.5 beta1: 0.5
beta2: 0.999 beta2: 0.999
......
...@@ -165,7 +165,8 @@ dataset: ...@@ -165,7 +165,8 @@ dataset:
# data loader # data loader
use_shuffle: true use_shuffle: true
num_workers: 4 # TODO fix out of memory for val while training
num_workers: 0
batch_size: 1 batch_size: 1
prefetch_mode: ~ prefetch_mode: ~
...@@ -180,7 +181,8 @@ dataset: ...@@ -180,7 +181,8 @@ dataset:
mean: [0.5, 0.5, 0.5] mean: [0.5, 0.5, 0.5]
std: [0.5, 0.5, 0.5] std: [0.5, 0.5, 0.5]
scale: 1 scale: 1
num_workers: 4 # TODO fix out of memory for val while training
num_workers: 0
batch_size: 8 batch_size: 8
phase: val phase: val
......
...@@ -17,7 +17,8 @@ model: ...@@ -17,7 +17,8 @@ model:
dataset: dataset:
train: train:
name: InvDNDataset name: InvDNDataset
num_workers: 10 # TODO fix out of memory for val while training
num_workers: 0
batch_size: 14 # 4 GPUs batch_size: 14 # 4 GPUs
opt: opt:
phase: train phase: train
...@@ -26,7 +27,8 @@ dataset: ...@@ -26,7 +27,8 @@ dataset:
train_dir: data/SIDD_Medium_Srgb_Patches_512/train/ train_dir: data/SIDD_Medium_Srgb_Patches_512/train/
test: test:
name: InvDNDataset name: InvDNDataset
num_workers: 1 # TODO fix out of memory for val while training
num_workers: 0
batch_size: 1 batch_size: 1
opt: opt:
phase: test phase: test
......
total_iters: 3200000 total_iters: 400000
output_dir: output_dir output_dir: output_dir
model: model:
...@@ -17,14 +17,16 @@ dataset: ...@@ -17,14 +17,16 @@ dataset:
train: train:
name: NAFNetTrain name: NAFNetTrain
rgb_dir: data/SIDD/train rgb_dir: data/SIDD/train
num_workers: 16 # TODO fix out of memory for val while training
batch_size: 8 # 1GPU num_workers: 0
batch_size: 8 # 8GPU
img_options: img_options:
patch_size: 256 patch_size: 256
test: test:
name: NAFNetVal name: NAFNetVal
rgb_dir: data/SIDD/val rgb_dir: data/SIDD/val
num_workers: 1 # TODO fix out of memory for val while training
num_workers: 0
batch_size: 1 batch_size: 1
img_options: img_options:
patch_size: 256 patch_size: 256
...@@ -34,10 +36,10 @@ export_model: ...@@ -34,10 +36,10 @@ export_model:
lr_scheduler: lr_scheduler:
name: CosineAnnealingRestartLR name: CosineAnnealingRestartLR
learning_rate: !!float 125e-6 # num_gpu * 0.000125 learning_rate: 0.001
periods: [3200000] periods: [400000]
restart_weights: [1] restart_weights: [1]
eta_min: !!float 1e-7 eta_min: !!float 8e-7
validate: validate:
interval: 5000 interval: 5000
......
total_iters: 6400000 total_iters: 420000
output_dir: output_dir output_dir: output_dir
model: model:
...@@ -20,8 +20,9 @@ model: ...@@ -20,8 +20,9 @@ model:
dataset: dataset:
train: train:
name: SwinIRDataset name: SwinIRDataset
num_workers: 8 # TODO fix out of memory for val while training
batch_size: 2 # 1GPU num_workers: 0
batch_size: 2 # 4GPU
opt: opt:
phase: train phase: train
n_channels: 3 n_channels: 3
...@@ -31,7 +32,8 @@ dataset: ...@@ -31,7 +32,8 @@ dataset:
dataroot_H: data/trainsets/trainH dataroot_H: data/trainsets/trainH
test: test:
name: SwinIRDataset name: SwinIRDataset
num_workers: 1 # TODO fix out of memory for val while training
num_workers: 0
batch_size: 1 batch_size: 1
opt: opt:
phase: test phase: test
...@@ -46,8 +48,8 @@ export_model: ...@@ -46,8 +48,8 @@ export_model:
lr_scheduler: lr_scheduler:
name: MultiStepDecay name: MultiStepDecay
learning_rate: 5e-5 # num_gpu * 5e-5 learning_rate: 2e-4
milestones: [3200000, 4800000, 5600000, 6000000, 6400000] milestones: [210000, 305000, 345000, 385000, 420000]
gamma: 0.5 gamma: 0.5
validate: validate:
......
...@@ -216,7 +216,7 @@ class Trainer: ...@@ -216,7 +216,7 @@ class Trainer:
self.model.setup_train_mode(is_train=True) self.model.setup_train_mode(is_train=True)
while self.current_iter < (self.total_iters + 1): while self.current_iter < (self.total_iters + 1):
self.current_epoch = iter_loader.epoch self.current_epoch = iter_loader.epoch
self.inner_iter = self.current_iter % self.iters_per_epoch self.inner_iter = self.current_iter % max(self.iters_per_epoch, 1)
add_profiler_step(self.profiler_options) add_profiler_step(self.profiler_options)
......
...@@ -91,8 +91,8 @@ class AOTGANModel(BaseModel): ...@@ -91,8 +91,8 @@ class AOTGANModel(BaseModel):
super(AOTGANModel, self).__init__() super(AOTGANModel, self).__init__()
# define nets # define nets
self.nets['net_gen'] = build_generator(generator) self.nets['netG'] = build_generator(generator)
self.nets['net_des'] = build_discriminator(discriminator) self.nets['netD'] = build_discriminator(discriminator)
self.net_vgg = build_criterion(criterion) self.net_vgg = build_criterion(criterion)
self.adv_loss = Adversal() self.adv_loss = Adversal()
...@@ -111,9 +111,9 @@ class AOTGANModel(BaseModel): ...@@ -111,9 +111,9 @@ class AOTGANModel(BaseModel):
def forward(self): def forward(self):
input_x = paddle.concat([self.img_masked, self.mask], 1) input_x = paddle.concat([self.img_masked, self.mask], 1)
self.pred_img = self.nets['net_gen'](input_x) self.pred_img = self.nets['netG'](input_x)
self.comp_img = (1 - self.mask) * self.img + self.mask * self.pred_img self.comp_img = (1 - self.mask) * self.img + self.mask * self.pred_img
self.visual_items['pred_img'] = self.pred_img self.visual_items['pred_img'] = self.pred_img.detach()
def train_iter(self, optimizers=None): def train_iter(self, optimizers=None):
self.forward() self.forward()
...@@ -121,7 +121,7 @@ class AOTGANModel(BaseModel): ...@@ -121,7 +121,7 @@ class AOTGANModel(BaseModel):
self.losses['l1'] = l1_loss * self.l1_weight self.losses['l1'] = l1_loss * self.l1_weight
self.losses['perceptual'] = perceptual_loss * self.perceptual_weight self.losses['perceptual'] = perceptual_loss * self.perceptual_weight
self.losses['style'] = style_loss * self.style_weight self.losses['style'] = style_loss * self.style_weight
dis_loss, gen_loss = self.adv_loss(self.nets['net_des'], self.comp_img, self.img, self.mask) dis_loss, gen_loss = self.adv_loss(self.nets['netD'], self.comp_img, self.img, self.mask)
self.losses['adv_g'] = gen_loss * self.adversal_weight self.losses['adv_g'] = gen_loss * self.adversal_weight
loss_d_fake = dis_loss[0] loss_d_fake = dis_loss[0]
loss_d_real = dis_loss[1] loss_d_real = dis_loss[1]
......
...@@ -2,7 +2,7 @@ tqdm ...@@ -2,7 +2,7 @@ tqdm
PyYAML>=5.1 PyYAML>=5.1
scikit-image>=0.14.0 scikit-image>=0.14.0
scipy>=1.1.0 scipy>=1.1.0
opencv-python==4.6.0.66 opencv-python<=4.6.0.66
imageio==2.9.0 imageio==2.9.0
imageio-ffmpeg imageio-ffmpeg
librosa==0.8.1 librosa==0.8.1
......
...@@ -67,19 +67,6 @@ FILENAME=$new_filename ...@@ -67,19 +67,6 @@ FILENAME=$new_filename
# MODE must be one of ['benchmark_train'] # MODE must be one of ['benchmark_train']
MODE=$2 MODE=$2
PARAMS=$3 PARAMS=$3
REST_ARGS=$4
# for log name
to_static=""
# parse "to_static" options and modify trainer into "to_static_trainer"
if [ $REST_ARGS = "to_static" ] || [ $PARAMS = "to_static" ] ;then
to_static="d2sT_"
sed -i 's/trainer:norm_train/trainer:to_static_train/g' $FILENAME
# clear PARAM contents
if [ $PARAMS = "to_static" ] ;then
PARAMS=""
fi
fi
IFS=$'\n' IFS=$'\n'
# parser params from train_benchmark.txt # parser params from train_benchmark.txt
...@@ -162,6 +149,14 @@ else ...@@ -162,6 +149,14 @@ else
device_num_list=($device_num) device_num_list=($device_num)
fi fi
# for log name
to_static=""
# parse "to_static" options and modify trainer into "to_static_trainer"
if [[ ${model_type} = "dynamicTostatic" ]];then
to_static="d2sT_"
sed -i 's/trainer:norm_train/trainer:to_static_train/g' $FILENAME
fi
IFS="|" IFS="|"
for batch_size in ${batch_size_list[*]}; do for batch_size in ${batch_size_list[*]}; do
for precision in ${fp_items_list[*]}; do for precision in ${fp_items_list[*]}; do
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册