sdk.cc 4.0 KB
Newer Older
S
superjom 已提交
1 2 3 4
#include "visualdl/logic/sdk.h"

namespace visualdl {

S
superjom 已提交
5
namespace components {
S
superjom 已提交
6

S
superjom 已提交
7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
template <typename T>
std::vector<T> ScalarReader<T>::records() const {
  std::vector<T> res;
  for (int i = 0; i < reader_.total_records(); i++) {
    res.push_back(reader_.record(i).data<T>(0).Get());
  }
  return res;
}

template <typename T>
std::vector<T> ScalarReader<T>::ids() const {
  std::vector<T> res;
  for (int i = 0; i < reader_.total_records(); i++) {
    res.push_back(reader_.record(i).id());
  }
  return res;
}

template <typename T>
std::vector<T> ScalarReader<T>::timestamps() const {
  std::vector<T> res;
  for (int i = 0; i < reader_.total_records(); i++) {
    res.push_back(reader_.record(i).timestamp());
  }
  return res;
}

template <typename T>
std::string ScalarReader<T>::caption() const {
  CHECK(!reader_.captions().empty()) << "no caption";
  return reader_.captions().front();
}

template <typename T>
size_t ScalarReader<T>::size() const {
  return reader_.total_records();
}

template class ScalarReader<int>;
template class ScalarReader<int64_t>;
template class ScalarReader<float>;
template class ScalarReader<double>;
S
superjom 已提交
49

S
superjom 已提交
50
void Image::StartSampling() {
S
superjom 已提交
51 52 53
  // TODO(ChunweiYan) big bug here, every step will be stored in protobuf
  // and that might result in explosion in some scenerios, Just sampling
  // some steps should be better.
S
superjom 已提交
54
  step_ = writer_.AddRecord();
S
superjom 已提交
55
  step_.SetId(step_id_);
S
superjom 已提交
56 57 58 59

  time_t time = std::time(nullptr);
  step_.SetTimeStamp(time);

S
superjom 已提交
60 61 62 63
  // resize record
  for (int i = 0; i < num_samples_; i++) {
    step_.AddData<value_t>();
  }
S
superjom 已提交
64 65 66 67
  num_records_ = 0;
}

int Image::IsSampleTaken() {
S
superjom 已提交
68
  if (!ToSampleThisStep()) return -1;
S
superjom 已提交
69 70 71 72 73
  num_records_++;
  if (num_records_ <= num_samples_) {
    return num_records_ - 1;
  }
  float prob = float(num_samples_) / num_records_;
S
superjom 已提交
74 75
  float randv = (float)rand() / RAND_MAX;
  if (randv < prob) {
S
superjom 已提交
76 77 78 79 80 81 82 83
    // take this sample
    int index = rand() % num_samples_;
    return index;
  }
  return -1;
}

void Image::FinishSampling() {
S
superjom 已提交
84 85 86 87 88
  step_id_++;
  if (ToSampleThisStep()) {
    // TODO(ChunweiYan) much optimizement here.
    writer_.parent()->PersistToDisk();
  }
S
superjom 已提交
89 90 91 92 93 94 95 96 97 98 99 100
}

template <typename T, typename U>
struct is_same_type {
  static const bool value = false;
};
template <typename T>
struct is_same_type<T, T> {
  static const bool value = true;
};

void Image::SetSample(int index,
S
superjom 已提交
101
                      const std::vector<shape_t>& shape,
S
superjom 已提交
102 103 104 105 106 107 108 109 110
                      const std::vector<value_t>& data) {
  // production
  int size = std::accumulate(
      shape.begin(), shape.end(), 1., [](float a, float b) { return a * b; });
  CHECK_EQ(size, data.size()) << "image's shape not match data";
  CHECK_LT(index, num_samples_);
  CHECK_LE(index, num_records_);

  // set data
S
superjom 已提交
111 112 113 114 115 116 117
  auto entry = step_.MutableData<std::vector<char>>(index);
  // trick to store int8 to protobuf
  std::vector<char> data_str(data.size());
  for (int i = 0; i < data.size(); i++) {
    data_str[i] = data[i];
  }
  entry.Set(data_str);
S
superjom 已提交
118 119

  static_assert(
S
superjom 已提交
120
      !is_same_type<value_t, shape_t>::value,
S
superjom 已提交
121 122 123
      "value_t should not use int64_t field, this type is used to store shape");

  // set meta with hack
S
superjom 已提交
124
  Entry<shape_t> meta;
S
superjom 已提交
125 126 127 128 129
  meta.set_parent(entry.parent());
  meta.entry = entry.entry;
  meta.SetMulti(shape);
}

S
superjom 已提交
130 131 132 133 134 135 136 137 138 139
std::string ImageReader::caption() {
  CHECK_EQ(reader_.captions().size(), 1);
  auto caption = reader_.captions().front();
  if (Reader::TagMatchMode(caption, mode_)) {
    return Reader::GenReadableTag(mode_, caption);
  }
  string::TagDecode(caption);
  return caption;
}

S
superjom 已提交
140 141 142
ImageReader::ImageRecord ImageReader::record(int offset, int index) {
  ImageRecord res;
  auto record = reader_.record(offset);
S
superjom 已提交
143
  auto data_entry = record.data<std::vector<char>>(index);
S
superjom 已提交
144
  auto shape_entry = record.data<shape_t>(index);
S
superjom 已提交
145 146 147 148 149
  auto data_str = data_entry.Get();
  std::transform(data_str.begin(),
                 data_str.end(),
                 std::back_inserter(res.data),
                 [](char i) { return (int)((unsigned char)i); });
S
superjom 已提交
150 151 152
  res.shape = shape_entry.GetMulti();
  res.step_id = record.id();
  return res;
S
superjom 已提交
153 154
}

S
superjom 已提交
155 156 157
}  // namespace components

}  // namespace visualdl