未验证 提交 b6204126 编写于 作者: L LielinJiang 提交者: GitHub

Fix mprnet train bug and add docs (#506)

* fix mprnet train and add docs

* update config
上级 5bf728df
total_iters: 100000
# epoch: 3000 for total batch size=16
total_iters: 400000
output_dir: output_dir
model:
......@@ -15,38 +16,38 @@ dataset:
train:
name: MPRTrain
rgb_dir: 'data/GoPro/train'
num_workers: 16
batch_size: 4
num_workers: 4
batch_size: 2
img_options:
patch_size: 256
test:
name: MPRTrain
name: MPRVal
rgb_dir: 'data/GoPro/test'
num_workers: 16
batch_size: 4
num_workers: 4
batch_size: 2
img_options:
patch_size: 256
lr_scheduler:
name: CosineAnnealingRestartLR
learning_rate: !!float 2e-4
periods: [25000, 25000, 25000, 25000]
restart_weights: [1, 1, 1, 1]
learning_rate: !!float 1e-4
periods: [400000]
restart_weights: [1]
eta_min: !!float 1e-6
validate:
interval: 10
interval: 5000
save_img: false
metrics:
psnr: # metric name, can be arbitrary
name: PSNR
crop_border: 4
test_y_channel: True
test_y_channel: false
ssim:
name: SSIM
crop_border: 4
test_y_channel: True
test_y_channel: false
optimizer:
name: Adam
......@@ -59,7 +60,7 @@ optimizer:
epsilon: 1e-8
log_config:
interval: 10
interval: 100
visiual_interval: 5000
snapshot_config:
......
......@@ -130,6 +130,10 @@ The metrics are PSNR / SSIM.
| pan_x4 | 30.4574 / 0.8643 | 26.7204 / 0.7434 | 28.9187 / 0.8176 |
| drns_x4 | 32.6684 / 0.8999 | 28.9037 / 0.7885 | - |
Deblur models zoo
| model | GoPro | Download Link |
|---|---|---|
| MPRNet | 33.4360 / 0.9410 | [link](https://paddlegan.bj.bcebos.com/models/MPR_Deblurring.pdparams) |
<!-- ![](../../imgs/horse2zebra.png) -->
......
......@@ -120,6 +120,10 @@ paddle模型使用DIV2K数据集训练,torch模型使用df2k和DIV2K数据集
| paddle | 30.4574 / 0.8643 | 26.7204 / 0.7434 |
| torch | 30.2183 / 0.8643 | 26.8035 / 0.7445 |
去模糊模型
| 模型 | GoPro | 下载地址 |
|---|---|---|
| MPRNet | 33.4360 / 0.9410 | [链接](https://paddlegan.bj.bcebos.com/models/MPR_Deblurring.pdparams) |
<!-- ![](../../imgs/horse2zebra.png) -->
......
......@@ -29,6 +29,7 @@ from ..utils.filesystem import makedirs, save, load
from ..utils.timer import TimeAverager
from ..utils.profiler import add_profiler_step
class IterLoader:
def __init__(self, dataloader):
self._dataloader = dataloader
......@@ -429,15 +430,35 @@ class Trainer:
def load(self, weight_path):
state_dicts = load(weight_path)
for net_name, net in self.model.nets.items():
if net_name in state_dicts:
net.set_state_dict(state_dicts[net_name])
self.logger.info(
'Loaded pretrained weight for net {}'.format(net_name))
def is_dict_in_dict_weight(state_dict):
if isinstance(state_dict, dict) and len(state_dict) > 0:
val = list(state_dict.values())[0]
if isinstance(val, dict):
return True
else:
return False
else:
self.logger.warning(
'Can not find state dict of net {}. Skip load pretrained weight for net {}'
.format(net_name, net_name))
return False
if is_dict_in_dict_weight(state_dicts):
for net_name, net in self.model.nets.items():
if net_name in state_dicts:
net.set_state_dict(state_dicts[net_name])
self.logger.info(
'Loaded pretrained weight for net {}'.format(net_name))
else:
self.logger.warning(
'Can not find state dict of net {}. Skip load pretrained weight for net {}'
.format(net_name, net_name))
else:
assert len(self.model.nets
) == 1, 'checkpoint only contain weight of one net, \
but model contains more than one net!'
net_name, net = list(self.model.nets.items())[0]
net.set_state_dict(state_dicts)
self.logger.info(
'Loaded pretrained weight for net {}'.format(net_name))
def close(self):
"""
......
......@@ -249,23 +249,25 @@ class CalcStyleLoss():
class EdgeLoss():
def __init__(self):
k = paddle.to_tensor([[.05, .25, .4, .25, .05]])
self.kernel = paddle.matmul(k.t(),k).unsqueeze(0).tile([3,1,1,1])
self.kernel = paddle.matmul(k.t(), k).unsqueeze(0).tile([3, 1, 1, 1])
self.loss = CharbonnierLoss()
def conv_gauss(self, img):
n_channels, _, kw, kh = self.kernel.shape
img = F.pad(img, [kw//2, kh//2, kw//2, kh//2], mode='replicate')
img = F.pad(img, [kw // 2, kh // 2, kw // 2, kh // 2], mode='replicate')
return F.conv2d(img, self.kernel, groups=n_channels)
def laplacian_kernel(self, current):
filtered = self.conv_gauss(current) # filter
down = filtered[:,:,::2,::2] # downsample
new_filter = paddle.zeros_like(filtered)
new_filter[:,:,::2,::2] = down*4 # upsample
filtered = self.conv_gauss(new_filter) # filter
filtered = self.conv_gauss(current) # filter
down = filtered[:, :, ::2, ::2] # downsample
new_filter = paddle.zeros_like(filtered)
new_filter.stop_gradient = True
new_filter[:, :, ::2, ::2] = down * 4 # upsample
filtered = self.conv_gauss(new_filter) # filter
diff = current - filtered
return diff
def __call__(self, x, y):
y.stop_gradient = True
loss = self.loss(self.laplacian_kernel(x), self.laplacian_kernel(y))
return loss
\ No newline at end of file
return loss
......@@ -20,6 +20,7 @@ from .base_model import BaseModel
from .generators.builder import build_generator
from .criterions.builder import build_criterion
from ..modules.init import reset_parameters, init_weights
from ..utils.visual import tensor2img
@MODELS.register()
......@@ -50,12 +51,12 @@ class MPRModel(BaseModel):
def setup_input(self, input):
self.target = input[0]
self.input_ = input[1]
self.lq = input[1]
def train_iter(self, optims=None):
optims['optim'].clear_gradients()
restored = self.nets['generator'](self.input_)
restored = self.nets['generator'](self.lq)
loss_char = []
loss_edge = []
......@@ -75,5 +76,21 @@ class MPRModel(BaseModel):
self.losses['loss'] = loss.numpy()
def forward(self):
"""Run forward pass; called by both functions <train_iter> and <test_iter>."""
pass
def test_iter(self, metrics=None):
self.nets['generator'].eval()
with paddle.no_grad():
self.output = self.nets['generator'](self.lq)[0]
self.visual_items['output'] = self.output
self.nets['generator'].train()
out_img = []
gt_img = []
for out_tensor, gt_tensor in zip(self.output, self.target):
out_img.append(tensor2img(out_tensor, (0., 1.)))
gt_img.append(tensor2img(gt_tensor, (0., 1.)))
if metrics is not None:
for metric in metrics.values():
metric.update(out_img, gt_img)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册