提交 194e3792 编写于 作者: S superjom

image component ready

上级 7a08e5e9
......@@ -49,6 +49,10 @@ template class ScalarReader<double>;
void Image::StartSampling() {
step_ = writer_.AddRecord();
// resize record
for (int i = 0; i < num_samples_; i++) {
step_.AddData<value_t>();
}
num_records_ = 0;
}
......@@ -58,8 +62,8 @@ int Image::IsSampleTaken() {
return num_records_ - 1;
}
float prob = float(num_samples_) / num_records_;
float thre = (float)rand() / RAND_MAX;
if (prob < thre) {
float randv = (float)rand() / RAND_MAX;
if (randv < prob) {
// take this sample
int index = rand() % num_samples_;
return index;
......@@ -82,7 +86,7 @@ struct is_same_type<T, T> {
};
void Image::SetSample(int index,
const std::vector<int64_t>& shape,
const std::vector<shape_t>& shape,
const std::vector<value_t>& data) {
// production
int size = std::accumulate(
......@@ -92,26 +96,32 @@ void Image::SetSample(int index,
CHECK_LE(index, num_records_);
// set data
Entry<value_t> entry;
if (index == num_records_) {
// add one entry
entry = step_.AddData<value_t>();
} else {
entry = step_.MutableData<value_t>(index);
}
auto entry = step_.MutableData<value_t>(index);
entry.SetMulti(data);
static_assert(
!is_same_type<value_t, int64_t>::value,
!is_same_type<value_t, shape_t>::value,
"value_t should not use int64_t field, this type is used to store shape");
// set meta with hack
Entry<int64_t> meta;
Entry<shape_t> meta;
meta.set_parent(entry.parent());
meta.entry = entry.entry;
meta.SetMulti(shape);
}
std::vector<ImageReader::value_t> ImageReader::data(int step, int index) {
auto record = reader_.record(step);
auto entry = record.data<value_t>(index);
return entry.GetMulti();
}
std::vector<ImageReader::shape_t> ImageReader::shape(int step, int index) {
auto record = reader_.record(step);
auto entry = record.data<shape_t>(index);
return entry.GetMulti();
}
} // namespace components
} // namespace visualdl
......@@ -145,12 +145,16 @@ private:
*/
struct Image {
using value_t = float;
using shape_t = int64_t;
Image(Tablet tablet, int num_samples) : writer_(tablet) {
writer_.SetType(Tablet::Type::kImage);
writer_.SetNumSamples(num_samples);
num_samples_ = num_samples;
}
void SetCaption(const std::string& c) {
writer_.SetCaptions(std::vector<std::string>({c}));
}
/*
* Start a sample period.
*/
......@@ -165,7 +169,7 @@ struct Image {
void FinishSampling();
void SetSample(int index,
const std::vector<int64_t>& shape,
const std::vector<shape_t>& shape,
const std::vector<value_t>& data);
private:
......@@ -175,6 +179,31 @@ private:
int num_samples_{0};
};
/*
* Image reader.
*/
struct ImageReader {
using value_t = typename Image::value_t;
using shape_t = typename Image::shape_t;
ImageReader(TabletReader tablet) : reader_(tablet) {}
std::string caption() {
CHECK_EQ(reader_.captions().size(), 1);
return reader_.captions().front();
}
// number of steps.
int num_records() { return reader_.total_records(); }
std::vector<value_t> data(int step, int index);
std::vector<shape_t> shape(int step, int index);
private:
TabletReader reader_;
};
} // namespace components
} // namespace visualdl
......
......@@ -44,25 +44,39 @@ TEST(Scalar, write) {
TEST(Image, test) {
const auto dir = "./tmp/sdk_test.image";
Writer writer__(dir, 1);
Writer writer__(dir, 4);
auto writer = writer__.AsMode("train");
auto tablet = writer.AddTablet("image0");
components::Image image(tablet, 3);
const int num_steps = 10;
image.StartSampling();
for (int i = 0; i < 100; i++) {
vector<int64_t> shape({3, 5, 5});
vector<float> data;
for (int j = 0; j < 3 * 5 * 5; j++) {
data.push_back(float(rand()) / RAND_MAX);
}
int index = image.IsSampleTaken();
if (index != -1) {
image.SetSample(index, shape, data);
LOG(INFO) << "write images";
image.SetCaption("this is an image");
for (int step = 0; step < num_steps; step++) {
image.StartSampling();
for (int i = 0; i < 7; i++) {
vector<int64_t> shape({3, 5, 5});
vector<float> data;
for (int j = 0; j < 3 * 5 * 5; j++) {
data.push_back(float(rand()) / RAND_MAX);
}
int index = image.IsSampleTaken();
if (index != -1) {
image.SetSample(index, shape, data);
}
}
image.FinishSampling();
}
image.FinishSampling();
LOG(INFO) << "read images";
// read it
Reader reader__(dir);
auto reader = reader__.AsMode("train");
auto tablet2read = reader.tablet("image0");
components::ImageReader image2read(tablet2read);
CHECK_EQ(image2read.caption(), "this is an image");
CHECK_EQ(image2read.num_records(), num_steps);
}
} // namespace visualdl
......@@ -60,6 +60,7 @@ IMPL_ENTRY_GET(bool, b);
}
IMPL_ENTRY_GET_MULTI(int, i32s);
IMPL_ENTRY_GET_MULTI(int64_t, i64s);
IMPL_ENTRY_GET_MULTI(float, fs);
IMPL_ENTRY_GET_MULTI(double, ds);
IMPL_ENTRY_GET_MULTI(std::string, ss);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册