Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Laurence001
d2l-zh
提交
0792f783
D
d2l-zh
项目概览
Laurence001
/
d2l-zh
与 Fork 源项目一致
从无法访问的项目Fork
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
d2l-zh
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
0792f783
编写于
7月 18, 2021
作者:
A
Aston Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
sync d2l
上级
72c58f54
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
152 addition
and
77 deletion
+152
-77
d2l/mxnet.py
d2l/mxnet.py
+61
-35
d2l/tensorflow.py
d2l/tensorflow.py
+32
-10
d2l/torch.py
d2l/torch.py
+59
-32
未找到文件。
d2l/mxnet.py
浏览文件 @
0792f783
...
...
@@ -520,17 +520,18 @@ class Vocab:
reserved_tokens
=
[]
# Sort according to frequencies
counter
=
count_corpus
(
tokens
)
self
.
token_freqs
=
sorted
(
counter
.
items
(),
key
=
lambda
x
:
x
[
1
],
reverse
=
True
)
self
.
_
token_freqs
=
sorted
(
counter
.
items
(),
key
=
lambda
x
:
x
[
1
],
reverse
=
True
)
# The index for the unknown token is 0
self
.
unk
,
uniq_tokens
=
0
,
[
'<unk>'
]
+
reserved_tokens
uniq_tokens
+=
[
token
for
token
,
freq
in
self
.
token_freqs
if
freq
>=
min_freq
and
token
not
in
uniq_tokens
]
self
.
idx_to_token
,
self
.
token_to_idx
=
[],
dict
()
for
token
in
uniq_tokens
:
self
.
idx_to_token
.
append
(
token
)
self
.
token_to_idx
[
token
]
=
len
(
self
.
idx_to_token
)
-
1
self
.
idx_to_token
=
[
'<unk>'
]
+
reserved_tokens
self
.
token_to_idx
=
{
token
:
idx
for
idx
,
token
in
enumerate
(
self
.
idx_to_token
)}
for
token
,
freq
in
self
.
_token_freqs
:
if
freq
<
min_freq
:
break
if
token
not
in
self
.
token_to_idx
:
self
.
idx_to_token
.
append
(
token
)
self
.
token_to_idx
[
token
]
=
len
(
self
.
idx_to_token
)
-
1
def
__len__
(
self
):
return
len
(
self
.
idx_to_token
)
...
...
@@ -545,6 +546,14 @@ class Vocab:
return
self
.
idx_to_token
[
indices
]
return
[
self
.
idx_to_token
[
index
]
for
index
in
indices
]
@
property
def
unk
(
self
):
# Index for the unknown token
return
0
@
property
def
token_freqs
(
self
):
# Index for the unknown token
return
self
.
_token_freqs
def
count_corpus
(
tokens
):
"""Count token frequencies."""
...
...
@@ -800,6 +809,19 @@ def tokenize_nmt(text, num_examples=None):
return
source
,
target
# Defined in file: ./chapter_recurrent-modern/machine-translation-and-dataset.md
def
show_list_len_pair_hist
(
legend
,
xlabel
,
ylabel
,
xlist
,
ylist
):
"""Plot the histogram for list length pairs."""
d2l
.
set_figsize
()
_
,
_
,
patches
=
d2l
.
plt
.
hist
([[
len
(
l
)
for
l
in
xlist
],
[
len
(
l
)
for
l
in
ylist
]])
d2l
.
plt
.
xlabel
(
xlabel
)
d2l
.
plt
.
ylabel
(
ylabel
)
for
patch
in
patches
[
1
].
patches
:
patch
.
set_hatch
(
'/'
)
d2l
.
plt
.
legend
(
legend
)
# Defined in file: ./chapter_recurrent-modern/machine-translation-and-dataset.md
def
truncate_pad
(
line
,
num_steps
,
padding_token
):
"""Truncate or pad sequences."""
...
...
@@ -1547,15 +1569,15 @@ def multibox_prior(data, sizes, ratios):
# Defined in file: ./chapter_computer-vision/anchor.md
def
show_bboxes
(
axes
,
bboxes
,
labels
=
None
,
colors
=
None
):
"""Show bounding boxes."""
def
_
make_list
(
obj
,
default_values
=
None
):
def
make_list
(
obj
,
default_values
=
None
):
if
obj
is
None
:
obj
=
default_values
elif
not
isinstance
(
obj
,
(
list
,
tuple
)):
obj
=
[
obj
]
return
obj
labels
=
_
make_list
(
labels
)
colors
=
_
make_list
(
colors
,
[
'b'
,
'g'
,
'r'
,
'm'
,
'c'
])
labels
=
make_list
(
labels
)
colors
=
make_list
(
colors
,
[
'b'
,
'g'
,
'r'
,
'm'
,
'c'
])
for
i
,
bbox
in
enumerate
(
bboxes
):
color
=
colors
[
i
%
len
(
colors
)]
rect
=
d2l
.
bbox_to_rect
(
d2l
.
numpy
(
bbox
),
color
)
...
...
@@ -1955,44 +1977,44 @@ d2l.DATA_HUB['ptb'] = (d2l.DATA_URL + 'ptb.zip',
def
read_ptb
():
data_dir
=
d2l
.
download_extract
(
'ptb'
)
# Read the training set.
with
open
(
os
.
path
.
join
(
data_dir
,
'ptb.train.txt'
))
as
f
:
raw_text
=
f
.
read
()
return
[
line
.
split
()
for
line
in
raw_text
.
split
(
'
\n
'
)]
# Defined in file: ./chapter_natural-language-processing-pretraining/word-embedding-dataset.md
def
subsampl
ing
(
sentences
,
vocab
):
#
Map low frequency words into <unk>
sentences
=
[[
vocab
.
idx_to_token
[
vocab
[
tk
]]
for
tk
in
line
]
def
subsampl
e
(
sentences
,
vocab
):
#
Exclude unknown tokens '<unk>'
sentences
=
[[
token
for
token
in
line
if
vocab
[
token
]
!=
vocab
.
unk
]
for
line
in
sentences
]
# Count the frequency for each word
counter
=
d2l
.
count_corpus
(
sentences
)
num_tokens
=
sum
(
counter
.
values
())
# Return True if
to keep this token
during subsampling
# Return True if
`token` is kept
during subsampling
def
keep
(
token
):
return
(
random
.
uniform
(
0
,
1
)
<
math
.
sqrt
(
1e-4
/
counter
[
token
]
*
num_tokens
))
# Now do the subsampling
return
[[
tk
for
tk
in
line
if
keep
(
tk
)]
for
line
in
sentences
]
return
([[
token
for
token
in
line
if
keep
(
token
)]
for
line
in
sentences
],
counter
)
# Defined in file: ./chapter_natural-language-processing-pretraining/word-embedding-dataset.md
def
get_centers_and_contexts
(
corpus
,
max_window_size
):
centers
,
contexts
=
[],
[]
for
line
in
corpus
:
#
Each sentence needs at least 2 words to form a "central target word
#
- context word" pair
#
To form a "center word--context word" pair, each sentence needs to
#
have at least 2 words
if
len
(
line
)
<
2
:
continue
centers
+=
line
for
i
in
range
(
len
(
line
)):
# Context window centered at
i
for
i
in
range
(
len
(
line
)):
# Context window centered at
`i`
window_size
=
random
.
randint
(
1
,
max_window_size
)
indices
=
list
(
range
(
max
(
0
,
i
-
window_size
),
min
(
len
(
line
),
i
+
1
+
window_size
)))
# Exclude the cent
ral target
word from the context words
# Exclude the cent
er
word from the context words
indices
.
remove
(
i
)
contexts
.
append
([
line
[
idx
]
for
idx
in
indices
])
return
centers
,
contexts
...
...
@@ -2000,15 +2022,17 @@ def get_centers_and_contexts(corpus, max_window_size):
# Defined in file: ./chapter_natural-language-processing-pretraining/word-embedding-dataset.md
class
RandomGenerator
:
"""
Draw a random int in [0, n]
according to n sampling weights."""
"""
Randomly draw among {1, ..., n}
according to n sampling weights."""
def
__init__
(
self
,
sampling_weights
):
self
.
population
=
list
(
range
(
len
(
sampling_weights
)))
# Exclude
self
.
population
=
list
(
range
(
1
,
len
(
sampling_weights
)
+
1
))
self
.
sampling_weights
=
sampling_weights
self
.
candidates
=
[]
self
.
i
=
0
def
draw
(
self
):
if
self
.
i
==
len
(
self
.
candidates
):
# Cache `k` random sampling results
self
.
candidates
=
random
.
choices
(
self
.
population
,
self
.
sampling_weights
,
k
=
10000
)
self
.
i
=
0
...
...
@@ -2017,9 +2041,11 @@ class RandomGenerator:
# Defined in file: ./chapter_natural-language-processing-pretraining/word-embedding-dataset.md
def
get_negatives
(
all_contexts
,
corpus
,
K
):
counter
=
d2l
.
count_corpus
(
corpus
)
sampling_weights
=
[
count
**
0.75
for
count
in
counter
.
values
()]
def
get_negatives
(
all_contexts
,
vocab
,
counter
,
K
):
# Sampling weights for words with indices 1, 2, ... (index 0 is the
# excluded unknown token) in the vocabulary
sampling_weights
=
[
counter
[
vocab
.
to_tokens
(
i
)]
**
0.75
for
i
in
range
(
1
,
len
(
vocab
))]
all_negatives
,
generator
=
[],
RandomGenerator
(
sampling_weights
)
for
contexts
in
all_contexts
:
negatives
=
[]
...
...
@@ -2049,19 +2075,19 @@ def batchify(data):
# Defined in file: ./chapter_natural-language-processing-pretraining/word-embedding-dataset.md
def
load_data_ptb
(
batch_size
,
max_window_size
,
num_noise_words
):
num_workers
=
d2l
.
get_dataloader_workers
()
sentences
=
read_ptb
()
vocab
=
d2l
.
Vocab
(
sentences
,
min_freq
=
10
)
subsampled
=
subsampling
(
sentences
,
vocab
)
subsampled
,
counter
=
subsample
(
sentences
,
vocab
)
corpus
=
[
vocab
[
line
]
for
line
in
subsampled
]
all_centers
,
all_contexts
=
get_centers_and_contexts
(
corpus
,
max_window_size
)
all_negatives
=
get_negatives
(
all_contexts
,
corpus
,
num_noise_words
)
all_negatives
=
get_negatives
(
all_contexts
,
vocab
,
counter
,
num_noise_words
)
dataset
=
gluon
.
data
.
ArrayDataset
(
all_centers
,
all_contexts
,
all_negatives
)
data_iter
=
gluon
.
data
.
DataLoader
(
dataset
,
batch_size
,
shuffle
=
True
,
batchify_fn
=
batchify
,
num_workers
=
num_workers
)
data_iter
=
gluon
.
data
.
DataLoader
(
dataset
,
batch_size
,
shuffle
=
True
,
batchify_fn
=
batchify
,
num_workers
=
d2l
.
get_dataloader_workers
()
)
return
data_iter
,
vocab
...
...
d2l/tensorflow.py
浏览文件 @
0792f783
...
...
@@ -544,17 +544,18 @@ class Vocab:
reserved_tokens
=
[]
# Sort according to frequencies
counter
=
count_corpus
(
tokens
)
self
.
token_freqs
=
sorted
(
counter
.
items
(),
key
=
lambda
x
:
x
[
1
],
reverse
=
True
)
self
.
_
token_freqs
=
sorted
(
counter
.
items
(),
key
=
lambda
x
:
x
[
1
],
reverse
=
True
)
# The index for the unknown token is 0
self
.
unk
,
uniq_tokens
=
0
,
[
'<unk>'
]
+
reserved_tokens
uniq_tokens
+=
[
token
for
token
,
freq
in
self
.
token_freqs
if
freq
>=
min_freq
and
token
not
in
uniq_tokens
]
self
.
idx_to_token
,
self
.
token_to_idx
=
[],
dict
()
for
token
in
uniq_tokens
:
self
.
idx_to_token
.
append
(
token
)
self
.
token_to_idx
[
token
]
=
len
(
self
.
idx_to_token
)
-
1
self
.
idx_to_token
=
[
'<unk>'
]
+
reserved_tokens
self
.
token_to_idx
=
{
token
:
idx
for
idx
,
token
in
enumerate
(
self
.
idx_to_token
)}
for
token
,
freq
in
self
.
_token_freqs
:
if
freq
<
min_freq
:
break
if
token
not
in
self
.
token_to_idx
:
self
.
idx_to_token
.
append
(
token
)
self
.
token_to_idx
[
token
]
=
len
(
self
.
idx_to_token
)
-
1
def
__len__
(
self
):
return
len
(
self
.
idx_to_token
)
...
...
@@ -569,6 +570,14 @@ class Vocab:
return
self
.
idx_to_token
[
indices
]
return
[
self
.
idx_to_token
[
index
]
for
index
in
indices
]
@
property
def
unk
(
self
):
# Index for the unknown token
return
0
@
property
def
token_freqs
(
self
):
# Index for the unknown token
return
self
.
_token_freqs
def
count_corpus
(
tokens
):
"""Count token frequencies."""
...
...
@@ -827,6 +836,19 @@ def tokenize_nmt(text, num_examples=None):
return
source
,
target
# Defined in file: ./chapter_recurrent-modern/machine-translation-and-dataset.md
def
show_list_len_pair_hist
(
legend
,
xlabel
,
ylabel
,
xlist
,
ylist
):
"""Plot the histogram for list length pairs."""
d2l
.
set_figsize
()
_
,
_
,
patches
=
d2l
.
plt
.
hist
([[
len
(
l
)
for
l
in
xlist
],
[
len
(
l
)
for
l
in
ylist
]])
d2l
.
plt
.
xlabel
(
xlabel
)
d2l
.
plt
.
ylabel
(
ylabel
)
for
patch
in
patches
[
1
].
patches
:
patch
.
set_hatch
(
'/'
)
d2l
.
plt
.
legend
(
legend
)
# Defined in file: ./chapter_recurrent-modern/machine-translation-and-dataset.md
def
truncate_pad
(
line
,
num_steps
,
padding_token
):
"""Truncate or pad sequences."""
...
...
d2l/torch.py
浏览文件 @
0792f783
...
...
@@ -451,7 +451,7 @@ def corr2d(X, K):
# Defined in file: ./chapter_convolutional-neural-networks/lenet.md
def
evaluate_accuracy_gpu
(
net
,
data_iter
,
device
=
None
):
"""Compute the accuracy for a model on a dataset using a GPU."""
if
isinstance
(
net
,
torch
.
nn
.
Module
):
if
isinstance
(
net
,
nn
.
Module
):
net
.
eval
()
# Set the model to evaluation mode
if
not
device
:
device
=
next
(
iter
(
net
.
parameters
())).
device
...
...
@@ -573,17 +573,18 @@ class Vocab:
reserved_tokens
=
[]
# Sort according to frequencies
counter
=
count_corpus
(
tokens
)
self
.
token_freqs
=
sorted
(
counter
.
items
(),
key
=
lambda
x
:
x
[
1
],
reverse
=
True
)
self
.
_
token_freqs
=
sorted
(
counter
.
items
(),
key
=
lambda
x
:
x
[
1
],
reverse
=
True
)
# The index for the unknown token is 0
self
.
unk
,
uniq_tokens
=
0
,
[
'<unk>'
]
+
reserved_tokens
uniq_tokens
+=
[
token
for
token
,
freq
in
self
.
token_freqs
if
freq
>=
min_freq
and
token
not
in
uniq_tokens
]
self
.
idx_to_token
,
self
.
token_to_idx
=
[],
dict
()
for
token
in
uniq_tokens
:
self
.
idx_to_token
.
append
(
token
)
self
.
token_to_idx
[
token
]
=
len
(
self
.
idx_to_token
)
-
1
self
.
idx_to_token
=
[
'<unk>'
]
+
reserved_tokens
self
.
token_to_idx
=
{
token
:
idx
for
idx
,
token
in
enumerate
(
self
.
idx_to_token
)}
for
token
,
freq
in
self
.
_token_freqs
:
if
freq
<
min_freq
:
break
if
token
not
in
self
.
token_to_idx
:
self
.
idx_to_token
.
append
(
token
)
self
.
token_to_idx
[
token
]
=
len
(
self
.
idx_to_token
)
-
1
def
__len__
(
self
):
return
len
(
self
.
idx_to_token
)
...
...
@@ -598,6 +599,14 @@ class Vocab:
return
self
.
idx_to_token
[
indices
]
return
[
self
.
idx_to_token
[
index
]
for
index
in
indices
]
@
property
def
unk
(
self
):
# Index for the unknown token
return
0
@
property
def
token_freqs
(
self
):
# Index for the unknown token
return
self
.
_token_freqs
def
count_corpus
(
tokens
):
"""Count token frequencies."""
...
...
@@ -883,6 +892,19 @@ def tokenize_nmt(text, num_examples=None):
return
source
,
target
# Defined in file: ./chapter_recurrent-modern/machine-translation-and-dataset.md
def
show_list_len_pair_hist
(
legend
,
xlabel
,
ylabel
,
xlist
,
ylist
):
"""Plot the histogram for list length pairs."""
d2l
.
set_figsize
()
_
,
_
,
patches
=
d2l
.
plt
.
hist
([[
len
(
l
)
for
l
in
xlist
],
[
len
(
l
)
for
l
in
ylist
]])
d2l
.
plt
.
xlabel
(
xlabel
)
d2l
.
plt
.
ylabel
(
ylabel
)
for
patch
in
patches
[
1
].
patches
:
patch
.
set_hatch
(
'/'
)
d2l
.
plt
.
legend
(
legend
)
# Defined in file: ./chapter_recurrent-modern/machine-translation-and-dataset.md
def
truncate_pad
(
line
,
num_steps
,
padding_token
):
"""Truncate or pad sequences."""
...
...
@@ -1660,15 +1682,15 @@ def multibox_prior(data, sizes, ratios):
# Defined in file: ./chapter_computer-vision/anchor.md
def
show_bboxes
(
axes
,
bboxes
,
labels
=
None
,
colors
=
None
):
"""Show bounding boxes."""
def
_
make_list
(
obj
,
default_values
=
None
):
def
make_list
(
obj
,
default_values
=
None
):
if
obj
is
None
:
obj
=
default_values
elif
not
isinstance
(
obj
,
(
list
,
tuple
)):
obj
=
[
obj
]
return
obj
labels
=
_
make_list
(
labels
)
colors
=
_
make_list
(
colors
,
[
'b'
,
'g'
,
'r'
,
'm'
,
'c'
])
labels
=
make_list
(
labels
)
colors
=
make_list
(
colors
,
[
'b'
,
'g'
,
'r'
,
'm'
,
'c'
])
for
i
,
bbox
in
enumerate
(
bboxes
):
color
=
colors
[
i
%
len
(
colors
)]
rect
=
d2l
.
bbox_to_rect
(
d2l
.
numpy
(
bbox
),
color
)
...
...
@@ -2072,44 +2094,44 @@ d2l.DATA_HUB['ptb'] = (d2l.DATA_URL + 'ptb.zip',
def
read_ptb
():
data_dir
=
d2l
.
download_extract
(
'ptb'
)
# Read the training set.
with
open
(
os
.
path
.
join
(
data_dir
,
'ptb.train.txt'
))
as
f
:
raw_text
=
f
.
read
()
return
[
line
.
split
()
for
line
in
raw_text
.
split
(
'
\n
'
)]
# Defined in file: ./chapter_natural-language-processing-pretraining/word-embedding-dataset.md
def
subsampl
ing
(
sentences
,
vocab
):
#
Map low frequency words into <unk>
sentences
=
[[
vocab
.
idx_to_token
[
vocab
[
tk
]]
for
tk
in
line
]
def
subsampl
e
(
sentences
,
vocab
):
#
Exclude unknown tokens '<unk>'
sentences
=
[[
token
for
token
in
line
if
vocab
[
token
]
!=
vocab
.
unk
]
for
line
in
sentences
]
# Count the frequency for each word
counter
=
d2l
.
count_corpus
(
sentences
)
num_tokens
=
sum
(
counter
.
values
())
# Return True if
to keep this token
during subsampling
# Return True if
`token` is kept
during subsampling
def
keep
(
token
):
return
(
random
.
uniform
(
0
,
1
)
<
math
.
sqrt
(
1e-4
/
counter
[
token
]
*
num_tokens
))
# Now do the subsampling
return
[[
tk
for
tk
in
line
if
keep
(
tk
)]
for
line
in
sentences
]
return
([[
token
for
token
in
line
if
keep
(
token
)]
for
line
in
sentences
],
counter
)
# Defined in file: ./chapter_natural-language-processing-pretraining/word-embedding-dataset.md
def
get_centers_and_contexts
(
corpus
,
max_window_size
):
centers
,
contexts
=
[],
[]
for
line
in
corpus
:
#
Each sentence needs at least 2 words to form a "central target word
#
- context word" pair
#
To form a "center word--context word" pair, each sentence needs to
#
have at least 2 words
if
len
(
line
)
<
2
:
continue
centers
+=
line
for
i
in
range
(
len
(
line
)):
# Context window centered at
i
for
i
in
range
(
len
(
line
)):
# Context window centered at
`i`
window_size
=
random
.
randint
(
1
,
max_window_size
)
indices
=
list
(
range
(
max
(
0
,
i
-
window_size
),
min
(
len
(
line
),
i
+
1
+
window_size
)))
# Exclude the cent
ral target
word from the context words
# Exclude the cent
er
word from the context words
indices
.
remove
(
i
)
contexts
.
append
([
line
[
idx
]
for
idx
in
indices
])
return
centers
,
contexts
...
...
@@ -2117,15 +2139,17 @@ def get_centers_and_contexts(corpus, max_window_size):
# Defined in file: ./chapter_natural-language-processing-pretraining/word-embedding-dataset.md
class
RandomGenerator
:
"""
Draw a random int in [0, n]
according to n sampling weights."""
"""
Randomly draw among {1, ..., n}
according to n sampling weights."""
def
__init__
(
self
,
sampling_weights
):
self
.
population
=
list
(
range
(
len
(
sampling_weights
)))
# Exclude
self
.
population
=
list
(
range
(
1
,
len
(
sampling_weights
)
+
1
))
self
.
sampling_weights
=
sampling_weights
self
.
candidates
=
[]
self
.
i
=
0
def
draw
(
self
):
if
self
.
i
==
len
(
self
.
candidates
):
# Cache `k` random sampling results
self
.
candidates
=
random
.
choices
(
self
.
population
,
self
.
sampling_weights
,
k
=
10000
)
self
.
i
=
0
...
...
@@ -2134,9 +2158,11 @@ class RandomGenerator:
# Defined in file: ./chapter_natural-language-processing-pretraining/word-embedding-dataset.md
def
get_negatives
(
all_contexts
,
corpus
,
K
):
counter
=
d2l
.
count_corpus
(
corpus
)
sampling_weights
=
[
count
**
0.75
for
count
in
counter
.
values
()]
def
get_negatives
(
all_contexts
,
vocab
,
counter
,
K
):
# Sampling weights for words with indices 1, 2, ... (index 0 is the
# excluded unknown token) in the vocabulary
sampling_weights
=
[
counter
[
vocab
.
to_tokens
(
i
)]
**
0.75
for
i
in
range
(
1
,
len
(
vocab
))]
all_negatives
,
generator
=
[],
RandomGenerator
(
sampling_weights
)
for
contexts
in
all_contexts
:
negatives
=
[]
...
...
@@ -2169,11 +2195,12 @@ def load_data_ptb(batch_size, max_window_size, num_noise_words):
num_workers
=
d2l
.
get_dataloader_workers
()
sentences
=
read_ptb
()
vocab
=
d2l
.
Vocab
(
sentences
,
min_freq
=
10
)
subsampled
=
subsampling
(
sentences
,
vocab
)
subsampled
,
counter
=
subsample
(
sentences
,
vocab
)
corpus
=
[
vocab
[
line
]
for
line
in
subsampled
]
all_centers
,
all_contexts
=
get_centers_and_contexts
(
corpus
,
max_window_size
)
all_negatives
=
get_negatives
(
all_contexts
,
corpus
,
num_noise_words
)
all_negatives
=
get_negatives
(
all_contexts
,
vocab
,
counter
,
num_noise_words
)
class
PTBDataset
(
torch
.
utils
.
data
.
Dataset
):
def
__init__
(
self
,
centers
,
contexts
,
negatives
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录