diff --git a/ppgan/metric/README.md b/ppgan/metric/README.md new file mode 100644 index 0000000000000000000000000000000000000000..d27e99d639bfed7b63ff90fdd1d54e12b45f78e0 --- /dev/null +++ b/ppgan/metric/README.md @@ -0,0 +1,10 @@ +English (./README.md) + +# Usage + +To compute the FID score between two datasets, where images of each dataset are contained in an individual folder: + +wget https://paddlegan.bj.bcebos.com/InceptionV3.pdparams +``` +python test_fid_score.py --image_data_path1 /path/to/dataset1 --image_data_path2 /path/to/dataset2 --inference_model ./InceptionV3.pdparams +``` diff --git a/ppgan/metric/compute_fid.py b/ppgan/metric/compute_fid.py new file mode 100644 index 0000000000000000000000000000000000000000..c8fc8059e2658768b1f07436f2bca6e08446014c --- /dev/null +++ b/ppgan/metric/compute_fid.py @@ -0,0 +1,224 @@ +#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +import os +import fnmatch +import numpy as np +import cv2 +from cv2 import imread +from scipy import linalg +import paddle.fluid as fluid +from inception import InceptionV3 +from paddle.fluid.dygraph.base import to_variable + + +def tqdm(x): + return x + + +""" based on https://github.com/mit-han-lab/gan-compression/blob/master/metric/fid_score.py +""" +""" +inceptionV3 pretrain model is convert from pytorch, pretrain_model url is https://paddle-gan-models.bj.bcebos.com/params_inceptionV3.tar.gz +""" + + +def _calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): + m1 = np.atleast_1d(mu1) + m2 = np.atleast_1d(mu2) + + sigma1 = np.atleast_2d(sigma1) + sigma2 = np.atleast_2d(sigma2) + + assert mu1.shape == mu2.shape, 'Training and test mean vectors have different lengths' + assert sigma1.shape == sigma2.shape, 'Training and test covariances have different dimensions' + + diff = mu1 - mu2 + + t = sigma1.dot(sigma2) + covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) + if not np.isfinite(covmean).all(): + msg = ('fid calculation produces singular product; ' + 'adding %s to diagonal of cov estimates') % eps + print(msg) + offset = np.eye(sigma1.shape[0]) * eps + covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) + + # Numerical error might give slight imaginary component + if np.iscomplexobj(covmean): + if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): + m = np.max(np.abs(covmean.imag)) + raise ValueError('Imaginary component {}'.format(m)) + covmean = covmean.real + + tr_covmean = np.trace(covmean) + + return (diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - + 2 * tr_covmean) + + +def _get_activations_from_ims(img, model, batch_size, dims, use_gpu, + premodel_path): + n_batches = (len(img) + batch_size - 1) // batch_size + n_used_img = len(img) + + pred_arr = np.empty((n_used_img, dims)) + + for i in tqdm(range(n_batches)): + start = i * batch_size + end = start + batch_size + if end > len(img): + end = len(img) + images = img[start:end] + if images.shape[1] != 3: + images = images.transpose((0, 3, 1, 2)) + images /= 255 + + images = to_variable(images) + param_dict, _ = fluid.load_dygraph(premodel_path) + model.set_dict(param_dict) + model.eval() + pred = model(images)[0][0] + pred_arr[start:end] = pred.reshape(end - start, -1) + + return pred_arr + + +def _compute_statistic_of_img(img, model, batch_size, dims, use_gpu, + premodel_path): + act = _get_activations_from_ims(img, model, batch_size, dims, use_gpu, + premodel_path) + mu = np.mean(act, axis=0) + sigma = np.cov(act, rowvar=False) + return mu, sigma + + +def calculate_fid_given_img(img_fake, + img_real, + batch_size, + use_gpu, + dims, + premodel_path, + model=None): + assert os.path.exists( + premodel_path + ), 'pretrain_model path {} is not exists! Please download it first'.format( + premodel_path) + if model is None: + block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] + model = InceptionV3([block_idx]) + + m1, s1 = _compute_statistic_of_img(img_fake, model, batch_size, dims, + use_gpu, premodel_path) + m2, s2 = _compute_statistic_of_img(img_real, model, batch_size, dims, + use_gpu, premodel_path) + + fid_value = _calculate_frechet_distance(m1, s1, m2, s2) + return fid_value + + +def _get_activations(files, model, batch_size, dims, use_gpu, premodel_path): + if len(files) % batch_size != 0: + print(('Warning: number of images is not a multiple of the ' + 'batch size. Some samples are going to be ignored.')) + if batch_size > len(files): + print(('Warning: batch size is bigger than the datasets size. ' + 'Setting batch size to datasets size')) + batch_size = len(files) + + n_batches = len(files) // batch_size + n_used_imgs = n_batches * batch_size + + pred_arr = np.empty((n_used_imgs, dims)) + for i in tqdm(range(n_batches)): + start = i * batch_size + end = start + batch_size + images = np.array( + [imread(str(f)).astype(np.float32) for f in files[start:end]]) + + if len(images.shape) != 4: + images = imread(str(files[start])) + images = cv2.cvtColor(images, cv2.COLOR_BGR2GRAY) + images = np.array([images.astype(np.float32)]) + + images = images.transpose((0, 3, 1, 2)) + images /= 255 + + images = to_variable(images) + param_dict, _ = fluid.load_dygraph(premodel_path) + model.set_dict(param_dict) + model.eval() + + pred = model(images)[0][0].numpy() + + pred_arr[start:end] = pred.reshape(end - start, -1) + + return pred_arr + + +def _calculate_activation_statistics(files, + model, + premodel_path, + batch_size=50, + dims=2048, + use_gpu=False): + act = _get_activations(files, model, batch_size, dims, use_gpu, + premodel_path) + mu = np.mean(act, axis=0) + sigma = np.cov(act, rowvar=False) + return mu, sigma + + +def _compute_statistics_of_path(path, model, batch_size, dims, use_gpu, + premodel_path): + if path.endswith('.npz'): + f = np.load(path) + m, s = f['mu'][:], f['sigma'][:] + f.close() + else: + files = [] + for root, dirnames, filenames in os.walk(path): + for filename in fnmatch.filter( + filenames, '*.jpg') or fnmatch.filter(filenames, '*.png'): + files.append(os.path.join(root, filename)) + m, s = _calculate_activation_statistics(files, model, premodel_path, + batch_size, dims, use_gpu) + return m, s + + +def calculate_fid_given_paths(paths, + premodel_path, + batch_size, + use_gpu, + dims, + model=None): + assert os.path.exists( + premodel_path + ), 'pretrain_model path {} is not exists! Please download it first'.format( + premodel_path) + for p in paths: + if not os.path.exists(p): + raise RuntimeError('Invalid path: %s' % p) + + if model is None: + block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] + model = InceptionV3([block_idx], class_dim=1008) + + m1, s1 = _compute_statistics_of_path(paths[0], model, batch_size, dims, + use_gpu, premodel_path) + m2, s2 = _compute_statistics_of_path(paths[1], model, batch_size, dims, + use_gpu, premodel_path) + + fid_value = _calculate_frechet_distance(m1, s1, m2, s2) + return fid_value diff --git a/ppgan/metric/inception.py b/ppgan/metric/inception.py new file mode 100644 index 0000000000000000000000000000000000000000..35a2866005f1df574a35a6b11a02c737445da81c --- /dev/null +++ b/ppgan/metric/inception.py @@ -0,0 +1,757 @@ +#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +import math +import paddle +import paddle.fluid as fluid +from paddle.fluid.param_attr import ParamAttr +from paddle.fluid.layer_helper import LayerHelper +from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear +from paddle.fluid.dygraph.base import to_variable + +__all__ = ['InceptionV3'] + + +class InceptionV3(fluid.dygraph.Layer): + DEFAULT_BLOCK_INDEX = 3 + BLOCK_INDEX_BY_DIM = { + 64: 0, # First max pooling features + 192: 1, # Second max pooling featurs + 768: 2, # Pre-aux classifier features + 2048: 3 # Final average pooling features + } + + def __init__(self, + output_blocks=[DEFAULT_BLOCK_INDEX], + class_dim=1000, + aux_logits=False, + resize_input=True, + normalize_input=True): + super(InceptionV3, self).__init__() + self.resize_input = resize_input + self.normalize_input = normalize_input + self.output_blocks = sorted(output_blocks) + self.last_needed_block = max(output_blocks) + self.class_dim = class_dim + self.aux_logits = aux_logits + + assert self.last_needed_block <= 3, 'Last possible output block index is 3' + self.blocks = [] + + self.Conv2d_1a_3x3 = ConvBNLayer(3, + 32, + 3, + stride=2, + name='Conv2d_1a_3x3') + self.Conv2d_2a_3x3 = ConvBNLayer(32, 32, 3, name='Conv2d_2a_3x3') + self.Conv2d_2b_3x3 = ConvBNLayer(32, + 64, + 3, + padding=1, + name='Conv2d_2b_3x3') + self.maxpool1 = Pool2D(pool_size=3, pool_stride=2, pool_type='max') + + block0 = [ + self.Conv2d_1a_3x3, self.Conv2d_2a_3x3, self.Conv2d_2b_3x3, + self.maxpool1 + ] + self.blocks.append(fluid.dygraph.Sequential(*block0)) + ### block1 + + if self.last_needed_block >= 1: + self.Conv2d_3b_1x1 = ConvBNLayer(64, 80, 1, name='Conv2d_3b_1x1') + self.Conv2d_4a_3x3 = ConvBNLayer(80, 192, 3, name='Conv2d_4a_3x3') + self.maxpool2 = Pool2D(pool_size=3, pool_stride=2, pool_type='max') + block1 = [self.Conv2d_3b_1x1, self.Conv2d_4a_3x3, self.maxpool2] + self.blocks.append(fluid.dygraph.Sequential(*block1)) + + ### block2 + ### Mixed_5b 5c 5d + if self.last_needed_block >= 2: + self.Mixed_5b = Fid_inceptionA(192, + pool_features=32, + name='Mixed_5b') + self.Mixed_5c = Fid_inceptionA(256, + pool_features=64, + name='Mixed_5c') + self.Mixed_5d = Fid_inceptionA(288, + pool_features=64, + name='Mixed_5d') + + ### Mixed_6 + self.Mixed_6a = InceptionB(288, name='Mixed_6a') + self.Mixed_6b = Fid_inceptionC(768, c7=128, name='Mixed_6b') + self.Mixed_6c = Fid_inceptionC(768, c7=160, name='Mixed_6c') + self.Mixed_6d = Fid_inceptionC(768, c7=160, name='Mixed_6d') + self.Mixed_6e = Fid_inceptionC(768, c7=192, name='Mixed_6e') + + block2 = [ + self.Mixed_5b, self.Mixed_5c, self.Mixed_5d, self.Mixed_6a, + self.Mixed_6b, self.Mixed_6c, self.Mixed_6d, self.Mixed_6e + ] + self.blocks.append(fluid.dygraph.Sequential(*block2)) + + if self.aux_logits: + self.AuxLogits = InceptionAux(768, self.class_dim, name='AuxLogits') + ### block3 + ### Mixed_7 + if self.last_needed_block >= 3: + self.Mixed_7a = InceptionD(768, name='Mixed_7a') + self.Mixed_7b = Fid_inceptionE_1(1280, name='Mixed_7b') + self.Mixed_7c = Fid_inceptionE_2(2048, name='Mixed_7c') + self.avgpool = Pool2D(global_pooling=True, pool_type='avg') + + block3 = [self.Mixed_7a, self.Mixed_7b, self.Mixed_7c, self.avgpool] + self.blocks.append(fluid.dygraph.Sequential(*block3)) + + def forward(self, x): + out = [] + aux = None + if self.resize_input: + x = fluid.layers.resize_bilinear(x, + out_shape=[299, 299], + align_corners=False, + align_mode=0) + + if self.normalize_input: + x = x * 2 - 1 + + for idx, block in enumerate(self.blocks): + x = block(x) + if self.aux_logits and (idx == 2): + aux = self.AuxLogits(x) + if idx in self.output_blocks: + out.append(x) + if idx == self.last_needed_block: + break + + return out, aux + + +class InceptionA(fluid.dygraph.Layer): + def __init__(self, in_channels, pool_features, name=None): + super(InceptionA, self).__init__() + self.branch1x1 = ConvBNLayer(in_channels, + 64, + 1, + name=name + '.branch1x1') + + self.branch5x5_1 = ConvBNLayer(in_channels, + 48, + 1, + name=name + '.branch5x5_1') + self.branch5x5_2 = ConvBNLayer(48, + 64, + 5, + padding=2, + name=name + '.branch5x5_2') + + self.branch3x3dbl_1 = ConvBNLayer(in_channels, + 64, + 1, + name=name + '.branch3x3dbl_1') + self.branch3x3dbl_2 = ConvBNLayer(64, + 96, + 3, + padding=1, + name=name + '.branch3x3dbl_2') + self.branch3x3dbl_3 = ConvBNLayer(96, + 96, + 3, + padding=1, + name=name + '.branch3x3dbl_3') + + self.branch_pool0 = Pool2D(pool_size=3, + pool_stride=1, + pool_padding=1, + exclusive=True, + pool_type='avg') + self.branch_pool = ConvBNLayer(in_channels, + pool_features, + 1, + name=name + '.branch_pool') + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch5x5 = self.branch5x5_1(x) + branch5x5 = self.branch5x5_2(branch5x5) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) + + branch_pool = self.branch_pool0(x) + branch_pool = self.branch_pool(branch_pool) + return fluid.layers.concat( + [branch1x1, branch5x5, branch3x3dbl, branch_pool], axis=1) + + +class InceptionB(fluid.dygraph.Layer): + def __init__(self, in_channels, name=None): + super(InceptionB, self).__init__() + self.branch3x3 = ConvBNLayer(in_channels, + 384, + 3, + stride=2, + name=name + '.branch3x3') + + self.branch3x3dbl_1 = ConvBNLayer(in_channels, + 64, + 1, + name=name + '.branch3x3dbl_1') + self.branch3x3dbl_2 = ConvBNLayer(64, + 96, + 3, + padding=1, + name=name + '.branch3x3dbl_2') + self.branch3x3dbl_3 = ConvBNLayer(96, + 96, + 3, + stride=2, + name=name + '.branch3x3dbl_3') + + self.branch_pool = Pool2D(pool_size=3, pool_stride=2, pool_type='max') + + def forward(self, x): + branch3x3 = self.branch3x3(x) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) + + branch_pool = self.branch_pool(x) + return fluid.layers.concat([branch3x3, branch3x3dbl, branch_pool], + axis=1) + + +class InceptionC(fluid.dygraph.Layer): + def __init__(self, in_channels, c7, name=None): + super(InceptionC, self).__init__() + self.branch1x1 = ConvBNLayer(in_channels, + 192, + 1, + name=name + '.branch1x1') + + self.branch7x7_1 = ConvBNLayer(in_channels, + c7, + 1, + name=name + '.branch7x7_1') + self.branch7x7_2 = ConvBNLayer(c7, + c7, (1, 7), + padding=(0, 3), + name=name + '.branch7x7_2') + self.branch7x7_3 = ConvBNLayer(c7, + 192, (7, 1), + padding=(3, 0), + name=name + '.branch7x7_3') + + self.branch7x7dbl_1 = ConvBNLayer(in_channels, + c7, + 1, + name=name + '.branch7x7dbl_1') + self.branch7x7dbl_2 = ConvBNLayer(c7, + c7, (7, 1), + padding=(3, 0), + name=name + '.branch7x7dbl_2') + self.branch7x7dbl_3 = ConvBNLayer(c7, + c7, (1, 7), + padding=(0, 3), + name=name + '.branch7x7dbl_3') + self.branch7x7dbl_4 = ConvBNLayer(c7, + c7, (7, 1), + padding=(3, 0), + name=name + '.branch7x7dbl_4') + self.branch7x7dbl_5 = ConvBNLayer(c7, + 192, (1, 7), + padding=(0, 3), + name=name + '.branch7x7dbl_5') + + self.branch_pool0 = Pool2D(pool_size=3, + pool_stride=1, + pool_padding=1, + exclusive=True, + pool_type='avg') + self.branch_pool = ConvBNLayer(in_channels, + 192, + 1, + name=name + '.branch_pool') + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch7x7 = self.branch7x7_1(x) + branch7x7 = self.branch7x7_2(branch7x7) + branch7x7 = self.branch7x7_3(branch7x7) + + branch7x7dbl = self.branch7x7dbl_1(x) + branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) + + branch_pool = self.branch_pool0(x) + branch_pool = self.branch_pool(branch_pool) + + return fluid.layers.concat( + [branch1x1, branch7x7, branch7x7dbl, branch_pool], axis=1) + + +class InceptionD(fluid.dygraph.Layer): + def __init__(self, in_channels, name=None): + super(InceptionD, self).__init__() + self.branch3x3_1 = ConvBNLayer(in_channels, + 192, + 1, + name=name + '.branch3x3_1') + self.branch3x3_2 = ConvBNLayer(192, + 320, + 3, + stride=2, + name=name + '.branch3x3_2') + + self.branch7x7x3_1 = ConvBNLayer(in_channels, + 192, + 1, + name=name + '.branch7x7x3_1') + self.branch7x7x3_2 = ConvBNLayer(192, + 192, (1, 7), + padding=(0, 3), + name=name + '.branch7x7x3_2') + self.branch7x7x3_3 = ConvBNLayer(192, + 192, (7, 1), + padding=(3, 0), + name=name + '.branch7x7x3_3') + self.branch7x7x3_4 = ConvBNLayer(192, + 192, + 3, + stride=2, + name=name + '.branch7x7x3_4') + + self.branch_pool = Pool2D(pool_size=3, pool_stride=2, pool_type='max') + + def forward(self, x): + branch3x3 = self.branch3x3_1(x) + branch3x3 = self.branch3x3_2(branch3x3) + + branch7x7x3 = self.branch7x7x3_1(x) + branch7x7x3 = self.branch7x7x3_2(branch7x7x3) + branch7x7x3 = self.branch7x7x3_3(branch7x7x3) + branch7x7x3 = self.branch7x7x3_4(branch7x7x3) + + branch_pool = self.branch_pool(x) + + return fluid.layers.concat([branch3x3, branch7x7x3, branch_pool], + axis=1) + + +class InceptionE(fluid.dygraph.Layer): + def __init__(self, in_channels, name=None): + super(InceptionE, self).__init__() + self.branch1x1 = ConvBNLayer(in_channels, + 320, + 1, + name=name + '.branch1x1') + + self.branch3x3_1 = ConvBNLayer(in_channels, + 384, + 1, + name=name + '.branch3x3_1') + self.branch3x3_2a = ConvBNLayer(384, + 384, (1, 3), + padding=(0, 1), + name=name + '.branch3x3_2a') + self.branch3x3_2b = ConvBNLayer(384, + 384, (3, 1), + padding=(1, 0), + name=name + '.branch3x3_2b') + + self.branch3x3dbl_1 = ConvBNLayer(in_channels, + 448, + 1, + name=name + '.branch3x3dbl_1') + self.branch3x3dbl_2 = ConvBNLayer(448, + 384, + 3, + padding=1, + name=name + '.branch3x3dbl_2') + self.branch3x3dbl_3a = ConvBNLayer(384, + 384, (1, 3), + padding=(0, 1), + name=name + '.branch3x3dbl_3a') + self.branch3x3dbl_3b = ConvBNLayer(384, + 384, (3, 1), + padding=(1, 0), + name=name + '.branch3x3dbl_3b') + + self.branch_pool0 = Pool2D(pool_size=3, + pool_stride=1, + pool_padding=1, + exclusive=True, + pool_type='avg') + self.branch_pool = ConvBNLayer(in_channels, + 192, + 1, + name=name + '.branch_pool') + + def forward(self, x): + branch1x1 = self.branch1x1(x) + branch3x3_1 = self.branch3x3_1(x) + branch3x3_2a = self.branch3x3_2a(branch3x3_1) + branch3x3_2b = self.branch3x3_2b(branch3x3_1) + branch3x3 = fluid.layers.concat([branch3x3_2a, branch3x3_2b], axis=1) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl_3a = self.branch3x3dbl_3a(branch3x3dbl) + branch3x3dbl_3b = self.branch3x3dbl_3b(branch3x3dbl) + branch3x3dbl = fluid.layers.concat([branch3x3dbl_3a, branch3x3dbl_3b], + axis=1) + + branch_pool = self.branch_pool0(x) + branch_pool = self.branch_pool(branch_pool) + + return fluid.layers.concat( + [branch1x1, branch3x3, branch3x3dbl, branch_pool], axis=1) + + +class InceptionAux(fluid.dygraph.Layer): + def __init__(self, in_channels, num_classes, name=None): + super(InceptionAux, self).__init__() + self.num_classes = num_classes + self.pool0 = Pool2D(pool_size=5, pool_stride=3, pool_type='avg') + self.conv0 = ConvBNLayer(in_channels, 128, 1, name=name + '.conv0') + self.conv1 = ConvBNLayer(128, 768, 5, name=name + '.conv1') + self.pool1 = Pool2D(global_pooling=True, pool_type='avg') + + def forward(self, x): + x = self.pool0(x) + x = self.conv0(x) + x = self.conv1(x) + x = self.pool1(x) + x = fluid.layers.flatten(x, axis=1) + x = fluid.layers.fc(x, size=self.num_classes) + return x + + +class Fid_inceptionA(fluid.dygraph.Layer): + """ FID block in inception v3 + """ + def __init__(self, in_channels, pool_features, name=None): + super(Fid_inceptionA, self).__init__() + self.branch1x1 = ConvBNLayer(in_channels, + 64, + 1, + name=name + '.branch1x1') + + self.branch5x5_1 = ConvBNLayer(in_channels, + 48, + 1, + name=name + '.branch5x5_1') + self.branch5x5_2 = ConvBNLayer(48, + 64, + 5, + padding=2, + name=name + '.branch5x5_2') + + self.branch3x3dbl_1 = ConvBNLayer(in_channels, + 64, + 1, + name=name + '.branch3x3dbl_1') + self.branch3x3dbl_2 = ConvBNLayer(64, + 96, + 3, + padding=1, + name=name + '.branch3x3dbl_2') + self.branch3x3dbl_3 = ConvBNLayer(96, + 96, + 3, + padding=1, + name=name + '.branch3x3dbl_3') + + self.branch_pool0 = Pool2D(pool_size=3, + pool_stride=1, + pool_padding=1, + exclusive=True, + pool_type='avg') + self.branch_pool = ConvBNLayer(in_channels, + pool_features, + 1, + name=name + '.branch_pool') + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch5x5 = self.branch5x5_1(x) + branch5x5 = self.branch5x5_2(branch5x5) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) + + branch_pool = self.branch_pool0(x) + branch_pool = self.branch_pool(branch_pool) + return fluid.layers.concat( + [branch1x1, branch5x5, branch3x3dbl, branch_pool], axis=1) + + +class Fid_inceptionC(fluid.dygraph.Layer): + """ FID block in inception v3 + """ + def __init__(self, in_channels, c7, name=None): + super(Fid_inceptionC, self).__init__() + self.branch1x1 = ConvBNLayer(in_channels, + 192, + 1, + name=name + '.branch1x1') + + self.branch7x7_1 = ConvBNLayer(in_channels, + c7, + 1, + name=name + '.branch7x7_1') + self.branch7x7_2 = ConvBNLayer(c7, + c7, (1, 7), + padding=(0, 3), + name=name + '.branch7x7_2') + self.branch7x7_3 = ConvBNLayer(c7, + 192, (7, 1), + padding=(3, 0), + name=name + '.branch7x7_3') + + self.branch7x7dbl_1 = ConvBNLayer(in_channels, + c7, + 1, + name=name + '.branch7x7dbl_1') + self.branch7x7dbl_2 = ConvBNLayer(c7, + c7, (7, 1), + padding=(3, 0), + name=name + '.branch7x7dbl_2') + self.branch7x7dbl_3 = ConvBNLayer(c7, + c7, (1, 7), + padding=(0, 3), + name=name + '.branch7x7dbl_3') + self.branch7x7dbl_4 = ConvBNLayer(c7, + c7, (7, 1), + padding=(3, 0), + name=name + '.branch7x7dbl_4') + self.branch7x7dbl_5 = ConvBNLayer(c7, + 192, (1, 7), + padding=(0, 3), + name=name + '.branch7x7dbl_5') + + self.branch_pool0 = Pool2D(pool_size=3, + pool_stride=1, + pool_padding=1, + exclusive=True, + pool_type='avg') + self.branch_pool = ConvBNLayer(in_channels, + 192, + 1, + name=name + '.branch_pool') + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch7x7 = self.branch7x7_1(x) + branch7x7 = self.branch7x7_2(branch7x7) + branch7x7 = self.branch7x7_3(branch7x7) + + branch7x7dbl = self.branch7x7dbl_1(x) + branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) + + branch_pool = self.branch_pool0(x) + branch_pool = self.branch_pool(branch_pool) + + return fluid.layers.concat( + [branch1x1, branch7x7, branch7x7dbl, branch_pool], axis=1) + + +class Fid_inceptionE_1(fluid.dygraph.Layer): + """ FID block in inception v3 + """ + def __init__(self, in_channels, name=None): + super(Fid_inceptionE_1, self).__init__() + self.branch1x1 = ConvBNLayer(in_channels, + 320, + 1, + name=name + '.branch1x1') + + self.branch3x3_1 = ConvBNLayer(in_channels, + 384, + 1, + name=name + '.branch3x3_1') + self.branch3x3_2a = ConvBNLayer(384, + 384, (1, 3), + padding=(0, 1), + name=name + '.branch3x3_2a') + self.branch3x3_2b = ConvBNLayer(384, + 384, (3, 1), + padding=(1, 0), + name=name + '.branch3x3_2b') + + self.branch3x3dbl_1 = ConvBNLayer(in_channels, + 448, + 1, + name=name + '.branch3x3dbl_1') + self.branch3x3dbl_2 = ConvBNLayer(448, + 384, + 3, + padding=1, + name=name + '.branch3x3dbl_2') + self.branch3x3dbl_3a = ConvBNLayer(384, + 384, (1, 3), + padding=(0, 1), + name=name + '.branch3x3dbl_3a') + self.branch3x3dbl_3b = ConvBNLayer(384, + 384, (3, 1), + padding=(1, 0), + name=name + '.branch3x3dbl_3b') + + self.branch_pool0 = Pool2D(pool_size=3, + pool_stride=1, + pool_padding=1, + exclusive=True, + pool_type='avg') + self.branch_pool = ConvBNLayer(in_channels, + 192, + 1, + name=name + '.branch_pool') + + def forward(self, x): + branch1x1 = self.branch1x1(x) + branch3x3_1 = self.branch3x3_1(x) + branch3x3_2a = self.branch3x3_2a(branch3x3_1) + branch3x3_2b = self.branch3x3_2b(branch3x3_1) + branch3x3 = fluid.layers.concat([branch3x3_2a, branch3x3_2b], axis=1) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl_3a = self.branch3x3dbl_3a(branch3x3dbl) + branch3x3dbl_3b = self.branch3x3dbl_3b(branch3x3dbl) + branch3x3dbl = fluid.layers.concat([branch3x3dbl_3a, branch3x3dbl_3b], + axis=1) + + branch_pool = self.branch_pool0(x) + branch_pool = self.branch_pool(branch_pool) + + return fluid.layers.concat( + [branch1x1, branch3x3, branch3x3dbl, branch_pool], axis=1) + + +class Fid_inceptionE_2(fluid.dygraph.Layer): + """ FID block in inception v3 + """ + def __init__(self, in_channels, name=None): + super(Fid_inceptionE_2, self).__init__() + self.branch1x1 = ConvBNLayer(in_channels, + 320, + 1, + name=name + '.branch1x1') + + self.branch3x3_1 = ConvBNLayer(in_channels, + 384, + 1, + name=name + '.branch3x3_1') + self.branch3x3_2a = ConvBNLayer(384, + 384, (1, 3), + padding=(0, 1), + name=name + '.branch3x3_2a') + self.branch3x3_2b = ConvBNLayer(384, + 384, (3, 1), + padding=(1, 0), + name=name + '.branch3x3_2b') + + self.branch3x3dbl_1 = ConvBNLayer(in_channels, + 448, + 1, + name=name + '.branch3x3dbl_1') + self.branch3x3dbl_2 = ConvBNLayer(448, + 384, + 3, + padding=1, + name=name + '.branch3x3dbl_2') + self.branch3x3dbl_3a = ConvBNLayer(384, + 384, (1, 3), + padding=(0, 1), + name=name + '.branch3x3dbl_3a') + self.branch3x3dbl_3b = ConvBNLayer(384, + 384, (3, 1), + padding=(1, 0), + name=name + '.branch3x3dbl_3b') + ### same with paper + self.branch_pool0 = Pool2D(pool_size=3, + pool_stride=1, + pool_padding=1, + pool_type='max') + self.branch_pool = ConvBNLayer(in_channels, + 192, + 1, + name=name + '.branch_pool') + + def forward(self, x): + branch1x1 = self.branch1x1(x) + branch3x3_1 = self.branch3x3_1(x) + branch3x3_2a = self.branch3x3_2a(branch3x3_1) + branch3x3_2b = self.branch3x3_2b(branch3x3_1) + branch3x3 = fluid.layers.concat([branch3x3_2a, branch3x3_2b], axis=1) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl_3a = self.branch3x3dbl_3a(branch3x3dbl) + branch3x3dbl_3b = self.branch3x3dbl_3b(branch3x3dbl) + branch3x3dbl = fluid.layers.concat([branch3x3dbl_3a, branch3x3dbl_3b], + axis=1) + + branch_pool = self.branch_pool0(x) + branch_pool = self.branch_pool(branch_pool) + + return fluid.layers.concat( + [branch1x1, branch3x3, branch3x3dbl, branch_pool], axis=1) + + +class ConvBNLayer(fluid.dygraph.Layer): + def __init__(self, + in_channels, + num_filters, + filter_size, + stride=1, + padding=0, + groups=1, + act='relu', + name=None): + super(ConvBNLayer, self).__init__() + self.conv = Conv2D(num_channels=in_channels, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=padding, + groups=groups, + act=None, + param_attr=ParamAttr(name=name + ".conv.weight"), + bias_attr=False) + self.bn = BatchNorm(num_filters, + act=act, + epsilon=0.001, + param_attr=ParamAttr(name=name + ".bn.weight"), + bias_attr=ParamAttr(name=name + ".bn.bias"), + moving_mean_name=name + '.bn.running_mean', + moving_variance_name=name + '.bn.running_var') + + def forward(self, inputs): + y = self.conv(inputs) + y = self.bn(y) + return y diff --git a/ppgan/metric/test_fid_score.py b/ppgan/metric/test_fid_score.py new file mode 100644 index 0000000000000000000000000000000000000000..e8abccaaf3e8c4bda5a5c51e7621014b12a0664d --- /dev/null +++ b/ppgan/metric/test_fid_score.py @@ -0,0 +1,60 @@ +#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +import argparse +from compute_fid import * + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--image_data_path1', + type=str, + default='./real', + help='path of image data') + parser.add_argument('--image_data_path2', + type=str, + default='./fake', + help='path of image data') + parser.add_argument('--inference_model', + type=str, + default='./pretrained/params_inceptionV3', + help='path of inference_model.') + parser.add_argument('--use_gpu', + type=bool, + default=True, + help='default use gpu.') + parser.add_argument('--batch_size', + type=int, + default=1, + help='sample number in a batch for inference.') + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + path1 = args.image_data_path1 + path2 = args.image_data_path2 + paths = (path1, path2) + inference_model_path = args.inference_model + batch_size = args.batch_size + + with fluid.dygraph.guard(): + fid_value = calculate_fid_given_paths(paths, inference_model_path, + batch_size, args.use_gpu, 2048) + print('FID: ', fid_value) + + +if __name__ == "__main__": + main()