diff --git a/kvdb/include/kvdb/kvdb_impl.h b/kvdb/include/kvdb/kvdb_impl.h index 215cd8e07762e99688f02eed0bd5ffc3a18a0f3e..b610120afc15314a46134a88db1c5979ba6b1287 100644 --- a/kvdb/include/kvdb/kvdb_impl.h +++ b/kvdb/include/kvdb/kvdb_impl.h @@ -19,12 +19,12 @@ #include #include class AbstractKVDB; -class AbstractDictReader; -class AbstractParamDict; +class FileReader; +class ParamDict; typedef std::shared_ptr AbsKVDBPtr; -typedef std::shared_ptr AbsDictReaderPtr; -typedef std::shared_ptr AbsParamDictPtr; +typedef std::shared_ptr FileReaderPtr; +typedef std::shared_ptr ParamDictPtr; class AbstractKVDB { public: @@ -38,15 +38,50 @@ public: // TODO: Implement RedisKVDB //class RedisKVDB; -class AbstractDictReader { +class FileReader { public: - virtual std::string GetFileName() = 0; - virtual void SetFileName(std::string) = 0; - virtual std::string GetMD5() = 0; - virtual bool CheckDiff() = 0; - virtual std::chrono::system_clock::time_point GetTimeStamp() = 0; - virtual void Read(AbstractParamDict*) = 0; - virtual ~AbstractDictReader() = 0; + inline virtual std::string GetFileName() { + return this->filename_; + } + + inline virtual void SetFileName(std::string filename) { + this->filename_ = filename; + this->last_md5_val_ = this->GetMD5(); + this->time_stamp_ = std::chrono::system_clock::now(); + } + + inline virtual std::string GetMD5() { + auto getCmdOut = [] (std::string cmd) { + std::string data; + FILE *stream = nullptr; + const int max_buffer = 256; + char buffer[max_buffer]; + cmd.append(" 2>&1"); + stream = popen(cmd.c_str(), "r"); + if (stream) { + if (fgets(buffer, max_buffer, stream) != NULL) { + data.append(buffer); + } + } + return data; + }; + std::string cmd = "md5sum " + this->filename_; + //TODO: throw exception if error occurs during execution of shell command + std::string md5val = getCmdOut(cmd); + this->time_stamp_ = md5val == this->last_md5_val_? this->time_stamp_: std::chrono::system_clock::now(); + this->last_md5_val_ = md5val; + return md5val; + } + + inline virtual bool CheckDiff() { + return this->GetMD5() == this->last_md5_val_; + } + + inline virtual std::chrono::system_clock::time_point GetTimeStamp() { + return this->time_stamp_; + } + + inline virtual ~FileReader() {}; protected: std::string filename_; std::string last_md5_val_; @@ -54,27 +89,31 @@ protected: }; -class AbstractParamDict { +class ParamDict { + typedef std::string Key; + typedef std::vector Value; public: - virtual std::vector GetDictReaderLst() = 0; - virtual void SetDictReaderLst(std::vector) = 0; + virtual std::vector GetDictReaderLst(); + virtual void SetFileReaderLst(std::vector lst); - virtual std::vector GetSparseValue(int64_t, int64_t) = 0; - virtual std::vector GetSparseValue(std::string, std::string) = 0; + virtual std::vector GetSparseValue(int64_t, int64_t); + virtual std::vector GetSparseValue(std::string, std::string); - virtual bool InsertSparseValue(int64_t, int64_t, const std::vector&) = 0; - virtual bool InsertSparseValue(std::string, std::string, const std::vector&) = 0; - - virtual void UpdateBaseModel() = 0; - virtual void UpdateDeltaModel() = 0; + virtual bool InsertSparseValue(int64_t, int64_t, const std::vector&); + virtual bool InsertSparseValue(std::string, std::string, const std::vector&); - virtual std::pair GetKVDB() = 0; - virtual void SetKVDB(std::pair) = 0; - virtual void CreateKVDB() = 0; + virtual void SetReader(std::function(std::string)>); + virtual void UpdateBaseModel(); + virtual void UpdateDeltaModel(); - virtual ~AbstractParamDict() = 0; + virtual std::pair GetKVDB(); + virtual void SetKVDB(std::pair); + virtual void CreateKVDB(); + + virtual ~ParamDict(); protected: - std::vector dict_reader_lst_; + std::function(std::string)> read_func_; + std::vector file_reader_lst_; AbsKVDBPtr front_db, back_db; }; @@ -83,9 +122,9 @@ protected: class ParamDictMgr { public: void UpdateAll(); - void InsertParamDict(std::string, AbsParamDictPtr); + void InsertParamDict(std::string, ParamDictPtr); protected: - std::unordered_map ParamDictMap; + std::unordered_map ParamDictMap; }; diff --git a/kvdb/include/kvdb/rocksdb_impl.h b/kvdb/include/kvdb/rocksdb_impl.h index 890c5f86cab6c41b916359fbecbe36cf800e4d75..71adb23a5175d4b4542f71947f415b9313db74f9 100644 --- a/kvdb/include/kvdb/rocksdb_impl.h +++ b/kvdb/include/kvdb/rocksdb_impl.h @@ -30,36 +30,6 @@ public: static int db_count; }; -class RocksDBDictReader : public AbstractDictReader{ -public: - std::string GetFileName(); - void SetFileName(std::string); - std::string GetMD5(); - bool CheckDiff(); - std::chrono::system_clock::time_point GetTimeStamp(); - void Read(AbstractParamDict*); - ~RocksDBDictReader(); -}; - -class RocksDBParamDict : public AbstractParamDict{ -public: - std::vector GetDictReaderLst(); - void SetDictReaderLst(std::vector); - - std::vector GetSparseValue(int64_t, int64_t); - std::vector GetSparseValue(std::string, std::string); - bool InsertSparseValue(int64_t, int64_t, const std::vector&); - bool InsertSparseValue(std::string, std::string, const std::vector&); - - void UpdateBaseModel(); - void UpdateDeltaModel(); - - std::pair GetKVDB(); - void SetKVDB(std::pair); - void CreateKVDB(); - - ~RocksDBParamDict(); -}; diff --git a/kvdb/src/.gtest_kvdb.cpp.swp b/kvdb/src/.gtest_kvdb.cpp.swp deleted file mode 100644 index 22618177ce8d4d22c0a46df11f667b6e6e7ae5ea..0000000000000000000000000000000000000000 Binary files a/kvdb/src/.gtest_kvdb.cpp.swp and /dev/null differ diff --git a/kvdb/src/gtest_kvdb.cpp b/kvdb/src/gtest_kvdb.cpp index b0070a82f958b1d5fa87b3c3563c882f85719501..4e7d0aca5ae778a17ed5a4040f6424ee96579c47 100644 --- a/kvdb/src/gtest_kvdb.cpp +++ b/kvdb/src/gtest_kvdb.cpp @@ -16,6 +16,7 @@ #include "kvdb/kvdb_impl.h" #include "kvdb/paddle_rocksdb.h" #include +#include #include #include #include @@ -28,19 +29,19 @@ protected: static void SetUpTestCase() { kvdb = std::make_shared(); - dict_reader = std::make_shared(); - param_dict = std::make_shared(); + dict_reader = std::make_shared(); + param_dict = std::make_shared(); } static AbsKVDBPtr kvdb; - static AbsDictReaderPtr dict_reader; - static AbsParamDictPtr param_dict; + static FileReaderPtr dict_reader; + static ParamDictPtr param_dict; static ParamDictMgr dict_mgr; }; AbsKVDBPtr KVDBTest::kvdb; -AbsDictReaderPtr KVDBTest::dict_reader; -AbsParamDictPtr KVDBTest::param_dict; +FileReaderPtr KVDBTest::dict_reader; +ParamDictPtr KVDBTest::param_dict; ParamDictMgr KVDBTest::dict_mgr; void GenerateTestIn(std::string); @@ -58,7 +59,7 @@ TEST_F(KVDBTest, AbstractKVDB_Unit_Test) { } } -TEST_F(KVDBTest, AbstractDictReader_Unit_Test) { +TEST_F(KVDBTest, FileReader_Unit_Test) { std::string test_in_filename = "abs_dict_reader_test_in.txt"; GenerateTestIn(test_in_filename); dict_reader->SetFileName(test_in_filename); @@ -81,9 +82,29 @@ TEST_F(KVDBTest, AbstractDictReader_Unit_Test) { ASSERT_NE(timestamp_2, timestamp_3); } #include -TEST_F(KVDBTest, RocksDBParamDict_Unit_Test) { +TEST_F(KVDBTest, ParamDict_Unit_Test) { std::string test_in_filename = "abs_dict_reader_test_in.txt"; - param_dict->SetDictReaderLst({dict_reader}); + param_dict->SetFileReaderLst({test_in_filename}); + param_dict->SetReader( + [] (std::string text) { + auto split = [](const std::string& s, + std::vector& sv, + const char* delim = " ") { + sv.clear(); + char* buffer = new char[s.size() + 1]; + std::copy(s.begin(), s.end(), buffer); + char* p = strtok(buffer, delim); + do { + sv.push_back(p); + } while ((p = strtok(NULL, delim))); + return; + }; + std::vector text_split; + split(text, text_split, " "); + std::string key = text_split[0]; + text_split.erase(text_split.begin()); + return make_pair(key, text_split); + }); param_dict->CreateKVDB(); GenerateTestIn(test_in_filename); @@ -91,7 +112,7 @@ TEST_F(KVDBTest, RocksDBParamDict_Unit_Test) { std::this_thread::sleep_for(std::chrono::seconds(2)); - std::vector test_vec = param_dict->GetSparseValue(1, 1); + std::vector test_vec = param_dict->GetSparseValue("1", ""); ASSERT_LT(fabs(test_vec[0] - 1.0), 1e-2); diff --git a/kvdb/src/mock_param_dict_impl.cpp b/kvdb/src/mock_param_dict_impl.cpp index 07f5b7b0304fb835f47d178c52529d3619059310..e1763b2b6bd1b668a3a2d3baa25ccdbbd63e95fe 100644 --- a/kvdb/src/mock_param_dict_impl.cpp +++ b/kvdb/src/mock_param_dict_impl.cpp @@ -16,74 +16,22 @@ #include #include #include +#include #include -std::string RocksDBDictReader::GetFileName() { - return this->filename_; -} - -void RocksDBDictReader::SetFileName(std::string filename) { - this->filename_ = filename; - this->last_md5_val_ = this->GetMD5(); - this->time_stamp_ = std::chrono::system_clock::now(); -} - -std::string RocksDBDictReader::GetMD5() { - auto getCmdOut = [] (std::string cmd) { - std::string data; - FILE *stream = nullptr; - const int max_buffer = 256; - char buffer[max_buffer]; - cmd.append(" 2>&1"); - stream = popen(cmd.c_str(), "r"); - if (stream) { - if (fgets(buffer, max_buffer, stream) != NULL) { - data.append(buffer); - } - } - return data; - }; - std::string cmd = "md5sum " + this->filename_; -//TODO: throw exception if error occurs during execution of shell command - std::string md5val = getCmdOut(cmd); - this->time_stamp_ = md5val == this->last_md5_val_? this->time_stamp_: std::chrono::system_clock::now(); - this->last_md5_val_ = md5val; - return md5val; -} -bool RocksDBDictReader::CheckDiff() { - return this->GetMD5() == this->last_md5_val_; +std::vector ParamDict::GetDictReaderLst() { + return this->file_reader_lst_; } -std::chrono::system_clock::time_point RocksDBDictReader::GetTimeStamp() { - return this->time_stamp_; -} - -void RocksDBDictReader::Read(AbstractParamDict* param_dict) { - std::string line; - std::ifstream infile(this->filename_); - if (infile.is_open()) { - while (getline(infile, line)) { - //TODO: Write String Parse Here - // param_dict->InsertSparseValue(); - } +void ParamDict::SetFileReaderLst(std::vector lst) { + for (size_t i = 0; i < lst.size(); i++) { + FileReaderPtr fr = std::make_shared(); + fr->SetFileName(lst[i]); + this->file_reader_lst_.push_back(fr); } - infile.close(); } -RocksDBDictReader::~RocksDBDictReader() { -//TODO: I imageine nothing to do here -} - - -std::vector RocksDBParamDict::GetDictReaderLst() { - return this->dict_reader_lst_; -} - -void RocksDBParamDict::SetDictReaderLst(std::vector lst) { - this->dict_reader_lst_ = lst; -} - -std::vector RocksDBParamDict::GetSparseValue(std::string feasign, std::string slot) { +std::vector ParamDict::GetSparseValue(std::string feasign, std::string slot) { auto BytesToFloat = [](uint8_t* byteArray){ return *((float*)byteArray); }; @@ -100,15 +48,19 @@ std::vector RocksDBParamDict::GetSparseValue(std::string feasign, std::st return value; } -std::vector RocksDBParamDict::GetSparseValue(int64_t feasign, int64_t slot) { +void ParamDict::SetReader(std::function(std::string)> func) { + read_func_ = func; +} + +std::vector ParamDict::GetSparseValue(int64_t feasign, int64_t slot) { return this->GetSparseValue(std::to_string(feasign), std::to_string(slot)); } -bool RocksDBParamDict::InsertSparseValue(int64_t feasign, int64_t slot, const std::vector& values) { +bool ParamDict::InsertSparseValue(int64_t feasign, int64_t slot, const std::vector& values) { return this->InsertSparseValue(std::to_string(feasign), std::to_string(slot), values); } -bool RocksDBParamDict::InsertSparseValue(std::string feasign, std::string slot, const std::vector& values) { +bool ParamDict::InsertSparseValue(std::string feasign, std::string slot, const std::vector& values) { auto FloatToBytes = [](float fvalue, uint8_t *arr){ unsigned char *pf = nullptr; unsigned char *px = nullptr; @@ -136,23 +88,29 @@ bool RocksDBParamDict::InsertSparseValue(std::string feasign, std::string slot, return true; } -void RocksDBParamDict::UpdateBaseModel() { +void ParamDict::UpdateBaseModel() { + auto is_number = [] (const std::string& s) + { + return !s.empty() && std::find_if(s.begin(), + s.end(), [](char c) { return !std::isdigit(c); }) == s.end(); + }; std::thread t([&] () { - for (AbsDictReaderPtr dict_reader: this->dict_reader_lst_) { - if (dict_reader->CheckDiff()) { - std::vector strs; - dict_reader->Read(this); - for (const std::string& str: strs) { - std::vector arr; - std::istringstream in(str); - copy(std::istream_iterator(in), std::istream_iterator(), back_inserter(arr)); + for (FileReaderPtr file_reader: this->file_reader_lst_) { + std::string line; + std::ifstream infile(file_reader->GetFileName()); + if (infile.is_open()) { + while (getline(infile, line)) { + std::pair kvpair = read_func_(line); std::vector nums; - for (size_t i = 2; i < arr.size(); i++) { - nums.push_back(std::stof(arr[i])); + for (size_t i = 0; i < kvpair.second.size(); i++) { + if (is_number(kvpair.second[i])) { + nums.push_back(std::stof(kvpair.second[i])); + } } - this->InsertSparseValue(arr[0], arr[1], nums); + this->InsertSparseValue(kvpair.first, "", nums); } } + infile.close(); } AbsKVDBPtr temp = front_db; front_db = back_db; @@ -162,27 +120,27 @@ void RocksDBParamDict::UpdateBaseModel() { } -void RocksDBParamDict::UpdateDeltaModel() { +void ParamDict::UpdateDeltaModel() { UpdateBaseModel(); } -std::pair RocksDBParamDict::GetKVDB() { +std::pair ParamDict::GetKVDB() { return {front_db, back_db}; } -void RocksDBParamDict::SetKVDB(std::pair kvdbs) { +void ParamDict::SetKVDB(std::pair kvdbs) { this->front_db = kvdbs.first; this->back_db = kvdbs.second; } -void RocksDBParamDict::CreateKVDB() { +void ParamDict::CreateKVDB() { this->front_db = std::make_shared(); this->back_db = std::make_shared(); this->front_db->CreateDB(); this->back_db->CreateDB(); } -RocksDBParamDict::~RocksDBParamDict() { +ParamDict::~ParamDict() { } diff --git a/kvdb/src/param_dict_mgr_impl.cpp b/kvdb/src/param_dict_mgr_impl.cpp index 0ee37f42342fcf1d51f8a7538b8bc0e1e93f495b..67fd8744fe8dfb0ebe028650506432e5f5ac3186 100644 --- a/kvdb/src/param_dict_mgr_impl.cpp +++ b/kvdb/src/param_dict_mgr_impl.cpp @@ -21,10 +21,8 @@ void ParamDictMgr::UpdateAll() { } -void ParamDictMgr::InsertParamDict(std::string key, AbsParamDictPtr value) { +void ParamDictMgr::InsertParamDict(std::string key, ParamDictPtr value) { this->ParamDictMap.insert(std::make_pair(key, value)); } AbstractKVDB::~AbstractKVDB() {} -AbstractDictReader::~AbstractDictReader() {} -AbstractParamDict::~AbstractParamDict() {}