sdk.cc 4.1 KB
Newer Older
S
superjom 已提交
1 2
#include "visualdl/logic/sdk.h"

3 4
#include "visualdl/utils/image.h"

S
superjom 已提交
5 6
namespace visualdl {

S
superjom 已提交
7
namespace components {
S
superjom 已提交
8

S
superjom 已提交
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
template <typename T>
std::vector<T> ScalarReader<T>::records() const {
  std::vector<T> res;
  for (int i = 0; i < reader_.total_records(); i++) {
    res.push_back(reader_.record(i).data<T>(0).Get());
  }
  return res;
}

template <typename T>
std::vector<T> ScalarReader<T>::ids() const {
  std::vector<T> res;
  for (int i = 0; i < reader_.total_records(); i++) {
    res.push_back(reader_.record(i).id());
  }
  return res;
}

template <typename T>
std::vector<T> ScalarReader<T>::timestamps() const {
  std::vector<T> res;
  for (int i = 0; i < reader_.total_records(); i++) {
    res.push_back(reader_.record(i).timestamp());
  }
  return res;
}

template <typename T>
std::string ScalarReader<T>::caption() const {
  CHECK(!reader_.captions().empty()) << "no caption";
  return reader_.captions().front();
}

template <typename T>
size_t ScalarReader<T>::size() const {
  return reader_.total_records();
}

template class ScalarReader<int>;
template class ScalarReader<int64_t>;
template class ScalarReader<float>;
template class ScalarReader<double>;
S
superjom 已提交
51

S
superjom 已提交
52
void Image::StartSampling() {
S
superjom 已提交
53 54
  if (!ToSampleThisStep()) return;

S
superjom 已提交
55
  step_ = writer_.AddRecord();
S
superjom 已提交
56
  step_.SetId(step_id_);
S
superjom 已提交
57 58 59 60

  time_t time = std::time(nullptr);
  step_.SetTimeStamp(time);

S
superjom 已提交
61 62 63 64
  // resize record
  for (int i = 0; i < num_samples_; i++) {
    step_.AddData<value_t>();
  }
S
superjom 已提交
65 66 67 68
  num_records_ = 0;
}

int Image::IsSampleTaken() {
S
superjom 已提交
69
  if (!ToSampleThisStep()) return -1;
S
superjom 已提交
70 71 72 73 74
  num_records_++;
  if (num_records_ <= num_samples_) {
    return num_records_ - 1;
  }
  float prob = float(num_samples_) / num_records_;
S
superjom 已提交
75 76
  float randv = (float)rand() / RAND_MAX;
  if (randv < prob) {
S
superjom 已提交
77 78 79 80 81 82 83 84
    // take this sample
    int index = rand() % num_samples_;
    return index;
  }
  return -1;
}

void Image::FinishSampling() {
S
superjom 已提交
85 86 87 88 89
  step_id_++;
  if (ToSampleThisStep()) {
    // TODO(ChunweiYan) much optimizement here.
    writer_.parent()->PersistToDisk();
  }
S
superjom 已提交
90 91 92 93 94 95 96 97 98 99 100 101
}

template <typename T, typename U>
struct is_same_type {
  static const bool value = false;
};
template <typename T>
struct is_same_type<T, T> {
  static const bool value = true;
};

void Image::SetSample(int index,
S
superjom 已提交
102
                      const std::vector<shape_t>& shape,
S
superjom 已提交
103 104 105
                      const std::vector<value_t>& data) {
  // production
  int size = std::accumulate(
106
      shape.begin(), shape.end(), 1., [](int a, int b) { return a * b; });
S
superjom 已提交
107
  CHECK_GT(size, 0);
S
superjom 已提交
108 109 110 111
  CHECK_EQ(shape.size(), 3)
      << "shape should be something like (width, height, num_channel)";
  CHECK_LE(shape.back(), 3);
  CHECK_GE(shape.back(), 2);
S
superjom 已提交
112 113 114 115
  CHECK_EQ(size, data.size()) << "image's shape not match data";
  CHECK_LT(index, num_samples_);
  CHECK_LE(index, num_records_);

S
superjom 已提交
116
  auto entry = step_.MutableData<std::vector<byte_t>>(index);
S
superjom 已提交
117
  // trick to store int8 to protobuf
S
superjom 已提交
118
  std::vector<byte_t> data_str(data.size());
S
superjom 已提交
119 120 121
  for (int i = 0; i < data.size(); i++) {
    data_str[i] = data[i];
  }
S
superjom 已提交
122
  entry.SetRaw(std::string(data_str.begin(), data_str.end()));
S
superjom 已提交
123 124

  static_assert(
S
superjom 已提交
125
      !is_same_type<value_t, shape_t>::value,
S
superjom 已提交
126 127 128
      "value_t should not use int64_t field, this type is used to store shape");

  // set meta with hack
S
superjom 已提交
129
  Entry<shape_t> meta;
S
superjom 已提交
130 131 132 133 134
  meta.set_parent(entry.parent());
  meta.entry = entry.entry;
  meta.SetMulti(shape);
}

S
superjom 已提交
135 136 137 138 139 140 141 142 143 144
std::string ImageReader::caption() {
  CHECK_EQ(reader_.captions().size(), 1);
  auto caption = reader_.captions().front();
  if (Reader::TagMatchMode(caption, mode_)) {
    return Reader::GenReadableTag(mode_, caption);
  }
  string::TagDecode(caption);
  return caption;
}

S
superjom 已提交
145 146 147
ImageReader::ImageRecord ImageReader::record(int offset, int index) {
  ImageRecord res;
  auto record = reader_.record(offset);
S
superjom 已提交
148
  auto data_entry = record.data<std::vector<byte_t>>(index);
S
superjom 已提交
149
  auto shape_entry = record.data<shape_t>(index);
S
superjom 已提交
150
  auto data_str = data_entry.GetRaw();
S
superjom 已提交
151 152 153
  std::transform(data_str.begin(),
                 data_str.end(),
                 std::back_inserter(res.data),
S
superjom 已提交
154
                 [](byte_t i) { return (int)(i); });
S
superjom 已提交
155 156 157
  res.shape = shape_entry.GetMulti();
  res.step_id = record.id();
  return res;
S
superjom 已提交
158 159
}

S
superjom 已提交
160 161 162
}  // namespace components

}  // namespace visualdl