gtest_kvdb.cpp 4.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

W
wangjiawei04 已提交
15
#include "kvdb/rocksdb_impl.h"
16 17 18
#include "kvdb/kvdb_impl.h"
#include "kvdb/paddle_rocksdb.h"
#include <gtest/gtest.h>
19
#include <functional>
20 21 22 23 24 25 26 27 28 29 30 31
#include <string>
#include <fstream>
#include <chrono>
#include <thread>
class KVDBTest : public ::testing::Test {
protected:
    void SetUp() override{
                
    }
    
    static void SetUpTestCase() {
        kvdb = std::make_shared<RocksKVDB>();
32 33
        dict_reader = std::make_shared<FileReader>();
        param_dict = std::make_shared<ParamDict>();
34 35 36
    }
    
    static AbsKVDBPtr kvdb;
37 38
    static FileReaderPtr dict_reader;
    static ParamDictPtr param_dict;
39 40 41 42
    static ParamDictMgr dict_mgr;

};
AbsKVDBPtr KVDBTest::kvdb;
43 44
FileReaderPtr KVDBTest::dict_reader;
ParamDictPtr KVDBTest::param_dict;
45 46 47 48 49 50 51 52
ParamDictMgr KVDBTest::dict_mgr;

void GenerateTestIn(std::string);
void UpdateTestIn(std::string);

TEST_F(KVDBTest, AbstractKVDB_Unit_Test) {
    kvdb->CreateDB();
    kvdb->SetDBName("test_kvdb");
W
wangjiawei04 已提交
53
    for (int i = 0; i < 100; i++) {
54 55
        kvdb->Set(std::to_string(i), std::to_string(i * 2));
    }
W
wangjiawei04 已提交
56
    for (int i = 0; i < 100; i++) {
57 58 59 60 61
        std::string val = kvdb->Get(std::to_string(i));
        ASSERT_EQ(val, std::to_string(i * 2));
    }
}

62
TEST_F(KVDBTest, FileReader_Unit_Test) {
63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84
    std::string test_in_filename = "abs_dict_reader_test_in.txt";
    GenerateTestIn(test_in_filename);
    dict_reader->SetFileName(test_in_filename);

    std::string md5_1 = dict_reader->GetMD5();
    std::chrono::system_clock::time_point timestamp_1 = dict_reader->GetTimeStamp();

    std::string md5_2 = dict_reader->GetMD5();
    std::chrono::system_clock::time_point timestamp_2 = dict_reader->GetTimeStamp();
    
    ASSERT_EQ(md5_1, md5_2);
    ASSERT_EQ(timestamp_1, timestamp_2);

    UpdateTestIn(test_in_filename);

    std::string md5_3 = dict_reader->GetMD5();
    std::chrono::system_clock::time_point timestamp_3 = dict_reader->GetTimeStamp();
    
    ASSERT_NE(md5_2, md5_3);
    ASSERT_NE(timestamp_2, timestamp_3);   
}
#include <cmath>
85
TEST_F(KVDBTest, ParamDict_Unit_Test) {
86
    std::string test_in_filename = "abs_dict_reader_test_in.txt";
87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
    param_dict->SetFileReaderLst({test_in_filename});
    param_dict->SetReader(
    [] (std::string text) {
      auto split = [](const std::string& s,
                 std::vector<std::string>& sv,
                 const char* delim = " ") {
        sv.clear();
        char* buffer = new char[s.size() + 1];
        std::copy(s.begin(), s.end(), buffer);
        char* p = strtok(buffer, delim);
        do {
          sv.push_back(p);
        } while ((p = strtok(NULL, delim)));
        return;
      };
      std::vector<std::string> text_split;
      split(text, text_split, " ");
      std::string key = text_split[0];
      text_split.erase(text_split.begin());
      return make_pair(key, text_split);            
    });
108 109 110 111 112 113 114
    param_dict->CreateKVDB();
    GenerateTestIn(test_in_filename);

    param_dict->UpdateBaseModel();

    std::this_thread::sleep_for(std::chrono::seconds(2));
    
115
    std::vector<float> test_vec = param_dict->GetSparseValue("1", "");
116 117 118 119 120 121 122 123 124

    ASSERT_LT(fabs(test_vec[0] - 1.0), 1e-2);

    UpdateTestIn(test_in_filename);
    param_dict->UpdateDeltaModel();
}

void GenerateTestIn(std::string filename) {
    std::ifstream in_file(filename);
W
wangjiawei04 已提交
125
    if (in_file.good()) {
126 127 128 129 130
        in_file.close();
        std::string cmd = "rm -rf "+ filename;
        system(cmd.c_str());
    }
    std::ofstream out_file(filename);
W
wangjiawei04 已提交
131
    for (size_t i = 0; i < 100000; i++) {
132
        out_file << i << " " << i << " ";
W
wangjiawei04 已提交
133
        for (size_t j = 0; j < 3; j++) {
134 135 136 137 138 139 140 141 142
            out_file << i << " ";
        }
        out_file << std::endl;
    }
    out_file.close();
}

void UpdateTestIn(std::string filename) {
    std::ifstream in_file(filename);
W
wangjiawei04 已提交
143
    if (in_file.good()) {
144 145 146 147 148
        in_file.close();
        std::string cmd = "rm -rf " + filename;
        system(cmd.c_str());
    }
    std::ofstream out_file(filename);
W
wangjiawei04 已提交
149
    for (size_t i = 0; i < 10000; i++) {
150
        out_file << i << " " << i << " ";
W
wangjiawei04 已提交
151
        for (size_t j = 0; j < 3; j++) {
152 153 154 155 156 157 158 159 160 161 162 163
            out_file << i + 1 << " ";
        }
        out_file << std::endl;
    }
    out_file.close();
}

int main(int argc, char** argv) {
     ::testing::InitGoogleTest(&argc, argv);
     return RUN_ALL_TESTS();
}