提交 d3175966 编写于 作者: W wangjiawei04

modify API name and prepare for linking with CTR

Change-Id: Ia1c0d39a40d69937edcb43050d6d7d7088b55909
上级 81cd31a8
...@@ -19,12 +19,12 @@ ...@@ -19,12 +19,12 @@
#include <memory> #include <memory>
#include <chrono> #include <chrono>
class AbstractKVDB; class AbstractKVDB;
class AbstractDictReader; class FileReader;
class AbstractParamDict; class ParamDict;
typedef std::shared_ptr<AbstractKVDB> AbsKVDBPtr; typedef std::shared_ptr<AbstractKVDB> AbsKVDBPtr;
typedef std::shared_ptr<AbstractDictReader> AbsDictReaderPtr; typedef std::shared_ptr<FileReader> FileReaderPtr;
typedef std::shared_ptr<AbstractParamDict> AbsParamDictPtr; typedef std::shared_ptr<ParamDict> ParamDictPtr;
class AbstractKVDB { class AbstractKVDB {
public: public:
...@@ -38,15 +38,50 @@ public: ...@@ -38,15 +38,50 @@ public:
// TODO: Implement RedisKVDB // TODO: Implement RedisKVDB
//class RedisKVDB; //class RedisKVDB;
class AbstractDictReader { class FileReader {
public: public:
virtual std::string GetFileName() = 0; inline virtual std::string GetFileName() {
virtual void SetFileName(std::string) = 0; return this->filename_;
virtual std::string GetMD5() = 0; }
virtual bool CheckDiff() = 0;
virtual std::chrono::system_clock::time_point GetTimeStamp() = 0; inline virtual void SetFileName(std::string filename) {
virtual void Read(AbstractParamDict*) = 0; this->filename_ = filename;
virtual ~AbstractDictReader() = 0; this->last_md5_val_ = this->GetMD5();
this->time_stamp_ = std::chrono::system_clock::now();
}
inline virtual std::string GetMD5() {
auto getCmdOut = [] (std::string cmd) {
std::string data;
FILE *stream = nullptr;
const int max_buffer = 256;
char buffer[max_buffer];
cmd.append(" 2>&1");
stream = popen(cmd.c_str(), "r");
if (stream) {
if (fgets(buffer, max_buffer, stream) != NULL) {
data.append(buffer);
}
}
return data;
};
std::string cmd = "md5sum " + this->filename_;
//TODO: throw exception if error occurs during execution of shell command
std::string md5val = getCmdOut(cmd);
this->time_stamp_ = md5val == this->last_md5_val_? this->time_stamp_: std::chrono::system_clock::now();
this->last_md5_val_ = md5val;
return md5val;
}
inline virtual bool CheckDiff() {
return this->GetMD5() == this->last_md5_val_;
}
inline virtual std::chrono::system_clock::time_point GetTimeStamp() {
return this->time_stamp_;
}
inline virtual ~FileReader() {};
protected: protected:
std::string filename_; std::string filename_;
std::string last_md5_val_; std::string last_md5_val_;
...@@ -54,27 +89,31 @@ protected: ...@@ -54,27 +89,31 @@ protected:
}; };
class AbstractParamDict { class ParamDict {
typedef std::string Key;
typedef std::vector<std::string> Value;
public: public:
virtual std::vector<AbsDictReaderPtr> GetDictReaderLst() = 0; virtual std::vector<FileReaderPtr> GetDictReaderLst();
virtual void SetDictReaderLst(std::vector<AbsDictReaderPtr>) = 0; virtual void SetFileReaderLst(std::vector<std::string> lst);
virtual std::vector<float> GetSparseValue(int64_t, int64_t) = 0; virtual std::vector<float> GetSparseValue(int64_t, int64_t);
virtual std::vector<float> GetSparseValue(std::string, std::string) = 0; virtual std::vector<float> GetSparseValue(std::string, std::string);
virtual bool InsertSparseValue(int64_t, int64_t, const std::vector<float>&) = 0; virtual bool InsertSparseValue(int64_t, int64_t, const std::vector<float>&);
virtual bool InsertSparseValue(std::string, std::string, const std::vector<float>&) = 0; virtual bool InsertSparseValue(std::string, std::string, const std::vector<float>&);
virtual void UpdateBaseModel() = 0; virtual void SetReader(std::function<std::pair<Key, Value>(std::string)>);
virtual void UpdateDeltaModel() = 0; virtual void UpdateBaseModel();
virtual void UpdateDeltaModel();
virtual std::pair<AbsKVDBPtr, AbsKVDBPtr> GetKVDB() = 0; virtual std::pair<AbsKVDBPtr, AbsKVDBPtr> GetKVDB();
virtual void SetKVDB(std::pair<AbsKVDBPtr, AbsKVDBPtr>) = 0; virtual void SetKVDB(std::pair<AbsKVDBPtr, AbsKVDBPtr>);
virtual void CreateKVDB() = 0; virtual void CreateKVDB();
virtual ~AbstractParamDict() = 0; virtual ~ParamDict();
protected: protected:
std::vector<AbsDictReaderPtr> dict_reader_lst_; std::function<std::pair<Key, Value>(std::string)> read_func_;
std::vector<FileReaderPtr> file_reader_lst_;
AbsKVDBPtr front_db, back_db; AbsKVDBPtr front_db, back_db;
}; };
...@@ -83,9 +122,9 @@ protected: ...@@ -83,9 +122,9 @@ protected:
class ParamDictMgr { class ParamDictMgr {
public: public:
void UpdateAll(); void UpdateAll();
void InsertParamDict(std::string, AbsParamDictPtr); void InsertParamDict(std::string, ParamDictPtr);
protected: protected:
std::unordered_map<std::string, AbsParamDictPtr> ParamDictMap; std::unordered_map<std::string, ParamDictPtr> ParamDictMap;
}; };
...@@ -30,36 +30,6 @@ public: ...@@ -30,36 +30,6 @@ public:
static int db_count; static int db_count;
}; };
class RocksDBDictReader : public AbstractDictReader{
public:
std::string GetFileName();
void SetFileName(std::string);
std::string GetMD5();
bool CheckDiff();
std::chrono::system_clock::time_point GetTimeStamp();
void Read(AbstractParamDict*);
~RocksDBDictReader();
};
class RocksDBParamDict : public AbstractParamDict{
public:
std::vector<AbsDictReaderPtr> GetDictReaderLst();
void SetDictReaderLst(std::vector<AbsDictReaderPtr>);
std::vector<float> GetSparseValue(int64_t, int64_t);
std::vector<float> GetSparseValue(std::string, std::string);
bool InsertSparseValue(int64_t, int64_t, const std::vector<float>&);
bool InsertSparseValue(std::string, std::string, const std::vector<float>&);
void UpdateBaseModel();
void UpdateDeltaModel();
std::pair<AbsKVDBPtr, AbsKVDBPtr> GetKVDB();
void SetKVDB(std::pair<AbsKVDBPtr, AbsKVDBPtr>);
void CreateKVDB();
~RocksDBParamDict();
};
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "kvdb/kvdb_impl.h" #include "kvdb/kvdb_impl.h"
#include "kvdb/paddle_rocksdb.h" #include "kvdb/paddle_rocksdb.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <functional>
#include <string> #include <string>
#include <fstream> #include <fstream>
#include <chrono> #include <chrono>
...@@ -28,19 +29,19 @@ protected: ...@@ -28,19 +29,19 @@ protected:
static void SetUpTestCase() { static void SetUpTestCase() {
kvdb = std::make_shared<RocksKVDB>(); kvdb = std::make_shared<RocksKVDB>();
dict_reader = std::make_shared<RocksDBDictReader>(); dict_reader = std::make_shared<FileReader>();
param_dict = std::make_shared<RocksDBParamDict>(); param_dict = std::make_shared<ParamDict>();
} }
static AbsKVDBPtr kvdb; static AbsKVDBPtr kvdb;
static AbsDictReaderPtr dict_reader; static FileReaderPtr dict_reader;
static AbsParamDictPtr param_dict; static ParamDictPtr param_dict;
static ParamDictMgr dict_mgr; static ParamDictMgr dict_mgr;
}; };
AbsKVDBPtr KVDBTest::kvdb; AbsKVDBPtr KVDBTest::kvdb;
AbsDictReaderPtr KVDBTest::dict_reader; FileReaderPtr KVDBTest::dict_reader;
AbsParamDictPtr KVDBTest::param_dict; ParamDictPtr KVDBTest::param_dict;
ParamDictMgr KVDBTest::dict_mgr; ParamDictMgr KVDBTest::dict_mgr;
void GenerateTestIn(std::string); void GenerateTestIn(std::string);
...@@ -58,7 +59,7 @@ TEST_F(KVDBTest, AbstractKVDB_Unit_Test) { ...@@ -58,7 +59,7 @@ TEST_F(KVDBTest, AbstractKVDB_Unit_Test) {
} }
} }
TEST_F(KVDBTest, AbstractDictReader_Unit_Test) { TEST_F(KVDBTest, FileReader_Unit_Test) {
std::string test_in_filename = "abs_dict_reader_test_in.txt"; std::string test_in_filename = "abs_dict_reader_test_in.txt";
GenerateTestIn(test_in_filename); GenerateTestIn(test_in_filename);
dict_reader->SetFileName(test_in_filename); dict_reader->SetFileName(test_in_filename);
...@@ -81,9 +82,29 @@ TEST_F(KVDBTest, AbstractDictReader_Unit_Test) { ...@@ -81,9 +82,29 @@ TEST_F(KVDBTest, AbstractDictReader_Unit_Test) {
ASSERT_NE(timestamp_2, timestamp_3); ASSERT_NE(timestamp_2, timestamp_3);
} }
#include <cmath> #include <cmath>
TEST_F(KVDBTest, RocksDBParamDict_Unit_Test) { TEST_F(KVDBTest, ParamDict_Unit_Test) {
std::string test_in_filename = "abs_dict_reader_test_in.txt"; std::string test_in_filename = "abs_dict_reader_test_in.txt";
param_dict->SetDictReaderLst({dict_reader}); param_dict->SetFileReaderLst({test_in_filename});
param_dict->SetReader(
[] (std::string text) {
auto split = [](const std::string& s,
std::vector<std::string>& sv,
const char* delim = " ") {
sv.clear();
char* buffer = new char[s.size() + 1];
std::copy(s.begin(), s.end(), buffer);
char* p = strtok(buffer, delim);
do {
sv.push_back(p);
} while ((p = strtok(NULL, delim)));
return;
};
std::vector<std::string> text_split;
split(text, text_split, " ");
std::string key = text_split[0];
text_split.erase(text_split.begin());
return make_pair(key, text_split);
});
param_dict->CreateKVDB(); param_dict->CreateKVDB();
GenerateTestIn(test_in_filename); GenerateTestIn(test_in_filename);
...@@ -91,7 +112,7 @@ TEST_F(KVDBTest, RocksDBParamDict_Unit_Test) { ...@@ -91,7 +112,7 @@ TEST_F(KVDBTest, RocksDBParamDict_Unit_Test) {
std::this_thread::sleep_for(std::chrono::seconds(2)); std::this_thread::sleep_for(std::chrono::seconds(2));
std::vector<float> test_vec = param_dict->GetSparseValue(1, 1); std::vector<float> test_vec = param_dict->GetSparseValue("1", "");
ASSERT_LT(fabs(test_vec[0] - 1.0), 1e-2); ASSERT_LT(fabs(test_vec[0] - 1.0), 1e-2);
......
...@@ -16,74 +16,22 @@ ...@@ -16,74 +16,22 @@
#include <thread> #include <thread>
#include <iterator> #include <iterator>
#include <fstream> #include <fstream>
#include <algorithm>
#include <sstream> #include <sstream>
std::string RocksDBDictReader::GetFileName() {
return this->filename_;
}
void RocksDBDictReader::SetFileName(std::string filename) {
this->filename_ = filename;
this->last_md5_val_ = this->GetMD5();
this->time_stamp_ = std::chrono::system_clock::now();
}
std::string RocksDBDictReader::GetMD5() {
auto getCmdOut = [] (std::string cmd) {
std::string data;
FILE *stream = nullptr;
const int max_buffer = 256;
char buffer[max_buffer];
cmd.append(" 2>&1");
stream = popen(cmd.c_str(), "r");
if (stream) {
if (fgets(buffer, max_buffer, stream) != NULL) {
data.append(buffer);
}
}
return data;
};
std::string cmd = "md5sum " + this->filename_;
//TODO: throw exception if error occurs during execution of shell command
std::string md5val = getCmdOut(cmd);
this->time_stamp_ = md5val == this->last_md5_val_? this->time_stamp_: std::chrono::system_clock::now();
this->last_md5_val_ = md5val;
return md5val;
}
bool RocksDBDictReader::CheckDiff() { std::vector<FileReaderPtr> ParamDict::GetDictReaderLst() {
return this->GetMD5() == this->last_md5_val_; return this->file_reader_lst_;
} }
std::chrono::system_clock::time_point RocksDBDictReader::GetTimeStamp() { void ParamDict::SetFileReaderLst(std::vector<std::string> lst) {
return this->time_stamp_; for (size_t i = 0; i < lst.size(); i++) {
} FileReaderPtr fr = std::make_shared<FileReader>();
fr->SetFileName(lst[i]);
void RocksDBDictReader::Read(AbstractParamDict* param_dict) { this->file_reader_lst_.push_back(fr);
std::string line;
std::ifstream infile(this->filename_);
if (infile.is_open()) {
while (getline(infile, line)) {
//TODO: Write String Parse Here
// param_dict->InsertSparseValue();
}
} }
infile.close();
} }
RocksDBDictReader::~RocksDBDictReader() { std::vector<float> ParamDict::GetSparseValue(std::string feasign, std::string slot) {
//TODO: I imageine nothing to do here
}
std::vector<AbsDictReaderPtr> RocksDBParamDict::GetDictReaderLst() {
return this->dict_reader_lst_;
}
void RocksDBParamDict::SetDictReaderLst(std::vector<AbsDictReaderPtr> lst) {
this->dict_reader_lst_ = lst;
}
std::vector<float> RocksDBParamDict::GetSparseValue(std::string feasign, std::string slot) {
auto BytesToFloat = [](uint8_t* byteArray){ auto BytesToFloat = [](uint8_t* byteArray){
return *((float*)byteArray); return *((float*)byteArray);
}; };
...@@ -100,15 +48,19 @@ std::vector<float> RocksDBParamDict::GetSparseValue(std::string feasign, std::st ...@@ -100,15 +48,19 @@ std::vector<float> RocksDBParamDict::GetSparseValue(std::string feasign, std::st
return value; return value;
} }
std::vector<float> RocksDBParamDict::GetSparseValue(int64_t feasign, int64_t slot) { void ParamDict::SetReader(std::function<std::pair<Key, Value>(std::string)> func) {
read_func_ = func;
}
std::vector<float> ParamDict::GetSparseValue(int64_t feasign, int64_t slot) {
return this->GetSparseValue(std::to_string(feasign), std::to_string(slot)); return this->GetSparseValue(std::to_string(feasign), std::to_string(slot));
} }
bool RocksDBParamDict::InsertSparseValue(int64_t feasign, int64_t slot, const std::vector<float>& values) { bool ParamDict::InsertSparseValue(int64_t feasign, int64_t slot, const std::vector<float>& values) {
return this->InsertSparseValue(std::to_string(feasign), std::to_string(slot), values); return this->InsertSparseValue(std::to_string(feasign), std::to_string(slot), values);
} }
bool RocksDBParamDict::InsertSparseValue(std::string feasign, std::string slot, const std::vector<float>& values) { bool ParamDict::InsertSparseValue(std::string feasign, std::string slot, const std::vector<float>& values) {
auto FloatToBytes = [](float fvalue, uint8_t *arr){ auto FloatToBytes = [](float fvalue, uint8_t *arr){
unsigned char *pf = nullptr; unsigned char *pf = nullptr;
unsigned char *px = nullptr; unsigned char *px = nullptr;
...@@ -136,23 +88,29 @@ bool RocksDBParamDict::InsertSparseValue(std::string feasign, std::string slot, ...@@ -136,23 +88,29 @@ bool RocksDBParamDict::InsertSparseValue(std::string feasign, std::string slot,
return true; return true;
} }
void RocksDBParamDict::UpdateBaseModel() { void ParamDict::UpdateBaseModel() {
auto is_number = [] (const std::string& s)
{
return !s.empty() && std::find_if(s.begin(),
s.end(), [](char c) { return !std::isdigit(c); }) == s.end();
};
std::thread t([&] () { std::thread t([&] () {
for (AbsDictReaderPtr dict_reader: this->dict_reader_lst_) { for (FileReaderPtr file_reader: this->file_reader_lst_) {
if (dict_reader->CheckDiff()) { std::string line;
std::vector<std::string> strs; std::ifstream infile(file_reader->GetFileName());
dict_reader->Read(this); if (infile.is_open()) {
for (const std::string& str: strs) { while (getline(infile, line)) {
std::vector<std::string> arr; std::pair<Key, Value> kvpair = read_func_(line);
std::istringstream in(str);
copy(std::istream_iterator<std::string>(in), std::istream_iterator<std::string>(), back_inserter(arr));
std::vector<float> nums; std::vector<float> nums;
for (size_t i = 2; i < arr.size(); i++) { for (size_t i = 0; i < kvpair.second.size(); i++) {
nums.push_back(std::stof(arr[i])); if (is_number(kvpair.second[i])) {
nums.push_back(std::stof(kvpair.second[i]));
}
} }
this->InsertSparseValue(arr[0], arr[1], nums); this->InsertSparseValue(kvpair.first, "", nums);
} }
} }
infile.close();
} }
AbsKVDBPtr temp = front_db; AbsKVDBPtr temp = front_db;
front_db = back_db; front_db = back_db;
...@@ -162,27 +120,27 @@ void RocksDBParamDict::UpdateBaseModel() { ...@@ -162,27 +120,27 @@ void RocksDBParamDict::UpdateBaseModel() {
} }
void RocksDBParamDict::UpdateDeltaModel() { void ParamDict::UpdateDeltaModel() {
UpdateBaseModel(); UpdateBaseModel();
} }
std::pair<AbsKVDBPtr, AbsKVDBPtr> RocksDBParamDict::GetKVDB() { std::pair<AbsKVDBPtr, AbsKVDBPtr> ParamDict::GetKVDB() {
return {front_db, back_db}; return {front_db, back_db};
} }
void RocksDBParamDict::SetKVDB(std::pair<AbsKVDBPtr, AbsKVDBPtr> kvdbs) { void ParamDict::SetKVDB(std::pair<AbsKVDBPtr, AbsKVDBPtr> kvdbs) {
this->front_db = kvdbs.first; this->front_db = kvdbs.first;
this->back_db = kvdbs.second; this->back_db = kvdbs.second;
} }
void RocksDBParamDict::CreateKVDB() { void ParamDict::CreateKVDB() {
this->front_db = std::make_shared<RocksKVDB>(); this->front_db = std::make_shared<RocksKVDB>();
this->back_db = std::make_shared<RocksKVDB>(); this->back_db = std::make_shared<RocksKVDB>();
this->front_db->CreateDB(); this->front_db->CreateDB();
this->back_db->CreateDB(); this->back_db->CreateDB();
} }
RocksDBParamDict::~RocksDBParamDict() { ParamDict::~ParamDict() {
} }
......
...@@ -21,10 +21,8 @@ void ParamDictMgr::UpdateAll() { ...@@ -21,10 +21,8 @@ void ParamDictMgr::UpdateAll() {
} }
void ParamDictMgr::InsertParamDict(std::string key, AbsParamDictPtr value) { void ParamDictMgr::InsertParamDict(std::string key, ParamDictPtr value) {
this->ParamDictMap.insert(std::make_pair(key, value)); this->ParamDictMap.insert(std::make_pair(key, value));
} }
AbstractKVDB::~AbstractKVDB() {} AbstractKVDB::~AbstractKVDB() {}
AbstractDictReader::~AbstractDictReader() {}
AbstractParamDict::~AbstractParamDict() {}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册