diff --git a/ppgan/models/sr_model.py b/ppgan/models/sr_model.py index bd255a7397b664553924b7b89ac59308f3396fa6..118f49fe11da1f72567c996dfd7f60ea694d7b96 100644 --- a/ppgan/models/sr_model.py +++ b/ppgan/models/sr_model.py @@ -1,10 +1,7 @@ from collections import OrderedDict import paddle import paddle.nn as nn -# import torch.nn.parallel as P -# from torch.nn.parallel import DataParallel, DistributedDataParallel -# import models.networks as networks -# import models.lr_scheduler as lr_scheduler + from .generators.builder import build_generator from .discriminators.builder import build_discriminator from ..solver import build_optimizer @@ -13,8 +10,6 @@ from .losses import GANLoss from .builder import MODELS import importlib -import mmcv -import torch from collections import OrderedDict from copy import deepcopy from os import path as osp @@ -24,12 +19,11 @@ from .builder import MODELS @MODELS.register() class SRModel(BaseModel): """Base SR model for single image super-resolution.""" - def __init__(self, cfg): super(SRModel, self).__init__(cfg) self.model_names = ['G'] - + self.netG = build_generator(cfg.model.generator) self.visual_names = ['lq', 'output', 'gt'] @@ -119,7 +113,7 @@ class SRModel(BaseModel): def forward(self): pass - + def test(self): """Forward function used in test time. """ @@ -137,111 +131,7 @@ class SRModel(BaseModel): l_pix = self.criterionL1(self.output, self.gt) l_total += l_pix loss_dict['l_pix'] = l_pix - # perceptual loss - # if self.cri_perceptual: - # l_percep, l_style = self.cri_perceptual(self.output, self.gt) - # if l_percep is not None: - # l_total += l_percep - # loss_dict['l_percep'] = l_percep - # if l_style is not None: - # l_total += l_style - # loss_dict['l_style'] = l_style l_total.backward() self.loss_l_total = l_total self.optimizer_G.step() - - # self.log_dict = self.reduce_loss_dict(loss_dict) - # def get_current_visuals(self): - # out_dict = OrderedDict() - # out_dict['lq'] = self.lq.detach().cpu() - # out_dict['result'] = self.output.detach().cpu() - # if hasattr(self, 'gt'): - # out_dict['gt'] = self.gt.detach().cpu() - # return out_dict - - # def test(self): - # self.net_g.eval() - # with torch.no_grad(): - # self.output = self.net_g(self.lq) - # self.net_g.train() - - # def dist_validation(self, dataloader, current_iter, tb_logger, save_img): - # logger = get_root_logger() - # logger.info('Only support single GPU validation.') - # self.nondist_validation(dataloader, current_iter, tb_logger, save_img) - - # def nondist_validation(self, dataloader, current_iter, tb_logger, - # save_img): - # dataset_name = dataloader.dataset.opt['name'] - # with_metrics = self.opt['val'].get('metrics') is not None - # if with_metrics: - # self.metric_results = { - # metric: 0 - # for metric in self.opt['val']['metrics'].keys() - # } - # pbar = ProgressBar(len(dataloader)) - - # for idx, val_data in enumerate(dataloader): - # img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0] - # self.feed_data(val_data) - # self.test() - - # visuals = self.get_current_visuals() - # sr_img = tensor2img([visuals['result']]) - # if 'gt' in visuals: - # gt_img = tensor2img([visuals['gt']]) - # del self.gt - - # # tentative for out of GPU memory - # del self.lq - # del self.output - # torch.cuda.empty_cache() - - # if save_img: - # if self.opt['is_train']: - # save_img_path = osp.join(self.opt['path']['visualization'], - # img_name, - # f'{img_name}_{current_iter}.png') - # else: - # if self.opt['val']['suffix']: - # save_img_path = osp.join( - # self.opt['path']['visualization'], dataset_name, - # f'{img_name}_{self.opt["val"]["suffix"]}.png') - # else: - # save_img_path = osp.join( - # self.opt['path']['visualization'], dataset_name, - # f'{img_name}_{self.opt["name"]}.png') - # mmcv.imwrite(sr_img, save_img_path) - - # if with_metrics: - # # calculate metrics - # opt_metric = deepcopy(self.opt['val']['metrics']) - # for name, opt_ in opt_metric.items(): - # metric_type = opt_.pop('type') - # self.metric_results[name] += getattr( - # metric_module, metric_type)(sr_img, gt_img, **opt_) - # pbar.update(f'Test {img_name}') - - # if with_metrics: - # for metric in self.metric_results.keys(): - # self.metric_results[metric] /= (idx + 1) - - # self._log_validation_metric_values(current_iter, dataset_name, - # tb_logger) - - # def _log_validation_metric_values(self, current_iter, dataset_name, - # tb_logger): - # log_str = f'Validation {dataset_name}\n' - # for metric, value in self.metric_results.items(): - # log_str += f'\t # {metric}: {value:.4f}\n' - # logger = get_root_logger() - # logger.info(log_str) - # if tb_logger: - # for metric, value in self.metric_results.items(): - # tb_logger.add_scalar(f'metrics/{metric}', value, current_iter) - - - # def save(self, epoch, current_iter): - # self.save_network(self.net_g, 'net_g', current_iter) - # self.save_training_state(epoch, current_iter) diff --git a/ppgan/models/srgan_model.py b/ppgan/models/srgan_model.py index ae5a46190597e07e4611d7999c3a1ab85fc5fe45..32ca581f8fcaac86fe3bf1bfeaca7617ecc0e06a 100644 --- a/ppgan/models/srgan_model.py +++ b/ppgan/models/srgan_model.py @@ -1,16 +1,11 @@ -# import logging from collections import OrderedDict import paddle import paddle.nn as nn -# import torch.nn.parallel as P -# from torch.nn.parallel import DataParallel, DistributedDataParallel -# import models.networks as networks -# import models.lr_scheduler as lr_scheduler + from .generators.builder import build_generator from .base_model import BaseModel from .losses import GANLoss from .builder import MODELS -# logger = logging.getLogger('base') @MODELS.register() @@ -27,7 +22,6 @@ class SRGANModel(BaseModel): # TODO: support srgan train. if False: # self.netD = build_discriminator(cfg.model.discriminator) - self.netG.train() # self.netD.train()