Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Greenplum
Annotated Deep Learning Paper Implementations
提交
423f5c62
A
Annotated Deep Learning Paper Implementations
项目概览
Greenplum
/
Annotated Deep Learning Paper Implementations
10 个月 前同步成功
通知
6
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
A
Annotated Deep Learning Paper Implementations
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
423f5c62
编写于
11月 08, 2020
作者:
V
Varuna Jayasiri
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
✨
knn
上级
b7b281a3
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
168 addition
and
0 deletion
+168
-0
labml_nn/transformers/knn/build_index.py
labml_nn/transformers/knn/build_index.py
+82
-0
labml_nn/transformers/knn/eval_knn.py
labml_nn/transformers/knn/eval_knn.py
+86
-0
未找到文件。
labml_nn/transformers/knn/build_index.py
0 → 100644
浏览文件 @
423f5c62
from
typing
import
Optional
import
faiss
import
numpy
as
np
import
torch
from
labml
import
experiment
,
monit
,
lab
from
labml.utils.pytorch
import
get_modules
from
labml_nn.transformers.knn.train_model
import
Configs
def
load_experiment
(
run_uuid
:
str
,
checkpoint
:
Optional
[
int
]
=
None
):
conf
=
Configs
()
conf_dict
=
experiment
.
load_configs
(
run_uuid
)
conf_dict
[
'is_save_ff_input'
]
=
True
experiment
.
evaluate
()
experiment
.
configs
(
conf
,
conf_dict
,
'run'
)
experiment
.
add_pytorch_models
(
get_modules
(
conf
))
experiment
.
load
(
run_uuid
,
checkpoint
)
experiment
.
start
()
return
conf
def
gather_keys
(
conf
:
Configs
):
d_model
=
conf
.
transformer
.
d_model
data_loader
=
conf
.
trainer
.
data_loader
n_keys
=
data_loader
.
data
.
shape
[
0
]
*
data_loader
.
data
.
shape
[
1
]
-
1
keys_store
=
np
.
memmap
(
str
(
lab
.
get_data_path
()
/
'keys.npy'
),
dtype
=
np
.
float32
,
mode
=
'w+'
,
shape
=
(
n_keys
,
d_model
))
vals_store
=
np
.
memmap
(
str
(
lab
.
get_data_path
()
/
'vals.npy'
),
dtype
=
np
.
int
,
mode
=
'w+'
,
shape
=
(
n_keys
,
1
))
added
=
0
with
torch
.
no_grad
():
for
i
,
batch
in
monit
.
enum
(
"Collect data"
,
data_loader
,
is_children_silent
=
True
):
vals
=
batch
[
1
].
view
(
-
1
,
1
)
data
=
batch
[
0
].
to
(
conf
.
device
)
_
=
conf
.
model
(
data
)
keys
=
conf
.
model
.
ff_input
.
view
(
-
1
,
d_model
)
keys
=
keys
# / torch.sqrt((keys ** 2).sum(-1, keepdims=True) + 1e-10)
keys_store
[
added
:
added
+
keys
.
shape
[
0
]]
=
keys
.
cpu
()
vals_store
[
added
:
added
+
keys
.
shape
[
0
]]
=
vals
added
+=
keys
.
shape
[
0
]
def
build_index
(
conf
:
Configs
,
n_centeroids
:
int
=
2048
,
code_size
:
int
=
64
,
n_probe
:
int
=
8
,
n_train
:
int
=
200_000
):
d_model
=
conf
.
transformer
.
d_model
data_loader
=
conf
.
trainer
.
data_loader
n_keys
=
data_loader
.
data
.
shape
[
0
]
*
data_loader
.
data
.
shape
[
1
]
-
1
quantizer
=
faiss
.
IndexFlatL2
(
d_model
)
index
=
faiss
.
IndexIVFPQ
(
quantizer
,
d_model
,
n_centeroids
,
code_size
,
8
)
index
.
nprobe
=
n_probe
keys_store
=
np
.
memmap
(
str
(
lab
.
get_data_path
()
/
'keys.npy'
),
dtype
=
np
.
float32
,
mode
=
'r'
,
shape
=
(
n_keys
,
d_model
))
random_sample
=
np
.
random
.
choice
(
np
.
arange
(
n_keys
),
size
=
[
min
(
n_train
,
n_keys
)],
replace
=
False
)
with
monit
.
section
(
'Train index'
):
index
.
train
(
keys_store
[
random_sample
])
for
s
in
monit
.
iterate
(
'Index'
,
range
(
0
,
n_keys
,
1024
)):
e
=
min
(
s
+
1024
,
n_keys
)
keys
=
keys_store
[
s
:
e
]
idx
=
np
.
arange
(
s
,
e
)
index
.
add_with_ids
(
keys
,
idx
)
with
monit
.
section
(
'Save'
):
faiss
.
write_index
(
index
,
str
(
lab
.
get_data_path
()
/
'faiss.index'
))
def
main
():
conf
=
load_experiment
(
'4984b85c20bf11eb877a69c1a03717cd'
)
conf
.
model
.
eval
()
gather_keys
(
conf
)
build_index
(
conf
)
if
__name__
==
'__main__'
:
main
()
labml_nn/transformers/knn/eval_knn.py
0 → 100644
浏览文件 @
423f5c62
from
typing
import
Optional
,
List
import
faiss
import
numpy
as
np
import
torch
from
torch.nn
import
functional
as
F
from
labml
import
monit
,
tracker
,
lab
from
labml.logger
import
inspect
from
labml_nn.transformers.knn.train_model
import
Configs
def
knn
(
queries
:
torch
.
Tensor
,
index
:
faiss
.
IndexFlatL2
,
keys_store
:
np
.
ndarray
,
vals_store
:
np
.
ndarray
,
n_tokens
:
int
):
queries_shape
=
queries
.
shape
queries
=
queries
.
view
(
-
1
,
queries_shape
[
-
1
])
distance
,
idx
=
index
.
search
(
queries
.
numpy
(),
10
)
keys_found
=
queries
.
new_tensor
(
keys_store
[
idx
])
vals_found
=
torch
.
tensor
(
vals_store
[
idx
]).
squeeze
(
-
1
)
keys_found_n
=
keys_found
/
torch
.
sqrt
((
keys_found
**
2
).
sum
(
-
1
,
keepdims
=
True
)
+
1e-10
)
queries_n
=
queries
/
torch
.
sqrt
((
queries
**
2
).
sum
(
-
1
,
keepdims
=
True
)
+
1e-10
)
dot_prod
=
(
keys_found_n
*
queries_n
.
unsqueeze
(
1
)).
sum
(
-
1
)
logits_token
=
dot_prod
.
new_zeros
(
queries
.
shape
[
0
],
n_tokens
)
_
=
logits_token
.
scatter_
(
dim
=
1
,
index
=
vals_found
,
src
=
dot_prod
,
reduce
=
'add'
)
logits_token
=
logits_token
.
reshape
(
queries_shape
[
0
],
queries_shape
[
1
],
-
1
)
return
logits_token
def
validation_loss
(
coef
:
List
[
float
],
last_n
:
Optional
[
int
],
conf
:
Configs
,
index
:
faiss
.
IndexFlatL2
,
keys_store
:
np
.
ndarray
,
vals_store
:
np
.
ndarray
):
valid
=
conf
.
validator
losses
=
[[]
for
_
in
coef
]
samples
=
[]
with
torch
.
no_grad
():
with
tracker
.
namespace
(
'valid'
):
for
i
,
batch
in
monit
.
enum
(
"Validation"
,
valid
.
data_loader
,
is_children_silent
=
True
):
data
,
tgt
=
batch
[
0
].
to
(
conf
.
device
),
batch
[
1
].
to
(
conf
.
device
)
res
=
conf
.
model
(
data
)
res_knn
=
knn
(
conf
.
model
.
ff_input
.
cpu
(),
index
,
keys_store
,
vals_store
,
conf
.
n_tokens
)
res_knn
=
res_knn
.
to
(
conf
.
device
)
if
last_n
:
res
=
res
[
-
last_n
:]
res_knn
=
res_knn
[
-
last_n
:]
tgt
=
tgt
[
-
last_n
:]
s
=
res
.
shape
[
0
]
*
data
.
shape
[
1
]
samples
.
append
(
s
)
for
i
,
c
in
enumerate
(
coef
):
loss
=
conf
.
loss_func
(
res_knn
*
c
+
(
1
-
c
)
*
res
,
tgt
)
losses
[
i
].
append
(
loss
*
s
)
inspect
({
c
:
np
.
sum
(
losses
[
i
])
/
np
.
sum
(
samples
)
for
i
,
c
in
enumerate
(
coef
)})
return
losses
,
samples
def
load_index
(
conf
:
Configs
,
n_probe
:
int
=
8
):
d_model
=
conf
.
transformer
.
d_model
data_loader
=
conf
.
trainer
.
data_loader
n_keys
=
data_loader
.
data
.
shape
[
0
]
*
data_loader
.
data
.
shape
[
1
]
-
1
with
monit
.
section
(
'Load index'
):
index
=
faiss
.
read_index
(
str
(
lab
.
get_data_path
()
/
'faiss.index'
))
index
.
nprobe
=
n_probe
keys_store
=
np
.
memmap
(
str
(
lab
.
get_data_path
()
/
'keys.npy'
),
dtype
=
np
.
float32
,
mode
=
'r'
,
shape
=
(
n_keys
,
d_model
))
vals_store
=
np
.
memmap
(
str
(
lab
.
get_data_path
()
/
'vals.npy'
),
dtype
=
np
.
int
,
mode
=
'r'
,
shape
=
(
n_keys
,
1
))
return
index
,
keys_store
,
vals_store
def
main
():
from
labml_nn.transformers.knn.build_index
import
load_experiment
conf
=
load_experiment
(
'4984b85c20bf11eb877a69c1a03717cd'
)
conf
.
model
.
eval
()
index
,
keys_store
,
vals_store
=
load_index
(
conf
)
validation_loss
([
i
/
20
for
i
in
range
(
10
)],
None
,
conf
,
index
,
keys_store
,
vals_store
)
if
__name__
==
'__main__'
:
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录