提交 194e3792 编写于 作者: S superjom

image component ready

上级 7a08e5e9
...@@ -49,6 +49,10 @@ template class ScalarReader<double>; ...@@ -49,6 +49,10 @@ template class ScalarReader<double>;
void Image::StartSampling() { void Image::StartSampling() {
step_ = writer_.AddRecord(); step_ = writer_.AddRecord();
// resize record
for (int i = 0; i < num_samples_; i++) {
step_.AddData<value_t>();
}
num_records_ = 0; num_records_ = 0;
} }
...@@ -58,8 +62,8 @@ int Image::IsSampleTaken() { ...@@ -58,8 +62,8 @@ int Image::IsSampleTaken() {
return num_records_ - 1; return num_records_ - 1;
} }
float prob = float(num_samples_) / num_records_; float prob = float(num_samples_) / num_records_;
float thre = (float)rand() / RAND_MAX; float randv = (float)rand() / RAND_MAX;
if (prob < thre) { if (randv < prob) {
// take this sample // take this sample
int index = rand() % num_samples_; int index = rand() % num_samples_;
return index; return index;
...@@ -82,7 +86,7 @@ struct is_same_type<T, T> { ...@@ -82,7 +86,7 @@ struct is_same_type<T, T> {
}; };
void Image::SetSample(int index, void Image::SetSample(int index,
const std::vector<int64_t>& shape, const std::vector<shape_t>& shape,
const std::vector<value_t>& data) { const std::vector<value_t>& data) {
// production // production
int size = std::accumulate( int size = std::accumulate(
...@@ -92,26 +96,32 @@ void Image::SetSample(int index, ...@@ -92,26 +96,32 @@ void Image::SetSample(int index,
CHECK_LE(index, num_records_); CHECK_LE(index, num_records_);
// set data // set data
Entry<value_t> entry; auto entry = step_.MutableData<value_t>(index);
if (index == num_records_) {
// add one entry
entry = step_.AddData<value_t>();
} else {
entry = step_.MutableData<value_t>(index);
}
entry.SetMulti(data); entry.SetMulti(data);
static_assert( static_assert(
!is_same_type<value_t, int64_t>::value, !is_same_type<value_t, shape_t>::value,
"value_t should not use int64_t field, this type is used to store shape"); "value_t should not use int64_t field, this type is used to store shape");
// set meta with hack // set meta with hack
Entry<int64_t> meta; Entry<shape_t> meta;
meta.set_parent(entry.parent()); meta.set_parent(entry.parent());
meta.entry = entry.entry; meta.entry = entry.entry;
meta.SetMulti(shape); meta.SetMulti(shape);
} }
std::vector<ImageReader::value_t> ImageReader::data(int step, int index) {
auto record = reader_.record(step);
auto entry = record.data<value_t>(index);
return entry.GetMulti();
}
std::vector<ImageReader::shape_t> ImageReader::shape(int step, int index) {
auto record = reader_.record(step);
auto entry = record.data<shape_t>(index);
return entry.GetMulti();
}
} // namespace components } // namespace components
} // namespace visualdl } // namespace visualdl
...@@ -145,12 +145,16 @@ private: ...@@ -145,12 +145,16 @@ private:
*/ */
struct Image { struct Image {
using value_t = float; using value_t = float;
using shape_t = int64_t;
Image(Tablet tablet, int num_samples) : writer_(tablet) { Image(Tablet tablet, int num_samples) : writer_(tablet) {
writer_.SetType(Tablet::Type::kImage); writer_.SetType(Tablet::Type::kImage);
writer_.SetNumSamples(num_samples); writer_.SetNumSamples(num_samples);
num_samples_ = num_samples; num_samples_ = num_samples;
} }
void SetCaption(const std::string& c) {
writer_.SetCaptions(std::vector<std::string>({c}));
}
/* /*
* Start a sample period. * Start a sample period.
*/ */
...@@ -165,7 +169,7 @@ struct Image { ...@@ -165,7 +169,7 @@ struct Image {
void FinishSampling(); void FinishSampling();
void SetSample(int index, void SetSample(int index,
const std::vector<int64_t>& shape, const std::vector<shape_t>& shape,
const std::vector<value_t>& data); const std::vector<value_t>& data);
private: private:
...@@ -175,6 +179,31 @@ private: ...@@ -175,6 +179,31 @@ private:
int num_samples_{0}; int num_samples_{0};
}; };
/*
* Image reader.
*/
struct ImageReader {
using value_t = typename Image::value_t;
using shape_t = typename Image::shape_t;
ImageReader(TabletReader tablet) : reader_(tablet) {}
std::string caption() {
CHECK_EQ(reader_.captions().size(), 1);
return reader_.captions().front();
}
// number of steps.
int num_records() { return reader_.total_records(); }
std::vector<value_t> data(int step, int index);
std::vector<shape_t> shape(int step, int index);
private:
TabletReader reader_;
};
} // namespace components } // namespace components
} // namespace visualdl } // namespace visualdl
......
...@@ -44,25 +44,39 @@ TEST(Scalar, write) { ...@@ -44,25 +44,39 @@ TEST(Scalar, write) {
TEST(Image, test) { TEST(Image, test) {
const auto dir = "./tmp/sdk_test.image"; const auto dir = "./tmp/sdk_test.image";
Writer writer__(dir, 1); Writer writer__(dir, 4);
auto writer = writer__.AsMode("train"); auto writer = writer__.AsMode("train");
auto tablet = writer.AddTablet("image0"); auto tablet = writer.AddTablet("image0");
components::Image image(tablet, 3); components::Image image(tablet, 3);
const int num_steps = 10;
image.StartSampling(); LOG(INFO) << "write images";
for (int i = 0; i < 100; i++) { image.SetCaption("this is an image");
vector<int64_t> shape({3, 5, 5}); for (int step = 0; step < num_steps; step++) {
vector<float> data; image.StartSampling();
for (int j = 0; j < 3 * 5 * 5; j++) { for (int i = 0; i < 7; i++) {
data.push_back(float(rand()) / RAND_MAX); vector<int64_t> shape({3, 5, 5});
} vector<float> data;
int index = image.IsSampleTaken(); for (int j = 0; j < 3 * 5 * 5; j++) {
if (index != -1) { data.push_back(float(rand()) / RAND_MAX);
image.SetSample(index, shape, data); }
int index = image.IsSampleTaken();
if (index != -1) {
image.SetSample(index, shape, data);
}
} }
image.FinishSampling();
} }
image.FinishSampling();
LOG(INFO) << "read images";
// read it
Reader reader__(dir);
auto reader = reader__.AsMode("train");
auto tablet2read = reader.tablet("image0");
components::ImageReader image2read(tablet2read);
CHECK_EQ(image2read.caption(), "this is an image");
CHECK_EQ(image2read.num_records(), num_steps);
} }
} // namespace visualdl } // namespace visualdl
...@@ -60,6 +60,7 @@ IMPL_ENTRY_GET(bool, b); ...@@ -60,6 +60,7 @@ IMPL_ENTRY_GET(bool, b);
} }
IMPL_ENTRY_GET_MULTI(int, i32s); IMPL_ENTRY_GET_MULTI(int, i32s);
IMPL_ENTRY_GET_MULTI(int64_t, i64s);
IMPL_ENTRY_GET_MULTI(float, fs); IMPL_ENTRY_GET_MULTI(float, fs);
IMPL_ENTRY_GET_MULTI(double, ds); IMPL_ENTRY_GET_MULTI(double, ds);
IMPL_ENTRY_GET_MULTI(std::string, ss); IMPL_ENTRY_GET_MULTI(std::string, ss);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册