未验证 提交 f9300835 编写于 作者: L lzzyzlbb 提交者: GitHub

add stargan pretrain model (#366)

上级 36616a74
......@@ -24,7 +24,7 @@ model:
num_domains: *NUM_DOMAINS
fan:
name: FAN
fname_pretrained: models/stargan-v2/wing.pdparams
fname_pretrained: None
discriminator:
name: StarGANv2Discriminator
img_size: *IMAGE_SIZE
......
......@@ -9,6 +9,9 @@ import math
from ppgan.modules.wing import CoordConvTh, ConvBlock, HourGlass, preprocess
from ppgan.utils.download import get_path_from_url
FAN_WEIGHT_URL = "https://paddlegan.bj.bcebos.com/models/wing.pdparams"
class AvgPool2D(nn.Layer):
"""
......@@ -298,6 +301,9 @@ class FAN(nn.Layer):
if fname_pretrained is not None:
self.load_pretrained_weights(fname_pretrained)
else:
weight_path = get_path_from_url(FAN_WEIGHT_URL)
self.load_pretrained_weights(weight_path)
def load_pretrained_weights(self, fname):
import pickle
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册