提交 2aaef2e9 编写于 作者: L LielinJiang

add srmodel

上级 54896a27
from .unpaired_dataset import UnpairedDataset
from .single_dataset import SingleDataset
from .paired_dataset import PairedDataset
from .sr_image_dataset import SRImageDataset
\ No newline at end of file
......@@ -95,6 +95,9 @@ def get_transform(cfg,
if convert:
transform_list += [transforms.Permute(to_rgb=True)]
transform_list += [
transforms.Normalize((127.5, 127.5, 127.5), (127.5, 127.5, 127.5))
transforms.Normalize((0., 0., 0.), (255., 255., 255.))
]
# transform_list += [
# transforms.Normalize((127.5, 127.5, 127.5), (127.5, 127.5, 127.5))
# ]
return transforms.Compose(transform_list)
......@@ -111,4 +111,8 @@ def build_dataloader(cfg, is_train=True):
dataloader = DictDataLoader(dataset, batch_size, is_train, num_workers)
# for i, item in enumerate(dataloader):
# print(i, item.keys())
# # break
# print('dataset build success!')
return dataloader
import os
import time
import copy
import logging
import paddle
......@@ -10,7 +11,7 @@ from ..datasets.builder import build_dataloader
from ..models.builder import build_model
from ..utils.visual import tensor2img, save_image
from ..utils.filesystem import save, load, makedirs
from ..metric.psnr_ssim import calculate_psnr, calculate_ssim
class Trainer:
def __init__(self, cfg):
......@@ -45,9 +46,11 @@ class Trainer:
# time count
self.time_count = {}
self.best_metric = {}
def distributed_data_parallel(self):
strategy = paddle.prepare_context()
strategy = paddle.distributed.prepare_context()
for name in self.model.model_names:
if isinstance(name, str):
net = getattr(self.model, 'net' + name)
......@@ -78,11 +81,61 @@ class Trainer:
step_start_time = time.time()
self.logger.info('train one epoch time: {}'.format(time.time() -
start_time))
self.validate()
self.model.lr_scheduler.step()
if epoch % self.weight_interval == 0:
self.save(epoch, 'weight', keep=-1)
self.save(epoch)
def validate(self):
if not hasattr(self, 'val_dataloader'):
self.val_dataloader = build_dataloader(self.cfg.dataset.val, is_train=False)
metric_result = {}
for i, data in enumerate(self.val_dataloader):
self.batch_id = i
self.model.set_input(data)
self.model.test()
visual_results = {}
current_paths = self.model.get_image_paths()
current_visuals = self.model.get_current_visuals()
# print('debug1:', self.cfg.validate.metrics)
for j in range(len(current_paths)):
short_path = os.path.basename(current_paths[j])
basename = os.path.splitext(short_path)[0]
for k, img_tensor in current_visuals.items():
name = '%s_%s' % (basename, k)
visual_results.update({name: img_tensor[j]})
# print('debug2:', self.cfg.validate.metrics)
if 'psnr' in self.cfg.validate.metrics:
# args = copy.deepcopy(self.cfg.validate.metrics.pnsr)
# args.pop('name')
if 'psnr' not in metric_result:
metric_result['psnr'] = calculate_psnr(tensor2img(current_visuals['output'][j]), tensor2img(current_visuals['gt'][j]), **self.cfg.validate.metrics.psnr)
else:
metric_result['psnr'] += calculate_psnr(tensor2img(current_visuals['output'][j]), tensor2img(current_visuals['gt'][j]), **self.cfg.validate.metrics.psnr)
if 'ssim' in self.cfg.validate.metrics:
if 'ssim' not in metric_result:
metric_result['ssim'] = calculate_ssim(tensor2img(current_visuals['output'][j]), tensor2img(current_visuals['gt'][j]), **self.cfg.validate.metrics.ssim)
else:
metric_result['ssim'] += calculate_ssim(tensor2img(current_visuals['output'][j]), tensor2img(current_visuals['gt'][j]), **self.cfg.validate.metrics.ssim)
self.visual('visual_val', visual_results=visual_results)
if i % self.log_interval == 0:
self.logger.info('val iter: [%d/%d]' %
(i, len(self.val_dataloader)))
for metric_name in metric_result.keys():
metric_result[metric_name] /= len(self.val_dataloader.dataset)
self.logger.info('Epoch {} validate end: {}'.format(self.current_epoch, metric_result))
def test(self):
if not hasattr(self, 'test_dataloader'):
self.test_dataloader = build_dataloader(self.cfg.dataset.test,
......@@ -210,5 +263,6 @@ class Trainer:
for name in self.model.model_names:
if isinstance(name, str):
self.logger.info('laod model {} {} params!'.format(self.cfg.model.name, 'net' + name))
net = getattr(self.model, 'net' + name)
net.set_dict(state_dicts['net' + name])
from .base_model import BaseModel
from .cycle_gan_model import CycleGANModel
from .pix2pix_model import Pix2PixModel
from .srgan_model import SRGANModel
from .sr_model import SRModel
from .resnet import ResnetGenerator
from .unet import UnetGenerator
\ No newline at end of file
from .unet import UnetGenerator
from .rrdb_net import RRDBNet
\ No newline at end of file
......@@ -15,7 +15,9 @@ def tensor2img(input_image, imtype=np.uint8):
image_numpy = image_numpy[0]
if image_numpy.shape[0] == 1: # grayscale to RGB
image_numpy = np.tile(image_numpy, (3, 1, 1))
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling
# image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling
image_numpy = image_numpy.clip(0, 1)
image_numpy = (np.transpose(image_numpy, (1, 2, 0))) * 255.0 # post-processing: tranpose and scaling
else: # if it is a numpy array, do nothing
image_numpy = input_image
return image_numpy.astype(imtype)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册