提交 99d09216 编写于 作者: L LielinJiang

rm unused code

上级 a0a56e75
from collections import OrderedDict
import paddle
import paddle.nn as nn
# import torch.nn.parallel as P
# from torch.nn.parallel import DataParallel, DistributedDataParallel
# import models.networks as networks
# import models.lr_scheduler as lr_scheduler
from .generators.builder import build_generator
from .discriminators.builder import build_discriminator
from ..solver import build_optimizer
......@@ -13,8 +10,6 @@ from .losses import GANLoss
from .builder import MODELS
import importlib
import mmcv
import torch
from collections import OrderedDict
from copy import deepcopy
from os import path as osp
......@@ -24,12 +19,11 @@ from .builder import MODELS
@MODELS.register()
class SRModel(BaseModel):
"""Base SR model for single image super-resolution."""
def __init__(self, cfg):
super(SRModel, self).__init__(cfg)
self.model_names = ['G']
self.netG = build_generator(cfg.model.generator)
self.visual_names = ['lq', 'output', 'gt']
......@@ -119,7 +113,7 @@ class SRModel(BaseModel):
def forward(self):
pass
def test(self):
"""Forward function used in test time.
"""
......@@ -137,111 +131,7 @@ class SRModel(BaseModel):
l_pix = self.criterionL1(self.output, self.gt)
l_total += l_pix
loss_dict['l_pix'] = l_pix
# perceptual loss
# if self.cri_perceptual:
# l_percep, l_style = self.cri_perceptual(self.output, self.gt)
# if l_percep is not None:
# l_total += l_percep
# loss_dict['l_percep'] = l_percep
# if l_style is not None:
# l_total += l_style
# loss_dict['l_style'] = l_style
l_total.backward()
self.loss_l_total = l_total
self.optimizer_G.step()
# self.log_dict = self.reduce_loss_dict(loss_dict)
# def get_current_visuals(self):
# out_dict = OrderedDict()
# out_dict['lq'] = self.lq.detach().cpu()
# out_dict['result'] = self.output.detach().cpu()
# if hasattr(self, 'gt'):
# out_dict['gt'] = self.gt.detach().cpu()
# return out_dict
# def test(self):
# self.net_g.eval()
# with torch.no_grad():
# self.output = self.net_g(self.lq)
# self.net_g.train()
# def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
# logger = get_root_logger()
# logger.info('Only support single GPU validation.')
# self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
# def nondist_validation(self, dataloader, current_iter, tb_logger,
# save_img):
# dataset_name = dataloader.dataset.opt['name']
# with_metrics = self.opt['val'].get('metrics') is not None
# if with_metrics:
# self.metric_results = {
# metric: 0
# for metric in self.opt['val']['metrics'].keys()
# }
# pbar = ProgressBar(len(dataloader))
# for idx, val_data in enumerate(dataloader):
# img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
# self.feed_data(val_data)
# self.test()
# visuals = self.get_current_visuals()
# sr_img = tensor2img([visuals['result']])
# if 'gt' in visuals:
# gt_img = tensor2img([visuals['gt']])
# del self.gt
# # tentative for out of GPU memory
# del self.lq
# del self.output
# torch.cuda.empty_cache()
# if save_img:
# if self.opt['is_train']:
# save_img_path = osp.join(self.opt['path']['visualization'],
# img_name,
# f'{img_name}_{current_iter}.png')
# else:
# if self.opt['val']['suffix']:
# save_img_path = osp.join(
# self.opt['path']['visualization'], dataset_name,
# f'{img_name}_{self.opt["val"]["suffix"]}.png')
# else:
# save_img_path = osp.join(
# self.opt['path']['visualization'], dataset_name,
# f'{img_name}_{self.opt["name"]}.png')
# mmcv.imwrite(sr_img, save_img_path)
# if with_metrics:
# # calculate metrics
# opt_metric = deepcopy(self.opt['val']['metrics'])
# for name, opt_ in opt_metric.items():
# metric_type = opt_.pop('type')
# self.metric_results[name] += getattr(
# metric_module, metric_type)(sr_img, gt_img, **opt_)
# pbar.update(f'Test {img_name}')
# if with_metrics:
# for metric in self.metric_results.keys():
# self.metric_results[metric] /= (idx + 1)
# self._log_validation_metric_values(current_iter, dataset_name,
# tb_logger)
# def _log_validation_metric_values(self, current_iter, dataset_name,
# tb_logger):
# log_str = f'Validation {dataset_name}\n'
# for metric, value in self.metric_results.items():
# log_str += f'\t # {metric}: {value:.4f}\n'
# logger = get_root_logger()
# logger.info(log_str)
# if tb_logger:
# for metric, value in self.metric_results.items():
# tb_logger.add_scalar(f'metrics/{metric}', value, current_iter)
# def save(self, epoch, current_iter):
# self.save_network(self.net_g, 'net_g', current_iter)
# self.save_training_state(epoch, current_iter)
# import logging
from collections import OrderedDict
import paddle
import paddle.nn as nn
# import torch.nn.parallel as P
# from torch.nn.parallel import DataParallel, DistributedDataParallel
# import models.networks as networks
# import models.lr_scheduler as lr_scheduler
from .generators.builder import build_generator
from .base_model import BaseModel
from .losses import GANLoss
from .builder import MODELS
# logger = logging.getLogger('base')
@MODELS.register()
......@@ -27,7 +22,6 @@ class SRGANModel(BaseModel):
# TODO: support srgan train.
if False:
# self.netD = build_discriminator(cfg.model.discriminator)
self.netG.train()
# self.netD.train()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册