mock_param_dict_impl.cpp 4.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include <algorithm>
16 17
#include <fstream>
#include <iterator>
18
#include <sstream>
19 20
#include <thread>
#include "kvdb/rocksdb_impl.h"
21

22
std::vector<FileReaderPtr> ParamDict::GetDictReaderLst() {
23
  return this->file_reader_lst_;
24 25
}

26
void ParamDict::SetFileReaderLst(std::vector<std::string> lst) {
27 28 29 30 31
  for (size_t i = 0; i < lst.size(); i++) {
    FileReaderPtr fr = std::make_shared<FileReader>();
    fr->SetFileName(lst[i]);
    this->file_reader_lst_.push_back(fr);
  }
32 33
}

34 35 36 37 38 39 40 41 42 43 44 45 46
std::vector<float> ParamDict::GetSparseValue(std::string feasign,
                                             std::string slot) {
  auto BytesToFloat = [](uint8_t* byte_array) { return *((float*)byte_array); };
  // TODO: the concatation of feasign and slot is TBD.
  std::string result = front_db->Get(feasign + slot);
  std::vector<float> value;
  if (result == "NOT_FOUND") return value;
  uint8_t* raw_values_ptr = reinterpret_cast<uint8_t*>(&result[0]);
  for (size_t i = 0; i < result.size(); i += sizeof(float)) {
    float temp = BytesToFloat(raw_values_ptr + i);
    value.push_back(temp);
  }
  return value;
47 48
}

49 50 51
void ParamDict::SetReader(
    std::function<std::pair<Key, Value>(std::string)> func) {
  read_func_ = func;
52 53 54
}

std::vector<float> ParamDict::GetSparseValue(int64_t feasign, int64_t slot) {
55
  return this->GetSparseValue(std::to_string(feasign), std::to_string(slot));
56 57
}

58 59 60 61 62
bool ParamDict::InsertSparseValue(int64_t feasign,
                                  int64_t slot,
                                  const std::vector<float>& values) {
  return this->InsertSparseValue(
      std::to_string(feasign), std::to_string(slot), values);
63 64
}

65 66 67 68 69 70 71 72 73 74 75
bool ParamDict::InsertSparseValue(std::string feasign,
                                  std::string slot,
                                  const std::vector<float>& values) {
  auto FloatToBytes = [](float fvalue, uint8_t* arr) {
    unsigned char* pf = nullptr;
    unsigned char* px = nullptr;
    unsigned char i = 0;
    pf = (unsigned char*)&fvalue;
    px = arr;
    for (i = 0; i < sizeof(float); i++) {
      *(px + i) = *(pf + i);
76
    }
77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
  };

  std::string key = feasign + slot;
  uint8_t* values_ptr = new uint8_t[values.size() * sizeof(float)];
  std::string value;
  for (size_t i = 0; i < values.size(); i++) {
    FloatToBytes(values[i], values_ptr + sizeof(float) * i);
  }
  char* raw_values_ptr = reinterpret_cast<char*>(values_ptr);
  for (size_t i = 0; i < values.size() * sizeof(float); i++) {
    value.push_back(raw_values_ptr[i]);
  }
  back_db->Set(key, value);
  // TODO: change stateless to stateful
  return true;
92 93
}

94
void ParamDict::UpdateBaseModel() {
95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110
  auto is_number = [](const std::string& s) {
    return !s.empty() && std::find_if(s.begin(), s.end(), [](char c) {
                           return !std::isdigit(c);
                         }) == s.end();
  };
  std::thread t([&]() {
    for (FileReaderPtr file_reader : this->file_reader_lst_) {
      std::string line;
      std::ifstream infile(file_reader->GetFileName());
      if (infile.is_open()) {
        while (getline(infile, line)) {
          std::pair<Key, Value> kvpair = read_func_(line);
          std::vector<float> nums;
          for (size_t i = 0; i < kvpair.second.size(); i++) {
            if (is_number(kvpair.second[i])) {
              nums.push_back(std::stof(kvpair.second[i]));
111
            }
112 113
          }
          this->InsertSparseValue(kvpair.first, "", nums);
114
        }
115 116 117 118 119 120 121 122
      }
      infile.close();
    }
    AbsKVDBPtr temp = front_db;
    front_db = back_db;
    back_db = temp;
  });
  t.detach();
123 124
}

125
void ParamDict::UpdateDeltaModel() { UpdateBaseModel(); }
126

127 128
std::pair<AbsKVDBPtr, AbsKVDBPtr> ParamDict::GetKVDB() {
  return {front_db, back_db};
129 130
}

131
void ParamDict::SetKVDB(std::pair<AbsKVDBPtr, AbsKVDBPtr> kvdbs) {
132 133
  this->front_db = kvdbs.first;
  this->back_db = kvdbs.second;
134 135
}

136
void ParamDict::CreateKVDB() {
137 138 139 140
  this->front_db = std::make_shared<RocksKVDB>();
  this->back_db = std::make_shared<RocksKVDB>();
  this->front_db->CreateDB();
  this->back_db->CreateDB();
141 142
}

143 144
ParamDict::~ParamDict() {
}