From f93008355af6837518383bd1867efa95635ce3b0 Mon Sep 17 00:00:00 2001 From: lzzyzlbb <287246233@qq.com> Date: Tue, 20 Jul 2021 12:55:34 +0800 Subject: [PATCH] add stargan pretrain model (#366) --- configs/starganv2_celeba_hq.yaml | 2 +- ppgan/models/generators/generator_starganv2.py | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/configs/starganv2_celeba_hq.yaml b/configs/starganv2_celeba_hq.yaml index c5e00e1..8ea231c 100644 --- a/configs/starganv2_celeba_hq.yaml +++ b/configs/starganv2_celeba_hq.yaml @@ -24,7 +24,7 @@ model: num_domains: *NUM_DOMAINS fan: name: FAN - fname_pretrained: models/stargan-v2/wing.pdparams + fname_pretrained: None discriminator: name: StarGANv2Discriminator img_size: *IMAGE_SIZE diff --git a/ppgan/models/generators/generator_starganv2.py b/ppgan/models/generators/generator_starganv2.py index bed8c01..ad1aedb 100755 --- a/ppgan/models/generators/generator_starganv2.py +++ b/ppgan/models/generators/generator_starganv2.py @@ -9,6 +9,9 @@ import math from ppgan.modules.wing import CoordConvTh, ConvBlock, HourGlass, preprocess +from ppgan.utils.download import get_path_from_url + +FAN_WEIGHT_URL = "https://paddlegan.bj.bcebos.com/models/wing.pdparams" class AvgPool2D(nn.Layer): """ @@ -298,6 +301,9 @@ class FAN(nn.Layer): if fname_pretrained is not None: self.load_pretrained_weights(fname_pretrained) + else: + weight_path = get_path_from_url(FAN_WEIGHT_URL) + self.load_pretrained_weights(weight_path) def load_pretrained_weights(self, fname): import pickle -- GitLab