提交 f59c9b74 编写于 作者: S superjom

add image

上级 193df892
...@@ -47,6 +47,71 @@ template class ScalarReader<int64_t>; ...@@ -47,6 +47,71 @@ template class ScalarReader<int64_t>;
template class ScalarReader<float>; template class ScalarReader<float>;
template class ScalarReader<double>; template class ScalarReader<double>;
void Image::StartSampling() {
step_ = writer_.AddRecord();
num_records_ = 0;
}
int Image::IsSampleTaken() {
num_records_++;
if (num_records_ <= num_samples_) {
return num_records_ - 1;
}
float prob = float(num_samples_) / num_records_;
float thre = (float)rand() / RAND_MAX;
if (prob < thre) {
// take this sample
int index = rand() % num_samples_;
return index;
}
return -1;
}
void Image::FinishSampling() {
// TODO(ChunweiYan) much optimizement here.
writer_.parent()->PersistToDisk();
}
template <typename T, typename U>
struct is_same_type {
static const bool value = false;
};
template <typename T>
struct is_same_type<T, T> {
static const bool value = true;
};
void Image::SetSample(int index,
const std::vector<int64_t>& shape,
const std::vector<value_t>& data) {
// production
int size = std::accumulate(
shape.begin(), shape.end(), 1., [](float a, float b) { return a * b; });
CHECK_EQ(size, data.size()) << "image's shape not match data";
CHECK_LT(index, num_samples_);
CHECK_LE(index, num_records_);
// set data
Entry<value_t> entry;
if (index == num_records_) {
// add one entry
entry = step_.AddData<value_t>();
} else {
entry = step_.MutableData<value_t>(index);
}
entry.SetMulti(data);
static_assert(
!is_same_type<value_t, int64_t>::value,
"value_t should not use int64_t field, this type is used to store shape");
// set meta with hack
Entry<int64_t> meta;
meta.set_parent(entry.parent());
meta.entry = entry.entry;
meta.SetMulti(shape);
}
} // namespace components } // namespace components
} // namespace visualdl } // namespace visualdl
...@@ -140,6 +140,41 @@ private: ...@@ -140,6 +140,41 @@ private:
TabletReader reader_; TabletReader reader_;
}; };
/*
* Image component writer.
*/
struct Image {
using value_t = float;
Image(Tablet tablet, int num_samples) : writer_(tablet) {
writer_.SetType(Tablet::Type::kImage);
writer_.SetNumSamples(num_samples);
num_samples_ = num_samples;
}
/*
* Start a sample period.
*/
void StartSampling();
/*
* Will this sample will be taken.
*/
int IsSampleTaken();
/*
* End a sample period.
*/
void FinishSampling();
void SetSample(int index,
const std::vector<int64_t>& shape,
const std::vector<value_t>& data);
private:
Tablet writer_;
Record step_;
int num_records_{0};
int num_samples_{0};
};
} // namespace components } // namespace components
} // namespace visualdl } // namespace visualdl
......
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
using namespace std;
namespace visualdl { namespace visualdl {
TEST(Scalar, write) { TEST(Scalar, write) {
...@@ -40,4 +42,27 @@ TEST(Scalar, write) { ...@@ -40,4 +42,27 @@ TEST(Scalar, write) {
ASSERT_EQ(scalar_reader1.caption(), "customized caption"); ASSERT_EQ(scalar_reader1.caption(), "customized caption");
} }
TEST(Image, test) {
const auto dir = "./tmp/sdk_test.image";
Writer writer__(dir, 1);
auto writer = writer__.AsMode("train");
auto tablet = writer.AddTablet("image0");
components::Image image(tablet, 3);
image.StartSampling();
for (int i = 0; i < 100; i++) {
vector<int64_t> shape({3, 5, 5});
vector<float> data;
for (int j = 0; j < 3 * 5 * 5; j++) {
data.push_back(float(rand()) / RAND_MAX);
}
int index = image.IsSampleTaken();
if (index != -1) {
image.SetSample(index, shape, data);
}
}
image.FinishSampling();
}
} // namespace visualdl } // namespace visualdl
...@@ -10,6 +10,17 @@ namespace visualdl { ...@@ -10,6 +10,17 @@ namespace visualdl {
WRITE_GUARD \ WRITE_GUARD \
} }
#define IMPL_ENTRY_SETMUL(ctype__, dtype__, field__) \
template <> \
void Entry<ctype__>::SetMulti(const std::vector<ctype__>& vs) { \
entry->set_dtype(storage::DataType::dtype__); \
entry->clear_##field__(); \
for (auto v : vs) { \
entry->add_##field__(v); \
} \
WRITE_GUARD \
}
IMPL_ENTRY_SET_OR_ADD(Set, int, kInt32, set_i32); IMPL_ENTRY_SET_OR_ADD(Set, int, kInt32, set_i32);
IMPL_ENTRY_SET_OR_ADD(Set, int64_t, kInt64, set_i64); IMPL_ENTRY_SET_OR_ADD(Set, int64_t, kInt64, set_i64);
IMPL_ENTRY_SET_OR_ADD(Set, bool, kBool, set_b); IMPL_ENTRY_SET_OR_ADD(Set, bool, kBool, set_b);
...@@ -22,6 +33,12 @@ IMPL_ENTRY_SET_OR_ADD(Add, double, kDoubles, add_ds); ...@@ -22,6 +33,12 @@ IMPL_ENTRY_SET_OR_ADD(Add, double, kDoubles, add_ds);
IMPL_ENTRY_SET_OR_ADD(Add, std::string, kStrings, add_ss); IMPL_ENTRY_SET_OR_ADD(Add, std::string, kStrings, add_ss);
IMPL_ENTRY_SET_OR_ADD(Add, bool, kBools, add_bs); IMPL_ENTRY_SET_OR_ADD(Add, bool, kBools, add_bs);
IMPL_ENTRY_SETMUL(int, kInt32, i32s);
IMPL_ENTRY_SETMUL(int64_t, kInt64, i64s);
IMPL_ENTRY_SETMUL(float, kFloat, fs);
IMPL_ENTRY_SETMUL(double, kDouble, ds);
IMPL_ENTRY_SETMUL(bool, kBool, bs);
#define IMPL_ENTRY_GET(T, fieldname__) \ #define IMPL_ENTRY_GET(T, fieldname__) \
template <> \ template <> \
T EntryReader<T>::Get() const { \ T EntryReader<T>::Get() const { \
......
...@@ -19,8 +19,9 @@ struct Entry { ...@@ -19,8 +19,9 @@ struct Entry {
storage::Entry* entry{nullptr}; storage::Entry* entry{nullptr};
Entry() {} Entry() {}
explicit Entry(storage::Entry* entry, Storage* parent) Entry(storage::Entry* entry, Storage* parent) : entry(entry), x_(parent) {}
: entry(entry), x_(parent) {} Entry(const Entry<T>& other) : entry(other.entry), x_(other.x_) {}
void operator()(storage::Entry* entry, Storage* parent) { void operator()(storage::Entry* entry, Storage* parent) {
this->entry = entry; this->entry = entry;
x_ = parent; x_ = parent;
...@@ -32,7 +33,10 @@ struct Entry { ...@@ -32,7 +33,10 @@ struct Entry {
// Add a value to repeated message field. // Add a value to repeated message field.
void Add(T v); void Add(T v);
void SetMulti(const std::vector<T>& v);
Storage* parent() { return x_; } Storage* parent() { return x_; }
void set_parent(Storage* x) { x_ = x; }
private: private:
Storage* x_; Storage* x_;
......
...@@ -30,7 +30,9 @@ struct Record { ...@@ -30,7 +30,9 @@ struct Record {
DECL_GUARD(Record) DECL_GUARD(Record)
Record() {}
Record(storage::Record* x, Storage* parent) : data_(x), x_(parent) {} Record(storage::Record* x, Storage* parent) : data_(x), x_(parent) {}
Record(const Record& other) : data_(other.data_), x_(other.x_) {}
// write operations // write operations
void SetTimeStamp(int64_t x) { void SetTimeStamp(int64_t x) {
...@@ -59,6 +61,12 @@ struct Record { ...@@ -59,6 +61,12 @@ struct Record {
return Entry<T>(data_->add_data(), parent()); return Entry<T>(data_->add_data(), parent());
} }
template <typename T>
Entry<T> MutableData(int i) {
WRITE_GUARD
return Entry<T>(data_->mutable_data(i), parent());
}
Storage* parent() { return x_; } Storage* parent() { return x_; }
private: private:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册