fid.py 10.2 KB
Newer Older
Q
qingqing01 已提交
1
#Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
L
lijianshe02 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.

import os
import fnmatch
import numpy as np
import cv2
W
wangna11BD 已提交
19
import paddle
M
Mark Ma 已提交
20
from PIL import Image
L
lijianshe02 已提交
21 22
from cv2 import imread
from scipy import linalg
F
FNRE 已提交
23 24 25
from .inception import InceptionV3
from paddle.utils.download import get_weights_path_from_url
from .builder import METRICS
L
lijianshe02 已提交
26

M
Mark Ma 已提交
27 28 29
try:
    from tqdm import tqdm
except:
Q
qingqing01 已提交
30

M
Mark Ma 已提交
31 32
    def tqdm(x):
        return x
L
lijianshe02 已提交
33 34 35 36 37 38 39


""" based on https://github.com/mit-han-lab/gan-compression/blob/master/metric/fid_score.py
"""
"""
inceptionV3 pretrain model is convert from pytorch, pretrain_model url is https://paddle-gan-models.bj.bcebos.com/params_inceptionV3.tar.gz
"""
F
FNRE 已提交
40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
INCEPTIONV3_WEIGHT_URL = "https://paddlegan.bj.bcebos.com/InceptionV3.pdparams"

@METRICS.register()
class FID(paddle.metric.Metric):
    def __init__(self, batch_size=1, use_GPU=True, dims = 2048, premodel_path=None, model=None):
        self.batch_size = batch_size
        self.use_GPU = use_GPU
        self.dims = dims
        self.premodel_path = premodel_path
        if model is None:
            block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
            model = InceptionV3([block_idx])
        if premodel_path is None:
            premodel_path = get_weights_path_from_url(INCEPTIONV3_WEIGHT_URL)
        self.model = model
        param_dict = paddle.load(premodel_path)
        model.load_dict(param_dict)
        model.eval()
        self.reset()   
        
    def reset(self):
        self.results = []

    def update(self, preds, gts):
        value = calculate_fid_given_img(preds, gts, self.batch_size, self.model, self.use_GPU, self.dims)
        self.results.append(value)

    def accumulate(self):
        if len(self.results) <= 0:
            return 0.
        return np.mean(self.results)

    def name(self):
        return 'FID'


L
lijianshe02 已提交
76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111


def _calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
    m1 = np.atleast_1d(mu1)
    m2 = np.atleast_1d(mu2)

    sigma1 = np.atleast_2d(sigma1)
    sigma2 = np.atleast_2d(sigma2)

    assert mu1.shape == mu2.shape, 'Training and test mean vectors have different lengths'
    assert sigma1.shape == sigma2.shape, 'Training and test covariances have different dimensions'

    diff = mu1 - mu2

    t = sigma1.dot(sigma2)
    covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
    if not np.isfinite(covmean).all():
        msg = ('fid calculation produces singular product; '
               'adding %s to diagonal of cov estimates') % eps
        print(msg)
        offset = np.eye(sigma1.shape[0]) * eps
        covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))

    # Numerical error might give slight imaginary component
    if np.iscomplexobj(covmean):
        if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
            m = np.max(np.abs(covmean.imag))
            raise ValueError('Imaginary component {}'.format(m))
        covmean = covmean.real

    tr_covmean = np.trace(covmean)

    return (diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) -
            2 * tr_covmean)


F
FNRE 已提交
112
def _get_activations_from_ims(img, model, batch_size, dims, use_gpu):
L
lijianshe02 已提交
113 114
    n_batches = (len(img) + batch_size - 1) // batch_size
    n_used_img = len(img)
F
FNRE 已提交
115
    
L
lijianshe02 已提交
116
    pred_arr = np.empty((n_used_img, dims))
F
FNRE 已提交
117
    
L
lijianshe02 已提交
118 119 120 121 122 123 124 125 126 127
    for i in tqdm(range(n_batches)):
        start = i * batch_size
        end = start + batch_size
        if end > len(img):
            end = len(img)
        images = img[start:end]
        if images.shape[1] != 3:
            images = images.transpose((0, 3, 1, 2))
        images /= 255

W
wangna11BD 已提交
128
        images = paddle.to_tensor(images)
L
lijianshe02 已提交
129
        pred = model(images)[0][0]
F
FNRE 已提交
130
        pred_arr[start:end] = pred.reshape([end - start, -1]).cpu().numpy()
L
lijianshe02 已提交
131 132 133
    return pred_arr


F
FNRE 已提交
134 135
def _compute_statistic_of_img(img, model, batch_size, dims, use_gpu):
    act = _get_activations_from_ims(img, model, batch_size, dims, use_gpu)
L
lijianshe02 已提交
136 137 138 139 140 141 142 143
    mu = np.mean(act, axis=0)
    sigma = np.cov(act, rowvar=False)
    return mu, sigma


def calculate_fid_given_img(img_fake,
                            img_real,
                            batch_size,
F
FNRE 已提交
144 145 146
                            model,
                            use_gpu = True,
                            dims = 2048):
L
lijianshe02 已提交
147 148

    m1, s1 = _compute_statistic_of_img(img_fake, model, batch_size, dims,
F
FNRE 已提交
149
                                       use_gpu)
L
lijianshe02 已提交
150
    m2, s2 = _compute_statistic_of_img(img_real, model, batch_size, dims,
F
FNRE 已提交
151
                                       use_gpu)
L
lijianshe02 已提交
152 153 154 155 156

    fid_value = _calculate_frechet_distance(m1, s1, m2, s2)
    return fid_value


Q
qingqing01 已提交
157 158 159 160 161 162 163
def _get_activations(files,
                     model,
                     batch_size,
                     dims,
                     use_gpu,
                     premodel_path,
                     style=None):
L
lijianshe02 已提交
164 165 166 167 168 169 170 171 172 173 174 175 176 177 178
    if len(files) % batch_size != 0:
        print(('Warning: number of images is not a multiple of the '
               'batch size. Some samples are going to be ignored.'))
    if batch_size > len(files):
        print(('Warning: batch size is bigger than the datasets size. '
               'Setting batch size to datasets size'))
        batch_size = len(files)

    n_batches = len(files) // batch_size
    n_used_imgs = n_batches * batch_size

    pred_arr = np.empty((n_used_imgs, dims))
    for i in tqdm(range(n_batches)):
        start = i * batch_size
        end = start + batch_size
M
Mark Ma 已提交
179 180 181 182 183 184 185 186 187 188 189 190

        # same as stargan-v2 official implementation: resize to 256 first, then resize to 299
        if style == 'stargan':
            img_list = []
            for f in files[start:end]:
                im = Image.open(str(f)).convert('RGB')
                if im.size[0] != 299:
                    im = im.resize((256, 256), 2)
                    im = im.resize((299, 299), 2)

                img_list.append(np.array(im).astype('float32'))

Q
qingqing01 已提交
191
            images = np.array(img_list)
M
Mark Ma 已提交
192 193 194
        else:
            images = np.array(
                [imread(str(f)).astype(np.float32) for f in files[start:end]])
L
lijianshe02 已提交
195 196 197 198 199 200 201 202 203

        if len(images.shape) != 4:
            images = imread(str(files[start]))
            images = cv2.cvtColor(images, cv2.COLOR_BGR2GRAY)
            images = np.array([images.astype(np.float32)])

        images = images.transpose((0, 3, 1, 2))
        images /= 255

M
Mark Ma 已提交
204 205 206 207 208
        # imagenet normalization
        if style == 'stargan':
            mean = np.array([0.485, 0.456, 0.406]).astype('float32')
            std = np.array([0.229, 0.224, 0.225]).astype('float32')
            images[:] = (images[:] - mean[:, None, None]) / std[:, None, None]
L
lijianshe02 已提交
209

Q
qingqing01 已提交
210
        if style == 'stargan':
M
Mark Ma 已提交
211 212
            pred_arr[start:end] = inception_infer(images, premodel_path)
        else:
W
wangna11BD 已提交
213 214 215
            with paddle.guard():
                images = paddle.to_tensor(images)
                param_dict, _ = paddle.load(premodel_path)
M
Mark Ma 已提交
216 217
                model.set_dict(param_dict)
                model.eval()
L
lijianshe02 已提交
218

M
Mark Ma 已提交
219 220 221
                pred = model(images)[0][0].numpy()

                pred_arr[start:end] = pred.reshape(end - start, -1)
L
lijianshe02 已提交
222 223 224 225

    return pred_arr


M
Mark Ma 已提交
226
def inception_infer(x, model_path):
W
wangna11BD 已提交
227
    exe = paddle.static.Executor()
Q
qingqing01 已提交
228
    [inference_program, feed_target_names,
W
wangna11BD 已提交
229
     fetch_targets] = paddle.static.load_inference_model(model_path, exe)
M
Mark Ma 已提交
230
    results = exe.run(inference_program,
Q
qingqing01 已提交
231 232
                      feed={feed_target_names[0]: x},
                      fetch_list=fetch_targets)
M
Mark Ma 已提交
233 234 235
    return results[0]


L
lijianshe02 已提交
236 237 238 239 240
def _calculate_activation_statistics(files,
                                     model,
                                     premodel_path,
                                     batch_size=50,
                                     dims=2048,
M
Mark Ma 已提交
241
                                     use_gpu=False,
Q
qingqing01 已提交
242
                                     style=None):
L
lijianshe02 已提交
243
    act = _get_activations(files, model, batch_size, dims, use_gpu,
M
Mark Ma 已提交
244
                           premodel_path, style)
L
lijianshe02 已提交
245 246 247 248 249
    mu = np.mean(act, axis=0)
    sigma = np.cov(act, rowvar=False)
    return mu, sigma


Q
qingqing01 已提交
250 251 252 253 254 255 256
def _compute_statistics_of_path(path,
                                model,
                                batch_size,
                                dims,
                                use_gpu,
                                premodel_path,
                                style=None):
L
lijianshe02 已提交
257 258 259 260 261 262 263 264 265 266 267
    if path.endswith('.npz'):
        f = np.load(path)
        m, s = f['mu'][:], f['sigma'][:]
        f.close()
    else:
        files = []
        for root, dirnames, filenames in os.walk(path):
            for filename in fnmatch.filter(
                    filenames, '*.jpg') or fnmatch.filter(filenames, '*.png'):
                files.append(os.path.join(root, filename))
        m, s = _calculate_activation_statistics(files, model, premodel_path,
Q
qingqing01 已提交
268 269
                                                batch_size, dims, use_gpu,
                                                style)
L
lijianshe02 已提交
270 271 272 273
    return m, s


def calculate_fid_given_paths(paths,
L
lijianshe02 已提交
274
                              premodel_path,
L
lijianshe02 已提交
275 276 277
                              batch_size,
                              use_gpu,
                              dims,
M
Mark Ma 已提交
278
                              model=None,
Q
qingqing01 已提交
279
                              style=None):
L
lijianshe02 已提交
280 281 282 283 284 285 286 287
    assert os.path.exists(
        premodel_path
    ), 'pretrain_model path {} is not exists! Please download it first'.format(
        premodel_path)
    for p in paths:
        if not os.path.exists(p):
            raise RuntimeError('Invalid path: %s' % p)

288
    if model is None and style != 'stargan':
W
wangna11BD 已提交
289
        with paddle.guard():
290 291
            block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
            model = InceptionV3([block_idx], class_dim=1008)
L
lijianshe02 已提交
292 293

    m1, s1 = _compute_statistics_of_path(paths[0], model, batch_size, dims,
M
Mark Ma 已提交
294
                                         use_gpu, premodel_path, style)
L
lijianshe02 已提交
295
    m2, s2 = _compute_statistics_of_path(paths[1], model, batch_size, dims,
M
Mark Ma 已提交
296
                                         use_gpu, premodel_path, style)
L
lijianshe02 已提交
297 298 299

    fid_value = _calculate_frechet_distance(m1, s1, m2, s2)
    return fid_value