未验证 提交 9e783d8d 编写于 作者: L LielinJiang 提交者: GitHub

add lpips metric (#339)

* add lpips
上级 170a2b05
...@@ -119,6 +119,8 @@ validate: ...@@ -119,6 +119,8 @@ validate:
name: SSIM name: SSIM
crop_border: 4 crop_border: 4
test_y_channel: false test_y_channel: false
lpips:
name: LPIPSMetric
log_config: log_config:
interval: 100 interval: 100
......
...@@ -14,4 +14,5 @@ ...@@ -14,4 +14,5 @@
from .psnr_ssim import PSNR, SSIM from .psnr_ssim import PSNR, SSIM
from .fid import FID from .fid import FID
from .lpips import LPIPSMetric
from .builder import build_metric from .builder import build_metric
#Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from collections import namedtuple
import numpy as np
import paddle
import paddle.nn as nn
from paddle.utils.download import get_weights_path_from_url
from ..modules import init
from ..models.criterions.perceptual_loss import PerceptualVGG
from .builder import METRICS
lpips = True
VGG16_TORCHVISION_URL = 'https://paddlegan.bj.bcebos.com/models/vgg16_from_torch.pdparams'
LINS_01_VGG_URL = 'https://paddlegan.bj.bcebos.com/models/lins_0.1_vgg.pdparams'
@METRICS.register()
class LPIPSMetric(paddle.metric.Metric):
"""Calculate LPIPS (Learned Perceptual Image Patch Similarity).
Ref: https://arxiv.org/abs/1801.03924
Args:
net (str): Type of backbone net. Default: 'vgg'.
version (str): Version of lpips method. Defalut: '0.1'.
mean (list): Sequence of means for each channel of input image. Default: None.
std (list): Sequence of standard deviations for each channel of input image. Default: None.
Returns:
float: lpips result.
"""
def __init__(self, net='vgg', version='0.1', mean=None, std=None):
self.net = net
self.version = version
self.loss_fn = LPIPS(net=net, version=version)
if mean is None:
self.mean = [0.5, 0.5, 0.5]
else:
self.mean = mean
if std is None:
self.std = [0.5, 0.5, 0.5]
else:
self.std = std
self.reset()
def reset(self):
self.results = []
def update(self, preds, gts):
if not isinstance(preds, (list, tuple)):
preds = [preds]
if not isinstance(gts, (list, tuple)):
gts = [gts]
for pred, gt in zip(preds, gts):
pred, gt = pred.astype(np.float32) / 255., gt.astype(
np.float32) / 255.
pred = paddle.vision.transforms.normalize(pred.transpose([2, 0, 1]),
self.mean, self.std)
gt = paddle.vision.transforms.normalize(gt.transpose([2, 0, 1]),
self.mean, self.std)
with paddle.no_grad():
value = self.loss_fn(
paddle.to_tensor(pred).unsqueeze(0),
paddle.to_tensor(gt).unsqueeze(0))
self.results.append(value.item())
def accumulate(self):
if len(self.results) <= 0:
return 0.
return np.mean(self.results)
def name(self):
return 'LPIPS'
def spatial_average(in_tens, keepdim=True):
return in_tens.mean([2, 3], keepdim=keepdim)
# assumes scale factor is same for H and W
def upsample(in_tens, out_HW=(64, 64)):
in_H, in_W = in_tens.shape[2], in_tens.shape[3]
scale_factor_H, scale_factor_W = 1. * out_HW[0] / in_H, 1. * out_HW[1] / in_W
return nn.Upsample(scale_factor=(scale_factor_H, scale_factor_W),
mode='bilinear',
align_corners=False)(in_tens)
def normalize_tensor(in_feat, eps=1e-10):
norm_factor = paddle.sqrt(paddle.sum(in_feat**2, 1, keepdim=True))
return in_feat / (norm_factor + eps)
# Learned perceptual metric
class LPIPS(nn.Layer):
def __init__(self,
pretrained=True,
net='vgg',
version='0.1',
lpips=True,
spatial=False,
pnet_rand=False,
pnet_tune=False,
use_dropout=True,
model_path=None,
eval_mode=True,
verbose=True):
# lpips - [True] means with linear calibration on top of base network
# pretrained - [True] means load linear weights
super(LPIPS, self).__init__()
if (verbose):
print(
'Setting up [%s] perceptual loss: trunk [%s], v[%s], spatial [%s]'
% ('LPIPS' if lpips else 'baseline', net, version,
'on' if spatial else 'off'))
self.pnet_type = net
self.pnet_tune = pnet_tune
self.pnet_rand = pnet_rand
self.spatial = spatial
# false means baseline of just averaging all layers
self.lpips = lpips
self.version = version
self.scaling_layer = ScalingLayer()
if (self.pnet_type in ['vgg', 'vgg16']):
net_type = vgg16
self.chns = [64, 128, 256, 512, 512]
elif (self.pnet_type == 'alex'):
raise TypeError('alex not support now!')
elif (self.pnet_type == 'squeeze'):
raise TypeError('squeeze not support now!')
self.L = len(self.chns)
self.net = net_type(pretrained=True, requires_grad=False)
if (lpips):
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
# TODO: add alex and squeezenet
# 7 layers for squeezenet
self.lins = nn.LayerList(self.lins)
if (pretrained):
if (model_path is None):
model_path = get_weights_path_from_url(LINS_01_VGG_URL)
if (verbose):
print('Loading model from: %s' % model_path)
self.lins.set_state_dict(paddle.load(model_path))
if (eval_mode):
self.eval()
def forward(self, in0, in1, retPerLayer=False, normalize=False):
# turn on this flag if input is [0,1] so it can be adjusted to [-1, +1]
if normalize:
in0 = 2 * in0 - 1
in1 = 2 * in1 - 1
# v0.0 - original release had a bug, where input was not scaled
in0_input, in1_input = (
self.scaling_layer(in0),
self.scaling_layer(in1)) if self.version == '0.1' else (in0, in1)
outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
feats0, feats1, diffs = {}, {}, {}
for kk in range(self.L):
feats0[kk], feats1[kk] = normalize_tensor(
outs0[kk]), normalize_tensor(outs1[kk])
diffs[kk] = (feats0[kk] - feats1[kk])**2
if (self.lpips):
if (self.spatial):
res = [
upsample(self.lins[kk].model(diffs[kk]),
out_HW=in0.shape[2:]) for kk in range(self.L)
]
else:
res = [
spatial_average(self.lins[kk].model(diffs[kk]),
keepdim=True) for kk in range(self.L)
]
else:
if (self.spatial):
res = [
upsample(diffs[kk].sum(dim=1, keepdim=True),
out_HW=in0.shape[2:]) for kk in range(self.L)
]
else:
res = [
spatial_average(diffs[kk].sum(dim=1, keepdim=True),
keepdim=True) for kk in range(self.L)
]
val = res[0]
for l in range(1, self.L):
val += res[l]
if (retPerLayer):
return (val, res)
else:
return val
class ScalingLayer(nn.Layer):
def __init__(self):
super(ScalingLayer, self).__init__()
self.register_buffer(
'shift',
paddle.to_tensor([-.030, -.088, -.188]).reshape([1, 3, 1, 1]))
self.register_buffer(
'scale',
paddle.to_tensor([.458, .448, .450]).reshape([1, 3, 1, 1]))
def forward(self, inp):
return (inp - self.shift) / self.scale
class NetLinLayer(nn.Layer):
''' A single linear layer which does a 1x1 conv '''
def __init__(self, chn_in, chn_out=1, use_dropout=False):
super(NetLinLayer, self).__init__()
layers = [
nn.Dropout(),
] if (use_dropout) else []
layers += [
nn.Conv2D(chn_in, chn_out, 1, stride=1, padding=0, bias_attr=False),
]
self.model = nn.Sequential(*layers)
class vgg16(nn.Layer):
def __init__(self, requires_grad=False, pretrained=True):
super(vgg16, self).__init__()
self.vgg16 = PerceptualVGG(['3', '8', '15', '22', '29'], 'vgg16', False,
VGG16_TORCHVISION_URL)
if not requires_grad:
for param in self.parameters():
param.trainable = False
def forward(self, x):
out = self.vgg16(x)
vgg_outputs = namedtuple(
"VggOutputs",
['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
out = vgg_outputs(out['3'], out['8'], out['15'], out['22'], out['29'])
return out
...@@ -21,7 +21,6 @@ from abc import ABC, abstractmethod ...@@ -21,7 +21,6 @@ from abc import ABC, abstractmethod
from .criterions.builder import build_criterion from .criterions.builder import build_criterion
from ..solver import build_lr_scheduler, build_optimizer from ..solver import build_lr_scheduler, build_optimizer
from ..metrics import build_metric
from ..utils.visual import tensor2img from ..utils.visual import tensor2img
...@@ -133,6 +132,7 @@ class BaseModel(ABC): ...@@ -133,6 +132,7 @@ class BaseModel(ABC):
return self.optimizers return self.optimizers
def setup_metrics(self, cfg): def setup_metrics(self, cfg):
from ..metrics import build_metric
if isinstance(list(cfg.values())[0], dict): if isinstance(list(cfg.values())[0], dict):
for metric_name, cfg_ in cfg.items(): for metric_name, cfg_ in cfg.items():
self.metrics[metric_name] = build_metric(cfg_) self.metrics[metric_name] = build_metric(cfg_)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册