未验证 提交 bad312c3 编写于 作者: C Chengmo 提交者: GitHub

[cherry-pick][release/1.6]Speed GEO-SGD (#20177)

* Speed GEO-SGD (#20158)

* delete debug vlog & add rpc function & fix word2vec bug & speed GEO-SGD
上级 2691de18
...@@ -27,7 +27,10 @@ limitations under the License. */ ...@@ -27,7 +27,10 @@ limitations under the License. */
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/rpc_client.h"
#include "paddle/fluid/operators/distributed/rpc_common.h" #include "paddle/fluid/operators/distributed/rpc_common.h"
#include "paddle/fluid/operators/distributed_ops/send_recv_util.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h" #include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
...@@ -268,7 +271,7 @@ class Communicator { ...@@ -268,7 +271,7 @@ class Communicator {
}; };
using SparseIdsMap = using SparseIdsMap =
std::unordered_map<std::string, std::unordered_set<int64_t>>; std::unordered_map<std::string, std::vector<std::unordered_set<int64_t>>>;
class AsyncCommunicator : public Communicator { class AsyncCommunicator : public Communicator {
public: public:
...@@ -348,15 +351,18 @@ class GeoSgdCommunicator : public Communicator { ...@@ -348,15 +351,18 @@ class GeoSgdCommunicator : public Communicator {
private: private:
void SendThread(); void SendThread();
void RecvAll();
std::unordered_set<int64_t> SparseIdsMerge( std::unordered_set<int64_t> SparseIdsMerge(
const std::vector<SparseIdsMap>& ids_send_vec, const std::vector<SparseIdsMap>& ids_send_vec,
const std::string& var_name); const std::string& var_name, const std::string& splited_var_name);
void SendUpdateDenseVars(const std::string& var_name); void SendUpdateDenseVars(const std::string& var_name);
void SendUpdateSparseVars(const std::string& var_name, void SendUpdateSparseVars(const std::string& var_name,
const std::string& splited_var_name,
const std::unordered_set<int64_t>& ids_table); const std::unordered_set<int64_t>& ids_table);
void RecvUpdateVars(const std::string& var_name);
void RecvUpdateDenseVars(const std::string& var_name);
void RecvUpdateSparseVars(const std::string& var_name,
const std::string& splited_var_name);
void GeoSgdDenseParamInit(framework::Scope* scope_x, void GeoSgdDenseParamInit(framework::Scope* scope_x,
framework::Scope* scope_y, framework::Scope* scope_y,
...@@ -366,6 +372,14 @@ class GeoSgdCommunicator : public Communicator { ...@@ -366,6 +372,14 @@ class GeoSgdCommunicator : public Communicator {
framework::Scope* scope_y, framework::Scope* scope_y,
const std::string var_name); const std::string var_name);
void RpcSend(const std::string& origin_var_name,
const std::string& splited_var_name,
const size_t& splited_var_index);
void RpcRecv(const std::string& origin_var_name,
const std::string& splited_var_name,
const size_t& splited_var_index);
const std::string VarToDeltaVar(const std::string var_name) { const std::string VarToDeltaVar(const std::string var_name) {
std::string delta_name = var_name; std::string delta_name = var_name;
const std::string send_name = delta_name.append(".delta"); const std::string send_name = delta_name.append(".delta");
...@@ -379,6 +393,20 @@ class GeoSgdCommunicator : public Communicator { ...@@ -379,6 +393,20 @@ class GeoSgdCommunicator : public Communicator {
return param_name; return param_name;
} }
size_t GetSplitedVarIndex(const std::string var_name,
const std::string splited_var_name) {
size_t index = 0;
for (size_t i = 0;
i < send_varname_to_ctx_[var_name].splited_var_names.size(); i++) {
if (send_varname_to_ctx_[var_name].splited_var_names[i] ==
splited_var_name) {
index = i;
break;
}
}
return index;
}
private: private:
int trainer_nums_ = 1; int trainer_nums_ = 1;
int geo_need_push_nums_ = 100; int geo_need_push_nums_ = 100;
...@@ -390,8 +418,6 @@ class GeoSgdCommunicator : public Communicator { ...@@ -390,8 +418,6 @@ class GeoSgdCommunicator : public Communicator {
std::shared_ptr<Scope> pserver_scope_; // parameter on pserver,gloabl scope std::shared_ptr<Scope> pserver_scope_; // parameter on pserver,gloabl scope
RpcCtxMap send_varname_to_ctx_; RpcCtxMap send_varname_to_ctx_;
RpcCtxMap recv_varname_to_ctx_; RpcCtxMap recv_varname_to_ctx_;
std::atomic_uint have_push_{0};
std::unordered_map<std::string, bool> std::unordered_map<std::string, bool>
var_list_; // if var is sparse, using selected rows, bool=true var_list_; // if var is sparse, using selected rows, bool=true
...@@ -399,9 +425,12 @@ class GeoSgdCommunicator : public Communicator { ...@@ -399,9 +425,12 @@ class GeoSgdCommunicator : public Communicator {
need_push_queue_; need_push_queue_;
std::vector<SparseIdsMap> ids_send_vec_; std::vector<SparseIdsMap> ids_send_vec_;
std::unordered_map<std::string, std::vector<int64_t>> absolute_section_;
std::unique_ptr<::ThreadPool> send_threadpool_{nullptr}; std::unique_ptr<::ThreadPool> send_threadpool_{nullptr};
std::unique_ptr<::ThreadPool> recv_threadpool_{nullptr};
std::unique_ptr<std::thread> send_thread_{nullptr}; std::unique_ptr<std::thread> send_thread_{nullptr};
size_t need_thread_nums_{0};
}; };
} // namespace distributed } // namespace distributed
......
...@@ -179,8 +179,8 @@ class GeoSgdTranspiler(DistributeTranspiler): ...@@ -179,8 +179,8 @@ class GeoSgdTranspiler(DistributeTranspiler):
return self.vars_info return self.vars_info
def get_trainer_program(self, wait_port=True): def get_trainer_program(self, wait_port=True):
# if wait_port: if wait_port:
# wait_server_ready(self.pserver_endpoints) wait_server_ready(self.pserver_endpoints)
return self.origin_program return self.origin_program
def get_pserver_programs(self, endpoint): def get_pserver_programs(self, endpoint):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册