未验证 提交 e1963bbc 编写于 作者: F FNRE 提交者: GitHub

add FID (#327)

* add FID
上级 97c1d594
......@@ -13,4 +13,5 @@
# limitations under the License.
from .psnr_ssim import PSNR, SSIM
from .fid import FID
from .builder import build_metric
......@@ -20,7 +20,9 @@ import paddle
from PIL import Image
from cv2 import imread
from scipy import linalg
from inception import InceptionV3
from .inception import InceptionV3
from paddle.utils.download import get_weights_path_from_url
from .builder import METRICS
try:
from tqdm import tqdm
......@@ -35,6 +37,42 @@ except:
"""
inceptionV3 pretrain model is convert from pytorch, pretrain_model url is https://paddle-gan-models.bj.bcebos.com/params_inceptionV3.tar.gz
"""
INCEPTIONV3_WEIGHT_URL = "https://paddlegan.bj.bcebos.com/InceptionV3.pdparams"
@METRICS.register()
class FID(paddle.metric.Metric):
def __init__(self, batch_size=1, use_GPU=True, dims = 2048, premodel_path=None, model=None):
self.batch_size = batch_size
self.use_GPU = use_GPU
self.dims = dims
self.premodel_path = premodel_path
if model is None:
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
model = InceptionV3([block_idx])
if premodel_path is None:
premodel_path = get_weights_path_from_url(INCEPTIONV3_WEIGHT_URL)
self.model = model
param_dict = paddle.load(premodel_path)
model.load_dict(param_dict)
model.eval()
self.reset()
def reset(self):
self.results = []
def update(self, preds, gts):
value = calculate_fid_given_img(preds, gts, self.batch_size, self.model, self.use_GPU, self.dims)
self.results.append(value)
def accumulate(self):
if len(self.results) <= 0:
return 0.
return np.mean(self.results)
def name(self):
return 'FID'
def _calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
......@@ -71,8 +109,7 @@ def _calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
2 * tr_covmean)
def _get_activations_from_ims(img, model, batch_size, dims, use_gpu,
premodel_path):
def _get_activations_from_ims(img, model, batch_size, dims, use_gpu):
n_batches = (len(img) + batch_size - 1) // batch_size
n_used_img = len(img)
......@@ -89,19 +126,13 @@ def _get_activations_from_ims(img, model, batch_size, dims, use_gpu,
images /= 255
images = paddle.to_tensor(images)
param_dict, _ = paddle.load(premodel_path)
model.set_dict(param_dict)
model.eval()
pred = model(images)[0][0]
pred_arr[start:end] = pred.reshape(end - start, -1)
pred_arr[start:end] = pred.reshape([end - start, -1]).cpu().numpy()
return pred_arr
def _compute_statistic_of_img(img, model, batch_size, dims, use_gpu,
premodel_path):
act = _get_activations_from_ims(img, model, batch_size, dims, use_gpu,
premodel_path)
def _compute_statistic_of_img(img, model, batch_size, dims, use_gpu):
act = _get_activations_from_ims(img, model, batch_size, dims, use_gpu)
mu = np.mean(act, axis=0)
sigma = np.cov(act, rowvar=False)
return mu, sigma
......@@ -110,22 +141,14 @@ def _compute_statistic_of_img(img, model, batch_size, dims, use_gpu,
def calculate_fid_given_img(img_fake,
img_real,
batch_size,
use_gpu,
dims,
premodel_path,
model=None):
assert os.path.exists(
premodel_path
), 'pretrain_model path {} is not exists! Please download it first'.format(
premodel_path)
if model is None:
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
model = InceptionV3([block_idx])
model,
use_gpu = True,
dims = 2048):
m1, s1 = _compute_statistic_of_img(img_fake, model, batch_size, dims,
use_gpu, premodel_path)
use_gpu)
m2, s2 = _compute_statistic_of_img(img_real, model, batch_size, dims,
use_gpu, premodel_path)
use_gpu)
fid_value = _calculate_frechet_distance(m1, s1, m2, s2)
return fid_value
......
......@@ -15,7 +15,7 @@
import math
import paddle
import paddle.nn as nn
from paddle.nn import Conv2D, AvgPool2D, MaxPool2D, BatchNorm, Linear
from paddle.nn import Conv2D, AvgPool2D, MaxPool2D, BatchNorm, Linear, AdaptiveAvgPool2D
__all__ = ['InceptionV3']
......@@ -57,7 +57,7 @@ class InceptionV3(nn.Layer):
3,
padding=1,
name='Conv2d_2b_3x3')
self.maxpool1 = MaxPool2D(pool_size=3, pool_stride=2)
self.maxpool1 = MaxPool2D(kernel_size=3, stride=2)
block0 = [
self.Conv2d_1a_3x3, self.Conv2d_2a_3x3, self.Conv2d_2b_3x3,
......@@ -69,7 +69,7 @@ class InceptionV3(nn.Layer):
if self.last_needed_block >= 1:
self.Conv2d_3b_1x1 = ConvBNLayer(64, 80, 1, name='Conv2d_3b_1x1')
self.Conv2d_4a_3x3 = ConvBNLayer(80, 192, 3, name='Conv2d_4a_3x3')
self.maxpool2 = MaxPool2D(pool_size=3, pool_stride=2)
self.maxpool2 = MaxPool2D(kernel_size=3, stride=2)
block1 = [self.Conv2d_3b_1x1, self.Conv2d_4a_3x3, self.maxpool2]
self.blocks.append(nn.Sequential(*block1))
......@@ -107,7 +107,7 @@ class InceptionV3(nn.Layer):
self.Mixed_7a = InceptionD(768, name='Mixed_7a')
self.Mixed_7b = Fid_inceptionE_1(1280, name='Mixed_7b')
self.Mixed_7c = Fid_inceptionE_2(2048, name='Mixed_7c')
self.avgpool = AvgPool2D(global_pooling=True)
self.avgpool = AdaptiveAvgPool2D(output_size=1)
block3 = [self.Mixed_7a, self.Mixed_7b, self.Mixed_7c, self.avgpool]
self.blocks.append(nn.Sequential(*block3))
......@@ -170,9 +170,9 @@ class InceptionA(nn.Layer):
padding=1,
name=name + '.branch3x3dbl_3')
self.branch_pool0 = AvgPool2D(pool_size=3,
pool_stride=1,
pool_padding=1,
self.branch_pool0 = AvgPool2D(kernel_size=3,
stride=1,
padding=1,
exclusive=True)
self.branch_pool = ConvBNLayer(in_channels,
pool_features,
......@@ -219,7 +219,7 @@ class InceptionB(nn.Layer):
stride=2,
name=name + '.branch3x3dbl_3')
self.branch_pool = MaxPool2D(pool_size=3, pool_stride=2)
self.branch_pool = MaxPool2D(kernel_size=3, stride=2)
def forward(self, x):
branch3x3 = self.branch3x3(x)
......@@ -275,9 +275,9 @@ class InceptionC(nn.Layer):
padding=(0, 3),
name=name + '.branch7x7dbl_5')
self.branch_pool0 = AvgPool2D(pool_size=3,
pool_stride=1,
pool_padding=1,
self.branch_pool0 = AvgPool2D(kernel_size=3,
stride=1,
padding=1,
exclusive=True)
self.branch_pool = ConvBNLayer(in_channels,
192,
......@@ -335,7 +335,7 @@ class InceptionD(nn.Layer):
stride=2,
name=name + '.branch7x7x3_4')
self.branch_pool = MaxPool2D(pool_size=3, pool_stride=2)
self.branch_pool = MaxPool2D(kernel_size=3, stride=2)
def forward(self, x):
branch3x3 = self.branch3x3_1(x)
......@@ -391,9 +391,9 @@ class InceptionE(nn.Layer):
padding=(1, 0),
name=name + '.branch3x3dbl_3b')
self.branch_pool0 = AvgPool2D(pool_size=3,
pool_stride=1,
pool_padding=1,
self.branch_pool0 = AvgPool2D(kernel_size=3,
stride=1,
padding=1,
exclusive=True)
self.branch_pool = ConvBNLayer(in_channels,
192,
......@@ -425,7 +425,7 @@ class InceptionAux(nn.Layer):
def __init__(self, in_channels, num_classes, name=None):
super(InceptionAux, self).__init__()
self.num_classes = num_classes
self.pool0 = AvgPool2D(pool_size=5, pool_stride=3)
self.pool0 = AvgPool2D(kernel_size=5, stride=3)
self.conv0 = ConvBNLayer(in_channels, 128, 1, name=name + '.conv0')
self.conv1 = ConvBNLayer(128, 768, 5, name=name + '.conv1')
self.pool1 = AvgPool2D(global_pooling=True)
......@@ -475,9 +475,9 @@ class Fid_inceptionA(nn.Layer):
padding=1,
name=name + '.branch3x3dbl_3')
self.branch_pool0 = AvgPool2D(pool_size=3,
pool_stride=1,
pool_padding=1,
self.branch_pool0 = AvgPool2D(kernel_size=3,
stride=1,
padding=1,
exclusive=True)
self.branch_pool = ConvBNLayer(in_channels,
pool_features,
......@@ -544,9 +544,9 @@ class Fid_inceptionC(nn.Layer):
padding=(0, 3),
name=name + '.branch7x7dbl_5')
self.branch_pool0 = AvgPool2D(pool_size=3,
pool_stride=1,
pool_padding=1,
self.branch_pool0 = AvgPool2D(kernel_size=3,
stride=1,
padding=1,
exclusive=True)
self.branch_pool = ConvBNLayer(in_channels,
192,
......@@ -614,9 +614,9 @@ class Fid_inceptionE_1(nn.Layer):
padding=(1, 0),
name=name + '.branch3x3dbl_3b')
self.branch_pool0 = AvgPool2D(pool_size=3,
pool_stride=1,
pool_padding=1,
self.branch_pool0 = AvgPool2D(kernel_size=3,
stride=1,
padding=1,
exclusive=True)
self.branch_pool = ConvBNLayer(in_channels,
192,
......@@ -685,9 +685,9 @@ class Fid_inceptionE_2(nn.Layer):
padding=(1, 0),
name=name + '.branch3x3dbl_3b')
### same with paper
self.branch_pool0 = MaxPool2D(pool_size=3,
pool_stride=1,
pool_padding=1)
self.branch_pool0 = MaxPool2D(kernel_size=3,
stride=1,
padding=1)
self.branch_pool = ConvBNLayer(in_channels,
192,
1,
......@@ -725,14 +725,13 @@ class ConvBNLayer(nn.Layer):
act='relu',
name=None):
super(ConvBNLayer, self).__init__()
self.conv = Conv2D(num_channels=in_channels,
num_filters=num_filters,
filter_size=filter_size,
self.conv = Conv2D(in_channels=in_channels,
out_channels=num_filters,
kernel_size=filter_size,
stride=stride,
padding=padding,
groups=groups,
act=None,
param_attr=paddle.ParamAttr(name=name + ".conv.weight"),
weight_attr=paddle.ParamAttr(name=name + ".conv.weight"),
bias_attr=False)
self.bn = BatchNorm(num_filters,
act=act,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册