【double grad】使用grad接口,优化生成模型时某op梯度找不到
Created by: wzzju
以下为简单的Paddle可复现代码:
import paddle
import numpy as np
paddle.enable_imperative()
class Generator(paddle.fluid.dygraph.Layer):
def __init__(self):
super(Generator, self).__init__()
self.conv1 = paddle.nn.Conv2D(3, 3, 3, 1)
def forward(self, x):
x = self.conv1(x)
x = paddle.fluid.layers.tanh(x)
return x
class Discriminator(paddle.fluid.dygraph.Layer):
def __init__(self):
super(Discriminator, self).__init__()
self.convd = paddle.nn.Conv2D(6, 3, 1)
def forward(self, x):
x = self.convd(x)
return x
def cal_gradient_penalty(netD, real_data, fake_data, edge_data=None, type='mixed', constant=1.0, lambda_gp=10.0):
if lambda_gp > 0.0:
if type == 'real': # either use real images, fake images, or a linear interpolation of two.
interpolatesv = real_data
elif type == 'fake':
interpolatesv = fake_data
elif type == 'mixed':
alpha = paddle.rand((real_data.shape[0], 1))
alpha = paddle.expand(alpha, [1, np.prod(real_data.shape) // real_data.shape[0]])
alpha = paddle.reshape(alpha, real_data.shape)
interpolatesv = alpha * real_data + ((1 - alpha) * fake_data)
else:
raise NotImplementedError('{} not implemented'.format(type))
# interpolatesv.requires_grad_(True)
interpolatesv.stop_gradient = False
real_data.stop_gradient = True
fake_AB = paddle.concat((real_data.detach(), interpolatesv), 1)
disc_interpolates = netD(fake_AB)
# FIXME: use paddle.ones
outs = paddle.fill_constant(disc_interpolates.shape, disc_interpolates.dtype, 1.0)
gradients = paddle.imperative.grad(outputs=disc_interpolates, inputs=fake_AB,
grad_outputs=outs, #paddle.ones(list(disc_interpolates.shape)),
create_graph=True,
retain_graph=True,
only_inputs=True,
# no_grad_vars=set(netD.parameters())
)
gradients = paddle.reshape(gradients[0], [real_data.shape[0], -1]) # flat the data
gradient_penalty = paddle.reduce_mean((paddle.norm(gradients + 1e-16, 2, 1) - constant) ** 2) * lambda_gp # added eps
return gradient_penalty, gradients
else:
return 0.0, None
g = Generator()
d = Discriminator()
optim_g = paddle.optimizer.Adam(parameter_list=g.parameters())
optim_d = paddle.optimizer.Adam(parameter_list=d.parameters())
gan_criterion = paddle.nn.MSELoss()
l1_criterion = paddle.nn.L1Loss()
A = np.random.rand(2, 3, 32, 32).astype('float32')
B = np.random.rand(2, 3, 32, 32).astype('float32')
realA = paddle.imperative.to_variable(A)
realB = paddle.imperative.to_variable(B)
fakeB = g(realA)
optim_d.clear_gradients()
fake_AB = paddle.concat((realA, fakeB), 1)
G_pred_fake = d(fake_AB.detach())
false_target = paddle.fill_constant(G_pred_fake.shape, 'float32', 0.0)
use_gp = True
if use_gp:
G_gradient_penalty, _ = cal_gradient_penalty(d, realA, fakeB, lambda_gp=10.0)
loss_d = gan_criterion(G_pred_fake, false_target) + G_gradient_penalty
else:
loss_d = gan_criterion(G_pred_fake, false_target)
loss_d.backward()
optim_d.minimize(loss_d)
print('discriminator loss:', loss_d.numpy())
optim_g.clear_gradients()
fake_AB = paddle.concat((realA, fakeB), 1)
G_pred_fake = d(fake_AB)
true_target = paddle.fill_constant(G_pred_fake.shape, 'float32', 1.0)
loss_g = l1_criterion(fakeB, realB) + gan_criterion(G_pred_fake, true_target)
print('generator loss:', loss_g.numpy())
loss_g.backward()
optim_g.minimize(loss_g)
paddle develop版本运行报下列错误: 将代码中的use_gp改为false(不用grad接口),就不报错。同时将tanh去掉的话,会换另一个op的grad找不到。
torch对应的代码可以正常运行,如下所示:
import torch
import numpy as np
# paddle.enable_imperative()
class Generator(torch.nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.conv1 = torch.nn.Conv2d(3, 3, 3, 1, 1)
def forward(self, x):
x = self.conv1(x)
x = torch.nn.functional.tanh(x)
return x
class Discriminator(torch.nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.convd = torch.nn.Conv2d(6, 3, 1)
def forward(self, x):
x = self.convd(x)
return x
def cal_gradient_penalty(netD, real_data, fake_data, edge_data=None, type='mixed', constant=1.0, lambda_gp=10.0):
if lambda_gp > 0.0:
if type == 'real': # either use real images, fake images, or a linear interpolation of two.
interpolatesv = real_data
elif type == 'fake':
interpolatesv = fake_data
elif type == 'mixed':
alpha = torch.rand(real_data.shape[0], 1).cuda()
alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view(*real_data.shape)
interpolatesv = alpha * real_data + ((1 - alpha) * fake_data)
else:
raise NotImplementedError('{} not implemented'.format(type))
interpolatesv.requires_grad_(True)
# interpolatesv.stop_gradient = False
# real_data.stop_gradient = True
fake_AB = torch.cat((real_data.detach(), interpolatesv), 1)
disc_interpolates = netD(fake_AB)
# FIXME: use paddle.ones
# outs = paddle.fill_constant(disc_interpolates.shape, disc_interpolates.dtype, 1.0)
gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=fake_AB,
grad_outputs=torch.ones(disc_interpolates.size()).cuda(),
create_graph=True,
retain_graph=True,
only_inputs=True,
# no_grad_vars=set(netD.parameters())
)
gradients = gradients[0].view(real_data.size(0), -1) # flat the data
gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** 2).mean() * lambda_gp # added eps
return gradient_penalty, gradients
else:
return 0.0, None
g = Generator().cuda()
d = Discriminator().cuda()
optim_g = torch.optim.Adam(g.parameters(), lr=0.001)
optim_d = torch.optim.Adam(d.parameters(), lr=0.001)
gan_criterion = torch.nn.MSELoss()
l1_criterion = torch.nn.L1Loss()
# A = np.random.rand(2, 3, 32, 32).astype('float32')
# B = np.random.rand(2, 3, 32, 32).astype('float32')
# realA = paddle.imperative.to_variable(A)
# realB = paddle.imperative.to_variable(B)
realA = torch.randn(2, 3, 32, 32).cuda()
realB = torch.randn(2, 3, 32, 32).cuda()
fakeB = g(realA)
optim_d.zero_grad()
fake_AB = torch.cat((realA, fakeB), 1)
G_pred_fake = d(fake_AB.detach())
false_target = torch.zeros(G_pred_fake.shape).cuda()
use_gp = True
if use_gp:
G_gradient_penalty, _ = cal_gradient_penalty(d, realA, fakeB, lambda_gp=10.0)
loss_d = gan_criterion(G_pred_fake, false_target) + G_gradient_penalty
else:
loss_d = gan_criterion(G_pred_fake, false_target)
loss_d.backward(retain_graph=True)
optim_d.step()
print('discriminator loss:', loss_d.cpu().detach().numpy())
optim_g.zero_grad()
fake_AB = torch.cat((realA, fakeB), 1)
G_pred_fake = d(fake_AB)
true_target = torch.ones(G_pred_fake.shape).cuda()
loss_g = l1_criterion(fakeB, realB) + gan_criterion(G_pred_fake, true_target)
print('generator loss:', loss_g.cpu().detach().numpy())
loss_g.backward()
optim_g.step()