未验证 提交 eb05db71 编写于 作者: C Chengmo 提交者: GitHub

Speed GEO-SGD (#20158)

* delete debug vlog & add rpc function & fix word2vec bug & speed GEO-SGD
上级 b28d4a82
......@@ -27,7 +27,10 @@ limitations under the License. */
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/rpc_client.h"
#include "paddle/fluid/operators/distributed/rpc_common.h"
#include "paddle/fluid/operators/distributed_ops/send_recv_util.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/device_context.h"
......@@ -268,7 +271,7 @@ class Communicator {
};
using SparseIdsMap =
std::unordered_map<std::string, std::unordered_set<int64_t>>;
std::unordered_map<std::string, std::vector<std::unordered_set<int64_t>>>;
class AsyncCommunicator : public Communicator {
public:
......@@ -348,15 +351,18 @@ class GeoSgdCommunicator : public Communicator {
private:
void SendThread();
void RecvAll();
std::unordered_set<int64_t> SparseIdsMerge(
const std::vector<SparseIdsMap>& ids_send_vec,
const std::string& var_name);
const std::string& var_name, const std::string& splited_var_name);
void SendUpdateDenseVars(const std::string& var_name);
void SendUpdateSparseVars(const std::string& var_name,
const std::string& splited_var_name,
const std::unordered_set<int64_t>& ids_table);
void RecvUpdateVars(const std::string& var_name);
void RecvUpdateDenseVars(const std::string& var_name);
void RecvUpdateSparseVars(const std::string& var_name,
const std::string& splited_var_name);
void GeoSgdDenseParamInit(framework::Scope* scope_x,
framework::Scope* scope_y,
......@@ -366,6 +372,14 @@ class GeoSgdCommunicator : public Communicator {
framework::Scope* scope_y,
const std::string var_name);
void RpcSend(const std::string& origin_var_name,
const std::string& splited_var_name,
const size_t& splited_var_index);
void RpcRecv(const std::string& origin_var_name,
const std::string& splited_var_name,
const size_t& splited_var_index);
const std::string VarToDeltaVar(const std::string var_name) {
std::string delta_name = var_name;
const std::string send_name = delta_name.append(".delta");
......@@ -379,6 +393,20 @@ class GeoSgdCommunicator : public Communicator {
return param_name;
}
size_t GetSplitedVarIndex(const std::string var_name,
const std::string splited_var_name) {
size_t index = 0;
for (size_t i = 0;
i < send_varname_to_ctx_[var_name].splited_var_names.size(); i++) {
if (send_varname_to_ctx_[var_name].splited_var_names[i] ==
splited_var_name) {
index = i;
break;
}
}
return index;
}
private:
int trainer_nums_ = 1;
int geo_need_push_nums_ = 100;
......@@ -390,8 +418,6 @@ class GeoSgdCommunicator : public Communicator {
std::shared_ptr<Scope> pserver_scope_; // parameter on pserver,gloabl scope
RpcCtxMap send_varname_to_ctx_;
RpcCtxMap recv_varname_to_ctx_;
std::atomic_uint have_push_{0};
std::unordered_map<std::string, bool>
var_list_; // if var is sparse, using selected rows, bool=true
......@@ -399,9 +425,12 @@ class GeoSgdCommunicator : public Communicator {
need_push_queue_;
std::vector<SparseIdsMap> ids_send_vec_;
std::unordered_map<std::string, std::vector<int64_t>> absolute_section_;
std::unique_ptr<::ThreadPool> send_threadpool_{nullptr};
std::unique_ptr<::ThreadPool> recv_threadpool_{nullptr};
std::unique_ptr<std::thread> send_thread_{nullptr};
size_t need_thread_nums_{0};
};
} // namespace distributed
......
......@@ -179,8 +179,8 @@ class GeoSgdTranspiler(DistributeTranspiler):
return self.vars_info
def get_trainer_program(self, wait_port=True):
# if wait_port:
# wait_server_ready(self.pserver_endpoints)
if wait_port:
wait_server_ready(self.pserver_endpoints)
return self.origin_program
def get_pserver_programs(self, endpoint):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册