From a0a56e75e1d6b3a914502b1554c647c7d8386591 Mon Sep 17 00:00:00 2001 From: LielinJiang Date: Wed, 16 Sep 2020 12:05:30 +0000 Subject: [PATCH] rm unused code --- ppgan/models/srgan_model.py | 347 +----------------------------------- 1 file changed, 7 insertions(+), 340 deletions(-) diff --git a/ppgan/models/srgan_model.py b/ppgan/models/srgan_model.py index 762e44a..ae5a461 100644 --- a/ppgan/models/srgan_model.py +++ b/ppgan/models/srgan_model.py @@ -12,126 +12,24 @@ from .losses import GANLoss from .builder import MODELS # logger = logging.getLogger('base') + @MODELS.register() class SRGANModel(BaseModel): def __init__(self, cfg): super(SRGANModel, self).__init__(cfg) - # if opt['dist']: - # self.rank = torch.distributed.get_rank() - # else: - # self.rank = -1 # non dist training - # train_opt = opt['train'] - # define networks and load pretrained models + # define networks self.model_names = ['G'] - + self.netG = build_generator(cfg.model.generator) self.visual_names = ['LQ', 'GT', 'fake_H'] - # self.netG = networks.define_G(opt).to(self.device) - # if opt['dist']: - # self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()]) - # else: - # self.netG = DataParallel(self.netG) - if False:#self.is_train: - self.netD = build_discriminator(cfg.model.discriminator) - # if opt['dist']: - # self.netD = DistributedDataParallel(self.netD, - # device_ids=[torch.cuda.current_device()]) - # else: - # self.netD = DataParallel(self.netD) + # TODO: support srgan train. + if False: + # self.netD = build_discriminator(cfg.model.discriminator) self.netG.train() - self.netD.train() - - # define losses, optimizer and scheduler - # if self.is_train: - # pass - # G pixel loss - # if train_opt['pixel_weight'] > 0: - # l_pix_type = train_opt['pixel_criterion'] - # if l_pix_type == 'l1': - # self.cri_pix = nn.L1Loss().to(self.device) - # elif l_pix_type == 'l2': - # self.cri_pix = nn.MSELoss().to(self.device) - # else: - # raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_pix_type)) - # self.l_pix_w = train_opt['pixel_weight'] - # else: - # # logger.info('Remove pixel loss.') - # self.cri_pix = None - - # # G feature loss - # if train_opt['feature_weight'] > 0: - # l_fea_type = train_opt['feature_criterion'] - # if l_fea_type == 'l1': - # self.cri_fea = nn.L1Loss().to(self.device) - # elif l_fea_type == 'l2': - # self.cri_fea = nn.MSELoss().to(self.device) - # else: - # raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_fea_type)) - # self.l_fea_w = train_opt['feature_weight'] - # else: - # logger.info('Remove feature loss.') - # self.cri_fea = None - # if self.cri_fea: # load VGG perceptual loss - # self.netF = networks.define_F(opt, use_bn=False).to(self.device) - # if opt['dist']: - # self.netF = DistributedDataParallel(self.netF, - # device_ids=[torch.cuda.current_device()]) - # else: - # self.netF = DataParallel(self.netF) - - # # GD gan loss - # self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device) - # self.l_gan_w = train_opt['gan_weight'] - # # D_update_ratio and D_init_iters - # self.D_update_ratio = train_opt['D_update_ratio'] if train_opt['D_update_ratio'] else 1 - # self.D_init_iters = train_opt['D_init_iters'] if train_opt['D_init_iters'] else 0 - - # # optimizers - # # G - # wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0 - # optim_params = [] - # for k, v in self.netG.named_parameters(): # can optimize for a part of the model - # if v.requires_grad: - # optim_params.append(v) - # else: - # if self.rank <= 0: - # logger.warning('Params [{:s}] will not optimize.'.format(k)) - # self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], - # weight_decay=wd_G, - # betas=(train_opt['beta1_G'], train_opt['beta2_G'])) - # self.optimizers.append(self.optimizer_G) - # # D - # wd_D = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0 - # self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], - # weight_decay=wd_D, - # betas=(train_opt['beta1_D'], train_opt['beta2_D'])) - # self.optimizers.append(self.optimizer_D) - - # # schedulers - # if train_opt['lr_scheme'] == 'MultiStepLR': - # for optimizer in self.optimizers: - # self.schedulers.append( - # lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'], - # restarts=train_opt['restarts'], - # weights=train_opt['restart_weights'], - # gamma=train_opt['lr_gamma'], - # clear_state=train_opt['clear_state'])) - # elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': - # for optimizer in self.optimizers: - # self.schedulers.append( - # lr_scheduler.CosineAnnealingLR_Restart( - # optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], - # restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) - # else: - # raise NotImplementedError('MultiStepLR learning rate scheme is enough.') - - # self.log_dict = OrderedDict() - - # self.print_network() # print network - # self.load() # load G and D if needed + # self.netD.train() def set_input(self, input): """Unpack input data from the dataloader and perform necessary pre-processing steps. @@ -150,239 +48,8 @@ class SRGANModel(BaseModel): if 'A_paths' in input: self.image_paths = input['A_paths'] - # def feed_data(self, data, need_GT=True): - # self.var_L = data['LQ'].to(self.device) # LQ - # if need_GT: - # self.var_H = data['GT'].to(self.device) # GT - # input_ref = data['ref'] if 'ref' in data else data['GT'] - # self.var_ref = input_ref.to(self.device) - def forward(self): self.fake_H = self.netG(self.LQ) def optimize_parameters(self, step): pass - # # G - # for p in self.netD.parameters(): - # p.requires_grad = False - - # self.optimizer_G.zero_grad() - # self.fake_H = self.netG(self.var_L.detach()) - - # l_g_total = 0 - # if step % self.D_update_ratio == 0 and step > self.D_init_iters: - # if self.cri_pix: # pixel loss - # l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H) - # l_g_total += l_g_pix - # if self.cri_fea: # feature loss - # real_fea = self.netF(self.var_H).detach() - # fake_fea = self.netF(self.fake_H) - # l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea) - # l_g_total += l_g_fea - - # pred_g_fake = self.netD(self.fake_H) - # if self.opt['train']['gan_type'] == 'gan': - # l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True) - # elif self.opt['train']['gan_type'] == 'ragan': - # pred_d_real = self.netD(self.var_ref).detach() - # l_g_gan = self.l_gan_w * ( - # self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) + - # self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2 - # l_g_total += l_g_gan - - # l_g_total.backward() - # self.optimizer_G.step() - - # # D - # for p in self.netD.parameters(): - # p.requires_grad = True - - # self.optimizer_D.zero_grad() - # l_d_total = 0 - # pred_d_real = self.netD(self.var_ref) - # pred_d_fake = self.netD(self.fake_H.detach()) # detach to avoid BP to G - # if self.opt['train']['gan_type'] == 'gan': - # l_d_real = self.cri_gan(pred_d_real, True) - # l_d_fake = self.cri_gan(pred_d_fake, False) - # l_d_total = l_d_real + l_d_fake - # elif self.opt['train']['gan_type'] == 'ragan': - # l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) - # l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False) - # l_d_total = (l_d_real + l_d_fake) / 2 - - # l_d_total.backward() - # self.optimizer_D.step() - - # # set log - # if step % self.D_update_ratio == 0 and step > self.D_init_iters: - # if self.cri_pix: - # self.log_dict['l_g_pix'] = l_g_pix.item() - # # self.log_dict['l_g_mean_color'] = l_g_mean_color.item() - # if self.cri_fea: - # self.log_dict['l_g_fea'] = l_g_fea.item() - # self.log_dict['l_g_gan'] = l_g_gan.item() - - # self.log_dict['l_d_real'] = l_d_real.item() - # self.log_dict['l_d_fake'] = l_d_fake.item() - # self.log_dict['D_real'] = torch.mean(pred_d_real.detach()) - # self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach()) - - # def test(self): - # self.netG.eval() - # with torch.no_grad(): - # self.fake_H = self.netG(self.var_L) - # self.netG.train() - - # def back_projection(self): - # lr_error = self.var_L - torch.nn.functional.interpolate(self.fake_H, - # scale_factor=1/self.opt['scale'], - # mode='bicubic', - # align_corners=False) - # us_error = torch.nn.functional.interpolate(lr_error, - # scale_factor=self.opt['scale'], - # mode='bicubic', - # align_corners=False) - # self.fake_H += self.opt['back_projection_lamda'] * us_error - # torch.clamp(self.fake_H, 0, 1) - - # def test_chop(self): - # self.netG.eval() - # with torch.no_grad(): - # self.fake_H = self.forward_chop(self.var_L) - # self.netG.train() - - # def forward_chop(self, *args, shave=10, min_size=160000): - # # scale = 1 if self.input_large else self.scale[self.idx_scale] - # scale = self.opt['scale'] - # n_GPUs = min(torch.cuda.device_count(), 4) - # args = [a.squeeze().unsqueeze(0) for a in args] - - # # height, width - # h, w = args[0].size()[-2:] - # # print('len(args)', len(args)) - # # print('args[0].size()', args[0].size()) - - # top = slice(0, h//2 + shave) - # bottom = slice(h - h//2 - shave, h) - # left = slice(0, w//2 + shave) - # right = slice(w - w//2 - shave, w) - # x_chops = [torch.cat([ - # a[..., top, left], - # a[..., top, right], - # a[..., bottom, left], - # a[..., bottom, right] - # ]) for a in args] - # # print('len(x_chops)', len(x_chops)) - # # print('x_chops[0].size()', x_chops[0].size()) - - # y_chops = [] - # if h * w < 4 * min_size: - # for i in range(0, 4, n_GPUs): - # x = [x_chop[i:(i + n_GPUs)] for x_chop in x_chops] - # # print(len(x)) - # # print(x[0].size()) - # y = P.data_parallel(self.netG, *x, range(n_GPUs)) - # if not isinstance(y, list): y = [y] - # if not y_chops: - # y_chops = [[c for c in _y.chunk(n_GPUs, dim=0)] for _y in y] - # else: - # for y_chop, _y in zip(y_chops, y): - # y_chop.extend(_y.chunk(n_GPUs, dim=0)) - # else: - - # # print(x_chops[0].size()) - # for p in zip(*x_chops): - # # print('len(p)', len(p)) - # # print('p[0].size()', p[0].size()) - # y = self.forward_chop(*p, shave=shave, min_size=min_size) - # if not isinstance(y, list): y = [y] - # if not y_chops: - # y_chops = [[_y] for _y in y] - # else: - # for y_chop, _y in zip(y_chops, y): y_chop.append(_y) - - # h *= scale - # w *= scale - # top = slice(0, h//2) - # bottom = slice(h - h//2, h) - # bottom_r = slice(h//2 - h, None) - # left = slice(0, w//2) - # right = slice(w - w//2, w) - # right_r = slice(w//2 - w, None) - - # # batch size, number of color channels - # b, c = y_chops[0][0].size()[:-2] - # y = [y_chop[0].new(b, c, h, w) for y_chop in y_chops] - # for y_chop, _y in zip(y_chops, y): - # _y[..., top, left] = y_chop[0][..., top, left] - # _y[..., top, right] = y_chop[1][..., top, right_r] - # _y[..., bottom, left] = y_chop[2][..., bottom_r, left] - # _y[..., bottom, right] = y_chop[3][..., bottom_r, right_r] - - # if len(y) == 1: - # y = y[0] - - # return y - - # def get_current_log(self): - # return self.log_dict - - # def get_current_visuals(self, need_GT=True): - # out_dict = OrderedDict() - # out_dict['LQ'] = self.var_L.detach()[0].float().cpu() - # out_dict['SR'] = self.fake_H.detach()[0].float().cpu() - # if need_GT: - # out_dict['GT'] = self.var_H.detach()[0].float().cpu() - # return out_dict - - # def print_network(self): - # # Generator - # s, n = self.get_network_description(self.netG) - # if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel): - # net_struc_str = '{} - {}'.format(self.netG.__class__.__name__, - # self.netG.module.__class__.__name__) - # else: - # net_struc_str = '{}'.format(self.netG.__class__.__name__) - # if self.rank <= 0: - # logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n)) - # logger.info(s) - # if self.is_train: - # # Discriminator - # s, n = self.get_network_description(self.netD) - # if isinstance(self.netD, nn.DataParallel) or isinstance(self.netD, - # DistributedDataParallel): - # net_struc_str = '{} - {}'.format(self.netD.__class__.__name__, - # self.netD.module.__class__.__name__) - # else: - # net_struc_str = '{}'.format(self.netD.__class__.__name__) - # if self.rank <= 0: - # logger.info('Network D structure: {}, with parameters: {:,d}'.format( - # net_struc_str, n)) - # logger.info(s) - - # if self.cri_fea: # F, Perceptual Network - # s, n = self.get_network_description(self.netF) - # if isinstance(self.netF, nn.DataParallel) or isinstance( - # self.netF, DistributedDataParallel): - # net_struc_str = '{} - {}'.format(self.netF.__class__.__name__, - # self.netF.module.__class__.__name__) - # else: - # net_struc_str = '{}'.format(self.netF.__class__.__name__) - # if self.rank <= 0: - # logger.info('Network F structure: {}, with parameters: {:,d}'.format( - # net_struc_str, n)) - # logger.info(s) - - # def load(self): - # load_path_G = self.opt['path']['pretrain_model_G'] - # if load_path_G is not None: - # logger.info('Loading model for G [{:s}] ...'.format(load_path_G)) - # self.load_network(load_path_G, self.netG, self.opt['path']['strict_load']) - # load_path_D = self.opt['path']['pretrain_model_D'] - # if self.opt['is_train'] and load_path_D is not None: - # logger.info('Loading model for D [{:s}] ...'.format(load_path_D)) - # self.load_network(load_path_D, self.netD, self.opt['path']['strict_load']) - - # def save(self, iter_step): - # self.save_network(self.netG, 'G', iter_step) - # self.save_network(self.netD, 'D', iter_step) -- GitLab