未验证 提交 7b592f5c 编写于 作者: L LielinJiang 提交者: GitHub

fix attributed error (#305)

上级 fccd4d47
......@@ -83,7 +83,11 @@ class GANModel(BaseModel):
input = {'img': input}
self.D_real_inputs = [paddle.to_tensor(input['img'])]
if 'class_id' in input: # n class input
self.n_class = self.nets['netG'].n_class
if isinstance(self.nets['netG'], paddle.DataParallel):
self.n_class = self.nets['netG']._layers.n_class
else:
self.n_class = self.nets['netG'].n_class
self.D_real_inputs += [
paddle.to_tensor(input['class_id'], dtype='int64')
]
......
......@@ -156,10 +156,13 @@ def compute_g_loss(nets,
loss_ds = paddle.mean(paddle.abs(x_fake - x_fake2))
# cycle-consistency loss
if isinstance(nets['fan'], paddle.DataParallel):
masks = nets['fan']._layers.get_heatmap(x_fake) if w_hpf > 0 else None
if w_hpf > 0:
if isinstance(nets['fan'], paddle.DataParallel):
masks = nets['fan']._layers.get_heatmap(x_fake)
else:
masks = nets['fan'].get_heatmap(x_fake)
else:
masks = nets['fan'].get_heatmap(x_fake) if w_hpf > 0 else None
masks = None
s_org = nets['style_encoder'](x_real, y_org)
x_rec = nets['generator'](x_fake, s_org, masks=masks)
......@@ -261,12 +264,13 @@ class StarGANv2Model(BaseModel):
'ref2'], self.input['ref_cls']
z_trg, z_trg2 = self.input['z_trg'], self.input['z_trg2']
if isinstance(self.nets['fan'], paddle.DataParallel):
masks = self.nets['fan']._layers.get_heatmap(
x_real) if self.w_hpf > 0 else None
if self.w_hpf > 0:
if isinstance(self.nets['fan'], paddle.DataParallel):
masks = self.nets['fan']._layers.get_heatmap(x_real)
else:
masks = self.nets['fan'].get_heatmap(x_real)
else:
masks = self.nets['fan'].get_heatmap(
x_real) if self.w_hpf > 0 else None
masks = None
# train the discriminator
d_loss, d_losses_latent = compute_d_loss(self.nets,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册