Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleGAN
提交
99d09216
P
PaddleGAN
项目概览
PaddlePaddle
/
PaddleGAN
大约 1 年 前同步成功
通知
97
Star
7254
Fork
1210
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
4
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleGAN
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
4
Issue
4
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
99d09216
编写于
9月 16, 2020
作者:
L
LielinJiang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
rm unused code
上级
a0a56e75
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
4 addition
and
120 deletion
+4
-120
ppgan/models/sr_model.py
ppgan/models/sr_model.py
+3
-113
ppgan/models/srgan_model.py
ppgan/models/srgan_model.py
+1
-7
未找到文件。
ppgan/models/sr_model.py
浏览文件 @
99d09216
from
collections
import
OrderedDict
import
paddle
import
paddle.nn
as
nn
# import torch.nn.parallel as P
# from torch.nn.parallel import DataParallel, DistributedDataParallel
# import models.networks as networks
# import models.lr_scheduler as lr_scheduler
from
.generators.builder
import
build_generator
from
.discriminators.builder
import
build_discriminator
from
..solver
import
build_optimizer
...
...
@@ -13,8 +10,6 @@ from .losses import GANLoss
from
.builder
import
MODELS
import
importlib
import
mmcv
import
torch
from
collections
import
OrderedDict
from
copy
import
deepcopy
from
os
import
path
as
osp
...
...
@@ -24,12 +19,11 @@ from .builder import MODELS
@
MODELS
.
register
()
class
SRModel
(
BaseModel
):
"""Base SR model for single image super-resolution."""
def
__init__
(
self
,
cfg
):
super
(
SRModel
,
self
).
__init__
(
cfg
)
self
.
model_names
=
[
'G'
]
self
.
netG
=
build_generator
(
cfg
.
model
.
generator
)
self
.
visual_names
=
[
'lq'
,
'output'
,
'gt'
]
...
...
@@ -119,7 +113,7 @@ class SRModel(BaseModel):
def
forward
(
self
):
pass
def
test
(
self
):
"""Forward function used in test time.
"""
...
...
@@ -137,111 +131,7 @@ class SRModel(BaseModel):
l_pix
=
self
.
criterionL1
(
self
.
output
,
self
.
gt
)
l_total
+=
l_pix
loss_dict
[
'l_pix'
]
=
l_pix
# perceptual loss
# if self.cri_perceptual:
# l_percep, l_style = self.cri_perceptual(self.output, self.gt)
# if l_percep is not None:
# l_total += l_percep
# loss_dict['l_percep'] = l_percep
# if l_style is not None:
# l_total += l_style
# loss_dict['l_style'] = l_style
l_total
.
backward
()
self
.
loss_l_total
=
l_total
self
.
optimizer_G
.
step
()
# self.log_dict = self.reduce_loss_dict(loss_dict)
# def get_current_visuals(self):
# out_dict = OrderedDict()
# out_dict['lq'] = self.lq.detach().cpu()
# out_dict['result'] = self.output.detach().cpu()
# if hasattr(self, 'gt'):
# out_dict['gt'] = self.gt.detach().cpu()
# return out_dict
# def test(self):
# self.net_g.eval()
# with torch.no_grad():
# self.output = self.net_g(self.lq)
# self.net_g.train()
# def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
# logger = get_root_logger()
# logger.info('Only support single GPU validation.')
# self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
# def nondist_validation(self, dataloader, current_iter, tb_logger,
# save_img):
# dataset_name = dataloader.dataset.opt['name']
# with_metrics = self.opt['val'].get('metrics') is not None
# if with_metrics:
# self.metric_results = {
# metric: 0
# for metric in self.opt['val']['metrics'].keys()
# }
# pbar = ProgressBar(len(dataloader))
# for idx, val_data in enumerate(dataloader):
# img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
# self.feed_data(val_data)
# self.test()
# visuals = self.get_current_visuals()
# sr_img = tensor2img([visuals['result']])
# if 'gt' in visuals:
# gt_img = tensor2img([visuals['gt']])
# del self.gt
# # tentative for out of GPU memory
# del self.lq
# del self.output
# torch.cuda.empty_cache()
# if save_img:
# if self.opt['is_train']:
# save_img_path = osp.join(self.opt['path']['visualization'],
# img_name,
# f'{img_name}_{current_iter}.png')
# else:
# if self.opt['val']['suffix']:
# save_img_path = osp.join(
# self.opt['path']['visualization'], dataset_name,
# f'{img_name}_{self.opt["val"]["suffix"]}.png')
# else:
# save_img_path = osp.join(
# self.opt['path']['visualization'], dataset_name,
# f'{img_name}_{self.opt["name"]}.png')
# mmcv.imwrite(sr_img, save_img_path)
# if with_metrics:
# # calculate metrics
# opt_metric = deepcopy(self.opt['val']['metrics'])
# for name, opt_ in opt_metric.items():
# metric_type = opt_.pop('type')
# self.metric_results[name] += getattr(
# metric_module, metric_type)(sr_img, gt_img, **opt_)
# pbar.update(f'Test {img_name}')
# if with_metrics:
# for metric in self.metric_results.keys():
# self.metric_results[metric] /= (idx + 1)
# self._log_validation_metric_values(current_iter, dataset_name,
# tb_logger)
# def _log_validation_metric_values(self, current_iter, dataset_name,
# tb_logger):
# log_str = f'Validation {dataset_name}\n'
# for metric, value in self.metric_results.items():
# log_str += f'\t # {metric}: {value:.4f}\n'
# logger = get_root_logger()
# logger.info(log_str)
# if tb_logger:
# for metric, value in self.metric_results.items():
# tb_logger.add_scalar(f'metrics/{metric}', value, current_iter)
# def save(self, epoch, current_iter):
# self.save_network(self.net_g, 'net_g', current_iter)
# self.save_training_state(epoch, current_iter)
ppgan/models/srgan_model.py
浏览文件 @
99d09216
# import logging
from
collections
import
OrderedDict
import
paddle
import
paddle.nn
as
nn
# import torch.nn.parallel as P
# from torch.nn.parallel import DataParallel, DistributedDataParallel
# import models.networks as networks
# import models.lr_scheduler as lr_scheduler
from
.generators.builder
import
build_generator
from
.base_model
import
BaseModel
from
.losses
import
GANLoss
from
.builder
import
MODELS
# logger = logging.getLogger('base')
@
MODELS
.
register
()
...
...
@@ -27,7 +22,6 @@ class SRGANModel(BaseModel):
# TODO: support srgan train.
if
False
:
# self.netD = build_discriminator(cfg.model.discriminator)
self
.
netG
.
train
()
# self.netD.train()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录