diff --git a/ppgan/datasets/__init__.py b/ppgan/datasets/__init__.py index 9b807e9be0c83dda6415ebf01418cc77b8f463ba..0aeb70936b58125fb92d00ce5905e2608142f728 100644 --- a/ppgan/datasets/__init__.py +++ b/ppgan/datasets/__init__.py @@ -1,3 +1,4 @@ from .unpaired_dataset import UnpairedDataset from .single_dataset import SingleDataset from .paired_dataset import PairedDataset +from .sr_image_dataset import SRImageDataset \ No newline at end of file diff --git a/ppgan/datasets/base_dataset.py b/ppgan/datasets/base_dataset.py index 87e996925477c5fa096df48779327397ce22873b..c0036b40d441380ae3b66ed7b6ecb3e95e5841d1 100644 --- a/ppgan/datasets/base_dataset.py +++ b/ppgan/datasets/base_dataset.py @@ -95,6 +95,9 @@ def get_transform(cfg, if convert: transform_list += [transforms.Permute(to_rgb=True)] transform_list += [ - transforms.Normalize((127.5, 127.5, 127.5), (127.5, 127.5, 127.5)) + transforms.Normalize((0., 0., 0.), (255., 255., 255.)) ] + # transform_list += [ + # transforms.Normalize((127.5, 127.5, 127.5), (127.5, 127.5, 127.5)) + # ] return transforms.Compose(transform_list) diff --git a/ppgan/datasets/builder.py b/ppgan/datasets/builder.py index 62b5346795c1383d683926e46064b97ea8a14aee..284c774214371554915c784ecd1a33fbdf04a133 100644 --- a/ppgan/datasets/builder.py +++ b/ppgan/datasets/builder.py @@ -111,4 +111,8 @@ def build_dataloader(cfg, is_train=True): dataloader = DictDataLoader(dataset, batch_size, is_train, num_workers) + # for i, item in enumerate(dataloader): + # print(i, item.keys()) + # # break + # print('dataset build success!') return dataloader diff --git a/ppgan/engine/trainer.py b/ppgan/engine/trainer.py index f7f456962e61c4358853bcf2365006fd340e68f6..2e1f839fdeff91755b8a0de54e278cb37f7f78fa 100644 --- a/ppgan/engine/trainer.py +++ b/ppgan/engine/trainer.py @@ -1,5 +1,6 @@ import os import time +import copy import logging import paddle @@ -10,7 +11,7 @@ from ..datasets.builder import build_dataloader from ..models.builder import build_model from ..utils.visual import tensor2img, save_image from ..utils.filesystem import save, load, makedirs - +from ..metric.psnr_ssim import calculate_psnr, calculate_ssim class Trainer: def __init__(self, cfg): @@ -45,9 +46,11 @@ class Trainer: # time count self.time_count = {} + self.best_metric = {} + def distributed_data_parallel(self): - strategy = paddle.prepare_context() + strategy = paddle.distributed.prepare_context() for name in self.model.model_names: if isinstance(name, str): net = getattr(self.model, 'net' + name) @@ -78,11 +81,61 @@ class Trainer: step_start_time = time.time() self.logger.info('train one epoch time: {}'.format(time.time() - start_time)) + self.validate() self.model.lr_scheduler.step() if epoch % self.weight_interval == 0: self.save(epoch, 'weight', keep=-1) self.save(epoch) + def validate(self): + if not hasattr(self, 'val_dataloader'): + self.val_dataloader = build_dataloader(self.cfg.dataset.val, is_train=False) + + metric_result = {} + + for i, data in enumerate(self.val_dataloader): + self.batch_id = i + + self.model.set_input(data) + self.model.test() + + visual_results = {} + current_paths = self.model.get_image_paths() + current_visuals = self.model.get_current_visuals() + + # print('debug1:', self.cfg.validate.metrics) + for j in range(len(current_paths)): + short_path = os.path.basename(current_paths[j]) + basename = os.path.splitext(short_path)[0] + for k, img_tensor in current_visuals.items(): + name = '%s_%s' % (basename, k) + visual_results.update({name: img_tensor[j]}) + # print('debug2:', self.cfg.validate.metrics) + if 'psnr' in self.cfg.validate.metrics: + # args = copy.deepcopy(self.cfg.validate.metrics.pnsr) + # args.pop('name') + if 'psnr' not in metric_result: + metric_result['psnr'] = calculate_psnr(tensor2img(current_visuals['output'][j]), tensor2img(current_visuals['gt'][j]), **self.cfg.validate.metrics.psnr) + else: + metric_result['psnr'] += calculate_psnr(tensor2img(current_visuals['output'][j]), tensor2img(current_visuals['gt'][j]), **self.cfg.validate.metrics.psnr) + if 'ssim' in self.cfg.validate.metrics: + if 'ssim' not in metric_result: + metric_result['ssim'] = calculate_ssim(tensor2img(current_visuals['output'][j]), tensor2img(current_visuals['gt'][j]), **self.cfg.validate.metrics.ssim) + else: + metric_result['ssim'] += calculate_ssim(tensor2img(current_visuals['output'][j]), tensor2img(current_visuals['gt'][j]), **self.cfg.validate.metrics.ssim) + + self.visual('visual_val', visual_results=visual_results) + + if i % self.log_interval == 0: + self.logger.info('val iter: [%d/%d]' % + (i, len(self.val_dataloader))) + + for metric_name in metric_result.keys(): + metric_result[metric_name] /= len(self.val_dataloader.dataset) + + self.logger.info('Epoch {} validate end: {}'.format(self.current_epoch, metric_result)) + + def test(self): if not hasattr(self, 'test_dataloader'): self.test_dataloader = build_dataloader(self.cfg.dataset.test, @@ -210,5 +263,6 @@ class Trainer: for name in self.model.model_names: if isinstance(name, str): + self.logger.info('laod model {} {} params!'.format(self.cfg.model.name, 'net' + name)) net = getattr(self.model, 'net' + name) net.set_dict(state_dicts['net' + name]) diff --git a/ppgan/models/__init__.py b/ppgan/models/__init__.py index 621e8edf337837440e679d9969ffd614520268e2..1fb4e96098b6ea230c029c8c0f0ff7ad2eb5b139 100644 --- a/ppgan/models/__init__.py +++ b/ppgan/models/__init__.py @@ -1,4 +1,6 @@ from .base_model import BaseModel from .cycle_gan_model import CycleGANModel from .pix2pix_model import Pix2PixModel +from .srgan_model import SRGANModel +from .sr_model import SRModel diff --git a/ppgan/models/generators/__init__.py b/ppgan/models/generators/__init__.py index 840d716a1399053dbb621e4d4ede60aec4afde0b..15ac59d156f852e10fa4263fb4fd5b1fe9f7a976 100644 --- a/ppgan/models/generators/__init__.py +++ b/ppgan/models/generators/__init__.py @@ -1,2 +1,3 @@ from .resnet import ResnetGenerator -from .unet import UnetGenerator \ No newline at end of file +from .unet import UnetGenerator +from .rrdb_net import RRDBNet \ No newline at end of file diff --git a/ppgan/utils/visual.py b/ppgan/utils/visual.py index a50c59eb673a3a792e0712c94f898dee94e12e5d..56639acb52b6822d37bf3a0ba0cfee8522f4b42c 100644 --- a/ppgan/utils/visual.py +++ b/ppgan/utils/visual.py @@ -15,7 +15,9 @@ def tensor2img(input_image, imtype=np.uint8): image_numpy = image_numpy[0] if image_numpy.shape[0] == 1: # grayscale to RGB image_numpy = np.tile(image_numpy, (3, 1, 1)) - image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling + # image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling + image_numpy = image_numpy.clip(0, 1) + image_numpy = (np.transpose(image_numpy, (1, 2, 0))) * 255.0 # post-processing: tranpose and scaling else: # if it is a numpy array, do nothing image_numpy = input_image return image_numpy.astype(imtype)