提交 2aaef2e9 编写于 作者: L LielinJiang

add srmodel

上级 54896a27
from .unpaired_dataset import UnpairedDataset from .unpaired_dataset import UnpairedDataset
from .single_dataset import SingleDataset from .single_dataset import SingleDataset
from .paired_dataset import PairedDataset from .paired_dataset import PairedDataset
from .sr_image_dataset import SRImageDataset
\ No newline at end of file
...@@ -95,6 +95,9 @@ def get_transform(cfg, ...@@ -95,6 +95,9 @@ def get_transform(cfg,
if convert: if convert:
transform_list += [transforms.Permute(to_rgb=True)] transform_list += [transforms.Permute(to_rgb=True)]
transform_list += [ transform_list += [
transforms.Normalize((127.5, 127.5, 127.5), (127.5, 127.5, 127.5)) transforms.Normalize((0., 0., 0.), (255., 255., 255.))
] ]
# transform_list += [
# transforms.Normalize((127.5, 127.5, 127.5), (127.5, 127.5, 127.5))
# ]
return transforms.Compose(transform_list) return transforms.Compose(transform_list)
...@@ -111,4 +111,8 @@ def build_dataloader(cfg, is_train=True): ...@@ -111,4 +111,8 @@ def build_dataloader(cfg, is_train=True):
dataloader = DictDataLoader(dataset, batch_size, is_train, num_workers) dataloader = DictDataLoader(dataset, batch_size, is_train, num_workers)
# for i, item in enumerate(dataloader):
# print(i, item.keys())
# # break
# print('dataset build success!')
return dataloader return dataloader
import os import os
import time import time
import copy
import logging import logging
import paddle import paddle
...@@ -10,7 +11,7 @@ from ..datasets.builder import build_dataloader ...@@ -10,7 +11,7 @@ from ..datasets.builder import build_dataloader
from ..models.builder import build_model from ..models.builder import build_model
from ..utils.visual import tensor2img, save_image from ..utils.visual import tensor2img, save_image
from ..utils.filesystem import save, load, makedirs from ..utils.filesystem import save, load, makedirs
from ..metric.psnr_ssim import calculate_psnr, calculate_ssim
class Trainer: class Trainer:
def __init__(self, cfg): def __init__(self, cfg):
...@@ -45,9 +46,11 @@ class Trainer: ...@@ -45,9 +46,11 @@ class Trainer:
# time count # time count
self.time_count = {} self.time_count = {}
self.best_metric = {}
def distributed_data_parallel(self): def distributed_data_parallel(self):
strategy = paddle.prepare_context() strategy = paddle.distributed.prepare_context()
for name in self.model.model_names: for name in self.model.model_names:
if isinstance(name, str): if isinstance(name, str):
net = getattr(self.model, 'net' + name) net = getattr(self.model, 'net' + name)
...@@ -78,11 +81,61 @@ class Trainer: ...@@ -78,11 +81,61 @@ class Trainer:
step_start_time = time.time() step_start_time = time.time()
self.logger.info('train one epoch time: {}'.format(time.time() - self.logger.info('train one epoch time: {}'.format(time.time() -
start_time)) start_time))
self.validate()
self.model.lr_scheduler.step() self.model.lr_scheduler.step()
if epoch % self.weight_interval == 0: if epoch % self.weight_interval == 0:
self.save(epoch, 'weight', keep=-1) self.save(epoch, 'weight', keep=-1)
self.save(epoch) self.save(epoch)
def validate(self):
if not hasattr(self, 'val_dataloader'):
self.val_dataloader = build_dataloader(self.cfg.dataset.val, is_train=False)
metric_result = {}
for i, data in enumerate(self.val_dataloader):
self.batch_id = i
self.model.set_input(data)
self.model.test()
visual_results = {}
current_paths = self.model.get_image_paths()
current_visuals = self.model.get_current_visuals()
# print('debug1:', self.cfg.validate.metrics)
for j in range(len(current_paths)):
short_path = os.path.basename(current_paths[j])
basename = os.path.splitext(short_path)[0]
for k, img_tensor in current_visuals.items():
name = '%s_%s' % (basename, k)
visual_results.update({name: img_tensor[j]})
# print('debug2:', self.cfg.validate.metrics)
if 'psnr' in self.cfg.validate.metrics:
# args = copy.deepcopy(self.cfg.validate.metrics.pnsr)
# args.pop('name')
if 'psnr' not in metric_result:
metric_result['psnr'] = calculate_psnr(tensor2img(current_visuals['output'][j]), tensor2img(current_visuals['gt'][j]), **self.cfg.validate.metrics.psnr)
else:
metric_result['psnr'] += calculate_psnr(tensor2img(current_visuals['output'][j]), tensor2img(current_visuals['gt'][j]), **self.cfg.validate.metrics.psnr)
if 'ssim' in self.cfg.validate.metrics:
if 'ssim' not in metric_result:
metric_result['ssim'] = calculate_ssim(tensor2img(current_visuals['output'][j]), tensor2img(current_visuals['gt'][j]), **self.cfg.validate.metrics.ssim)
else:
metric_result['ssim'] += calculate_ssim(tensor2img(current_visuals['output'][j]), tensor2img(current_visuals['gt'][j]), **self.cfg.validate.metrics.ssim)
self.visual('visual_val', visual_results=visual_results)
if i % self.log_interval == 0:
self.logger.info('val iter: [%d/%d]' %
(i, len(self.val_dataloader)))
for metric_name in metric_result.keys():
metric_result[metric_name] /= len(self.val_dataloader.dataset)
self.logger.info('Epoch {} validate end: {}'.format(self.current_epoch, metric_result))
def test(self): def test(self):
if not hasattr(self, 'test_dataloader'): if not hasattr(self, 'test_dataloader'):
self.test_dataloader = build_dataloader(self.cfg.dataset.test, self.test_dataloader = build_dataloader(self.cfg.dataset.test,
...@@ -210,5 +263,6 @@ class Trainer: ...@@ -210,5 +263,6 @@ class Trainer:
for name in self.model.model_names: for name in self.model.model_names:
if isinstance(name, str): if isinstance(name, str):
self.logger.info('laod model {} {} params!'.format(self.cfg.model.name, 'net' + name))
net = getattr(self.model, 'net' + name) net = getattr(self.model, 'net' + name)
net.set_dict(state_dicts['net' + name]) net.set_dict(state_dicts['net' + name])
from .base_model import BaseModel from .base_model import BaseModel
from .cycle_gan_model import CycleGANModel from .cycle_gan_model import CycleGANModel
from .pix2pix_model import Pix2PixModel from .pix2pix_model import Pix2PixModel
from .srgan_model import SRGANModel
from .sr_model import SRModel
from .resnet import ResnetGenerator from .resnet import ResnetGenerator
from .unet import UnetGenerator from .unet import UnetGenerator
\ No newline at end of file from .rrdb_net import RRDBNet
\ No newline at end of file
...@@ -15,7 +15,9 @@ def tensor2img(input_image, imtype=np.uint8): ...@@ -15,7 +15,9 @@ def tensor2img(input_image, imtype=np.uint8):
image_numpy = image_numpy[0] image_numpy = image_numpy[0]
if image_numpy.shape[0] == 1: # grayscale to RGB if image_numpy.shape[0] == 1: # grayscale to RGB
image_numpy = np.tile(image_numpy, (3, 1, 1)) image_numpy = np.tile(image_numpy, (3, 1, 1))
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling # image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling
image_numpy = image_numpy.clip(0, 1)
image_numpy = (np.transpose(image_numpy, (1, 2, 0))) * 255.0 # post-processing: tranpose and scaling
else: # if it is a numpy array, do nothing else: # if it is a numpy array, do nothing
image_numpy = input_image image_numpy = input_image
return image_numpy.astype(imtype) return image_numpy.astype(imtype)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册