未验证 提交 988b5fe1 编写于 作者: T Thunderbrook 提交者: GitHub

[PsCore] support ssd (#33031)

* support ssd in PsCore

* remove log

* remove bz2

* defalut value

* code style

* parse table class

* code style

* add define
上级 b425215a
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
INCLUDE(ExternalProject)
SET(ROCKSDB_SOURCES_DIR ${THIRD_PARTY_PATH}/rocksdb)
SET(ROCKSDB_INSTALL_DIR ${THIRD_PARTY_PATH}/install/rocksdb)
SET(ROCKSDB_INCLUDE_DIR "${ROCKSDB_INSTALL_DIR}/include" CACHE PATH "rocksdb include directory." FORCE)
SET(ROCKSDB_LIBRARIES "${ROCKSDB_INSTALL_DIR}/lib/librocksdb.a" CACHE FILEPATH "rocksdb library." FORCE)
SET(ROCKSDB_CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC")
INCLUDE_DIRECTORIES(${ROCKSDB_INCLUDE_DIR})
ExternalProject_Add(
extern_rocksdb
${EXTERNAL_PROJECT_LOG_ARGS}
PREFIX ${ROCKSDB_SOURCES_DIR}
GIT_REPOSITORY "https://github.com/facebook/rocksdb"
GIT_TAG v6.10.1
UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
-DWITH_BZ2=OFF
-DWITH_GFLAGS=OFF
-DCMAKE_CXX_FLAGS=${ROCKSDB_CMAKE_CXX_FLAGS}
-DCMAKE_C_FLAGS=${CMAKE_C_FLAGS}
# BUILD_BYPRODUCTS ${ROCKSDB_SOURCES_DIR}/src/extern_rocksdb/librocksdb.a
INSTALL_COMMAND mkdir -p ${ROCKSDB_INSTALL_DIR}/lib/
&& cp ${ROCKSDB_SOURCES_DIR}/src/extern_rocksdb/librocksdb.a ${ROCKSDB_LIBRARIES}
&& cp -r ${ROCKSDB_SOURCES_DIR}/src/extern_rocksdb/include ${ROCKSDB_INSTALL_DIR}/
BUILD_IN_SOURCE 1
)
ADD_DEPENDENCIES(extern_rocksdb snappy)
ADD_LIBRARY(rocksdb STATIC IMPORTED GLOBAL)
SET_PROPERTY(TARGET rocksdb PROPERTY IMPORTED_LOCATION ${ROCKSDB_LIBRARIES})
ADD_DEPENDENCIES(rocksdb extern_rocksdb)
LIST(APPEND external_project_dependencies rocksdb)
...@@ -317,6 +317,11 @@ if (WITH_PSCORE) ...@@ -317,6 +317,11 @@ if (WITH_PSCORE)
include(external/libmct) # download, build, install libmct include(external/libmct) # download, build, install libmct
list(APPEND third_party_deps extern_libmct) list(APPEND third_party_deps extern_libmct)
if (WITH_HETERPS)
include(external/rocksdb) # download, build, install libmct
list(APPEND third_party_deps extern_rocksdb)
endif()
endif() endif()
if(WITH_XBYAK) if(WITH_XBYAK)
......
...@@ -417,8 +417,10 @@ void FleetWrapper::PushSparseFromTensorWithLabelAsync( ...@@ -417,8 +417,10 @@ void FleetWrapper::PushSparseFromTensorWithLabelAsync(
return; return;
} }
void FleetWrapper::LoadModel(const std::string& path, const int mode) { void FleetWrapper::LoadModel(const std::string& path, const std::string& mode) {
auto ret = pserver_ptr_->_worker_ptr->load(path, std::to_string(mode)); auto* communicator = Communicator::GetInstance();
auto ret = communicator->_worker_ptr->load(path, mode);
// auto ret = pserver_ptr_->_worker_ptr->load(path, std::to_string(mode));
ret.wait(); ret.wait();
if (ret.get() != 0) { if (ret.get() != 0) {
LOG(ERROR) << "load model from path:" << path << " failed"; LOG(ERROR) << "load model from path:" << path << " failed";
...@@ -429,8 +431,11 @@ void FleetWrapper::LoadModel(const std::string& path, const int mode) { ...@@ -429,8 +431,11 @@ void FleetWrapper::LoadModel(const std::string& path, const int mode) {
void FleetWrapper::LoadModelOneTable(const uint64_t table_id, void FleetWrapper::LoadModelOneTable(const uint64_t table_id,
const std::string& path, const int mode) { const std::string& path, const int mode) {
auto* communicator = Communicator::GetInstance();
auto ret = auto ret =
pserver_ptr_->_worker_ptr->load(table_id, path, std::to_string(mode)); communicator->_worker_ptr->load(table_id, path, std::to_string(mode));
// auto ret =
// pserver_ptr_->_worker_ptr->load(table_id, path, std::to_string(mode));
ret.wait(); ret.wait();
if (ret.get() != 0) { if (ret.get() != 0) {
LOG(ERROR) << "load model of table id: " << table_id LOG(ERROR) << "load model of table id: " << table_id
......
...@@ -200,7 +200,7 @@ class FleetWrapper { ...@@ -200,7 +200,7 @@ class FleetWrapper {
void PrintTableStat(const uint64_t table_id); void PrintTableStat(const uint64_t table_id);
// mode = 0, load all feature // mode = 0, load all feature
// mode = 1, load delta feature, which means load diff // mode = 1, load delta feature, which means load diff
void LoadModel(const std::string& path, const int mode); void LoadModel(const std::string& path, const std::string& mode);
// mode = 0, load all feature // mode = 0, load all feature
// mode = 1, load delta feature, which means load diff // mode = 1, load delta feature, which means load diff
void LoadModelOneTable(const uint64_t table_id, const std::string& path, void LoadModelOneTable(const uint64_t table_id, const std::string& path,
......
...@@ -42,17 +42,17 @@ int32_t PsLocalClient::initialize() { ...@@ -42,17 +42,17 @@ int32_t PsLocalClient::initialize() {
::std::future<int32_t> PsLocalClient::load(const std::string& epoch, ::std::future<int32_t> PsLocalClient::load(const std::string& epoch,
const std::string& mode) { const std::string& mode) {
// TODO // TODO
// for (auto& it : _table_map) { for (auto& it : _table_map) {
// load(it.first, epoch, mode); load(it.first, epoch, mode);
//} }
return done(); return done();
} }
::std::future<int32_t> PsLocalClient::load(uint32_t table_id, ::std::future<int32_t> PsLocalClient::load(uint32_t table_id,
const std::string& epoch, const std::string& epoch,
const std::string& mode) { const std::string& mode) {
// TODO // TODO
// auto* table_ptr = table(table_id); auto* table_ptr = table(table_id);
// table_ptr->load(epoch, mode); table_ptr->load(epoch, mode);
return done(); return done();
} }
...@@ -245,7 +245,6 @@ int32_t PsLocalClient::initialize() { ...@@ -245,7 +245,6 @@ int32_t PsLocalClient::initialize() {
::std::future<int32_t> PsLocalClient::push_sparse_raw_gradient( ::std::future<int32_t> PsLocalClient::push_sparse_raw_gradient(
size_t table_id, const uint64_t* keys, const float** update_values, size_t table_id, const uint64_t* keys, const float** update_values,
size_t num, void* callback) { size_t num, void* callback) {
VLOG(1) << "wxx push_sparse_raw_gradient";
PSClientClosure* closure = reinterpret_cast<PSClientClosure*>(callback); PSClientClosure* closure = reinterpret_cast<PSClientClosure*>(callback);
auto* accessor = table_accessor(table_id); auto* accessor = table_accessor(table_id);
auto* table_ptr = table(table_id); auto* table_ptr = table(table_id);
......
...@@ -26,9 +26,14 @@ class PsLocalServer : public PSServer { ...@@ -26,9 +26,14 @@ class PsLocalServer : public PSServer {
PsLocalServer() {} PsLocalServer() {}
virtual ~PsLocalServer() {} virtual ~PsLocalServer() {}
virtual uint64_t start() { return 0; } virtual uint64_t start() { return 0; }
virtual uint64_t start(const std::string& ip, uint32_t port) { return 0; } virtual uint64_t start(const std::string &ip, uint32_t port) { return 0; }
virtual int32_t stop() { return 0; } virtual int32_t stop() { return 0; }
virtual int32_t port() { return 0; } virtual int32_t port() { return 0; }
virtual int32_t configure(
const PSParameter &config, PSEnvironment &env, size_t server_rank,
const std::vector<framework::ProgramDesc> &server_sub_program = {}) {
return 0;
}
private: private:
virtual int32_t initialize() { return 0; } virtual int32_t initialize() { return 0; }
......
...@@ -70,7 +70,7 @@ class PSServer { ...@@ -70,7 +70,7 @@ class PSServer {
virtual int32_t configure( virtual int32_t configure(
const PSParameter &config, PSEnvironment &env, size_t server_rank, const PSParameter &config, PSEnvironment &env, size_t server_rank,
const std::vector<framework::ProgramDesc> &server_sub_program = {}) final; const std::vector<framework::ProgramDesc> &server_sub_program = {});
// return server_ip // return server_ip
virtual std::string ip() { return butil::my_ip_cstr(); } virtual std::string ip() { return butil::my_ip_cstr(); }
......
...@@ -9,15 +9,24 @@ set_source_files_properties(${graphDir}/graph_node.cc PROPERTIES COMPILE_FLAGS $ ...@@ -9,15 +9,24 @@ set_source_files_properties(${graphDir}/graph_node.cc PROPERTIES COMPILE_FLAGS $
cc_library(graph_node SRCS ${graphDir}/graph_node.cc DEPS WeightedSampler) cc_library(graph_node SRCS ${graphDir}/graph_node.cc DEPS WeightedSampler)
set_source_files_properties(common_dense_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(common_dense_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(common_sparse_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(common_sparse_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(ssd_sparse_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(sparse_geo_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(sparse_geo_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(barrier_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(barrier_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(common_graph_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(common_graph_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
get_property(RPC_DEPS GLOBAL PROPERTY RPC_DEPS) get_property(RPC_DEPS GLOBAL PROPERTY RPC_DEPS)
cc_library(common_table SRCS common_sparse_table.cc common_dense_table.cc set(EXTERN_DEP "")
sparse_geo_table.cc barrier_table.cc common_graph_table.cc DEPS ${TABLE_DEPS} if(WITH_HETERPS)
${RPC_DEPS} graph_edge graph_node device_context string_helper simple_threadpool xxhash generator) set(TABLE_SRC common_sparse_table.cc ssd_sparse_table.cc common_dense_table.cc sparse_geo_table.cc barrier_table.cc common_graph_table.cc)
set(EXTERN_DEP rocksdb)
else()
set(TABLE_SRC common_sparse_table.cc common_dense_table.cc sparse_geo_table.cc barrier_table.cc common_graph_table.cc)
endif()
cc_library(common_table SRCS ${TABLE_SRC} DEPS ${TABLE_DEPS}
${RPC_DEPS} graph_edge graph_node device_context string_helper
simple_threadpool xxhash generator ${EXTERN_DEP})
set_source_files_properties(tensor_accessor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(tensor_accessor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(tensor_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(tensor_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
......
...@@ -25,83 +25,12 @@ class ValueBlock; ...@@ -25,83 +25,12 @@ class ValueBlock;
} // namespace distributed } // namespace distributed
} // namespace paddle } // namespace paddle
#define PSERVER_SAVE_SUFFIX ".shard"
using boost::lexical_cast;
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
enum SaveMode { all, base, delta }; void CommonSparseTable::ProcessALine(const std::vector<std::string>& columns,
const Meta& meta, const int64_t id,
struct Meta { std::vector<std::vector<float>>* values) {
std::string param;
int shard_id;
std::vector<std::string> names;
std::vector<int> dims;
uint64_t count;
std::unordered_map<std::string, int> dims_map;
explicit Meta(const std::string& metapath) {
std::ifstream file(metapath);
std::string line;
int num_lines = 0;
while (std::getline(file, line)) {
if (StartWith(line, "#")) {
continue;
}
auto pairs = paddle::string::split_string<std::string>(line, "=");
PADDLE_ENFORCE_EQ(
pairs.size(), 2,
paddle::platform::errors::InvalidArgument(
"info in %s except k=v, but got %s", metapath, line));
if (pairs[0] == "param") {
param = pairs[1];
}
if (pairs[0] == "shard_id") {
shard_id = std::stoi(pairs[1]);
}
if (pairs[0] == "row_names") {
names = paddle::string::split_string<std::string>(pairs[1], ",");
}
if (pairs[0] == "row_dims") {
auto dims_strs =
paddle::string::split_string<std::string>(pairs[1], ",");
for (auto& str : dims_strs) {
dims.push_back(std::stoi(str));
}
}
if (pairs[0] == "count") {
count = std::stoull(pairs[1]);
}
}
for (int x = 0; x < names.size(); ++x) {
dims_map[names[x]] = dims[x];
}
}
Meta(std::string param, int shard_id, std::vector<std::string> row_names,
std::vector<int> dims, uint64_t count) {
this->param = param;
this->shard_id = shard_id;
this->names = row_names;
this->dims = dims;
this->count = count;
}
std::string ToString() {
std::stringstream ss;
ss << "param=" << param << "\n";
ss << "shard_id=" << shard_id << "\n";
ss << "row_names=" << paddle::string::join_strings(names, ',') << "\n";
ss << "row_dims=" << paddle::string::join_strings(dims, ',') << "\n";
ss << "count=" << count << "\n";
return ss.str();
}
};
void ProcessALine(const std::vector<std::string>& columns, const Meta& meta,
const int64_t id, std::vector<std::vector<float>>* values) {
auto colunmn_size = columns.size(); auto colunmn_size = columns.size();
auto load_values = auto load_values =
paddle::string::split_string<std::string>(columns[colunmn_size - 1], ","); paddle::string::split_string<std::string>(columns[colunmn_size - 1], ",");
...@@ -134,8 +63,10 @@ void ProcessALine(const std::vector<std::string>& columns, const Meta& meta, ...@@ -134,8 +63,10 @@ void ProcessALine(const std::vector<std::string>& columns, const Meta& meta,
} }
} }
void SaveMetaToText(std::ostream* os, const CommonAccessorParameter& common, void CommonSparseTable::SaveMetaToText(std::ostream* os,
const size_t shard_idx, const int64_t total) { const CommonAccessorParameter& common,
const size_t shard_idx,
const int64_t total) {
// save meta // save meta
std::stringstream stream; std::stringstream stream;
stream << "param=" << common.table_name() << "\n"; stream << "param=" << common.table_name() << "\n";
...@@ -148,8 +79,10 @@ void SaveMetaToText(std::ostream* os, const CommonAccessorParameter& common, ...@@ -148,8 +79,10 @@ void SaveMetaToText(std::ostream* os, const CommonAccessorParameter& common,
os->write(stream.str().c_str(), sizeof(char) * stream.str().size()); os->write(stream.str().c_str(), sizeof(char) * stream.str().size());
} }
int64_t SaveValueToText(std::ostream* os, std::shared_ptr<ValueBlock> block, int64_t CommonSparseTable::SaveValueToText(std::ostream* os,
std::shared_ptr<::ThreadPool> pool, const int mode) { std::shared_ptr<ValueBlock> block,
std::shared_ptr<::ThreadPool> pool,
const int mode, int shard_id) {
int64_t save_num = 0; int64_t save_num = 0;
for (auto& table : block->values_) { for (auto& table : block->values_) {
for (auto& value : table) { for (auto& value : table) {
...@@ -186,9 +119,9 @@ int64_t SaveValueToText(std::ostream* os, std::shared_ptr<ValueBlock> block, ...@@ -186,9 +119,9 @@ int64_t SaveValueToText(std::ostream* os, std::shared_ptr<ValueBlock> block,
return save_num; return save_num;
} }
int64_t LoadFromText(const std::string& valuepath, const std::string& metapath, int64_t CommonSparseTable::LoadFromText(
const int pserver_id, const int pserver_num, const std::string& valuepath, const std::string& metapath,
const int local_shard_num, const int pserver_id, const int pserver_num, const int local_shard_num,
std::vector<std::shared_ptr<ValueBlock>>* blocks) { std::vector<std::shared_ptr<ValueBlock>>* blocks) {
Meta meta = Meta(metapath); Meta meta = Meta(metapath);
...@@ -198,7 +131,7 @@ int64_t LoadFromText(const std::string& valuepath, const std::string& metapath, ...@@ -198,7 +131,7 @@ int64_t LoadFromText(const std::string& valuepath, const std::string& metapath,
while (std::getline(file, line)) { while (std::getline(file, line)) {
auto values = paddle::string::split_string<std::string>(line, "\t"); auto values = paddle::string::split_string<std::string>(line, "\t");
auto id = lexical_cast<int64_t>(values[0]); auto id = lexical_cast<uint64_t>(values[0]);
if (id % pserver_num != pserver_id) { if (id % pserver_num != pserver_id) {
VLOG(3) << "will not load " << values[0] << " from " << valuepath VLOG(3) << "will not load " << values[0] << " from " << valuepath
...@@ -388,8 +321,9 @@ int32_t CommonSparseTable::save(const std::string& dirname, ...@@ -388,8 +321,9 @@ int32_t CommonSparseTable::save(const std::string& dirname,
int64_t total_ins = 0; int64_t total_ins = 0;
for (int shard_id = 0; shard_id < task_pool_size_; ++shard_id) { for (int shard_id = 0; shard_id < task_pool_size_; ++shard_id) {
// save values // save values
auto shard_save_num = SaveValueToText(vs.get(), shard_values_[shard_id], auto shard_save_num =
_shards_task_pool[shard_id], mode); SaveValueToText(vs.get(), shard_values_[shard_id],
_shards_task_pool[shard_id], mode, shard_id);
total_ins += shard_save_num; total_ins += shard_save_num;
} }
vs->close(); vs->close();
......
...@@ -32,11 +32,83 @@ ...@@ -32,11 +32,83 @@
#include "paddle/fluid/framework/rw_lock.h" #include "paddle/fluid/framework/rw_lock.h"
#include "paddle/fluid/string/string_helper.h" #include "paddle/fluid/string/string_helper.h"
#define PSERVER_SAVE_SUFFIX ".shard"
using boost::lexical_cast;
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
class SparseOptimizer; class SparseOptimizer;
enum SaveMode { all, base, delta };
struct Meta {
std::string param;
int shard_id;
std::vector<std::string> names;
std::vector<int> dims;
uint64_t count;
std::unordered_map<std::string, int> dims_map;
explicit Meta(const std::string& metapath) {
std::ifstream file(metapath);
std::string line;
int num_lines = 0;
while (std::getline(file, line)) {
if (StartWith(line, "#")) {
continue;
}
auto pairs = paddle::string::split_string<std::string>(line, "=");
PADDLE_ENFORCE_EQ(
pairs.size(), 2,
paddle::platform::errors::InvalidArgument(
"info in %s except k=v, but got %s", metapath, line));
if (pairs[0] == "param") {
param = pairs[1];
}
if (pairs[0] == "shard_id") {
shard_id = std::stoi(pairs[1]);
}
if (pairs[0] == "row_names") {
names = paddle::string::split_string<std::string>(pairs[1], ",");
}
if (pairs[0] == "row_dims") {
auto dims_strs =
paddle::string::split_string<std::string>(pairs[1], ",");
for (auto& str : dims_strs) {
dims.push_back(std::stoi(str));
}
}
if (pairs[0] == "count") {
count = std::stoull(pairs[1]);
}
}
for (int x = 0; x < names.size(); ++x) {
dims_map[names[x]] = dims[x];
}
}
Meta(std::string param, int shard_id, std::vector<std::string> row_names,
std::vector<int> dims, uint64_t count) {
this->param = param;
this->shard_id = shard_id;
this->names = row_names;
this->dims = dims;
this->count = count;
}
std::string ToString() {
std::stringstream ss;
ss << "param=" << param << "\n";
ss << "shard_id=" << shard_id << "\n";
ss << "row_names=" << paddle::string::join_strings(names, ',') << "\n";
ss << "row_dims=" << paddle::string::join_strings(dims, ',') << "\n";
ss << "count=" << count << "\n";
return ss.str();
}
};
class CommonSparseTable : public SparseTable { class CommonSparseTable : public SparseTable {
public: public:
CommonSparseTable() { rwlock_.reset(new framework::RWLock); } CommonSparseTable() { rwlock_.reset(new framework::RWLock); }
...@@ -56,9 +128,25 @@ class CommonSparseTable : public SparseTable { ...@@ -56,9 +128,25 @@ class CommonSparseTable : public SparseTable {
virtual int32_t initialize_optimizer(); virtual int32_t initialize_optimizer();
virtual int32_t initialize_recorder(); virtual int32_t initialize_recorder();
int32_t load(const std::string& path, const std::string& param); virtual int32_t load(const std::string& path, const std::string& param);
int32_t save(const std::string& path, const std::string& param); virtual int32_t save(const std::string& path, const std::string& param);
void SaveMetaToText(std::ostream* os, const CommonAccessorParameter& common,
const size_t shard_idx, const int64_t total);
int64_t SaveValueToText(std::ostream* os, std::shared_ptr<ValueBlock> block,
std::shared_ptr<::ThreadPool> pool, const int mode,
int shard_id);
virtual void ProcessALine(const std::vector<std::string>& columns,
const Meta& meta, const int64_t id,
std::vector<std::vector<float>>* values);
virtual int64_t LoadFromText(
const std::string& valuepath, const std::string& metapath,
const int pserver_id, const int pserver_num, const int local_shard_num,
std::vector<std::shared_ptr<ValueBlock>>* blocks);
virtual std::pair<int64_t, int64_t> print_table_stat(); virtual std::pair<int64_t, int64_t> print_table_stat();
virtual int32_t pull_sparse(float* values, const PullSparseValue& pull_value); virtual int32_t pull_sparse(float* values, const PullSparseValue& pull_value);
...@@ -89,7 +177,7 @@ class CommonSparseTable : public SparseTable { ...@@ -89,7 +177,7 @@ class CommonSparseTable : public SparseTable {
virtual int32_t _push_sparse(const uint64_t* keys, const float** values, virtual int32_t _push_sparse(const uint64_t* keys, const float** values,
size_t num); size_t num);
private: protected:
const int task_pool_size_ = 11; const int task_pool_size_ = 11;
std::vector<std::shared_ptr<::ThreadPool>> _shards_task_pool; std::vector<std::shared_ptr<::ThreadPool>> _shards_task_pool;
......
...@@ -83,6 +83,7 @@ inline bool probility_entry(VALUE *value, float threshold) { ...@@ -83,6 +83,7 @@ inline bool probility_entry(VALUE *value, float threshold) {
class ValueBlock { class ValueBlock {
public: public:
typedef typename robin_hood::unordered_map<uint64_t, VALUE *> map_type;
explicit ValueBlock(const std::vector<std::string> &value_names, explicit ValueBlock(const std::vector<std::string> &value_names,
const std::vector<int> &value_dims, const std::vector<int> &value_dims,
const std::vector<int> &value_offsets, const std::vector<int> &value_offsets,
...@@ -261,6 +262,18 @@ class ValueBlock { ...@@ -261,6 +262,18 @@ class ValueBlock {
value->is_entry_ = state; value->is_entry_ = state;
} }
void erase(uint64_t feasign) {
size_t hash = _hasher(feasign);
size_t bucket = compute_bucket(hash);
auto &table = values_[bucket];
auto iter = table.find(feasign);
if (iter != table.end()) {
butil::return_object(iter->second);
iter = table.erase(iter);
}
}
void Shrink(const int threshold) { void Shrink(const int threshold) {
for (auto &table : values_) { for (auto &table : values_) {
for (auto iter = table.begin(); iter != table.end();) { for (auto iter = table.begin(); iter != table.end();) {
...@@ -289,6 +302,23 @@ class ValueBlock { ...@@ -289,6 +302,23 @@ class ValueBlock {
} }
} }
map_type::iterator end() {
return values_[SPARSE_SHARD_BUCKET_NUM - 1].end();
}
map_type::iterator Find(uint64_t id) {
size_t hash = _hasher(id);
size_t bucket = compute_bucket(hash);
auto &table = values_[bucket];
auto got = table.find(id);
if (got == table.end()) {
return end();
} else {
return got;
}
}
private: private:
bool Has(const uint64_t id) { bool Has(const uint64_t id) {
size_t hash = _hasher(id); size_t hash = _hasher(id);
...@@ -304,7 +334,7 @@ class ValueBlock { ...@@ -304,7 +334,7 @@ class ValueBlock {
} }
public: public:
robin_hood::unordered_map<uint64_t, VALUE *> values_[SPARSE_SHARD_BUCKET_NUM]; map_type values_[SPARSE_SHARD_BUCKET_NUM];
size_t value_length_ = 0; size_t value_length_ = 0;
std::hash<uint64_t> _hasher; std::hash<uint64_t> _hasher;
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifdef PADDLE_WITH_HETERPS
#include <glog/logging.h>
#include <rocksdb/db.h>
#include <rocksdb/filter_policy.h>
#include <rocksdb/options.h>
#include <rocksdb/slice.h>
#include <rocksdb/table.h>
#include <rocksdb/write_batch.h>
#include <iostream>
#include <string>
namespace paddle {
namespace distributed {
class RocksDBHandler {
public:
RocksDBHandler() {}
~RocksDBHandler() {}
static RocksDBHandler* GetInstance() {
static RocksDBHandler handler;
return &handler;
}
int initialize(const std::string& db_path, const int colnum) {
VLOG(3) << "db path: " << db_path << " colnum: " << colnum;
rocksdb::Options options;
rocksdb::BlockBasedTableOptions bbto;
bbto.block_size = 4 * 1024;
bbto.block_cache = rocksdb::NewLRUCache(64 * 1024 * 1024);
bbto.block_cache_compressed = rocksdb::NewLRUCache(64 * 1024 * 1024);
bbto.cache_index_and_filter_blocks = false;
bbto.filter_policy.reset(rocksdb::NewBloomFilterPolicy(20, false));
bbto.whole_key_filtering = true;
options.table_factory.reset(rocksdb::NewBlockBasedTableFactory(bbto));
options.keep_log_file_num = 100;
options.max_log_file_size = 50 * 1024 * 1024; // 50MB
options.create_if_missing = true;
options.use_direct_reads = true;
options.max_background_flushes = 5;
options.max_background_compactions = 5;
options.base_background_compactions = 10;
options.write_buffer_size = 256 * 1024 * 1024; // 256MB
options.max_write_buffer_number = 8;
options.max_bytes_for_level_base =
options.max_write_buffer_number * options.write_buffer_size;
options.min_write_buffer_number_to_merge = 1;
options.target_file_size_base = 1024 * 1024 * 1024; // 1024MB
options.memtable_prefix_bloom_size_ratio = 0.02;
options.num_levels = 4;
options.max_open_files = -1;
options.compression = rocksdb::kNoCompression;
options.level0_file_num_compaction_trigger = 8;
options.level0_slowdown_writes_trigger =
1.8 * options.level0_file_num_compaction_trigger;
options.level0_stop_writes_trigger =
3.6 * options.level0_file_num_compaction_trigger;
if (!db_path.empty()) {
std::string rm_cmd = "rm -rf " + db_path;
system(rm_cmd.c_str());
}
rocksdb::Status s = rocksdb::DB::Open(options, db_path, &_db);
assert(s.ok());
_handles.resize(colnum);
for (int i = 0; i < colnum; i++) {
s = _db->CreateColumnFamily(options, "shard_" + std::to_string(i),
&_handles[i]);
assert(s.ok());
}
LOG(INFO) << "DB initialize success, colnum:" << colnum;
return 0;
}
int put(int id, const char* key, int key_len, const char* value,
int value_len) {
rocksdb::WriteOptions options;
options.disableWAL = true;
rocksdb::Status s =
_db->Put(options, _handles[id], rocksdb::Slice(key, key_len),
rocksdb::Slice(value, value_len));
assert(s.ok());
return 0;
}
int put_batch(int id, std::vector<std::pair<char*, int>>& ssd_keys,
std::vector<std::pair<char*, int>>& ssd_values, int n) {
rocksdb::WriteOptions options;
options.disableWAL = true;
rocksdb::WriteBatch batch(n * 128);
for (int i = 0; i < n; i++) {
batch.Put(_handles[id],
rocksdb::Slice(ssd_keys[i].first, ssd_keys[i].second),
rocksdb::Slice(ssd_values[i].first, ssd_values[i].second));
}
rocksdb::Status s = _db->Write(options, &batch);
assert(s.ok());
return 0;
}
int get(int id, const char* key, int key_len, std::string& value) {
rocksdb::Status s = _db->Get(rocksdb::ReadOptions(), _handles[id],
rocksdb::Slice(key, key_len), &value);
if (s.IsNotFound()) {
return 1;
}
assert(s.ok());
return 0;
}
int del_data(int id, const char* key, int key_len) {
rocksdb::WriteOptions options;
options.disableWAL = true;
rocksdb::Status s =
_db->Delete(options, _handles[id], rocksdb::Slice(key, key_len));
assert(s.ok());
return 0;
}
int flush(int id) {
rocksdb::Status s = _db->Flush(rocksdb::FlushOptions(), _handles[id]);
assert(s.ok());
return 0;
}
rocksdb::Iterator* get_iterator(int id) {
return _db->NewIterator(rocksdb::ReadOptions(), _handles[id]);
}
int get_estimate_key_num(uint64_t& num_keys) {
_db->GetAggregatedIntProperty("rocksdb.estimate-num-keys", &num_keys);
return 0;
}
private:
std::vector<rocksdb::ColumnFamilyHandle*> _handles;
rocksdb::DB* _db;
};
}
}
#endif
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifdef PADDLE_WITH_HETERPS
#include "paddle/fluid/distributed/table/ssd_sparse_table.h"
DEFINE_string(rocksdb_path, "database", "path of sparse table rocksdb file");
namespace paddle {
namespace distributed {
int32_t SSDSparseTable::initialize() {
_shards_task_pool.resize(task_pool_size_);
for (int i = 0; i < _shards_task_pool.size(); ++i) {
_shards_task_pool[i].reset(new ::ThreadPool(1));
}
sync = _config.common().sync();
VLOG(1) << "table " << _config.common().table_name() << " is sync: " << sync;
_global_lr = new float(1.0);
auto common = _config.common();
int size = static_cast<int>(common.params().size());
size_t offset = 0;
for (int x = 0; x < size; ++x) {
auto& varname = common.params()[x];
auto& dim = common.dims()[x];
value_idx_[varname] = x;
value_names_.push_back(varname);
value_dims_.push_back(dim);
value_offsets_.push_back(offset);
initializer_attrs_.push_back(common.initializers()[x]);
if (varname == "Param") {
param_dim_ = dim;
param_offset_ = offset;
}
offset += dim;
}
initialize_value();
initialize_optimizer();
initialize_recorder();
_db = paddle::distributed::RocksDBHandler::GetInstance();
_db->initialize(FLAGS_rocksdb_path, task_pool_size_);
return 0;
}
int32_t SSDSparseTable::pull_sparse(float* pull_values,
const PullSparseValue& pull_value) {
auto shard_num = task_pool_size_;
std::vector<std::future<int>> tasks(shard_num);
for (int shard_id = 0; shard_id < shard_num; ++shard_id) {
tasks[shard_id] = _shards_task_pool[shard_id]->enqueue(
[this, shard_id, shard_num, &pull_value, &pull_values]() -> int {
auto& block = shard_values_[shard_id];
std::vector<int> offsets;
pull_value.Fission(shard_id, shard_num, &offsets);
for (auto& offset : offsets) {
auto feasign = pull_value.feasigns_[offset];
auto frequencie = pull_value.frequencies_[offset];
float* embedding = nullptr;
auto iter = block->Find(feasign);
// in mem
if (iter == block->end()) {
embedding = iter->second->data_.data();
if (pull_value.is_training_) {
block->AttrUpdate(iter->second, frequencie);
}
} else {
// need create
std::string tmp_str("");
if (_db->get(shard_id, (char*)&feasign, sizeof(uint64_t),
tmp_str) > 0) {
embedding = block->Init(feasign, true, frequencie);
} else {
// in db
int data_size = tmp_str.size() / sizeof(float);
int value_size = block->value_length_;
float* db_value = (float*)const_cast<char*>(tmp_str.c_str());
VALUE* value = block->InitGet(feasign);
// copy to mem
memcpy(value->data_.data(), db_value,
value_size * sizeof(float));
embedding = db_value;
// param, count, unseen_day
value->count_ = db_value[value_size];
value->unseen_days_ = db_value[value_size + 1];
value->is_entry_ = db_value[value_size + 2];
if (pull_value.is_training_) {
block->AttrUpdate(value, frequencie);
}
}
}
std::copy_n(embedding + param_offset_, param_dim_,
pull_values + param_dim_ * offset);
}
return 0;
});
}
for (size_t shard_id = 0; shard_id < tasks.size(); ++shard_id) {
tasks[shard_id].wait();
}
return 0;
}
int32_t SSDSparseTable::pull_sparse_ptr(char** pull_values,
const uint64_t* keys, size_t num) {
auto shard_num = task_pool_size_;
std::vector<std::future<int>> tasks(shard_num);
std::vector<std::vector<uint64_t>> offset_bucket;
offset_bucket.resize(task_pool_size_);
for (int x = 0; x < num; ++x) {
auto y = keys[x] % task_pool_size_;
offset_bucket[y].push_back(x);
}
for (int shard_id = 0; shard_id < shard_num; ++shard_id) {
tasks[shard_id] = _shards_task_pool[shard_id]->enqueue(
[this, shard_id, &keys, &pull_values, &offset_bucket]() -> int {
auto& block = shard_values_[shard_id];
auto& offsets = offset_bucket[shard_id];
for (auto& offset : offsets) {
auto feasign = keys[offset];
auto iter = block->Find(feasign);
VALUE* value = nullptr;
// in mem
if (iter != block->end()) {
value = iter->second;
} else {
// need create
std::string tmp_str("");
if (_db->get(shard_id, (char*)&feasign, sizeof(uint64_t),
tmp_str) > 0) {
value = block->InitGet(feasign);
} else {
// in db
int data_size = tmp_str.size() / sizeof(float);
int value_size = block->value_length_;
float* db_value = (float*)const_cast<char*>(tmp_str.c_str());
value = block->InitGet(feasign);
// copy to mem
memcpy(value->data_.data(), db_value,
value_size * sizeof(float));
// param, count, unseen_day
value->count_ = db_value[value_size];
value->unseen_days_ = db_value[value_size + 1];
value->is_entry_ = db_value[value_size + 2];
}
}
pull_values[offset] = (char*)value;
}
return 0;
});
}
for (size_t shard_id = 0; shard_id < tasks.size(); ++shard_id) {
tasks[shard_id].wait();
}
return 0;
}
int32_t SSDSparseTable::shrink(const std::string& param) { return 0; }
int32_t SSDSparseTable::update_table() {
int count = 0;
int value_size = shard_values_[0]->value_length_;
int db_size = 3 + value_size;
float tmp_value[db_size];
for (size_t i = 0; i < task_pool_size_; ++i) {
auto& block = shard_values_[i];
for (auto& table : block->values_) {
for (auto iter = table.begin(); iter != table.end();) {
VALUE* value = iter->second;
if (value->unseen_days_ >= 1) {
tmp_value[value_size] = value->count_;
tmp_value[value_size + 1] = value->unseen_days_;
tmp_value[value_size + 2] = value->is_entry_;
memcpy(tmp_value, value->data_.data(), sizeof(float) * value_size);
_db->put(i, (char*)&(iter->first), sizeof(uint64_t), (char*)tmp_value,
db_size * sizeof(float));
count++;
butil::return_object(iter->second);
iter = table.erase(iter);
} else {
++iter;
}
}
}
_db->flush(i);
}
VLOG(1) << "Table>> update count: " << count;
return 0;
}
int64_t SSDSparseTable::SaveValueToText(std::ostream* os,
std::shared_ptr<ValueBlock> block,
std::shared_ptr<::ThreadPool> pool,
const int mode, int shard_id) {
int64_t save_num = 0;
for (auto& table : block->values_) {
for (auto& value : table) {
if (mode == SaveMode::delta && !value.second->need_save_) {
continue;
}
++save_num;
std::stringstream ss;
auto* vs = value.second->data_.data();
auto id = value.first;
ss << id << "\t" << value.second->count_ << "\t"
<< value.second->unseen_days_ << "\t" << value.second->is_entry_
<< "\t";
for (int i = 0; i < block->value_length_ - 1; i++) {
ss << std::to_string(vs[i]) << ",";
}
ss << std::to_string(vs[block->value_length_ - 1]);
ss << "\n";
os->write(ss.str().c_str(), sizeof(char) * ss.str().size());
if (mode == SaveMode::base || mode == SaveMode::delta) {
value.second->need_save_ = false;
}
}
}
if (mode != 1) {
int value_size = block->value_length_;
auto* it = _db->get_iterator(shard_id);
for (it->SeekToFirst(); it->Valid(); it->Next()) {
float* value = (float*)const_cast<char*>(it->value().data());
std::stringstream ss;
ss << *((uint64_t*)const_cast<char*>(it->key().data())) << "\t"
<< value[value_size] << "\t" << value[value_size + 1] << "\t"
<< value[value_size + 2] << "\t";
for (int i = 0; i < block->value_length_ - 1; i++) {
ss << std::to_string(value[i]) << ",";
}
ss << std::to_string(value[block->value_length_ - 1]);
ss << "\n";
os->write(ss.str().c_str(), sizeof(char) * ss.str().size());
}
}
return save_num;
}
int32_t SSDSparseTable::load(const std::string& path,
const std::string& param) {
rwlock_->WRLock();
VLOG(3) << "ssd sparse table load with " << path << " with meta " << param;
LoadFromText(path, param, _shard_idx, _shard_num, task_pool_size_,
&shard_values_);
rwlock_->UNLock();
return 0;
}
int64_t SSDSparseTable::LoadFromText(
const std::string& valuepath, const std::string& metapath,
const int pserver_id, const int pserver_num, const int local_shard_num,
std::vector<std::shared_ptr<ValueBlock>>* blocks) {
Meta meta = Meta(metapath);
int num_lines = 0;
std::ifstream file(valuepath);
std::string line;
int value_size = shard_values_[0]->value_length_;
int db_size = 3 + value_size;
float tmp_value[db_size];
while (std::getline(file, line)) {
auto values = paddle::string::split_string<std::string>(line, "\t");
auto id = lexical_cast<uint64_t>(values[0]);
if (id % pserver_num != pserver_id) {
VLOG(3) << "will not load " << values[0] << " from " << valuepath
<< ", please check id distribution";
continue;
}
auto shard_id = id % local_shard_num;
auto block = blocks->at(shard_id);
std::vector<std::vector<float>> kvalues;
ProcessALine(values, meta, id, &kvalues);
block->Init(id, false);
VALUE* value_instant = block->GetValue(id);
if (values.size() == 5) {
value_instant->count_ = lexical_cast<int>(values[1]);
value_instant->unseen_days_ = lexical_cast<int>(values[2]);
value_instant->is_entry_ =
static_cast<bool>(lexical_cast<int>(values[3]));
}
std::vector<float*> block_values = block->Get(id, meta.names, meta.dims);
auto blas = GetBlas<float>();
for (int x = 0; x < meta.names.size(); ++x) {
blas.VCOPY(meta.dims[x], kvalues[x].data(), block_values[x]);
}
VLOG(3) << "loading: " << id
<< "unseen day: " << value_instant->unseen_days_;
if (value_instant->unseen_days_ >= 1) {
tmp_value[value_size] = value_instant->count_;
tmp_value[value_size + 1] = value_instant->unseen_days_;
tmp_value[value_size + 2] = value_instant->is_entry_;
memcpy(tmp_value, value_instant->data_.data(),
sizeof(float) * value_size);
_db->put(shard_id, (char*)&(id), sizeof(uint64_t), (char*)tmp_value,
db_size * sizeof(float));
block->erase(id);
}
}
return 0;
}
} // namespace ps
} // namespace paddle
#endif
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/distributed/table/common_sparse_table.h"
#include "paddle/fluid/distributed/table/depends/rocksdb_warpper.h"
#ifdef PADDLE_WITH_HETERPS
namespace paddle {
namespace distributed {
class SSDSparseTable : public CommonSparseTable {
public:
SSDSparseTable() {}
virtual ~SSDSparseTable() {}
virtual int32_t initialize() override;
void SaveMetaToText(std::ostream* os, const CommonAccessorParameter& common,
const size_t shard_idx, const int64_t total);
int64_t SaveValueToText(std::ostream* os, std::shared_ptr<ValueBlock> block,
std::shared_ptr<::ThreadPool> pool, const int mode,
int shard_id);
virtual int64_t LoadFromText(
const std::string& valuepath, const std::string& metapath,
const int pserver_id, const int pserver_num, const int local_shard_num,
std::vector<std::shared_ptr<ValueBlock>>* blocks);
virtual int32_t load(const std::string& path, const std::string& param);
// exchange data
virtual int32_t update_table();
virtual int32_t pull_sparse(float* values, const PullSparseValue& pull_value);
virtual int32_t pull_sparse_ptr(char** pull_values, const uint64_t* keys,
size_t num);
virtual int32_t flush() override { return 0; }
virtual int32_t shrink(const std::string& param) override;
virtual void clear() override {}
private:
RocksDBHandler* _db;
int64_t _cache_tk_size;
};
} // namespace ps
} // namespace paddle
#endif
...@@ -21,6 +21,9 @@ ...@@ -21,6 +21,9 @@
#include "paddle/fluid/distributed/table/common_graph_table.h" #include "paddle/fluid/distributed/table/common_graph_table.h"
#include "paddle/fluid/distributed/table/common_sparse_table.h" #include "paddle/fluid/distributed/table/common_sparse_table.h"
#include "paddle/fluid/distributed/table/sparse_geo_table.h" #include "paddle/fluid/distributed/table/sparse_geo_table.h"
#ifdef PADDLE_WITH_HETERPS
#include "paddle/fluid/distributed/table/ssd_sparse_table.h"
#endif
#include "paddle/fluid/distributed/table/tensor_accessor.h" #include "paddle/fluid/distributed/table/tensor_accessor.h"
#include "paddle/fluid/distributed/table/tensor_table.h" #include "paddle/fluid/distributed/table/tensor_table.h"
...@@ -29,6 +32,9 @@ namespace distributed { ...@@ -29,6 +32,9 @@ namespace distributed {
REGISTER_PSCORE_CLASS(Table, GraphTable); REGISTER_PSCORE_CLASS(Table, GraphTable);
REGISTER_PSCORE_CLASS(Table, CommonDenseTable); REGISTER_PSCORE_CLASS(Table, CommonDenseTable);
REGISTER_PSCORE_CLASS(Table, CommonSparseTable); REGISTER_PSCORE_CLASS(Table, CommonSparseTable);
#ifdef PADDLE_WITH_HETERPS
REGISTER_PSCORE_CLASS(Table, SSDSparseTable);
#endif
REGISTER_PSCORE_CLASS(Table, SparseGeoTable); REGISTER_PSCORE_CLASS(Table, SparseGeoTable);
REGISTER_PSCORE_CLASS(Table, BarrierTable); REGISTER_PSCORE_CLASS(Table, BarrierTable);
REGISTER_PSCORE_CLASS(Table, TensorTable); REGISTER_PSCORE_CLASS(Table, TensorTable);
......
...@@ -118,6 +118,11 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -118,6 +118,11 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
") for entry attribute.") ") for entry attribute.")
.SetDefault("none"); .SetDefault("none");
AddAttr<std::string>("table_class",
"(std::string, default "
") for table_class.")
.SetDefault("none");
AddAttr<std::vector<std::string>>( AddAttr<std::vector<std::string>>(
"table_names", "table_names",
"(string vector, the split table names that will be fetched from " "(string vector, the split table names that will be fetched from "
......
...@@ -58,6 +58,8 @@ void BindDistFleetWrapper(py::module* m) { ...@@ -58,6 +58,8 @@ void BindDistFleetWrapper(py::module* m) {
"DistFleetWrapper") "DistFleetWrapper")
.def(py::init([]() { return FleetWrapper::GetInstance(); })) .def(py::init([]() { return FleetWrapper::GetInstance(); }))
.def("load_sparse", &FleetWrapper::LoadSparseOnServer) .def("load_sparse", &FleetWrapper::LoadSparseOnServer)
.def("load_model", &FleetWrapper::LoadModel)
.def("load_one_table", &FleetWrapper::LoadModelOneTable)
.def("init_server", &FleetWrapper::InitServer) .def("init_server", &FleetWrapper::InitServer)
.def("run_server", .def("run_server",
(uint64_t (FleetWrapper::*)(void)) & FleetWrapper::RunServer) (uint64_t (FleetWrapper::*)(void)) & FleetWrapper::RunServer)
......
...@@ -77,6 +77,7 @@ stop_worker = fleet.stop_worker ...@@ -77,6 +77,7 @@ stop_worker = fleet.stop_worker
distributed_optimizer = fleet.distributed_optimizer distributed_optimizer = fleet.distributed_optimizer
save_inference_model = fleet.save_inference_model save_inference_model = fleet.save_inference_model
save_persistables = fleet.save_persistables save_persistables = fleet.save_persistables
load_model = fleet.load_model
minimize = fleet.minimize minimize = fleet.minimize
distributed_model = fleet.distributed_model distributed_model = fleet.distributed_model
step = fleet.step step = fleet.step
......
...@@ -540,6 +540,29 @@ class Fleet(object): ...@@ -540,6 +540,29 @@ class Fleet(object):
""" """
self._runtime_handle._init_server(*args, **kwargs) self._runtime_handle._init_server(*args, **kwargs)
def load_model(self, path, mode):
"""
load fleet model from path
Returns:
None
Examples:
.. code-block:: python
import paddle.distributed.fleet as fleet
fleet.init()
# build net
# fleet.distributed_optimizer(...)
fleet.load_model("path", "mode")
"""
self._runtime_handle.load_model(path, mode)
@is_non_distributed_check @is_non_distributed_check
@inited_runtime_handler @inited_runtime_handler
def run_server(self): def run_server(self):
......
...@@ -35,6 +35,23 @@ def conv_indent(indent): ...@@ -35,6 +35,23 @@ def conv_indent(indent):
PSERVER_SAVE_SUFFIX = ".shard" PSERVER_SAVE_SUFFIX = ".shard"
def parse_table_class(varname, o_main_program):
from paddle.fluid.incubate.fleet.parameter_server.ir.public import is_distributed_sparse_op
from paddle.fluid.incubate.fleet.parameter_server.ir.public import is_sparse_op
for op in o_main_program.global_block().ops:
if not is_distributed_sparse_op(op) and not is_sparse_op(op):
continue
param_name = op.input("W")[0]
if param_name == varname and op.type == "lookup_table" or op.type == "lookup_table_v2":
if op.has_attr('table_class') and op.attr("table_class") != "none":
return op.attr('table_class')
else:
return "CommonSparseTable"
class Accessor: class Accessor:
def __init__(self): def __init__(self):
self.accessor_class = "" self.accessor_class = ""
...@@ -723,13 +740,15 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -723,13 +740,15 @@ class TheOnePSRuntime(RuntimeBase):
table.type = "PS_SPARSE_TABLE" table.type = "PS_SPARSE_TABLE"
table.shard_num = 256 table.shard_num = 256
common.table_name = self.compiled_strategy.grad_name_to_param_name[
ctx.origin_varnames()[0]]
if self.compiled_strategy.is_geo_mode(): if self.compiled_strategy.is_geo_mode():
table.table_class = "SparseGeoTable" table.table_class = "SparseGeoTable"
else: else:
table.table_class = "CommonSparseTable" table.table_class = parse_table_class(
common.table_name, self.origin_main_program)
common.table_name = self.compiled_strategy.grad_name_to_param_name[
ctx.origin_varnames()[0]]
else: else:
table.type = "PS_DENSE_TABLE" table.type = "PS_DENSE_TABLE"
table.table_class = "CommonDenseTable" table.table_class = "CommonDenseTable"
...@@ -1044,6 +1063,9 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -1044,6 +1063,9 @@ class TheOnePSRuntime(RuntimeBase):
def _save_persistables(self, *args, **kwargs): def _save_persistables(self, *args, **kwargs):
self._ps_inference_save_persistables(*args, **kwargs) self._ps_inference_save_persistables(*args, **kwargs)
def load_model(self, path, mode):
self._worker.load_model(path, mode)
def _shrink(self, threshold): def _shrink(self, threshold):
import paddle.distributed.fleet as fleet import paddle.distributed.fleet as fleet
fleet.util.barrier() fleet.util.barrier()
......
...@@ -967,6 +967,7 @@ def sparse_embedding(input, ...@@ -967,6 +967,7 @@ def sparse_embedding(input,
padding_idx=None, padding_idx=None,
is_test=False, is_test=False,
entry=None, entry=None,
table_class="CommonSparseTable",
param_attr=None, param_attr=None,
dtype='float32'): dtype='float32'):
helper = LayerHelper('sparse_embedding', **locals()) helper = LayerHelper('sparse_embedding', **locals())
...@@ -989,6 +990,10 @@ def sparse_embedding(input, ...@@ -989,6 +990,10 @@ def sparse_embedding(input,
padding_idx = -1 if padding_idx is None else padding_idx if padding_idx >= 0 else ( padding_idx = -1 if padding_idx is None else padding_idx if padding_idx >= 0 else (
size[0] + padding_idx) size[0] + padding_idx)
if table_class not in ["CommonSparseTable", "SSDSparseTable"]:
raise ValueError(
"table_class must be in [CommonSparseTable, SSDSparseTable]")
entry_str = "none" entry_str = "none"
if entry is not None: if entry is not None:
...@@ -1011,7 +1016,8 @@ def sparse_embedding(input, ...@@ -1011,7 +1016,8 @@ def sparse_embedding(input,
'is_distributed': True, 'is_distributed': True,
'remote_prefetch': True, 'remote_prefetch': True,
'is_test': is_test, 'is_test': is_test,
'entry': entry_str 'entry': entry_str,
'table_class': table_class
}) })
return tmp return tmp
......
...@@ -365,7 +365,41 @@ def ps_gpu_pass(program): ...@@ -365,7 +365,41 @@ def ps_gpu_pass(program):
for name in remove_var: for name in remove_var:
program.global_block()._remove_var(name) program.global_block()._remove_var(name)
def _remove_optimizer_var(program):
embedding_w = {}
for idx, op in list(enumerate(program.global_block().ops)):
if op.type == "lookup_table_grad":
for name in op.input("W"):
embedding_w[name] = 1
optimize_vars = []
optimize_op_role_vars = []
optimize_need_delete_vars = []
for op in _get_optimize_ops(program):
for name in op.input("Param"):
if name in embedding_w:
optimize_op_role_vars.extend(op.attr("op_role_var"))
for key_name in op.input_names:
if key_name == "LearningRate":
continue
for var in op.input(key_name):
optimize_vars.append(var)
optimize_vars = list(set(optimize_vars))
optimize_op_role_vars = list(set(optimize_op_role_vars))
for var in optimize_vars:
if var not in optimize_op_role_vars:
optimize_need_delete_vars.append(var)
need_delete_optimize_vars = list(set(optimize_need_delete_vars))
for name in need_delete_optimize_vars:
if program.global_block().has_var(name):
program.global_block()._remove_var(name)
_add_push_box_sparse_op(program) _add_push_box_sparse_op(program)
_remove_optimizer_var(program)
_remove_lookup_table_grad_op_and_var(program) _remove_lookup_table_grad_op_and_var(program)
return program return program
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册