未验证 提交 b6204126 编写于 作者: L LielinJiang 提交者: GitHub

Fix mprnet train bug and add docs (#506)

* fix mprnet train and add docs

* update config
上级 5bf728df
total_iters: 100000 # epoch: 3000 for total batch size=16
total_iters: 400000
output_dir: output_dir output_dir: output_dir
model: model:
...@@ -15,38 +16,38 @@ dataset: ...@@ -15,38 +16,38 @@ dataset:
train: train:
name: MPRTrain name: MPRTrain
rgb_dir: 'data/GoPro/train' rgb_dir: 'data/GoPro/train'
num_workers: 16 num_workers: 4
batch_size: 4 batch_size: 2
img_options: img_options:
patch_size: 256 patch_size: 256
test: test:
name: MPRTrain name: MPRVal
rgb_dir: 'data/GoPro/test' rgb_dir: 'data/GoPro/test'
num_workers: 16 num_workers: 4
batch_size: 4 batch_size: 2
img_options: img_options:
patch_size: 256 patch_size: 256
lr_scheduler: lr_scheduler:
name: CosineAnnealingRestartLR name: CosineAnnealingRestartLR
learning_rate: !!float 2e-4 learning_rate: !!float 1e-4
periods: [25000, 25000, 25000, 25000] periods: [400000]
restart_weights: [1, 1, 1, 1] restart_weights: [1]
eta_min: !!float 1e-6 eta_min: !!float 1e-6
validate: validate:
interval: 10 interval: 5000
save_img: false save_img: false
metrics: metrics:
psnr: # metric name, can be arbitrary psnr: # metric name, can be arbitrary
name: PSNR name: PSNR
crop_border: 4 crop_border: 4
test_y_channel: True test_y_channel: false
ssim: ssim:
name: SSIM name: SSIM
crop_border: 4 crop_border: 4
test_y_channel: True test_y_channel: false
optimizer: optimizer:
name: Adam name: Adam
...@@ -59,7 +60,7 @@ optimizer: ...@@ -59,7 +60,7 @@ optimizer:
epsilon: 1e-8 epsilon: 1e-8
log_config: log_config:
interval: 10 interval: 100
visiual_interval: 5000 visiual_interval: 5000
snapshot_config: snapshot_config:
......
...@@ -130,6 +130,10 @@ The metrics are PSNR / SSIM. ...@@ -130,6 +130,10 @@ The metrics are PSNR / SSIM.
| pan_x4 | 30.4574 / 0.8643 | 26.7204 / 0.7434 | 28.9187 / 0.8176 | | pan_x4 | 30.4574 / 0.8643 | 26.7204 / 0.7434 | 28.9187 / 0.8176 |
| drns_x4 | 32.6684 / 0.8999 | 28.9037 / 0.7885 | - | | drns_x4 | 32.6684 / 0.8999 | 28.9037 / 0.7885 | - |
Deblur models zoo
| model | GoPro | Download Link |
|---|---|---|
| MPRNet | 33.4360 / 0.9410 | [link](https://paddlegan.bj.bcebos.com/models/MPR_Deblurring.pdparams) |
<!-- ![](../../imgs/horse2zebra.png) --> <!-- ![](../../imgs/horse2zebra.png) -->
......
...@@ -120,6 +120,10 @@ paddle模型使用DIV2K数据集训练,torch模型使用df2k和DIV2K数据集 ...@@ -120,6 +120,10 @@ paddle模型使用DIV2K数据集训练,torch模型使用df2k和DIV2K数据集
| paddle | 30.4574 / 0.8643 | 26.7204 / 0.7434 | | paddle | 30.4574 / 0.8643 | 26.7204 / 0.7434 |
| torch | 30.2183 / 0.8643 | 26.8035 / 0.7445 | | torch | 30.2183 / 0.8643 | 26.8035 / 0.7445 |
去模糊模型
| 模型 | GoPro | 下载地址 |
|---|---|---|
| MPRNet | 33.4360 / 0.9410 | [链接](https://paddlegan.bj.bcebos.com/models/MPR_Deblurring.pdparams) |
<!-- ![](../../imgs/horse2zebra.png) --> <!-- ![](../../imgs/horse2zebra.png) -->
......
...@@ -29,6 +29,7 @@ from ..utils.filesystem import makedirs, save, load ...@@ -29,6 +29,7 @@ from ..utils.filesystem import makedirs, save, load
from ..utils.timer import TimeAverager from ..utils.timer import TimeAverager
from ..utils.profiler import add_profiler_step from ..utils.profiler import add_profiler_step
class IterLoader: class IterLoader:
def __init__(self, dataloader): def __init__(self, dataloader):
self._dataloader = dataloader self._dataloader = dataloader
...@@ -429,6 +430,17 @@ class Trainer: ...@@ -429,6 +430,17 @@ class Trainer:
def load(self, weight_path): def load(self, weight_path):
state_dicts = load(weight_path) state_dicts = load(weight_path)
def is_dict_in_dict_weight(state_dict):
if isinstance(state_dict, dict) and len(state_dict) > 0:
val = list(state_dict.values())[0]
if isinstance(val, dict):
return True
else:
return False
else:
return False
if is_dict_in_dict_weight(state_dicts):
for net_name, net in self.model.nets.items(): for net_name, net in self.model.nets.items():
if net_name in state_dicts: if net_name in state_dicts:
net.set_state_dict(state_dicts[net_name]) net.set_state_dict(state_dicts[net_name])
...@@ -438,6 +450,15 @@ class Trainer: ...@@ -438,6 +450,15 @@ class Trainer:
self.logger.warning( self.logger.warning(
'Can not find state dict of net {}. Skip load pretrained weight for net {}' 'Can not find state dict of net {}. Skip load pretrained weight for net {}'
.format(net_name, net_name)) .format(net_name, net_name))
else:
assert len(self.model.nets
) == 1, 'checkpoint only contain weight of one net, \
but model contains more than one net!'
net_name, net = list(self.model.nets.items())[0]
net.set_state_dict(state_dicts)
self.logger.info(
'Loaded pretrained weight for net {}'.format(net_name))
def close(self): def close(self):
""" """
......
...@@ -249,23 +249,25 @@ class CalcStyleLoss(): ...@@ -249,23 +249,25 @@ class CalcStyleLoss():
class EdgeLoss(): class EdgeLoss():
def __init__(self): def __init__(self):
k = paddle.to_tensor([[.05, .25, .4, .25, .05]]) k = paddle.to_tensor([[.05, .25, .4, .25, .05]])
self.kernel = paddle.matmul(k.t(),k).unsqueeze(0).tile([3,1,1,1]) self.kernel = paddle.matmul(k.t(), k).unsqueeze(0).tile([3, 1, 1, 1])
self.loss = CharbonnierLoss() self.loss = CharbonnierLoss()
def conv_gauss(self, img): def conv_gauss(self, img):
n_channels, _, kw, kh = self.kernel.shape n_channels, _, kw, kh = self.kernel.shape
img = F.pad(img, [kw//2, kh//2, kw//2, kh//2], mode='replicate') img = F.pad(img, [kw // 2, kh // 2, kw // 2, kh // 2], mode='replicate')
return F.conv2d(img, self.kernel, groups=n_channels) return F.conv2d(img, self.kernel, groups=n_channels)
def laplacian_kernel(self, current): def laplacian_kernel(self, current):
filtered = self.conv_gauss(current) # filter filtered = self.conv_gauss(current) # filter
down = filtered[:,:,::2,::2] # downsample down = filtered[:, :, ::2, ::2] # downsample
new_filter = paddle.zeros_like(filtered) new_filter = paddle.zeros_like(filtered)
new_filter[:,:,::2,::2] = down*4 # upsample new_filter.stop_gradient = True
new_filter[:, :, ::2, ::2] = down * 4 # upsample
filtered = self.conv_gauss(new_filter) # filter filtered = self.conv_gauss(new_filter) # filter
diff = current - filtered diff = current - filtered
return diff return diff
def __call__(self, x, y): def __call__(self, x, y):
y.stop_gradient = True
loss = self.loss(self.laplacian_kernel(x), self.laplacian_kernel(y)) loss = self.loss(self.laplacian_kernel(x), self.laplacian_kernel(y))
return loss return loss
...@@ -20,6 +20,7 @@ from .base_model import BaseModel ...@@ -20,6 +20,7 @@ from .base_model import BaseModel
from .generators.builder import build_generator from .generators.builder import build_generator
from .criterions.builder import build_criterion from .criterions.builder import build_criterion
from ..modules.init import reset_parameters, init_weights from ..modules.init import reset_parameters, init_weights
from ..utils.visual import tensor2img
@MODELS.register() @MODELS.register()
...@@ -50,12 +51,12 @@ class MPRModel(BaseModel): ...@@ -50,12 +51,12 @@ class MPRModel(BaseModel):
def setup_input(self, input): def setup_input(self, input):
self.target = input[0] self.target = input[0]
self.input_ = input[1] self.lq = input[1]
def train_iter(self, optims=None): def train_iter(self, optims=None):
optims['optim'].clear_gradients() optims['optim'].clear_gradients()
restored = self.nets['generator'](self.input_) restored = self.nets['generator'](self.lq)
loss_char = [] loss_char = []
loss_edge = [] loss_edge = []
...@@ -75,5 +76,21 @@ class MPRModel(BaseModel): ...@@ -75,5 +76,21 @@ class MPRModel(BaseModel):
self.losses['loss'] = loss.numpy() self.losses['loss'] = loss.numpy()
def forward(self): def forward(self):
"""Run forward pass; called by both functions <train_iter> and <test_iter>."""
pass pass
def test_iter(self, metrics=None):
self.nets['generator'].eval()
with paddle.no_grad():
self.output = self.nets['generator'](self.lq)[0]
self.visual_items['output'] = self.output
self.nets['generator'].train()
out_img = []
gt_img = []
for out_tensor, gt_tensor in zip(self.output, self.target):
out_img.append(tensor2img(out_tensor, (0., 1.)))
gt_img.append(tensor2img(gt_tensor, (0., 1.)))
if metrics is not None:
for metric in metrics.values():
metric.update(out_img, gt_img)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册