fid.py 10.2 KB
Newer Older
L
lzzyzlbb 已提交
1 2 3 4
# code was heavily based on https://github.com/mseitzer/pytorch-fid
# This implementation is licensed under the Apache License 2.0.
# Copyright (c) mseitzer

L
lijianshe02 已提交
5 6 7 8 9

import os
import fnmatch
import numpy as np
import cv2
W
wangna11BD 已提交
10
import paddle
M
Mark Ma 已提交
11
from PIL import Image
L
lijianshe02 已提交
12 13
from cv2 import imread
from scipy import linalg
F
FNRE 已提交
14 15 16
from .inception import InceptionV3
from paddle.utils.download import get_weights_path_from_url
from .builder import METRICS
L
lijianshe02 已提交
17

M
Mark Ma 已提交
18 19 20
try:
    from tqdm import tqdm
except:
Q
qingqing01 已提交
21

M
Mark Ma 已提交
22 23
    def tqdm(x):
        return x
L
lijianshe02 已提交
24 25 26 27 28 29 30


""" based on https://github.com/mit-han-lab/gan-compression/blob/master/metric/fid_score.py
"""
"""
inceptionV3 pretrain model is convert from pytorch, pretrain_model url is https://paddle-gan-models.bj.bcebos.com/params_inceptionV3.tar.gz
"""
F
FNRE 已提交
31 32
INCEPTIONV3_WEIGHT_URL = "https://paddlegan.bj.bcebos.com/InceptionV3.pdparams"

L
LielinJiang 已提交
33

F
FNRE 已提交
34 35
@METRICS.register()
class FID(paddle.metric.Metric):
L
LielinJiang 已提交
36 37 38 39 40 41
    def __init__(self,
                 batch_size=1,
                 use_GPU=True,
                 dims=2048,
                 premodel_path=None,
                 model=None):
F
FNRE 已提交
42 43 44 45 46 47
        self.batch_size = batch_size
        self.use_GPU = use_GPU
        self.dims = dims
        self.premodel_path = premodel_path
        if model is None:
            block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
L
lzzyzlbb 已提交
48
            model = InceptionV3([block_idx], normalize_input=False)
F
FNRE 已提交
49 50 51 52
        if premodel_path is None:
            premodel_path = get_weights_path_from_url(INCEPTIONV3_WEIGHT_URL)
        self.model = model
        param_dict = paddle.load(premodel_path)
L
lzzyzlbb 已提交
53 54
        self.model.load_dict(param_dict)
        self.model.eval()
L
LielinJiang 已提交
55 56
        self.reset()

F
FNRE 已提交
57
    def reset(self):
L
lzzyzlbb 已提交
58 59
        self.preds = []
        self.gts = []
F
FNRE 已提交
60 61 62
        self.results = []

    def update(self, preds, gts):
L
lzzyzlbb 已提交
63 64 65 66 67
        preds_inception, gts_inception = calculate_inception_val(
            preds, gts, self.batch_size, self.model, self.use_GPU, self.dims)
        self.preds.append(preds_inception)
        self.gts.append(gts_inception)

F
FNRE 已提交
68
    def accumulate(self):
L
lzzyzlbb 已提交
69 70 71
        self.preds = np.concatenate(self.preds, axis=0)
        self.gts = np.concatenate(self.gts, axis=0)
        value = calculate_fid_given_img(self.preds, self.gts)
L
LielinJiang 已提交
72
        self.reset()
L
lzzyzlbb 已提交
73
        return value
F
FNRE 已提交
74 75 76 77 78

    def name(self):
        return 'FID'


L
lijianshe02 已提交
79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
def _calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
    m1 = np.atleast_1d(mu1)
    m2 = np.atleast_1d(mu2)

    sigma1 = np.atleast_2d(sigma1)
    sigma2 = np.atleast_2d(sigma2)

    assert mu1.shape == mu2.shape, 'Training and test mean vectors have different lengths'
    assert sigma1.shape == sigma2.shape, 'Training and test covariances have different dimensions'

    diff = mu1 - mu2

    t = sigma1.dot(sigma2)
    covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
    if not np.isfinite(covmean).all():
        msg = ('fid calculation produces singular product; '
               'adding %s to diagonal of cov estimates') % eps
        print(msg)
        offset = np.eye(sigma1.shape[0]) * eps
        covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))

    # Numerical error might give slight imaginary component
    if np.iscomplexobj(covmean):
        if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
            m = np.max(np.abs(covmean.imag))
            raise ValueError('Imaginary component {}'.format(m))
        covmean = covmean.real
    tr_covmean = np.trace(covmean)

    return (diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) -
            2 * tr_covmean)


F
FNRE 已提交
112
def _get_activations_from_ims(img, model, batch_size, dims, use_gpu):
L
lijianshe02 已提交
113 114
    n_batches = (len(img) + batch_size - 1) // batch_size
    n_used_img = len(img)
L
LielinJiang 已提交
115

L
lijianshe02 已提交
116
    pred_arr = np.empty((n_used_img, dims))
L
LielinJiang 已提交
117 118

    for i in range(n_batches):
L
lijianshe02 已提交
119 120 121 122 123 124 125
        start = i * batch_size
        end = start + batch_size
        if end > len(img):
            end = len(img)
        images = img[start:end]
        if images.shape[1] != 3:
            images = images.transpose((0, 3, 1, 2))
L
LielinJiang 已提交
126

W
wangna11BD 已提交
127
        images = paddle.to_tensor(images)
L
lijianshe02 已提交
128
        pred = model(images)[0][0]
F
FNRE 已提交
129
        pred_arr[start:end] = pred.reshape([end - start, -1]).cpu().numpy()
L
lijianshe02 已提交
130 131 132
    return pred_arr


L
lzzyzlbb 已提交
133
def _compute_statistic_of_img(act):
L
lijianshe02 已提交
134 135 136 137
    mu = np.mean(act, axis=0)
    sigma = np.cov(act, rowvar=False)
    return mu, sigma

L
LielinJiang 已提交
138

L
lzzyzlbb 已提交
139
def calculate_inception_val(img_fake,
L
lijianshe02 已提交
140 141
                            img_real,
                            batch_size,
F
FNRE 已提交
142
                            model,
L
LielinJiang 已提交
143 144 145 146 147 148
                            use_gpu=True,
                            dims=2048):
    act_fake = _get_activations_from_ims(img_fake, model, batch_size, dims,
                                         use_gpu)
    act_real = _get_activations_from_ims(img_real, model, batch_size, dims,
                                         use_gpu)
L
lzzyzlbb 已提交
149
    return act_fake, act_real
L
lijianshe02 已提交
150

L
LielinJiang 已提交
151

L
lzzyzlbb 已提交
152
def calculate_fid_given_img(act_fake, act_real):
L
lijianshe02 已提交
153

L
lzzyzlbb 已提交
154 155
    m1, s1 = _compute_statistic_of_img(act_fake)
    m2, s2 = _compute_statistic_of_img(act_real)
L
lijianshe02 已提交
156 157 158 159
    fid_value = _calculate_frechet_distance(m1, s1, m2, s2)
    return fid_value


Q
qingqing01 已提交
160 161 162 163 164 165 166
def _get_activations(files,
                     model,
                     batch_size,
                     dims,
                     use_gpu,
                     premodel_path,
                     style=None):
L
lijianshe02 已提交
167 168 169 170 171 172 173 174 175 176 177 178 179 180 181
    if len(files) % batch_size != 0:
        print(('Warning: number of images is not a multiple of the '
               'batch size. Some samples are going to be ignored.'))
    if batch_size > len(files):
        print(('Warning: batch size is bigger than the datasets size. '
               'Setting batch size to datasets size'))
        batch_size = len(files)

    n_batches = len(files) // batch_size
    n_used_imgs = n_batches * batch_size

    pred_arr = np.empty((n_used_imgs, dims))
    for i in tqdm(range(n_batches)):
        start = i * batch_size
        end = start + batch_size
M
Mark Ma 已提交
182 183 184 185 186 187 188 189 190 191 192 193

        # same as stargan-v2 official implementation: resize to 256 first, then resize to 299
        if style == 'stargan':
            img_list = []
            for f in files[start:end]:
                im = Image.open(str(f)).convert('RGB')
                if im.size[0] != 299:
                    im = im.resize((256, 256), 2)
                    im = im.resize((299, 299), 2)

                img_list.append(np.array(im).astype('float32'))

Q
qingqing01 已提交
194
            images = np.array(img_list)
M
Mark Ma 已提交
195 196 197
        else:
            images = np.array(
                [imread(str(f)).astype(np.float32) for f in files[start:end]])
L
lijianshe02 已提交
198 199 200 201 202 203 204 205 206

        if len(images.shape) != 4:
            images = imread(str(files[start]))
            images = cv2.cvtColor(images, cv2.COLOR_BGR2GRAY)
            images = np.array([images.astype(np.float32)])

        images = images.transpose((0, 3, 1, 2))
        images /= 255

M
Mark Ma 已提交
207 208 209 210 211
        # imagenet normalization
        if style == 'stargan':
            mean = np.array([0.485, 0.456, 0.406]).astype('float32')
            std = np.array([0.229, 0.224, 0.225]).astype('float32')
            images[:] = (images[:] - mean[:, None, None]) / std[:, None, None]
L
lijianshe02 已提交
212

Q
qingqing01 已提交
213
        if style == 'stargan':
M
Mark Ma 已提交
214 215
            pred_arr[start:end] = inception_infer(images, premodel_path)
        else:
W
wangna11BD 已提交
216 217 218
            with paddle.guard():
                images = paddle.to_tensor(images)
                param_dict, _ = paddle.load(premodel_path)
M
Mark Ma 已提交
219 220
                model.set_dict(param_dict)
                model.eval()
L
lijianshe02 已提交
221

M
Mark Ma 已提交
222 223 224
                pred = model(images)[0][0].numpy()

                pred_arr[start:end] = pred.reshape(end - start, -1)
L
lijianshe02 已提交
225 226 227 228

    return pred_arr


M
Mark Ma 已提交
229
def inception_infer(x, model_path):
W
wangna11BD 已提交
230
    exe = paddle.static.Executor()
Q
qingqing01 已提交
231
    [inference_program, feed_target_names,
W
wangna11BD 已提交
232
     fetch_targets] = paddle.static.load_inference_model(model_path, exe)
M
Mark Ma 已提交
233
    results = exe.run(inference_program,
Q
qingqing01 已提交
234 235
                      feed={feed_target_names[0]: x},
                      fetch_list=fetch_targets)
M
Mark Ma 已提交
236 237 238
    return results[0]


L
lijianshe02 已提交
239 240 241 242 243
def _calculate_activation_statistics(files,
                                     model,
                                     premodel_path,
                                     batch_size=50,
                                     dims=2048,
M
Mark Ma 已提交
244
                                     use_gpu=False,
Q
qingqing01 已提交
245
                                     style=None):
L
lijianshe02 已提交
246
    act = _get_activations(files, model, batch_size, dims, use_gpu,
M
Mark Ma 已提交
247
                           premodel_path, style)
L
lijianshe02 已提交
248 249 250 251 252
    mu = np.mean(act, axis=0)
    sigma = np.cov(act, rowvar=False)
    return mu, sigma


Q
qingqing01 已提交
253 254 255 256 257 258 259
def _compute_statistics_of_path(path,
                                model,
                                batch_size,
                                dims,
                                use_gpu,
                                premodel_path,
                                style=None):
L
lijianshe02 已提交
260 261 262 263 264 265 266 267 268 269 270
    if path.endswith('.npz'):
        f = np.load(path)
        m, s = f['mu'][:], f['sigma'][:]
        f.close()
    else:
        files = []
        for root, dirnames, filenames in os.walk(path):
            for filename in fnmatch.filter(
                    filenames, '*.jpg') or fnmatch.filter(filenames, '*.png'):
                files.append(os.path.join(root, filename))
        m, s = _calculate_activation_statistics(files, model, premodel_path,
Q
qingqing01 已提交
271 272
                                                batch_size, dims, use_gpu,
                                                style)
L
lijianshe02 已提交
273 274 275 276
    return m, s


def calculate_fid_given_paths(paths,
L
lijianshe02 已提交
277
                              premodel_path,
L
lijianshe02 已提交
278 279 280
                              batch_size,
                              use_gpu,
                              dims,
M
Mark Ma 已提交
281
                              model=None,
Q
qingqing01 已提交
282
                              style=None):
L
lijianshe02 已提交
283 284 285 286 287 288 289 290
    assert os.path.exists(
        premodel_path
    ), 'pretrain_model path {} is not exists! Please download it first'.format(
        premodel_path)
    for p in paths:
        if not os.path.exists(p):
            raise RuntimeError('Invalid path: %s' % p)

291
    if model is None and style != 'stargan':
W
wangna11BD 已提交
292
        with paddle.guard():
293 294
            block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
            model = InceptionV3([block_idx], class_dim=1008)
L
lijianshe02 已提交
295 296

    m1, s1 = _compute_statistics_of_path(paths[0], model, batch_size, dims,
M
Mark Ma 已提交
297
                                         use_gpu, premodel_path, style)
L
lijianshe02 已提交
298
    m2, s2 = _compute_statistics_of_path(paths[1], model, batch_size, dims,
M
Mark Ma 已提交
299
                                         use_gpu, premodel_path, style)
L
lijianshe02 已提交
300 301 302

    fid_value = _calculate_frechet_distance(m1, s1, m2, s2)
    return fid_value