提交 613df707 编写于 作者: M malin10

test=develop, bug fix

上级 78b603b2
...@@ -425,6 +425,7 @@ void GeoCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx, ...@@ -425,6 +425,7 @@ void GeoCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx,
} else { } else {
auto &send_ctx = iter.second; auto &send_ctx = iter.second;
send_var_nums_ += send_ctx.splited_varnames.size();
if (!send_ctx.is_sparse) { if (!send_ctx.is_sparse) {
continue; continue;
} }
...@@ -462,16 +463,17 @@ void GeoCommunicator::Send(const std::vector<std::string> &var_names, ...@@ -462,16 +463,17 @@ void GeoCommunicator::Send(const std::vector<std::string> &var_names,
for (size_t i = 0; i < var_tables.size(); i++) { for (size_t i = 0; i < var_tables.size(); i++) {
auto table_name = var_tables[i]; auto table_name = var_tables[i];
if (table_name == STEP_COUNTER) { if (table_name == STEP_COUNTER) {
auto &queue = send_varname_to_queue_.at(table_name); continue;
// auto &queue = send_varname_to_queue_.at(table_name);
auto tmp_var = std::make_shared<Variable>();
auto *tensor = tmp_var->GetMutable<framework::LoDTensor>(); // auto tmp_var = std::make_shared<Variable>();
tensor->Resize(framework::make_ddim({1})); // auto *tensor = tmp_var->GetMutable<framework::LoDTensor>();
auto *out_d = tensor->mutable_data<int64_t>(platform::CPUPlace()); // tensor->Resize(framework::make_ddim({1}));
out_d[0] = 1; // auto *out_d = tensor->mutable_data<int64_t>(platform::CPUPlace());
VLOG(3) << "send to " << table_name << " with queue size " // out_d[0] = 1;
<< queue->Size(); // VLOG(3) << "send to " << table_name << " with queue size "
queue->Push(tmp_var); // << queue->Size();
// queue->Push(tmp_var);
} else { } else {
auto splited_var_nums = auto splited_var_nums =
recv_varname_to_ctx_[table_name].splited_varnames.size(); recv_varname_to_ctx_[table_name].splited_varnames.size();
...@@ -506,18 +508,22 @@ void GeoCommunicator::MainThread() { ...@@ -506,18 +508,22 @@ void GeoCommunicator::MainThread() {
while (running_) { while (running_) {
// int meet = Meet(); // int meet = Meet();
VLOG(1) << "async_meet: " << meet; // VLOG(1) << "async_meet: " << meet;
// SendGlobalStep(meet); // SendGlobalStep(meet);
auto before = GetCurrentUS();
SendByCommunicator(0); SendByCommunicator(0);
auto after = GetCurrentUS();
VLOG(0) << "finish one SendByCommunicator using " << (after - before);
} }
VLOG(1) << "geo-communicator stopped, send thread exit"; VLOG(1) << "geo-communicator stopped, send thread exit";
} }
void GeoCommunicator::SendByCommunicator(int batches) { void GeoCommunicator::SendByCommunicator(int batches) {
std::vector<std::future<void>> tasks; std::vector<std::future<void>> tasks;
tasks.reserve(send_varname_to_ctx_.size()); tasks.reserve(send_var_nums_);
auto before_send_by_communicator = GetCurrentUS();
size_t wait_times = 0; size_t wait_times = 0;
while (ids_send_vec_.size() < static_cast<size_t>(max_merge_var_num_)) { while (ids_send_vec_.size() < static_cast<size_t>(max_merge_var_num_)) {
VLOG(1) << "ids_send_vec_ Size: " << ids_send_vec_.size(); VLOG(1) << "ids_send_vec_ Size: " << ids_send_vec_.size();
...@@ -537,6 +543,13 @@ void GeoCommunicator::SendByCommunicator(int batches) { ...@@ -537,6 +543,13 @@ void GeoCommunicator::SendByCommunicator(int batches) {
} }
if (ids_send_vec_.size() >= static_cast<size_t>(max_merge_var_num_)) { if (ids_send_vec_.size() >= static_cast<size_t>(max_merge_var_num_)) {
auto before_send_global_step = GetCurrentUS();
VLOG(0) << "finish ins_send_vec using time "
<< (before_send_global_step - before_send_by_communicator);
SendGlobalStep(max_merge_var_num_);
auto after_send_global_step = GetCurrentUS();
VLOG(0) << "finish send global_step using "
<< (after_send_global_step - before_send_global_step);
for (auto &iter : send_varname_to_ctx_) { for (auto &iter : send_varname_to_ctx_) {
VLOG(1) << "debug " << iter.first; VLOG(1) << "debug " << iter.first;
auto &var_name = iter.first; auto &var_name = iter.first;
...@@ -550,11 +563,20 @@ void GeoCommunicator::SendByCommunicator(int batches) { ...@@ -550,11 +563,20 @@ void GeoCommunicator::SendByCommunicator(int batches) {
for (int ep_idx = 0; ep_idx < pserver_num; ep_idx++) { for (int ep_idx = 0; ep_idx < pserver_num; ep_idx++) {
auto send_recv_task = [this, ep_idx, &var_name] { auto send_recv_task = [this, ep_idx, &var_name] {
auto before_send_sparse = GetCurrentUS();
if (var_name == STEP_COUNTER) { if (var_name == STEP_COUNTER) {
return; return;
} }
SendSparse(var_name, ep_idx); SendSparse(var_name, ep_idx);
auto after_send_sparse = GetCurrentUS();
RecvSparse(var_name, ep_idx); RecvSparse(var_name, ep_idx);
auto after_recv_sparse = GetCurrentUS();
VLOG(0)
<< "send recv "
<< send_varname_to_ctx_.at(var_name).splited_varnames[ep_idx]
<< " finish, using " << (after_send_sparse - before_send_sparse)
<< " and " << (after_recv_sparse - after_send_sparse)
<< "; total = " << (after_recv_sparse - before_send_sparse);
}; };
tasks.emplace_back( tasks.emplace_back(
send_threadpool_->enqueue(std::move(send_recv_task))); send_threadpool_->enqueue(std::move(send_recv_task)));
...@@ -562,6 +584,7 @@ void GeoCommunicator::SendByCommunicator(int batches) { ...@@ -562,6 +584,7 @@ void GeoCommunicator::SendByCommunicator(int batches) {
} }
} else { } else {
auto send_recv_task = [this, &var_name, &send_ctx] { auto send_recv_task = [this, &var_name, &send_ctx] {
return;
if (var_name == STEP_COUNTER) { if (var_name == STEP_COUNTER) {
return; return;
} }
......
...@@ -457,6 +457,7 @@ class GeoCommunicator : public AsyncCommunicator { ...@@ -457,6 +457,7 @@ class GeoCommunicator : public AsyncCommunicator {
// parameter on pserver // parameter on pserver
std::shared_ptr<Scope> pserver_scope_; std::shared_ptr<Scope> pserver_scope_;
int send_var_nums_ = 0;
std::unordered_map<std::string, std::shared_ptr<SparseValue>> old_sparses_; std::unordered_map<std::string, std::shared_ptr<SparseValue>> old_sparses_;
std::shared_ptr<BlockingQueue<std::shared_ptr<SparseIdsMap>>> std::shared_ptr<BlockingQueue<std::shared_ptr<SparseIdsMap>>>
need_push_queue_; need_push_queue_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册