sdk.cc 4.1 KB
Newer Older
S
superjom 已提交
1 2 3 4
#include "visualdl/logic/sdk.h"

namespace visualdl {

S
superjom 已提交
5
namespace components {
S
superjom 已提交
6

S
superjom 已提交
7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
template <typename T>
std::vector<T> ScalarReader<T>::records() const {
  std::vector<T> res;
  for (int i = 0; i < reader_.total_records(); i++) {
    res.push_back(reader_.record(i).data<T>(0).Get());
  }
  return res;
}

template <typename T>
std::vector<T> ScalarReader<T>::ids() const {
  std::vector<T> res;
  for (int i = 0; i < reader_.total_records(); i++) {
    res.push_back(reader_.record(i).id());
  }
  return res;
}

template <typename T>
std::vector<T> ScalarReader<T>::timestamps() const {
  std::vector<T> res;
  for (int i = 0; i < reader_.total_records(); i++) {
    res.push_back(reader_.record(i).timestamp());
  }
  return res;
}

template <typename T>
std::string ScalarReader<T>::caption() const {
  CHECK(!reader_.captions().empty()) << "no caption";
  return reader_.captions().front();
}

template <typename T>
size_t ScalarReader<T>::size() const {
  return reader_.total_records();
}

template class ScalarReader<int>;
template class ScalarReader<int64_t>;
template class ScalarReader<float>;
template class ScalarReader<double>;
S
superjom 已提交
49

S
superjom 已提交
50
void Image::StartSampling() {
S
superjom 已提交
51 52
  if (!ToSampleThisStep()) return;

S
superjom 已提交
53
  step_ = writer_.AddRecord();
S
superjom 已提交
54
  step_.SetId(step_id_);
S
superjom 已提交
55 56 57 58

  time_t time = std::time(nullptr);
  step_.SetTimeStamp(time);

S
superjom 已提交
59 60 61 62
  // resize record
  for (int i = 0; i < num_samples_; i++) {
    step_.AddData<value_t>();
  }
S
superjom 已提交
63 64 65 66
  num_records_ = 0;
}

int Image::IsSampleTaken() {
S
superjom 已提交
67
  if (!ToSampleThisStep()) return -1;
S
superjom 已提交
68 69 70 71 72
  num_records_++;
  if (num_records_ <= num_samples_) {
    return num_records_ - 1;
  }
  float prob = float(num_samples_) / num_records_;
S
superjom 已提交
73 74
  float randv = (float)rand() / RAND_MAX;
  if (randv < prob) {
S
superjom 已提交
75 76 77 78 79 80 81 82
    // take this sample
    int index = rand() % num_samples_;
    return index;
  }
  return -1;
}

void Image::FinishSampling() {
S
superjom 已提交
83 84 85 86 87
  step_id_++;
  if (ToSampleThisStep()) {
    // TODO(ChunweiYan) much optimizement here.
    writer_.parent()->PersistToDisk();
  }
S
superjom 已提交
88 89 90 91 92 93 94 95 96 97 98 99
}

template <typename T, typename U>
struct is_same_type {
  static const bool value = false;
};
template <typename T>
struct is_same_type<T, T> {
  static const bool value = true;
};

void Image::SetSample(int index,
S
superjom 已提交
100
                      const std::vector<shape_t>& shape,
S
superjom 已提交
101 102 103 104
                      const std::vector<value_t>& data) {
  // production
  int size = std::accumulate(
      shape.begin(), shape.end(), 1., [](float a, float b) { return a * b; });
S
superjom 已提交
105
  CHECK_GT(size, 0);
S
superjom 已提交
106 107 108 109
  CHECK_EQ(shape.size(), 3)
      << "shape should be something like (width, height, num_channel)";
  CHECK_LE(shape.back(), 3);
  CHECK_GE(shape.back(), 2);
S
superjom 已提交
110 111 112 113 114
  CHECK_EQ(size, data.size()) << "image's shape not match data";
  CHECK_LT(index, num_samples_);
  CHECK_LE(index, num_records_);

  // set data
S
superjom 已提交
115 116 117 118 119 120 121
  auto entry = step_.MutableData<std::vector<char>>(index);
  // trick to store int8 to protobuf
  std::vector<char> data_str(data.size());
  for (int i = 0; i < data.size(); i++) {
    data_str[i] = data[i];
  }
  entry.Set(data_str);
S
superjom 已提交
122 123

  static_assert(
S
superjom 已提交
124
      !is_same_type<value_t, shape_t>::value,
S
superjom 已提交
125 126 127
      "value_t should not use int64_t field, this type is used to store shape");

  // set meta with hack
S
superjom 已提交
128
  Entry<shape_t> meta;
S
superjom 已提交
129 130 131 132 133
  meta.set_parent(entry.parent());
  meta.entry = entry.entry;
  meta.SetMulti(shape);
}

S
superjom 已提交
134 135 136 137 138 139 140 141 142 143
std::string ImageReader::caption() {
  CHECK_EQ(reader_.captions().size(), 1);
  auto caption = reader_.captions().front();
  if (Reader::TagMatchMode(caption, mode_)) {
    return Reader::GenReadableTag(mode_, caption);
  }
  string::TagDecode(caption);
  return caption;
}

S
superjom 已提交
144 145 146
ImageReader::ImageRecord ImageReader::record(int offset, int index) {
  ImageRecord res;
  auto record = reader_.record(offset);
S
superjom 已提交
147
  auto data_entry = record.data<std::vector<char>>(index);
S
superjom 已提交
148
  auto shape_entry = record.data<shape_t>(index);
S
superjom 已提交
149 150 151 152 153
  auto data_str = data_entry.Get();
  std::transform(data_str.begin(),
                 data_str.end(),
                 std::back_inserter(res.data),
                 [](char i) { return (int)((unsigned char)i); });
S
superjom 已提交
154 155 156
  res.shape = shape_entry.GetMulti();
  res.step_id = record.id();
  return res;
S
superjom 已提交
157 158
}

S
superjom 已提交
159 160 161
}  // namespace components

}  // namespace visualdl