提交 22dff409 编写于 作者: J Jiawei Wang 提交者: GitHub

Merge pull request #18 from bjjwwang/my_cool_stuff

Modifications about code-format and 
...@@ -42,12 +42,6 @@ repos: ...@@ -42,12 +42,6 @@ repos:
entry: bash ./tools/codestyle/pylint_pre_commit.hook entry: bash ./tools/codestyle/pylint_pre_commit.hook
language: system language: system
files: \.(py)$ files: \.(py)$
- repo: https://github.com/PaddlePaddle/pre-commit-golang
sha: 8337620115c25ff8333f1b1a493bd031049bd7c0
hooks:
- id: go-fmt
types:
- go
- repo: local - repo: local
hooks: hooks:
- id: copyright_checker - id: copyright_checker
......
...@@ -38,6 +38,7 @@ ExternalProject_Add( ...@@ -38,6 +38,7 @@ ExternalProject_Add(
-DCMAKE_INSTALL_PREFIX=${OPENCV_INSTALL_DIR} -DCMAKE_INSTALL_PREFIX=${OPENCV_INSTALL_DIR}
-DCMAKE_INSTALL_LIBDIR=${OPENCV_INSTALL_DIR}/lib -DCMAKE_INSTALL_LIBDIR=${OPENCV_INSTALL_DIR}/lib
-DCMAKE_POSITION_INDEPENDENT_CODE=ON -DCMAKE_POSITION_INDEPENDENT_CODE=ON
-DWITH_GTK=OFF
-DBUILD_TESTS=OFF -DBUILD_TESTS=OFF
-DBUILD_PERF_TESTS=OFF -DBUILD_PERF_TESTS=OFF
-DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE} -DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE}
......
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
INCLUDE(ExternalProject)
SET(REDISCLIENT_SOURCES_DIR ${THIRD_PARTY_PATH}/redis++)
SET(REDISCLIENT_INSTALL_DIR ${THIRD_PARTY_PATH}/install/redis++)
SET(REDISCLIENT_INCLUDE_DIR "${REDISCLIENT_INSTALL_DIR}/include" CACHE PATH "redis++ include directory." FORCE)
SET(REDISCLIENT_LIBRARIES "${REDISCLIENT_INSTALL_DIR}/lib/libredis++.a" CACHE FILEPATH "redis++ library." FORCE)
INCLUDE_DIRECTORIES(${REDISCLIENT_INCLUDE_DIR})
ExternalProject_Add(
extern_redis++
${EXTERNAL_PROJECT_LOG_ARGS}
PREFIX ${REDISCLIENT_SOURCES_DIR}
GIT_REPOSITORY "https://github.com/sewenew/redis-plus-plus"
GIT_TAG master
UPDATE_COMMAND ""
CONFIGURE_COMMAND ""
BUILD_COMMAND cmake . && CXXFLAGS=-fPIC make -j ${NUM_OF_PROCESSOR} static
INSTALL_COMMAND mkdir -p ${REDISCLIENT_INSTALL_DIR}/lib/
&& cp ${REDISCLIENT_SOURCES_DIR}/src/extern_redis++/lib/libredis++.a ${REDISCLIENT_LIBRARIES}
&& cp -r ${REDISCLIENT_SOURCES_DIR}/src/extern_redis++/src/sw/redis++/ ${REDISCLIENT_INSTALL_DIR}/include/redis++/
BUILD_IN_SOURCE 1
)
ADD_DEPENDENCIES(extern_redis++ snappy)
ADD_LIBRARY(redis++ STATIC IMPORTED GLOBAL)
SET_PROPERTY(TARGET redis++ PROPERTY IMPORTED_LOCATION ${ROCKSDB_LIBRARIES})
ADD_DEPENDENCIES(redis++ extern_redis++)
LIST(APPEND external_project_dependencies redis++)
...@@ -16,12 +16,12 @@ ...@@ -16,12 +16,12 @@
#include <sys/types.h> #include <sys/types.h>
#include <unistd.h> #include <unistd.h>
#include <algorithm>
#include <fstream> #include <fstream>
#include "sdk-cpp/builtin_format.pb.h" #include "sdk-cpp/builtin_format.pb.h"
#include "sdk-cpp/echo_kvdb_service.pb.h" #include "sdk-cpp/echo_kvdb_service.pb.h"
#include "sdk-cpp/include/common.h" #include "sdk-cpp/include/common.h"
#include "sdk-cpp/include/predictor_sdk.h" #include "sdk-cpp/include/predictor_sdk.h"
#include <algorithm>
using baidu::paddle_serving::sdk_cpp::Predictor; using baidu::paddle_serving::sdk_cpp::Predictor;
using baidu::paddle_serving::sdk_cpp::PredictorApi; using baidu::paddle_serving::sdk_cpp::PredictorApi;
using baidu::paddle_serving::predictor::echo_kvdb_service::Request; using baidu::paddle_serving::predictor::echo_kvdb_service::Request;
...@@ -48,11 +48,11 @@ void print_res(const Request& req, ...@@ -48,11 +48,11 @@ void print_res(const Request& req,
uint64_t elapse_ms) { uint64_t elapse_ms) {
LOG(INFO) << "Receive Response size: " << res.ress_size(); LOG(INFO) << "Receive Response size: " << res.ress_size();
for (size_t i = 0; i < res.ress_size(); i++) { for (size_t i = 0; i < res.ress_size(); i++) {
KVDBRes val = res.ress(i); KVDBRes val = res.ress(i);
LOG(INFO) << "Receive value from demo-server: " << val.value(); LOG(INFO) << "Receive value from demo-server: " << val.value();
} }
LOG(INFO) << "Succ call predictor[echo_kvdb_service], the tag is: " << route_tag LOG(INFO) << "Succ call predictor[echo_kvdb_service], the tag is: "
<< ", elapse_ms: " << elapse_ms; << route_tag << ", elapse_ms: " << elapse_ms;
} }
int main(int argc, char** argv) { int main(int argc, char** argv) {
...@@ -100,7 +100,7 @@ int main(int argc, char** argv) { ...@@ -100,7 +100,7 @@ int main(int argc, char** argv) {
while (true) { while (true) {
if (global_key > 10000) { if (global_key > 10000) {
break; break;
} }
timeval start; timeval start;
gettimeofday(&start, NULL); gettimeofday(&start, NULL);
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
// You may obtain a copy of the License at // You may obtain a copy of the License at
...@@ -23,36 +23,34 @@ using baidu::paddle_serving::predictor::format::KVDBRes; ...@@ -23,36 +23,34 @@ using baidu::paddle_serving::predictor::format::KVDBRes;
using baidu::paddle_serving::predictor::echo_kvdb_service::Request; using baidu::paddle_serving::predictor::echo_kvdb_service::Request;
using baidu::paddle_serving::predictor::echo_kvdb_service::Response; using baidu::paddle_serving::predictor::echo_kvdb_service::Response;
int KVDBEchoOp::inference() { int KVDBEchoOp::inference() { debug(); }
debug();
}
int KVDBEchoOp::debug() { int KVDBEchoOp::debug() {
//TODO: implement DEBUG mode // TODO: implement DEBUG mode
baidu::paddle_serving::predictor::Resource& resource = baidu::paddle_serving::predictor::Resource::instance(); baidu::paddle_serving::predictor::Resource& resource =
std::shared_ptr<RocksDBWrapper> db = resource.getDB(); baidu::paddle_serving::predictor::Resource::instance();
const Request* req = dynamic_cast<const Request*>(get_request_message()); std::shared_ptr<RocksDBWrapper> db = resource.getDB();
Response* res = mutable_data<Response>(); const Request* req = dynamic_cast<const Request*>(get_request_message());
LOG(INFO) << "Receive request in KVDB echo service: " << req->ShortDebugString(); Response* res = mutable_data<Response>();
for (size_t i = 0; i < req->reqs_size(); i++) { LOG(INFO) << "Receive request in KVDB echo service: "
auto kvdbreq = req->reqs(i); << req->ShortDebugString();
std::string op = kvdbreq.op(); for (size_t i = 0; i < req->reqs_size(); i++) {
std::string key = kvdbreq.key(); auto kvdbreq = req->reqs(i);
std::string val = kvdbreq.value(); std::string op = kvdbreq.op();
if (op == "SET") { std::string key = kvdbreq.key();
db->Put(key, val); std::string val = kvdbreq.value();
KVDBRes* kvdb_value_res = res->mutable_ress()->Add(); if (op == "SET") {
kvdb_value_res -> set_value("OK"); db->Put(key, val);
} else if (op == "GET") { KVDBRes* kvdb_value_res = res->mutable_ress()->Add();
std::string getvalue = db->Get(key); kvdb_value_res->set_value("OK");
KVDBRes* kvdb_value_res = res->mutable_ress()->Add(); } else if (op == "GET") {
kvdb_value_res -> set_value(getvalue); std::string getvalue = db->Get(key);
} KVDBRes* kvdb_value_res = res->mutable_ress()->Add();
kvdb_value_res->set_value(getvalue);
} }
return 0; }
return 0;
} }
DEFINE_OP(KVDBEchoOp); DEFINE_OP(KVDBEchoOp);
} }
} }
......
...@@ -14,23 +14,24 @@ ...@@ -14,23 +14,24 @@
#pragma once #pragma once
#include "demo-serving/echo_kvdb_service.pb.h" #include "demo-serving/echo_kvdb_service.pb.h"
#include "predictor/framework/resource.h" #include "kvdb/paddle_rocksdb.h"
#include "predictor/common/inner_common.h" #include "predictor/common/inner_common.h"
#include "predictor/framework/channel.h" #include "predictor/framework/channel.h"
#include "predictor/framework/op_repository.h" #include "predictor/framework/op_repository.h"
#include "predictor/framework/resource.h"
#include "predictor/op/op.h" #include "predictor/op/op.h"
#include "kvdb/paddle_rocksdb.h"
namespace baidu { namespace baidu {
namespace paddle_serving { namespace paddle_serving {
namespace predictor { namespace predictor {
class KVDBEchoOp: public OpWithChannel<baidu::paddle_serving::predictor::echo_kvdb_service::Response> { class KVDBEchoOp
public: : public OpWithChannel<
DECLARE_OP(KVDBEchoOp); baidu::paddle_serving::predictor::echo_kvdb_service::Response> {
int inference(); public:
int debug(); DECLARE_OP(KVDBEchoOp);
int inference();
int debug();
}; };
} // namespace predictor } // namespace predictor
} // namespace paddle_serving } // namespace paddle_serving
......
...@@ -64,9 +64,10 @@ int ReaderOp::inference() { ...@@ -64,9 +64,10 @@ int ReaderOp::inference() {
size_t dense_capacity = 3 * resize.width * resize.height; size_t dense_capacity = 3 * resize.width * resize.height;
size_t len = dense_capacity * sizeof(float) * sample_size; size_t len = dense_capacity * sizeof(float) * sample_size;
// Allocate buffer in PaddleTensor, so that buffer will be managed by the Tensor // Allocate buffer in PaddleTensor, so that buffer will be managed by the
// Tensor
in_tensor.data.Resize(len); in_tensor.data.Resize(len);
float *data = reinterpret_cast<float *>(in_tensor.data.data()); float* data = reinterpret_cast<float*>(in_tensor.data.data());
if (in_tensor.data.data() == NULL) { if (in_tensor.data.data() == NULL) {
LOG(ERROR) << "Failed create temp float array, " LOG(ERROR) << "Failed create temp float array, "
<< "size=" << dense_capacity * sample_size * sizeof(float); << "size=" << dense_capacity * sample_size * sizeof(float);
......
...@@ -20,15 +20,15 @@ package baidu.paddle_serving.predictor.echo_kvdb_service; ...@@ -20,15 +20,15 @@ package baidu.paddle_serving.predictor.echo_kvdb_service;
option cc_generic_services = true; option cc_generic_services = true;
message Request { message Request {
repeated baidu.paddle_serving.predictor.format.KVDBReq reqs = 1; repeated baidu.paddle_serving.predictor.format.KVDBReq reqs = 1;
}; };
message Response { message Response {
repeated baidu.paddle_serving.predictor.format.KVDBRes ress = 1; repeated baidu.paddle_serving.predictor.format.KVDBRes ress = 1;
}; };
service EchoKVDBService { service EchoKVDBService {
rpc inference(Request) returns (Response); rpc inference(Request) returns (Response);
rpc debug(Request) returns (Response); rpc debug(Request) returns (Response);
option (pds.options).generate_impl = true; option (pds.options).generate_impl = true;
}; };
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
// //
...@@ -14,10 +27,10 @@ ...@@ -14,10 +27,10 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <vector>
#include <unordered_map>
#include <memory>
#include <chrono> #include <chrono>
#include <memory>
#include <unordered_map>
#include <vector>
class AbstractKVDB; class AbstractKVDB;
class FileReader; class FileReader;
class ParamDict; class ParamDict;
...@@ -27,104 +40,102 @@ typedef std::shared_ptr<FileReader> FileReaderPtr; ...@@ -27,104 +40,102 @@ typedef std::shared_ptr<FileReader> FileReaderPtr;
typedef std::shared_ptr<ParamDict> ParamDictPtr; typedef std::shared_ptr<ParamDict> ParamDictPtr;
class AbstractKVDB { class AbstractKVDB {
public: public:
virtual void CreateDB() = 0; virtual void CreateDB() = 0;
virtual void SetDBName(std::string) = 0; virtual void SetDBName(std::string) = 0;
virtual void Set(std::string key, std::string value) = 0; virtual void Set(std::string key, std::string value) = 0;
virtual std::string Get(std::string key) = 0; virtual std::string Get(std::string key) = 0;
virtual ~AbstractKVDB() = 0; virtual ~AbstractKVDB() = 0;
}; };
// TODO: Implement RedisKVDB
//class RedisKVDB;
class FileReader { class FileReader {
public: public:
inline virtual std::string GetFileName() { inline virtual std::string GetFileName() { return this->filename_; }
return this->filename_;
} inline virtual void SetFileName(std::string filename) {
this->filename_ = filename;
inline virtual void SetFileName(std::string filename) { this->last_md5_val_ = this->GetMD5();
this->filename_ = filename; this->time_stamp_ = std::chrono::system_clock::now();
this->last_md5_val_ = this->GetMD5(); }
this->time_stamp_ = std::chrono::system_clock::now();
} inline virtual std::string GetMD5() {
auto getCmdOut = [](std::string cmd) {
inline virtual std::string GetMD5() { std::string data;
auto getCmdOut = [] (std::string cmd) { FILE *stream = nullptr;
std::string data; const int max_buffer = 256;
FILE *stream = nullptr; char buffer[max_buffer];
const int max_buffer = 256; cmd.append(" 2>&1");
char buffer[max_buffer]; stream = popen(cmd.c_str(), "r");
cmd.append(" 2>&1"); if (stream) {
stream = popen(cmd.c_str(), "r"); if (fgets(buffer, max_buffer, stream) != NULL) {
if (stream) { data.append(buffer);
if (fgets(buffer, max_buffer, stream) != NULL) { }
data.append(buffer); }
} return data;
} };
return data; std::string cmd = "md5sum " + this->filename_;
}; // TODO: throw exception if error occurs during execution of shell command
std::string cmd = "md5sum " + this->filename_; std::string md5val = getCmdOut(cmd);
//TODO: throw exception if error occurs during execution of shell command this->time_stamp_ = md5val == this->last_md5_val_
std::string md5val = getCmdOut(cmd); ? this->time_stamp_
this->time_stamp_ = md5val == this->last_md5_val_? this->time_stamp_: std::chrono::system_clock::now(); : std::chrono::system_clock::now();
this->last_md5_val_ = md5val; this->last_md5_val_ = md5val;
return md5val; return md5val;
} }
inline virtual bool CheckDiff() { inline virtual bool CheckDiff() {
return this->GetMD5() == this->last_md5_val_; return this->GetMD5() == this->last_md5_val_;
} }
inline virtual std::chrono::system_clock::time_point GetTimeStamp() { inline virtual std::chrono::system_clock::time_point GetTimeStamp() {
return this->time_stamp_; return this->time_stamp_;
} }
inline virtual ~FileReader() {}; inline virtual ~FileReader(){};
protected:
std::string filename_; private:
std::string last_md5_val_; std::string filename_;
std::chrono::system_clock::time_point time_stamp_; std::string last_md5_val_;
std::chrono::system_clock::time_point time_stamp_;
}; };
class ParamDict { class ParamDict {
typedef std::string Key; typedef std::string Key;
typedef std::vector<std::string> Value; typedef std::vector<std::string> Value;
public:
virtual std::vector<FileReaderPtr> GetDictReaderLst();
virtual void SetFileReaderLst(std::vector<std::string> lst);
virtual std::vector<float> GetSparseValue(int64_t, int64_t);
virtual std::vector<float> GetSparseValue(std::string, std::string);
virtual bool InsertSparseValue(int64_t, int64_t, const std::vector<float>&);
virtual bool InsertSparseValue(std::string, std::string, const std::vector<float>&);
virtual void SetReader(std::function<std::pair<Key, Value>(std::string)>);
virtual void UpdateBaseModel();
virtual void UpdateDeltaModel();
virtual std::pair<AbsKVDBPtr, AbsKVDBPtr> GetKVDB();
virtual void SetKVDB(std::pair<AbsKVDBPtr, AbsKVDBPtr>);
virtual void CreateKVDB();
virtual ~ParamDict();
protected:
std::function<std::pair<Key, Value>(std::string)> read_func_;
std::vector<FileReaderPtr> file_reader_lst_;
AbsKVDBPtr front_db, back_db;
};
public:
virtual std::vector<FileReaderPtr> GetDictReaderLst();
virtual void SetFileReaderLst(std::vector<std::string> lst);
virtual std::vector<float> GetSparseValue(int64_t, int64_t);
virtual std::vector<float> GetSparseValue(std::string, std::string);
class ParamDictMgr { virtual bool InsertSparseValue(int64_t, int64_t, const std::vector<float> &);
public: virtual bool InsertSparseValue(std::string,
void UpdateAll(); std::string,
void InsertParamDict(std::string, ParamDictPtr); const std::vector<float> &);
virtual void SetReader(std::function<std::pair<Key, Value>(std::string)>);
virtual void UpdateBaseModel();
virtual void UpdateDeltaModel();
virtual std::pair<AbsKVDBPtr, AbsKVDBPtr> GetKVDB();
virtual void SetKVDB(std::pair<AbsKVDBPtr, AbsKVDBPtr>);
virtual void CreateKVDB();
virtual ~ParamDict();
protected: private:
std::unordered_map<std::string, ParamDictPtr> ParamDictMap; std::function<std::pair<Key, Value>(std::string)> read_func_;
std::vector<FileReaderPtr> file_reader_lst_;
AbsKVDBPtr front_db, back_db;
}; };
class ParamDictMgr {
public:
void UpdateAll();
void InsertParamDict(std::string, ParamDictPtr);
private:
std::unordered_map<std::string, ParamDictPtr> ParamDictMap;
};
...@@ -15,26 +15,25 @@ ...@@ -15,26 +15,25 @@
#pragma once #pragma once
#include <memory> #include <memory>
#include <string> #include <string>
#include "rocksdb/compaction_filter.h"
#include "rocksdb/db.h" #include "rocksdb/db.h"
#include "rocksdb/filter_policy.h"
#include "rocksdb/options.h" #include "rocksdb/options.h"
#include "rocksdb/slice.h" #include "rocksdb/slice.h"
#include "rocksdb/sst_file_writer.h" #include "rocksdb/sst_file_writer.h"
#include "rocksdb/table.h" #include "rocksdb/table.h"
#include "rocksdb/compaction_filter.h"
#include "rocksdb/filter_policy.h"
class RocksDBWrapper { class RocksDBWrapper {
public: public:
RocksDBWrapper(std::string db_name); RocksDBWrapper(std::string db_name);
std::string Get(std::string key); std::string Get(std::string key);
bool Put(std::string key, std::string value); bool Put(std::string key, std::string value);
void SetDBName(std::string db_name); void SetDBName(std::string db_name);
static std::shared_ptr<RocksDBWrapper> RocksDBWrapperFactory(std::string db_name = "SparseMatrix"); static std::shared_ptr<RocksDBWrapper> RocksDBWrapperFactory(
std::string db_name = "SparseMatrix");
protected: private:
rocksdb::DB *db_; rocksdb::DB *db_;
std::string db_name_; std::string db_name_;
}; };
...@@ -13,23 +13,19 @@ ...@@ -13,23 +13,19 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include "kvdb_impl.h" #include "kvdb/kvdb_impl.h"
#include "paddle_rocksdb.h" #include "kvdb/paddle_rocksdb.h"
class RocksKVDB: public AbstractKVDB { class RocksKVDB : public AbstractKVDB {
public: public:
void CreateDB(); void CreateDB();
void SetDBName(std::string); void SetDBName(std::string);
void Set(std::string key, std::string value); void Set(std::string key, std::string value);
std::string Get(std::string key); std::string Get(std::string key);
~RocksKVDB(); ~RocksKVDB();
protected: private:
std::shared_ptr<RocksDBWrapper> db_; std::shared_ptr<RocksDBWrapper> db_;
public: public:
static int db_count; static int db_count;
}; };
...@@ -12,65 +12,56 @@ ...@@ -12,65 +12,56 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "kvdb/rocksdb_impl.h"
#include "kvdb/kvdb_impl.h"
#include "kvdb/paddle_rocksdb.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <string> #include <chrono>
#include <fstream> #include <fstream>
#include <sstream> #include <sstream>
#include <chrono> #include <string>
#include <thread> #include <thread>
#include "kvdb/kvdb_impl.h"
#include "kvdb/paddle_rocksdb.h"
#include "kvdb/rocksdb_impl.h"
class KVDBTest : public ::testing::Test { class KVDBTest : public ::testing::Test {
protected: protected:
void SetUp() override{ void SetUp() override {}
}
static void SetUpTestCase() {
}
static void SetUpTestCase() {}
}; };
int my_argc; int my_argc;
char** my_argv; char** my_argv;
std::vector<std::string> StringSplit(std::string str, char split) { std::vector<std::string> StringSplit(std::string str, char split) {
std::vector<std::string> strs; std::vector<std::string> strs;
std::istringstream f(str); std::istringstream f(str);
std:: string s; std::string s;
while (getline(f, s, split)) { while (getline(f, s, split)) {
strs.push_back(s); strs.push_back(s);
} }
return strs; return strs;
} }
TEST_F(KVDBTest, AbstractKVDB_Func_Test) { TEST_F(KVDBTest, AbstractKVDB_Func_Test) {
AbsKVDBPtr kvdb = std::make_shared<RocksKVDB>(); AbsKVDBPtr kvdb = std::make_shared<RocksKVDB>();
kvdb->CreateDB(); kvdb->CreateDB();
std::string set_list = "setlist.txt"; std::string set_list = "setlist.txt";
std::string get_list = "getlist.txt"; std::string get_list = "getlist.txt";
std::ifstream set_file(set_list); std::ifstream set_file(set_list);
std::ifstream get_file(get_list); std::ifstream get_file(get_list);
for (std::string line; getline(set_file, line); ) for (std::string line; getline(set_file, line);) {
{ std::vector<std::string> strs = StringSplit(line, ' ');
std::vector<std::string> strs = StringSplit (line, ' '); kvdb->Set(strs[0], strs[1]);
kvdb->Set(strs[0], strs[1]); }
}
for (std::string line; getline(get_file, line); ) { for (std::string line; getline(get_file, line);) {
std::vector<std::string> strs = StringSplit(line, ' '); std::vector<std::string> strs = StringSplit(line, ' ');
std::string val = kvdb->Get(strs[0]); std::string val = kvdb->Get(strs[0]);
ASSERT_EQ(val, strs[1]); ASSERT_EQ(val, strs[1]);
} }
} }
int main(int argc, char** argv) { int main(int argc, char** argv) {
my_argc = argc; my_argc = argc;
my_argv = argv; my_argv = argv;
::testing::InitGoogleTest(&argc, argv); ::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS(); return RUN_ALL_TESTS();
} }
...@@ -12,63 +12,59 @@ ...@@ -12,63 +12,59 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "kvdb/rocksdb_impl.h"
#include "kvdb/kvdb_impl.h"
#include "kvdb/paddle_rocksdb.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <string>
#include <fstream>
#include <chrono> #include <chrono>
#include <fstream>
#include <string>
#include <thread> #include <thread>
#include "kvdb/kvdb_impl.h"
#include "kvdb/paddle_rocksdb.h"
#include "kvdb/rocksdb_impl.h"
class KVDBTest : public ::testing::Test { class KVDBTest : public ::testing::Test {
protected: protected:
void SetUp() override{ void SetUp() override {}
}
static void SetUpTestCase() {
} static void SetUpTestCase() {}
}; };
int my_argc; int my_argc;
char** my_argv; char** my_argv;
void db_thread_test(AbsKVDBPtr kvdb, int size) { void db_thread_test(AbsKVDBPtr kvdb, int size) {
for (int i = 0; i < size; i++) { for (int i = 0; i < size; i++) {
kvdb->Set(std::to_string(i), std::to_string(i)); kvdb->Set(std::to_string(i), std::to_string(i));
kvdb->Get(std::to_string(i)); kvdb->Get(std::to_string(i));
} }
} }
TEST_F(KVDBTest, AbstractKVDB_Thread_Test) { TEST_F(KVDBTest, AbstractKVDB_Thread_Test) {
if (my_argc != 3) { if (my_argc != 3) {
std::cerr << "illegal input! should be db_thread ${num_of_thread} ${num_of_ops_each_thread}" << std::endl; std::cerr << "illegal input! should be db_thread ${num_of_thread} "
return; "${num_of_ops_each_thread}"
} << std::endl;
int num_of_thread = atoi(my_argv[1]);
int nums_of_ops_each_thread = atoi(my_argv[2]);
std::vector<AbsKVDBPtr> kvdbptrs;
for (int i= 0; i < num_of_thread; i++) {
kvdbptrs.push_back(std::make_shared<RocksKVDB>());
kvdbptrs[i]->CreateDB();
}
std::vector<std::thread> tarr;
for (int i = 0; i< num_of_thread; i++) {
tarr.push_back(std::thread(db_thread_test, kvdbptrs[i], nums_of_ops_each_thread));
}
for (int i = 0; i< num_of_thread; i++) {
tarr[i].join();
}
return; return;
}
int num_of_thread = atoi(my_argv[1]);
int nums_of_ops_each_thread = atoi(my_argv[2]);
std::vector<AbsKVDBPtr> kvdbptrs;
for (int i = 0; i < num_of_thread; i++) {
kvdbptrs.push_back(std::make_shared<RocksKVDB>());
kvdbptrs[i]->CreateDB();
}
std::vector<std::thread> tarr;
for (int i = 0; i < num_of_thread; i++) {
tarr.push_back(
std::thread(db_thread_test, kvdbptrs[i], nums_of_ops_each_thread));
}
for (int i = 0; i < num_of_thread; i++) {
tarr[i].join();
}
return;
} }
int main(int argc, char** argv) { int main(int argc, char** argv) {
my_argc = argc; my_argc = argc;
my_argv = argv; my_argv = argv;
::testing::InitGoogleTest(&argc, argv); ::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS(); return RUN_ALL_TESTS();
} }
...@@ -12,32 +12,29 @@ ...@@ -12,32 +12,29 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "kvdb/rocksdb_impl.h"
#include "kvdb/kvdb_impl.h"
#include "kvdb/paddle_rocksdb.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <chrono>
#include <fstream>
#include <functional> #include <functional>
#include <string> #include <string>
#include <fstream>
#include <chrono>
#include <thread> #include <thread>
#include "kvdb/kvdb_impl.h"
#include "kvdb/paddle_rocksdb.h"
#include "kvdb/rocksdb_impl.h"
class KVDBTest : public ::testing::Test { class KVDBTest : public ::testing::Test {
protected: protected:
void SetUp() override{ void SetUp() override {}
} static void SetUpTestCase() {
kvdb = std::make_shared<RocksKVDB>();
static void SetUpTestCase() { dict_reader = std::make_shared<FileReader>();
kvdb = std::make_shared<RocksKVDB>(); param_dict = std::make_shared<ParamDict>();
dict_reader = std::make_shared<FileReader>(); }
param_dict = std::make_shared<ParamDict>();
} static AbsKVDBPtr kvdb;
static FileReaderPtr dict_reader;
static AbsKVDBPtr kvdb; static ParamDictPtr param_dict;
static FileReaderPtr dict_reader; static ParamDictMgr dict_mgr;
static ParamDictPtr param_dict;
static ParamDictMgr dict_mgr;
}; };
AbsKVDBPtr KVDBTest::kvdb; AbsKVDBPtr KVDBTest::kvdb;
FileReaderPtr KVDBTest::dict_reader; FileReaderPtr KVDBTest::dict_reader;
...@@ -48,116 +45,117 @@ void GenerateTestIn(std::string); ...@@ -48,116 +45,117 @@ void GenerateTestIn(std::string);
void UpdateTestIn(std::string); void UpdateTestIn(std::string);
TEST_F(KVDBTest, AbstractKVDB_Unit_Test) { TEST_F(KVDBTest, AbstractKVDB_Unit_Test) {
kvdb->CreateDB(); kvdb->CreateDB();
kvdb->SetDBName("test_kvdb"); kvdb->SetDBName("test_kvdb");
for (int i = 0; i < 100; i++) { for (int i = 0; i < 100; i++) {
kvdb->Set(std::to_string(i), std::to_string(i * 2)); kvdb->Set(std::to_string(i), std::to_string(i * 2));
} }
for (int i = 0; i < 100; i++) { for (int i = 0; i < 100; i++) {
std::string val = kvdb->Get(std::to_string(i)); std::string val = kvdb->Get(std::to_string(i));
ASSERT_EQ(val, std::to_string(i * 2)); ASSERT_EQ(val, std::to_string(i * 2));
} }
} }
TEST_F(KVDBTest, FileReader_Unit_Test) { TEST_F(KVDBTest, FileReader_Unit_Test) {
std::string test_in_filename = "abs_dict_reader_test_in.txt"; std::string test_in_filename = "abs_dict_reader_test_in.txt";
GenerateTestIn(test_in_filename); GenerateTestIn(test_in_filename);
dict_reader->SetFileName(test_in_filename); dict_reader->SetFileName(test_in_filename);
std::string md5_1 = dict_reader->GetMD5(); std::string md5_1 = dict_reader->GetMD5();
std::chrono::system_clock::time_point timestamp_1 = dict_reader->GetTimeStamp(); std::chrono::system_clock::time_point timestamp_1 =
dict_reader->GetTimeStamp();
std::string md5_2 = dict_reader->GetMD5();
std::chrono::system_clock::time_point timestamp_2 = dict_reader->GetTimeStamp(); std::string md5_2 = dict_reader->GetMD5();
std::chrono::system_clock::time_point timestamp_2 =
ASSERT_EQ(md5_1, md5_2); dict_reader->GetTimeStamp();
ASSERT_EQ(timestamp_1, timestamp_2);
ASSERT_EQ(md5_1, md5_2);
UpdateTestIn(test_in_filename); ASSERT_EQ(timestamp_1, timestamp_2);
std::string md5_3 = dict_reader->GetMD5(); UpdateTestIn(test_in_filename);
std::chrono::system_clock::time_point timestamp_3 = dict_reader->GetTimeStamp();
std::string md5_3 = dict_reader->GetMD5();
ASSERT_NE(md5_2, md5_3); std::chrono::system_clock::time_point timestamp_3 =
ASSERT_NE(timestamp_2, timestamp_3); dict_reader->GetTimeStamp();
ASSERT_NE(md5_2, md5_3);
ASSERT_NE(timestamp_2, timestamp_3);
} }
#include <cmath> #include <cmath>
TEST_F(KVDBTest, ParamDict_Unit_Test) { TEST_F(KVDBTest, ParamDict_Unit_Test) {
std::string test_in_filename = "abs_dict_reader_test_in.txt"; std::string test_in_filename = "abs_dict_reader_test_in.txt";
param_dict->SetFileReaderLst({test_in_filename}); param_dict->SetFileReaderLst({test_in_filename});
param_dict->SetReader( param_dict->SetReader([](std::string text) {
[] (std::string text) { auto split = [](const std::string& s,
auto split = [](const std::string& s, std::vector<std::string>& sv,
std::vector<std::string>& sv, const char* delim = " ") {
const char* delim = " ") { sv.clear();
sv.clear(); char* buffer = new char[s.size() + 1];
char* buffer = new char[s.size() + 1]; std::copy(s.begin(), s.end(), buffer);
std::copy(s.begin(), s.end(), buffer); char* p = strtok(buffer, delim);
char* p = strtok(buffer, delim); do {
do { sv.push_back(p);
sv.push_back(p); } while ((p = strtok(NULL, delim)));
} while ((p = strtok(NULL, delim))); return;
return; };
}; std::vector<std::string> text_split;
std::vector<std::string> text_split; split(text, text_split, " ");
split(text, text_split, " "); std::string key = text_split[0];
std::string key = text_split[0]; text_split.erase(text_split.begin());
text_split.erase(text_split.begin()); return make_pair(key, text_split);
return make_pair(key, text_split); });
}); param_dict->CreateKVDB();
param_dict->CreateKVDB(); GenerateTestIn(test_in_filename);
GenerateTestIn(test_in_filename);
param_dict->UpdateBaseModel();
param_dict->UpdateBaseModel();
std::this_thread::sleep_for(std::chrono::seconds(2));
std::this_thread::sleep_for(std::chrono::seconds(2));
std::vector<float> test_vec = param_dict->GetSparseValue("1", "");
std::vector<float> test_vec = param_dict->GetSparseValue("1", "");
ASSERT_LT(fabs(test_vec[0] - 1.0), 1e-2);
ASSERT_LT(fabs(test_vec[0] - 1.0), 1e-2);
UpdateTestIn(test_in_filename);
UpdateTestIn(test_in_filename); param_dict->UpdateDeltaModel();
param_dict->UpdateDeltaModel();
} }
void GenerateTestIn(std::string filename) { void GenerateTestIn(std::string filename) {
std::ifstream in_file(filename); std::ifstream in_file(filename);
if (in_file.good()) { if (in_file.good()) {
in_file.close(); in_file.close();
std::string cmd = "rm -rf "+ filename; std::string cmd = "rm -rf " + filename;
system(cmd.c_str()); system(cmd.c_str());
}
std::ofstream out_file(filename);
for (size_t i = 0; i < 100000; i++) {
out_file << i << " " << i << " ";
for (size_t j = 0; j < 3; j++) {
out_file << i << " ";
} }
std::ofstream out_file(filename); out_file << std::endl;
for (size_t i = 0; i < 100000; i++) { }
out_file << i << " " << i << " "; out_file.close();
for (size_t j = 0; j < 3; j++) {
out_file << i << " ";
}
out_file << std::endl;
}
out_file.close();
} }
void UpdateTestIn(std::string filename) { void UpdateTestIn(std::string filename) {
std::ifstream in_file(filename); std::ifstream in_file(filename);
if (in_file.good()) { if (in_file.good()) {
in_file.close(); in_file.close();
std::string cmd = "rm -rf " + filename; std::string cmd = "rm -rf " + filename;
system(cmd.c_str()); system(cmd.c_str());
} }
std::ofstream out_file(filename); std::ofstream out_file(filename);
for (size_t i = 0; i < 10000; i++) { for (size_t i = 0; i < 10000; i++) {
out_file << i << " " << i << " "; out_file << i << " " << i << " ";
for (size_t j = 0; j < 3; j++) { for (size_t j = 0; j < 3; j++) {
out_file << i + 1 << " "; out_file << i + 1 << " ";
}
out_file << std::endl;
} }
out_file.close(); out_file << std::endl;
}
out_file.close();
} }
int main(int argc, char** argv) { int main(int argc, char** argv) {
::testing::InitGoogleTest(&argc, argv); ::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS(); return RUN_ALL_TESTS();
} }
...@@ -12,139 +12,132 @@ ...@@ -12,139 +12,132 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "kvdb/rocksdb_impl.h"
#include <thread>
#include <iterator>
#include <fstream>
#include <algorithm> #include <algorithm>
#include <fstream>
#include <iterator>
#include <sstream> #include <sstream>
#include <thread>
#include "kvdb/rocksdb_impl.h"
std::vector<FileReaderPtr> ParamDict::GetDictReaderLst() { std::vector<FileReaderPtr> ParamDict::GetDictReaderLst() {
return this->file_reader_lst_; return this->file_reader_lst_;
} }
void ParamDict::SetFileReaderLst(std::vector<std::string> lst) { void ParamDict::SetFileReaderLst(std::vector<std::string> lst) {
for (size_t i = 0; i < lst.size(); i++) { for (size_t i = 0; i < lst.size(); i++) {
FileReaderPtr fr = std::make_shared<FileReader>(); FileReaderPtr fr = std::make_shared<FileReader>();
fr->SetFileName(lst[i]); fr->SetFileName(lst[i]);
this->file_reader_lst_.push_back(fr); this->file_reader_lst_.push_back(fr);
} }
} }
std::vector<float> ParamDict::GetSparseValue(std::string feasign, std::string slot) { std::vector<float> ParamDict::GetSparseValue(std::string feasign,
auto BytesToFloat = [](uint8_t* byteArray){ std::string slot) {
return *((float*)byteArray); auto BytesToFloat = [](uint8_t* byte_array) { return *((float*)byte_array); };
}; // TODO: the concatation of feasign and slot is TBD.
//TODO: the concatation of feasign and slot is TBD. std::string result = front_db->Get(feasign + slot);
std::string result = front_db->Get(feasign + slot); std::vector<float> value;
std::vector<float> value; if (result == "NOT_FOUND") return value;
if (result == "NOT_FOUND") uint8_t* raw_values_ptr = reinterpret_cast<uint8_t*>(&result[0]);
return value; for (size_t i = 0; i < result.size(); i += sizeof(float)) {
uint8_t* raw_values_ptr = reinterpret_cast<uint8_t *>(&result[0]); float temp = BytesToFloat(raw_values_ptr + i);
for (size_t i = 0; i < result.size(); i += 4) { value.push_back(temp);
float temp = BytesToFloat(raw_values_ptr + i); }
value.push_back(temp); return value;
}
return value;
} }
void ParamDict::SetReader(std::function<std::pair<Key, Value>(std::string)> func) { void ParamDict::SetReader(
read_func_ = func; std::function<std::pair<Key, Value>(std::string)> func) {
read_func_ = func;
} }
std::vector<float> ParamDict::GetSparseValue(int64_t feasign, int64_t slot) { std::vector<float> ParamDict::GetSparseValue(int64_t feasign, int64_t slot) {
return this->GetSparseValue(std::to_string(feasign), std::to_string(slot)); return this->GetSparseValue(std::to_string(feasign), std::to_string(slot));
} }
bool ParamDict::InsertSparseValue(int64_t feasign, int64_t slot, const std::vector<float>& values) { bool ParamDict::InsertSparseValue(int64_t feasign,
return this->InsertSparseValue(std::to_string(feasign), std::to_string(slot), values); int64_t slot,
const std::vector<float>& values) {
return this->InsertSparseValue(
std::to_string(feasign), std::to_string(slot), values);
} }
bool ParamDict::InsertSparseValue(std::string feasign, std::string slot, const std::vector<float>& values) { bool ParamDict::InsertSparseValue(std::string feasign,
auto FloatToBytes = [](float fvalue, uint8_t *arr){ std::string slot,
unsigned char *pf = nullptr; const std::vector<float>& values) {
unsigned char *px = nullptr; auto FloatToBytes = [](float fvalue, uint8_t* arr) {
unsigned char i = 0; unsigned char* pf = nullptr;
pf =(unsigned char *)&fvalue; unsigned char* px = nullptr;
px = arr; unsigned char i = 0;
for (i = 0; i < 4; i++) pf = (unsigned char*)&fvalue;
{ px = arr;
*(px+i)=*(pf+i); for (i = 0; i < sizeof(float); i++) {
} *(px + i) = *(pf + i);
};
std::string key = feasign + slot;
uint8_t* values_ptr = new uint8_t[values.size() * 4];
std::string value;
for (size_t i = 0; i < values.size(); i++) {
FloatToBytes(values[i], values_ptr + 4 * i);
} }
char* raw_values_ptr = reinterpret_cast<char*>(values_ptr); };
for (size_t i = 0; i < values.size()*4; i++) {
value.push_back(raw_values_ptr[i]); std::string key = feasign + slot;
} uint8_t* values_ptr = new uint8_t[values.size() * sizeof(float)];
back_db->Set(key, value); std::string value;
//TODO: change stateless to stateful for (size_t i = 0; i < values.size(); i++) {
return true; FloatToBytes(values[i], values_ptr + sizeof(float) * i);
}
char* raw_values_ptr = reinterpret_cast<char*>(values_ptr);
for (size_t i = 0; i < values.size() * sizeof(float); i++) {
value.push_back(raw_values_ptr[i]);
}
back_db->Set(key, value);
// TODO: change stateless to stateful
return true;
} }
void ParamDict::UpdateBaseModel() { void ParamDict::UpdateBaseModel() {
auto is_number = [] (const std::string& s) auto is_number = [](const std::string& s) {
{ return !s.empty() && std::find_if(s.begin(), s.end(), [](char c) {
return !s.empty() && std::find_if(s.begin(), return !std::isdigit(c);
s.end(), [](char c) { return !std::isdigit(c); }) == s.end(); }) == s.end();
}; };
std::thread t([&] () { std::thread t([&]() {
for (FileReaderPtr file_reader: this->file_reader_lst_) { for (FileReaderPtr file_reader : this->file_reader_lst_) {
std::string line; std::string line;
std::ifstream infile(file_reader->GetFileName()); std::ifstream infile(file_reader->GetFileName());
if (infile.is_open()) { if (infile.is_open()) {
while (getline(infile, line)) { while (getline(infile, line)) {
std::pair<Key, Value> kvpair = read_func_(line); std::pair<Key, Value> kvpair = read_func_(line);
std::vector<float> nums; std::vector<float> nums;
for (size_t i = 0; i < kvpair.second.size(); i++) { for (size_t i = 0; i < kvpair.second.size(); i++) {
if (is_number(kvpair.second[i])) { if (is_number(kvpair.second[i])) {
nums.push_back(std::stof(kvpair.second[i])); nums.push_back(std::stof(kvpair.second[i]));
}
}
this->InsertSparseValue(kvpair.first, "", nums);
}
} }
infile.close(); }
this->InsertSparseValue(kvpair.first, "", nums);
} }
AbsKVDBPtr temp = front_db; }
front_db = back_db; infile.close();
back_db = temp; }
}); AbsKVDBPtr temp = front_db;
t.detach(); front_db = back_db;
back_db = temp;
});
t.detach();
} }
void ParamDict::UpdateDeltaModel() { UpdateBaseModel(); }
void ParamDict::UpdateDeltaModel() { std::pair<AbsKVDBPtr, AbsKVDBPtr> ParamDict::GetKVDB() {
UpdateBaseModel(); return {front_db, back_db};
}
std::pair<AbsKVDBPtr, AbsKVDBPtr> ParamDict::GetKVDB() {
return {front_db, back_db};
} }
void ParamDict::SetKVDB(std::pair<AbsKVDBPtr, AbsKVDBPtr> kvdbs) { void ParamDict::SetKVDB(std::pair<AbsKVDBPtr, AbsKVDBPtr> kvdbs) {
this->front_db = kvdbs.first; this->front_db = kvdbs.first;
this->back_db = kvdbs.second; this->back_db = kvdbs.second;
} }
void ParamDict::CreateKVDB() { void ParamDict::CreateKVDB() {
this->front_db = std::make_shared<RocksKVDB>(); this->front_db = std::make_shared<RocksKVDB>();
this->back_db = std::make_shared<RocksKVDB>(); this->back_db = std::make_shared<RocksKVDB>();
this->front_db->CreateDB(); this->front_db->CreateDB();
this->back_db->CreateDB(); this->back_db->CreateDB();
} }
ParamDict::~ParamDict() { ParamDict::~ParamDict() {}
}
...@@ -15,39 +15,40 @@ ...@@ -15,39 +15,40 @@
#include "kvdb/paddle_rocksdb.h" #include "kvdb/paddle_rocksdb.h"
RocksDBWrapper::RocksDBWrapper(std::string db_name) { RocksDBWrapper::RocksDBWrapper(std::string db_name) {
rocksdb::Options options; rocksdb::Options options;
options.create_if_missing = true; options.create_if_missing = true;
db_name_ = db_name; db_name_ = db_name;
db_ = nullptr; db_ = nullptr;
rocksdb::Status s = rocksdb::DB::Open(options, db_name, &db_); rocksdb::Status s = rocksdb::DB::Open(options, db_name, &db_);
return; return;
} }
std::string RocksDBWrapper::Get(std::string key) { std::string RocksDBWrapper::Get(std::string key) {
rocksdb::ReadOptions options; rocksdb::ReadOptions options;
options.verify_checksums = true; options.verify_checksums = true;
std::string result; std::string result;
rocksdb::Status s = db_->Get(options, key, &result); rocksdb::Status s = db_->Get(options, key, &result);
if (s.IsNotFound()) { if (s.IsNotFound()) {
result = "NOT_FOUND"; result = "NOT_FOUND";
} }
return result; return result;
} }
bool RocksDBWrapper::Put(std::string key, std::string value) { bool RocksDBWrapper::Put(std::string key, std::string value) {
rocksdb::WriteOptions options; rocksdb::WriteOptions options;
rocksdb::Status s = db_->Put(options, key, value); rocksdb::Status s = db_->Put(options, key, value);
if (s.ok()) { if (s.ok()) {
return true; return true;
} else { } else {
return false; return false;
} }
} }
void RocksDBWrapper::SetDBName(std::string db_name) { void RocksDBWrapper::SetDBName(std::string db_name) {
this->db_name_ = db_name; this->db_name_ = db_name;
} }
std::shared_ptr<RocksDBWrapper> RocksDBWrapper::RocksDBWrapperFactory(std::string db_name) { std::shared_ptr<RocksDBWrapper> RocksDBWrapper::RocksDBWrapperFactory(
return std::make_shared<RocksDBWrapper>(db_name); std::string db_name) {
return std::make_shared<RocksDBWrapper>(db_name);
} }
...@@ -15,14 +15,14 @@ ...@@ -15,14 +15,14 @@
#include "kvdb/kvdb_impl.h" #include "kvdb/kvdb_impl.h"
void ParamDictMgr::UpdateAll() { void ParamDictMgr::UpdateAll() {
for (auto it = this->ParamDictMap.begin(); it!= this->ParamDictMap.end(); ++it) { for (auto it = this->ParamDictMap.begin(); it != this->ParamDictMap.end();
it->second->UpdateBaseModel(); ++it) {
} it->second->UpdateBaseModel();
}
} }
void ParamDictMgr::InsertParamDict(std::string key, ParamDictPtr value) { void ParamDictMgr::InsertParamDict(std::string key, ParamDictPtr value) {
this->ParamDictMap.insert(std::make_pair(key, value)); this->ParamDictMap.insert(std::make_pair(key, value));
} }
AbstractKVDB::~AbstractKVDB() {} AbstractKVDB::~AbstractKVDB() {}
...@@ -16,29 +16,22 @@ ...@@ -16,29 +16,22 @@
int RocksKVDB::db_count; int RocksKVDB::db_count;
void RocksKVDB::CreateDB() { void RocksKVDB::CreateDB() {
this->db_ = RocksDBWrapper::RocksDBWrapperFactory("RocksDB_" + std::to_string(RocksKVDB::db_count)); this->db_ = RocksDBWrapper::RocksDBWrapperFactory(
RocksKVDB::db_count ++; "RocksDB_" + std::to_string(RocksKVDB::db_count));
return; RocksKVDB::db_count++;
return;
} }
void RocksKVDB::SetDBName(std::string db_name) { void RocksKVDB::SetDBName(std::string db_name) {
this->db_->SetDBName(db_name); this->db_->SetDBName(db_name);
return; return;
} }
void RocksKVDB::Set(std::string key, std::string value) { void RocksKVDB::Set(std::string key, std::string value) {
this->db_->Put(key, value); this->db_->Put(key, value);
return; return;
} }
std::string RocksKVDB::Get(std::string key) { std::string RocksKVDB::Get(std::string key) { return this->db_->Get(key); }
return this->db_->Get(key);
}
RocksKVDB::~RocksKVDB() {
}
RocksKVDB::~RocksKVDB() {}
...@@ -12,37 +12,34 @@ ...@@ -12,37 +12,34 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "kvdb/rocksdb_impl.h"
#include "kvdb/paddle_rocksdb.h"
#include <iostream> #include <iostream>
#include "kvdb/paddle_rocksdb.h"
#include "kvdb/rocksdb_impl.h"
void test_rockskvdb() { void test_rockskvdb() {
RocksKVDB db; RocksKVDB db;
db.CreateDB(); db.CreateDB();
db.SetDBName("Sparse Matrix"); db.SetDBName("Sparse Matrix");
db.Set("1", "One"); db.Set("1", "One");
std::cout << db.Get("1") << std::endl; std::cout << db.Get("1") << std::endl;
return ; return;
} }
void test_rocksdbwrapper() { void test_rocksdbwrapper() {
std::shared_ptr<RocksDBWrapper> db = RocksDBWrapper::RocksDBWrapperFactory("TEST"); std::shared_ptr<RocksDBWrapper> db =
for (size_t i = 0; i < 1000; i++) { RocksDBWrapper::RocksDBWrapperFactory("TEST");
db->Put(std::to_string(i), std::to_string(i * 2)); for (size_t i = 0; i < 1000; i++) {
} db->Put(std::to_string(i), std::to_string(i * 2));
for (size_t i = 0; i < 1000; i++) { }
std::string res = db->Get(std::to_string(i)); for (size_t i = 0; i < 1000; i++) {
std::cout << res << " "; std::string res = db->Get(std::to_string(i));
} std::cout << res << " ";
std::cout << std::endl; }
std::cout << std::endl;
} }
#ifdef RAW_TEST #ifdef RAW_TEST
int main() { int main() {
test_rockskvdb(); test_rockskvdb();
test_rocksdbwrapper(); test_rocksdbwrapper();
} }
#endif #endif
...@@ -484,7 +484,8 @@ class PdsCodeGenerator : public CodeGenerator { ...@@ -484,7 +484,8 @@ class PdsCodeGenerator : public CodeGenerator {
"response);\n" "response);\n"
"}\n" "}\n"
"tt.stop();\n" "tt.stop();\n"
"if (ret.flags != baidu::rpc::SKIP_SUB_CHANNEL && ret.method != NULL) {\n" "if (ret.flags != baidu::rpc::SKIP_SUB_CHANNEL && ret.method != "
"NULL) {\n"
" _stub_handler->update_latency(tt.u_elapsed(), \"pack_map\");\n" " _stub_handler->update_latency(tt.u_elapsed(), \"pack_map\");\n"
"}\n" "}\n"
"return ret;\n"); "return ret;\n");
...@@ -498,7 +499,8 @@ class PdsCodeGenerator : public CodeGenerator { ...@@ -498,7 +499,8 @@ class PdsCodeGenerator : public CodeGenerator {
//////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////
printer->Print( printer->Print(
"class $name$_StubResponseMerger : public baidu::rpc::ResponseMerger {\n" "class $name$_StubResponseMerger : public baidu::rpc::ResponseMerger "
"{\n"
"private:\n" "private:\n"
" uint32_t _package_size;\n" " uint32_t _package_size;\n"
" baidu::paddle_serving::sdk_cpp::Stub* _stub_handler;\n" " baidu::paddle_serving::sdk_cpp::Stub* _stub_handler;\n"
...@@ -600,7 +602,8 @@ class PdsCodeGenerator : public CodeGenerator { ...@@ -600,7 +602,8 @@ class PdsCodeGenerator : public CodeGenerator {
"}\n"); "}\n");
"LOG(INFO) \n" "LOG(INFO) \n"
" << \"[default] Succ map, channel_index: \" << channel_index;\n"; " << \"[default] Succ map, channel_index: \" << channel_index;\n";
printer->Print("return baidu::rpc::SubCall(method, request, cur_res, 0);\n"); printer->Print(
"return baidu::rpc::SubCall(method, request, cur_res, 0);\n");
return true; return true;
} }
bool generate_paddle_serving_stub_default_merger( bool generate_paddle_serving_stub_default_merger(
......
...@@ -37,4 +37,3 @@ install(FILES ${CMAKE_CURRENT_LIST_DIR}/mempool/mempool.h ...@@ -37,4 +37,3 @@ install(FILES ${CMAKE_CURRENT_LIST_DIR}/mempool/mempool.h
DESTINATION ${PADDLE_SERVING_INSTALL_DIR}/include/predictor/mempool) DESTINATION ${PADDLE_SERVING_INSTALL_DIR}/include/predictor/mempool)
install(FILES ${CMAKE_CURRENT_LIST_DIR}/op/op.h install(FILES ${CMAKE_CURRENT_LIST_DIR}/op/op.h
DESTINATION ${PADDLE_SERVING_INSTALL_DIR}/include/predictor/op) DESTINATION ${PADDLE_SERVING_INSTALL_DIR}/include/predictor/op)
...@@ -20,11 +20,10 @@ namespace paddle_serving { ...@@ -20,11 +20,10 @@ namespace paddle_serving {
namespace predictor { namespace predictor {
struct MempoolRegion { struct MempoolRegion {
MempoolRegion(im::fugue::memory::Region *region, MempoolRegion(im::fugue::memory::Region* region, im::Mempool* mempool)
im::Mempool *mempool) : : _region(region), _mempool(mempool) {}
_region(region), _mempool(mempool){} im::fugue::memory::Region* region() { return _region; }
im::fugue::memory::Region *region() {return _region;} im::Mempool* mempool() { return _mempool; }
im::Mempool *mempool() {return _mempool;}
im::fugue::memory::Region* _region; im::fugue::memory::Region* _region;
im::Mempool* _mempool; im::Mempool* _mempool;
...@@ -54,10 +53,10 @@ int MempoolWrapper::initialize() { ...@@ -54,10 +53,10 @@ int MempoolWrapper::initialize() {
} }
int MempoolWrapper::thread_initialize() { int MempoolWrapper::thread_initialize() {
im::fugue::memory::Region *region = new im::fugue::memory::Region(); im::fugue::memory::Region* region = new im::fugue::memory::Region();
region->init(); region->init();
im::Mempool* mempool = new (std::nothrow) im::Mempool(region); im::Mempool* mempool = new (std::nothrow) im::Mempool(region);
MempoolRegion *mempool_region = new MempoolRegion(region, mempool); MempoolRegion* mempool_region = new MempoolRegion(region, mempool);
if (mempool == NULL) { if (mempool == NULL) {
LOG(ERROR) << "Failed create thread mempool"; LOG(ERROR) << "Failed create thread mempool";
return -1; return -1;
...@@ -76,7 +75,8 @@ int MempoolWrapper::thread_initialize() { ...@@ -76,7 +75,8 @@ int MempoolWrapper::thread_initialize() {
} }
int MempoolWrapper::thread_clear() { int MempoolWrapper::thread_clear() {
MempoolRegion* mempool_region = (MempoolRegion*)THREAD_GETSPECIFIC(_bspec_key); MempoolRegion* mempool_region =
(MempoolRegion*)THREAD_GETSPECIFIC(_bspec_key);
if (mempool_region == NULL) { if (mempool_region == NULL) {
LOG(WARNING) << "THREAD_GETSPECIFIC() returned NULL"; LOG(WARNING) << "THREAD_GETSPECIFIC() returned NULL";
return -1; return -1;
...@@ -91,7 +91,8 @@ int MempoolWrapper::thread_clear() { ...@@ -91,7 +91,8 @@ int MempoolWrapper::thread_clear() {
} }
void* MempoolWrapper::malloc(size_t size) { void* MempoolWrapper::malloc(size_t size) {
MempoolRegion* mempool_region = (MempoolRegion*)THREAD_GETSPECIFIC(_bspec_key); MempoolRegion* mempool_region =
(MempoolRegion*)THREAD_GETSPECIFIC(_bspec_key);
if (mempool_region == NULL) { if (mempool_region == NULL) {
LOG(WARNING) << "THREAD_GETSPECIFIC() returned NULL"; LOG(WARNING) << "THREAD_GETSPECIFIC() returned NULL";
return NULL; return NULL;
......
...@@ -39,7 +39,7 @@ class MempoolWrapper { ...@@ -39,7 +39,7 @@ class MempoolWrapper {
void* malloc(size_t size); void* malloc(size_t size);
private: private:
//im::fugue::memory::Region _region; // im::fugue::memory::Region _region;
THREAD_KEY_T _bspec_key; THREAD_KEY_T _bspec_key;
}; };
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#include <butil/scoped_lock.h> // BAIDU_SCOPED_LOCK #include <butil/scoped_lock.h> // BAIDU_SCOPED_LOCK
#endif #endif
#include <bvar/bvar.h> // bvar #include <bvar/bvar.h> // bvar
#include <string> #include <string>
#ifdef BCLOUD #ifdef BCLOUD
......
...@@ -36,13 +36,9 @@ DynamicResource::DynamicResource() {} ...@@ -36,13 +36,9 @@ DynamicResource::DynamicResource() {}
DynamicResource::~DynamicResource() {} DynamicResource::~DynamicResource() {}
int DynamicResource::initialize() { int DynamicResource::initialize() { return 0; }
return 0;
}
std::shared_ptr<RocksDBWrapper> Resource::getDB() { std::shared_ptr<RocksDBWrapper> Resource::getDB() { return db; }
return db;
}
int DynamicResource::clear() { return 0; } int DynamicResource::clear() { return 0; }
...@@ -86,7 +82,7 @@ int Resource::initialize(const std::string& path, const std::string& file) { ...@@ -86,7 +82,7 @@ int Resource::initialize(const std::string& path, const std::string& file) {
LOG(ERROR) << "unable to create tls_bthread_key of thrd_data"; LOG(ERROR) << "unable to create tls_bthread_key of thrd_data";
return -1; return -1;
} }
//init rocksDB instance // init rocksDB instance
if (db.get() == nullptr) { if (db.get() == nullptr) {
db = RocksDBWrapper::RocksDBWrapperFactory("kvdb"); db = RocksDBWrapper::RocksDBWrapperFactory("kvdb");
} }
......
...@@ -14,9 +14,9 @@ ...@@ -14,9 +14,9 @@
#pragma once #pragma once
#include <string> #include <string>
#include "kvdb/paddle_rocksdb.h"
#include "predictor/common/inner_common.h" #include "predictor/common/inner_common.h"
#include "predictor/framework/memory.h" #include "predictor/framework/memory.h"
#include "kvdb/paddle_rocksdb.h"
namespace baidu { namespace baidu {
namespace paddle_serving { namespace paddle_serving {
...@@ -31,7 +31,6 @@ struct DynamicResource { ...@@ -31,7 +31,6 @@ struct DynamicResource {
int initialize(); int initialize();
int clear(); int clear();
}; };
class Resource { class Resource {
...@@ -65,7 +64,7 @@ class Resource { ...@@ -65,7 +64,7 @@ class Resource {
private: private:
int thread_finalize() { return 0; } int thread_finalize() { return 0; }
std::shared_ptr<RocksDBWrapper> db; std::shared_ptr<RocksDBWrapper> db;
THREAD_KEY_T _tls_bspec_key; THREAD_KEY_T _tls_bspec_key;
}; };
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#ifdef BCLOUD #ifdef BCLOUD
#include <baidu/rpc/policy/nova_pbrpc_protocol.h> // NovaServiceAdaptor #include <baidu/rpc/policy/nova_pbrpc_protocol.h> // NovaServiceAdaptor
#include <baidu/rpc/policy/nshead_mcpack_protocol.h> // NsheadMcpackAdaptor #include <baidu/rpc/policy/nshead_mcpack_protocol.h> // NsheadMcpackAdaptor
#include <baidu/rpc/policy/public_pbrpc_protocol.h> // PublicPbrpcServiceAdaptor #include <baidu/rpc/policy/public_pbrpc_protocol.h> // PublicPbrpcServiceAdaptor
#else #else
#include <brpc/policy/nova_pbrpc_protocol.h> // NovaServiceAdaptor #include <brpc/policy/nova_pbrpc_protocol.h> // NovaServiceAdaptor
#include <brpc/policy/nshead_mcpack_protocol.h> // NsheadMcpackAdaptor #include <brpc/policy/nshead_mcpack_protocol.h> // NsheadMcpackAdaptor
......
...@@ -17,16 +17,11 @@ package baidu.paddle_serving.predictor.format; ...@@ -17,16 +17,11 @@ package baidu.paddle_serving.predictor.format;
// echo kvdb formant // echo kvdb formant
message KVDBReq { message KVDBReq {
required string op = 1; required string op = 1;
required string key = 2; required string key = 2;
optional string value = 3; optional string value = 3;
}; };
message KVDBRes{ message KVDBRes { required string value = 2; };
required string value = 2;
};
// dense format // dense format
message DenseInstance { repeated float features = 1; }; message DenseInstance { repeated float features = 1; };
......
...@@ -143,7 +143,8 @@ int main(int argc, char** argv) { ...@@ -143,7 +143,8 @@ int main(int argc, char** argv) {
std::string filename(argv[0]); std::string filename(argv[0]);
filename = filename.substr(filename.find_last_of('/') + 1); filename = filename.substr(filename.find_last_of('/') + 1);
settings.log_file = strdup((std::string("./log/") + filename + ".log").c_str()); settings.log_file =
strdup((std::string("./log/") + filename + ".log").c_str());
settings.delete_old = logging::DELETE_OLD_LOG_FILE; settings.delete_old = logging::DELETE_OLD_LOG_FILE;
logging::InitLogging(settings); logging::InitLogging(settings);
......
...@@ -31,11 +31,11 @@ ...@@ -31,11 +31,11 @@
#include "baidu/rpc/channel.h" #include "baidu/rpc/channel.h"
#include "baidu/rpc/parallel_channel.h" #include "baidu/rpc/parallel_channel.h"
#include "baidu/rpc/traceprintf.h" #include "baidu/rpc/traceprintf.h"
#include "bthread.h"
#include "base/logging.h"
#include "base/comlog_sink.h" #include "base/comlog_sink.h"
#include "base/logging.h"
#include "base/object_pool.h" #include "base/object_pool.h"
#include "base/time.h" #include "base/time.h"
#include "bthread.h"
#else #else
#include "brpc/channel.h" #include "brpc/channel.h"
#include "brpc/parallel_channel.h" #include "brpc/parallel_channel.h"
......
...@@ -20,15 +20,15 @@ package baidu.paddle_serving.predictor.echo_kvdb_service; ...@@ -20,15 +20,15 @@ package baidu.paddle_serving.predictor.echo_kvdb_service;
option cc_generic_services = true; option cc_generic_services = true;
message Request { message Request {
repeated baidu.paddle_serving.predictor.format.KVDBReq reqs = 1; repeated baidu.paddle_serving.predictor.format.KVDBReq reqs = 1;
}; };
message Response { message Response {
repeated baidu.paddle_serving.predictor.format.KVDBRes ress = 1; repeated baidu.paddle_serving.predictor.format.KVDBRes ress = 1;
}; };
service EchoKVDBService { service EchoKVDBService {
rpc inference(Request) returns (Response); rpc inference(Request) returns (Response);
rpc debug(Request) returns (Response); rpc debug(Request) returns (Response);
option (pds.options).generate_stub = true; option (pds.options).generate_stub = true;
}; };
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册