未验证 提交 e1963bbc 编写于 作者: F FNRE 提交者: GitHub

add FID (#327)

* add FID
上级 97c1d594
...@@ -13,4 +13,5 @@ ...@@ -13,4 +13,5 @@
# limitations under the License. # limitations under the License.
from .psnr_ssim import PSNR, SSIM from .psnr_ssim import PSNR, SSIM
from .fid import FID
from .builder import build_metric from .builder import build_metric
...@@ -20,7 +20,9 @@ import paddle ...@@ -20,7 +20,9 @@ import paddle
from PIL import Image from PIL import Image
from cv2 import imread from cv2 import imread
from scipy import linalg from scipy import linalg
from inception import InceptionV3 from .inception import InceptionV3
from paddle.utils.download import get_weights_path_from_url
from .builder import METRICS
try: try:
from tqdm import tqdm from tqdm import tqdm
...@@ -35,6 +37,42 @@ except: ...@@ -35,6 +37,42 @@ except:
""" """
inceptionV3 pretrain model is convert from pytorch, pretrain_model url is https://paddle-gan-models.bj.bcebos.com/params_inceptionV3.tar.gz inceptionV3 pretrain model is convert from pytorch, pretrain_model url is https://paddle-gan-models.bj.bcebos.com/params_inceptionV3.tar.gz
""" """
INCEPTIONV3_WEIGHT_URL = "https://paddlegan.bj.bcebos.com/InceptionV3.pdparams"
@METRICS.register()
class FID(paddle.metric.Metric):
def __init__(self, batch_size=1, use_GPU=True, dims = 2048, premodel_path=None, model=None):
self.batch_size = batch_size
self.use_GPU = use_GPU
self.dims = dims
self.premodel_path = premodel_path
if model is None:
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
model = InceptionV3([block_idx])
if premodel_path is None:
premodel_path = get_weights_path_from_url(INCEPTIONV3_WEIGHT_URL)
self.model = model
param_dict = paddle.load(premodel_path)
model.load_dict(param_dict)
model.eval()
self.reset()
def reset(self):
self.results = []
def update(self, preds, gts):
value = calculate_fid_given_img(preds, gts, self.batch_size, self.model, self.use_GPU, self.dims)
self.results.append(value)
def accumulate(self):
if len(self.results) <= 0:
return 0.
return np.mean(self.results)
def name(self):
return 'FID'
def _calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): def _calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
...@@ -71,13 +109,12 @@ def _calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): ...@@ -71,13 +109,12 @@ def _calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
2 * tr_covmean) 2 * tr_covmean)
def _get_activations_from_ims(img, model, batch_size, dims, use_gpu, def _get_activations_from_ims(img, model, batch_size, dims, use_gpu):
premodel_path):
n_batches = (len(img) + batch_size - 1) // batch_size n_batches = (len(img) + batch_size - 1) // batch_size
n_used_img = len(img) n_used_img = len(img)
pred_arr = np.empty((n_used_img, dims)) pred_arr = np.empty((n_used_img, dims))
for i in tqdm(range(n_batches)): for i in tqdm(range(n_batches)):
start = i * batch_size start = i * batch_size
end = start + batch_size end = start + batch_size
...@@ -89,19 +126,13 @@ def _get_activations_from_ims(img, model, batch_size, dims, use_gpu, ...@@ -89,19 +126,13 @@ def _get_activations_from_ims(img, model, batch_size, dims, use_gpu,
images /= 255 images /= 255
images = paddle.to_tensor(images) images = paddle.to_tensor(images)
param_dict, _ = paddle.load(premodel_path)
model.set_dict(param_dict)
model.eval()
pred = model(images)[0][0] pred = model(images)[0][0]
pred_arr[start:end] = pred.reshape(end - start, -1) pred_arr[start:end] = pred.reshape([end - start, -1]).cpu().numpy()
return pred_arr return pred_arr
def _compute_statistic_of_img(img, model, batch_size, dims, use_gpu, def _compute_statistic_of_img(img, model, batch_size, dims, use_gpu):
premodel_path): act = _get_activations_from_ims(img, model, batch_size, dims, use_gpu)
act = _get_activations_from_ims(img, model, batch_size, dims, use_gpu,
premodel_path)
mu = np.mean(act, axis=0) mu = np.mean(act, axis=0)
sigma = np.cov(act, rowvar=False) sigma = np.cov(act, rowvar=False)
return mu, sigma return mu, sigma
...@@ -110,22 +141,14 @@ def _compute_statistic_of_img(img, model, batch_size, dims, use_gpu, ...@@ -110,22 +141,14 @@ def _compute_statistic_of_img(img, model, batch_size, dims, use_gpu,
def calculate_fid_given_img(img_fake, def calculate_fid_given_img(img_fake,
img_real, img_real,
batch_size, batch_size,
use_gpu, model,
dims, use_gpu = True,
premodel_path, dims = 2048):
model=None):
assert os.path.exists(
premodel_path
), 'pretrain_model path {} is not exists! Please download it first'.format(
premodel_path)
if model is None:
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
model = InceptionV3([block_idx])
m1, s1 = _compute_statistic_of_img(img_fake, model, batch_size, dims, m1, s1 = _compute_statistic_of_img(img_fake, model, batch_size, dims,
use_gpu, premodel_path) use_gpu)
m2, s2 = _compute_statistic_of_img(img_real, model, batch_size, dims, m2, s2 = _compute_statistic_of_img(img_real, model, batch_size, dims,
use_gpu, premodel_path) use_gpu)
fid_value = _calculate_frechet_distance(m1, s1, m2, s2) fid_value = _calculate_frechet_distance(m1, s1, m2, s2)
return fid_value return fid_value
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import math import math
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
from paddle.nn import Conv2D, AvgPool2D, MaxPool2D, BatchNorm, Linear from paddle.nn import Conv2D, AvgPool2D, MaxPool2D, BatchNorm, Linear, AdaptiveAvgPool2D
__all__ = ['InceptionV3'] __all__ = ['InceptionV3']
...@@ -57,7 +57,7 @@ class InceptionV3(nn.Layer): ...@@ -57,7 +57,7 @@ class InceptionV3(nn.Layer):
3, 3,
padding=1, padding=1,
name='Conv2d_2b_3x3') name='Conv2d_2b_3x3')
self.maxpool1 = MaxPool2D(pool_size=3, pool_stride=2) self.maxpool1 = MaxPool2D(kernel_size=3, stride=2)
block0 = [ block0 = [
self.Conv2d_1a_3x3, self.Conv2d_2a_3x3, self.Conv2d_2b_3x3, self.Conv2d_1a_3x3, self.Conv2d_2a_3x3, self.Conv2d_2b_3x3,
...@@ -69,7 +69,7 @@ class InceptionV3(nn.Layer): ...@@ -69,7 +69,7 @@ class InceptionV3(nn.Layer):
if self.last_needed_block >= 1: if self.last_needed_block >= 1:
self.Conv2d_3b_1x1 = ConvBNLayer(64, 80, 1, name='Conv2d_3b_1x1') self.Conv2d_3b_1x1 = ConvBNLayer(64, 80, 1, name='Conv2d_3b_1x1')
self.Conv2d_4a_3x3 = ConvBNLayer(80, 192, 3, name='Conv2d_4a_3x3') self.Conv2d_4a_3x3 = ConvBNLayer(80, 192, 3, name='Conv2d_4a_3x3')
self.maxpool2 = MaxPool2D(pool_size=3, pool_stride=2) self.maxpool2 = MaxPool2D(kernel_size=3, stride=2)
block1 = [self.Conv2d_3b_1x1, self.Conv2d_4a_3x3, self.maxpool2] block1 = [self.Conv2d_3b_1x1, self.Conv2d_4a_3x3, self.maxpool2]
self.blocks.append(nn.Sequential(*block1)) self.blocks.append(nn.Sequential(*block1))
...@@ -107,7 +107,7 @@ class InceptionV3(nn.Layer): ...@@ -107,7 +107,7 @@ class InceptionV3(nn.Layer):
self.Mixed_7a = InceptionD(768, name='Mixed_7a') self.Mixed_7a = InceptionD(768, name='Mixed_7a')
self.Mixed_7b = Fid_inceptionE_1(1280, name='Mixed_7b') self.Mixed_7b = Fid_inceptionE_1(1280, name='Mixed_7b')
self.Mixed_7c = Fid_inceptionE_2(2048, name='Mixed_7c') self.Mixed_7c = Fid_inceptionE_2(2048, name='Mixed_7c')
self.avgpool = AvgPool2D(global_pooling=True) self.avgpool = AdaptiveAvgPool2D(output_size=1)
block3 = [self.Mixed_7a, self.Mixed_7b, self.Mixed_7c, self.avgpool] block3 = [self.Mixed_7a, self.Mixed_7b, self.Mixed_7c, self.avgpool]
self.blocks.append(nn.Sequential(*block3)) self.blocks.append(nn.Sequential(*block3))
...@@ -170,9 +170,9 @@ class InceptionA(nn.Layer): ...@@ -170,9 +170,9 @@ class InceptionA(nn.Layer):
padding=1, padding=1,
name=name + '.branch3x3dbl_3') name=name + '.branch3x3dbl_3')
self.branch_pool0 = AvgPool2D(pool_size=3, self.branch_pool0 = AvgPool2D(kernel_size=3,
pool_stride=1, stride=1,
pool_padding=1, padding=1,
exclusive=True) exclusive=True)
self.branch_pool = ConvBNLayer(in_channels, self.branch_pool = ConvBNLayer(in_channels,
pool_features, pool_features,
...@@ -219,7 +219,7 @@ class InceptionB(nn.Layer): ...@@ -219,7 +219,7 @@ class InceptionB(nn.Layer):
stride=2, stride=2,
name=name + '.branch3x3dbl_3') name=name + '.branch3x3dbl_3')
self.branch_pool = MaxPool2D(pool_size=3, pool_stride=2) self.branch_pool = MaxPool2D(kernel_size=3, stride=2)
def forward(self, x): def forward(self, x):
branch3x3 = self.branch3x3(x) branch3x3 = self.branch3x3(x)
...@@ -275,9 +275,9 @@ class InceptionC(nn.Layer): ...@@ -275,9 +275,9 @@ class InceptionC(nn.Layer):
padding=(0, 3), padding=(0, 3),
name=name + '.branch7x7dbl_5') name=name + '.branch7x7dbl_5')
self.branch_pool0 = AvgPool2D(pool_size=3, self.branch_pool0 = AvgPool2D(kernel_size=3,
pool_stride=1, stride=1,
pool_padding=1, padding=1,
exclusive=True) exclusive=True)
self.branch_pool = ConvBNLayer(in_channels, self.branch_pool = ConvBNLayer(in_channels,
192, 192,
...@@ -335,7 +335,7 @@ class InceptionD(nn.Layer): ...@@ -335,7 +335,7 @@ class InceptionD(nn.Layer):
stride=2, stride=2,
name=name + '.branch7x7x3_4') name=name + '.branch7x7x3_4')
self.branch_pool = MaxPool2D(pool_size=3, pool_stride=2) self.branch_pool = MaxPool2D(kernel_size=3, stride=2)
def forward(self, x): def forward(self, x):
branch3x3 = self.branch3x3_1(x) branch3x3 = self.branch3x3_1(x)
...@@ -391,9 +391,9 @@ class InceptionE(nn.Layer): ...@@ -391,9 +391,9 @@ class InceptionE(nn.Layer):
padding=(1, 0), padding=(1, 0),
name=name + '.branch3x3dbl_3b') name=name + '.branch3x3dbl_3b')
self.branch_pool0 = AvgPool2D(pool_size=3, self.branch_pool0 = AvgPool2D(kernel_size=3,
pool_stride=1, stride=1,
pool_padding=1, padding=1,
exclusive=True) exclusive=True)
self.branch_pool = ConvBNLayer(in_channels, self.branch_pool = ConvBNLayer(in_channels,
192, 192,
...@@ -425,7 +425,7 @@ class InceptionAux(nn.Layer): ...@@ -425,7 +425,7 @@ class InceptionAux(nn.Layer):
def __init__(self, in_channels, num_classes, name=None): def __init__(self, in_channels, num_classes, name=None):
super(InceptionAux, self).__init__() super(InceptionAux, self).__init__()
self.num_classes = num_classes self.num_classes = num_classes
self.pool0 = AvgPool2D(pool_size=5, pool_stride=3) self.pool0 = AvgPool2D(kernel_size=5, stride=3)
self.conv0 = ConvBNLayer(in_channels, 128, 1, name=name + '.conv0') self.conv0 = ConvBNLayer(in_channels, 128, 1, name=name + '.conv0')
self.conv1 = ConvBNLayer(128, 768, 5, name=name + '.conv1') self.conv1 = ConvBNLayer(128, 768, 5, name=name + '.conv1')
self.pool1 = AvgPool2D(global_pooling=True) self.pool1 = AvgPool2D(global_pooling=True)
...@@ -475,9 +475,9 @@ class Fid_inceptionA(nn.Layer): ...@@ -475,9 +475,9 @@ class Fid_inceptionA(nn.Layer):
padding=1, padding=1,
name=name + '.branch3x3dbl_3') name=name + '.branch3x3dbl_3')
self.branch_pool0 = AvgPool2D(pool_size=3, self.branch_pool0 = AvgPool2D(kernel_size=3,
pool_stride=1, stride=1,
pool_padding=1, padding=1,
exclusive=True) exclusive=True)
self.branch_pool = ConvBNLayer(in_channels, self.branch_pool = ConvBNLayer(in_channels,
pool_features, pool_features,
...@@ -544,9 +544,9 @@ class Fid_inceptionC(nn.Layer): ...@@ -544,9 +544,9 @@ class Fid_inceptionC(nn.Layer):
padding=(0, 3), padding=(0, 3),
name=name + '.branch7x7dbl_5') name=name + '.branch7x7dbl_5')
self.branch_pool0 = AvgPool2D(pool_size=3, self.branch_pool0 = AvgPool2D(kernel_size=3,
pool_stride=1, stride=1,
pool_padding=1, padding=1,
exclusive=True) exclusive=True)
self.branch_pool = ConvBNLayer(in_channels, self.branch_pool = ConvBNLayer(in_channels,
192, 192,
...@@ -614,9 +614,9 @@ class Fid_inceptionE_1(nn.Layer): ...@@ -614,9 +614,9 @@ class Fid_inceptionE_1(nn.Layer):
padding=(1, 0), padding=(1, 0),
name=name + '.branch3x3dbl_3b') name=name + '.branch3x3dbl_3b')
self.branch_pool0 = AvgPool2D(pool_size=3, self.branch_pool0 = AvgPool2D(kernel_size=3,
pool_stride=1, stride=1,
pool_padding=1, padding=1,
exclusive=True) exclusive=True)
self.branch_pool = ConvBNLayer(in_channels, self.branch_pool = ConvBNLayer(in_channels,
192, 192,
...@@ -685,9 +685,9 @@ class Fid_inceptionE_2(nn.Layer): ...@@ -685,9 +685,9 @@ class Fid_inceptionE_2(nn.Layer):
padding=(1, 0), padding=(1, 0),
name=name + '.branch3x3dbl_3b') name=name + '.branch3x3dbl_3b')
### same with paper ### same with paper
self.branch_pool0 = MaxPool2D(pool_size=3, self.branch_pool0 = MaxPool2D(kernel_size=3,
pool_stride=1, stride=1,
pool_padding=1) padding=1)
self.branch_pool = ConvBNLayer(in_channels, self.branch_pool = ConvBNLayer(in_channels,
192, 192,
1, 1,
...@@ -725,14 +725,13 @@ class ConvBNLayer(nn.Layer): ...@@ -725,14 +725,13 @@ class ConvBNLayer(nn.Layer):
act='relu', act='relu',
name=None): name=None):
super(ConvBNLayer, self).__init__() super(ConvBNLayer, self).__init__()
self.conv = Conv2D(num_channels=in_channels, self.conv = Conv2D(in_channels=in_channels,
num_filters=num_filters, out_channels=num_filters,
filter_size=filter_size, kernel_size=filter_size,
stride=stride, stride=stride,
padding=padding, padding=padding,
groups=groups, groups=groups,
act=None, weight_attr=paddle.ParamAttr(name=name + ".conv.weight"),
param_attr=paddle.ParamAttr(name=name + ".conv.weight"),
bias_attr=False) bias_attr=False)
self.bn = BatchNorm(num_filters, self.bn = BatchNorm(num_filters,
act=act, act=act,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册