提交 99d09216 编写于 作者: L LielinJiang

rm unused code

上级 a0a56e75
from collections import OrderedDict from collections import OrderedDict
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
# import torch.nn.parallel as P
# from torch.nn.parallel import DataParallel, DistributedDataParallel
# import models.networks as networks
# import models.lr_scheduler as lr_scheduler
from .generators.builder import build_generator from .generators.builder import build_generator
from .discriminators.builder import build_discriminator from .discriminators.builder import build_discriminator
from ..solver import build_optimizer from ..solver import build_optimizer
...@@ -13,8 +10,6 @@ from .losses import GANLoss ...@@ -13,8 +10,6 @@ from .losses import GANLoss
from .builder import MODELS from .builder import MODELS
import importlib import importlib
import mmcv
import torch
from collections import OrderedDict from collections import OrderedDict
from copy import deepcopy from copy import deepcopy
from os import path as osp from os import path as osp
...@@ -24,12 +19,11 @@ from .builder import MODELS ...@@ -24,12 +19,11 @@ from .builder import MODELS
@MODELS.register() @MODELS.register()
class SRModel(BaseModel): class SRModel(BaseModel):
"""Base SR model for single image super-resolution.""" """Base SR model for single image super-resolution."""
def __init__(self, cfg): def __init__(self, cfg):
super(SRModel, self).__init__(cfg) super(SRModel, self).__init__(cfg)
self.model_names = ['G'] self.model_names = ['G']
self.netG = build_generator(cfg.model.generator) self.netG = build_generator(cfg.model.generator)
self.visual_names = ['lq', 'output', 'gt'] self.visual_names = ['lq', 'output', 'gt']
...@@ -119,7 +113,7 @@ class SRModel(BaseModel): ...@@ -119,7 +113,7 @@ class SRModel(BaseModel):
def forward(self): def forward(self):
pass pass
def test(self): def test(self):
"""Forward function used in test time. """Forward function used in test time.
""" """
...@@ -137,111 +131,7 @@ class SRModel(BaseModel): ...@@ -137,111 +131,7 @@ class SRModel(BaseModel):
l_pix = self.criterionL1(self.output, self.gt) l_pix = self.criterionL1(self.output, self.gt)
l_total += l_pix l_total += l_pix
loss_dict['l_pix'] = l_pix loss_dict['l_pix'] = l_pix
# perceptual loss
# if self.cri_perceptual:
# l_percep, l_style = self.cri_perceptual(self.output, self.gt)
# if l_percep is not None:
# l_total += l_percep
# loss_dict['l_percep'] = l_percep
# if l_style is not None:
# l_total += l_style
# loss_dict['l_style'] = l_style
l_total.backward() l_total.backward()
self.loss_l_total = l_total self.loss_l_total = l_total
self.optimizer_G.step() self.optimizer_G.step()
# self.log_dict = self.reduce_loss_dict(loss_dict)
# def get_current_visuals(self):
# out_dict = OrderedDict()
# out_dict['lq'] = self.lq.detach().cpu()
# out_dict['result'] = self.output.detach().cpu()
# if hasattr(self, 'gt'):
# out_dict['gt'] = self.gt.detach().cpu()
# return out_dict
# def test(self):
# self.net_g.eval()
# with torch.no_grad():
# self.output = self.net_g(self.lq)
# self.net_g.train()
# def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
# logger = get_root_logger()
# logger.info('Only support single GPU validation.')
# self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
# def nondist_validation(self, dataloader, current_iter, tb_logger,
# save_img):
# dataset_name = dataloader.dataset.opt['name']
# with_metrics = self.opt['val'].get('metrics') is not None
# if with_metrics:
# self.metric_results = {
# metric: 0
# for metric in self.opt['val']['metrics'].keys()
# }
# pbar = ProgressBar(len(dataloader))
# for idx, val_data in enumerate(dataloader):
# img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
# self.feed_data(val_data)
# self.test()
# visuals = self.get_current_visuals()
# sr_img = tensor2img([visuals['result']])
# if 'gt' in visuals:
# gt_img = tensor2img([visuals['gt']])
# del self.gt
# # tentative for out of GPU memory
# del self.lq
# del self.output
# torch.cuda.empty_cache()
# if save_img:
# if self.opt['is_train']:
# save_img_path = osp.join(self.opt['path']['visualization'],
# img_name,
# f'{img_name}_{current_iter}.png')
# else:
# if self.opt['val']['suffix']:
# save_img_path = osp.join(
# self.opt['path']['visualization'], dataset_name,
# f'{img_name}_{self.opt["val"]["suffix"]}.png')
# else:
# save_img_path = osp.join(
# self.opt['path']['visualization'], dataset_name,
# f'{img_name}_{self.opt["name"]}.png')
# mmcv.imwrite(sr_img, save_img_path)
# if with_metrics:
# # calculate metrics
# opt_metric = deepcopy(self.opt['val']['metrics'])
# for name, opt_ in opt_metric.items():
# metric_type = opt_.pop('type')
# self.metric_results[name] += getattr(
# metric_module, metric_type)(sr_img, gt_img, **opt_)
# pbar.update(f'Test {img_name}')
# if with_metrics:
# for metric in self.metric_results.keys():
# self.metric_results[metric] /= (idx + 1)
# self._log_validation_metric_values(current_iter, dataset_name,
# tb_logger)
# def _log_validation_metric_values(self, current_iter, dataset_name,
# tb_logger):
# log_str = f'Validation {dataset_name}\n'
# for metric, value in self.metric_results.items():
# log_str += f'\t # {metric}: {value:.4f}\n'
# logger = get_root_logger()
# logger.info(log_str)
# if tb_logger:
# for metric, value in self.metric_results.items():
# tb_logger.add_scalar(f'metrics/{metric}', value, current_iter)
# def save(self, epoch, current_iter):
# self.save_network(self.net_g, 'net_g', current_iter)
# self.save_training_state(epoch, current_iter)
# import logging
from collections import OrderedDict from collections import OrderedDict
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
# import torch.nn.parallel as P
# from torch.nn.parallel import DataParallel, DistributedDataParallel
# import models.networks as networks
# import models.lr_scheduler as lr_scheduler
from .generators.builder import build_generator from .generators.builder import build_generator
from .base_model import BaseModel from .base_model import BaseModel
from .losses import GANLoss from .losses import GANLoss
from .builder import MODELS from .builder import MODELS
# logger = logging.getLogger('base')
@MODELS.register() @MODELS.register()
...@@ -27,7 +22,6 @@ class SRGANModel(BaseModel): ...@@ -27,7 +22,6 @@ class SRGANModel(BaseModel):
# TODO: support srgan train. # TODO: support srgan train.
if False: if False:
# self.netD = build_discriminator(cfg.model.discriminator) # self.netD = build_discriminator(cfg.model.discriminator)
self.netG.train() self.netG.train()
# self.netD.train() # self.netD.train()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册