From ca9ff7e60551ee4691482298ff5929acf3fc7b6c Mon Sep 17 00:00:00 2001 From: lijianshe02 Date: Thu, 20 Aug 2020 06:13:47 +0000 Subject: [PATCH] add dygraph fid score computation --- ppgan/metric/compute_fid.py | 11 +------ ppgan/metric/test_fid_score.py | 60 ++++++++++++++++++++++++++++++++++ 2 files changed, 61 insertions(+), 10 deletions(-) create mode 100644 ppgan/metric/test_fid_score.py diff --git a/ppgan/metric/compute_fid.py b/ppgan/metric/compute_fid.py index 3642a27..c8fc805 100644 --- a/ppgan/metric/compute_fid.py +++ b/ppgan/metric/compute_fid.py @@ -198,10 +198,10 @@ def _compute_statistics_of_path(path, model, batch_size, dims, use_gpu, def calculate_fid_given_paths(paths, + premodel_path, batch_size, use_gpu, dims, - premodel_path, model=None): assert os.path.exists( premodel_path @@ -222,12 +222,3 @@ def calculate_fid_given_paths(paths, fid_value = _calculate_frechet_distance(m1, s1, m2, s2) return fid_value - - -if __name__ == '__main__': - with fluid.dygraph.guard(): - fid_value = calculate_fid_given_paths( - ('/workspace/color/fid_test/real', - '/workspace/color/fid_test/fake'), 1, True, 2048, - 'pretrained/params_inceptionV3/compare.pdparams') - print('FID: ', fid_value) diff --git a/ppgan/metric/test_fid_score.py b/ppgan/metric/test_fid_score.py new file mode 100644 index 0000000..6e484f5 --- /dev/null +++ b/ppgan/metric/test_fid_score.py @@ -0,0 +1,60 @@ +#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +import argparse +from compute_fid import * + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--image_data_path1', + type=str, + default='./real', + help='path of image data') + parser.add_argument('--image_data_path2', + type=str, + default='./fake', + help='path of image data') + parser.add_argument('--inference_model', + type=str, + default='./pretrained/params_inceptionV3', + help='path of inference_model.') + parser.add_argument('--use_gpu', + type=bool, + default=True, + help='default use gpu.') + parser.add_argument('--batch_size', + type=int, + default=1, + help='sample number in a batch for inference.') + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + path1 = args.image_data_path1 + path2 = args.image_data_path2 + paths = (path1, path2) + inference_model_path = args.inference_model + batch_size = args.batch_size + + with fluid.dygraph.guard(): + fid_value = calculate_fid_given_paths(paths, inference_model_path, 1, + True, 2048) + print('FID: ', fid_value) + + +if __name__ == "__main__": + main() -- GitLab