【论文复现】wgan-gp实现后,梯度直线上升,检查了无数遍,不知道是哪一步出错了
Created by: wszzzx
`# WGAN-GP
class G(fluid.dygraph.Layer):
def init(self, name_scope):
super(G, self).init(name_scope)
name_scope = self.full_name()
self.fc1 = Linear(input_dim=100, output_dim=25622,act='relu')
self.convtrans1 = Conv2DTranspose(num_channels=256,num_filters=128,filter_size=3,output_size=6,stride=2)
self.bn2 = fluid.dygraph.BatchNorm(num_channels=128, act='relu')
self.convtrans2 = Conv2DTranspose(num_channels=128,num_filters=64,filter_size=3,output_size=13,stride=2)
self.bn3 = fluid.dygraph.BatchNorm(num_channels=64, act='relu')
self.convtrans3 = Conv2DTranspose(num_channels=64,num_filters=1,filter_size=3,output_size=28,stride=2,act='tanh')
def forward(self, z):
z = fluid.layers.reshape(z, shape=[-1, 100])
y = self.fc1(z)
y = fluid.layers.reshape(y, shape=[-1,256,2,2])
y = self.convtrans1(y)
y = self.convtrans2(y)
y = self.convtrans3(y)
return y
class D(fluid.dygraph.Layer):
def init(self, name_scope):
super(D, self).init(name_scope)
name_scope = self.full_name()
self.conv1 = Conv2D(num_channels=1, num_filters=64, filter_size=3,stride=2,)
self.conv2 = Conv2D(num_channels=64, num_filters=128, filter_size=3,stride=2)
self.conv3 = Conv2D(num_channels=128, num_filters=256, filter_size=3,stride=2)
self.fc1 = Linear(input_dim=25622, output_dim=1024)
self.fc2 = Linear(input_dim=1024, output_dim=1)
def forward(self, img):
z = fluid.layers.reshape(img, shape=[-1, 1, 28, 28])
y = self.conv1(z)
y = self.conv2(y)
y = self.conv3(y)
y = fluid.layers.reshape(y, shape=[-1, 256*2*2])
y = self.fc1(y)
y = self.fc2(y)
return y`
`lambda_gp = 10 def compute_gradient_penalty(D, real_samples, fake_samples): """Calculates the gradient penalty loss for WGAN GP""" # Random weight term for interpolation between real and fake samples alpha = fluid.dygraph.to_variable(np.random.uniform(size=(real_samples.shape[0],1,1,1))).astype('float32') interpolates = fluid.layers.elementwise_mul(real_samples,alpha) + fluid.layers.elementwise_mul(fake_samples,1-alpha) d_interpolates = D(interpolates) gradients = paddle.fluid.dygraph.grad(d_interpolates,inputs=interpolates)[0]
gradients = fluid.layers.reshape(gradients, shape=[gradients.shape[0],-1])
gradient = np.mean((np.linalg.norm(gradients.numpy(),2,axis=1)-1)**2)
return float(gradient)
def train(mnist_generator, epoch_num=1, batch_size=128, use_gpu=True, load_model=False): place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace() with fluid.dygraph.guard(place): # 模型存储路径 model_path = './output/' d = D('D') d.train() g = G('G') g.train() # 创建优化方法 d_optimizer = fluid.optimizer.AdamOptimizer(learning_rate=2e-4, parameter_list=d.parameters()) g_optimizer = fluid.optimizer.AdamOptimizer(learning_rate=2e-4, parameter_list=g.parameters())
# 读取上次保存的模型
if load_model == True:
g_para, g_opt = fluid.load_dygraph(model_path+'g')
d_para, d_r_opt = fluid.load_dygraph(model_path+'d_o_r')
# 上面判别器的参数已经读取到d_para了,此处无需再次读取
_, d_f_opt = fluid.load_dygraph(model_path+'d_o_f')
g.load_dict(g_para)
g_optimizer.set_dict(g_opt)
d.load_dict(d_para)
real_d_optimizer.set_dict(d_r_opt)
fake_d_optimizer.set_dict(d_f_opt)
iteration_num = 0
d_loss_list = []
g_loss_list = []
for epoch in range(epoch_num):
for i, real_image in enumerate(mnist_generator()):
# 丢弃不满整个batch_size的数据
if(len(real_image) != BATCH_SIZE):
continue
iteration_num += 1
z = next(z_generator())
z = fluid.dygraph.to_variable(np.array(z))
fake_image = g(z)
real_image = fluid.dygraph.to_variable(np.array(real_image))
gradient_penalty = compute_gradient_penalty(d, real_image, fake_image)
p_d_fake = d(fake_image)
p_d_real = d(real_image)
# d_loss = -fluid.layers.mean(p_d_real) + fluid.layers.mean(p_d_fake) + lambda_gp * gradient_penalty
# print('loss',-fluid.layers.mean(p_d_real).numpy(),fluid.layers.mean(p_d_fake).numpy(),lambda_gp * gradient_penalty)
d_loss = -fluid.layers.mean(d(real_image)) + fluid.layers.mean(d(fake_image))
d_loss.backward()
d_optimizer.minimize(d_loss)
d.clear_gradients()
# 生成器用输入的高斯噪声z生成假图片
# fake = g(z)
# 计算判别器d判断生成器g生成的假图片的概率
fake_image = g(z)
p_d_fake = d(fake_image)
g_loss = -fluid.layers.mean(p_d_fake)
# 反向传播更新生成器g的参数
g_loss.backward()
g_optimizer.minimize(g_loss)
g.clear_gradients()
# 打印输出
if(iteration_num % 200 == 0):
print('epoch =', epoch, ', batch =', i, ', d_loss =', d_loss.numpy(),'g_loss =', g_loss.numpy())
show_image_grid(fake_image.numpy(), BATCH_SIZE, epoch)
d_loss_list.append(d_loss.numpy())
g_loss_list.append(g_loss.numpy()) `