提交 d3175966 编写于 作者: W wangjiawei04

modify API name and prepare for linking with CTR

Change-Id: Ia1c0d39a40d69937edcb43050d6d7d7088b55909
上级 81cd31a8
......@@ -19,12 +19,12 @@
#include <memory>
#include <chrono>
class AbstractKVDB;
class AbstractDictReader;
class AbstractParamDict;
class FileReader;
class ParamDict;
typedef std::shared_ptr<AbstractKVDB> AbsKVDBPtr;
typedef std::shared_ptr<AbstractDictReader> AbsDictReaderPtr;
typedef std::shared_ptr<AbstractParamDict> AbsParamDictPtr;
typedef std::shared_ptr<FileReader> FileReaderPtr;
typedef std::shared_ptr<ParamDict> ParamDictPtr;
class AbstractKVDB {
public:
......@@ -38,15 +38,50 @@ public:
// TODO: Implement RedisKVDB
//class RedisKVDB;
class AbstractDictReader {
class FileReader {
public:
virtual std::string GetFileName() = 0;
virtual void SetFileName(std::string) = 0;
virtual std::string GetMD5() = 0;
virtual bool CheckDiff() = 0;
virtual std::chrono::system_clock::time_point GetTimeStamp() = 0;
virtual void Read(AbstractParamDict*) = 0;
virtual ~AbstractDictReader() = 0;
inline virtual std::string GetFileName() {
return this->filename_;
}
inline virtual void SetFileName(std::string filename) {
this->filename_ = filename;
this->last_md5_val_ = this->GetMD5();
this->time_stamp_ = std::chrono::system_clock::now();
}
inline virtual std::string GetMD5() {
auto getCmdOut = [] (std::string cmd) {
std::string data;
FILE *stream = nullptr;
const int max_buffer = 256;
char buffer[max_buffer];
cmd.append(" 2>&1");
stream = popen(cmd.c_str(), "r");
if (stream) {
if (fgets(buffer, max_buffer, stream) != NULL) {
data.append(buffer);
}
}
return data;
};
std::string cmd = "md5sum " + this->filename_;
//TODO: throw exception if error occurs during execution of shell command
std::string md5val = getCmdOut(cmd);
this->time_stamp_ = md5val == this->last_md5_val_? this->time_stamp_: std::chrono::system_clock::now();
this->last_md5_val_ = md5val;
return md5val;
}
inline virtual bool CheckDiff() {
return this->GetMD5() == this->last_md5_val_;
}
inline virtual std::chrono::system_clock::time_point GetTimeStamp() {
return this->time_stamp_;
}
inline virtual ~FileReader() {};
protected:
std::string filename_;
std::string last_md5_val_;
......@@ -54,27 +89,31 @@ protected:
};
class AbstractParamDict {
class ParamDict {
typedef std::string Key;
typedef std::vector<std::string> Value;
public:
virtual std::vector<AbsDictReaderPtr> GetDictReaderLst() = 0;
virtual void SetDictReaderLst(std::vector<AbsDictReaderPtr>) = 0;
virtual std::vector<FileReaderPtr> GetDictReaderLst();
virtual void SetFileReaderLst(std::vector<std::string> lst);
virtual std::vector<float> GetSparseValue(int64_t, int64_t) = 0;
virtual std::vector<float> GetSparseValue(std::string, std::string) = 0;
virtual std::vector<float> GetSparseValue(int64_t, int64_t);
virtual std::vector<float> GetSparseValue(std::string, std::string);
virtual bool InsertSparseValue(int64_t, int64_t, const std::vector<float>&) = 0;
virtual bool InsertSparseValue(std::string, std::string, const std::vector<float>&) = 0;
virtual bool InsertSparseValue(int64_t, int64_t, const std::vector<float>&);
virtual bool InsertSparseValue(std::string, std::string, const std::vector<float>&);
virtual void UpdateBaseModel() = 0;
virtual void UpdateDeltaModel() = 0;
virtual void SetReader(std::function<std::pair<Key, Value>(std::string)>);
virtual void UpdateBaseModel();
virtual void UpdateDeltaModel();
virtual std::pair<AbsKVDBPtr, AbsKVDBPtr> GetKVDB() = 0;
virtual void SetKVDB(std::pair<AbsKVDBPtr, AbsKVDBPtr>) = 0;
virtual void CreateKVDB() = 0;
virtual std::pair<AbsKVDBPtr, AbsKVDBPtr> GetKVDB();
virtual void SetKVDB(std::pair<AbsKVDBPtr, AbsKVDBPtr>);
virtual void CreateKVDB();
virtual ~AbstractParamDict() = 0;
virtual ~ParamDict();
protected:
std::vector<AbsDictReaderPtr> dict_reader_lst_;
std::function<std::pair<Key, Value>(std::string)> read_func_;
std::vector<FileReaderPtr> file_reader_lst_;
AbsKVDBPtr front_db, back_db;
};
......@@ -83,9 +122,9 @@ protected:
class ParamDictMgr {
public:
void UpdateAll();
void InsertParamDict(std::string, AbsParamDictPtr);
void InsertParamDict(std::string, ParamDictPtr);
protected:
std::unordered_map<std::string, AbsParamDictPtr> ParamDictMap;
std::unordered_map<std::string, ParamDictPtr> ParamDictMap;
};
......@@ -30,36 +30,6 @@ public:
static int db_count;
};
class RocksDBDictReader : public AbstractDictReader{
public:
std::string GetFileName();
void SetFileName(std::string);
std::string GetMD5();
bool CheckDiff();
std::chrono::system_clock::time_point GetTimeStamp();
void Read(AbstractParamDict*);
~RocksDBDictReader();
};
class RocksDBParamDict : public AbstractParamDict{
public:
std::vector<AbsDictReaderPtr> GetDictReaderLst();
void SetDictReaderLst(std::vector<AbsDictReaderPtr>);
std::vector<float> GetSparseValue(int64_t, int64_t);
std::vector<float> GetSparseValue(std::string, std::string);
bool InsertSparseValue(int64_t, int64_t, const std::vector<float>&);
bool InsertSparseValue(std::string, std::string, const std::vector<float>&);
void UpdateBaseModel();
void UpdateDeltaModel();
std::pair<AbsKVDBPtr, AbsKVDBPtr> GetKVDB();
void SetKVDB(std::pair<AbsKVDBPtr, AbsKVDBPtr>);
void CreateKVDB();
~RocksDBParamDict();
};
......@@ -16,6 +16,7 @@
#include "kvdb/kvdb_impl.h"
#include "kvdb/paddle_rocksdb.h"
#include <gtest/gtest.h>
#include <functional>
#include <string>
#include <fstream>
#include <chrono>
......@@ -28,19 +29,19 @@ protected:
static void SetUpTestCase() {
kvdb = std::make_shared<RocksKVDB>();
dict_reader = std::make_shared<RocksDBDictReader>();
param_dict = std::make_shared<RocksDBParamDict>();
dict_reader = std::make_shared<FileReader>();
param_dict = std::make_shared<ParamDict>();
}
static AbsKVDBPtr kvdb;
static AbsDictReaderPtr dict_reader;
static AbsParamDictPtr param_dict;
static FileReaderPtr dict_reader;
static ParamDictPtr param_dict;
static ParamDictMgr dict_mgr;
};
AbsKVDBPtr KVDBTest::kvdb;
AbsDictReaderPtr KVDBTest::dict_reader;
AbsParamDictPtr KVDBTest::param_dict;
FileReaderPtr KVDBTest::dict_reader;
ParamDictPtr KVDBTest::param_dict;
ParamDictMgr KVDBTest::dict_mgr;
void GenerateTestIn(std::string);
......@@ -58,7 +59,7 @@ TEST_F(KVDBTest, AbstractKVDB_Unit_Test) {
}
}
TEST_F(KVDBTest, AbstractDictReader_Unit_Test) {
TEST_F(KVDBTest, FileReader_Unit_Test) {
std::string test_in_filename = "abs_dict_reader_test_in.txt";
GenerateTestIn(test_in_filename);
dict_reader->SetFileName(test_in_filename);
......@@ -81,9 +82,29 @@ TEST_F(KVDBTest, AbstractDictReader_Unit_Test) {
ASSERT_NE(timestamp_2, timestamp_3);
}
#include <cmath>
TEST_F(KVDBTest, RocksDBParamDict_Unit_Test) {
TEST_F(KVDBTest, ParamDict_Unit_Test) {
std::string test_in_filename = "abs_dict_reader_test_in.txt";
param_dict->SetDictReaderLst({dict_reader});
param_dict->SetFileReaderLst({test_in_filename});
param_dict->SetReader(
[] (std::string text) {
auto split = [](const std::string& s,
std::vector<std::string>& sv,
const char* delim = " ") {
sv.clear();
char* buffer = new char[s.size() + 1];
std::copy(s.begin(), s.end(), buffer);
char* p = strtok(buffer, delim);
do {
sv.push_back(p);
} while ((p = strtok(NULL, delim)));
return;
};
std::vector<std::string> text_split;
split(text, text_split, " ");
std::string key = text_split[0];
text_split.erase(text_split.begin());
return make_pair(key, text_split);
});
param_dict->CreateKVDB();
GenerateTestIn(test_in_filename);
......@@ -91,7 +112,7 @@ TEST_F(KVDBTest, RocksDBParamDict_Unit_Test) {
std::this_thread::sleep_for(std::chrono::seconds(2));
std::vector<float> test_vec = param_dict->GetSparseValue(1, 1);
std::vector<float> test_vec = param_dict->GetSparseValue("1", "");
ASSERT_LT(fabs(test_vec[0] - 1.0), 1e-2);
......
......@@ -16,74 +16,22 @@
#include <thread>
#include <iterator>
#include <fstream>
#include <algorithm>
#include <sstream>
std::string RocksDBDictReader::GetFileName() {
return this->filename_;
}
void RocksDBDictReader::SetFileName(std::string filename) {
this->filename_ = filename;
this->last_md5_val_ = this->GetMD5();
this->time_stamp_ = std::chrono::system_clock::now();
}
std::string RocksDBDictReader::GetMD5() {
auto getCmdOut = [] (std::string cmd) {
std::string data;
FILE *stream = nullptr;
const int max_buffer = 256;
char buffer[max_buffer];
cmd.append(" 2>&1");
stream = popen(cmd.c_str(), "r");
if (stream) {
if (fgets(buffer, max_buffer, stream) != NULL) {
data.append(buffer);
}
}
return data;
};
std::string cmd = "md5sum " + this->filename_;
//TODO: throw exception if error occurs during execution of shell command
std::string md5val = getCmdOut(cmd);
this->time_stamp_ = md5val == this->last_md5_val_? this->time_stamp_: std::chrono::system_clock::now();
this->last_md5_val_ = md5val;
return md5val;
}
bool RocksDBDictReader::CheckDiff() {
return this->GetMD5() == this->last_md5_val_;
std::vector<FileReaderPtr> ParamDict::GetDictReaderLst() {
return this->file_reader_lst_;
}
std::chrono::system_clock::time_point RocksDBDictReader::GetTimeStamp() {
return this->time_stamp_;
}
void RocksDBDictReader::Read(AbstractParamDict* param_dict) {
std::string line;
std::ifstream infile(this->filename_);
if (infile.is_open()) {
while (getline(infile, line)) {
//TODO: Write String Parse Here
// param_dict->InsertSparseValue();
}
void ParamDict::SetFileReaderLst(std::vector<std::string> lst) {
for (size_t i = 0; i < lst.size(); i++) {
FileReaderPtr fr = std::make_shared<FileReader>();
fr->SetFileName(lst[i]);
this->file_reader_lst_.push_back(fr);
}
infile.close();
}
RocksDBDictReader::~RocksDBDictReader() {
//TODO: I imageine nothing to do here
}
std::vector<AbsDictReaderPtr> RocksDBParamDict::GetDictReaderLst() {
return this->dict_reader_lst_;
}
void RocksDBParamDict::SetDictReaderLst(std::vector<AbsDictReaderPtr> lst) {
this->dict_reader_lst_ = lst;
}
std::vector<float> RocksDBParamDict::GetSparseValue(std::string feasign, std::string slot) {
std::vector<float> ParamDict::GetSparseValue(std::string feasign, std::string slot) {
auto BytesToFloat = [](uint8_t* byteArray){
return *((float*)byteArray);
};
......@@ -100,15 +48,19 @@ std::vector<float> RocksDBParamDict::GetSparseValue(std::string feasign, std::st
return value;
}
std::vector<float> RocksDBParamDict::GetSparseValue(int64_t feasign, int64_t slot) {
void ParamDict::SetReader(std::function<std::pair<Key, Value>(std::string)> func) {
read_func_ = func;
}
std::vector<float> ParamDict::GetSparseValue(int64_t feasign, int64_t slot) {
return this->GetSparseValue(std::to_string(feasign), std::to_string(slot));
}
bool RocksDBParamDict::InsertSparseValue(int64_t feasign, int64_t slot, const std::vector<float>& values) {
bool ParamDict::InsertSparseValue(int64_t feasign, int64_t slot, const std::vector<float>& values) {
return this->InsertSparseValue(std::to_string(feasign), std::to_string(slot), values);
}
bool RocksDBParamDict::InsertSparseValue(std::string feasign, std::string slot, const std::vector<float>& values) {
bool ParamDict::InsertSparseValue(std::string feasign, std::string slot, const std::vector<float>& values) {
auto FloatToBytes = [](float fvalue, uint8_t *arr){
unsigned char *pf = nullptr;
unsigned char *px = nullptr;
......@@ -136,23 +88,29 @@ bool RocksDBParamDict::InsertSparseValue(std::string feasign, std::string slot,
return true;
}
void RocksDBParamDict::UpdateBaseModel() {
void ParamDict::UpdateBaseModel() {
auto is_number = [] (const std::string& s)
{
return !s.empty() && std::find_if(s.begin(),
s.end(), [](char c) { return !std::isdigit(c); }) == s.end();
};
std::thread t([&] () {
for (AbsDictReaderPtr dict_reader: this->dict_reader_lst_) {
if (dict_reader->CheckDiff()) {
std::vector<std::string> strs;
dict_reader->Read(this);
for (const std::string& str: strs) {
std::vector<std::string> arr;
std::istringstream in(str);
copy(std::istream_iterator<std::string>(in), std::istream_iterator<std::string>(), back_inserter(arr));
for (FileReaderPtr file_reader: this->file_reader_lst_) {
std::string line;
std::ifstream infile(file_reader->GetFileName());
if (infile.is_open()) {
while (getline(infile, line)) {
std::pair<Key, Value> kvpair = read_func_(line);
std::vector<float> nums;
for (size_t i = 2; i < arr.size(); i++) {
nums.push_back(std::stof(arr[i]));
for (size_t i = 0; i < kvpair.second.size(); i++) {
if (is_number(kvpair.second[i])) {
nums.push_back(std::stof(kvpair.second[i]));
}
}
this->InsertSparseValue(arr[0], arr[1], nums);
this->InsertSparseValue(kvpair.first, "", nums);
}
}
infile.close();
}
AbsKVDBPtr temp = front_db;
front_db = back_db;
......@@ -162,27 +120,27 @@ void RocksDBParamDict::UpdateBaseModel() {
}
void RocksDBParamDict::UpdateDeltaModel() {
void ParamDict::UpdateDeltaModel() {
UpdateBaseModel();
}
std::pair<AbsKVDBPtr, AbsKVDBPtr> RocksDBParamDict::GetKVDB() {
std::pair<AbsKVDBPtr, AbsKVDBPtr> ParamDict::GetKVDB() {
return {front_db, back_db};
}
void RocksDBParamDict::SetKVDB(std::pair<AbsKVDBPtr, AbsKVDBPtr> kvdbs) {
void ParamDict::SetKVDB(std::pair<AbsKVDBPtr, AbsKVDBPtr> kvdbs) {
this->front_db = kvdbs.first;
this->back_db = kvdbs.second;
}
void RocksDBParamDict::CreateKVDB() {
void ParamDict::CreateKVDB() {
this->front_db = std::make_shared<RocksKVDB>();
this->back_db = std::make_shared<RocksKVDB>();
this->front_db->CreateDB();
this->back_db->CreateDB();
}
RocksDBParamDict::~RocksDBParamDict() {
ParamDict::~ParamDict() {
}
......
......@@ -21,10 +21,8 @@ void ParamDictMgr::UpdateAll() {
}
void ParamDictMgr::InsertParamDict(std::string key, AbsParamDictPtr value) {
void ParamDictMgr::InsertParamDict(std::string key, ParamDictPtr value) {
this->ParamDictMap.insert(std::make_pair(key, value));
}
AbstractKVDB::~AbstractKVDB() {}
AbstractDictReader::~AbstractDictReader() {}
AbstractParamDict::~AbstractParamDict() {}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册