未验证 提交 7b592f5c 编写于 作者: L LielinJiang 提交者: GitHub

fix attributed error (#305)

上级 fccd4d47
...@@ -83,7 +83,11 @@ class GANModel(BaseModel): ...@@ -83,7 +83,11 @@ class GANModel(BaseModel):
input = {'img': input} input = {'img': input}
self.D_real_inputs = [paddle.to_tensor(input['img'])] self.D_real_inputs = [paddle.to_tensor(input['img'])]
if 'class_id' in input: # n class input if 'class_id' in input: # n class input
self.n_class = self.nets['netG'].n_class if isinstance(self.nets['netG'], paddle.DataParallel):
self.n_class = self.nets['netG']._layers.n_class
else:
self.n_class = self.nets['netG'].n_class
self.D_real_inputs += [ self.D_real_inputs += [
paddle.to_tensor(input['class_id'], dtype='int64') paddle.to_tensor(input['class_id'], dtype='int64')
] ]
......
...@@ -156,10 +156,13 @@ def compute_g_loss(nets, ...@@ -156,10 +156,13 @@ def compute_g_loss(nets,
loss_ds = paddle.mean(paddle.abs(x_fake - x_fake2)) loss_ds = paddle.mean(paddle.abs(x_fake - x_fake2))
# cycle-consistency loss # cycle-consistency loss
if isinstance(nets['fan'], paddle.DataParallel): if w_hpf > 0:
masks = nets['fan']._layers.get_heatmap(x_fake) if w_hpf > 0 else None if isinstance(nets['fan'], paddle.DataParallel):
masks = nets['fan']._layers.get_heatmap(x_fake)
else:
masks = nets['fan'].get_heatmap(x_fake)
else: else:
masks = nets['fan'].get_heatmap(x_fake) if w_hpf > 0 else None masks = None
s_org = nets['style_encoder'](x_real, y_org) s_org = nets['style_encoder'](x_real, y_org)
x_rec = nets['generator'](x_fake, s_org, masks=masks) x_rec = nets['generator'](x_fake, s_org, masks=masks)
...@@ -261,12 +264,13 @@ class StarGANv2Model(BaseModel): ...@@ -261,12 +264,13 @@ class StarGANv2Model(BaseModel):
'ref2'], self.input['ref_cls'] 'ref2'], self.input['ref_cls']
z_trg, z_trg2 = self.input['z_trg'], self.input['z_trg2'] z_trg, z_trg2 = self.input['z_trg'], self.input['z_trg2']
if isinstance(self.nets['fan'], paddle.DataParallel): if self.w_hpf > 0:
masks = self.nets['fan']._layers.get_heatmap( if isinstance(self.nets['fan'], paddle.DataParallel):
x_real) if self.w_hpf > 0 else None masks = self.nets['fan']._layers.get_heatmap(x_real)
else:
masks = self.nets['fan'].get_heatmap(x_real)
else: else:
masks = self.nets['fan'].get_heatmap( masks = None
x_real) if self.w_hpf > 0 else None
# train the discriminator # train the discriminator
d_loss, d_losses_latent = compute_d_loss(self.nets, d_loss, d_losses_latent = compute_d_loss(self.nets,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册