未验证 提交 fccd4d47 编写于 作者: L LielinJiang 提交者: GitHub

fix attributed error for DataParallel (#303)

上级 fa53795c
...@@ -91,7 +91,12 @@ class GANModel(BaseModel): ...@@ -91,7 +91,12 @@ class GANModel(BaseModel):
self.n_class = 0 self.n_class = 0
batch_size = self.D_real_inputs[0].shape[0] batch_size = self.D_real_inputs[0].shape[0]
self.G_inputs = self.nets['netG'].random_inputs(batch_size)
if isinstance(self.nets['netG'], paddle.DataParallel):
self.G_inputs = self.nets['netG']._layers.random_inputs(batch_size)
else:
self.G_inputs = self.nets['netG'].random_inputs(batch_size)
if not isinstance(self.G_inputs, (list, tuple)): if not isinstance(self.G_inputs, (list, tuple)):
self.G_inputs = [self.G_inputs] self.G_inputs = [self.G_inputs]
......
...@@ -25,20 +25,29 @@ def translate_using_reference(nets, w_hpf, x_src, x_ref, y_ref): ...@@ -25,20 +25,29 @@ def translate_using_reference(nets, w_hpf, x_src, x_ref, y_ref):
for _ in range(N): for _ in range(N):
s_ref_lists.append(s_ref_list) s_ref_lists.append(s_ref_list)
s_ref_list = paddle.stack(s_ref_lists, axis=1) s_ref_list = paddle.stack(s_ref_lists, axis=1)
s_ref_list = paddle.reshape(s_ref_list, (s_ref_list.shape[0], s_ref_list.shape[1], s_ref_list.shape[3])) s_ref_list = paddle.reshape(
s_ref_list,
(s_ref_list.shape[0], s_ref_list.shape[1], s_ref_list.shape[3]))
x_concat = [x_src_with_wb] x_concat = [x_src_with_wb]
for i, s_ref in enumerate(s_ref_list): for i, s_ref in enumerate(s_ref_list):
x_fake = nets['generator'](x_src, s_ref, masks=masks) x_fake = nets['generator'](x_src, s_ref, masks=masks)
x_fake_with_ref = paddle.concat([x_ref[i:i+1], x_fake], axis=0) x_fake_with_ref = paddle.concat([x_ref[i:i + 1], x_fake], axis=0)
x_concat += [x_fake_with_ref] x_concat += [x_fake_with_ref]
x_concat = paddle.concat(x_concat, axis=0) x_concat = paddle.concat(x_concat, axis=0)
img = tensor2img(make_grid(x_concat, nrow=N+1, range=(0, 1))) img = tensor2img(make_grid(x_concat, nrow=N + 1, range=(0, 1)))
del x_concat del x_concat
return img return img
def compute_d_loss(nets, lambda_reg, x_real, y_org, y_trg, z_trg=None, x_ref=None, masks=None): def compute_d_loss(nets,
lambda_reg,
x_real,
y_org,
y_trg,
z_trg=None,
x_ref=None,
masks=None):
assert (z_trg is None) != (x_ref is None) assert (z_trg is None) != (x_ref is None)
# with real images # with real images
x_real.stop_gradient = False x_real.stop_gradient = False
...@@ -58,9 +67,11 @@ def compute_d_loss(nets, lambda_reg, x_real, y_org, y_trg, z_trg=None, x_ref=Non ...@@ -58,9 +67,11 @@ def compute_d_loss(nets, lambda_reg, x_real, y_org, y_trg, z_trg=None, x_ref=Non
loss_fake = adv_loss(out, 0) loss_fake = adv_loss(out, 0)
loss = loss_real + loss_fake + lambda_reg * loss_reg loss = loss_real + loss_fake + lambda_reg * loss_reg
return loss, {'real': loss_real.numpy(), return loss, {
'fake': loss_fake.numpy(), 'real': loss_real.numpy(),
'reg': loss_reg.numpy()} 'fake': loss_fake.numpy(),
'reg': loss_reg.numpy()
}
def adv_loss(logits, target): def adv_loss(logits, target):
...@@ -73,21 +84,29 @@ def adv_loss(logits, target): ...@@ -73,21 +84,29 @@ def adv_loss(logits, target):
def r1_reg(d_out, x_in): def r1_reg(d_out, x_in):
# zero-centered gradient penalty for real images # zero-centered gradient penalty for real images
batch_size = x_in.shape[0] batch_size = x_in.shape[0]
grad_dout = paddle.grad( grad_dout = paddle.grad(outputs=d_out.sum(),
outputs=d_out.sum(), inputs=x_in, inputs=x_in,
create_graph=True, retain_graph=True, only_inputs=True create_graph=True,
)[0] retain_graph=True,
only_inputs=True)[0]
grad_dout2 = grad_dout.pow(2) grad_dout2 = grad_dout.pow(2)
assert(grad_dout2.shape == x_in.shape) assert (grad_dout2.shape == x_in.shape)
reg = 0.5 * paddle.reshape(grad_dout2, (batch_size, -1)).sum(1).mean(0) reg = 0.5 * paddle.reshape(grad_dout2, (batch_size, -1)).sum(1).mean(0)
return reg return reg
def soft_update(source, target, beta=1.0): def soft_update(source, target, beta=1.0):
assert 0.0 <= beta <= 1.0 assert 0.0 <= beta <= 1.0
if isinstance(source, paddle.DataParallel):
source = source._layers
target_model_map = dict(target.named_parameters()) target_model_map = dict(target.named_parameters())
for param_name, source_param in source.named_parameters(): for param_name, source_param in source.named_parameters():
target_param = target_model_map[param_name] target_param = target_model_map[param_name]
target_param.set_value(beta * source_param + (1.0 - beta) * target_param) target_param.set_value(beta * source_param +
(1.0 - beta) * target_param)
def dump_model(model): def dump_model(model):
params = {} params = {}
...@@ -97,7 +116,17 @@ def dump_model(model): ...@@ -97,7 +116,17 @@ def dump_model(model):
return params return params
def compute_g_loss(nets, w_hpf, lambda_sty, lambda_ds, lambda_cyc, x_real, y_org, y_trg, z_trgs=None, x_refs=None, masks=None): def compute_g_loss(nets,
w_hpf,
lambda_sty,
lambda_ds,
lambda_cyc,
x_real,
y_org,
y_trg,
z_trgs=None,
x_refs=None,
masks=None):
assert (z_trgs is None) != (x_refs is None) assert (z_trgs is None) != (x_refs is None)
if z_trgs is not None: if z_trgs is not None:
z_trg, z_trg2 = z_trgs z_trg, z_trg2 = z_trgs
...@@ -127,17 +156,23 @@ def compute_g_loss(nets, w_hpf, lambda_sty, lambda_ds, lambda_cyc, x_real, y_org ...@@ -127,17 +156,23 @@ def compute_g_loss(nets, w_hpf, lambda_sty, lambda_ds, lambda_cyc, x_real, y_org
loss_ds = paddle.mean(paddle.abs(x_fake - x_fake2)) loss_ds = paddle.mean(paddle.abs(x_fake - x_fake2))
# cycle-consistency loss # cycle-consistency loss
masks = nets['fan'].get_heatmap(x_fake) if w_hpf > 0 else None if isinstance(nets['fan'], paddle.DataParallel):
masks = nets['fan']._layers.get_heatmap(x_fake) if w_hpf > 0 else None
else:
masks = nets['fan'].get_heatmap(x_fake) if w_hpf > 0 else None
s_org = nets['style_encoder'](x_real, y_org) s_org = nets['style_encoder'](x_real, y_org)
x_rec = nets['generator'](x_fake, s_org, masks=masks) x_rec = nets['generator'](x_fake, s_org, masks=masks)
loss_cyc = paddle.mean(paddle.abs(x_rec - x_real)) loss_cyc = paddle.mean(paddle.abs(x_rec - x_real))
loss = loss_adv + lambda_sty * loss_sty \ loss = loss_adv + lambda_sty * loss_sty \
- lambda_ds * loss_ds + lambda_cyc * loss_cyc - lambda_ds * loss_ds + lambda_cyc * loss_cyc
return loss, {'adv': loss_adv.numpy(), return loss, {
'sty': loss_sty.numpy(), 'adv': loss_adv.numpy(),
'ds:': loss_ds.numpy(), 'sty': loss_sty.numpy(),
'cyc': loss_cyc.numpy()} 'ds:': loss_ds.numpy(),
'cyc': loss_cyc.numpy()
}
def he_init(module): def he_init(module):
...@@ -154,7 +189,7 @@ def he_init(module): ...@@ -154,7 +189,7 @@ def he_init(module):
@MODELS.register() @MODELS.register()
class StarGANv2Model(BaseModel): class StarGANv2Model(BaseModel):
def __init__( def __init__(
self, self,
generator, generator,
style=None, style=None,
mapping=None, mapping=None,
...@@ -195,7 +230,7 @@ class StarGANv2Model(BaseModel): ...@@ -195,7 +230,7 @@ class StarGANv2Model(BaseModel):
# remember the initial value of ds weight # remember the initial value of ds weight
self.initial_lambda_ds = self.lambda_ds self.initial_lambda_ds = self.lambda_ds
def setup_input(self, input): def setup_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps. """Unpack input data from the dataloader and perform necessary pre-processing steps.
...@@ -206,8 +241,10 @@ class StarGANv2Model(BaseModel): ...@@ -206,8 +241,10 @@ class StarGANv2Model(BaseModel):
""" """
pass pass
self.input = input self.input = input
self.input['z_trg'] = paddle.randn((input['src'].shape[0], self.latent_dim)) self.input['z_trg'] = paddle.randn(
self.input['z_trg2'] = paddle.randn((input['src'].shape[0], self.latent_dim)) (input['src'].shape[0], self.latent_dim))
self.input['z_trg2'] = paddle.randn(
(input['src'].shape[0], self.latent_dim))
def forward(self): def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>.""" """Run forward pass; called by both functions <optimize_parameters> and <test>."""
...@@ -220,50 +257,89 @@ class StarGANv2Model(BaseModel): ...@@ -220,50 +257,89 @@ class StarGANv2Model(BaseModel):
def train_iter(self, optimizers=None): def train_iter(self, optimizers=None):
#TODO #TODO
x_real, y_org = self.input['src'], self.input['src_cls'] x_real, y_org = self.input['src'], self.input['src_cls']
x_ref, x_ref2, y_trg = self.input['ref'], self.input['ref2'], self.input['ref_cls'] x_ref, x_ref2, y_trg = self.input['ref'], self.input[
'ref2'], self.input['ref_cls']
z_trg, z_trg2 = self.input['z_trg'], self.input['z_trg2'] z_trg, z_trg2 = self.input['z_trg'], self.input['z_trg2']
masks = self.nets['fan'].get_heatmap(x_real) if self.w_hpf > 0 else None if isinstance(self.nets['fan'], paddle.DataParallel):
masks = self.nets['fan']._layers.get_heatmap(
x_real) if self.w_hpf > 0 else None
else:
masks = self.nets['fan'].get_heatmap(
x_real) if self.w_hpf > 0 else None
# train the discriminator # train the discriminator
d_loss, d_losses_latent = compute_d_loss( d_loss, d_losses_latent = compute_d_loss(self.nets,
self.nets, self.lambda_reg, x_real, y_org, y_trg, z_trg=z_trg, masks=masks) self.lambda_reg,
x_real,
y_org,
y_trg,
z_trg=z_trg,
masks=masks)
self._reset_grad(optimizers) self._reset_grad(optimizers)
d_loss.backward() d_loss.backward()
optimizers['discriminator'].minimize(d_loss) optimizers['discriminator'].minimize(d_loss)
d_loss, d_losses_ref = compute_d_loss( d_loss, d_losses_ref = compute_d_loss(self.nets,
self.nets, self.lambda_reg, x_real, y_org, y_trg, x_ref=x_ref, masks=masks) self.lambda_reg,
x_real,
y_org,
y_trg,
x_ref=x_ref,
masks=masks)
self._reset_grad(optimizers) self._reset_grad(optimizers)
d_loss.backward() d_loss.backward()
optimizers['discriminator'].step() optimizers['discriminator'].step()
# train the generator # train the generator
g_loss, g_losses_latent = compute_g_loss( g_loss, g_losses_latent = compute_g_loss(self.nets,
self.nets, self.w_hpf, self.lambda_sty, self.lambda_ds, self.lambda_cyc, x_real, y_org, y_trg, z_trgs=[z_trg, z_trg2], masks=masks) self.w_hpf,
self.lambda_sty,
self.lambda_ds,
self.lambda_cyc,
x_real,
y_org,
y_trg,
z_trgs=[z_trg, z_trg2],
masks=masks)
self._reset_grad(optimizers) self._reset_grad(optimizers)
g_loss.backward() g_loss.backward()
optimizers['generator'].step() optimizers['generator'].step()
optimizers['mapping_network'].step() optimizers['mapping_network'].step()
optimizers['style_encoder'].step() optimizers['style_encoder'].step()
g_loss, g_losses_ref = compute_g_loss( g_loss, g_losses_ref = compute_g_loss(self.nets,
self.nets, self.w_hpf, self.lambda_sty, self.lambda_ds, self.lambda_cyc, x_real, y_org, y_trg, x_refs=[x_ref, x_ref2], masks=masks) self.w_hpf,
self.lambda_sty,
self.lambda_ds,
self.lambda_cyc,
x_real,
y_org,
y_trg,
x_refs=[x_ref, x_ref2],
masks=masks)
self._reset_grad(optimizers) self._reset_grad(optimizers)
g_loss.backward() g_loss.backward()
optimizers['generator'].step() optimizers['generator'].step()
# compute moving average of network parameters # compute moving average of network parameters
soft_update(self.nets['generator'], self.nets_ema['generator'], beta=0.999) soft_update(self.nets['generator'],
soft_update(self.nets['mapping_network'], self.nets_ema['mapping_network'], beta=0.999) self.nets_ema['generator'],
soft_update(self.nets['style_encoder'], self.nets_ema['style_encoder'], beta=0.999) beta=0.999)
soft_update(self.nets['mapping_network'],
self.nets_ema['mapping_network'],
beta=0.999)
soft_update(self.nets['style_encoder'],
self.nets_ema['style_encoder'],
beta=0.999)
# decay weight for diversity sensitive loss # decay weight for diversity sensitive loss
if self.lambda_ds > 0: if self.lambda_ds > 0:
self.lambda_ds -= (self.initial_lambda_ds / self.total_iter) self.lambda_ds -= (self.initial_lambda_ds / self.total_iter)
for loss, prefix in zip([d_losses_latent, d_losses_ref, g_losses_latent, g_losses_ref], for loss, prefix in zip(
['D/latent_', 'D/ref_', 'G/latent_', 'G/ref_']): [d_losses_latent, d_losses_ref, g_losses_latent, g_losses_ref],
['D/latent_', 'D/ref_', 'G/latent_', 'G/ref_']):
for key, value in loss.items(): for key, value in loss.items():
self.losses[prefix + key] = value self.losses[prefix + key] = value
self.losses['G/lambda_ds'] = self.lambda_ds self.losses['G/lambda_ds'] = self.lambda_ds
...@@ -273,17 +349,24 @@ class StarGANv2Model(BaseModel): ...@@ -273,17 +349,24 @@ class StarGANv2Model(BaseModel):
#TODO #TODO
self.nets_ema['generator'].eval() self.nets_ema['generator'].eval()
self.nets_ema['style_encoder'].eval() self.nets_ema['style_encoder'].eval()
soft_update(self.nets['generator'], self.nets_ema['generator'], beta=0.999) soft_update(self.nets['generator'],
soft_update(self.nets['mapping_network'], self.nets_ema['mapping_network'], beta=0.999) self.nets_ema['generator'],
soft_update(self.nets['style_encoder'], self.nets_ema['style_encoder'], beta=0.999) beta=0.999)
soft_update(self.nets['mapping_network'],
self.nets_ema['mapping_network'],
beta=0.999)
soft_update(self.nets['style_encoder'],
self.nets_ema['style_encoder'],
beta=0.999)
src_img = self.input['src'] src_img = self.input['src']
ref_img = self.input['ref'] ref_img = self.input['ref']
ref_label = self.input['ref_cls'] ref_label = self.input['ref_cls']
with paddle.no_grad(): with paddle.no_grad():
img = translate_using_reference(self.nets_ema, self.w_hpf, img = translate_using_reference(
paddle.to_tensor(src_img).astype('float32'), self.nets_ema, self.w_hpf,
paddle.to_tensor(ref_img).astype('float32'), paddle.to_tensor(src_img).astype('float32'),
paddle.to_tensor(ref_label).astype('float32')) paddle.to_tensor(ref_img).astype('float32'),
paddle.to_tensor(ref_label).astype('float32'))
self.visual_items['reference'] = img self.visual_items['reference'] = img
self.nets_ema['generator'].train() self.nets_ema['generator'].train()
self.nets_ema['style_encoder'].train() self.nets_ema['style_encoder'].train()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册