未验证 提交 f9300835 编写于 作者: L lzzyzlbb 提交者: GitHub

add stargan pretrain model (#366)

上级 36616a74
...@@ -24,7 +24,7 @@ model: ...@@ -24,7 +24,7 @@ model:
num_domains: *NUM_DOMAINS num_domains: *NUM_DOMAINS
fan: fan:
name: FAN name: FAN
fname_pretrained: models/stargan-v2/wing.pdparams fname_pretrained: None
discriminator: discriminator:
name: StarGANv2Discriminator name: StarGANv2Discriminator
img_size: *IMAGE_SIZE img_size: *IMAGE_SIZE
......
...@@ -9,6 +9,9 @@ import math ...@@ -9,6 +9,9 @@ import math
from ppgan.modules.wing import CoordConvTh, ConvBlock, HourGlass, preprocess from ppgan.modules.wing import CoordConvTh, ConvBlock, HourGlass, preprocess
from ppgan.utils.download import get_path_from_url
FAN_WEIGHT_URL = "https://paddlegan.bj.bcebos.com/models/wing.pdparams"
class AvgPool2D(nn.Layer): class AvgPool2D(nn.Layer):
""" """
...@@ -298,6 +301,9 @@ class FAN(nn.Layer): ...@@ -298,6 +301,9 @@ class FAN(nn.Layer):
if fname_pretrained is not None: if fname_pretrained is not None:
self.load_pretrained_weights(fname_pretrained) self.load_pretrained_weights(fname_pretrained)
else:
weight_path = get_path_from_url(FAN_WEIGHT_URL)
self.load_pretrained_weights(weight_path)
def load_pretrained_weights(self, fname): def load_pretrained_weights(self, fname):
import pickle import pickle
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册