提交 81f6b55a 编写于 作者: V Varuna Jayasiri

📚 knn-lm index

上级 2f937806
......@@ -21,4 +21,9 @@ So to run $k$NN-LM we need to:
* [Build an index](build_index.html) of $\big(f(c_i), w_i\big)$
* [Evaluate kNN-ML](eval_knn.html) using $k$NN seach on $\big(f(c_i), w_i\big)$
with $f(c_t)$
This experiment uses a small dataset so that we can run this without using up a few hundred giga-bytes
of disk space for the index.
The official implementation of $k$NN-LM can be found [here](https://github.com/urvashik/knnlm).
"""
"""
# Index $$\big(f(c_i), w_i\big)$
We store $f(c_i)$ and $w_i$ in memory mapped numpy arrays.
We find $f(c_i)$ nearest to $f(c_t)$ using [FAISS](https://github.com/facebookresearch/faiss).
FAISS indexes $\big(f(c_i), i\big)$ and we query it with $f(c_t)$.
"""
from typing import Optional
import faiss
......@@ -10,71 +18,130 @@ from labml_nn.transformers.knn.train_model import Configs
def load_experiment(run_uuid: str, checkpoint: Optional[int] = None):
"""
Load a saved experiment from [train model](train_model.html).
"""
# Create configurations object
conf = Configs()
# Load custom configurations used in the experiment
conf_dict = experiment.load_configs(run_uuid)
# We need to get inputs to the feed forward layer, $f(c_i)$
conf_dict['is_save_ff_input'] = True
# This experiment is just an evaluation; i.e. nothing is tracked or saved
experiment.evaluate()
# Initialize configurations
experiment.configs(conf, conf_dict, 'run')
# Set models for saving/loading
experiment.add_pytorch_models(get_modules(conf))
# Specify the experiment to load from
experiment.load(run_uuid, checkpoint)
# Start the experiment; this is when it actually loads models
experiment.start()
return conf
def gather_keys(conf: Configs):
"""
## Gather $\big(f(c_i), w_i\big)$ and save them in numpy arrays
*Note that these numpy arrays will take up a lot of space (even few hundred gigabytes)
depending on the size of your dataset*.
"""
# Dimensions of $f(c_i)$
d_model = conf.transformer.d_model
# Training data loader
data_loader = conf.trainer.data_loader
# Number of contexts; i.e. number of tokens in the training data minus one.
# $\big(f(c_i), w_i\big)$ for $i \in [2, T]$
n_keys = data_loader.data.shape[0] * data_loader.data.shape[1] - 1
# Numpy array for $f(c_i)$
keys_store = np.memmap(str(lab.get_data_path() / 'keys.npy'), dtype=np.float32, mode='w+', shape=(n_keys, d_model))
# Numpy array for $w_i$
vals_store = np.memmap(str(lab.get_data_path() / 'vals.npy'), dtype=np.int, mode='w+', shape=(n_keys, 1))
# Number of keys $f(c_i)$ collected
added = 0
with torch.no_grad():
# Loop through data
for i, batch in monit.enum("Collect data", data_loader, is_children_silent=True):
# $w_i$ the target labels
vals = batch[1].view(-1, 1)
# Input data moved to the device of the model
data = batch[0].to(conf.device)
# Run the model
_ = conf.model(data)
# Get $f(c_i)$
keys = conf.model.ff_input.view(-1, d_model)
keys = keys # / torch.sqrt((keys ** 2).sum(-1, keepdims=True) + 1e-10)
# Save keys, $f(c_i)$ in the memory mapped numpy array
keys_store[added: added + keys.shape[0]] = keys.cpu()
# Save values, $w_i$ in the memory mapped numpy array
vals_store[added: added + keys.shape[0]] = vals
# Increment the number of collected keys
added += keys.shape[0]
def build_index(conf: Configs, n_centeroids: int = 2048, code_size: int = 64, n_probe: int = 8, n_train: int = 200_000):
"""
## Build FAISS index
[Getting started](https://github.com/facebookresearch/faiss/wiki/Getting-started),
[faster search](https://github.com/facebookresearch/faiss/wiki/Faster-search),
and [lower memory footprint)(https://github.com/facebookresearch/faiss/wiki/Lower-memory-footprint)
tutorials on FAISS will help you learn more about FAISS usage.
"""
# Dimensions of $f(c_i)$
d_model = conf.transformer.d_model
# Training data loader
data_loader = conf.trainer.data_loader
# Number of contexts; i.e. number of tokens in the training data minus one.
# $\big(f(c_i), w_i\big)$ for $i \in [2, T]$
n_keys = data_loader.data.shape[0] * data_loader.data.shape[1] - 1
# Build an index with Verenoi cell based faster search with compression that
# doesn't store full vectors.
quantizer = faiss.IndexFlatL2(d_model)
index = faiss.IndexIVFPQ(quantizer, d_model, n_centeroids, code_size, 8)
index.nprobe = n_probe
# Load the memory mapped numpy array of keys
keys_store = np.memmap(str(lab.get_data_path() / 'keys.npy'), dtype=np.float32, mode='r', shape=(n_keys, d_model))
# Pick a random sample of keys to train the index with
random_sample = np.random.choice(np.arange(n_keys), size=[min(n_train, n_keys)], replace=False)
with monit.section('Train index'):
# Train the index to store the keys
index.train(keys_store[random_sample])
# Add keys to the index; $\big(f(c_i), i\big)$
for s in monit.iterate('Index', range(0, n_keys, 1024)):
e = min(s + 1024, n_keys)
# $f(c_i)$
keys = keys_store[s:e]
# $i$
idx = np.arange(s, e)
# Add to index
index.add_with_ids(keys, idx)
with monit.section('Save'):
# Save the index
faiss.write_index(index, str(lab.get_data_path() / 'faiss.index'))
def main():
# Load the experiment
conf = load_experiment('4984b85c20bf11eb877a69c1a03717cd')
# Set model to evaluation mode
conf.model.eval()
# Collect $\big(f(c_i), w_i\big)$
gather_keys(conf)
# Add them to the index for fast search
build_index(conf)
......
......@@ -251,6 +251,8 @@ def character():
def tiny_shakespeare(c: Configs):
"""
Initialize/load tiny shakespeare dataset
This dataset is from Andrej Karpathy's [char-rnn](https://github.com/karpathy/char-rnn) project.
"""
return TextFileDataset(
lab.get_data_path() / 'tiny_shakespeare.txt', c.tokenizer,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册