From 19cb0d189f53e41e12829da360cd8e605d5c4758 Mon Sep 17 00:00:00 2001 From: zhaocaibei123 <48509226+zhaocaibei123@users.noreply.github.com> Date: Mon, 4 Apr 2022 21:29:18 +0800 Subject: [PATCH] Table refine: Pull/Push(TableContext) (#41320) * update name * update name * fix test * fix fleet bind * update name * update name * fix test * fix gpups wrapper * remove Push/Pull/Load/Save with context in client and wrapper base class * fix * fix * remove some interface * fix * remove * code style * recover * fix * remove code unused * fix * recover * fix Co-authored-by: esythan --- .../distributed/ps/service/brpc_ps_server.cc | 36 +++++- .../distributed/ps/service/ps_local_client.cc | 60 +++++++++- .../ps/table/common_dense_table.cc | 7 +- .../distributed/ps/table/common_dense_table.h | 22 ++-- .../distributed/ps/table/common_graph_table.h | 28 +++-- .../ps/table/common_sparse_table.h | 14 ++- .../fluid/distributed/ps/table/common_table.h | 57 --------- .../ps/table/memory_sparse_geo_table.cc | 24 ++++ .../ps/table/memory_sparse_geo_table.h | 32 ++--- .../ps/table/memory_sparse_table.cc | 18 ++- .../ps/table/memory_sparse_table.h | 58 ++++----- paddle/fluid/distributed/ps/table/table.h | 53 +++------ .../fluid/distributed/ps/table/tensor_table.h | 89 ++++---------- .../test/brpc_service_sparse_sgd_test.cc | 110 ++++++++++-------- .../distributed/test/dense_table_test.cc | 47 +++++++- .../distributed/test/memory_geo_table_test.cc | 37 +++++- .../test/memory_sparse_table_test.cc | 25 +++- python/paddle/distributed/ps/the_one_ps.py | 2 +- 18 files changed, 406 insertions(+), 313 deletions(-) diff --git a/paddle/fluid/distributed/ps/service/brpc_ps_server.cc b/paddle/fluid/distributed/ps/service/brpc_ps_server.cc index a1690cbb935..d22cca91f78 100644 --- a/paddle/fluid/distributed/ps/service/brpc_ps_server.cc +++ b/paddle/fluid/distributed/ps/service/brpc_ps_server.cc @@ -244,7 +244,14 @@ int32_t BrpcPsService::PushDenseParam(Table *table, uint32_t num = *(const uint32_t *)data; const float *values = (const float *)(data + sizeof(uint32_t)); - if (table->PushDenseParam(values, num) != 0) { + TableContext table_context; + table_context.value_type = Dense; + table_context.push_context.values = values; + table_context.push_context.is_param = true; + table_context.num = num; + + // if (table->PushDenseParam(values, num) != 0) { + if (table->Push(table_context) != 0) { set_response_code(response, -1, "PushDenseParam failed"); } return 0; @@ -330,7 +337,15 @@ int32_t BrpcPsService::PushSparseParam(Table *table, const uint64_t *keys = (const uint64_t *)push_data.data(); const float *values = (const float *)(push_data.data() + sizeof(uint64_t) * num); - if (table->PushSparseParam(keys, values, num) != 0) { + + TableContext table_context; + table_context.value_type = Sparse; + table_context.push_context.keys = keys; + table_context.push_context.values = values; + table_context.push_context.is_param = true; + table_context.num = num; + // if (table->PushSparseParam(keys, values, num) != 0) { + if (table->Push(table_context) != 0) { set_response_code(response, -1, "PushSparseParam error"); } return 0; @@ -349,7 +364,14 @@ int32_t BrpcPsService::PullGeoParam(Table *table, std::vector values; std::vector ids; - table->PullGeoParam(trainer_id, &values, &ids); + + TableContext table_context; + table_context.value_type = Sparse; + table_context.pull_context.geo_pull_keys = &ids; + table_context.pull_context.geo_pull_values = &values; + table_context.trainer_id = trainer_id; + table->Pull(table_context); + // table->PullGeoParam(trainer_id, &values, &ids); uint32_t num = ids.size(); cntl->response_attachment().append((char *)(&num), sizeof(uint32_t)); @@ -625,7 +647,13 @@ int32_t BrpcPsService::PushGlobalStep(Table *table, const int64_t *values = (const int64_t *)(request.data().data() + sizeof(uint32_t)); auto trainer_id = request.client_id(); - if (table->PushDense(values, trainer_id) != 0) { + + TableContext context; + context.trainer_id = trainer_id; + context.push_context.push_steps = values; + + // if (table->PushDense(values, trainer_id) != 0) { + if (table->Push(context) != 0) { set_response_code(response, -1, "run_program failed"); } diff --git a/paddle/fluid/distributed/ps/service/ps_local_client.cc b/paddle/fluid/distributed/ps/service/ps_local_client.cc index 3e93f861d4e..bc024ed3175 100644 --- a/paddle/fluid/distributed/ps/service/ps_local_client.cc +++ b/paddle/fluid/distributed/ps/service/ps_local_client.cc @@ -104,7 +104,13 @@ int32_t PsLocalClient::Initialize() { std::vector region_buffer; region_buffer.resize(num_per_shard); - table_ptr->PullDense(region_buffer.data(), region_buffer.size()); + + TableContext table_context; + table_context.value_type = Dense; + table_context.pull_context.values = region_buffer.data(); + table_context.num = region_buffer.size(); + table_ptr->Pull(table_context); + // table_ptr->PullDense(region_buffer.data(), region_buffer.size()); size_t region_idx = 0; size_t region_data_idx = 0; @@ -154,6 +160,13 @@ int32_t PsLocalClient::Initialize() { offset += data_num; } + TableContext table_context; + table_context.value_type = Dense; + table_context.push_context.values = region_buffer.data(); + table_context.push_context.is_param = true; + table_context.num = region_buffer.size(); + + table_ptr->Push(table_context); // table_ptr->PushDenseParam(region_buffer.data(), region_buffer.size()); return done(); @@ -168,7 +181,13 @@ int32_t PsLocalClient::Initialize() { auto* table_ptr = GetTable(table_id); - table_ptr->PushDense(total_send_data, total_send_data_size); + TableContext table_context; + table_context.value_type = Dense; + table_context.push_context.values = total_send_data; + table_context.num = total_send_data_size; + // table_ptr->PushDense(total_send_data, total_send_data_size); + table_ptr->Push(table_context); + delete closure; return done(); } @@ -194,7 +213,12 @@ int32_t PsLocalClient::Initialize() { offset += data_num; } - table_ptr->PushDense(region_buffer.data(), region_buffer.size()); + TableContext table_context; + table_context.value_type = Dense; + table_context.push_context.values = region_buffer.data(); + table_context.num = region_buffer.size(); + // table_ptr->PushDense(total_send_data, total_send_data_size); + table_ptr->Push(table_context); return done(); } @@ -241,7 +265,15 @@ int32_t PsLocalClient::Initialize() { //将key拆分到各shard请求,并记录原始对应value指针 auto* table_ptr = GetTable(table_id); - table_ptr->PullSparsePtr(select_values, keys, num); + TableContext table_context; + table_context.value_type = Sparse; + table_context.pull_context.keys = keys; + table_context.pull_context.ptr_values = select_values; + table_context.use_ptr = true; + table_context.num = num; + + // table_ptr->PullSparsePtr(select_values, keys, num); + table_ptr->Pull(table_context); return done(); } @@ -253,7 +285,15 @@ int32_t PsLocalClient::Initialize() { auto* accessor = GetTableAccessor(table_id); auto* table_ptr = GetTable(table_id); - table_ptr->PushSparse(keys, update_values, num); + TableContext table_context; + table_context.value_type = Sparse; + table_context.push_context.keys = keys; + table_context.push_context.ptr_values = update_values; + table_context.num = num; + table_context.use_ptr = true; + + // table_ptr->PushSparse(keys, update_values, num); + table_ptr->Push(table_context); delete closure; return done(); } @@ -265,7 +305,15 @@ int32_t PsLocalClient::Initialize() { auto* accessor = GetTableAccessor(table_id); auto* table_ptr = GetTable(table_id); - table_ptr->PushSparse(keys, update_values, num); + TableContext table_context; + table_context.value_type = Sparse; + table_context.push_context.keys = keys; + table_context.push_context.ptr_values = update_values; + table_context.num = num; + table_context.use_ptr = true; + + // table_ptr->PushSparse(keys, update_values, num); + table_ptr->Push(table_context); return done(); } } diff --git a/paddle/fluid/distributed/ps/table/common_dense_table.cc b/paddle/fluid/distributed/ps/table/common_dense_table.cc index 4242b65dea0..45208670f9d 100644 --- a/paddle/fluid/distributed/ps/table/common_dense_table.cc +++ b/paddle/fluid/distributed/ps/table/common_dense_table.cc @@ -139,8 +139,11 @@ int32_t CommonDenseTable::Pull(TableContext& context) { int32_t CommonDenseTable::Push(TableContext& context) { CHECK(context.value_type == Dense); if (context.push_context.values != nullptr) { - const float* values = context.push_context.values; - return PushDense(values, context.num); + if (!context.push_context.is_param) { + return PushDense(context.push_context.values, context.num); + } else { + return PushDenseParam(context.push_context.values, context.num); + } } return 0; } diff --git a/paddle/fluid/distributed/ps/table/common_dense_table.h b/paddle/fluid/distributed/ps/table/common_dense_table.h index 8e4ff1ecaf4..acda009d024 100644 --- a/paddle/fluid/distributed/ps/table/common_dense_table.h +++ b/paddle/fluid/distributed/ps/table/common_dense_table.h @@ -30,21 +30,22 @@ namespace distributed { class DenseOptimizer; -class CommonDenseTable : public DenseTable { +class CommonDenseTable : public Table { public: CommonDenseTable() {} virtual ~CommonDenseTable() {} int32_t Initialize() override; int32_t InitializeShard() override { return 0; } - virtual void CreateInitializer(const std::string& attr, - const std::string& name); - virtual int32_t InitializeValue(); - virtual int32_t InitializeOptimizer(); - virtual int32_t Pull(TableContext& context); - virtual int32_t Push(TableContext& context); - int32_t PullDense(float* pull_values, size_t num) override; - int32_t PushDenseParam(const float* values, size_t num) override; - int32_t PushDense(const float* values, size_t num) override; + void CreateInitializer(const std::string& attr, const std::string& name); + int32_t InitializeValue(); + int32_t InitializeOptimizer(); + + int32_t Pull(TableContext& context) override; + int32_t Push(TableContext& context) override; + + int32_t PullDense(float* pull_values, size_t num); + int32_t PushDenseParam(const float* values, size_t num); + int32_t PushDense(const float* values, size_t num); int32_t Pour() override; int32_t SetGlobalLR(float* lr) override; @@ -54,6 +55,7 @@ class CommonDenseTable : public DenseTable { int32_t Flush() override { return 0; } int32_t Shrink(const std::string& param) override { return 0; } void Clear() override { return; } + void* GetShard(size_t shard_idx) override { return 0; } protected: int32_t _PushDense(const float* values, size_t num); diff --git a/paddle/fluid/distributed/ps/table/common_graph_table.h b/paddle/fluid/distributed/ps/table/common_graph_table.h index 035a3de3eba..acc484e6098 100644 --- a/paddle/fluid/distributed/ps/table/common_graph_table.h +++ b/paddle/fluid/distributed/ps/table/common_graph_table.h @@ -404,7 +404,7 @@ class GraphSampler { }; #endif -class GraphTable : public SparseTable { +class GraphTable : public Table { public: GraphTable() { use_cache = false; @@ -415,6 +415,23 @@ class GraphTable : public SparseTable { rw_lock.reset(new pthread_rwlock_t()); } virtual ~GraphTable(); + + virtual void *GetShard(size_t shard_idx) { return 0; } + + static int32_t sparse_local_shard_num(uint32_t shard_num, + uint32_t server_num) { + if (shard_num % server_num == 0) { + return shard_num / server_num; + } + size_t local_shard_num = shard_num / server_num + 1; + return local_shard_num; + } + + static size_t get_sparse_shard(uint32_t shard_num, uint32_t server_num, + uint64_t key) { + return (key % shard_num) / sparse_local_shard_num(shard_num, server_num); + } + virtual int32_t pull_graph_list(int start, int size, std::unique_ptr &buffer, int &actual_size, bool need_feature, @@ -452,15 +469,6 @@ class GraphTable : public SparseTable { virtual int32_t Pull(TableContext &context) { return 0; } virtual int32_t Push(TableContext &context) { return 0; } - virtual int32_t PullSparse(float *values, const PullSparseValue &pull_value) { - return 0; - } - - virtual int32_t PushSparse(const uint64_t *keys, const float *values, - size_t num) { - return 0; - } - virtual int32_t clear_nodes(); virtual void Clear() {} virtual int32_t Flush() { return 0; } diff --git a/paddle/fluid/distributed/ps/table/common_sparse_table.h b/paddle/fluid/distributed/ps/table/common_sparse_table.h index f6deaf0a82b..2673e8dfae3 100644 --- a/paddle/fluid/distributed/ps/table/common_sparse_table.h +++ b/paddle/fluid/distributed/ps/table/common_sparse_table.h @@ -108,15 +108,16 @@ struct Meta { } }; -class CommonSparseTable : public SparseTable { +class CommonSparseTable : public Table { public: CommonSparseTable() { rwlock_.reset(new phi::RWLock); } virtual ~CommonSparseTable() {} // unused method begin - virtual int32_t PullDense(float* pull_values, size_t num) { return 0; } - virtual int32_t PushDenseParam(const float* values, size_t num) { return 0; } - virtual int32_t PushDense(const float* values, size_t num) { return 0; } + // virtual int32_t PullDense(float* pull_values, size_t num) { return 0; } + // virtual int32_t PushDenseParam(const float* values, size_t num) { return + // 0; } + // virtual int32_t PushDense(const float* values, size_t num) { return 0; } // unused method end virtual int32_t Pull(TableContext& context); @@ -163,14 +164,15 @@ class CommonSparseTable : public SparseTable { // only for sparse geo table virtual int32_t PushSparseParam(const uint64_t* keys, const float* values, size_t num); - - virtual int32_t SetGlobalLR(float* lr) override; + virtual int32_t SetGlobalLR(float* lr); virtual int32_t Pour(); virtual int32_t Flush(); virtual int32_t Shrink(const std::string& param); virtual void Clear(); + virtual void* GetShard(size_t shard_idx) { return 0; } + protected: virtual int32_t _PushSparse(const uint64_t* keys, const float* values, size_t num); diff --git a/paddle/fluid/distributed/ps/table/common_table.h b/paddle/fluid/distributed/ps/table/common_table.h index f5e263e8e71..f69d9ccbf14 100644 --- a/paddle/fluid/distributed/ps/table/common_table.h +++ b/paddle/fluid/distributed/ps/table/common_table.h @@ -66,50 +66,6 @@ struct ReservoirValue { } }; -class SparseTable : public Table { - public: - SparseTable() {} - virtual ~SparseTable() {} - - virtual void *GetShard(size_t shard_idx) { return 0; } - - int32_t PullDense(float *values, size_t num) override { return 0; } - - int32_t PushDense(const float *values, size_t num) override { return 0; } - - static int32_t sparse_local_shard_num(uint32_t shard_num, - uint32_t server_num) { - if (shard_num % server_num == 0) { - return shard_num / server_num; - } - size_t local_shard_num = shard_num / server_num + 1; - return local_shard_num; - } - - static size_t get_sparse_shard(uint32_t shard_num, uint32_t server_num, - uint64_t key) { - return (key % shard_num) / sparse_local_shard_num(shard_num, server_num); - } -}; - -class DenseTable : public Table { - public: - DenseTable() {} - virtual ~DenseTable() {} - - virtual void *GetShard(size_t shard_idx) { return 0; } - int32_t PullSparse(float *values, - const PullSparseValue &pull_value) override { - return 0; - } - int32_t PushSparse(const uint64_t *keys, const float *values, - size_t num) override { - return 0; - } - int32_t PushDenseParam(const float *values, size_t num) override { return 0; } - int32_t Shrink(const std::string ¶m) override { return 0; } -}; - class BarrierTable : public Table { public: BarrierTable() {} @@ -120,19 +76,6 @@ class BarrierTable : public Table { virtual int32_t Pull(TableContext &context) { return 0; } virtual int32_t Push(TableContext &context) { return 0; } - int32_t PullDense(float *values, size_t num) override { return 0; } - - int32_t PushDense(const float *values, size_t num) override { return 0; } - - int32_t PullSparse(float *values, - const PullSparseValue &pull_value) override { - return 0; - } - int32_t PushSparse(const uint64_t *keys, const float *values, - size_t num) override { - return 0; - } - int32_t PushDenseParam(const float *values, size_t num) override { return 0; } int32_t Shrink(const std::string ¶m) override { return 0; } virtual void Clear() {} virtual int32_t Flush() { return 0; } diff --git a/paddle/fluid/distributed/ps/table/memory_sparse_geo_table.cc b/paddle/fluid/distributed/ps/table/memory_sparse_geo_table.cc index 979e1c48254..1567d31d0f3 100644 --- a/paddle/fluid/distributed/ps/table/memory_sparse_geo_table.cc +++ b/paddle/fluid/distributed/ps/table/memory_sparse_geo_table.cc @@ -17,6 +17,29 @@ namespace paddle { namespace distributed { +int32_t MemorySparseGeoTable::Pull(TableContext& context) { + CHECK(context.value_type == Sparse); + if (context.pull_context.geo_pull_keys != nullptr) { + return PullGeoParam(context.trainer_id, + context.pull_context.geo_pull_values, + context.pull_context.geo_pull_keys); + } else { + return PullSparse(context.pull_context.values, + context.pull_context.pull_value); + } +} + +int32_t MemorySparseGeoTable::Push(TableContext& context) { + CHECK(context.value_type == Sparse); + if (!context.push_context.is_param) { + return PushSparse(context.push_context.keys, context.push_context.values, + context.num); + } else { + return PushSparseParam(context.push_context.keys, + context.push_context.values, context.num); + } +} + int32_t MemorySparseGeoTable::PushSparseParam(const uint64_t* keys, const float* values, size_t num) { VLOG(5) << "DEBUG MemorySparseGeoTable::PushSparseParam begin " @@ -117,6 +140,7 @@ int32_t MemorySparseGeoTable::Initialize() { return 0; } +// hash different from MemorySparseTable int32_t MemorySparseGeoTable::PullSparse(float* pull_values, const PullSparseValue& pull_value) { auto shard_num = _task_pool_size; diff --git a/paddle/fluid/distributed/ps/table/memory_sparse_geo_table.h b/paddle/fluid/distributed/ps/table/memory_sparse_geo_table.h index 1a74df32db8..60ba5d9602e 100644 --- a/paddle/fluid/distributed/ps/table/memory_sparse_geo_table.h +++ b/paddle/fluid/distributed/ps/table/memory_sparse_geo_table.h @@ -34,40 +34,44 @@ namespace distributed { class GeoRecorder; -class MemorySparseGeoTable : public SparseTable { +class MemorySparseGeoTable : public Table { public: typedef SparseTableShard shard_type; MemorySparseGeoTable() { _geo_recorder = nullptr; } virtual ~MemorySparseGeoTable() {} - virtual int32_t Initialize(); - virtual int32_t InitializeShard() { return 0; } - virtual int32_t Load(const std::string& path, const std::string& param) { + int32_t Initialize() override; + int32_t InitializeShard() override { return 0; } + int32_t Load(const std::string& path, const std::string& param) override { return 0; } - virtual int32_t Save(const std::string& path, const std::string& param) { + int32_t Save(const std::string& path, const std::string& param) override { return 0; } - virtual int32_t Pull(TableContext& context) { return 0; } - virtual int32_t Push(TableContext& context) { return 0; } - virtual int32_t Flush() { return 0; } - virtual int32_t Shrink(const std::string& param) { return 0; } - virtual void Clear() { return; } - virtual int32_t PullSparse(float* values, const PullSparseValue& pull_value); + int32_t Pull(TableContext& context) override; + int32_t Push(TableContext& context) override; + int32_t Flush() override { return 0; } + int32_t Shrink(const std::string& param) override { return 0; } + void Clear() override { return; } + + int32_t PullSparse(float* values, const PullSparseValue& pull_value); int32_t PushSparseParam(const uint64_t* keys, const float* values, size_t num); - // TODO(zhaocaibei123): change to pull_sparse, and rename pull_sparse + int32_t PullGeoParam(const uint32_t trainer_id, std::vector* values, std::vector* keys); - int32_t PushSparse(const uint64_t* keys, const float* values, - size_t num) override; + int32_t PushSparse(const uint64_t* keys, const float* values, size_t num); int32_t _PushSparse(const uint64_t* keys, const float* values, size_t num); // int32_t _pull_sparse(float* pull_values, const PullSparseValue& // pull_value); + void* GetShard(size_t shard_idx) override { + return &_local_shards[shard_idx]; + } + private: std::shared_ptr _geo_recorder; const int _task_pool_size = 10; diff --git a/paddle/fluid/distributed/ps/table/memory_sparse_table.cc b/paddle/fluid/distributed/ps/table/memory_sparse_table.cc index b4b2263ed77..e6c52e0b9b0 100644 --- a/paddle/fluid/distributed/ps/table/memory_sparse_table.cc +++ b/paddle/fluid/distributed/ps/table/memory_sparse_table.cc @@ -47,7 +47,7 @@ int32_t MemorySparseTable::Initialize() { int32_t MemorySparseTable::InitializeValue() { _sparse_table_shard_num = static_cast(_config.shard_num()); _avg_local_shard_num = - SparseTable::sparse_local_shard_num(_sparse_table_shard_num, _shard_num); + sparse_local_shard_num(_sparse_table_shard_num, _shard_num); _real_local_shard_num = _avg_local_shard_num; if (_real_local_shard_num * (_shard_idx + 1) > _sparse_table_shard_num) { _real_local_shard_num = @@ -405,9 +405,13 @@ int32_t MemorySparseTable::Pull(TableContext& context) { int32_t MemorySparseTable::Push(TableContext& context) { CHECK(context.value_type == Sparse); - - const uint64_t* keys = context.push_context.keys; - return PushSparse(keys, context.push_context.values, context.num); + if (!context.use_ptr) { + return PushSparse(context.push_context.keys, context.push_context.values, + context.num); + } else { + return PushSparse(context.push_context.keys, + context.push_context.ptr_values, context.num); + } } int32_t MemorySparseTable::PullSparse(float* pull_values, @@ -610,12 +614,6 @@ int32_t MemorySparseTable::PushSparse(const uint64_t* keys, const float* values, int32_t MemorySparseTable::PushSparse(const uint64_t* keys, const float** values, size_t num) { - _PushSparse(keys, values, num); - return 0; -} - -int32_t MemorySparseTable::_PushSparse(const uint64_t* keys, - const float** values, size_t num) { std::vector> tasks(_real_local_shard_num); std::vector>> task_keys( _real_local_shard_num); diff --git a/paddle/fluid/distributed/ps/table/memory_sparse_table.h b/paddle/fluid/distributed/ps/table/memory_sparse_table.h index a4af4caa472..87a73bd22fa 100644 --- a/paddle/fluid/distributed/ps/table/memory_sparse_table.h +++ b/paddle/fluid/distributed/ps/table/memory_sparse_table.h @@ -34,28 +34,37 @@ namespace paddle { namespace distributed { -class MemorySparseTable : public SparseTable { +class MemorySparseTable : public Table { public: typedef SparseTableShard shard_type; MemorySparseTable() {} virtual ~MemorySparseTable() {} - // unused method begin - virtual int32_t PullDense(float* pull_values, size_t num) { return 0; } - virtual int32_t PushDenseParam(const float* values, size_t num) { return 0; } - virtual int32_t PushDense(const float* values, size_t num) { return 0; } // unused method end + static int32_t sparse_local_shard_num(uint32_t shard_num, + uint32_t server_num) { + if (shard_num % server_num == 0) { + return shard_num / server_num; + } + size_t local_shard_num = shard_num / server_num + 1; + return local_shard_num; + } - virtual int32_t Pull(TableContext& context); - virtual int32_t Push(TableContext& context); + static size_t get_sparse_shard(uint32_t shard_num, uint32_t server_num, + uint64_t key) { + return (key % shard_num) / sparse_local_shard_num(shard_num, server_num); + } - virtual int32_t Initialize(); - virtual int32_t InitializeShard() { return 0; } - virtual int32_t InitializeValue(); + int32_t Pull(TableContext& context) override; + int32_t Push(TableContext& context) override; - virtual int32_t Load(const std::string& path, const std::string& param); + int32_t Initialize() override; + int32_t InitializeShard() override { return 0; } + int32_t InitializeValue(); - virtual int32_t Save(const std::string& path, const std::string& param); + int32_t Load(const std::string& path, const std::string& param) override; + + int32_t Save(const std::string& path, const std::string& param) override; int32_t LoadLocalFS(const std::string& path, const std::string& param); int32_t SaveLocalFS(const std::string& path, const std::string& param, @@ -64,25 +73,22 @@ class MemorySparseTable : public SparseTable { int64_t LocalSize(); int64_t LocalMFSize(); - virtual std::pair PrintTableStat(); - virtual int32_t PullSparse(float* values, const PullSparseValue& pull_value); + std::pair PrintTableStat() override; + int32_t PullSparse(float* values, const PullSparseValue& pull_value); - virtual int32_t PullSparsePtr(char** pull_values, const uint64_t* keys, - size_t num); + int32_t PullSparsePtr(char** pull_values, const uint64_t* keys, size_t num); - virtual int32_t PushSparse(const uint64_t* keys, const float* values, - size_t num); + int32_t PushSparse(const uint64_t* keys, const float* values, size_t num); - virtual int32_t PushSparse(const uint64_t* keys, const float** values, - size_t num); + int32_t PushSparse(const uint64_t* keys, const float** values, size_t num); - virtual int32_t Flush(); - virtual int32_t Shrink(const std::string& param); - virtual void Clear(); + int32_t Flush() override; + int32_t Shrink(const std::string& param) override; + void Clear() override; - protected: - virtual int32_t _PushSparse(const uint64_t* keys, const float** values, - size_t num); + void* GetShard(size_t shard_idx) override { + return &_local_shards[shard_idx]; + } protected: const int _task_pool_size = 24; diff --git a/paddle/fluid/distributed/ps/table/table.h b/paddle/fluid/distributed/ps/table/table.h index f55c30b7740..c515e03e3fa 100644 --- a/paddle/fluid/distributed/ps/table/table.h +++ b/paddle/fluid/distributed/ps/table/table.h @@ -35,25 +35,30 @@ namespace distributed { enum ValueType { Sparse = 0, Dense = 1 }; -struct PullContext { - const uint64_t *keys; +struct TablePullContext { + const uint64_t *keys = nullptr; PullSparseValue pull_value; - float *values; - char **ptr_values; + float *values = nullptr; + char **ptr_values = nullptr; + std::vector *geo_pull_keys = nullptr; // for GEO + std::vector *geo_pull_values = nullptr; // for GEO }; struct TablePushContext { - const uint64_t *keys; - const float *values; - const float **ptr_values; + const uint64_t *keys = nullptr; + const float *values = nullptr; + const float **ptr_values = nullptr; + const int64_t *push_steps = nullptr; // for global step + bool is_param = false; // true: push param, false: push gradient }; struct TableContext { ValueType value_type; - PullContext pull_context; + TablePullContext pull_context; TablePushContext push_context; size_t num; bool use_ptr = false; + uint32_t trainer_id; // for GEO and global step }; class Table { @@ -65,38 +70,6 @@ class Table { virtual int32_t Pull(TableContext &context) = 0; virtual int32_t Push(TableContext &context) = 0; - virtual int32_t PullDense(float *values, size_t num) = 0; - virtual int32_t PushDense(const float *values, size_t num) = 0; - // for push global_step - virtual int32_t PushDense(const int64_t *values, const int32_t trainer_id) { - return 0; - } - virtual int32_t PushDenseParam(const float *values, size_t num) { return 0; } - - virtual int32_t PullSparsePtr(char **pull_values, const uint64_t *keys, - size_t num) { - VLOG(0) << "NOT IMPLEMENT"; - return 0; - } - virtual int32_t PullSparse(float *values, - const PullSparseValue &pull_value) = 0; - virtual int32_t PushSparse(const uint64_t *keys, const float *values, - size_t num) = 0; - virtual int32_t PushSparse(const uint64_t *keys, const float **values, - size_t num) { - return 0; - } - virtual int32_t PushSparseParam(const uint64_t *keys, const float *values, - size_t num) { - return 0; - } - - // only for sparse geo table - virtual int32_t PullGeoParam(const uint32_t trainer_id, - std::vector *values, - std::vector *keys) { - return 0; - } // only for barrier virtual int32_t Barrier(const uint32_t trainer_id, diff --git a/paddle/fluid/distributed/ps/table/tensor_table.h b/paddle/fluid/distributed/ps/table/tensor_table.h index 175aa194fb8..7bb236d02c9 100644 --- a/paddle/fluid/distributed/ps/table/tensor_table.h +++ b/paddle/fluid/distributed/ps/table/tensor_table.h @@ -50,43 +50,28 @@ class TensorTable : public Table { TensorTable() {} virtual ~TensorTable() {} - virtual int32_t Pull(TableContext &context) { return 0; } - virtual int32_t Push(TableContext &context) { return 0; } - int32_t PullDense(float *values, size_t num) override { return 0; } + int32_t Pull(TableContext &context) override { return 0; } + int32_t Push(TableContext &context) override { return 0; } - int32_t PushDense(const float *values, size_t num) override { return 0; } - - int32_t PullSparse(float *values, - const PullSparseValue &pull_value) override { - return 0; - } - int32_t PushSparse(const uint64_t *keys, const float *values, - size_t num) override { - return 0; - } int32_t Shrink(const std::string ¶m) override { return 0; } - virtual void *GetShard(size_t shard_idx) { return 0; } + void *GetShard(size_t shard_idx) override { return 0; } - virtual int32_t InitializeShard() { return 0; } + int32_t InitializeShard() override { return 0; } - virtual int32_t Flush() { return 0; } + int32_t Flush() override { return 0; } - virtual int32_t Load(const std::string &path, const std::string ¶m) { + int32_t Load(const std::string &path, const std::string ¶m) override { return 0; } - virtual int32_t Save(const std::string &path, const std::string ¶m) { + int32_t Save(const std::string &path, const std::string ¶m) override { return 0; } - virtual void Clear() {} + void Clear() override {} int32_t Initialize() override { return 0; } - int32_t PushDense(const int64_t *values, const int32_t trainer_id) override { - return 0; - } - int32_t SetProgramEnv( framework::Scope *scope, platform::Place place, const std::vector *sub_program) override { @@ -111,45 +96,28 @@ class DenseTensorTable : public TensorTable { DenseTensorTable() {} virtual ~DenseTensorTable() {} - int32_t PullSparse(float *values, - const PullSparseValue &pull_value) override { - return 0; - } - int32_t PushSparse(const uint64_t *keys, const float *values, - size_t num) override { - return 0; - } int32_t Shrink(const std::string ¶m) override { return 0; } - virtual void *GetShard(size_t shard_idx) { return 0; } + void *GetShard(size_t shard_idx) override { return 0; } - virtual int32_t InitializeShard() { return 0; } + int32_t InitializeShard() override { return 0; } - virtual int32_t Flush() { return 0; } + int32_t Flush() override { return 0; } - virtual void Clear() {} + void Clear() override {} // Todo: Support program Load & Save - virtual int32_t Load(const std::string &path, const std::string ¶m) { + int32_t Load(const std::string &path, const std::string ¶m) override { return 0; } - virtual int32_t Save(const std::string &path, const std::string ¶m) { + int32_t Save(const std::string &path, const std::string ¶m) override { return 0; } - // Todo: Support pull dense - int32_t PullDense(float *values, size_t num) override { return 0; } - /*----------------------------------------------------------------------*/ int32_t Initialize() override { return 0; } - int32_t PushDense(const float *values, size_t num) override { return 0; } - - int32_t PushDense(const int64_t *values, const int32_t trainer_id) { - return 0; - } - protected: virtual int32_t _RunProgram(const float *values, size_t num, const uint32_t trainer_id) { @@ -167,33 +135,23 @@ class GlobalStepTable : public DenseTensorTable { GlobalStepTable() {} virtual ~GlobalStepTable() {} - int32_t PullSparse(float *values, - const PullSparseValue &pull_value) override { - return 0; - } - int32_t PushSparse(const uint64_t *keys, const float *values, - size_t num) override { - return 0; - } int32_t Shrink(const std::string ¶m) override { return 0; } - virtual void *GetShard(size_t shard_idx) { return 0; } + void *GetShard(size_t shard_idx) override { return 0; } - virtual int32_t InitializeShard() { return 0; } + int32_t InitializeShard() override { return 0; } - virtual int32_t Flush() { return 0; } + int32_t Flush() override { return 0; } - virtual void Clear() {} + void Clear() override {} - virtual int32_t Load(const std::string &path, const std::string ¶m) { + int32_t Load(const std::string &path, const std::string ¶m) override { return 0; } - virtual int32_t Save(const std::string &path, const std::string ¶m) { + int32_t Save(const std::string &path, const std::string ¶m) override { return 0; } - int32_t PullDense(float *values, size_t num) override { return 0; } - /*----------------------------------------------------------------------*/ int32_t Initialize() override { @@ -235,12 +193,13 @@ class GlobalStepTable : public DenseTensorTable { decay_counters_[i] = 0; } } + return 0; } - int32_t PushDense(const float *values, size_t num) override { return 0; } + // int32_t PushDense(const float *values, size_t num) override { return 0; } - int32_t PushDense(const int64_t *values, const int32_t trainer_id) { - return _RunProgram(values, trainer_id); + virtual int32_t Push(TableContext context) { + return _RunProgram(context.push_context.push_steps, context.trainer_id); } int32_t SetTableMap(std::unordered_map> diff --git a/paddle/fluid/distributed/test/brpc_service_sparse_sgd_test.cc b/paddle/fluid/distributed/test/brpc_service_sparse_sgd_test.cc index f7d287af844..29195d99857 100644 --- a/paddle/fluid/distributed/test/brpc_service_sparse_sgd_test.cc +++ b/paddle/fluid/distributed/test/brpc_service_sparse_sgd_test.cc @@ -49,6 +49,8 @@ namespace distributed = paddle::distributed; void CreateVarsOnScope(framework::Scope* scope, platform::CPUPlace* place) { auto x_var = scope->Var("x"); x_var->GetMutable(); + auto x_g_var = scope->Var("x@GRAD"); + x_g_var->GetMutable(); } void InitTensorsOnClient(framework::Scope* scope, platform::CPUPlace* place, @@ -59,34 +61,49 @@ void InitTensorsOnClient(framework::Scope* scope, platform::CPUPlace* place, float* x_ptr = x_var->mutable_data(framework::DDim({1, rows_numel}), *place); for (int64_t i = 0; i < rows_numel; ++i) x_ptr[i] = 1.0; + + auto g_size = rows_numel + + 30; // hard code here: key_num * (fea_dim + 3), show/clk/slot + auto x_g_var = scope->Var("x@GRAD")->GetMutable(); + float* x_g_ptr = + x_g_var->mutable_data(framework::DDim({1, g_size}), *place); + for (int64_t i = 0; i < g_size; ++i) x_g_ptr[i] = 1.0; } void GetDownpourSparseTableProto( ::paddle::distributed::TableParameter* sparse_table_proto) { sparse_table_proto->set_table_id(0); - sparse_table_proto->set_table_class("CommonSparseTable"); - sparse_table_proto->set_shard_num(256); - sparse_table_proto->set_type(::paddle::distributed::PS_SPARSE_TABLE); - ::paddle::distributed::TableAccessorParameter* accessor_proto = + sparse_table_proto->set_table_class("MemorySparseTable"); + sparse_table_proto->set_shard_num(10); + ::paddle::distributed::TableAccessorParameter* accessor_config = sparse_table_proto->mutable_accessor(); - ::paddle::distributed::CommonAccessorParameter* common_proto = - sparse_table_proto->mutable_common(); - - accessor_proto->set_accessor_class("CommMergeAccessor"); - accessor_proto->set_fea_dim(0); - accessor_proto->set_embedx_dim(10); - - common_proto->set_name("sgd"); - common_proto->set_table_name("MergedDense"); - common_proto->set_trainer_num(1); - common_proto->set_sync(false); - common_proto->set_entry("none"); - common_proto->add_params("Param"); - common_proto->add_dims(10); - common_proto->add_initializers("uniform_random&0&-1.0&1.0"); - common_proto->add_params("LearningRate"); - common_proto->add_dims(1); - common_proto->add_initializers("fill_constant&1.0"); + + accessor_config->set_accessor_class("SparseAccessor"); + accessor_config->set_fea_dim(10); + accessor_config->set_embedx_dim(9); + accessor_config->set_embedx_threshold(0); + accessor_config->mutable_ctr_accessor_param()->set_nonclk_coeff(0.2); + accessor_config->mutable_ctr_accessor_param()->set_click_coeff(1); + accessor_config->mutable_ctr_accessor_param()->set_base_threshold(0.5); + accessor_config->mutable_ctr_accessor_param()->set_delta_threshold(0.2); + accessor_config->mutable_ctr_accessor_param()->set_delta_keep_days(16); + accessor_config->mutable_ctr_accessor_param()->set_show_click_decay_rate( + 0.99); + + accessor_config->mutable_embed_sgd_param()->set_name("SparseNaiveSGDRule"); + auto* naive_param = + accessor_config->mutable_embed_sgd_param()->mutable_naive(); + naive_param->set_learning_rate(1.0); + naive_param->set_initial_range(0.3); + naive_param->add_weight_bounds(-10.0); + naive_param->add_weight_bounds(10.0); + + accessor_config->mutable_embedx_sgd_param()->set_name("SparseNaiveSGDRule"); + naive_param = accessor_config->mutable_embedx_sgd_param()->mutable_naive(); + naive_param->set_learning_rate(1.0); + naive_param->set_initial_range(0.3); + naive_param->add_weight_bounds(-10.0); + naive_param->add_weight_bounds(10.0); } ::paddle::distributed::PSParameter GetServerProto() { @@ -217,42 +234,42 @@ void RunBrpcPushSparse() { auto pull_status = worker_ptr_->PullSparse( fea_value_ptr.data(), 0, fea_keys.data(), fea_keys.size(), true); pull_status.wait(); - for (size_t idx = 0; idx < tensor->numel(); ++idx) { - fea_values.data()[idx] *= 2.0; - } - - /*-----------------------Test Push Param----------------------------------*/ - LOG(INFO) << "Run push_sparse_param"; - paddle::distributed::DownpourBrpcClosure* closure_push_param = + /*-----------------------Test Push Grad----------------------------------*/ + // first to expand embedx, init + paddle::distributed::DownpourBrpcClosure* closure_push_grad = new paddle::distributed::DownpourBrpcClosure(1, [&](void* done) { int ret = 0; auto* closure = (paddle::distributed::DownpourBrpcClosure*)done; for (size_t i = 0; i < 1; ++i) { if (closure->check_response( - i, paddle::distributed::PS_PUSH_SPARSE_PARAM) != 0) { + i, paddle::distributed::PS_PUSH_SPARSE_TABLE) != 0) { ret = -1; break; } } closure->set_promise_value(ret); }); - auto push_status = worker_ptr_->PushSparseParam( - 0, fea_keys.data(), (const float**)fea_value_ptr.data(), fea_keys.size(), - closure_push_param); - push_status.wait(); - auto pull_param_status = worker_ptr_->PullSparse( - fea_temp_value_ptr.data(), 0, fea_keys.data(), fea_keys.size(), true); - pull_param_status.wait(); + framework::Variable* g_var = client_scope.FindVar("x@GRAD"); + framework::LoDTensor* g_tensor = g_var->GetMutable(); - for (size_t idx = 0; idx < tensor->numel(); ++idx) { - EXPECT_FLOAT_EQ(fea_temp_values[idx], fea_values[idx]); + LOG(INFO) << "Run push_sparse_grad"; + std::vector push_g_vec; + for (auto i = 0; i < static_cast(fea_keys.size()); ++i) { + push_g_vec.push_back(g_tensor->data() + i * 13); } + auto push_grad_status = worker_ptr_->PushSparseRawGradient( + 0, fea_keys.data(), (const float**)push_g_vec.data(), fea_keys.size(), + closure_push_grad); + push_grad_status.wait(); - /*-----------------------Test Push Grad----------------------------------*/ + // pull + pull_status = worker_ptr_->PullSparse(fea_value_ptr.data(), 0, + fea_keys.data(), fea_keys.size(), true); + pull_status.wait(); - paddle::distributed::DownpourBrpcClosure* closure_push_grad = + paddle::distributed::DownpourBrpcClosure* closure_push_grad1 = new paddle::distributed::DownpourBrpcClosure(1, [&](void* done) { int ret = 0; auto* closure = (paddle::distributed::DownpourBrpcClosure*)done; @@ -266,16 +283,13 @@ void RunBrpcPushSparse() { closure->set_promise_value(ret); }); - LOG(INFO) << "Run pull_sparse_grad"; - std::vector push_g_vec; - for (auto i = 0; i < static_cast(fea_keys.size()); ++i) { - push_g_vec.push_back(tensor->data() + i * 10); - } - auto push_grad_status = worker_ptr_->PushSparseRawGradient( + // push again, embedx update this time + push_grad_status = worker_ptr_->PushSparseRawGradient( 0, fea_keys.data(), (const float**)push_g_vec.data(), fea_keys.size(), - closure_push_grad); + closure_push_grad1); push_grad_status.wait(); + // pull update auto pull_update_status = worker_ptr_->PullSparse( fea_temp_value_ptr.data(), 0, fea_keys.data(), fea_keys.size(), true); pull_update_status.wait(); diff --git a/paddle/fluid/distributed/test/dense_table_test.cc b/paddle/fluid/distributed/test/dense_table_test.cc index 49346c2898f..40992b1b53b 100644 --- a/paddle/fluid/distributed/test/dense_table_test.cc +++ b/paddle/fluid/distributed/test/dense_table_test.cc @@ -69,7 +69,13 @@ TEST(CommonDenseTable, Adam) { // pull parameters for create and check std::vector init_values; init_values.resize(fea_dim); - table->PullDense(init_values.data(), fea_dim); + + TableContext table_context1; + table_context1.value_type = Dense; + table_context1.pull_context.values = init_values.data(); + table_context1.num = fea_dim; + table->Pull(table_context1); + // table->PullDense(init_values.data(), fea_dim); // push gradient std::vector> trainer_gradient_values; @@ -85,12 +91,24 @@ TEST(CommonDenseTable, Adam) { // for adam for (int i = 0; i < trainers; i++) { auto &push_values = trainer_gradient_values[i]; - table->PushDense(push_values.data(), push_values.size()); + + TableContext table_context; + table_context.value_type = Dense; + table_context.push_context.values = push_values.data(); + table_context.num = push_values.size(); + table->Push(table_context); + // table->PushDense(push_values.data(), push_values.size()); } std::vector pull_values; pull_values.resize(fea_dim); - table->PullDense(pull_values.data(), fea_dim); + + TableContext table_context; + table_context.value_type = Dense; + table_context.pull_context.values = pull_values.data(); + table_context.num = fea_dim; + table->Pull(table_context); + // table->PullDense(pull_values.data(), fea_dim); float mom_rate = 0.99; float decay_rate = 0.9999; @@ -150,7 +168,13 @@ TEST(CommonDenseTable, SGD) { // pull parameters for create and check std::vector init_values; init_values.resize(fea_dim); - table->PullDense(init_values.data(), fea_dim); + + TableContext table_context1; + table_context1.value_type = Dense; + table_context1.pull_context.values = init_values.data(); + table_context1.num = fea_dim; + table->Pull(table_context1); + // table->PullDense(init_values.data(), fea_dim); std::vector total_gradients; total_gradients.resize(fea_dim); @@ -173,7 +197,12 @@ TEST(CommonDenseTable, SGD) { for (int i = 0; i < trainers; i++) { auto &push_values = trainer_gradient_values[i]; auto task = [table, &push_values] { - table->PushDense(push_values.data(), push_values.size()); + TableContext table_context; + table_context.value_type = Dense; + table_context.push_context.values = push_values.data(); + table_context.num = push_values.size(); + table->Push(table_context); + // table->PushDense(push_values.data(), push_values.size()); }; task_status.push_back(pool_->enqueue(std::move(task))); } @@ -183,7 +212,13 @@ TEST(CommonDenseTable, SGD) { std::vector pull_values; pull_values.resize(fea_dim); - table->PullDense(pull_values.data(), fea_dim); + + TableContext table_context; + table_context.value_type = Dense; + table_context.pull_context.values = pull_values.data(); + table_context.num = fea_dim; + table->Pull(table_context); + // table->PullDense(pull_values.data(), fea_dim); for (int j = 0; j < fea_dim; j++) { auto update_val = init_values[j] - 1.0 * total_gradients[j]; ASSERT_TRUE(abs(update_val - pull_values[j]) < 1e-5); diff --git a/paddle/fluid/distributed/test/memory_geo_table_test.cc b/paddle/fluid/distributed/test/memory_geo_table_test.cc index 965f67992d0..ca3b51fade1 100644 --- a/paddle/fluid/distributed/test/memory_geo_table_test.cc +++ b/paddle/fluid/distributed/test/memory_geo_table_test.cc @@ -58,12 +58,26 @@ TEST(MemorySparseGeoTable, SSUM) { for (size_t i = 0; i < init_keys.size() * emb_dim; i++) { init_values.push_back(0.0); } - table->PushSparseParam(init_keys.data(), init_values.data(), - init_keys.size()); + + TableContext table_context1; + table_context1.value_type = Sparse; + table_context1.push_context.keys = init_keys.data(); + table_context1.push_context.values = init_values.data(); + table_context1.push_context.is_param = true; + table_context1.num = init_keys.size(); + + table->Push(table_context1); + // table->PushSparseParam(init_keys.data(), init_values.data(), + // init_keys.size()); std::vector pull_values(init_values.size()); auto value = PullSparseValue(init_keys, init_fres, emb_dim); - table->PullSparse(pull_values.data(), value); + TableContext table_context; + table_context.value_type = Sparse; + table_context.pull_context.pull_value = value; + table_context.pull_context.values = pull_values.data(); + table->Pull(table_context); + // table->PullSparse(pull_values.data(), value); for (size_t i = 0; i < init_keys.size() * emb_dim; i++) { ASSERT_TRUE(abs(pull_values[i] - init_values[i]) < 1e-5); @@ -93,7 +107,14 @@ TEST(MemorySparseGeoTable, SSUM) { auto &push_keys = trainer_keys[i]; auto &push_values = trainer_values[i]; auto task = [table, &push_keys, &push_values] { - table->PushSparse(push_keys.data(), push_values.data(), push_keys.size()); + TableContext table_context; + table_context.value_type = Sparse; + table_context.push_context.keys = push_keys.data(); + table_context.push_context.values = push_values.data(); + table_context.num = push_keys.size(); + table->Push(table_context); + // table->PushSparse(push_keys.data(), push_values.data(), + // push_keys.size()); }; task_status.push_back(pool_->enqueue(std::move(task))); } @@ -106,7 +127,13 @@ TEST(MemorySparseGeoTable, SSUM) { geo_pull_ids.resize(trainers); geo_pull_values.resize(trainers); for (int i = 0; i < trainers; i++) { - table->PullGeoParam(i, &geo_pull_values[i], &geo_pull_ids[i]); + TableContext table_context; + table_context.value_type = Sparse; + table_context.pull_context.geo_pull_keys = &geo_pull_ids[i]; + table_context.pull_context.geo_pull_values = &geo_pull_values[i]; + table_context.trainer_id = i; + table->Pull(table_context); + // table->PullGeoParam(i, &geo_pull_values[i], &geo_pull_ids[i]); ASSERT_EQ(geo_pull_values[i].size(), geo_pull_ids[i].size() * emb_dim); for (size_t j = 0; j < geo_pull_ids[i].size(); ++j) { auto id = geo_pull_ids[i][j]; diff --git a/paddle/fluid/distributed/test/memory_sparse_table_test.cc b/paddle/fluid/distributed/test/memory_sparse_table_test.cc index 73fa7272280..68bc50373ff 100644 --- a/paddle/fluid/distributed/test/memory_sparse_table_test.cc +++ b/paddle/fluid/distributed/test/memory_sparse_table_test.cc @@ -76,7 +76,13 @@ TEST(MemorySparseTable, SGD) { std::vector init_values; init_values.resize(init_keys.size() * (emb_dim + 3)); auto value = PullSparseValue(init_keys, init_fres, emb_dim); - table->PullSparse(init_values.data(), value); + + TableContext table_context; + table_context.value_type = Sparse; + table_context.pull_context.pull_value = value; + table_context.pull_context.values = init_values.data(); + table->Pull(table_context); + // table->PullSparse(init_values.data(), value); // for check std::vector total_gradients; @@ -109,7 +115,14 @@ TEST(MemorySparseTable, SGD) { auto &push_keys = trainer_keys[i]; auto &push_values = trainer_gradient_values[i]; auto task = [table, &push_keys, &push_values] { - table->PushSparse(push_keys.data(), push_values.data(), push_keys.size()); + TableContext table_context; + table_context.value_type = Sparse; + table_context.push_context.keys = push_keys.data(); + table_context.push_context.values = push_values.data(); + table_context.num = push_keys.size(); + table->Push(table_context); + // table->PushSparse(push_keys.data(), push_values.data(), + // push_keys.size()); }; task_status.push_back(pool_->enqueue(std::move(task))); } @@ -119,7 +132,13 @@ TEST(MemorySparseTable, SGD) { std::vector pull_values; pull_values.resize(init_keys.size() * (emb_dim + 3)); - table->PullSparse(pull_values.data(), value); + + TableContext table_context1; + table_context1.value_type = Sparse; + table_context1.pull_context.pull_value = value; + table_context1.pull_context.values = pull_values.data(); + table->Pull(table_context1); + // table->PullSparse(pull_values.data(), value); for (size_t i = 0; i < init_keys.size(); ++i) { for (size_t j = 2; j < emb_dim + 3; ++j) { diff --git a/python/paddle/distributed/ps/the_one_ps.py b/python/paddle/distributed/ps/the_one_ps.py index 007aaeb4fed..1fd435cca11 100755 --- a/python/paddle/distributed/ps/the_one_ps.py +++ b/python/paddle/distributed/ps/the_one_ps.py @@ -621,7 +621,7 @@ class SparseTable(Table): class GeoSparseTable(SparseTable): def __init__(self, context, send_ctx): super(GeoSparseTable, self).__init__(context, send_ctx) - self.table_class = "SparseGeoTable" + self.table_class = "MemorySparseGeoTable" if self.context['ps_mode'] != DistributedMode.GEO: raise ValueError("not geo sparse table!") -- GitLab