未验证 提交 1d8cd182 编写于 作者: F FNRE 提交者: GitHub

fix fom error (#319)

* 1.fix error with 4 channels of image of fom predictor. 2.fix error of fom evaluate 3.fix lapstyle vgg network
上级 6094e441
......@@ -115,7 +115,7 @@ log_config:
visiual_interval: 10
validate:
interval: 10
interval: 3000
save_img: false
snapshot_config:
......
......@@ -103,6 +103,16 @@ class FirstOrderPredictor(BasePredictor):
self.cfg, self.weight_path)
self.multi_person = multi_person
def read_img(self, path):
img = imageio.imread(path)
img = img.astype(np.float32)
if img.ndim == 2:
img = np.expand_dims(img, axis=2)
# som images have 4 channels
if img.shape[2] > 3:
img = img[:,:,:3]
return img
def run(self, source_image, driving_video):
def get_prediction(face_image):
if self.find_best_frame or self.best_frame is not None:
......@@ -138,7 +148,7 @@ class FirstOrderPredictor(BasePredictor):
adapt_movement_scale=self.adapt_scale)
return predictions
source_image = imageio.imread(source_image)
source_image = self.read_img(source_image)
reader = imageio.get_reader(driving_video)
fps = reader.get_meta_data()['fps']
driving_video = []
......
......@@ -251,7 +251,8 @@ class FramesDataset(Dataset):
out['driving'] = out['source']
out['source'] = buf
else:
video = np.stack(video_array, axis=0) / 255.0
video = np.stack(video_array, axis=0).astype(
np.float32) / 255.0
out['video'] = video.transpose(3, 0, 1, 2)
out['name'] = video_name
return out
......
......@@ -87,17 +87,19 @@ class FirstOrderModel(BaseModel):
"dis_lr": self.dis_lr
}
def setup_optimizers(self, lr_cfg, optimizer):
def setup_net_parallel(self):
if isinstance(self.nets['Gen_Full'], paddle.DataParallel):
self.nets['kp_detector'] = self.nets[
'Gen_Full']._layers.kp_extractor
self.nets['generator'] = self.nets['Gen_Full']._layers.generator
self.nets['discriminator'] = self.nets['Dis']._layers.discriminator
else:
self.nets['kp_detector'] = self.nets['Gen_Full'].kp_extractor
self.nets['generator'] = self.nets['Gen_Full'].generator
self.nets['discriminator'] = self.nets['Dis'].discriminator
def setup_optimizers(self, lr_cfg, optimizer):
self.setup_net_parallel()
# init params
init_weight(self.nets['kp_detector'])
init_weight(self.nets['generator'])
......@@ -163,6 +165,7 @@ class FirstOrderModel(BaseModel):
self.optimizers['optimizer_Dis'].step()
def test_iter(self, metrics=None):
self.setup_net_parallel()
self.nets['kp_detector'].eval()
self.nets['generator'].eval()
loss_list = []
......
......@@ -167,7 +167,18 @@ class DecoderNet(nn.Layer):
return out
vgg = nn.Sequential(
@GENERATORS.register()
class Encoder(nn.Layer):
"""Encoder of Drafting module.
Paper:
Drafting and Revision: Laplacian Pyramid Network for Fast High-Quality
Artistic Style Transfer.
"""
def __init__(self):
super(Encoder, self).__init__()
vgg_net = nn.Sequential(
nn.Conv2D(3, 3, (1, 1)),
nn.Pad2D([1, 1, 1, 1], mode='reflect'),
nn.Conv2D(3, 64, (3, 3)),
......@@ -221,19 +232,8 @@ vgg = nn.Sequential(
nn.Pad2D([1, 1, 1, 1], mode='reflect'),
nn.Conv2D(512, 512, (3, 3)),
nn.ReLU() # relu5-4
)
)
@GENERATORS.register()
class Encoder(nn.Layer):
"""Encoder of Drafting module.
Paper:
Drafting and Revision: Laplacian Pyramid Network for Fast High-Quality
Artistic Style Transfer.
"""
def __init__(self):
super(Encoder, self).__init__()
vgg_net = vgg
weight_path = get_path_from_url(
'https://paddlegan.bj.bcebos.com/models/vgg_normalised.pdparams')
vgg_net.set_dict(paddle.load(weight_path))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册