Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Greenplum
Annotated Deep Learning Paper Implementations
提交
90815b9c
A
Annotated Deep Learning Paper Implementations
项目概览
Greenplum
/
Annotated Deep Learning Paper Implementations
10 个月 前同步成功
通知
6
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
A
Annotated Deep Learning Paper Implementations
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
90815b9c
编写于
9月 26, 2020
作者:
V
Varuna Jayasiri
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
🚧
gan
上级
b5b55093
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
28 addition
and
16 deletion
+28
-16
labml_nn/gan/__init__.py
labml_nn/gan/__init__.py
+4
-6
labml_nn/gan/mnist.py
labml_nn/gan/mnist.py
+24
-10
未找到文件。
labml_nn/gan/__init__.py
浏览文件 @
90815b9c
...
...
@@ -12,7 +12,7 @@ class DiscriminatorLogitsLoss(Module):
self
.
loss_true
=
nn
.
BCEWithLogitsLoss
()
self
.
loss_false
=
nn
.
BCEWithLogitsLoss
()
self
.
register_buffer
(
'labels_true'
,
torch
.
ones
(
256
,
1
,
requires_grad
=
False
),
False
)
self
.
register_buffer
(
'labels_false'
,
torch
.
one
s
(
256
,
1
,
requires_grad
=
False
),
False
)
self
.
register_buffer
(
'labels_false'
,
torch
.
zero
s
(
256
,
1
,
requires_grad
=
False
),
False
)
def
__call__
(
self
,
logits_true
:
torch
.
Tensor
,
logits_false
:
torch
.
Tensor
):
if
len
(
logits_true
)
>
len
(
self
.
labels_true
):
...
...
@@ -20,12 +20,10 @@ class DiscriminatorLogitsLoss(Module):
self
.
labels_true
.
new_ones
(
len
(
logits_true
),
1
,
requires_grad
=
False
),
False
)
if
len
(
logits_false
)
>
len
(
self
.
labels_false
):
self
.
register_buffer
(
"labels_false"
,
self
.
labels_false
.
new_
one
s
(
len
(
logits_false
),
1
,
requires_grad
=
False
),
False
)
self
.
labels_false
.
new_
zero
s
(
len
(
logits_false
),
1
,
requires_grad
=
False
),
False
)
loss
=
(
self
.
loss_true
(
logits_true
,
self
.
labels_true
[:
len
(
logits_true
)])
+
self
.
loss_false
(
logits_false
,
self
.
labels_false
[:
len
(
logits_false
)]))
return
loss
return
self
.
loss_true
(
logits_true
,
self
.
labels_true
[:
len
(
logits_true
)]),
\
self
.
loss_false
(
logits_false
,
self
.
labels_false
[:
len
(
logits_false
)])
class
GeneratorLogitsLoss
(
Module
):
...
...
labml_nn/gan/mnist.py
浏览文件 @
90815b9c
...
...
@@ -18,7 +18,7 @@ plt.rcParams['image.interpolation'] = 'nearest'
plt
.
rcParams
[
'image.cmap'
]
=
'gray'
class
Generator
(
nn
.
Module
):
class
Generator
(
Module
):
def
__init__
(
self
):
super
(
Generator
,
self
).
__init__
()
...
...
@@ -38,7 +38,7 @@ class Generator(nn.Module):
return
x
class
Discriminator
(
nn
.
Module
):
class
Discriminator
(
Module
):
def
__init__
(
self
):
super
(
Discriminator
,
self
).
__init__
()
...
...
@@ -72,6 +72,15 @@ class GANBatchStep(BatchStepProtocol):
self
.
discriminator_optimizer
=
discriminator_optimizer
tracker
.
set_scalar
(
"loss.generator.*"
,
True
)
tracker
.
set_scalar
(
"loss.discriminator.*"
,
True
)
tracker
.
set_image
(
"generated"
,
True
)
def
prepare_for_iteration
(
self
):
if
MODE_STATE
.
is_train
:
self
.
generator
.
train
()
self
.
discriminator
.
train
()
else
:
self
.
generator
.
eval
()
self
.
discriminator
.
eval
()
def
process
(
self
,
batch
:
any
,
state
:
any
):
device
=
self
.
discriminator
.
device
...
...
@@ -79,10 +88,12 @@ class GANBatchStep(BatchStepProtocol):
data
,
target
=
data
.
to
(
device
),
target
.
to
(
device
)
with
monit
.
section
(
"generator"
):
latent
=
torch
.
normal
(
0
,
1
,
(
data
.
shape
[
0
],
100
)
,
device
=
device
)
latent
=
torch
.
randn
(
data
.
shape
[
0
],
100
,
device
=
device
)
if
MODE_STATE
.
is_train
:
self
.
generator_optimizer
.
zero_grad
()
logits
=
self
.
discriminator
(
self
.
generator
(
latent
))
generated_images
=
self
.
generator
(
latent
)
# tracker.add('generated', generated_images[0:1])
logits
=
self
.
discriminator
(
generated_images
)
loss
=
self
.
generator_loss
(
logits
)
tracker
.
add
(
"loss.generator."
,
loss
)
if
MODE_STATE
.
is_train
:
...
...
@@ -90,18 +101,21 @@ class GANBatchStep(BatchStepProtocol):
self
.
generator_optimizer
.
step
()
with
monit
.
section
(
"discriminator"
):
latent
=
torch
.
normal
(
0
,
1
,
(
data
.
shape
[
0
],
100
)
,
device
=
device
)
latent
=
torch
.
randn
(
data
.
shape
[
0
],
100
,
device
=
device
)
if
MODE_STATE
.
is_train
:
self
.
discriminator_optimizer
.
zero_grad
()
logits_false
=
self
.
discriminator
(
self
.
generator
(
latent
).
detach
())
logits_true
=
self
.
discriminator
(
data
)
loss
=
self
.
discriminator_loss
(
logits_true
,
logits_false
)
tracker
.
add
(
"loss.generator."
,
loss
)
logits_false
=
self
.
discriminator
(
self
.
generator
(
latent
).
detach
())
loss_true
,
loss_false
=
self
.
discriminator_loss
(
logits_true
,
logits_false
)
loss
=
loss_true
+
loss_false
tracker
.
add
(
"loss.discriminator.true."
,
loss_true
)
tracker
.
add
(
"loss.discriminator.false."
,
loss_false
)
tracker
.
add
(
"loss.discriminator."
,
loss
)
if
MODE_STATE
.
is_train
:
loss
.
backward
()
self
.
discriminator_optimizer
.
step
()
return
{},
None
return
{
'samples'
:
len
(
data
)
},
None
class
Configs
(
MNISTConfigs
,
TrainValidConfigs
):
...
...
@@ -154,7 +168,7 @@ def main():
'generator_optimizer.optimizer'
:
'Adam'
,
'discriminator_optimizer.learning_rate'
:
2.5e-4
,
'discriminator_optimizer.optimizer'
:
'Adam'
},
[
'set_seed'
,
'main'
]
)
'run'
)
with
experiment
.
start
():
conf
.
run
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录