diff --git a/ppgan/metric/compute_fid.py b/ppgan/metric/compute_fid.py index c7f8c0a0d4c5b2bb9dc07d769f706e073b6c3452..3e1d013ed2dfcc75ec034e30c886a6f3c2619efd 100644 --- a/ppgan/metric/compute_fid.py +++ b/ppgan/metric/compute_fid.py @@ -250,9 +250,10 @@ def calculate_fid_given_paths(paths, if not os.path.exists(p): raise RuntimeError('Invalid path: %s' % p) - if model is None: - block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] - model = InceptionV3([block_idx], class_dim=1008) + if model is None and style != 'stargan': + with fluid.dygraph.guard(): + block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] + model = InceptionV3([block_idx], class_dim=1008) m1, s1 = _compute_statistics_of_path(paths[0], model, batch_size, dims, use_gpu, premodel_path, style)