mock_param_dict_impl.cpp 5.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "kvdb/mock_kvdb_impl.h"
#include <thread>
#include <iterator>
#include <fstream>
#include <sstream>
std::string MockDictReader::GetFileName() {
    return this->filename_;
}

void MockDictReader::SetFileName(std::string filename) {
    this->filename_ = filename;
    this->last_md5_val_ = this->GetMD5();
    this->time_stamp_ = std::chrono::system_clock::now();
}

std::string MockDictReader::GetMD5() {
   auto getCmdOut = [] (std::string cmd) {
        std::string data;
33
        FILE *stream = nullptr;
34 35 36 37
        const int max_buffer = 256;
        char buffer[max_buffer];
        cmd.append(" 2>&1");
        stream = popen(cmd.c_str(), "r");
W
wangjiawei04 已提交
38 39
        if (stream) {
            if (fgets(buffer, max_buffer, stream) != NULL) {
40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
                data.append(buffer);
            }
        }
        return data;
   }; 
    std::string cmd = "md5sum " + this->filename_;
//TODO: throw exception if error occurs during execution of shell command
    std::string md5val = getCmdOut(cmd);
    this->time_stamp_ = md5val == this->last_md5_val_? this->time_stamp_: std::chrono::system_clock::now();
    this->last_md5_val_ = md5val;
    return md5val;
}

bool MockDictReader::CheckDiff() {
    return this->GetMD5() == this->last_md5_val_;
}

std::chrono::system_clock::time_point MockDictReader::GetTimeStamp() {
//TODO: Implement Get Time Stamp of dict file
    return this->time_stamp_;  
}

void MockDictReader::Read(std::vector<std::string>& res) {
    std::string line;
    std::ifstream infile(this->filename_);
W
wangjiawei04 已提交
65 66
    if (infile.is_open()) {
        while (getline(infile, line)) {
67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
            res.push_back(line);
        }
    }
    infile.close();
}

MockDictReader::~MockDictReader() {
//TODO: I imageine nothing to do here
}


std::vector<AbsDictReaderPtr> MockParamDict::GetDictReaderLst() {
    return this->dict_reader_lst_;
}

void MockParamDict::SetDictReaderLst(std::vector<AbsDictReaderPtr> lst) {
    this->dict_reader_lst_ = lst;
}

std::vector<float> MockParamDict::GetSparseValue(std::string feasign, std::string slot) {
    auto BytesToFloat = [](uint8_t* byteArray){
        return *((float*)byteArray);
    };
    //TODO: the concatation of feasign and slot is TBD.
    std::string result = front_db->Get(feasign + slot);
    std::vector<float> value;
W
wangjiawei04 已提交
93
    if (result == "NOT_FOUND") 
94 95
        return value;
    uint8_t* raw_values_ptr = reinterpret_cast<uint8_t *>(&result[0]);
W
wangjiawei04 已提交
96
    for (size_t i = 0; i < result.size(); i += 4) {
97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112
        float temp = BytesToFloat(raw_values_ptr + i);
        value.push_back(temp);
    }
    return value;
}

std::vector<float> MockParamDict::GetSparseValue(int64_t feasign, int64_t slot) {
    return this->GetSparseValue(std::to_string(feasign), std::to_string(slot));
}

bool MockParamDict::InsertSparseValue(int64_t feasign, int64_t slot, const std::vector<float>& values) {
    return this->InsertSparseValue(std::to_string(feasign), std::to_string(slot), values);       
}

bool MockParamDict::InsertSparseValue(std::string feasign, std::string slot, const std::vector<float>& values) {
    auto FloatToBytes = [](float fvalue, uint8_t *arr){
113 114 115
        unsigned char  *pf = nullptr;
        unsigned char *px = nullptr;
        unsigned char i = 0;
116 117
        pf =(unsigned char *)&fvalue;
        px = arr;
118
        for (i = 0; i < 4; i++)
119 120 121 122 123 124 125 126
        {
            *(px+i)=*(pf+i);
        }
    };

    std::string key = feasign + slot;
    uint8_t* values_ptr = new uint8_t[values.size() * 4];
    std::string value;
W
wangjiawei04 已提交
127
    for (size_t i = 0; i < values.size(); i++) {
128 129 130
        FloatToBytes(values[i], values_ptr + 4 * i);
    }
    char* raw_values_ptr = reinterpret_cast<char*>(values_ptr);
131
    for (size_t i = 0; i < values.size()*4; i++) {
132 133 134 135 136 137 138 139 140
        value.push_back(raw_values_ptr[i]);
    }
    back_db->Set(key, value);
//TODO: change stateless to stateful
    return true;
}

void MockParamDict::UpdateBaseModel() {
   std::thread t([&] () {
W
wangjiawei04 已提交
141 142
        for (AbsDictReaderPtr dict_reader: this->dict_reader_lst_) {
            if (dict_reader->CheckDiff()) {
143 144
                std::vector<std::string> strs;
                dict_reader->Read(strs);
W
wangjiawei04 已提交
145
                for (const std::string& str: strs) {
146 147 148 149
                    std::vector<std::string> arr;
                    std::istringstream in(str);
                    copy(std::istream_iterator<std::string>(in), std::istream_iterator<std::string>(), back_inserter(arr));
                    std::vector<float> nums;
W
wangjiawei04 已提交
150
                    for (size_t i = 2; i < arr.size(); i++) {
151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192
                        nums.push_back(std::stof(arr[i]));
                    }
                    this->InsertSparseValue(arr[0], arr[1], nums);
                }
            }
        }
        AbsKVDBPtr temp = front_db;
        front_db = back_db;
        back_db = temp;
   });
   t.detach();
}


void MockParamDict::UpdateDeltaModel() {
    UpdateBaseModel();
}

std::pair<AbsKVDBPtr, AbsKVDBPtr> MockParamDict::GetKVDB()  {
    return {front_db, back_db};
}

void MockParamDict::SetKVDB(std::pair<AbsKVDBPtr, AbsKVDBPtr> kvdbs) {
    this->front_db = kvdbs.first;
    this->back_db = kvdbs.second;
}

void MockParamDict::CreateKVDB() {
    this->front_db = std::make_shared<RocksKVDB>();
    this->back_db = std::make_shared<RocksKVDB>();
    this->front_db->CreateDB();
    this->back_db->CreateDB();
}

MockParamDict::~MockParamDict() {

}