未验证 提交 4d0d0eca 编写于 作者: Y yaoxuefeng 提交者: GitHub

mod base (#40702)

上级 382e460b
...@@ -414,6 +414,16 @@ std::future<int32_t> BrpcPsClient::load(uint32_t table_id, ...@@ -414,6 +414,16 @@ std::future<int32_t> BrpcPsClient::load(uint32_t table_id,
return send_cmd(table_id, PS_LOAD_ONE_TABLE, {epoch, mode}); return send_cmd(table_id, PS_LOAD_ONE_TABLE, {epoch, mode});
} }
std::future<int32_t> BrpcPsClient::Load(const LoadSaveContext &load_context) {
if (load_context.table_id < 0) {
return send_cmd(-1, PS_LOAD_ALL_TABLE,
{load_context.epoch, load_context.mode});
} else {
return send_cmd(load_context.table_id, PS_LOAD_ONE_TABLE,
{load_context.epoch, load_context.mode});
}
}
std::future<int32_t> BrpcPsClient::save(const std::string &epoch, std::future<int32_t> BrpcPsClient::save(const std::string &epoch,
const std::string &mode) { const std::string &mode) {
VLOG(1) << "BrpcPsClient::save path " << epoch; VLOG(1) << "BrpcPsClient::save path " << epoch;
...@@ -427,6 +437,19 @@ std::future<int32_t> BrpcPsClient::save(uint32_t table_id, ...@@ -427,6 +437,19 @@ std::future<int32_t> BrpcPsClient::save(uint32_t table_id,
return send_save_cmd(table_id, PS_SAVE_ONE_TABLE, {epoch, mode}); return send_save_cmd(table_id, PS_SAVE_ONE_TABLE, {epoch, mode});
} }
std::future<int32_t> BrpcPsClient::Save(const LoadSaveContext &save_context) {
if (save_context.table_id < 0) {
VLOG(1) << "BrpcPsClient::save path " << save_context.epoch;
return send_save_cmd(-1, PS_SAVE_ALL_TABLE,
{save_context.epoch, save_context.mode});
} else {
VLOG(1) << "BrpcPsClient::save one table path " << save_context.epoch
<< " table_id " << save_context.table_id;
return send_save_cmd(save_context.table_id, PS_SAVE_ONE_TABLE,
{save_context.epoch, save_context.mode});
}
}
std::future<int32_t> BrpcPsClient::clear() { std::future<int32_t> BrpcPsClient::clear() {
return send_cmd(-1, PS_CLEAR_ALL_TABLE, {}); return send_cmd(-1, PS_CLEAR_ALL_TABLE, {});
} }
...@@ -505,6 +528,44 @@ std::future<int32_t> BrpcPsClient::barrier(size_t table_id, ...@@ -505,6 +528,44 @@ std::future<int32_t> BrpcPsClient::barrier(size_t table_id,
return send_cmd(table_id, PS_BARRIER, {std::to_string(barrier_type)}); return send_cmd(table_id, PS_BARRIER, {std::to_string(barrier_type)});
} }
std::future<int32_t> BrpcPsClient::Pull(RequestContext &pull_context) {
if (pull_context.value_type == Dense) { // pull dense
Region *dense_region =
reinterpret_cast<Region *>(pull_context.dense_values);
pull_dense(dense_region, pull_context.num, pull_context.table);
} else { // pull sparse
uint64_t *keys = reinterpret_cast<uint64_t *>(pull_context.keys);
float **select_values =
reinterpret_cast<float **>(pull_context.sparse_values);
size_t table_id = pull_context.table;
size_t num = pull_context.num;
bool is_training = pull_context.is_training;
if (pull_context.training_mode == Geo) { // for geo
pull_sparse_param(select_values, table_id, keys, num, is_training);
} else if (pull_context.training_mode == Async) { // for async
pull_sparse(select_values, table_id, keys, num, is_training);
}
}
}
std::future<int32_t> BrpcPsClient::Push(RequestContext &push_context) {
if (push_context.value_type == Dense) { // push dense
const Region *dense_region = push_context.push_context.push_dense_values;
push_dense(dense_region, push_context.num, push_context.table);
} else { // push sparse
size_t table_id = push_context.table;
size_t num = push_context.num;
bool is_training = push_context.is_training;
if (push_context.training_mode == Geo) { // for geo
// TODO(zhaocaibei)
} else if (push_context.training_mode == Async) { // for async
const uint64_t *keys = push_context.push_context.keys;
const float **update_values = push_context.push_context.push_values;
push_sparse(table_id, keys, update_values, num);
}
}
}
std::future<int32_t> BrpcPsClient::pull_geo_param(size_t table_id, std::future<int32_t> BrpcPsClient::pull_geo_param(size_t table_id,
std::vector<float> *values, std::vector<float> *values,
std::vector<uint64_t> *keys, std::vector<uint64_t> *keys,
......
...@@ -163,12 +163,17 @@ class BrpcPsClient : public PSClient { ...@@ -163,12 +163,17 @@ class BrpcPsClient : public PSClient {
std::future<int32_t> load(uint32_t table_id, const std::string &epoch, std::future<int32_t> load(uint32_t table_id, const std::string &epoch,
const std::string &mode) override; const std::string &mode) override;
std::future<int32_t> Load(const LoadSaveContext &load_context) override;
std::future<int32_t> save(const std::string &epoch, std::future<int32_t> save(const std::string &epoch,
const std::string &mode) override; const std::string &mode) override;
std::future<int32_t> save(uint32_t table_id, const std::string &epoch, std::future<int32_t> save(uint32_t table_id, const std::string &epoch,
const std::string &mode) override; const std::string &mode) override;
virtual std::future<int32_t> Save(
const LoadSaveContext &save_context) override;
std::future<int32_t> clear() override; std::future<int32_t> clear() override;
std::future<int32_t> clear(uint32_t table_id) override; std::future<int32_t> clear(uint32_t table_id) override;
...@@ -199,6 +204,10 @@ class BrpcPsClient : public PSClient { ...@@ -199,6 +204,10 @@ class BrpcPsClient : public PSClient {
const uint64_t *keys, const uint64_t *keys,
size_t num, bool is_training); size_t num, bool is_training);
virtual std::future<int32_t> Pull(RequestContext &pull_context) override;
virtual std::future<int32_t> Push(RequestContext &push_context) override;
virtual std::future<int32_t> print_table_stat(uint32_t table_id); virtual std::future<int32_t> print_table_stat(uint32_t table_id);
virtual std::future<int32_t> barrier(size_t table_id, uint32_t barrier_type); virtual std::future<int32_t> barrier(size_t table_id, uint32_t barrier_type);
......
...@@ -51,7 +51,7 @@ class BrpcPsServer : public PSServer { ...@@ -51,7 +51,7 @@ class BrpcPsServer : public PSServer {
_server.Join(); _server.Join();
return 0; return 0;
} }
virtual int32_t port(); int32_t port();
private: private:
virtual int32_t initialize(); virtual int32_t initialize();
......
...@@ -43,7 +43,7 @@ class GraphBrpcServer : public PSServer { ...@@ -43,7 +43,7 @@ class GraphBrpcServer : public PSServer {
_server.Join(); _server.Join();
return 0; return 0;
} }
virtual int32_t port(); int32_t port();
std::condition_variable *export_cv() { return &cv_; } std::condition_variable *export_cv() { return &cv_; }
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include "paddle/fluid/distributed/ps/service/sendrecv.pb.h" #include "paddle/fluid/distributed/ps/service/sendrecv.pb.h"
#include "paddle/fluid/distributed/ps/table/accessor.h" #include "paddle/fluid/distributed/ps/table/accessor.h"
#include "paddle/fluid/distributed/ps/table/graph/graph_node.h" #include "paddle/fluid/distributed/ps/table/graph/graph_node.h"
#include "paddle/fluid/distributed/ps/table/table.h"
#include "paddle/fluid/platform/timer.h" #include "paddle/fluid/platform/timer.h"
namespace paddle { namespace paddle {
...@@ -59,6 +60,41 @@ class PSClientClosure : public google::protobuf::Closure { ...@@ -59,6 +60,41 @@ class PSClientClosure : public google::protobuf::Closure {
std::vector<std::shared_ptr<std::promise<int32_t>>> _promises; std::vector<std::shared_ptr<std::promise<int32_t>>> _promises;
}; };
struct LoadSaveContext {
int table_id;
std::string epoch;
std::string mode;
};
enum TrainingMode { Async = 0, Sync = 1, Geo = 3 };
enum TrainingPhase { Init = 0, Train = 1, Save = 2 };
// enum ValueType {
// Sparse = 0,
// Dense = 1
// };
struct PushContext {
const uint64_t *keys;
const float **push_values;
const Region *push_dense_values;
};
struct RequestContext {
int table;
TrainingMode training_mode; // 1 for async, 2 for geo, 3 for sync
TrainingPhase training_phase; // 1 for init, 2 for train
ValueType value_type; // 1 for sparse, 2 for dense
void *keys;
void **sparse_values; // for sparse values
Region *dense_values; // for dense values
PushContext push_context;
size_t num;
bool is_training;
void *callback;
};
class PSClient { class PSClient {
public: public:
PSClient() {} PSClient() {}
...@@ -86,6 +122,9 @@ class PSClient { ...@@ -86,6 +122,9 @@ class PSClient {
// 指定table数据load // 指定table数据load
virtual std::future<int32_t> load(uint32_t table_id, const std::string &epoch, virtual std::future<int32_t> load(uint32_t table_id, const std::string &epoch,
const std::string &mode) = 0; const std::string &mode) = 0;
// context配置load选项
virtual std::future<int32_t> Load(const LoadSaveContext &load_context) = 0;
// 全量table数据save value_accessor根据mode,可能有不同的save条件 // 全量table数据save value_accessor根据mode,可能有不同的save条件
virtual std::future<int32_t> save(const std::string &epoch, virtual std::future<int32_t> save(const std::string &epoch,
const std::string &mode) = 0; const std::string &mode) = 0;
...@@ -93,6 +132,8 @@ class PSClient { ...@@ -93,6 +132,8 @@ class PSClient {
virtual std::future<int32_t> save(uint32_t table_id, const std::string &epoch, virtual std::future<int32_t> save(uint32_t table_id, const std::string &epoch,
const std::string &mode) = 0; const std::string &mode) = 0;
virtual std::future<int32_t> Save(const LoadSaveContext &save_context) = 0;
// 清空table数据 // 清空table数据
virtual std::future<int32_t> clear() = 0; virtual std::future<int32_t> clear() = 0;
virtual std::future<int32_t> clear(uint32_t table_id) = 0; virtual std::future<int32_t> clear(uint32_t table_id) = 0;
...@@ -107,6 +148,8 @@ class PSClient { ...@@ -107,6 +148,8 @@ class PSClient {
virtual std::future<int32_t> pull_dense(Region *regions, size_t region_num, virtual std::future<int32_t> pull_dense(Region *regions, size_t region_num,
size_t table_id) = 0; // 保留 size_t table_id) = 0; // 保留
virtual std::future<int32_t> Push(RequestContext &push_context) = 0;
// firstly push dense param for parameter server // firstly push dense param for parameter server
// this is neccessary because dense weight initialized in trainer on cold // this is neccessary because dense weight initialized in trainer on cold
// start // start
...@@ -117,6 +160,9 @@ class PSClient { ...@@ -117,6 +160,9 @@ class PSClient {
virtual std::future<int32_t> push_dense(const Region *regions, virtual std::future<int32_t> push_dense(const Region *regions,
size_t region_num, size_t region_num,
size_t table_id) = 0; size_t table_id) = 0;
virtual std::future<int32_t> Pull(RequestContext &pull_context) = 0;
// 使用keys进行pull请求,结果填充values // 使用keys进行pull请求,结果填充values
// keys和values的个数均为num个,每个value占用select_size空间 // keys和values的个数均为num个,每个value占用select_size空间
// future结束前keys和values缓冲区不能再次使用 // future结束前keys和values缓冲区不能再次使用
......
...@@ -56,6 +56,19 @@ int32_t PsLocalClient::initialize() { ...@@ -56,6 +56,19 @@ int32_t PsLocalClient::initialize() {
return done(); return done();
} }
std::future<int32_t> PsLocalClient::Load(const LoadSaveContext& load_context) {
if (load_context.table_id < 0) {
for (auto& it : _table_map) {
load(it.first, load_context.epoch, load_context.mode);
}
return done();
} else {
auto* table_ptr = table(load_context.table_id);
table_ptr->load(load_context.epoch, load_context.mode);
return done();
}
}
::std::future<int32_t> PsLocalClient::save(const std::string& epoch, ::std::future<int32_t> PsLocalClient::save(const std::string& epoch,
const std::string& mode) { const std::string& mode) {
// TODO // TODO
...@@ -74,6 +87,21 @@ int32_t PsLocalClient::initialize() { ...@@ -74,6 +87,21 @@ int32_t PsLocalClient::initialize() {
return done(); return done();
} }
::std::future<int32_t> PsLocalClient::Save(
const LoadSaveContext& save_context) {
if (save_context.table_id < 0) {
for (auto& it : _table_map) {
save(it.first, save_context.epoch, save_context.mode);
}
return done();
} else {
auto* table_ptr = table(save_context.table_id);
table_ptr->flush();
table_ptr->save(save_context.epoch, save_context.mode);
return done();
}
}
::std::future<int32_t> PsLocalClient::clear() { ::std::future<int32_t> PsLocalClient::clear() {
// TODO // TODO
return done(); return done();
...@@ -93,6 +121,51 @@ int32_t PsLocalClient::initialize() { ...@@ -93,6 +121,51 @@ int32_t PsLocalClient::initialize() {
return done(); return done();
} }
::std::future<int32_t> PsLocalClient::Pull(RequestContext& pull_context) {
if (pull_context.value_type == Dense) { // pull dense
Region* dense_region = reinterpret_cast<Region*>(pull_context.dense_values);
pull_dense(dense_region, pull_context.num, pull_context.table);
} else { // pull sparse
uint64_t* keys = reinterpret_cast<uint64_t*>(pull_context.keys);
char** select_values = reinterpret_cast<char**>(pull_context.sparse_values);
size_t table_id = pull_context.table;
size_t num = pull_context.num;
pull_sparse_ptr(select_values, table_id, keys, num);
}
}
::std::future<int32_t> PsLocalClient::Push(RequestContext& push_context) {
if (push_context.value_type == Dense) { // push dense
if (push_context.training_phase == Init) {
const Region* regions = push_context.push_context.push_dense_values;
size_t region_num = push_context.num;
push_dense_param(regions, region_num, push_context.table);
} else {
if (push_context.training_mode == Geo) { // geo
float* total_send_data =
reinterpret_cast<float*>(push_context.dense_values);
size_t total_send_data_size = push_context.num;
push_dense_raw_gradient(push_context.table, total_send_data,
total_send_data_size, push_context.callback);
} else { // async and sync
const Region* regions = push_context.push_context.push_dense_values;
size_t region_num = push_context.num;
push_dense(regions, region_num, push_context.table);
}
}
} else { // push sparse
if (push_context.training_mode == Async) {
const uint64_t* keys = push_context.push_context.keys;
const float** update_values = push_context.push_context.push_values;
size_t table_id = push_context.table;
size_t num = push_context.num;
push_sparse(table_id, keys, update_values, num);
} else {
// TODO
}
}
}
::std::future<int32_t> PsLocalClient::pull_dense(Region* regions, ::std::future<int32_t> PsLocalClient::pull_dense(Region* regions,
size_t region_num, size_t region_num,
size_t table_id) { size_t table_id) {
......
...@@ -39,12 +39,16 @@ class PsLocalClient : public PSClient { ...@@ -39,12 +39,16 @@ class PsLocalClient : public PSClient {
virtual ::std::future<int32_t> load(uint32_t table_id, virtual ::std::future<int32_t> load(uint32_t table_id,
const std::string& epoch, const std::string& epoch,
const std::string& mode) override; const std::string& mode) override;
virtual std::future<int32_t> Load(
const LoadSaveContext& load_context) override;
virtual ::std::future<int32_t> save(const std::string& epoch, virtual ::std::future<int32_t> save(const std::string& epoch,
const std::string& mode) override; const std::string& mode) override;
virtual ::std::future<int32_t> save(uint32_t table_id, virtual ::std::future<int32_t> save(uint32_t table_id,
const std::string& epoch, const std::string& epoch,
const std::string& mode) override; const std::string& mode) override;
virtual std::future<int32_t> Save(
const LoadSaveContext& save_context) override;
virtual ::std::future<int32_t> clear() override; virtual ::std::future<int32_t> clear() override;
virtual ::std::future<int32_t> clear(uint32_t table_id) override; virtual ::std::future<int32_t> clear(uint32_t table_id) override;
...@@ -55,6 +59,10 @@ class PsLocalClient : public PSClient { ...@@ -55,6 +59,10 @@ class PsLocalClient : public PSClient {
virtual ::std::future<int32_t> pull_dense(Region* regions, size_t region_num, virtual ::std::future<int32_t> pull_dense(Region* regions, size_t region_num,
size_t table_id); size_t table_id);
virtual ::std::future<int32_t> Pull(RequestContext& pull_context) override;
virtual ::std::future<int32_t> Push(RequestContext& push_context) override;
virtual ::std::future<int32_t> push_dense(const Region* regions, virtual ::std::future<int32_t> push_dense(const Region* regions,
size_t region_num, size_t table_id); size_t region_num, size_t table_id);
......
...@@ -28,7 +28,6 @@ class PsLocalServer : public PSServer { ...@@ -28,7 +28,6 @@ class PsLocalServer : public PSServer {
virtual uint64_t start() { return 0; } virtual uint64_t start() { return 0; }
virtual uint64_t start(const std::string &ip, uint32_t port) { return 0; } virtual uint64_t start(const std::string &ip, uint32_t port) { return 0; }
virtual int32_t stop() { return 0; } virtual int32_t stop() { return 0; }
virtual int32_t port() { return 0; }
virtual int32_t configure( virtual int32_t configure(
const PSParameter &config, PSEnvironment &env, size_t server_rank, const PSParameter &config, PSEnvironment &env, size_t server_rank,
const std::vector<framework::ProgramDesc> &server_sub_program = {}) { const std::vector<framework::ProgramDesc> &server_sub_program = {}) {
......
...@@ -67,8 +67,6 @@ int32_t PSServer::configure( ...@@ -67,8 +67,6 @@ int32_t PSServer::configure(
_config = config.server_param(); _config = config.server_param();
_rank = server_rank; _rank = server_rank;
_environment = &env; _environment = &env;
_shuffled_ins =
paddle::framework::MakeChannel<std::pair<uint64_t, std::string>>();
size_t shard_num = env.get_ps_servers().size(); size_t shard_num = env.get_ps_servers().size();
const auto &downpour_param = _config.downpour_server_param(); const auto &downpour_param = _config.downpour_server_param();
......
...@@ -69,11 +69,6 @@ class PSServer { ...@@ -69,11 +69,6 @@ class PSServer {
const PSParameter &config, PSEnvironment &env, size_t server_rank, const PSParameter &config, PSEnvironment &env, size_t server_rank,
const std::vector<framework::ProgramDesc> &server_sub_program = {}); const std::vector<framework::ProgramDesc> &server_sub_program = {});
// return server_ip
virtual std::string ip() { return butil::my_ip_cstr(); }
// return server_port
virtual int32_t port() = 0;
virtual uint64_t start(const std::string &ip, uint32_t port) = 0; virtual uint64_t start(const std::string &ip, uint32_t port) = 0;
virtual int32_t stop() = 0; virtual int32_t stop() = 0;
...@@ -94,15 +89,6 @@ class PSServer { ...@@ -94,15 +89,6 @@ class PSServer {
return &_table_map; return &_table_map;
} }
typedef std::function<int32_t(int, int, const std::string &)> MsgHandlerFunc;
virtual int registe_pserver2pserver_msg_handler(int msg_type,
MsgHandlerFunc handler) {
_msg_handler_map[msg_type] = handler;
return 0;
}
paddle::framework::Channel<std::pair<uint64_t, std::string>> _shuffled_ins;
protected: protected:
virtual int32_t initialize() = 0; virtual int32_t initialize() = 0;
...@@ -111,7 +97,6 @@ class PSServer { ...@@ -111,7 +97,6 @@ class PSServer {
ServerParameter _config; ServerParameter _config;
PSEnvironment *_environment; PSEnvironment *_environment;
std::unordered_map<uint32_t, std::shared_ptr<Table>> _table_map; std::unordered_map<uint32_t, std::shared_ptr<Table>> _table_map;
std::unordered_map<int32_t, MsgHandlerFunc> _msg_handler_map;
protected: protected:
std::shared_ptr<framework::Scope> scope_; std::shared_ptr<framework::Scope> scope_;
......
...@@ -45,6 +45,17 @@ struct DataConverter { ...@@ -45,6 +45,17 @@ struct DataConverter {
std::string deconverter; std::string deconverter;
}; };
struct AccessorInfo {
size_t dim;
size_t size;
size_t select_size;
size_t select_dim;
size_t update_size;
size_t update_dim;
size_t mf_size;
size_t fea_dim;
};
class ValueAccessor { class ValueAccessor {
public: public:
ValueAccessor() {} ValueAccessor() {}
...@@ -68,6 +79,8 @@ class ValueAccessor { ...@@ -68,6 +79,8 @@ class ValueAccessor {
} }
virtual int initialize() = 0; virtual int initialize() = 0;
virtual void GetTableInfo(AccessorInfo& info) = 0;
// value维度 // value维度
virtual size_t dim() = 0; virtual size_t dim() = 0;
// value各个维度的size // value各个维度的size
...@@ -163,6 +176,7 @@ class ValueAccessor { ...@@ -163,6 +176,7 @@ class ValueAccessor {
TableAccessorParameter _config; TableAccessorParameter _config;
std::unordered_map<int, std::shared_ptr<struct DataConverter>> std::unordered_map<int, std::shared_ptr<struct DataConverter>>
_data_coverter_map; _data_coverter_map;
AccessorInfo _accessor_info;
}; };
REGISTER_PSCORE_REGISTERER(ValueAccessor); REGISTER_PSCORE_REGISTERER(ValueAccessor);
} // namespace distributed } // namespace distributed
......
...@@ -128,6 +128,21 @@ int32_t CommonDenseTable::set_global_lr(float* lr) { ...@@ -128,6 +128,21 @@ int32_t CommonDenseTable::set_global_lr(float* lr) {
return 0; return 0;
} }
int32_t CommonDenseTable::Pull(TableContext& context) {
CHECK(context.value_type == Dense);
float* pull_values = context.pull_context.values;
return pull_dense(pull_values, context.num);
}
int32_t CommonDenseTable::Push(TableContext& context) {
CHECK(context.value_type == Dense);
if (context.pull_context.values != nullptr) {
const float* values = context.push_context.values;
return push_dense(values, context.num);
}
return 0;
}
int32_t CommonDenseTable::pull_dense(float* pull_values, size_t num) { int32_t CommonDenseTable::pull_dense(float* pull_values, size_t num) {
std::copy(values_[param_idx_].begin(), values_[param_idx_].end(), std::copy(values_[param_idx_].begin(), values_[param_idx_].end(),
pull_values); pull_values);
......
...@@ -40,6 +40,8 @@ class CommonDenseTable : public DenseTable { ...@@ -40,6 +40,8 @@ class CommonDenseTable : public DenseTable {
const std::string& name); const std::string& name);
virtual int32_t initialize_value(); virtual int32_t initialize_value();
virtual int32_t initialize_optimizer(); virtual int32_t initialize_optimizer();
virtual int32_t Pull(TableContext& context);
virtual int32_t Push(TableContext& context);
int32_t pull_dense(float* pull_values, size_t num) override; int32_t pull_dense(float* pull_values, size_t num) override;
int32_t push_dense_param(const float* values, size_t num) override; int32_t push_dense_param(const float* values, size_t num) override;
int32_t push_dense(const float* values, size_t num) override; int32_t push_dense(const float* values, size_t num) override;
......
...@@ -454,6 +454,9 @@ class GraphTable : public SparseTable { ...@@ -454,6 +454,9 @@ class GraphTable : public SparseTable {
int32_t get_server_index_by_id(int64_t id); int32_t get_server_index_by_id(int64_t id);
Node *find_node(int64_t id); Node *find_node(int64_t id);
virtual int32_t Pull(TableContext &context) { return 0; }
virtual int32_t Push(TableContext &context) { return 0; }
virtual int32_t pull_sparse(float *values, virtual int32_t pull_sparse(float *values,
const PullSparseValue &pull_value) { const PullSparseValue &pull_value) {
return 0; return 0;
......
...@@ -355,6 +355,32 @@ int32_t CommonSparseTable::pour() { ...@@ -355,6 +355,32 @@ int32_t CommonSparseTable::pour() {
return 0; return 0;
} }
int32_t CommonSparseTable::Pull(TableContext& context) {
CHECK(context.value_type == Sparse);
if (context.use_ptr) {
char** pull_values = context.pull_context.ptr_values;
const uint64_t* keys = context.pull_context.keys;
return pull_sparse_ptr(pull_values, keys, context.num);
} else {
float* pull_values = context.pull_context.values;
const PullSparseValue& pull_value = context.pull_context.pull_value;
return pull_sparse(pull_values, pull_value);
}
}
int32_t CommonSparseTable::Push(TableContext& context) {
CHECK(context.value_type == Sparse);
if (context.pull_context.values != nullptr) {
const float* values = context.push_context.values;
const uint64_t* keys = context.push_context.keys;
return push_sparse(keys, values, context.num);
} else {
const float** values = context.push_context.ptr_values;
const uint64_t* keys = context.push_context.keys;
return push_sparse(keys, values, context.num);
}
}
int32_t CommonSparseTable::pull_sparse(float* pull_values, int32_t CommonSparseTable::pull_sparse(float* pull_values,
const PullSparseValue& pull_value) { const PullSparseValue& pull_value) {
auto shard_num = task_pool_size_; auto shard_num = task_pool_size_;
......
...@@ -121,6 +121,9 @@ class CommonSparseTable : public SparseTable { ...@@ -121,6 +121,9 @@ class CommonSparseTable : public SparseTable {
virtual int32_t push_dense(const float* values, size_t num) { return 0; } virtual int32_t push_dense(const float* values, size_t num) { return 0; }
// unused method end // unused method end
virtual int32_t Pull(TableContext& context);
virtual int32_t Push(TableContext& context);
virtual int32_t initialize(); virtual int32_t initialize();
virtual int32_t initialize_shard() { return 0; } virtual int32_t initialize_shard() { return 0; }
virtual int32_t initialize_value(); virtual int32_t initialize_value();
......
...@@ -119,6 +119,9 @@ class BarrierTable : public Table { ...@@ -119,6 +119,9 @@ class BarrierTable : public Table {
virtual void *get_shard(size_t shard_idx) { return 0; } virtual void *get_shard(size_t shard_idx) { return 0; }
virtual int32_t Pull(TableContext &context) { return 0; }
virtual int32_t Push(TableContext &context) { return 0; }
int32_t pull_dense(float *values, size_t num) override { return 0; } int32_t pull_dense(float *values, size_t num) override { return 0; }
int32_t push_dense(const float *values, size_t num) override { return 0; } int32_t push_dense(const float *values, size_t num) override { return 0; }
......
...@@ -38,6 +38,16 @@ int CtrCommonAccessor::initialize() { ...@@ -38,6 +38,16 @@ int CtrCommonAccessor::initialize() {
return 0; return 0;
} }
void CtrCommonAccessor::GetTableInfo(AccessorInfo& info) {
info.dim = dim();
info.size = size();
info.select_dim = select_dim();
info.select_size = select_size();
info.update_dim = update_dim();
info.update_size = update_size();
info.fea_dim = fea_dim();
}
size_t CtrCommonAccessor::dim() { return common_feature_value.dim(); } size_t CtrCommonAccessor::dim() { return common_feature_value.dim(); }
size_t CtrCommonAccessor::dim_size(size_t dim) { size_t CtrCommonAccessor::dim_size(size_t dim) {
......
...@@ -126,6 +126,7 @@ class CtrCommonAccessor : public ValueAccessor { ...@@ -126,6 +126,7 @@ class CtrCommonAccessor : public ValueAccessor {
virtual int initialize(); virtual int initialize();
virtual ~CtrCommonAccessor() {} virtual ~CtrCommonAccessor() {}
virtual void GetTableInfo(AccessorInfo& info);
// value维度 // value维度
virtual size_t dim(); virtual size_t dim();
// value各个维度的size // value各个维度的size
......
...@@ -37,6 +37,16 @@ int DownpourCtrDoubleAccessor::initialize() { ...@@ -37,6 +37,16 @@ int DownpourCtrDoubleAccessor::initialize() {
return 0; return 0;
} }
void DownpourCtrDoubleAccessor::GetTableInfo(AccessorInfo& info) {
info.dim = dim();
info.size = size();
info.select_dim = select_dim();
info.select_size = select_size();
info.update_dim = update_dim();
info.update_size = update_size();
info.fea_dim = fea_dim();
}
size_t DownpourCtrDoubleAccessor::dim() { size_t DownpourCtrDoubleAccessor::dim() {
auto embedx_dim = _config.embedx_dim(); auto embedx_dim = _config.embedx_dim();
return DownpourCtrDoubleFeatureValue::dim(embedx_dim); return DownpourCtrDoubleFeatureValue::dim(embedx_dim);
......
...@@ -168,6 +168,7 @@ class DownpourCtrDoubleAccessor : public ValueAccessor { ...@@ -168,6 +168,7 @@ class DownpourCtrDoubleAccessor : public ValueAccessor {
DownpourCtrDoubleAccessor() {} DownpourCtrDoubleAccessor() {}
virtual ~DownpourCtrDoubleAccessor() {} virtual ~DownpourCtrDoubleAccessor() {}
virtual int initialize(); virtual int initialize();
virtual void GetTableInfo(AccessorInfo& info);
// value维度 // value维度
virtual size_t dim(); virtual size_t dim();
// value各个维度的size // value各个维度的size
......
...@@ -58,7 +58,7 @@ struct PullSparseValue { ...@@ -58,7 +58,7 @@ struct PullSparseValue {
std::vector<int>* offset_shard) const { std::vector<int>* offset_shard) const {
offset_shard->reserve(numel_ / shard_num + 1); offset_shard->reserve(numel_ / shard_num + 1);
for (int x = 0; x < numel_; ++x) { for (int x = 0; x < numel_; ++x) {
if (feasigns_[x] % shard_num == shard_id) { if (int(feasigns_[x] % shard_num) == shard_id) {
offset_shard->push_back(x); offset_shard->push_back(x);
} }
} }
......
...@@ -37,6 +37,16 @@ int DownpourCtrAccessor::initialize() { ...@@ -37,6 +37,16 @@ int DownpourCtrAccessor::initialize() {
return 0; return 0;
} }
void DownpourCtrAccessor::GetTableInfo(AccessorInfo& info) {
info.dim = dim();
info.size = size();
info.select_dim = select_dim();
info.select_size = select_size();
info.update_dim = update_dim();
info.update_size = update_size();
info.fea_dim = fea_dim();
}
size_t DownpourCtrAccessor::dim() { size_t DownpourCtrAccessor::dim() {
auto embedx_dim = _config.embedx_dim(); auto embedx_dim = _config.embedx_dim();
return DownpourCtrFeatureValue::dim(embedx_dim); return DownpourCtrFeatureValue::dim(embedx_dim);
......
...@@ -160,6 +160,7 @@ class DownpourCtrAccessor : public ValueAccessor { ...@@ -160,6 +160,7 @@ class DownpourCtrAccessor : public ValueAccessor {
virtual ~DownpourCtrAccessor() {} virtual ~DownpourCtrAccessor() {}
virtual int initialize(); virtual int initialize();
virtual void GetTableInfo(AccessorInfo& info);
// value维度 // value维度
virtual size_t dim(); virtual size_t dim();
// value各个维度的size // value各个维度的size
......
...@@ -48,6 +48,8 @@ class MemorySparseGeoTable : public SparseTable { ...@@ -48,6 +48,8 @@ class MemorySparseGeoTable : public SparseTable {
virtual int32_t save(const std::string& path, const std::string& param) { virtual int32_t save(const std::string& path, const std::string& param) {
return 0; return 0;
} }
virtual int32_t Pull(TableContext& context) { return 0; }
virtual int32_t Push(TableContext& context) { return 0; }
virtual int32_t flush() { return 0; } virtual int32_t flush() { return 0; }
virtual int32_t shrink(const std::string& param) { return 0; } virtual int32_t shrink(const std::string& param) { return 0; }
virtual void clear() { return; } virtual void clear() { return; }
......
...@@ -390,6 +390,26 @@ std::pair<int64_t, int64_t> MemorySparseTable::print_table_stat() { ...@@ -390,6 +390,26 @@ std::pair<int64_t, int64_t> MemorySparseTable::print_table_stat() {
return {feasign_size, mf_size}; return {feasign_size, mf_size};
} }
int32_t MemorySparseTable::Pull(TableContext& context) {
CHECK(context.value_type == Sparse);
if (context.use_ptr) {
char** pull_values = context.pull_context.ptr_values;
const uint64_t* keys = context.pull_context.keys;
return pull_sparse_ptr(pull_values, keys, context.num);
} else {
float* pull_values = context.pull_context.values;
const PullSparseValue& pull_value = context.pull_context.pull_value;
return pull_sparse(pull_values, pull_value);
}
}
int32_t MemorySparseTable::Push(TableContext& context) {
CHECK(context.value_type == Sparse);
const uint64_t* keys = context.push_context.keys;
return push_sparse(keys, context.push_context.ptr_values, context.num);
}
int32_t MemorySparseTable::pull_sparse(float* pull_values, int32_t MemorySparseTable::pull_sparse(float* pull_values,
const PullSparseValue& pull_value) { const PullSparseValue& pull_value) {
CostTimer timer("pserver_sparse_select_all"); CostTimer timer("pserver_sparse_select_all");
......
...@@ -48,6 +48,9 @@ class MemorySparseTable : public SparseTable { ...@@ -48,6 +48,9 @@ class MemorySparseTable : public SparseTable {
virtual int32_t push_dense(const float* values, size_t num) { return 0; } virtual int32_t push_dense(const float* values, size_t num) { return 0; }
// unused method end // unused method end
virtual int32_t Pull(TableContext& context);
virtual int32_t Push(TableContext& context);
virtual int32_t initialize(); virtual int32_t initialize();
virtual int32_t initialize_shard() { return 0; } virtual int32_t initialize_shard() { return 0; }
virtual int32_t initialize_value(); virtual int32_t initialize_value();
......
...@@ -61,6 +61,21 @@ int32_t SSDSparseTable::initialize() { ...@@ -61,6 +61,21 @@ int32_t SSDSparseTable::initialize() {
return 0; return 0;
} }
int32_t SSDSparseTable::Pull(TableContext& context) {
CHECK(context.value_type == Sparse);
if (context.use_ptr) {
char** pull_values = context.pull_context.ptr_values;
const uint64_t* keys = context.pull_context.keys;
return pull_sparse_ptr(pull_values, keys, context.num);
} else {
float* pull_values = context.pull_context.values;
const PullSparseValue& pull_value = context.pull_context.pull_value;
return pull_sparse(pull_values, pull_value);
}
}
int32_t SSDSparseTable::Push(TableContext& context) { return 0; }
int32_t SSDSparseTable::pull_sparse(float* pull_values, int32_t SSDSparseTable::pull_sparse(float* pull_values,
const PullSparseValue& pull_value) { const PullSparseValue& pull_value) {
auto shard_num = task_pool_size_; auto shard_num = task_pool_size_;
......
...@@ -42,6 +42,9 @@ class SSDSparseTable : public CommonSparseTable { ...@@ -42,6 +42,9 @@ class SSDSparseTable : public CommonSparseTable {
// exchange data // exchange data
virtual int32_t update_table(); virtual int32_t update_table();
virtual int32_t Pull(TableContext& context);
virtual int32_t Push(TableContext& context);
virtual int32_t pull_sparse(float* values, const PullSparseValue& pull_value); virtual int32_t pull_sparse(float* values, const PullSparseValue& pull_value);
virtual int32_t pull_sparse_ptr(char** pull_values, const uint64_t* keys, virtual int32_t pull_sparse_ptr(char** pull_values, const uint64_t* keys,
......
...@@ -32,6 +32,30 @@ ...@@ -32,6 +32,30 @@
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
enum ValueType { Sparse = 0, Dense = 1 };
struct PullContext {
const uint64_t *keys;
const PullSparseValue pull_value;
float *values;
char **ptr_values;
};
struct TablePushContext {
const uint64_t *keys;
const float *values;
const float **ptr_values;
};
struct TableContext {
ValueType value_type;
PullContext pull_context;
TablePushContext push_context;
size_t num;
bool use_ptr;
};
class Table { class Table {
public: public:
Table() {} Table() {}
...@@ -39,6 +63,8 @@ class Table { ...@@ -39,6 +63,8 @@ class Table {
virtual int32_t initialize(const TableParameter &config, virtual int32_t initialize(const TableParameter &config,
const FsClientParameter &fs_config); const FsClientParameter &fs_config);
virtual int32_t Pull(TableContext &context) = 0;
virtual int32_t Push(TableContext &context) = 0;
virtual int32_t pull_dense(float *values, size_t num) = 0; virtual int32_t pull_dense(float *values, size_t num) = 0;
virtual int32_t push_dense(const float *values, size_t num) = 0; virtual int32_t push_dense(const float *values, size_t num) = 0;
// for push global_step // for push global_step
......
...@@ -20,6 +20,16 @@ namespace distributed { ...@@ -20,6 +20,16 @@ namespace distributed {
int CommMergeAccessor::initialize() { return 0; } int CommMergeAccessor::initialize() { return 0; }
void CommMergeAccessor::GetTableInfo(AccessorInfo &info) {
info.dim = dim();
info.size = size();
info.select_dim = select_dim();
info.select_size = select_size();
info.update_dim = update_dim();
info.update_size = update_size();
info.fea_dim = fea_dim();
}
// value 维度 // value 维度
size_t CommMergeAccessor::dim() { return 0; } size_t CommMergeAccessor::dim() { return 0; }
......
...@@ -30,6 +30,7 @@ class CommMergeAccessor : public ValueAccessor { ...@@ -30,6 +30,7 @@ class CommMergeAccessor : public ValueAccessor {
CommMergeAccessor() {} CommMergeAccessor() {}
virtual ~CommMergeAccessor() {} virtual ~CommMergeAccessor() {}
virtual int initialize(); virtual int initialize();
virtual void GetTableInfo(AccessorInfo &info);
// value维度 // value维度
virtual size_t dim(); virtual size_t dim();
// value各个维度的size // value各个维度的size
......
...@@ -48,6 +48,8 @@ class TensorTable : public Table { ...@@ -48,6 +48,8 @@ class TensorTable : public Table {
TensorTable() {} TensorTable() {}
virtual ~TensorTable() {} virtual ~TensorTable() {}
virtual int32_t Pull(TableContext &context) { return 0; }
virtual int32_t Push(TableContext &context) { return 0; }
int32_t pull_dense(float *values, size_t num) override { return 0; } int32_t pull_dense(float *values, size_t num) override { return 0; }
int32_t push_dense(const float *values, size_t num) override { return 0; } int32_t push_dense(const float *values, size_t num) override { return 0; }
......
...@@ -30,6 +30,32 @@ bool FleetWrapper::is_initialized_ = false; ...@@ -30,6 +30,32 @@ bool FleetWrapper::is_initialized_ = false;
std::shared_ptr<paddle::distributed::PSCore> FleetWrapper::pserver_ptr_ = NULL; std::shared_ptr<paddle::distributed::PSCore> FleetWrapper::pserver_ptr_ = NULL;
void FleetWrapper::Stop() { StopServer(); }
void FleetWrapper::Load(WrapperContext& context) {
auto table_id = context.table_id;
if (table_id >= 0 && context.meta != "") {
LoadSparseOnServer(context.path, context.meta, context.table_id);
return;
}
if (table_id < 0) { // laod all
LoadModel(context.path, context.mode);
} else { // load one table
LoadModelOneTable(table_id, context.path, context.mode);
}
return;
}
void FleetWrapper::Save(WrapperContext& context) {
auto table_id = context.table_id;
if (table_id < 0) {
SaveModel(context.path, context.mode);
} else {
SaveModelOneTable(table_id, context.path, context.mode);
}
return;
}
void FleetWrapper::SetClient2ClientConfig(int request_timeout_ms, void FleetWrapper::SetClient2ClientConfig(int request_timeout_ms,
int connect_timeout_ms, int connect_timeout_ms,
int max_retry) { int max_retry) {
......
...@@ -25,6 +25,7 @@ limitations under the License. */ ...@@ -25,6 +25,7 @@ limitations under the License. */
#include "paddle/fluid/distributed/ps/service/communicator/communicator_common.h" #include "paddle/fluid/distributed/ps/service/communicator/communicator_common.h"
#include "paddle/fluid/distributed/ps/service/ps_service/service.h" #include "paddle/fluid/distributed/ps/service/ps_service/service.h"
#include "paddle/fluid/distributed/ps/wrapper/ps_wrapper.h"
#include "paddle/fluid/framework/archive.h" #include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/framework/io/fs.h" #include "paddle/fluid/framework/io/fs.h"
#include "paddle/fluid/framework/io/shell.h" #include "paddle/fluid/framework/io/shell.h"
...@@ -54,7 +55,7 @@ using framework::Variable; ...@@ -54,7 +55,7 @@ using framework::Variable;
using RpcCtxMap = std::unordered_map<std::string, CommContext>; using RpcCtxMap = std::unordered_map<std::string, CommContext>;
class FleetWrapper { class FleetWrapper : public PSWrapper {
public: public:
virtual ~FleetWrapper() {} virtual ~FleetWrapper() {}
FleetWrapper() { FleetWrapper() {
...@@ -68,7 +69,13 @@ class FleetWrapper { ...@@ -68,7 +69,13 @@ class FleetWrapper {
// pserver request max retry // pserver request max retry
client2client_max_retry_ = 3; client2client_max_retry_ = 3;
} }
virtual int32_t Initialize(InitContext& context) { return 0; }
virtual void Stop() override;
virtual void Load(WrapperContext& context) override;
virtual void Save(WrapperContext& context) override;
// set client to client communication config // set client to client communication config
void SetClient2ClientConfig(int request_timeout_ms, int connect_timeout_ms, void SetClient2ClientConfig(int request_timeout_ms, int connect_timeout_ms,
int max_retry); int max_retry);
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
// You may obtain a copy of the License at You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and See the License for the specific language governing permissions and
// limitations under the License. limitations under the License. */
#ifndef PADDLE_FLUID_DISTRIBUTED_PS_WRAPPER_PS_WRAPPER_H_ #pragma once
#define PADDLE_FLUID_DISTRIBUTED_PS_WRAPPER_PS_WRAPPER_H_
#include <atomic>
#endif // PADDLE_FLUID_DISTRIBUTED_PS_WRAPPER_PS_WRAPPER_H_ #include <ctime>
#include <map>
#include <memory>
#include <random>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/distributed/ps/service/communicator/communicator_common.h"
#include "paddle/fluid/distributed/ps/service/ps_service/service.h"
#include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/framework/io/fs.h"
#include "paddle/fluid/framework/io/shell.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
namespace paddle {
namespace framework {
class Scope;
class SelectedRows;
class Variable;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace distributed {
class PSCore;
using framework::LoDTensor;
using framework::Scope;
using phi::SelectedRows;
using framework::Variable;
using RpcCtxMap = std::unordered_map<std::string, CommContext>;
struct WrapperContext {
uint32_t table_id;
const std::string path;
const int mode;
const std::string meta;
};
struct InitContext {
const std::vector<int> dev_ids; // for gpu
};
class PSWrapper {
public:
virtual ~PSWrapper() {}
PSWrapper() {}
// init server
virtual int32_t Initialize(InitContext& context) = 0;
virtual void Stop() = 0;
virtual void Load(WrapperContext& context) = 0;
virtual void Save(WrapperContext& context) = 0;
};
} // end namespace distributed
} // end namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册