“b857ff1b45290c087a067b31fec7912d95b362ad”上不存在“git@gitcode.net:s920243400/PaddleDetection.git”
提交 f59c9b74 编写于 作者: S superjom

add image

上级 193df892
......@@ -47,6 +47,71 @@ template class ScalarReader<int64_t>;
template class ScalarReader<float>;
template class ScalarReader<double>;
void Image::StartSampling() {
step_ = writer_.AddRecord();
num_records_ = 0;
}
int Image::IsSampleTaken() {
num_records_++;
if (num_records_ <= num_samples_) {
return num_records_ - 1;
}
float prob = float(num_samples_) / num_records_;
float thre = (float)rand() / RAND_MAX;
if (prob < thre) {
// take this sample
int index = rand() % num_samples_;
return index;
}
return -1;
}
void Image::FinishSampling() {
// TODO(ChunweiYan) much optimizement here.
writer_.parent()->PersistToDisk();
}
template <typename T, typename U>
struct is_same_type {
static const bool value = false;
};
template <typename T>
struct is_same_type<T, T> {
static const bool value = true;
};
void Image::SetSample(int index,
const std::vector<int64_t>& shape,
const std::vector<value_t>& data) {
// production
int size = std::accumulate(
shape.begin(), shape.end(), 1., [](float a, float b) { return a * b; });
CHECK_EQ(size, data.size()) << "image's shape not match data";
CHECK_LT(index, num_samples_);
CHECK_LE(index, num_records_);
// set data
Entry<value_t> entry;
if (index == num_records_) {
// add one entry
entry = step_.AddData<value_t>();
} else {
entry = step_.MutableData<value_t>(index);
}
entry.SetMulti(data);
static_assert(
!is_same_type<value_t, int64_t>::value,
"value_t should not use int64_t field, this type is used to store shape");
// set meta with hack
Entry<int64_t> meta;
meta.set_parent(entry.parent());
meta.entry = entry.entry;
meta.SetMulti(shape);
}
} // namespace components
} // namespace visualdl
......@@ -140,6 +140,41 @@ private:
TabletReader reader_;
};
/*
* Image component writer.
*/
struct Image {
using value_t = float;
Image(Tablet tablet, int num_samples) : writer_(tablet) {
writer_.SetType(Tablet::Type::kImage);
writer_.SetNumSamples(num_samples);
num_samples_ = num_samples;
}
/*
* Start a sample period.
*/
void StartSampling();
/*
* Will this sample will be taken.
*/
int IsSampleTaken();
/*
* End a sample period.
*/
void FinishSampling();
void SetSample(int index,
const std::vector<int64_t>& shape,
const std::vector<value_t>& data);
private:
Tablet writer_;
Record step_;
int num_records_{0};
int num_samples_{0};
};
} // namespace components
} // namespace visualdl
......
......@@ -2,6 +2,8 @@
#include <gtest/gtest.h>
using namespace std;
namespace visualdl {
TEST(Scalar, write) {
......@@ -40,4 +42,27 @@ TEST(Scalar, write) {
ASSERT_EQ(scalar_reader1.caption(), "customized caption");
}
TEST(Image, test) {
const auto dir = "./tmp/sdk_test.image";
Writer writer__(dir, 1);
auto writer = writer__.AsMode("train");
auto tablet = writer.AddTablet("image0");
components::Image image(tablet, 3);
image.StartSampling();
for (int i = 0; i < 100; i++) {
vector<int64_t> shape({3, 5, 5});
vector<float> data;
for (int j = 0; j < 3 * 5 * 5; j++) {
data.push_back(float(rand()) / RAND_MAX);
}
int index = image.IsSampleTaken();
if (index != -1) {
image.SetSample(index, shape, data);
}
}
image.FinishSampling();
}
} // namespace visualdl
......@@ -10,6 +10,17 @@ namespace visualdl {
WRITE_GUARD \
}
#define IMPL_ENTRY_SETMUL(ctype__, dtype__, field__) \
template <> \
void Entry<ctype__>::SetMulti(const std::vector<ctype__>& vs) { \
entry->set_dtype(storage::DataType::dtype__); \
entry->clear_##field__(); \
for (auto v : vs) { \
entry->add_##field__(v); \
} \
WRITE_GUARD \
}
IMPL_ENTRY_SET_OR_ADD(Set, int, kInt32, set_i32);
IMPL_ENTRY_SET_OR_ADD(Set, int64_t, kInt64, set_i64);
IMPL_ENTRY_SET_OR_ADD(Set, bool, kBool, set_b);
......@@ -22,6 +33,12 @@ IMPL_ENTRY_SET_OR_ADD(Add, double, kDoubles, add_ds);
IMPL_ENTRY_SET_OR_ADD(Add, std::string, kStrings, add_ss);
IMPL_ENTRY_SET_OR_ADD(Add, bool, kBools, add_bs);
IMPL_ENTRY_SETMUL(int, kInt32, i32s);
IMPL_ENTRY_SETMUL(int64_t, kInt64, i64s);
IMPL_ENTRY_SETMUL(float, kFloat, fs);
IMPL_ENTRY_SETMUL(double, kDouble, ds);
IMPL_ENTRY_SETMUL(bool, kBool, bs);
#define IMPL_ENTRY_GET(T, fieldname__) \
template <> \
T EntryReader<T>::Get() const { \
......
......@@ -19,8 +19,9 @@ struct Entry {
storage::Entry* entry{nullptr};
Entry() {}
explicit Entry(storage::Entry* entry, Storage* parent)
: entry(entry), x_(parent) {}
Entry(storage::Entry* entry, Storage* parent) : entry(entry), x_(parent) {}
Entry(const Entry<T>& other) : entry(other.entry), x_(other.x_) {}
void operator()(storage::Entry* entry, Storage* parent) {
this->entry = entry;
x_ = parent;
......@@ -32,7 +33,10 @@ struct Entry {
// Add a value to repeated message field.
void Add(T v);
void SetMulti(const std::vector<T>& v);
Storage* parent() { return x_; }
void set_parent(Storage* x) { x_ = x; }
private:
Storage* x_;
......
......@@ -30,7 +30,9 @@ struct Record {
DECL_GUARD(Record)
Record() {}
Record(storage::Record* x, Storage* parent) : data_(x), x_(parent) {}
Record(const Record& other) : data_(other.data_), x_(other.x_) {}
// write operations
void SetTimeStamp(int64_t x) {
......@@ -59,6 +61,12 @@ struct Record {
return Entry<T>(data_->add_data(), parent());
}
template <typename T>
Entry<T> MutableData(int i) {
WRITE_GUARD
return Entry<T>(data_->mutable_data(i), parent());
}
Storage* parent() { return x_; }
private:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册