mock_param_dict_impl.cpp 4.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

W
wangjiawei04 已提交
15
#include "kvdb/rocksdb_impl.h"
16 17 18
#include <thread>
#include <iterator>
#include <fstream>
19
#include <algorithm>
20 21
#include <sstream>

22 23
std::vector<FileReaderPtr> ParamDict::GetDictReaderLst() {
    return this->file_reader_lst_;
24 25
}

26 27 28 29 30
void ParamDict::SetFileReaderLst(std::vector<std::string> lst) {
    for (size_t i = 0; i < lst.size(); i++) {
        FileReaderPtr fr = std::make_shared<FileReader>();
        fr->SetFileName(lst[i]);
        this->file_reader_lst_.push_back(fr);
31 32 33
    }
}

34
std::vector<float> ParamDict::GetSparseValue(std::string feasign, std::string slot) {
35 36 37 38 39 40
    auto BytesToFloat = [](uint8_t* byteArray){
        return *((float*)byteArray);
    };
    //TODO: the concatation of feasign and slot is TBD.
    std::string result = front_db->Get(feasign + slot);
    std::vector<float> value;
W
wangjiawei04 已提交
41
    if (result == "NOT_FOUND") 
42 43
        return value;
    uint8_t* raw_values_ptr = reinterpret_cast<uint8_t *>(&result[0]);
W
wangjiawei04 已提交
44
    for (size_t i = 0; i < result.size(); i += 4) {
45 46 47 48 49 50
        float temp = BytesToFloat(raw_values_ptr + i);
        value.push_back(temp);
    }
    return value;
}

51 52 53 54 55
void ParamDict::SetReader(std::function<std::pair<Key, Value>(std::string)> func) {
    read_func_ = func;
}

std::vector<float> ParamDict::GetSparseValue(int64_t feasign, int64_t slot) {
56 57 58
    return this->GetSparseValue(std::to_string(feasign), std::to_string(slot));
}

59
bool ParamDict::InsertSparseValue(int64_t feasign, int64_t slot, const std::vector<float>& values) {
60 61 62
    return this->InsertSparseValue(std::to_string(feasign), std::to_string(slot), values);       
}

63
bool ParamDict::InsertSparseValue(std::string feasign, std::string slot, const std::vector<float>& values) {
64
    auto FloatToBytes = [](float fvalue, uint8_t *arr){
65 66 67
        unsigned char  *pf = nullptr;
        unsigned char *px = nullptr;
        unsigned char i = 0;
68 69
        pf =(unsigned char *)&fvalue;
        px = arr;
70
        for (i = 0; i < 4; i++)
71 72 73 74 75 76 77 78
        {
            *(px+i)=*(pf+i);
        }
    };

    std::string key = feasign + slot;
    uint8_t* values_ptr = new uint8_t[values.size() * 4];
    std::string value;
W
wangjiawei04 已提交
79
    for (size_t i = 0; i < values.size(); i++) {
80 81 82
        FloatToBytes(values[i], values_ptr + 4 * i);
    }
    char* raw_values_ptr = reinterpret_cast<char*>(values_ptr);
83
    for (size_t i = 0; i < values.size()*4; i++) {
84 85 86 87 88 89 90
        value.push_back(raw_values_ptr[i]);
    }
    back_db->Set(key, value);
//TODO: change stateless to stateful
    return true;
}

91 92 93 94 95 96
void ParamDict::UpdateBaseModel() {
    auto is_number = [] (const std::string& s)
    {
        return !s.empty() && std::find_if(s.begin(), 
                        s.end(), [](char c) { return !std::isdigit(c); }) == s.end();
    };
97
   std::thread t([&] () {
98 99 100 101 102 103
        for (FileReaderPtr file_reader: this->file_reader_lst_) {
            std::string line;
            std::ifstream infile(file_reader->GetFileName());
            if (infile.is_open()) {
                while (getline(infile, line)) {
                    std::pair<Key, Value> kvpair = read_func_(line);
104
                    std::vector<float> nums;
105 106 107 108
                    for (size_t i = 0; i < kvpair.second.size(); i++) {
                        if (is_number(kvpair.second[i])) {
                            nums.push_back(std::stof(kvpair.second[i]));
                        }
109
                    }
110
                    this->InsertSparseValue(kvpair.first, "", nums);
111 112
                }
            }
113
            infile.close();
114 115 116 117 118 119 120 121 122
        }
        AbsKVDBPtr temp = front_db;
        front_db = back_db;
        back_db = temp;
   });
   t.detach();
}


123
void ParamDict::UpdateDeltaModel() {
124 125 126
    UpdateBaseModel();
}

127
std::pair<AbsKVDBPtr, AbsKVDBPtr> ParamDict::GetKVDB()  {
128 129 130
    return {front_db, back_db};
}

131
void ParamDict::SetKVDB(std::pair<AbsKVDBPtr, AbsKVDBPtr> kvdbs) {
132 133 134 135
    this->front_db = kvdbs.first;
    this->back_db = kvdbs.second;
}

136
void ParamDict::CreateKVDB() {
137 138 139 140 141 142
    this->front_db = std::make_shared<RocksKVDB>();
    this->back_db = std::make_shared<RocksKVDB>();
    this->front_db->CreateDB();
    this->back_db->CreateDB();
}

143
ParamDict::~ParamDict() {
144 145 146 147 148 149 150

}