From fccd4d478171f3e096a5003f4fbc1d1626ece84c Mon Sep 17 00:00:00 2001 From: LielinJiang <50691816+LielinJiang@users.noreply.github.com> Date: Thu, 6 May 2021 13:10:21 +0800 Subject: [PATCH] fix attributed error for DataParallel (#303) --- ppgan/models/gan_model.py | 7 +- ppgan/models/starganv2_model.py | 173 +++++++++++++++++++++++--------- 2 files changed, 134 insertions(+), 46 deletions(-) diff --git a/ppgan/models/gan_model.py b/ppgan/models/gan_model.py index 6c82488..3637878 100644 --- a/ppgan/models/gan_model.py +++ b/ppgan/models/gan_model.py @@ -91,7 +91,12 @@ class GANModel(BaseModel): self.n_class = 0 batch_size = self.D_real_inputs[0].shape[0] - self.G_inputs = self.nets['netG'].random_inputs(batch_size) + + if isinstance(self.nets['netG'], paddle.DataParallel): + self.G_inputs = self.nets['netG']._layers.random_inputs(batch_size) + else: + self.G_inputs = self.nets['netG'].random_inputs(batch_size) + if not isinstance(self.G_inputs, (list, tuple)): self.G_inputs = [self.G_inputs] diff --git a/ppgan/models/starganv2_model.py b/ppgan/models/starganv2_model.py index d386e68..46622ef 100755 --- a/ppgan/models/starganv2_model.py +++ b/ppgan/models/starganv2_model.py @@ -25,20 +25,29 @@ def translate_using_reference(nets, w_hpf, x_src, x_ref, y_ref): for _ in range(N): s_ref_lists.append(s_ref_list) s_ref_list = paddle.stack(s_ref_lists, axis=1) - s_ref_list = paddle.reshape(s_ref_list, (s_ref_list.shape[0], s_ref_list.shape[1], s_ref_list.shape[3])) + s_ref_list = paddle.reshape( + s_ref_list, + (s_ref_list.shape[0], s_ref_list.shape[1], s_ref_list.shape[3])) x_concat = [x_src_with_wb] for i, s_ref in enumerate(s_ref_list): x_fake = nets['generator'](x_src, s_ref, masks=masks) - x_fake_with_ref = paddle.concat([x_ref[i:i+1], x_fake], axis=0) + x_fake_with_ref = paddle.concat([x_ref[i:i + 1], x_fake], axis=0) x_concat += [x_fake_with_ref] x_concat = paddle.concat(x_concat, axis=0) - img = tensor2img(make_grid(x_concat, nrow=N+1, range=(0, 1))) + img = tensor2img(make_grid(x_concat, nrow=N + 1, range=(0, 1))) del x_concat return img -def compute_d_loss(nets, lambda_reg, x_real, y_org, y_trg, z_trg=None, x_ref=None, masks=None): +def compute_d_loss(nets, + lambda_reg, + x_real, + y_org, + y_trg, + z_trg=None, + x_ref=None, + masks=None): assert (z_trg is None) != (x_ref is None) # with real images x_real.stop_gradient = False @@ -58,9 +67,11 @@ def compute_d_loss(nets, lambda_reg, x_real, y_org, y_trg, z_trg=None, x_ref=Non loss_fake = adv_loss(out, 0) loss = loss_real + loss_fake + lambda_reg * loss_reg - return loss, {'real': loss_real.numpy(), - 'fake': loss_fake.numpy(), - 'reg': loss_reg.numpy()} + return loss, { + 'real': loss_real.numpy(), + 'fake': loss_fake.numpy(), + 'reg': loss_reg.numpy() + } def adv_loss(logits, target): @@ -73,21 +84,29 @@ def adv_loss(logits, target): def r1_reg(d_out, x_in): # zero-centered gradient penalty for real images batch_size = x_in.shape[0] - grad_dout = paddle.grad( - outputs=d_out.sum(), inputs=x_in, - create_graph=True, retain_graph=True, only_inputs=True - )[0] + grad_dout = paddle.grad(outputs=d_out.sum(), + inputs=x_in, + create_graph=True, + retain_graph=True, + only_inputs=True)[0] grad_dout2 = grad_dout.pow(2) - assert(grad_dout2.shape == x_in.shape) + assert (grad_dout2.shape == x_in.shape) reg = 0.5 * paddle.reshape(grad_dout2, (batch_size, -1)).sum(1).mean(0) return reg + def soft_update(source, target, beta=1.0): assert 0.0 <= beta <= 1.0 + + if isinstance(source, paddle.DataParallel): + source = source._layers + target_model_map = dict(target.named_parameters()) for param_name, source_param in source.named_parameters(): target_param = target_model_map[param_name] - target_param.set_value(beta * source_param + (1.0 - beta) * target_param) + target_param.set_value(beta * source_param + + (1.0 - beta) * target_param) + def dump_model(model): params = {} @@ -97,7 +116,17 @@ def dump_model(model): return params -def compute_g_loss(nets, w_hpf, lambda_sty, lambda_ds, lambda_cyc, x_real, y_org, y_trg, z_trgs=None, x_refs=None, masks=None): +def compute_g_loss(nets, + w_hpf, + lambda_sty, + lambda_ds, + lambda_cyc, + x_real, + y_org, + y_trg, + z_trgs=None, + x_refs=None, + masks=None): assert (z_trgs is None) != (x_refs is None) if z_trgs is not None: z_trg, z_trg2 = z_trgs @@ -127,17 +156,23 @@ def compute_g_loss(nets, w_hpf, lambda_sty, lambda_ds, lambda_cyc, x_real, y_org loss_ds = paddle.mean(paddle.abs(x_fake - x_fake2)) # cycle-consistency loss - masks = nets['fan'].get_heatmap(x_fake) if w_hpf > 0 else None + if isinstance(nets['fan'], paddle.DataParallel): + masks = nets['fan']._layers.get_heatmap(x_fake) if w_hpf > 0 else None + else: + masks = nets['fan'].get_heatmap(x_fake) if w_hpf > 0 else None + s_org = nets['style_encoder'](x_real, y_org) x_rec = nets['generator'](x_fake, s_org, masks=masks) loss_cyc = paddle.mean(paddle.abs(x_rec - x_real)) loss = loss_adv + lambda_sty * loss_sty \ - lambda_ds * loss_ds + lambda_cyc * loss_cyc - return loss, {'adv': loss_adv.numpy(), - 'sty': loss_sty.numpy(), - 'ds:': loss_ds.numpy(), - 'cyc': loss_cyc.numpy()} + return loss, { + 'adv': loss_adv.numpy(), + 'sty': loss_sty.numpy(), + 'ds:': loss_ds.numpy(), + 'cyc': loss_cyc.numpy() + } def he_init(module): @@ -154,7 +189,7 @@ def he_init(module): @MODELS.register() class StarGANv2Model(BaseModel): def __init__( - self, + self, generator, style=None, mapping=None, @@ -195,7 +230,7 @@ class StarGANv2Model(BaseModel): # remember the initial value of ds weight self.initial_lambda_ds = self.lambda_ds - + def setup_input(self, input): """Unpack input data from the dataloader and perform necessary pre-processing steps. @@ -206,8 +241,10 @@ class StarGANv2Model(BaseModel): """ pass self.input = input - self.input['z_trg'] = paddle.randn((input['src'].shape[0], self.latent_dim)) - self.input['z_trg2'] = paddle.randn((input['src'].shape[0], self.latent_dim)) + self.input['z_trg'] = paddle.randn( + (input['src'].shape[0], self.latent_dim)) + self.input['z_trg2'] = paddle.randn( + (input['src'].shape[0], self.latent_dim)) def forward(self): """Run forward pass; called by both functions and .""" @@ -220,50 +257,89 @@ class StarGANv2Model(BaseModel): def train_iter(self, optimizers=None): #TODO x_real, y_org = self.input['src'], self.input['src_cls'] - x_ref, x_ref2, y_trg = self.input['ref'], self.input['ref2'], self.input['ref_cls'] + x_ref, x_ref2, y_trg = self.input['ref'], self.input[ + 'ref2'], self.input['ref_cls'] z_trg, z_trg2 = self.input['z_trg'], self.input['z_trg2'] - masks = self.nets['fan'].get_heatmap(x_real) if self.w_hpf > 0 else None + if isinstance(self.nets['fan'], paddle.DataParallel): + masks = self.nets['fan']._layers.get_heatmap( + x_real) if self.w_hpf > 0 else None + else: + masks = self.nets['fan'].get_heatmap( + x_real) if self.w_hpf > 0 else None # train the discriminator - d_loss, d_losses_latent = compute_d_loss( - self.nets, self.lambda_reg, x_real, y_org, y_trg, z_trg=z_trg, masks=masks) + d_loss, d_losses_latent = compute_d_loss(self.nets, + self.lambda_reg, + x_real, + y_org, + y_trg, + z_trg=z_trg, + masks=masks) self._reset_grad(optimizers) d_loss.backward() optimizers['discriminator'].minimize(d_loss) - d_loss, d_losses_ref = compute_d_loss( - self.nets, self.lambda_reg, x_real, y_org, y_trg, x_ref=x_ref, masks=masks) + d_loss, d_losses_ref = compute_d_loss(self.nets, + self.lambda_reg, + x_real, + y_org, + y_trg, + x_ref=x_ref, + masks=masks) self._reset_grad(optimizers) d_loss.backward() optimizers['discriminator'].step() # train the generator - g_loss, g_losses_latent = compute_g_loss( - self.nets, self.w_hpf, self.lambda_sty, self.lambda_ds, self.lambda_cyc, x_real, y_org, y_trg, z_trgs=[z_trg, z_trg2], masks=masks) + g_loss, g_losses_latent = compute_g_loss(self.nets, + self.w_hpf, + self.lambda_sty, + self.lambda_ds, + self.lambda_cyc, + x_real, + y_org, + y_trg, + z_trgs=[z_trg, z_trg2], + masks=masks) self._reset_grad(optimizers) g_loss.backward() optimizers['generator'].step() optimizers['mapping_network'].step() optimizers['style_encoder'].step() - g_loss, g_losses_ref = compute_g_loss( - self.nets, self.w_hpf, self.lambda_sty, self.lambda_ds, self.lambda_cyc, x_real, y_org, y_trg, x_refs=[x_ref, x_ref2], masks=masks) + g_loss, g_losses_ref = compute_g_loss(self.nets, + self.w_hpf, + self.lambda_sty, + self.lambda_ds, + self.lambda_cyc, + x_real, + y_org, + y_trg, + x_refs=[x_ref, x_ref2], + masks=masks) self._reset_grad(optimizers) g_loss.backward() optimizers['generator'].step() # compute moving average of network parameters - soft_update(self.nets['generator'], self.nets_ema['generator'], beta=0.999) - soft_update(self.nets['mapping_network'], self.nets_ema['mapping_network'], beta=0.999) - soft_update(self.nets['style_encoder'], self.nets_ema['style_encoder'], beta=0.999) + soft_update(self.nets['generator'], + self.nets_ema['generator'], + beta=0.999) + soft_update(self.nets['mapping_network'], + self.nets_ema['mapping_network'], + beta=0.999) + soft_update(self.nets['style_encoder'], + self.nets_ema['style_encoder'], + beta=0.999) # decay weight for diversity sensitive loss if self.lambda_ds > 0: self.lambda_ds -= (self.initial_lambda_ds / self.total_iter) - for loss, prefix in zip([d_losses_latent, d_losses_ref, g_losses_latent, g_losses_ref], - ['D/latent_', 'D/ref_', 'G/latent_', 'G/ref_']): + for loss, prefix in zip( + [d_losses_latent, d_losses_ref, g_losses_latent, g_losses_ref], + ['D/latent_', 'D/ref_', 'G/latent_', 'G/ref_']): for key, value in loss.items(): self.losses[prefix + key] = value self.losses['G/lambda_ds'] = self.lambda_ds @@ -273,17 +349,24 @@ class StarGANv2Model(BaseModel): #TODO self.nets_ema['generator'].eval() self.nets_ema['style_encoder'].eval() - soft_update(self.nets['generator'], self.nets_ema['generator'], beta=0.999) - soft_update(self.nets['mapping_network'], self.nets_ema['mapping_network'], beta=0.999) - soft_update(self.nets['style_encoder'], self.nets_ema['style_encoder'], beta=0.999) + soft_update(self.nets['generator'], + self.nets_ema['generator'], + beta=0.999) + soft_update(self.nets['mapping_network'], + self.nets_ema['mapping_network'], + beta=0.999) + soft_update(self.nets['style_encoder'], + self.nets_ema['style_encoder'], + beta=0.999) src_img = self.input['src'] ref_img = self.input['ref'] ref_label = self.input['ref_cls'] with paddle.no_grad(): - img = translate_using_reference(self.nets_ema, self.w_hpf, - paddle.to_tensor(src_img).astype('float32'), - paddle.to_tensor(ref_img).astype('float32'), - paddle.to_tensor(ref_label).astype('float32')) + img = translate_using_reference( + self.nets_ema, self.w_hpf, + paddle.to_tensor(src_img).astype('float32'), + paddle.to_tensor(ref_img).astype('float32'), + paddle.to_tensor(ref_label).astype('float32')) self.visual_items['reference'] = img self.nets_ema['generator'].train() self.nets_ema['style_encoder'].train() -- GitLab