提交 a8cadfb7 编写于 作者: J Jeff Wang 提交者: daminglu

Embedding visualization (#351)

上级 4e6e1949
......@@ -81,14 +81,6 @@ struct HistogramBuilder {
T right_boundary{std::numeric_limits<T>::min()};
std::vector<int> buckets;
void Get(size_t n, T* left, T* right, int* frequency) {
CHECK(!buckets.empty()) << "need to CreateBuckets first.";
CHECK_LT(n, num_buckets_) << "n out of range.";
*left = left_boundary + span_ * n;
*right = *left + span_;
*frequency = buckets[n];
}
private:
// Get the left and right boundaries.
void UpdateBoundary(const std::vector<T>& data) {
......@@ -106,9 +98,11 @@ private:
(float)left_boundary / num_buckets_;
buckets.resize(num_buckets_);
// Go through the data, increase the item count in a bucket.
for (auto v : data) {
int offset = std::min(int((v - left_boundary) / span_), num_buckets_ - 1);
buckets[offset]++;
int bucket_group_index =
std::min(int((v - left_boundary) / span_), num_buckets_ - 1);
buckets[bucket_group_index]++;
}
}
......
......@@ -84,9 +84,14 @@ PYBIND11_MODULE(core, m) {
auto tablet = self.tablet(tag);
return vs::components::TextReader(tablet);
})
.def("get_audio", [](vs::LogReader& self, const std::string& tag) {
.def("get_audio",
[](vs::LogReader& self, const std::string& tag) {
auto tablet = self.tablet(tag);
return vs::components::AudioReader(self.mode(), tablet);
})
.def("get_embedding", [](vs::LogReader& self, const std::string& tag) {
auto tablet = self.tablet(tag);
return vs::components::AudioReader(self.mode(), tablet);
return vs::components::EmbeddingReader(tablet);
});
// clang-format on
......@@ -136,7 +141,11 @@ PYBIND11_MODULE(core, m) {
int step_cycle) {
auto tablet = self.AddTablet(tag);
return vs::components::Audio(tablet, num_samples, step_cycle);
});
})
.def("new_embedding", [](vs::LogWriter& self, const std::string& tag) {
auto tablet = self.AddTablet(tag);
return vs::components::Embedding(tablet);
});
//------------------- components --------------------
#define ADD_SCALAR_READER(T) \
......@@ -233,6 +242,20 @@ PYBIND11_MODULE(core, m) {
.def("total_records", &cp::TextReader::total_records)
.def("size", &cp::TextReader::size);
py::class_<cp::Embedding>(m, "EmbeddingWriter")
.def("set_caption", &cp::Embedding::SetCaption)
.def("add_embeddings_with_word_list",
&cp::Embedding::AddEmbeddingsWithWordList);
py::class_<cp::EmbeddingReader>(m, "EmbeddingReader")
.def("get_all_labels", &cp::EmbeddingReader::get_all_labels)
.def("get_all_embeddings", &cp::EmbeddingReader::get_all_embeddings)
.def("ids", &cp::EmbeddingReader::ids)
.def("timestamps", &cp::EmbeddingReader::timestamps)
.def("caption", &cp::EmbeddingReader::caption)
.def("total_records", &cp::EmbeddingReader::total_records)
.def("size", &cp::EmbeddingReader::size);
py::class_<cp::Audio>(m, "AudioWriter", R"pbdoc(
PyBind class. Must instantiate through the LogWriter.
)pbdoc")
......
......@@ -347,6 +347,79 @@ std::string TextReader::caption() const {
size_t TextReader::size() const { return reader_.total_records(); }
/*
* Embedding functions
*/
void Embedding::AddEmbeddingsWithWordList(
const std::vector<std::vector<float>>& word_embeddings,
std::vector<std::string>& labels) {
for (int i = 0; i < word_embeddings.size(); i++) {
AddEmbedding(i, word_embeddings[i], labels[i]);
}
}
void Embedding::AddEmbedding(int item_id,
const std::vector<float>& one_hot_vector,
std::string& label) {
auto record = tablet_.AddRecord();
record.SetId(item_id);
time_t time = std::time(nullptr);
record.SetTimeStamp(time);
auto entry = record.AddData();
entry.SetMulti<float>(one_hot_vector);
entry.SetRaw(label);
}
/*
* EmbeddingReader functions
*/
std::vector<std::string> EmbeddingReader::get_all_labels() const {
std::vector<std::string> result;
for (int i = 0; i < total_records(); i++) {
auto record = reader_.record(i);
auto entry = record.data(0);
result.push_back(entry.GetRaw());
}
return result;
}
std::vector<std::vector<float>> EmbeddingReader::get_all_embeddings() const {
std::vector<std::vector<float>> result;
for (int i = 0; i < total_records(); i++) {
auto record = reader_.record(i);
auto entry = record.data(0);
auto tensors = entry.GetMulti<float>();
result.push_back(tensors);
}
return result;
}
std::vector<int> EmbeddingReader::ids() const {
std::vector<int> res;
for (int i = 0; i < reader_.total_records(); i++) {
res.push_back(reader_.record(i).id());
}
return res;
}
std::vector<time_t> EmbeddingReader::timestamps() const {
std::vector<time_t> res;
for (int i = 0; i < reader_.total_records(); i++) {
res.push_back(reader_.record(i).timestamp());
}
return res;
}
std::string EmbeddingReader::caption() const {
CHECK(!reader_.captions().empty()) << "no caption";
return reader_.captions().front();
}
size_t EmbeddingReader::size() const { return reader_.total_records(); }
void Audio::StartSampling() {
if (!ToSampleThisStep()) return;
......
......@@ -327,6 +327,52 @@ private:
TabletReader reader_;
};
/*
* Embedding component writer
*/
struct Embedding {
Embedding(Tablet tablet) : tablet_(tablet) {
tablet_.SetType(Tablet::Type::kEmbedding);
}
void SetCaption(const std::string cap) {
tablet_.SetCaptions(std::vector<std::string>({cap}));
}
// Add all word vectors along with all labels
// The index of labels should match with the index of word_embeddings
// EX: ["Apple", "Orange"] means the first item in word_embeddings represents
// "Apple"
void AddEmbeddingsWithWordList(
const std::vector<std::vector<float>>& word_embeddings,
std::vector<std::string>& labels);
// TODO: Create another function that takes 'word_embeddings' and 'word_dict'
private:
void AddEmbedding(int item_id,
const std::vector<float>& one_hot_vector,
std::string& label);
Tablet tablet_;
};
/*
* Embedding Reader.
*/
struct EmbeddingReader {
EmbeddingReader(TabletReader reader) : reader_(reader) {}
std::vector<int> ids() const;
std::vector<std::string> get_all_labels() const;
std::vector<std::vector<float>> get_all_embeddings() const;
std::vector<time_t> timestamps() const;
std::string caption() const;
size_t total_records() const { return reader_.total_records(); }
size_t size() const;
private:
TabletReader reader_;
};
/*
* Image component writer.
*/
......
......@@ -119,6 +119,10 @@ class LogReader(object):
check_tag_name_valid(tag)
return self.reader.get_text(tag)
def embedding(self, tag):
check_tag_name_valid(tag)
return self.reader.get_embedding(tag)
def audio(self, tag):
"""
Get an audio reader with tag
......@@ -256,6 +260,10 @@ class LogWriter(object):
check_tag_name_valid(tag)
return self.writer.new_text(tag)
def embedding(self, tag):
check_tag_name_valid(tag)
return self.writer.new_embedding(tag)
def save(self):
self.writer.save()
......
......@@ -109,6 +109,7 @@ message Tablet {
kImage = 2;
kText = 3;
kAudio = 4;
kEmbedding = 5;
}
// The unique identification for this `Tablet`. VisualDL will have no the
// concept of FileWriter like TB. It will store all the tablets in a single
......
......@@ -35,6 +35,7 @@ struct Tablet {
kImage = 2,
kText = 3,
kAudio = 4,
kEmbedding = 5,
kUnknown = -1
};
......@@ -59,6 +60,9 @@ struct Tablet {
if (name == "audio") {
return kAudio;
}
if (name == "embedding") {
return kEmbedding;
}
LOG(ERROR) << "unknown component: " << name;
return kUnknown;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册