From f59c9b74fc27ed5f53f488cdca58ccd042dfb1ca Mon Sep 17 00:00:00 2001 From: superjom Date: Mon, 25 Dec 2017 18:47:40 +0800 Subject: [PATCH] add image --- visualdl/logic/sdk.cc | 65 ++++++++++++++++++++++++++++++++++++++ visualdl/logic/sdk.h | 35 ++++++++++++++++++++ visualdl/logic/sdk_test.cc | 25 +++++++++++++++ visualdl/storage/entry.cc | 17 ++++++++++ visualdl/storage/entry.h | 8 +++-- visualdl/storage/record.h | 8 +++++ 6 files changed, 156 insertions(+), 2 deletions(-) diff --git a/visualdl/logic/sdk.cc b/visualdl/logic/sdk.cc index 62c46ea6..e8a801ed 100644 --- a/visualdl/logic/sdk.cc +++ b/visualdl/logic/sdk.cc @@ -47,6 +47,71 @@ template class ScalarReader; template class ScalarReader; template class ScalarReader; +void Image::StartSampling() { + step_ = writer_.AddRecord(); + num_records_ = 0; +} + +int Image::IsSampleTaken() { + num_records_++; + if (num_records_ <= num_samples_) { + return num_records_ - 1; + } + float prob = float(num_samples_) / num_records_; + float thre = (float)rand() / RAND_MAX; + if (prob < thre) { + // take this sample + int index = rand() % num_samples_; + return index; + } + return -1; +} + +void Image::FinishSampling() { + // TODO(ChunweiYan) much optimizement here. + writer_.parent()->PersistToDisk(); +} + +template +struct is_same_type { + static const bool value = false; +}; +template +struct is_same_type { + static const bool value = true; +}; + +void Image::SetSample(int index, + const std::vector& shape, + const std::vector& data) { + // production + int size = std::accumulate( + shape.begin(), shape.end(), 1., [](float a, float b) { return a * b; }); + CHECK_EQ(size, data.size()) << "image's shape not match data"; + CHECK_LT(index, num_samples_); + CHECK_LE(index, num_records_); + + // set data + Entry entry; + if (index == num_records_) { + // add one entry + entry = step_.AddData(); + } else { + entry = step_.MutableData(index); + } + entry.SetMulti(data); + + static_assert( + !is_same_type::value, + "value_t should not use int64_t field, this type is used to store shape"); + + // set meta with hack + Entry meta; + meta.set_parent(entry.parent()); + meta.entry = entry.entry; + meta.SetMulti(shape); +} + } // namespace components } // namespace visualdl diff --git a/visualdl/logic/sdk.h b/visualdl/logic/sdk.h index 66135d23..1c8cad6c 100644 --- a/visualdl/logic/sdk.h +++ b/visualdl/logic/sdk.h @@ -140,6 +140,41 @@ private: TabletReader reader_; }; +/* + * Image component writer. + */ +struct Image { + using value_t = float; + + Image(Tablet tablet, int num_samples) : writer_(tablet) { + writer_.SetType(Tablet::Type::kImage); + writer_.SetNumSamples(num_samples); + num_samples_ = num_samples; + } + /* + * Start a sample period. + */ + void StartSampling(); + /* + * Will this sample will be taken. + */ + int IsSampleTaken(); + /* + * End a sample period. + */ + void FinishSampling(); + + void SetSample(int index, + const std::vector& shape, + const std::vector& data); + +private: + Tablet writer_; + Record step_; + int num_records_{0}; + int num_samples_{0}; +}; + } // namespace components } // namespace visualdl diff --git a/visualdl/logic/sdk_test.cc b/visualdl/logic/sdk_test.cc index 3a2b2f5e..8529ffd6 100644 --- a/visualdl/logic/sdk_test.cc +++ b/visualdl/logic/sdk_test.cc @@ -2,6 +2,8 @@ #include +using namespace std; + namespace visualdl { TEST(Scalar, write) { @@ -40,4 +42,27 @@ TEST(Scalar, write) { ASSERT_EQ(scalar_reader1.caption(), "customized caption"); } +TEST(Image, test) { + const auto dir = "./tmp/sdk_test.image"; + Writer writer__(dir, 1); + auto writer = writer__.AsMode("train"); + + auto tablet = writer.AddTablet("image0"); + components::Image image(tablet, 3); + + image.StartSampling(); + for (int i = 0; i < 100; i++) { + vector shape({3, 5, 5}); + vector data; + for (int j = 0; j < 3 * 5 * 5; j++) { + data.push_back(float(rand()) / RAND_MAX); + } + int index = image.IsSampleTaken(); + if (index != -1) { + image.SetSample(index, shape, data); + } + } + image.FinishSampling(); +} + } // namespace visualdl diff --git a/visualdl/storage/entry.cc b/visualdl/storage/entry.cc index 0d4fc8dd..4e1a29c7 100644 --- a/visualdl/storage/entry.cc +++ b/visualdl/storage/entry.cc @@ -10,6 +10,17 @@ namespace visualdl { WRITE_GUARD \ } +#define IMPL_ENTRY_SETMUL(ctype__, dtype__, field__) \ + template <> \ + void Entry::SetMulti(const std::vector& vs) { \ + entry->set_dtype(storage::DataType::dtype__); \ + entry->clear_##field__(); \ + for (auto v : vs) { \ + entry->add_##field__(v); \ + } \ + WRITE_GUARD \ + } + IMPL_ENTRY_SET_OR_ADD(Set, int, kInt32, set_i32); IMPL_ENTRY_SET_OR_ADD(Set, int64_t, kInt64, set_i64); IMPL_ENTRY_SET_OR_ADD(Set, bool, kBool, set_b); @@ -22,6 +33,12 @@ IMPL_ENTRY_SET_OR_ADD(Add, double, kDoubles, add_ds); IMPL_ENTRY_SET_OR_ADD(Add, std::string, kStrings, add_ss); IMPL_ENTRY_SET_OR_ADD(Add, bool, kBools, add_bs); +IMPL_ENTRY_SETMUL(int, kInt32, i32s); +IMPL_ENTRY_SETMUL(int64_t, kInt64, i64s); +IMPL_ENTRY_SETMUL(float, kFloat, fs); +IMPL_ENTRY_SETMUL(double, kDouble, ds); +IMPL_ENTRY_SETMUL(bool, kBool, bs); + #define IMPL_ENTRY_GET(T, fieldname__) \ template <> \ T EntryReader::Get() const { \ diff --git a/visualdl/storage/entry.h b/visualdl/storage/entry.h index 060b0382..ba9542de 100644 --- a/visualdl/storage/entry.h +++ b/visualdl/storage/entry.h @@ -19,8 +19,9 @@ struct Entry { storage::Entry* entry{nullptr}; Entry() {} - explicit Entry(storage::Entry* entry, Storage* parent) - : entry(entry), x_(parent) {} + Entry(storage::Entry* entry, Storage* parent) : entry(entry), x_(parent) {} + Entry(const Entry& other) : entry(other.entry), x_(other.x_) {} + void operator()(storage::Entry* entry, Storage* parent) { this->entry = entry; x_ = parent; @@ -32,7 +33,10 @@ struct Entry { // Add a value to repeated message field. void Add(T v); + void SetMulti(const std::vector& v); + Storage* parent() { return x_; } + void set_parent(Storage* x) { x_ = x; } private: Storage* x_; diff --git a/visualdl/storage/record.h b/visualdl/storage/record.h index 4e5fc7fd..31fae1bb 100644 --- a/visualdl/storage/record.h +++ b/visualdl/storage/record.h @@ -30,7 +30,9 @@ struct Record { DECL_GUARD(Record) + Record() {} Record(storage::Record* x, Storage* parent) : data_(x), x_(parent) {} + Record(const Record& other) : data_(other.data_), x_(other.x_) {} // write operations void SetTimeStamp(int64_t x) { @@ -59,6 +61,12 @@ struct Record { return Entry(data_->add_data(), parent()); } + template + Entry MutableData(int i) { + WRITE_GUARD + return Entry(data_->mutable_data(i), parent()); + } + Storage* parent() { return x_; } private: -- GitLab