未验证 提交 26ea0c36 编写于 作者: J Jeff Wang 提交者: GitHub

Example doc update (#437)

* Add a new PyTorch example to show embedding.
Update pybind document.

* Change the function of embedding to take word
dictionary.
上级 f04cc550
# http://pytorch.org/tutorials/beginner/nlp/word_embeddings_tutorial.html?highlight=embedding
# The following tutorial is from the PyTorch site.
# =======================================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
# Import VisualDL
from visualdl import LogWriter
torch.manual_seed(1)
CONTEXT_SIZE = 2
EMBEDDING_DIM = 10
# We will use Shakespeare Sonnet 2
test_sentence = """When forty winters shall besiege thy brow,
And dig deep trenches in thy beauty's field,
Thy youth's proud livery so gazed on now,
Will be a totter'd weed of small worth held:
Then being asked, where all thy beauty lies,
Where all the treasure of thy lusty days;
To say, within thine own deep sunken eyes,
Were an all-eating shame, and thriftless praise.
How much more praise deserv'd thy beauty's use,
If thou couldst answer 'This fair child of mine
Shall sum my count, and make my old excuse,'
Proving his beauty by succession thine!
This were to be new made when thou art old,
And see thy blood warm when thou feel'st it cold.""".split()
# we should tokenize the input, but we will ignore that for now
# build a list of tuples. Each tuple is ([ word_i-2, word_i-1 ], target word)
trigrams = [([test_sentence[i], test_sentence[i + 1]], test_sentence[i + 2])
for i in range(len(test_sentence) - 2)]
# print the first 3, just so you can see what they look like
print(trigrams[:3])
vocab = set(test_sentence)
word_to_ix = {word: i for i, word in enumerate(vocab)}
class NGramLanguageModeler(nn.Module):
def __init__(self, vocab_size, embedding_dim, context_size):
super(NGramLanguageModeler, self).__init__()
self.embeddings = nn.Embedding(vocab_size, embedding_dim)
self.linear1 = nn.Linear(context_size * embedding_dim, 128)
self.linear2 = nn.Linear(128, vocab_size)
def forward(self, inputs):
embeds = self.embeddings(inputs).view((1, -1))
out = F.relu(self.linear1(embeds))
out = self.linear2(out)
log_probs = F.log_softmax(out, dim=1)
return log_probs
losses = []
loss_function = nn.NLLLoss()
model = NGramLanguageModeler(len(vocab), EMBEDDING_DIM, CONTEXT_SIZE)
optimizer = optim.SGD(model.parameters(), lr=0.001)
for epoch in range(10):
total_loss = torch.Tensor([0])
for context, target in trigrams:
# Step 1. Prepare the inputs to be passed to the model (i.e, turn the words
# into integer indices and wrap them in variables)
context_idxs = torch.tensor(
[word_to_ix[w] for w in context], dtype=torch.long)
# Step 2. Recall that torch *accumulates* gradients. Before passing in a
# new instance, you need to zero out the gradients from the old
# instance
model.zero_grad()
# Step 3. Run the forward pass, getting log probabilities over next
# words
log_probs = model(context_idxs)
# Step 4. Compute your loss function. (Again, Torch wants the target
# word wrapped in a variable)
loss = loss_function(
log_probs, torch.tensor([word_to_ix[target]], dtype=torch.long))
# Step 5. Do the backward pass and update the gradient
loss.backward()
optimizer.step()
# Get the Python number from a 1-element Tensor by calling tensor.item()
total_loss += loss.item()
losses.append(total_loss)
print(losses) # The loss decreased every iteration over the training data!
# VisualDL setup
logw = LogWriter("./embedding_log", sync_cycle=10000)
with logw.mode('train') as logger:
embedding = logger.embedding()
embeddings_list = model.embeddings.weight.data.numpy() # convert to numpy array
# VisualDL embedding log writer takes two parameters
# The first parameter is embedding list. The type is list[list[float]]
# The second parameter is word_dict. The type is dictionary<string, int>.
embedding.add_embeddings_with_word_dict(embeddings_list, word_to_ix)
......@@ -42,6 +42,12 @@ PYBIND11_MODULE(core, m) {
.. autoclass:: ImageWriter
:members:
.. autoclass:: TextWriter
:members:
.. autoclass:: AudioWriter
:members:
)pbdoc";
py::class_<vs::LogReader>(m, "LogReader")
......@@ -240,7 +246,7 @@ PYBIND11_MODULE(core, m) {
Add a record with the step and text value.
:param step: Current step value
:type index: integer
:type step: integer
:param text: Text record
:type text: basestring
)pbdoc");
......@@ -257,15 +263,25 @@ PYBIND11_MODULE(core, m) {
PyBind class. Must instantiate through the LogWriter.
)pbdoc")
.def("set_caption", &cp::Embedding::SetCaption)
.def(
"add_embeddings_with_word_list"
R"pbdoc(
Add embedding record. Each run can only store one embedding data.
:param embedding: hot vector of embedding words
:type embedding: list
)pbdoc",
&cp::Embedding::AddEmbeddingsWithWordList);
.def("add_embeddings_with_word_dict",
&cp::Embedding::AddEmbeddingsWithWordDict,
R"pbdoc(
Add the embedding record. Each run can only store one embedding data. **embeddings** and **word_dict** should be
the same length. The **word_dict** is used to find the word embedding index in **embeddings**::
embeddings = [[-1.5246837, -0.7505612, -0.65406495, -1.610278],
[-0.781105, -0.24952792, -0.22178008, 1.6906816]]
word_dict = {"Apple" : 0, "Orange": 1}
Shows that ``"Apple"`` is embedded to ``[-1.5246837, -0.7505612, -0.65406495, -1.610278]`` and
``"Orange"`` is embedded to ``[-0.781105, -0.24952792, -0.22178008, 1.6906816]``
:param embeddings: list of word embeddings
:type embeddings: list
:param word_dict: The mapping from words to indices.
:type word_dict: dictionary
)pbdoc");
py::class_<cp::EmbeddingReader>(m, "EmbeddingReader")
.def("get_all_labels", &cp::EmbeddingReader::get_all_labels)
......
......@@ -350,17 +350,19 @@ size_t TextReader::size() const { return reader_.total_records(); }
/*
* Embedding functions
*/
void Embedding::AddEmbeddingsWithWordList(
void Embedding::AddEmbeddingsWithWordDict(
const std::vector<std::vector<float>>& word_embeddings,
std::vector<std::string>& labels) {
for (int i = 0; i < word_embeddings.size(); i++) {
AddEmbedding(i, word_embeddings[i], labels[i]);
std::map<std::string, int>& word_dict) {
for (auto& word_index_pair : word_dict) {
AddEmbedding(word_index_pair.second,
word_embeddings[word_index_pair.second],
word_index_pair.first);
}
}
void Embedding::AddEmbedding(int item_id,
const std::vector<float>& one_hot_vector,
std::string& label) {
const std::string& label) {
auto record = tablet_.AddRecord();
record.SetId(item_id);
time_t time = std::time(nullptr);
......
......@@ -337,19 +337,14 @@ struct Embedding {
void SetCaption(const std::string cap) {
tablet_.SetCaptions(std::vector<std::string>({cap}));
}
// Add all word vectors along with all labels
// The index of labels should match with the index of word_embeddings
// EX: ["Apple", "Orange"] means the first item in word_embeddings represents
// "Apple"
void AddEmbeddingsWithWordList(
void AddEmbeddingsWithWordDict(
const std::vector<std::vector<float>>& word_embeddings,
std::vector<std::string>& labels);
// TODO: Create another function that takes 'word_embeddings' and 'word_dict'
std::map<std::string, int>& word_dict);
private:
void AddEmbedding(int item_id,
const std::vector<float>& one_hot_vector,
std::string& label);
const std::string& label);
Tablet tablet_;
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册