Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Greenplum
Annotated Deep Learning Paper Implementations
提交
a6502de6
A
Annotated Deep Learning Paper Implementations
项目概览
Greenplum
/
Annotated Deep Learning Paper Implementations
大约 1 年 前同步成功
通知
6
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
A
Annotated Deep Learning Paper Implementations
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
a6502de6
编写于
12月 26, 2020
作者:
V
Varuna Jayasiri
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
shuffle data
上级
716dda5f
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
24 addition
and
11 deletion
+24
-11
labml_nn/hypernetworks/experiment.py
labml_nn/hypernetworks/experiment.py
+24
-11
未找到文件。
labml_nn/hypernetworks/experiment.py
浏览文件 @
a6502de6
...
@@ -6,11 +6,12 @@ from labml import lab, experiment, monit, tracker, logger
...
@@ -6,11 +6,12 @@ from labml import lab, experiment, monit, tracker, logger
from
labml.configs
import
option
from
labml.configs
import
option
from
labml.logger
import
Text
from
labml.logger
import
Text
from
labml.utils.pytorch
import
get_modules
from
labml.utils.pytorch
import
get_modules
from
labml_helpers.datasets.text
import
TextDataset
,
SequentialDataLoader
,
TextFile
Dataset
from
labml_helpers.datasets.text
import
TextDataset
,
TextFileDataset
,
SequentialUnBatched
Dataset
from
labml_helpers.metrics.accuracy
import
Accuracy
from
labml_helpers.metrics.accuracy
import
Accuracy
from
labml_helpers.module
import
Module
from
labml_helpers.module
import
Module
from
labml_helpers.optimizer
import
OptimizerConfigs
from
labml_helpers.optimizer
import
OptimizerConfigs
from
labml_helpers.train_valid
import
SimpleTrainValidConfigs
,
BatchIndex
from
labml_helpers.train_valid
import
SimpleTrainValidConfigs
,
BatchIndex
from
torch.utils.data
import
DataLoader
from
labml_nn.hypernetworks.hyper_lstm
import
HyperLSTM
from
labml_nn.hypernetworks.hyper_lstm
import
HyperLSTM
...
@@ -48,6 +49,14 @@ class CrossEntropyLoss(Module):
...
@@ -48,6 +49,14 @@ class CrossEntropyLoss(Module):
return
self
.
loss
(
outputs
.
view
(
-
1
,
outputs
.
shape
[
-
1
]),
targets
.
view
(
-
1
))
return
self
.
loss
(
outputs
.
view
(
-
1
,
outputs
.
shape
[
-
1
]),
targets
.
view
(
-
1
))
def
transpose_batch
(
batch
):
transposed_data
=
list
(
zip
(
*
batch
))
src
=
torch
.
stack
(
transposed_data
[
0
],
1
)
tgt
=
torch
.
stack
(
transposed_data
[
1
],
1
)
return
src
,
tgt
class
Configs
(
SimpleTrainValidConfigs
):
class
Configs
(
SimpleTrainValidConfigs
):
"""
"""
## Configurations
## Configurations
...
@@ -78,16 +87,20 @@ class Configs(SimpleTrainValidConfigs):
...
@@ -78,16 +87,20 @@ class Configs(SimpleTrainValidConfigs):
self
.
optimizer
=
optimizer
self
.
optimizer
=
optimizer
# Create a sequential data loader for training
# Create a sequential data loader for training
self
.
train_loader
=
SequentialDataLoader
(
text
=
self
.
text
.
train
,
self
.
train_loader
=
DataLoader
(
SequentialUnBatchedDataset
(
text
=
self
.
text
.
train
,
dataset
=
self
.
text
,
dataset
=
self
.
text
,
batch_size
=
self
.
batch_size
,
seq_len
=
self
.
seq_len
),
seq_len
=
self
.
seq_len
)
batch_size
=
self
.
batch_size
,
collate_fn
=
transpose_batch
,
shuffle
=
True
)
# Create a sequential data loader for validation
# Create a sequential data loader for validation
self
.
valid_loader
=
SequentialDataLoader
(
text
=
self
.
text
.
valid
,
self
.
valid_loader
=
DataLoader
(
SequentialUnBatchedDataset
(
text
=
self
.
text
.
valid
,
dataset
=
self
.
text
,
dataset
=
self
.
text
,
batch_size
=
self
.
batch_size
,
seq_len
=
self
.
seq_len
),
seq_len
=
self
.
seq_len
)
batch_size
=
self
.
batch_size
,
collate_fn
=
transpose_batch
,
shuffle
=
True
)
self
.
state_modules
=
[
self
.
accuracy
]
self
.
state_modules
=
[
self
.
accuracy
]
...
@@ -186,12 +199,12 @@ def main():
...
@@ -186,12 +199,12 @@ def main():
# A dictionary of configurations to override
# A dictionary of configurations to override
{
'tokenizer'
:
'character'
,
{
'tokenizer'
:
'character'
,
'text'
:
'tiny_shakespeare'
,
'text'
:
'tiny_shakespeare'
,
'optimizer.learning_rate'
:
1
e-4
,
'optimizer.learning_rate'
:
2.5
e-4
,
'seq_len'
:
512
,
'seq_len'
:
512
,
'epochs'
:
128
,
'epochs'
:
128
,
'batch_size'
:
2
,
'batch_size'
:
2
,
'inner_iterations'
:
10
})
'inner_iterations'
:
25
})
# This is needed to initialize models
# This is needed to initialize models
conf
.
n_tokens
=
conf
.
text
.
n_tokens
conf
.
n_tokens
=
conf
.
text
.
n_tokens
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录