Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
OpenDocCN
Dive-into-DL-PyTorch
提交
6ab3fe84
D
Dive-into-DL-PyTorch
项目概览
OpenDocCN
/
Dive-into-DL-PyTorch
通知
9
Star
2
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
Dive-into-DL-PyTorch
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
6ab3fe84
编写于
4月 07, 2019
作者:
S
ShusenTang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update d2lzh_pytorch
上级
ac21562b
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
96 addition
and
1 deletion
+96
-1
code/d2lzh_pytorch/utils.py
code/d2lzh_pytorch/utils.py
+96
-1
未找到文件。
code/d2lzh_pytorch/utils.py
浏览文件 @
6ab3fe84
...
...
@@ -348,4 +348,99 @@ def data_iter_consecutive(corpus_indices, batch_size, num_steps, device=None):
i
=
i
*
num_steps
X
=
indices
[:,
i
:
i
+
num_steps
]
Y
=
indices
[:,
i
+
1
:
i
+
num_steps
+
1
]
yield
X
,
Y
\ No newline at end of file
yield
X
,
Y
# ###################################### 6.4 ######################################
def
one_hot
(
x
,
n_class
,
dtype
=
torch
.
float32
):
# X shape: (batch), output shape: (batch, n_class)
x
=
x
.
long
()
res
=
torch
.
zeros
(
x
.
shape
[
0
],
n_class
,
dtype
=
dtype
,
device
=
x
.
device
)
res
.
scatter_
(
1
,
x
.
view
(
-
1
,
1
),
1
)
return
res
def
to_onehot
(
X
,
n_class
):
# 本函数已保存在d2lzh包中方便以后使用
# X shape: (batch, seq_len), output: seq_len elements of (batch, n_class)
return
[
one_hot
(
X
[:,
i
],
n_class
)
for
i
in
range
(
X
.
shape
[
1
])]
def
predict_rnn
(
prefix
,
num_chars
,
rnn
,
params
,
init_rnn_state
,
num_hiddens
,
vocab_size
,
device
,
idx_to_char
,
char_to_idx
):
state
=
init_rnn_state
(
1
,
num_hiddens
,
device
)
output
=
[
char_to_idx
[
prefix
[
0
]]]
for
t
in
range
(
num_chars
+
len
(
prefix
)
-
1
):
# 将上一时间步的输出作为当前时间步的输入
X
=
to_onehot
(
torch
.
tensor
([[
output
[
-
1
]]],
device
=
device
),
vocab_size
)
# 计算输出和更新隐藏状态
(
Y
,
state
)
=
rnn
(
X
,
state
,
params
)
# 下一个时间步的输入是prefix里的字符或者当前的最佳预测字符
if
t
<
len
(
prefix
)
-
1
:
output
.
append
(
char_to_idx
[
prefix
[
t
+
1
]])
else
:
output
.
append
(
int
(
Y
[
0
].
argmax
(
dim
=
1
).
item
()))
return
''
.
join
([
idx_to_char
[
i
]
for
i
in
output
])
def
grad_clipping
(
params
,
theta
,
device
):
norm
=
torch
.
tensor
([
0.0
],
device
=
device
)
for
param
in
params
:
norm
+=
(
param
.
grad
.
data
**
2
).
sum
()
norm
=
norm
.
sqrt
().
item
()
if
norm
>
theta
:
for
param
in
params
:
param
.
grad
.
data
*=
(
theta
/
norm
)
# 本函数已保存在d2lzh包中方便以后使用
def
train_and_predict_rnn
(
rnn
,
get_params
,
init_rnn_state
,
num_hiddens
,
vocab_size
,
device
,
corpus_indices
,
idx_to_char
,
char_to_idx
,
is_random_iter
,
num_epochs
,
num_steps
,
lr
,
clipping_theta
,
batch_size
,
pred_period
,
pred_len
,
prefixes
):
if
is_random_iter
:
data_iter_fn
=
data_iter_random
else
:
data_iter_fn
=
data_iter_consecutive
params
=
get_params
()
loss
=
nn
.
CrossEntropyLoss
()
for
epoch
in
range
(
num_epochs
):
if
not
is_random_iter
:
# 如使用相邻采样,在epoch开始时初始化隐藏状态
state
=
init_rnn_state
(
batch_size
,
num_hiddens
,
device
)
l_sum
,
n
,
start
=
0.0
,
0
,
time
.
time
()
data_iter
=
data_iter_fn
(
corpus_indices
,
batch_size
,
num_steps
,
device
)
for
X
,
Y
in
data_iter
:
if
is_random_iter
:
# 如使用随机采样,在每个小批量更新前初始化隐藏状态
state
=
init_rnn_state
(
batch_size
,
num_hiddens
,
device
)
else
:
# 否则需要使用detach函数从计算图分离隐藏状态
for
s
in
state
:
s
.
detach_
()
inputs
=
to_onehot
(
X
,
vocab_size
)
# outputs有num_steps个形状为(batch_size, vocab_size)的矩阵
(
outputs
,
state
)
=
rnn
(
inputs
,
state
,
params
)
# 拼接之后形状为(num_steps * batch_size, vocab_size)
outputs
=
torch
.
cat
(
outputs
,
dim
=
0
)
# Y的形状是(batch_size, num_steps),转置后再变成长度为
# batch * num_steps 的向量,这样跟输出的行一一对应
y
=
torch
.
transpose
(
Y
,
0
,
1
).
contiguous
().
view
(
-
1
)
# 使用交叉熵损失计算平均分类误差
l
=
loss
(
outputs
,
y
.
long
())
# 梯度清0
if
params
[
0
].
grad
is
not
None
:
for
param
in
params
:
param
.
grad
.
data
.
zero_
()
l
.
backward
()
grad_clipping
(
params
,
clipping_theta
,
device
)
# 裁剪梯度
sgd
(
params
,
lr
,
1
)
# 因为误差已经取过均值,梯度不用再做平均
l_sum
+=
l
.
item
()
*
y
.
shape
[
0
]
n
+=
y
.
shape
[
0
]
if
(
epoch
+
1
)
%
pred_period
==
0
:
print
(
'epoch %d, perplexity %f, time %.2f sec'
%
(
epoch
+
1
,
math
.
exp
(
l_sum
/
n
),
time
.
time
()
-
start
))
for
prefix
in
prefixes
:
print
(
' -'
,
predict_rnn
(
prefix
,
pred_len
,
rnn
,
params
,
init_rnn_state
,
num_hiddens
,
vocab_size
,
device
,
idx_to_char
,
char_to_idx
))
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录