提交 90815b9c 编写于 作者: V Varuna Jayasiri

🚧 gan

上级 b5b55093
......@@ -12,7 +12,7 @@ class DiscriminatorLogitsLoss(Module):
self.loss_true = nn.BCEWithLogitsLoss()
self.loss_false = nn.BCEWithLogitsLoss()
self.register_buffer('labels_true', torch.ones(256, 1, requires_grad=False), False)
self.register_buffer('labels_false', torch.ones(256, 1, requires_grad=False), False)
self.register_buffer('labels_false', torch.zeros(256, 1, requires_grad=False), False)
def __call__(self, logits_true: torch.Tensor, logits_false: torch.Tensor):
if len(logits_true) > len(self.labels_true):
......@@ -20,12 +20,10 @@ class DiscriminatorLogitsLoss(Module):
self.labels_true.new_ones(len(logits_true), 1, requires_grad=False), False)
if len(logits_false) > len(self.labels_false):
self.register_buffer("labels_false",
self.labels_false.new_ones(len(logits_false), 1, requires_grad=False), False)
self.labels_false.new_zeros(len(logits_false), 1, requires_grad=False), False)
loss = (self.loss_true(logits_true, self.labels_true[:len(logits_true)]) +
self.loss_false(logits_false, self.labels_false[:len(logits_false)]))
return loss
return self.loss_true(logits_true, self.labels_true[:len(logits_true)]), \
self.loss_false(logits_false, self.labels_false[:len(logits_false)])
class GeneratorLogitsLoss(Module):
......
......@@ -18,7 +18,7 @@ plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'
class Generator(nn.Module):
class Generator(Module):
def __init__(self):
super(Generator, self).__init__()
......@@ -38,7 +38,7 @@ class Generator(nn.Module):
return x
class Discriminator(nn.Module):
class Discriminator(Module):
def __init__(self):
super(Discriminator, self).__init__()
......@@ -72,6 +72,15 @@ class GANBatchStep(BatchStepProtocol):
self.discriminator_optimizer = discriminator_optimizer
tracker.set_scalar("loss.generator.*", True)
tracker.set_scalar("loss.discriminator.*", True)
tracker.set_image("generated", True)
def prepare_for_iteration(self):
if MODE_STATE.is_train:
self.generator.train()
self.discriminator.train()
else:
self.generator.eval()
self.discriminator.eval()
def process(self, batch: any, state: any):
device = self.discriminator.device
......@@ -79,10 +88,12 @@ class GANBatchStep(BatchStepProtocol):
data, target = data.to(device), target.to(device)
with monit.section("generator"):
latent = torch.normal(0, 1, (data.shape[0], 100), device=device)
latent = torch.randn(data.shape[0], 100, device=device)
if MODE_STATE.is_train:
self.generator_optimizer.zero_grad()
logits = self.discriminator(self.generator(latent))
generated_images = self.generator(latent)
# tracker.add('generated', generated_images[0:1])
logits = self.discriminator(generated_images)
loss = self.generator_loss(logits)
tracker.add("loss.generator.", loss)
if MODE_STATE.is_train:
......@@ -90,18 +101,21 @@ class GANBatchStep(BatchStepProtocol):
self.generator_optimizer.step()
with monit.section("discriminator"):
latent = torch.normal(0, 1, (data.shape[0], 100), device=device)
latent = torch.randn(data.shape[0], 100, device=device)
if MODE_STATE.is_train:
self.discriminator_optimizer.zero_grad()
logits_false = self.discriminator(self.generator(latent).detach())
logits_true = self.discriminator(data)
loss = self.discriminator_loss(logits_true, logits_false)
tracker.add("loss.generator.", loss)
logits_false = self.discriminator(self.generator(latent).detach())
loss_true, loss_false = self.discriminator_loss(logits_true, logits_false)
loss = loss_true + loss_false
tracker.add("loss.discriminator.true.", loss_true)
tracker.add("loss.discriminator.false.", loss_false)
tracker.add("loss.discriminator.", loss)
if MODE_STATE.is_train:
loss.backward()
self.discriminator_optimizer.step()
return {}, None
return {'samples': len(data)}, None
class Configs(MNISTConfigs, TrainValidConfigs):
......@@ -154,7 +168,7 @@ def main():
'generator_optimizer.optimizer': 'Adam',
'discriminator_optimizer.learning_rate': 2.5e-4,
'discriminator_optimizer.optimizer': 'Adam'},
['set_seed', 'main'])
'run')
with experiment.start():
conf.run()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册