diff --git a/configs/starganv2_celeba_hq.yaml b/configs/starganv2_celeba_hq.yaml index c5e00e130695772a610affecfe8bd4dc4c2571dd..8ea231c3c7279b0569e6a58d1fdf9cb096d029f4 100644 --- a/configs/starganv2_celeba_hq.yaml +++ b/configs/starganv2_celeba_hq.yaml @@ -24,7 +24,7 @@ model: num_domains: *NUM_DOMAINS fan: name: FAN - fname_pretrained: models/stargan-v2/wing.pdparams + fname_pretrained: None discriminator: name: StarGANv2Discriminator img_size: *IMAGE_SIZE diff --git a/ppgan/models/generators/generator_starganv2.py b/ppgan/models/generators/generator_starganv2.py index bed8c01ac25019b7d4625d22a4792976e3211f77..ad1aedbb42f12b3d0aa11d509278bac3f37033aa 100755 --- a/ppgan/models/generators/generator_starganv2.py +++ b/ppgan/models/generators/generator_starganv2.py @@ -9,6 +9,9 @@ import math from ppgan.modules.wing import CoordConvTh, ConvBlock, HourGlass, preprocess +from ppgan.utils.download import get_path_from_url + +FAN_WEIGHT_URL = "https://paddlegan.bj.bcebos.com/models/wing.pdparams" class AvgPool2D(nn.Layer): """ @@ -298,6 +301,9 @@ class FAN(nn.Layer): if fname_pretrained is not None: self.load_pretrained_weights(fname_pretrained) + else: + weight_path = get_path_from_url(FAN_WEIGHT_URL) + self.load_pretrained_weights(weight_path) def load_pretrained_weights(self, fname): import pickle