未验证 提交 19cb0d18 编写于 作者: Z zhaocaibei123 提交者: GitHub

Table refine: Pull/Push(TableContext) (#41320)

* update name

* update name

* fix test

* fix fleet bind

* update name

* update name

* fix test

* fix gpups wrapper

* remove Push/Pull/Load/Save with context in client and wrapper base class

* fix

* fix

* remove some interface

* fix

* remove

* code style

* recover

* fix

* remove code unused

* fix

* recover

* fix
Co-authored-by: Nesythan <esythan@126.com>
上级 1071bafc
...@@ -244,7 +244,14 @@ int32_t BrpcPsService::PushDenseParam(Table *table, ...@@ -244,7 +244,14 @@ int32_t BrpcPsService::PushDenseParam(Table *table,
uint32_t num = *(const uint32_t *)data; uint32_t num = *(const uint32_t *)data;
const float *values = (const float *)(data + sizeof(uint32_t)); const float *values = (const float *)(data + sizeof(uint32_t));
if (table->PushDenseParam(values, num) != 0) { TableContext table_context;
table_context.value_type = Dense;
table_context.push_context.values = values;
table_context.push_context.is_param = true;
table_context.num = num;
// if (table->PushDenseParam(values, num) != 0) {
if (table->Push(table_context) != 0) {
set_response_code(response, -1, "PushDenseParam failed"); set_response_code(response, -1, "PushDenseParam failed");
} }
return 0; return 0;
...@@ -330,7 +337,15 @@ int32_t BrpcPsService::PushSparseParam(Table *table, ...@@ -330,7 +337,15 @@ int32_t BrpcPsService::PushSparseParam(Table *table,
const uint64_t *keys = (const uint64_t *)push_data.data(); const uint64_t *keys = (const uint64_t *)push_data.data();
const float *values = const float *values =
(const float *)(push_data.data() + sizeof(uint64_t) * num); (const float *)(push_data.data() + sizeof(uint64_t) * num);
if (table->PushSparseParam(keys, values, num) != 0) {
TableContext table_context;
table_context.value_type = Sparse;
table_context.push_context.keys = keys;
table_context.push_context.values = values;
table_context.push_context.is_param = true;
table_context.num = num;
// if (table->PushSparseParam(keys, values, num) != 0) {
if (table->Push(table_context) != 0) {
set_response_code(response, -1, "PushSparseParam error"); set_response_code(response, -1, "PushSparseParam error");
} }
return 0; return 0;
...@@ -349,7 +364,14 @@ int32_t BrpcPsService::PullGeoParam(Table *table, ...@@ -349,7 +364,14 @@ int32_t BrpcPsService::PullGeoParam(Table *table,
std::vector<float> values; std::vector<float> values;
std::vector<uint64_t> ids; std::vector<uint64_t> ids;
table->PullGeoParam(trainer_id, &values, &ids);
TableContext table_context;
table_context.value_type = Sparse;
table_context.pull_context.geo_pull_keys = &ids;
table_context.pull_context.geo_pull_values = &values;
table_context.trainer_id = trainer_id;
table->Pull(table_context);
// table->PullGeoParam(trainer_id, &values, &ids);
uint32_t num = ids.size(); uint32_t num = ids.size();
cntl->response_attachment().append((char *)(&num), sizeof(uint32_t)); cntl->response_attachment().append((char *)(&num), sizeof(uint32_t));
...@@ -625,7 +647,13 @@ int32_t BrpcPsService::PushGlobalStep(Table *table, ...@@ -625,7 +647,13 @@ int32_t BrpcPsService::PushGlobalStep(Table *table,
const int64_t *values = const int64_t *values =
(const int64_t *)(request.data().data() + sizeof(uint32_t)); (const int64_t *)(request.data().data() + sizeof(uint32_t));
auto trainer_id = request.client_id(); auto trainer_id = request.client_id();
if (table->PushDense(values, trainer_id) != 0) {
TableContext context;
context.trainer_id = trainer_id;
context.push_context.push_steps = values;
// if (table->PushDense(values, trainer_id) != 0) {
if (table->Push(context) != 0) {
set_response_code(response, -1, "run_program failed"); set_response_code(response, -1, "run_program failed");
} }
......
...@@ -104,7 +104,13 @@ int32_t PsLocalClient::Initialize() { ...@@ -104,7 +104,13 @@ int32_t PsLocalClient::Initialize() {
std::vector<float> region_buffer; std::vector<float> region_buffer;
region_buffer.resize(num_per_shard); region_buffer.resize(num_per_shard);
table_ptr->PullDense(region_buffer.data(), region_buffer.size());
TableContext table_context;
table_context.value_type = Dense;
table_context.pull_context.values = region_buffer.data();
table_context.num = region_buffer.size();
table_ptr->Pull(table_context);
// table_ptr->PullDense(region_buffer.data(), region_buffer.size());
size_t region_idx = 0; size_t region_idx = 0;
size_t region_data_idx = 0; size_t region_data_idx = 0;
...@@ -154,6 +160,13 @@ int32_t PsLocalClient::Initialize() { ...@@ -154,6 +160,13 @@ int32_t PsLocalClient::Initialize() {
offset += data_num; offset += data_num;
} }
TableContext table_context;
table_context.value_type = Dense;
table_context.push_context.values = region_buffer.data();
table_context.push_context.is_param = true;
table_context.num = region_buffer.size();
table_ptr->Push(table_context);
// table_ptr->PushDenseParam(region_buffer.data(), region_buffer.size()); // table_ptr->PushDenseParam(region_buffer.data(), region_buffer.size());
return done(); return done();
...@@ -168,7 +181,13 @@ int32_t PsLocalClient::Initialize() { ...@@ -168,7 +181,13 @@ int32_t PsLocalClient::Initialize() {
auto* table_ptr = GetTable(table_id); auto* table_ptr = GetTable(table_id);
table_ptr->PushDense(total_send_data, total_send_data_size); TableContext table_context;
table_context.value_type = Dense;
table_context.push_context.values = total_send_data;
table_context.num = total_send_data_size;
// table_ptr->PushDense(total_send_data, total_send_data_size);
table_ptr->Push(table_context);
delete closure; delete closure;
return done(); return done();
} }
...@@ -194,7 +213,12 @@ int32_t PsLocalClient::Initialize() { ...@@ -194,7 +213,12 @@ int32_t PsLocalClient::Initialize() {
offset += data_num; offset += data_num;
} }
table_ptr->PushDense(region_buffer.data(), region_buffer.size()); TableContext table_context;
table_context.value_type = Dense;
table_context.push_context.values = region_buffer.data();
table_context.num = region_buffer.size();
// table_ptr->PushDense(total_send_data, total_send_data_size);
table_ptr->Push(table_context);
return done(); return done();
} }
...@@ -241,7 +265,15 @@ int32_t PsLocalClient::Initialize() { ...@@ -241,7 +265,15 @@ int32_t PsLocalClient::Initialize() {
//将key拆分到各shard请求,并记录原始对应value指针 //将key拆分到各shard请求,并记录原始对应value指针
auto* table_ptr = GetTable(table_id); auto* table_ptr = GetTable(table_id);
table_ptr->PullSparsePtr(select_values, keys, num); TableContext table_context;
table_context.value_type = Sparse;
table_context.pull_context.keys = keys;
table_context.pull_context.ptr_values = select_values;
table_context.use_ptr = true;
table_context.num = num;
// table_ptr->PullSparsePtr(select_values, keys, num);
table_ptr->Pull(table_context);
return done(); return done();
} }
...@@ -253,7 +285,15 @@ int32_t PsLocalClient::Initialize() { ...@@ -253,7 +285,15 @@ int32_t PsLocalClient::Initialize() {
auto* accessor = GetTableAccessor(table_id); auto* accessor = GetTableAccessor(table_id);
auto* table_ptr = GetTable(table_id); auto* table_ptr = GetTable(table_id);
table_ptr->PushSparse(keys, update_values, num); TableContext table_context;
table_context.value_type = Sparse;
table_context.push_context.keys = keys;
table_context.push_context.ptr_values = update_values;
table_context.num = num;
table_context.use_ptr = true;
// table_ptr->PushSparse(keys, update_values, num);
table_ptr->Push(table_context);
delete closure; delete closure;
return done(); return done();
} }
...@@ -265,7 +305,15 @@ int32_t PsLocalClient::Initialize() { ...@@ -265,7 +305,15 @@ int32_t PsLocalClient::Initialize() {
auto* accessor = GetTableAccessor(table_id); auto* accessor = GetTableAccessor(table_id);
auto* table_ptr = GetTable(table_id); auto* table_ptr = GetTable(table_id);
table_ptr->PushSparse(keys, update_values, num); TableContext table_context;
table_context.value_type = Sparse;
table_context.push_context.keys = keys;
table_context.push_context.ptr_values = update_values;
table_context.num = num;
table_context.use_ptr = true;
// table_ptr->PushSparse(keys, update_values, num);
table_ptr->Push(table_context);
return done(); return done();
} }
} }
......
...@@ -139,8 +139,11 @@ int32_t CommonDenseTable::Pull(TableContext& context) { ...@@ -139,8 +139,11 @@ int32_t CommonDenseTable::Pull(TableContext& context) {
int32_t CommonDenseTable::Push(TableContext& context) { int32_t CommonDenseTable::Push(TableContext& context) {
CHECK(context.value_type == Dense); CHECK(context.value_type == Dense);
if (context.push_context.values != nullptr) { if (context.push_context.values != nullptr) {
const float* values = context.push_context.values; if (!context.push_context.is_param) {
return PushDense(values, context.num); return PushDense(context.push_context.values, context.num);
} else {
return PushDenseParam(context.push_context.values, context.num);
}
} }
return 0; return 0;
} }
......
...@@ -30,21 +30,22 @@ namespace distributed { ...@@ -30,21 +30,22 @@ namespace distributed {
class DenseOptimizer; class DenseOptimizer;
class CommonDenseTable : public DenseTable { class CommonDenseTable : public Table {
public: public:
CommonDenseTable() {} CommonDenseTable() {}
virtual ~CommonDenseTable() {} virtual ~CommonDenseTable() {}
int32_t Initialize() override; int32_t Initialize() override;
int32_t InitializeShard() override { return 0; } int32_t InitializeShard() override { return 0; }
virtual void CreateInitializer(const std::string& attr, void CreateInitializer(const std::string& attr, const std::string& name);
const std::string& name); int32_t InitializeValue();
virtual int32_t InitializeValue(); int32_t InitializeOptimizer();
virtual int32_t InitializeOptimizer();
virtual int32_t Pull(TableContext& context); int32_t Pull(TableContext& context) override;
virtual int32_t Push(TableContext& context); int32_t Push(TableContext& context) override;
int32_t PullDense(float* pull_values, size_t num) override;
int32_t PushDenseParam(const float* values, size_t num) override; int32_t PullDense(float* pull_values, size_t num);
int32_t PushDense(const float* values, size_t num) override; int32_t PushDenseParam(const float* values, size_t num);
int32_t PushDense(const float* values, size_t num);
int32_t Pour() override; int32_t Pour() override;
int32_t SetGlobalLR(float* lr) override; int32_t SetGlobalLR(float* lr) override;
...@@ -54,6 +55,7 @@ class CommonDenseTable : public DenseTable { ...@@ -54,6 +55,7 @@ class CommonDenseTable : public DenseTable {
int32_t Flush() override { return 0; } int32_t Flush() override { return 0; }
int32_t Shrink(const std::string& param) override { return 0; } int32_t Shrink(const std::string& param) override { return 0; }
void Clear() override { return; } void Clear() override { return; }
void* GetShard(size_t shard_idx) override { return 0; }
protected: protected:
int32_t _PushDense(const float* values, size_t num); int32_t _PushDense(const float* values, size_t num);
......
...@@ -404,7 +404,7 @@ class GraphSampler { ...@@ -404,7 +404,7 @@ class GraphSampler {
}; };
#endif #endif
class GraphTable : public SparseTable { class GraphTable : public Table {
public: public:
GraphTable() { GraphTable() {
use_cache = false; use_cache = false;
...@@ -415,6 +415,23 @@ class GraphTable : public SparseTable { ...@@ -415,6 +415,23 @@ class GraphTable : public SparseTable {
rw_lock.reset(new pthread_rwlock_t()); rw_lock.reset(new pthread_rwlock_t());
} }
virtual ~GraphTable(); virtual ~GraphTable();
virtual void *GetShard(size_t shard_idx) { return 0; }
static int32_t sparse_local_shard_num(uint32_t shard_num,
uint32_t server_num) {
if (shard_num % server_num == 0) {
return shard_num / server_num;
}
size_t local_shard_num = shard_num / server_num + 1;
return local_shard_num;
}
static size_t get_sparse_shard(uint32_t shard_num, uint32_t server_num,
uint64_t key) {
return (key % shard_num) / sparse_local_shard_num(shard_num, server_num);
}
virtual int32_t pull_graph_list(int start, int size, virtual int32_t pull_graph_list(int start, int size,
std::unique_ptr<char[]> &buffer, std::unique_ptr<char[]> &buffer,
int &actual_size, bool need_feature, int &actual_size, bool need_feature,
...@@ -452,15 +469,6 @@ class GraphTable : public SparseTable { ...@@ -452,15 +469,6 @@ class GraphTable : public SparseTable {
virtual int32_t Pull(TableContext &context) { return 0; } virtual int32_t Pull(TableContext &context) { return 0; }
virtual int32_t Push(TableContext &context) { return 0; } virtual int32_t Push(TableContext &context) { return 0; }
virtual int32_t PullSparse(float *values, const PullSparseValue &pull_value) {
return 0;
}
virtual int32_t PushSparse(const uint64_t *keys, const float *values,
size_t num) {
return 0;
}
virtual int32_t clear_nodes(); virtual int32_t clear_nodes();
virtual void Clear() {} virtual void Clear() {}
virtual int32_t Flush() { return 0; } virtual int32_t Flush() { return 0; }
......
...@@ -108,15 +108,16 @@ struct Meta { ...@@ -108,15 +108,16 @@ struct Meta {
} }
}; };
class CommonSparseTable : public SparseTable { class CommonSparseTable : public Table {
public: public:
CommonSparseTable() { rwlock_.reset(new phi::RWLock); } CommonSparseTable() { rwlock_.reset(new phi::RWLock); }
virtual ~CommonSparseTable() {} virtual ~CommonSparseTable() {}
// unused method begin // unused method begin
virtual int32_t PullDense(float* pull_values, size_t num) { return 0; } // virtual int32_t PullDense(float* pull_values, size_t num) { return 0; }
virtual int32_t PushDenseParam(const float* values, size_t num) { return 0; } // virtual int32_t PushDenseParam(const float* values, size_t num) { return
virtual int32_t PushDense(const float* values, size_t num) { return 0; } // 0; }
// virtual int32_t PushDense(const float* values, size_t num) { return 0; }
// unused method end // unused method end
virtual int32_t Pull(TableContext& context); virtual int32_t Pull(TableContext& context);
...@@ -163,14 +164,15 @@ class CommonSparseTable : public SparseTable { ...@@ -163,14 +164,15 @@ class CommonSparseTable : public SparseTable {
// only for sparse geo table // only for sparse geo table
virtual int32_t PushSparseParam(const uint64_t* keys, const float* values, virtual int32_t PushSparseParam(const uint64_t* keys, const float* values,
size_t num); size_t num);
virtual int32_t SetGlobalLR(float* lr);
virtual int32_t SetGlobalLR(float* lr) override;
virtual int32_t Pour(); virtual int32_t Pour();
virtual int32_t Flush(); virtual int32_t Flush();
virtual int32_t Shrink(const std::string& param); virtual int32_t Shrink(const std::string& param);
virtual void Clear(); virtual void Clear();
virtual void* GetShard(size_t shard_idx) { return 0; }
protected: protected:
virtual int32_t _PushSparse(const uint64_t* keys, const float* values, virtual int32_t _PushSparse(const uint64_t* keys, const float* values,
size_t num); size_t num);
......
...@@ -66,50 +66,6 @@ struct ReservoirValue { ...@@ -66,50 +66,6 @@ struct ReservoirValue {
} }
}; };
class SparseTable : public Table {
public:
SparseTable() {}
virtual ~SparseTable() {}
virtual void *GetShard(size_t shard_idx) { return 0; }
int32_t PullDense(float *values, size_t num) override { return 0; }
int32_t PushDense(const float *values, size_t num) override { return 0; }
static int32_t sparse_local_shard_num(uint32_t shard_num,
uint32_t server_num) {
if (shard_num % server_num == 0) {
return shard_num / server_num;
}
size_t local_shard_num = shard_num / server_num + 1;
return local_shard_num;
}
static size_t get_sparse_shard(uint32_t shard_num, uint32_t server_num,
uint64_t key) {
return (key % shard_num) / sparse_local_shard_num(shard_num, server_num);
}
};
class DenseTable : public Table {
public:
DenseTable() {}
virtual ~DenseTable() {}
virtual void *GetShard(size_t shard_idx) { return 0; }
int32_t PullSparse(float *values,
const PullSparseValue &pull_value) override {
return 0;
}
int32_t PushSparse(const uint64_t *keys, const float *values,
size_t num) override {
return 0;
}
int32_t PushDenseParam(const float *values, size_t num) override { return 0; }
int32_t Shrink(const std::string &param) override { return 0; }
};
class BarrierTable : public Table { class BarrierTable : public Table {
public: public:
BarrierTable() {} BarrierTable() {}
...@@ -120,19 +76,6 @@ class BarrierTable : public Table { ...@@ -120,19 +76,6 @@ class BarrierTable : public Table {
virtual int32_t Pull(TableContext &context) { return 0; } virtual int32_t Pull(TableContext &context) { return 0; }
virtual int32_t Push(TableContext &context) { return 0; } virtual int32_t Push(TableContext &context) { return 0; }
int32_t PullDense(float *values, size_t num) override { return 0; }
int32_t PushDense(const float *values, size_t num) override { return 0; }
int32_t PullSparse(float *values,
const PullSparseValue &pull_value) override {
return 0;
}
int32_t PushSparse(const uint64_t *keys, const float *values,
size_t num) override {
return 0;
}
int32_t PushDenseParam(const float *values, size_t num) override { return 0; }
int32_t Shrink(const std::string &param) override { return 0; } int32_t Shrink(const std::string &param) override { return 0; }
virtual void Clear() {} virtual void Clear() {}
virtual int32_t Flush() { return 0; } virtual int32_t Flush() { return 0; }
......
...@@ -17,6 +17,29 @@ ...@@ -17,6 +17,29 @@
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
int32_t MemorySparseGeoTable::Pull(TableContext& context) {
CHECK(context.value_type == Sparse);
if (context.pull_context.geo_pull_keys != nullptr) {
return PullGeoParam(context.trainer_id,
context.pull_context.geo_pull_values,
context.pull_context.geo_pull_keys);
} else {
return PullSparse(context.pull_context.values,
context.pull_context.pull_value);
}
}
int32_t MemorySparseGeoTable::Push(TableContext& context) {
CHECK(context.value_type == Sparse);
if (!context.push_context.is_param) {
return PushSparse(context.push_context.keys, context.push_context.values,
context.num);
} else {
return PushSparseParam(context.push_context.keys,
context.push_context.values, context.num);
}
}
int32_t MemorySparseGeoTable::PushSparseParam(const uint64_t* keys, int32_t MemorySparseGeoTable::PushSparseParam(const uint64_t* keys,
const float* values, size_t num) { const float* values, size_t num) {
VLOG(5) << "DEBUG MemorySparseGeoTable::PushSparseParam begin " VLOG(5) << "DEBUG MemorySparseGeoTable::PushSparseParam begin "
...@@ -117,6 +140,7 @@ int32_t MemorySparseGeoTable::Initialize() { ...@@ -117,6 +140,7 @@ int32_t MemorySparseGeoTable::Initialize() {
return 0; return 0;
} }
// hash different from MemorySparseTable
int32_t MemorySparseGeoTable::PullSparse(float* pull_values, int32_t MemorySparseGeoTable::PullSparse(float* pull_values,
const PullSparseValue& pull_value) { const PullSparseValue& pull_value) {
auto shard_num = _task_pool_size; auto shard_num = _task_pool_size;
......
...@@ -34,40 +34,44 @@ namespace distributed { ...@@ -34,40 +34,44 @@ namespace distributed {
class GeoRecorder; class GeoRecorder;
class MemorySparseGeoTable : public SparseTable { class MemorySparseGeoTable : public Table {
public: public:
typedef SparseTableShard<uint64_t, FixedFeatureValue> shard_type; typedef SparseTableShard<uint64_t, FixedFeatureValue> shard_type;
MemorySparseGeoTable() { _geo_recorder = nullptr; } MemorySparseGeoTable() { _geo_recorder = nullptr; }
virtual ~MemorySparseGeoTable() {} virtual ~MemorySparseGeoTable() {}
virtual int32_t Initialize(); int32_t Initialize() override;
virtual int32_t InitializeShard() { return 0; } int32_t InitializeShard() override { return 0; }
virtual int32_t Load(const std::string& path, const std::string& param) { int32_t Load(const std::string& path, const std::string& param) override {
return 0; return 0;
} }
virtual int32_t Save(const std::string& path, const std::string& param) { int32_t Save(const std::string& path, const std::string& param) override {
return 0; return 0;
} }
virtual int32_t Pull(TableContext& context) { return 0; } int32_t Pull(TableContext& context) override;
virtual int32_t Push(TableContext& context) { return 0; } int32_t Push(TableContext& context) override;
virtual int32_t Flush() { return 0; } int32_t Flush() override { return 0; }
virtual int32_t Shrink(const std::string& param) { return 0; } int32_t Shrink(const std::string& param) override { return 0; }
virtual void Clear() { return; } void Clear() override { return; }
virtual int32_t PullSparse(float* values, const PullSparseValue& pull_value);
int32_t PullSparse(float* values, const PullSparseValue& pull_value);
int32_t PushSparseParam(const uint64_t* keys, const float* values, int32_t PushSparseParam(const uint64_t* keys, const float* values,
size_t num); size_t num);
// TODO(zhaocaibei123): change to pull_sparse, and rename pull_sparse
int32_t PullGeoParam(const uint32_t trainer_id, std::vector<float>* values, int32_t PullGeoParam(const uint32_t trainer_id, std::vector<float>* values,
std::vector<uint64_t>* keys); std::vector<uint64_t>* keys);
int32_t PushSparse(const uint64_t* keys, const float* values, int32_t PushSparse(const uint64_t* keys, const float* values, size_t num);
size_t num) override;
int32_t _PushSparse(const uint64_t* keys, const float* values, size_t num); int32_t _PushSparse(const uint64_t* keys, const float* values, size_t num);
// int32_t _pull_sparse(float* pull_values, const PullSparseValue& // int32_t _pull_sparse(float* pull_values, const PullSparseValue&
// pull_value); // pull_value);
void* GetShard(size_t shard_idx) override {
return &_local_shards[shard_idx];
}
private: private:
std::shared_ptr<GeoRecorder> _geo_recorder; std::shared_ptr<GeoRecorder> _geo_recorder;
const int _task_pool_size = 10; const int _task_pool_size = 10;
......
...@@ -47,7 +47,7 @@ int32_t MemorySparseTable::Initialize() { ...@@ -47,7 +47,7 @@ int32_t MemorySparseTable::Initialize() {
int32_t MemorySparseTable::InitializeValue() { int32_t MemorySparseTable::InitializeValue() {
_sparse_table_shard_num = static_cast<int>(_config.shard_num()); _sparse_table_shard_num = static_cast<int>(_config.shard_num());
_avg_local_shard_num = _avg_local_shard_num =
SparseTable::sparse_local_shard_num(_sparse_table_shard_num, _shard_num); sparse_local_shard_num(_sparse_table_shard_num, _shard_num);
_real_local_shard_num = _avg_local_shard_num; _real_local_shard_num = _avg_local_shard_num;
if (_real_local_shard_num * (_shard_idx + 1) > _sparse_table_shard_num) { if (_real_local_shard_num * (_shard_idx + 1) > _sparse_table_shard_num) {
_real_local_shard_num = _real_local_shard_num =
...@@ -405,9 +405,13 @@ int32_t MemorySparseTable::Pull(TableContext& context) { ...@@ -405,9 +405,13 @@ int32_t MemorySparseTable::Pull(TableContext& context) {
int32_t MemorySparseTable::Push(TableContext& context) { int32_t MemorySparseTable::Push(TableContext& context) {
CHECK(context.value_type == Sparse); CHECK(context.value_type == Sparse);
if (!context.use_ptr) {
const uint64_t* keys = context.push_context.keys; return PushSparse(context.push_context.keys, context.push_context.values,
return PushSparse(keys, context.push_context.values, context.num); context.num);
} else {
return PushSparse(context.push_context.keys,
context.push_context.ptr_values, context.num);
}
} }
int32_t MemorySparseTable::PullSparse(float* pull_values, int32_t MemorySparseTable::PullSparse(float* pull_values,
...@@ -610,12 +614,6 @@ int32_t MemorySparseTable::PushSparse(const uint64_t* keys, const float* values, ...@@ -610,12 +614,6 @@ int32_t MemorySparseTable::PushSparse(const uint64_t* keys, const float* values,
int32_t MemorySparseTable::PushSparse(const uint64_t* keys, int32_t MemorySparseTable::PushSparse(const uint64_t* keys,
const float** values, size_t num) { const float** values, size_t num) {
_PushSparse(keys, values, num);
return 0;
}
int32_t MemorySparseTable::_PushSparse(const uint64_t* keys,
const float** values, size_t num) {
std::vector<std::future<int>> tasks(_real_local_shard_num); std::vector<std::future<int>> tasks(_real_local_shard_num);
std::vector<std::vector<std::pair<uint64_t, int>>> task_keys( std::vector<std::vector<std::pair<uint64_t, int>>> task_keys(
_real_local_shard_num); _real_local_shard_num);
......
...@@ -34,28 +34,37 @@ ...@@ -34,28 +34,37 @@
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
class MemorySparseTable : public SparseTable { class MemorySparseTable : public Table {
public: public:
typedef SparseTableShard<uint64_t, FixedFeatureValue> shard_type; typedef SparseTableShard<uint64_t, FixedFeatureValue> shard_type;
MemorySparseTable() {} MemorySparseTable() {}
virtual ~MemorySparseTable() {} virtual ~MemorySparseTable() {}
// unused method begin
virtual int32_t PullDense(float* pull_values, size_t num) { return 0; }
virtual int32_t PushDenseParam(const float* values, size_t num) { return 0; }
virtual int32_t PushDense(const float* values, size_t num) { return 0; }
// unused method end // unused method end
static int32_t sparse_local_shard_num(uint32_t shard_num,
uint32_t server_num) {
if (shard_num % server_num == 0) {
return shard_num / server_num;
}
size_t local_shard_num = shard_num / server_num + 1;
return local_shard_num;
}
virtual int32_t Pull(TableContext& context); static size_t get_sparse_shard(uint32_t shard_num, uint32_t server_num,
virtual int32_t Push(TableContext& context); uint64_t key) {
return (key % shard_num) / sparse_local_shard_num(shard_num, server_num);
}
virtual int32_t Initialize(); int32_t Pull(TableContext& context) override;
virtual int32_t InitializeShard() { return 0; } int32_t Push(TableContext& context) override;
virtual int32_t InitializeValue();
virtual int32_t Load(const std::string& path, const std::string& param); int32_t Initialize() override;
int32_t InitializeShard() override { return 0; }
int32_t InitializeValue();
virtual int32_t Save(const std::string& path, const std::string& param); int32_t Load(const std::string& path, const std::string& param) override;
int32_t Save(const std::string& path, const std::string& param) override;
int32_t LoadLocalFS(const std::string& path, const std::string& param); int32_t LoadLocalFS(const std::string& path, const std::string& param);
int32_t SaveLocalFS(const std::string& path, const std::string& param, int32_t SaveLocalFS(const std::string& path, const std::string& param,
...@@ -64,25 +73,22 @@ class MemorySparseTable : public SparseTable { ...@@ -64,25 +73,22 @@ class MemorySparseTable : public SparseTable {
int64_t LocalSize(); int64_t LocalSize();
int64_t LocalMFSize(); int64_t LocalMFSize();
virtual std::pair<int64_t, int64_t> PrintTableStat(); std::pair<int64_t, int64_t> PrintTableStat() override;
virtual int32_t PullSparse(float* values, const PullSparseValue& pull_value); int32_t PullSparse(float* values, const PullSparseValue& pull_value);
virtual int32_t PullSparsePtr(char** pull_values, const uint64_t* keys, int32_t PullSparsePtr(char** pull_values, const uint64_t* keys, size_t num);
size_t num);
virtual int32_t PushSparse(const uint64_t* keys, const float* values, int32_t PushSparse(const uint64_t* keys, const float* values, size_t num);
size_t num);
virtual int32_t PushSparse(const uint64_t* keys, const float** values, int32_t PushSparse(const uint64_t* keys, const float** values, size_t num);
size_t num);
virtual int32_t Flush(); int32_t Flush() override;
virtual int32_t Shrink(const std::string& param); int32_t Shrink(const std::string& param) override;
virtual void Clear(); void Clear() override;
protected: void* GetShard(size_t shard_idx) override {
virtual int32_t _PushSparse(const uint64_t* keys, const float** values, return &_local_shards[shard_idx];
size_t num); }
protected: protected:
const int _task_pool_size = 24; const int _task_pool_size = 24;
......
...@@ -35,25 +35,30 @@ namespace distributed { ...@@ -35,25 +35,30 @@ namespace distributed {
enum ValueType { Sparse = 0, Dense = 1 }; enum ValueType { Sparse = 0, Dense = 1 };
struct PullContext { struct TablePullContext {
const uint64_t *keys; const uint64_t *keys = nullptr;
PullSparseValue pull_value; PullSparseValue pull_value;
float *values; float *values = nullptr;
char **ptr_values; char **ptr_values = nullptr;
std::vector<uint64_t> *geo_pull_keys = nullptr; // for GEO
std::vector<float> *geo_pull_values = nullptr; // for GEO
}; };
struct TablePushContext { struct TablePushContext {
const uint64_t *keys; const uint64_t *keys = nullptr;
const float *values; const float *values = nullptr;
const float **ptr_values; const float **ptr_values = nullptr;
const int64_t *push_steps = nullptr; // for global step
bool is_param = false; // true: push param, false: push gradient
}; };
struct TableContext { struct TableContext {
ValueType value_type; ValueType value_type;
PullContext pull_context; TablePullContext pull_context;
TablePushContext push_context; TablePushContext push_context;
size_t num; size_t num;
bool use_ptr = false; bool use_ptr = false;
uint32_t trainer_id; // for GEO and global step
}; };
class Table { class Table {
...@@ -65,38 +70,6 @@ class Table { ...@@ -65,38 +70,6 @@ class Table {
virtual int32_t Pull(TableContext &context) = 0; virtual int32_t Pull(TableContext &context) = 0;
virtual int32_t Push(TableContext &context) = 0; virtual int32_t Push(TableContext &context) = 0;
virtual int32_t PullDense(float *values, size_t num) = 0;
virtual int32_t PushDense(const float *values, size_t num) = 0;
// for push global_step
virtual int32_t PushDense(const int64_t *values, const int32_t trainer_id) {
return 0;
}
virtual int32_t PushDenseParam(const float *values, size_t num) { return 0; }
virtual int32_t PullSparsePtr(char **pull_values, const uint64_t *keys,
size_t num) {
VLOG(0) << "NOT IMPLEMENT";
return 0;
}
virtual int32_t PullSparse(float *values,
const PullSparseValue &pull_value) = 0;
virtual int32_t PushSparse(const uint64_t *keys, const float *values,
size_t num) = 0;
virtual int32_t PushSparse(const uint64_t *keys, const float **values,
size_t num) {
return 0;
}
virtual int32_t PushSparseParam(const uint64_t *keys, const float *values,
size_t num) {
return 0;
}
// only for sparse geo table
virtual int32_t PullGeoParam(const uint32_t trainer_id,
std::vector<float> *values,
std::vector<uint64_t> *keys) {
return 0;
}
// only for barrier // only for barrier
virtual int32_t Barrier(const uint32_t trainer_id, virtual int32_t Barrier(const uint32_t trainer_id,
......
...@@ -50,43 +50,28 @@ class TensorTable : public Table { ...@@ -50,43 +50,28 @@ class TensorTable : public Table {
TensorTable() {} TensorTable() {}
virtual ~TensorTable() {} virtual ~TensorTable() {}
virtual int32_t Pull(TableContext &context) { return 0; } int32_t Pull(TableContext &context) override { return 0; }
virtual int32_t Push(TableContext &context) { return 0; } int32_t Push(TableContext &context) override { return 0; }
int32_t PullDense(float *values, size_t num) override { return 0; }
int32_t PushDense(const float *values, size_t num) override { return 0; }
int32_t PullSparse(float *values,
const PullSparseValue &pull_value) override {
return 0;
}
int32_t PushSparse(const uint64_t *keys, const float *values,
size_t num) override {
return 0;
}
int32_t Shrink(const std::string &param) override { return 0; } int32_t Shrink(const std::string &param) override { return 0; }
virtual void *GetShard(size_t shard_idx) { return 0; } void *GetShard(size_t shard_idx) override { return 0; }
virtual int32_t InitializeShard() { return 0; } int32_t InitializeShard() override { return 0; }
virtual int32_t Flush() { return 0; } int32_t Flush() override { return 0; }
virtual int32_t Load(const std::string &path, const std::string &param) { int32_t Load(const std::string &path, const std::string &param) override {
return 0; return 0;
} }
virtual int32_t Save(const std::string &path, const std::string &param) { int32_t Save(const std::string &path, const std::string &param) override {
return 0; return 0;
} }
virtual void Clear() {} void Clear() override {}
int32_t Initialize() override { return 0; } int32_t Initialize() override { return 0; }
int32_t PushDense(const int64_t *values, const int32_t trainer_id) override {
return 0;
}
int32_t SetProgramEnv( int32_t SetProgramEnv(
framework::Scope *scope, platform::Place place, framework::Scope *scope, platform::Place place,
const std::vector<framework::ProgramDesc> *sub_program) override { const std::vector<framework::ProgramDesc> *sub_program) override {
...@@ -111,45 +96,28 @@ class DenseTensorTable : public TensorTable { ...@@ -111,45 +96,28 @@ class DenseTensorTable : public TensorTable {
DenseTensorTable() {} DenseTensorTable() {}
virtual ~DenseTensorTable() {} virtual ~DenseTensorTable() {}
int32_t PullSparse(float *values,
const PullSparseValue &pull_value) override {
return 0;
}
int32_t PushSparse(const uint64_t *keys, const float *values,
size_t num) override {
return 0;
}
int32_t Shrink(const std::string &param) override { return 0; } int32_t Shrink(const std::string &param) override { return 0; }
virtual void *GetShard(size_t shard_idx) { return 0; } void *GetShard(size_t shard_idx) override { return 0; }
virtual int32_t InitializeShard() { return 0; } int32_t InitializeShard() override { return 0; }
virtual int32_t Flush() { return 0; } int32_t Flush() override { return 0; }
virtual void Clear() {} void Clear() override {}
// Todo: Support program Load & Save // Todo: Support program Load & Save
virtual int32_t Load(const std::string &path, const std::string &param) { int32_t Load(const std::string &path, const std::string &param) override {
return 0; return 0;
} }
virtual int32_t Save(const std::string &path, const std::string &param) { int32_t Save(const std::string &path, const std::string &param) override {
return 0; return 0;
} }
// Todo: Support pull dense
int32_t PullDense(float *values, size_t num) override { return 0; }
/*----------------------------------------------------------------------*/ /*----------------------------------------------------------------------*/
int32_t Initialize() override { return 0; } int32_t Initialize() override { return 0; }
int32_t PushDense(const float *values, size_t num) override { return 0; }
int32_t PushDense(const int64_t *values, const int32_t trainer_id) {
return 0;
}
protected: protected:
virtual int32_t _RunProgram(const float *values, size_t num, virtual int32_t _RunProgram(const float *values, size_t num,
const uint32_t trainer_id) { const uint32_t trainer_id) {
...@@ -167,33 +135,23 @@ class GlobalStepTable : public DenseTensorTable { ...@@ -167,33 +135,23 @@ class GlobalStepTable : public DenseTensorTable {
GlobalStepTable() {} GlobalStepTable() {}
virtual ~GlobalStepTable() {} virtual ~GlobalStepTable() {}
int32_t PullSparse(float *values,
const PullSparseValue &pull_value) override {
return 0;
}
int32_t PushSparse(const uint64_t *keys, const float *values,
size_t num) override {
return 0;
}
int32_t Shrink(const std::string &param) override { return 0; } int32_t Shrink(const std::string &param) override { return 0; }
virtual void *GetShard(size_t shard_idx) { return 0; } void *GetShard(size_t shard_idx) override { return 0; }
virtual int32_t InitializeShard() { return 0; } int32_t InitializeShard() override { return 0; }
virtual int32_t Flush() { return 0; } int32_t Flush() override { return 0; }
virtual void Clear() {} void Clear() override {}
virtual int32_t Load(const std::string &path, const std::string &param) { int32_t Load(const std::string &path, const std::string &param) override {
return 0; return 0;
} }
virtual int32_t Save(const std::string &path, const std::string &param) { int32_t Save(const std::string &path, const std::string &param) override {
return 0; return 0;
} }
int32_t PullDense(float *values, size_t num) override { return 0; }
/*----------------------------------------------------------------------*/ /*----------------------------------------------------------------------*/
int32_t Initialize() override { int32_t Initialize() override {
...@@ -235,12 +193,13 @@ class GlobalStepTable : public DenseTensorTable { ...@@ -235,12 +193,13 @@ class GlobalStepTable : public DenseTensorTable {
decay_counters_[i] = 0; decay_counters_[i] = 0;
} }
} }
return 0;
} }
int32_t PushDense(const float *values, size_t num) override { return 0; } // int32_t PushDense(const float *values, size_t num) override { return 0; }
int32_t PushDense(const int64_t *values, const int32_t trainer_id) { virtual int32_t Push(TableContext context) {
return _RunProgram(values, trainer_id); return _RunProgram(context.push_context.push_steps, context.trainer_id);
} }
int32_t SetTableMap(std::unordered_map<uint32_t, std::shared_ptr<Table>> int32_t SetTableMap(std::unordered_map<uint32_t, std::shared_ptr<Table>>
......
...@@ -49,6 +49,8 @@ namespace distributed = paddle::distributed; ...@@ -49,6 +49,8 @@ namespace distributed = paddle::distributed;
void CreateVarsOnScope(framework::Scope* scope, platform::CPUPlace* place) { void CreateVarsOnScope(framework::Scope* scope, platform::CPUPlace* place) {
auto x_var = scope->Var("x"); auto x_var = scope->Var("x");
x_var->GetMutable<framework::LoDTensor>(); x_var->GetMutable<framework::LoDTensor>();
auto x_g_var = scope->Var("x@GRAD");
x_g_var->GetMutable<framework::LoDTensor>();
} }
void InitTensorsOnClient(framework::Scope* scope, platform::CPUPlace* place, void InitTensorsOnClient(framework::Scope* scope, platform::CPUPlace* place,
...@@ -59,34 +61,49 @@ void InitTensorsOnClient(framework::Scope* scope, platform::CPUPlace* place, ...@@ -59,34 +61,49 @@ void InitTensorsOnClient(framework::Scope* scope, platform::CPUPlace* place,
float* x_ptr = float* x_ptr =
x_var->mutable_data<float>(framework::DDim({1, rows_numel}), *place); x_var->mutable_data<float>(framework::DDim({1, rows_numel}), *place);
for (int64_t i = 0; i < rows_numel; ++i) x_ptr[i] = 1.0; for (int64_t i = 0; i < rows_numel; ++i) x_ptr[i] = 1.0;
auto g_size = rows_numel +
30; // hard code here: key_num * (fea_dim + 3), show/clk/slot
auto x_g_var = scope->Var("x@GRAD")->GetMutable<framework::LoDTensor>();
float* x_g_ptr =
x_g_var->mutable_data<float>(framework::DDim({1, g_size}), *place);
for (int64_t i = 0; i < g_size; ++i) x_g_ptr[i] = 1.0;
} }
void GetDownpourSparseTableProto( void GetDownpourSparseTableProto(
::paddle::distributed::TableParameter* sparse_table_proto) { ::paddle::distributed::TableParameter* sparse_table_proto) {
sparse_table_proto->set_table_id(0); sparse_table_proto->set_table_id(0);
sparse_table_proto->set_table_class("CommonSparseTable"); sparse_table_proto->set_table_class("MemorySparseTable");
sparse_table_proto->set_shard_num(256); sparse_table_proto->set_shard_num(10);
sparse_table_proto->set_type(::paddle::distributed::PS_SPARSE_TABLE); ::paddle::distributed::TableAccessorParameter* accessor_config =
::paddle::distributed::TableAccessorParameter* accessor_proto =
sparse_table_proto->mutable_accessor(); sparse_table_proto->mutable_accessor();
::paddle::distributed::CommonAccessorParameter* common_proto =
sparse_table_proto->mutable_common(); accessor_config->set_accessor_class("SparseAccessor");
accessor_config->set_fea_dim(10);
accessor_proto->set_accessor_class("CommMergeAccessor"); accessor_config->set_embedx_dim(9);
accessor_proto->set_fea_dim(0); accessor_config->set_embedx_threshold(0);
accessor_proto->set_embedx_dim(10); accessor_config->mutable_ctr_accessor_param()->set_nonclk_coeff(0.2);
accessor_config->mutable_ctr_accessor_param()->set_click_coeff(1);
common_proto->set_name("sgd"); accessor_config->mutable_ctr_accessor_param()->set_base_threshold(0.5);
common_proto->set_table_name("MergedDense"); accessor_config->mutable_ctr_accessor_param()->set_delta_threshold(0.2);
common_proto->set_trainer_num(1); accessor_config->mutable_ctr_accessor_param()->set_delta_keep_days(16);
common_proto->set_sync(false); accessor_config->mutable_ctr_accessor_param()->set_show_click_decay_rate(
common_proto->set_entry("none"); 0.99);
common_proto->add_params("Param");
common_proto->add_dims(10); accessor_config->mutable_embed_sgd_param()->set_name("SparseNaiveSGDRule");
common_proto->add_initializers("uniform_random&0&-1.0&1.0"); auto* naive_param =
common_proto->add_params("LearningRate"); accessor_config->mutable_embed_sgd_param()->mutable_naive();
common_proto->add_dims(1); naive_param->set_learning_rate(1.0);
common_proto->add_initializers("fill_constant&1.0"); naive_param->set_initial_range(0.3);
naive_param->add_weight_bounds(-10.0);
naive_param->add_weight_bounds(10.0);
accessor_config->mutable_embedx_sgd_param()->set_name("SparseNaiveSGDRule");
naive_param = accessor_config->mutable_embedx_sgd_param()->mutable_naive();
naive_param->set_learning_rate(1.0);
naive_param->set_initial_range(0.3);
naive_param->add_weight_bounds(-10.0);
naive_param->add_weight_bounds(10.0);
} }
::paddle::distributed::PSParameter GetServerProto() { ::paddle::distributed::PSParameter GetServerProto() {
...@@ -217,42 +234,42 @@ void RunBrpcPushSparse() { ...@@ -217,42 +234,42 @@ void RunBrpcPushSparse() {
auto pull_status = worker_ptr_->PullSparse( auto pull_status = worker_ptr_->PullSparse(
fea_value_ptr.data(), 0, fea_keys.data(), fea_keys.size(), true); fea_value_ptr.data(), 0, fea_keys.data(), fea_keys.size(), true);
pull_status.wait(); pull_status.wait();
for (size_t idx = 0; idx < tensor->numel(); ++idx) {
fea_values.data()[idx] *= 2.0;
}
/*-----------------------Test Push Param----------------------------------*/
LOG(INFO) << "Run push_sparse_param"; /*-----------------------Test Push Grad----------------------------------*/
paddle::distributed::DownpourBrpcClosure* closure_push_param = // first to expand embedx, init
paddle::distributed::DownpourBrpcClosure* closure_push_grad =
new paddle::distributed::DownpourBrpcClosure(1, [&](void* done) { new paddle::distributed::DownpourBrpcClosure(1, [&](void* done) {
int ret = 0; int ret = 0;
auto* closure = (paddle::distributed::DownpourBrpcClosure*)done; auto* closure = (paddle::distributed::DownpourBrpcClosure*)done;
for (size_t i = 0; i < 1; ++i) { for (size_t i = 0; i < 1; ++i) {
if (closure->check_response( if (closure->check_response(
i, paddle::distributed::PS_PUSH_SPARSE_PARAM) != 0) { i, paddle::distributed::PS_PUSH_SPARSE_TABLE) != 0) {
ret = -1; ret = -1;
break; break;
} }
} }
closure->set_promise_value(ret); closure->set_promise_value(ret);
}); });
auto push_status = worker_ptr_->PushSparseParam(
0, fea_keys.data(), (const float**)fea_value_ptr.data(), fea_keys.size(),
closure_push_param);
push_status.wait();
auto pull_param_status = worker_ptr_->PullSparse( framework::Variable* g_var = client_scope.FindVar("x@GRAD");
fea_temp_value_ptr.data(), 0, fea_keys.data(), fea_keys.size(), true); framework::LoDTensor* g_tensor = g_var->GetMutable<framework::LoDTensor>();
pull_param_status.wait();
for (size_t idx = 0; idx < tensor->numel(); ++idx) { LOG(INFO) << "Run push_sparse_grad";
EXPECT_FLOAT_EQ(fea_temp_values[idx], fea_values[idx]); std::vector<float*> push_g_vec;
for (auto i = 0; i < static_cast<int>(fea_keys.size()); ++i) {
push_g_vec.push_back(g_tensor->data<float>() + i * 13);
} }
auto push_grad_status = worker_ptr_->PushSparseRawGradient(
0, fea_keys.data(), (const float**)push_g_vec.data(), fea_keys.size(),
closure_push_grad);
push_grad_status.wait();
/*-----------------------Test Push Grad----------------------------------*/ // pull
pull_status = worker_ptr_->PullSparse(fea_value_ptr.data(), 0,
fea_keys.data(), fea_keys.size(), true);
pull_status.wait();
paddle::distributed::DownpourBrpcClosure* closure_push_grad = paddle::distributed::DownpourBrpcClosure* closure_push_grad1 =
new paddle::distributed::DownpourBrpcClosure(1, [&](void* done) { new paddle::distributed::DownpourBrpcClosure(1, [&](void* done) {
int ret = 0; int ret = 0;
auto* closure = (paddle::distributed::DownpourBrpcClosure*)done; auto* closure = (paddle::distributed::DownpourBrpcClosure*)done;
...@@ -266,16 +283,13 @@ void RunBrpcPushSparse() { ...@@ -266,16 +283,13 @@ void RunBrpcPushSparse() {
closure->set_promise_value(ret); closure->set_promise_value(ret);
}); });
LOG(INFO) << "Run pull_sparse_grad"; // push again, embedx update this time
std::vector<float*> push_g_vec; push_grad_status = worker_ptr_->PushSparseRawGradient(
for (auto i = 0; i < static_cast<int>(fea_keys.size()); ++i) {
push_g_vec.push_back(tensor->data<float>() + i * 10);
}
auto push_grad_status = worker_ptr_->PushSparseRawGradient(
0, fea_keys.data(), (const float**)push_g_vec.data(), fea_keys.size(), 0, fea_keys.data(), (const float**)push_g_vec.data(), fea_keys.size(),
closure_push_grad); closure_push_grad1);
push_grad_status.wait(); push_grad_status.wait();
// pull update
auto pull_update_status = worker_ptr_->PullSparse( auto pull_update_status = worker_ptr_->PullSparse(
fea_temp_value_ptr.data(), 0, fea_keys.data(), fea_keys.size(), true); fea_temp_value_ptr.data(), 0, fea_keys.data(), fea_keys.size(), true);
pull_update_status.wait(); pull_update_status.wait();
......
...@@ -69,7 +69,13 @@ TEST(CommonDenseTable, Adam) { ...@@ -69,7 +69,13 @@ TEST(CommonDenseTable, Adam) {
// pull parameters for create and check // pull parameters for create and check
std::vector<float> init_values; std::vector<float> init_values;
init_values.resize(fea_dim); init_values.resize(fea_dim);
table->PullDense(init_values.data(), fea_dim);
TableContext table_context1;
table_context1.value_type = Dense;
table_context1.pull_context.values = init_values.data();
table_context1.num = fea_dim;
table->Pull(table_context1);
// table->PullDense(init_values.data(), fea_dim);
// push gradient // push gradient
std::vector<std::vector<float>> trainer_gradient_values; std::vector<std::vector<float>> trainer_gradient_values;
...@@ -85,12 +91,24 @@ TEST(CommonDenseTable, Adam) { ...@@ -85,12 +91,24 @@ TEST(CommonDenseTable, Adam) {
// for adam // for adam
for (int i = 0; i < trainers; i++) { for (int i = 0; i < trainers; i++) {
auto &push_values = trainer_gradient_values[i]; auto &push_values = trainer_gradient_values[i];
table->PushDense(push_values.data(), push_values.size());
TableContext table_context;
table_context.value_type = Dense;
table_context.push_context.values = push_values.data();
table_context.num = push_values.size();
table->Push(table_context);
// table->PushDense(push_values.data(), push_values.size());
} }
std::vector<float> pull_values; std::vector<float> pull_values;
pull_values.resize(fea_dim); pull_values.resize(fea_dim);
table->PullDense(pull_values.data(), fea_dim);
TableContext table_context;
table_context.value_type = Dense;
table_context.pull_context.values = pull_values.data();
table_context.num = fea_dim;
table->Pull(table_context);
// table->PullDense(pull_values.data(), fea_dim);
float mom_rate = 0.99; float mom_rate = 0.99;
float decay_rate = 0.9999; float decay_rate = 0.9999;
...@@ -150,7 +168,13 @@ TEST(CommonDenseTable, SGD) { ...@@ -150,7 +168,13 @@ TEST(CommonDenseTable, SGD) {
// pull parameters for create and check // pull parameters for create and check
std::vector<float> init_values; std::vector<float> init_values;
init_values.resize(fea_dim); init_values.resize(fea_dim);
table->PullDense(init_values.data(), fea_dim);
TableContext table_context1;
table_context1.value_type = Dense;
table_context1.pull_context.values = init_values.data();
table_context1.num = fea_dim;
table->Pull(table_context1);
// table->PullDense(init_values.data(), fea_dim);
std::vector<float> total_gradients; std::vector<float> total_gradients;
total_gradients.resize(fea_dim); total_gradients.resize(fea_dim);
...@@ -173,7 +197,12 @@ TEST(CommonDenseTable, SGD) { ...@@ -173,7 +197,12 @@ TEST(CommonDenseTable, SGD) {
for (int i = 0; i < trainers; i++) { for (int i = 0; i < trainers; i++) {
auto &push_values = trainer_gradient_values[i]; auto &push_values = trainer_gradient_values[i];
auto task = [table, &push_values] { auto task = [table, &push_values] {
table->PushDense(push_values.data(), push_values.size()); TableContext table_context;
table_context.value_type = Dense;
table_context.push_context.values = push_values.data();
table_context.num = push_values.size();
table->Push(table_context);
// table->PushDense(push_values.data(), push_values.size());
}; };
task_status.push_back(pool_->enqueue(std::move(task))); task_status.push_back(pool_->enqueue(std::move(task)));
} }
...@@ -183,7 +212,13 @@ TEST(CommonDenseTable, SGD) { ...@@ -183,7 +212,13 @@ TEST(CommonDenseTable, SGD) {
std::vector<float> pull_values; std::vector<float> pull_values;
pull_values.resize(fea_dim); pull_values.resize(fea_dim);
table->PullDense(pull_values.data(), fea_dim);
TableContext table_context;
table_context.value_type = Dense;
table_context.pull_context.values = pull_values.data();
table_context.num = fea_dim;
table->Pull(table_context);
// table->PullDense(pull_values.data(), fea_dim);
for (int j = 0; j < fea_dim; j++) { for (int j = 0; j < fea_dim; j++) {
auto update_val = init_values[j] - 1.0 * total_gradients[j]; auto update_val = init_values[j] - 1.0 * total_gradients[j];
ASSERT_TRUE(abs(update_val - pull_values[j]) < 1e-5); ASSERT_TRUE(abs(update_val - pull_values[j]) < 1e-5);
......
...@@ -58,12 +58,26 @@ TEST(MemorySparseGeoTable, SSUM) { ...@@ -58,12 +58,26 @@ TEST(MemorySparseGeoTable, SSUM) {
for (size_t i = 0; i < init_keys.size() * emb_dim; i++) { for (size_t i = 0; i < init_keys.size() * emb_dim; i++) {
init_values.push_back(0.0); init_values.push_back(0.0);
} }
table->PushSparseParam(init_keys.data(), init_values.data(),
init_keys.size()); TableContext table_context1;
table_context1.value_type = Sparse;
table_context1.push_context.keys = init_keys.data();
table_context1.push_context.values = init_values.data();
table_context1.push_context.is_param = true;
table_context1.num = init_keys.size();
table->Push(table_context1);
// table->PushSparseParam(init_keys.data(), init_values.data(),
// init_keys.size());
std::vector<float> pull_values(init_values.size()); std::vector<float> pull_values(init_values.size());
auto value = PullSparseValue(init_keys, init_fres, emb_dim); auto value = PullSparseValue(init_keys, init_fres, emb_dim);
table->PullSparse(pull_values.data(), value); TableContext table_context;
table_context.value_type = Sparse;
table_context.pull_context.pull_value = value;
table_context.pull_context.values = pull_values.data();
table->Pull(table_context);
// table->PullSparse(pull_values.data(), value);
for (size_t i = 0; i < init_keys.size() * emb_dim; i++) { for (size_t i = 0; i < init_keys.size() * emb_dim; i++) {
ASSERT_TRUE(abs(pull_values[i] - init_values[i]) < 1e-5); ASSERT_TRUE(abs(pull_values[i] - init_values[i]) < 1e-5);
...@@ -93,7 +107,14 @@ TEST(MemorySparseGeoTable, SSUM) { ...@@ -93,7 +107,14 @@ TEST(MemorySparseGeoTable, SSUM) {
auto &push_keys = trainer_keys[i]; auto &push_keys = trainer_keys[i];
auto &push_values = trainer_values[i]; auto &push_values = trainer_values[i];
auto task = [table, &push_keys, &push_values] { auto task = [table, &push_keys, &push_values] {
table->PushSparse(push_keys.data(), push_values.data(), push_keys.size()); TableContext table_context;
table_context.value_type = Sparse;
table_context.push_context.keys = push_keys.data();
table_context.push_context.values = push_values.data();
table_context.num = push_keys.size();
table->Push(table_context);
// table->PushSparse(push_keys.data(), push_values.data(),
// push_keys.size());
}; };
task_status.push_back(pool_->enqueue(std::move(task))); task_status.push_back(pool_->enqueue(std::move(task)));
} }
...@@ -106,7 +127,13 @@ TEST(MemorySparseGeoTable, SSUM) { ...@@ -106,7 +127,13 @@ TEST(MemorySparseGeoTable, SSUM) {
geo_pull_ids.resize(trainers); geo_pull_ids.resize(trainers);
geo_pull_values.resize(trainers); geo_pull_values.resize(trainers);
for (int i = 0; i < trainers; i++) { for (int i = 0; i < trainers; i++) {
table->PullGeoParam(i, &geo_pull_values[i], &geo_pull_ids[i]); TableContext table_context;
table_context.value_type = Sparse;
table_context.pull_context.geo_pull_keys = &geo_pull_ids[i];
table_context.pull_context.geo_pull_values = &geo_pull_values[i];
table_context.trainer_id = i;
table->Pull(table_context);
// table->PullGeoParam(i, &geo_pull_values[i], &geo_pull_ids[i]);
ASSERT_EQ(geo_pull_values[i].size(), geo_pull_ids[i].size() * emb_dim); ASSERT_EQ(geo_pull_values[i].size(), geo_pull_ids[i].size() * emb_dim);
for (size_t j = 0; j < geo_pull_ids[i].size(); ++j) { for (size_t j = 0; j < geo_pull_ids[i].size(); ++j) {
auto id = geo_pull_ids[i][j]; auto id = geo_pull_ids[i][j];
......
...@@ -76,7 +76,13 @@ TEST(MemorySparseTable, SGD) { ...@@ -76,7 +76,13 @@ TEST(MemorySparseTable, SGD) {
std::vector<float> init_values; std::vector<float> init_values;
init_values.resize(init_keys.size() * (emb_dim + 3)); init_values.resize(init_keys.size() * (emb_dim + 3));
auto value = PullSparseValue(init_keys, init_fres, emb_dim); auto value = PullSparseValue(init_keys, init_fres, emb_dim);
table->PullSparse(init_values.data(), value);
TableContext table_context;
table_context.value_type = Sparse;
table_context.pull_context.pull_value = value;
table_context.pull_context.values = init_values.data();
table->Pull(table_context);
// table->PullSparse(init_values.data(), value);
// for check // for check
std::vector<float> total_gradients; std::vector<float> total_gradients;
...@@ -109,7 +115,14 @@ TEST(MemorySparseTable, SGD) { ...@@ -109,7 +115,14 @@ TEST(MemorySparseTable, SGD) {
auto &push_keys = trainer_keys[i]; auto &push_keys = trainer_keys[i];
auto &push_values = trainer_gradient_values[i]; auto &push_values = trainer_gradient_values[i];
auto task = [table, &push_keys, &push_values] { auto task = [table, &push_keys, &push_values] {
table->PushSparse(push_keys.data(), push_values.data(), push_keys.size()); TableContext table_context;
table_context.value_type = Sparse;
table_context.push_context.keys = push_keys.data();
table_context.push_context.values = push_values.data();
table_context.num = push_keys.size();
table->Push(table_context);
// table->PushSparse(push_keys.data(), push_values.data(),
// push_keys.size());
}; };
task_status.push_back(pool_->enqueue(std::move(task))); task_status.push_back(pool_->enqueue(std::move(task)));
} }
...@@ -119,7 +132,13 @@ TEST(MemorySparseTable, SGD) { ...@@ -119,7 +132,13 @@ TEST(MemorySparseTable, SGD) {
std::vector<float> pull_values; std::vector<float> pull_values;
pull_values.resize(init_keys.size() * (emb_dim + 3)); pull_values.resize(init_keys.size() * (emb_dim + 3));
table->PullSparse(pull_values.data(), value);
TableContext table_context1;
table_context1.value_type = Sparse;
table_context1.pull_context.pull_value = value;
table_context1.pull_context.values = pull_values.data();
table->Pull(table_context1);
// table->PullSparse(pull_values.data(), value);
for (size_t i = 0; i < init_keys.size(); ++i) { for (size_t i = 0; i < init_keys.size(); ++i) {
for (size_t j = 2; j < emb_dim + 3; ++j) { for (size_t j = 2; j < emb_dim + 3; ++j) {
......
...@@ -621,7 +621,7 @@ class SparseTable(Table): ...@@ -621,7 +621,7 @@ class SparseTable(Table):
class GeoSparseTable(SparseTable): class GeoSparseTable(SparseTable):
def __init__(self, context, send_ctx): def __init__(self, context, send_ctx):
super(GeoSparseTable, self).__init__(context, send_ctx) super(GeoSparseTable, self).__init__(context, send_ctx)
self.table_class = "SparseGeoTable" self.table_class = "MemorySparseGeoTable"
if self.context['ps_mode'] != DistributedMode.GEO: if self.context['ps_mode'] != DistributedMode.GEO:
raise ValueError("not geo sparse table!") raise ValueError("not geo sparse table!")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册