未验证 提交 4d0d0eca 编写于 作者: Y yaoxuefeng 提交者: GitHub

mod base (#40702)

上级 382e460b
......@@ -414,6 +414,16 @@ std::future<int32_t> BrpcPsClient::load(uint32_t table_id,
return send_cmd(table_id, PS_LOAD_ONE_TABLE, {epoch, mode});
}
std::future<int32_t> BrpcPsClient::Load(const LoadSaveContext &load_context) {
if (load_context.table_id < 0) {
return send_cmd(-1, PS_LOAD_ALL_TABLE,
{load_context.epoch, load_context.mode});
} else {
return send_cmd(load_context.table_id, PS_LOAD_ONE_TABLE,
{load_context.epoch, load_context.mode});
}
}
std::future<int32_t> BrpcPsClient::save(const std::string &epoch,
const std::string &mode) {
VLOG(1) << "BrpcPsClient::save path " << epoch;
......@@ -427,6 +437,19 @@ std::future<int32_t> BrpcPsClient::save(uint32_t table_id,
return send_save_cmd(table_id, PS_SAVE_ONE_TABLE, {epoch, mode});
}
std::future<int32_t> BrpcPsClient::Save(const LoadSaveContext &save_context) {
if (save_context.table_id < 0) {
VLOG(1) << "BrpcPsClient::save path " << save_context.epoch;
return send_save_cmd(-1, PS_SAVE_ALL_TABLE,
{save_context.epoch, save_context.mode});
} else {
VLOG(1) << "BrpcPsClient::save one table path " << save_context.epoch
<< " table_id " << save_context.table_id;
return send_save_cmd(save_context.table_id, PS_SAVE_ONE_TABLE,
{save_context.epoch, save_context.mode});
}
}
std::future<int32_t> BrpcPsClient::clear() {
return send_cmd(-1, PS_CLEAR_ALL_TABLE, {});
}
......@@ -505,6 +528,44 @@ std::future<int32_t> BrpcPsClient::barrier(size_t table_id,
return send_cmd(table_id, PS_BARRIER, {std::to_string(barrier_type)});
}
std::future<int32_t> BrpcPsClient::Pull(RequestContext &pull_context) {
if (pull_context.value_type == Dense) { // pull dense
Region *dense_region =
reinterpret_cast<Region *>(pull_context.dense_values);
pull_dense(dense_region, pull_context.num, pull_context.table);
} else { // pull sparse
uint64_t *keys = reinterpret_cast<uint64_t *>(pull_context.keys);
float **select_values =
reinterpret_cast<float **>(pull_context.sparse_values);
size_t table_id = pull_context.table;
size_t num = pull_context.num;
bool is_training = pull_context.is_training;
if (pull_context.training_mode == Geo) { // for geo
pull_sparse_param(select_values, table_id, keys, num, is_training);
} else if (pull_context.training_mode == Async) { // for async
pull_sparse(select_values, table_id, keys, num, is_training);
}
}
}
std::future<int32_t> BrpcPsClient::Push(RequestContext &push_context) {
if (push_context.value_type == Dense) { // push dense
const Region *dense_region = push_context.push_context.push_dense_values;
push_dense(dense_region, push_context.num, push_context.table);
} else { // push sparse
size_t table_id = push_context.table;
size_t num = push_context.num;
bool is_training = push_context.is_training;
if (push_context.training_mode == Geo) { // for geo
// TODO(zhaocaibei)
} else if (push_context.training_mode == Async) { // for async
const uint64_t *keys = push_context.push_context.keys;
const float **update_values = push_context.push_context.push_values;
push_sparse(table_id, keys, update_values, num);
}
}
}
std::future<int32_t> BrpcPsClient::pull_geo_param(size_t table_id,
std::vector<float> *values,
std::vector<uint64_t> *keys,
......
......@@ -163,12 +163,17 @@ class BrpcPsClient : public PSClient {
std::future<int32_t> load(uint32_t table_id, const std::string &epoch,
const std::string &mode) override;
std::future<int32_t> Load(const LoadSaveContext &load_context) override;
std::future<int32_t> save(const std::string &epoch,
const std::string &mode) override;
std::future<int32_t> save(uint32_t table_id, const std::string &epoch,
const std::string &mode) override;
virtual std::future<int32_t> Save(
const LoadSaveContext &save_context) override;
std::future<int32_t> clear() override;
std::future<int32_t> clear(uint32_t table_id) override;
......@@ -199,6 +204,10 @@ class BrpcPsClient : public PSClient {
const uint64_t *keys,
size_t num, bool is_training);
virtual std::future<int32_t> Pull(RequestContext &pull_context) override;
virtual std::future<int32_t> Push(RequestContext &push_context) override;
virtual std::future<int32_t> print_table_stat(uint32_t table_id);
virtual std::future<int32_t> barrier(size_t table_id, uint32_t barrier_type);
......
......@@ -51,7 +51,7 @@ class BrpcPsServer : public PSServer {
_server.Join();
return 0;
}
virtual int32_t port();
int32_t port();
private:
virtual int32_t initialize();
......
......@@ -43,7 +43,7 @@ class GraphBrpcServer : public PSServer {
_server.Join();
return 0;
}
virtual int32_t port();
int32_t port();
std::condition_variable *export_cv() { return &cv_; }
......
......@@ -26,6 +26,7 @@
#include "paddle/fluid/distributed/ps/service/sendrecv.pb.h"
#include "paddle/fluid/distributed/ps/table/accessor.h"
#include "paddle/fluid/distributed/ps/table/graph/graph_node.h"
#include "paddle/fluid/distributed/ps/table/table.h"
#include "paddle/fluid/platform/timer.h"
namespace paddle {
......@@ -59,6 +60,41 @@ class PSClientClosure : public google::protobuf::Closure {
std::vector<std::shared_ptr<std::promise<int32_t>>> _promises;
};
struct LoadSaveContext {
int table_id;
std::string epoch;
std::string mode;
};
enum TrainingMode { Async = 0, Sync = 1, Geo = 3 };
enum TrainingPhase { Init = 0, Train = 1, Save = 2 };
// enum ValueType {
// Sparse = 0,
// Dense = 1
// };
struct PushContext {
const uint64_t *keys;
const float **push_values;
const Region *push_dense_values;
};
struct RequestContext {
int table;
TrainingMode training_mode; // 1 for async, 2 for geo, 3 for sync
TrainingPhase training_phase; // 1 for init, 2 for train
ValueType value_type; // 1 for sparse, 2 for dense
void *keys;
void **sparse_values; // for sparse values
Region *dense_values; // for dense values
PushContext push_context;
size_t num;
bool is_training;
void *callback;
};
class PSClient {
public:
PSClient() {}
......@@ -86,6 +122,9 @@ class PSClient {
// 指定table数据load
virtual std::future<int32_t> load(uint32_t table_id, const std::string &epoch,
const std::string &mode) = 0;
// context配置load选项
virtual std::future<int32_t> Load(const LoadSaveContext &load_context) = 0;
// 全量table数据save value_accessor根据mode,可能有不同的save条件
virtual std::future<int32_t> save(const std::string &epoch,
const std::string &mode) = 0;
......@@ -93,6 +132,8 @@ class PSClient {
virtual std::future<int32_t> save(uint32_t table_id, const std::string &epoch,
const std::string &mode) = 0;
virtual std::future<int32_t> Save(const LoadSaveContext &save_context) = 0;
// 清空table数据
virtual std::future<int32_t> clear() = 0;
virtual std::future<int32_t> clear(uint32_t table_id) = 0;
......@@ -107,6 +148,8 @@ class PSClient {
virtual std::future<int32_t> pull_dense(Region *regions, size_t region_num,
size_t table_id) = 0; // 保留
virtual std::future<int32_t> Push(RequestContext &push_context) = 0;
// firstly push dense param for parameter server
// this is neccessary because dense weight initialized in trainer on cold
// start
......@@ -117,6 +160,9 @@ class PSClient {
virtual std::future<int32_t> push_dense(const Region *regions,
size_t region_num,
size_t table_id) = 0;
virtual std::future<int32_t> Pull(RequestContext &pull_context) = 0;
// 使用keys进行pull请求,结果填充values
// keys和values的个数均为num个,每个value占用select_size空间
// future结束前keys和values缓冲区不能再次使用
......
......@@ -56,6 +56,19 @@ int32_t PsLocalClient::initialize() {
return done();
}
std::future<int32_t> PsLocalClient::Load(const LoadSaveContext& load_context) {
if (load_context.table_id < 0) {
for (auto& it : _table_map) {
load(it.first, load_context.epoch, load_context.mode);
}
return done();
} else {
auto* table_ptr = table(load_context.table_id);
table_ptr->load(load_context.epoch, load_context.mode);
return done();
}
}
::std::future<int32_t> PsLocalClient::save(const std::string& epoch,
const std::string& mode) {
// TODO
......@@ -74,6 +87,21 @@ int32_t PsLocalClient::initialize() {
return done();
}
::std::future<int32_t> PsLocalClient::Save(
const LoadSaveContext& save_context) {
if (save_context.table_id < 0) {
for (auto& it : _table_map) {
save(it.first, save_context.epoch, save_context.mode);
}
return done();
} else {
auto* table_ptr = table(save_context.table_id);
table_ptr->flush();
table_ptr->save(save_context.epoch, save_context.mode);
return done();
}
}
::std::future<int32_t> PsLocalClient::clear() {
// TODO
return done();
......@@ -93,6 +121,51 @@ int32_t PsLocalClient::initialize() {
return done();
}
::std::future<int32_t> PsLocalClient::Pull(RequestContext& pull_context) {
if (pull_context.value_type == Dense) { // pull dense
Region* dense_region = reinterpret_cast<Region*>(pull_context.dense_values);
pull_dense(dense_region, pull_context.num, pull_context.table);
} else { // pull sparse
uint64_t* keys = reinterpret_cast<uint64_t*>(pull_context.keys);
char** select_values = reinterpret_cast<char**>(pull_context.sparse_values);
size_t table_id = pull_context.table;
size_t num = pull_context.num;
pull_sparse_ptr(select_values, table_id, keys, num);
}
}
::std::future<int32_t> PsLocalClient::Push(RequestContext& push_context) {
if (push_context.value_type == Dense) { // push dense
if (push_context.training_phase == Init) {
const Region* regions = push_context.push_context.push_dense_values;
size_t region_num = push_context.num;
push_dense_param(regions, region_num, push_context.table);
} else {
if (push_context.training_mode == Geo) { // geo
float* total_send_data =
reinterpret_cast<float*>(push_context.dense_values);
size_t total_send_data_size = push_context.num;
push_dense_raw_gradient(push_context.table, total_send_data,
total_send_data_size, push_context.callback);
} else { // async and sync
const Region* regions = push_context.push_context.push_dense_values;
size_t region_num = push_context.num;
push_dense(regions, region_num, push_context.table);
}
}
} else { // push sparse
if (push_context.training_mode == Async) {
const uint64_t* keys = push_context.push_context.keys;
const float** update_values = push_context.push_context.push_values;
size_t table_id = push_context.table;
size_t num = push_context.num;
push_sparse(table_id, keys, update_values, num);
} else {
// TODO
}
}
}
::std::future<int32_t> PsLocalClient::pull_dense(Region* regions,
size_t region_num,
size_t table_id) {
......
......@@ -39,12 +39,16 @@ class PsLocalClient : public PSClient {
virtual ::std::future<int32_t> load(uint32_t table_id,
const std::string& epoch,
const std::string& mode) override;
virtual std::future<int32_t> Load(
const LoadSaveContext& load_context) override;
virtual ::std::future<int32_t> save(const std::string& epoch,
const std::string& mode) override;
virtual ::std::future<int32_t> save(uint32_t table_id,
const std::string& epoch,
const std::string& mode) override;
virtual std::future<int32_t> Save(
const LoadSaveContext& save_context) override;
virtual ::std::future<int32_t> clear() override;
virtual ::std::future<int32_t> clear(uint32_t table_id) override;
......@@ -55,6 +59,10 @@ class PsLocalClient : public PSClient {
virtual ::std::future<int32_t> pull_dense(Region* regions, size_t region_num,
size_t table_id);
virtual ::std::future<int32_t> Pull(RequestContext& pull_context) override;
virtual ::std::future<int32_t> Push(RequestContext& push_context) override;
virtual ::std::future<int32_t> push_dense(const Region* regions,
size_t region_num, size_t table_id);
......
......@@ -28,7 +28,6 @@ class PsLocalServer : public PSServer {
virtual uint64_t start() { return 0; }
virtual uint64_t start(const std::string &ip, uint32_t port) { return 0; }
virtual int32_t stop() { return 0; }
virtual int32_t port() { return 0; }
virtual int32_t configure(
const PSParameter &config, PSEnvironment &env, size_t server_rank,
const std::vector<framework::ProgramDesc> &server_sub_program = {}) {
......
......@@ -67,8 +67,6 @@ int32_t PSServer::configure(
_config = config.server_param();
_rank = server_rank;
_environment = &env;
_shuffled_ins =
paddle::framework::MakeChannel<std::pair<uint64_t, std::string>>();
size_t shard_num = env.get_ps_servers().size();
const auto &downpour_param = _config.downpour_server_param();
......
......@@ -69,11 +69,6 @@ class PSServer {
const PSParameter &config, PSEnvironment &env, size_t server_rank,
const std::vector<framework::ProgramDesc> &server_sub_program = {});
// return server_ip
virtual std::string ip() { return butil::my_ip_cstr(); }
// return server_port
virtual int32_t port() = 0;
virtual uint64_t start(const std::string &ip, uint32_t port) = 0;
virtual int32_t stop() = 0;
......@@ -94,15 +89,6 @@ class PSServer {
return &_table_map;
}
typedef std::function<int32_t(int, int, const std::string &)> MsgHandlerFunc;
virtual int registe_pserver2pserver_msg_handler(int msg_type,
MsgHandlerFunc handler) {
_msg_handler_map[msg_type] = handler;
return 0;
}
paddle::framework::Channel<std::pair<uint64_t, std::string>> _shuffled_ins;
protected:
virtual int32_t initialize() = 0;
......@@ -111,7 +97,6 @@ class PSServer {
ServerParameter _config;
PSEnvironment *_environment;
std::unordered_map<uint32_t, std::shared_ptr<Table>> _table_map;
std::unordered_map<int32_t, MsgHandlerFunc> _msg_handler_map;
protected:
std::shared_ptr<framework::Scope> scope_;
......
......@@ -45,6 +45,17 @@ struct DataConverter {
std::string deconverter;
};
struct AccessorInfo {
size_t dim;
size_t size;
size_t select_size;
size_t select_dim;
size_t update_size;
size_t update_dim;
size_t mf_size;
size_t fea_dim;
};
class ValueAccessor {
public:
ValueAccessor() {}
......@@ -68,6 +79,8 @@ class ValueAccessor {
}
virtual int initialize() = 0;
virtual void GetTableInfo(AccessorInfo& info) = 0;
// value维度
virtual size_t dim() = 0;
// value各个维度的size
......@@ -163,6 +176,7 @@ class ValueAccessor {
TableAccessorParameter _config;
std::unordered_map<int, std::shared_ptr<struct DataConverter>>
_data_coverter_map;
AccessorInfo _accessor_info;
};
REGISTER_PSCORE_REGISTERER(ValueAccessor);
} // namespace distributed
......
......@@ -128,6 +128,21 @@ int32_t CommonDenseTable::set_global_lr(float* lr) {
return 0;
}
int32_t CommonDenseTable::Pull(TableContext& context) {
CHECK(context.value_type == Dense);
float* pull_values = context.pull_context.values;
return pull_dense(pull_values, context.num);
}
int32_t CommonDenseTable::Push(TableContext& context) {
CHECK(context.value_type == Dense);
if (context.pull_context.values != nullptr) {
const float* values = context.push_context.values;
return push_dense(values, context.num);
}
return 0;
}
int32_t CommonDenseTable::pull_dense(float* pull_values, size_t num) {
std::copy(values_[param_idx_].begin(), values_[param_idx_].end(),
pull_values);
......
......@@ -40,6 +40,8 @@ class CommonDenseTable : public DenseTable {
const std::string& name);
virtual int32_t initialize_value();
virtual int32_t initialize_optimizer();
virtual int32_t Pull(TableContext& context);
virtual int32_t Push(TableContext& context);
int32_t pull_dense(float* pull_values, size_t num) override;
int32_t push_dense_param(const float* values, size_t num) override;
int32_t push_dense(const float* values, size_t num) override;
......
......@@ -454,6 +454,9 @@ class GraphTable : public SparseTable {
int32_t get_server_index_by_id(int64_t id);
Node *find_node(int64_t id);
virtual int32_t Pull(TableContext &context) { return 0; }
virtual int32_t Push(TableContext &context) { return 0; }
virtual int32_t pull_sparse(float *values,
const PullSparseValue &pull_value) {
return 0;
......
......@@ -355,6 +355,32 @@ int32_t CommonSparseTable::pour() {
return 0;
}
int32_t CommonSparseTable::Pull(TableContext& context) {
CHECK(context.value_type == Sparse);
if (context.use_ptr) {
char** pull_values = context.pull_context.ptr_values;
const uint64_t* keys = context.pull_context.keys;
return pull_sparse_ptr(pull_values, keys, context.num);
} else {
float* pull_values = context.pull_context.values;
const PullSparseValue& pull_value = context.pull_context.pull_value;
return pull_sparse(pull_values, pull_value);
}
}
int32_t CommonSparseTable::Push(TableContext& context) {
CHECK(context.value_type == Sparse);
if (context.pull_context.values != nullptr) {
const float* values = context.push_context.values;
const uint64_t* keys = context.push_context.keys;
return push_sparse(keys, values, context.num);
} else {
const float** values = context.push_context.ptr_values;
const uint64_t* keys = context.push_context.keys;
return push_sparse(keys, values, context.num);
}
}
int32_t CommonSparseTable::pull_sparse(float* pull_values,
const PullSparseValue& pull_value) {
auto shard_num = task_pool_size_;
......
......@@ -121,6 +121,9 @@ class CommonSparseTable : public SparseTable {
virtual int32_t push_dense(const float* values, size_t num) { return 0; }
// unused method end
virtual int32_t Pull(TableContext& context);
virtual int32_t Push(TableContext& context);
virtual int32_t initialize();
virtual int32_t initialize_shard() { return 0; }
virtual int32_t initialize_value();
......
......@@ -119,6 +119,9 @@ class BarrierTable : public Table {
virtual void *get_shard(size_t shard_idx) { return 0; }
virtual int32_t Pull(TableContext &context) { return 0; }
virtual int32_t Push(TableContext &context) { return 0; }
int32_t pull_dense(float *values, size_t num) override { return 0; }
int32_t push_dense(const float *values, size_t num) override { return 0; }
......
......@@ -38,6 +38,16 @@ int CtrCommonAccessor::initialize() {
return 0;
}
void CtrCommonAccessor::GetTableInfo(AccessorInfo& info) {
info.dim = dim();
info.size = size();
info.select_dim = select_dim();
info.select_size = select_size();
info.update_dim = update_dim();
info.update_size = update_size();
info.fea_dim = fea_dim();
}
size_t CtrCommonAccessor::dim() { return common_feature_value.dim(); }
size_t CtrCommonAccessor::dim_size(size_t dim) {
......
......@@ -126,6 +126,7 @@ class CtrCommonAccessor : public ValueAccessor {
virtual int initialize();
virtual ~CtrCommonAccessor() {}
virtual void GetTableInfo(AccessorInfo& info);
// value维度
virtual size_t dim();
// value各个维度的size
......
......@@ -37,6 +37,16 @@ int DownpourCtrDoubleAccessor::initialize() {
return 0;
}
void DownpourCtrDoubleAccessor::GetTableInfo(AccessorInfo& info) {
info.dim = dim();
info.size = size();
info.select_dim = select_dim();
info.select_size = select_size();
info.update_dim = update_dim();
info.update_size = update_size();
info.fea_dim = fea_dim();
}
size_t DownpourCtrDoubleAccessor::dim() {
auto embedx_dim = _config.embedx_dim();
return DownpourCtrDoubleFeatureValue::dim(embedx_dim);
......
......@@ -168,6 +168,7 @@ class DownpourCtrDoubleAccessor : public ValueAccessor {
DownpourCtrDoubleAccessor() {}
virtual ~DownpourCtrDoubleAccessor() {}
virtual int initialize();
virtual void GetTableInfo(AccessorInfo& info);
// value维度
virtual size_t dim();
// value各个维度的size
......
......@@ -58,7 +58,7 @@ struct PullSparseValue {
std::vector<int>* offset_shard) const {
offset_shard->reserve(numel_ / shard_num + 1);
for (int x = 0; x < numel_; ++x) {
if (feasigns_[x] % shard_num == shard_id) {
if (int(feasigns_[x] % shard_num) == shard_id) {
offset_shard->push_back(x);
}
}
......
......@@ -37,6 +37,16 @@ int DownpourCtrAccessor::initialize() {
return 0;
}
void DownpourCtrAccessor::GetTableInfo(AccessorInfo& info) {
info.dim = dim();
info.size = size();
info.select_dim = select_dim();
info.select_size = select_size();
info.update_dim = update_dim();
info.update_size = update_size();
info.fea_dim = fea_dim();
}
size_t DownpourCtrAccessor::dim() {
auto embedx_dim = _config.embedx_dim();
return DownpourCtrFeatureValue::dim(embedx_dim);
......
......@@ -160,6 +160,7 @@ class DownpourCtrAccessor : public ValueAccessor {
virtual ~DownpourCtrAccessor() {}
virtual int initialize();
virtual void GetTableInfo(AccessorInfo& info);
// value维度
virtual size_t dim();
// value各个维度的size
......
......@@ -48,6 +48,8 @@ class MemorySparseGeoTable : public SparseTable {
virtual int32_t save(const std::string& path, const std::string& param) {
return 0;
}
virtual int32_t Pull(TableContext& context) { return 0; }
virtual int32_t Push(TableContext& context) { return 0; }
virtual int32_t flush() { return 0; }
virtual int32_t shrink(const std::string& param) { return 0; }
virtual void clear() { return; }
......
......@@ -390,6 +390,26 @@ std::pair<int64_t, int64_t> MemorySparseTable::print_table_stat() {
return {feasign_size, mf_size};
}
int32_t MemorySparseTable::Pull(TableContext& context) {
CHECK(context.value_type == Sparse);
if (context.use_ptr) {
char** pull_values = context.pull_context.ptr_values;
const uint64_t* keys = context.pull_context.keys;
return pull_sparse_ptr(pull_values, keys, context.num);
} else {
float* pull_values = context.pull_context.values;
const PullSparseValue& pull_value = context.pull_context.pull_value;
return pull_sparse(pull_values, pull_value);
}
}
int32_t MemorySparseTable::Push(TableContext& context) {
CHECK(context.value_type == Sparse);
const uint64_t* keys = context.push_context.keys;
return push_sparse(keys, context.push_context.ptr_values, context.num);
}
int32_t MemorySparseTable::pull_sparse(float* pull_values,
const PullSparseValue& pull_value) {
CostTimer timer("pserver_sparse_select_all");
......
......@@ -48,6 +48,9 @@ class MemorySparseTable : public SparseTable {
virtual int32_t push_dense(const float* values, size_t num) { return 0; }
// unused method end
virtual int32_t Pull(TableContext& context);
virtual int32_t Push(TableContext& context);
virtual int32_t initialize();
virtual int32_t initialize_shard() { return 0; }
virtual int32_t initialize_value();
......
......@@ -61,6 +61,21 @@ int32_t SSDSparseTable::initialize() {
return 0;
}
int32_t SSDSparseTable::Pull(TableContext& context) {
CHECK(context.value_type == Sparse);
if (context.use_ptr) {
char** pull_values = context.pull_context.ptr_values;
const uint64_t* keys = context.pull_context.keys;
return pull_sparse_ptr(pull_values, keys, context.num);
} else {
float* pull_values = context.pull_context.values;
const PullSparseValue& pull_value = context.pull_context.pull_value;
return pull_sparse(pull_values, pull_value);
}
}
int32_t SSDSparseTable::Push(TableContext& context) { return 0; }
int32_t SSDSparseTable::pull_sparse(float* pull_values,
const PullSparseValue& pull_value) {
auto shard_num = task_pool_size_;
......
......@@ -42,6 +42,9 @@ class SSDSparseTable : public CommonSparseTable {
// exchange data
virtual int32_t update_table();
virtual int32_t Pull(TableContext& context);
virtual int32_t Push(TableContext& context);
virtual int32_t pull_sparse(float* values, const PullSparseValue& pull_value);
virtual int32_t pull_sparse_ptr(char** pull_values, const uint64_t* keys,
......
......@@ -32,6 +32,30 @@
namespace paddle {
namespace distributed {
enum ValueType { Sparse = 0, Dense = 1 };
struct PullContext {
const uint64_t *keys;
const PullSparseValue pull_value;
float *values;
char **ptr_values;
};
struct TablePushContext {
const uint64_t *keys;
const float *values;
const float **ptr_values;
};
struct TableContext {
ValueType value_type;
PullContext pull_context;
TablePushContext push_context;
size_t num;
bool use_ptr;
};
class Table {
public:
Table() {}
......@@ -39,6 +63,8 @@ class Table {
virtual int32_t initialize(const TableParameter &config,
const FsClientParameter &fs_config);
virtual int32_t Pull(TableContext &context) = 0;
virtual int32_t Push(TableContext &context) = 0;
virtual int32_t pull_dense(float *values, size_t num) = 0;
virtual int32_t push_dense(const float *values, size_t num) = 0;
// for push global_step
......
......@@ -20,6 +20,16 @@ namespace distributed {
int CommMergeAccessor::initialize() { return 0; }
void CommMergeAccessor::GetTableInfo(AccessorInfo &info) {
info.dim = dim();
info.size = size();
info.select_dim = select_dim();
info.select_size = select_size();
info.update_dim = update_dim();
info.update_size = update_size();
info.fea_dim = fea_dim();
}
// value 维度
size_t CommMergeAccessor::dim() { return 0; }
......
......@@ -30,6 +30,7 @@ class CommMergeAccessor : public ValueAccessor {
CommMergeAccessor() {}
virtual ~CommMergeAccessor() {}
virtual int initialize();
virtual void GetTableInfo(AccessorInfo &info);
// value维度
virtual size_t dim();
// value各个维度的size
......
......@@ -48,6 +48,8 @@ class TensorTable : public Table {
TensorTable() {}
virtual ~TensorTable() {}
virtual int32_t Pull(TableContext &context) { return 0; }
virtual int32_t Push(TableContext &context) { return 0; }
int32_t pull_dense(float *values, size_t num) override { return 0; }
int32_t push_dense(const float *values, size_t num) override { return 0; }
......
......@@ -30,6 +30,32 @@ bool FleetWrapper::is_initialized_ = false;
std::shared_ptr<paddle::distributed::PSCore> FleetWrapper::pserver_ptr_ = NULL;
void FleetWrapper::Stop() { StopServer(); }
void FleetWrapper::Load(WrapperContext& context) {
auto table_id = context.table_id;
if (table_id >= 0 && context.meta != "") {
LoadSparseOnServer(context.path, context.meta, context.table_id);
return;
}
if (table_id < 0) { // laod all
LoadModel(context.path, context.mode);
} else { // load one table
LoadModelOneTable(table_id, context.path, context.mode);
}
return;
}
void FleetWrapper::Save(WrapperContext& context) {
auto table_id = context.table_id;
if (table_id < 0) {
SaveModel(context.path, context.mode);
} else {
SaveModelOneTable(table_id, context.path, context.mode);
}
return;
}
void FleetWrapper::SetClient2ClientConfig(int request_timeout_ms,
int connect_timeout_ms,
int max_retry) {
......
......@@ -25,6 +25,7 @@ limitations under the License. */
#include "paddle/fluid/distributed/ps/service/communicator/communicator_common.h"
#include "paddle/fluid/distributed/ps/service/ps_service/service.h"
#include "paddle/fluid/distributed/ps/wrapper/ps_wrapper.h"
#include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/framework/io/fs.h"
#include "paddle/fluid/framework/io/shell.h"
......@@ -54,7 +55,7 @@ using framework::Variable;
using RpcCtxMap = std::unordered_map<std::string, CommContext>;
class FleetWrapper {
class FleetWrapper : public PSWrapper {
public:
virtual ~FleetWrapper() {}
FleetWrapper() {
......@@ -68,7 +69,13 @@ class FleetWrapper {
// pserver request max retry
client2client_max_retry_ = 3;
}
virtual int32_t Initialize(InitContext& context) { return 0; }
virtual void Stop() override;
virtual void Load(WrapperContext& context) override;
virtual void Save(WrapperContext& context) override;
// set client to client communication config
void SetClient2ClientConfig(int request_timeout_ms, int connect_timeout_ms,
int max_retry);
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef PADDLE_FLUID_DISTRIBUTED_PS_WRAPPER_PS_WRAPPER_H_
#define PADDLE_FLUID_DISTRIBUTED_PS_WRAPPER_PS_WRAPPER_H_
#endif // PADDLE_FLUID_DISTRIBUTED_PS_WRAPPER_PS_WRAPPER_H_
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <atomic>
#include <ctime>
#include <map>
#include <memory>
#include <random>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/distributed/ps/service/communicator/communicator_common.h"
#include "paddle/fluid/distributed/ps/service/ps_service/service.h"
#include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/framework/io/fs.h"
#include "paddle/fluid/framework/io/shell.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
namespace paddle {
namespace framework {
class Scope;
class SelectedRows;
class Variable;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace distributed {
class PSCore;
using framework::LoDTensor;
using framework::Scope;
using phi::SelectedRows;
using framework::Variable;
using RpcCtxMap = std::unordered_map<std::string, CommContext>;
struct WrapperContext {
uint32_t table_id;
const std::string path;
const int mode;
const std::string meta;
};
struct InitContext {
const std::vector<int> dev_ids; // for gpu
};
class PSWrapper {
public:
virtual ~PSWrapper() {}
PSWrapper() {}
// init server
virtual int32_t Initialize(InitContext& context) = 0;
virtual void Stop() = 0;
virtual void Load(WrapperContext& context) = 0;
virtual void Save(WrapperContext& context) = 0;
};
} // end namespace distributed
} // end namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册