提交 b20f528b 编写于 作者: M malin10

test=develop, bug fix

上级 b3526fb4
......@@ -452,6 +452,7 @@ void GeoCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx,
void GeoCommunicator::Send(const std::vector<std::string> &var_names,
const std::vector<std::string> &var_tables,
const framework::Scope &scope) {
return;
waiting_ = false;
// PADDLE_ENFORCE_EQ(
......@@ -475,6 +476,7 @@ void GeoCommunicator::Send(const std::vector<std::string> &var_names,
// << queue->Size();
// queue->Push(tmp_var);
} else {
auto p1 = GetCurrentUS();
auto splited_var_nums =
recv_varname_to_ctx_[table_name].splited_varnames.size();
if (ids_table->find(table_name) == ids_table->end()) {
......@@ -484,18 +486,28 @@ void GeoCommunicator::Send(const std::vector<std::string> &var_names,
table_name,
std::vector<std::unordered_set<int64_t>>{splited_var_nums}));
}
auto p2 = GetCurrentUS();
auto *var = scope.FindVar(var_names[i]);
auto &rows = var->Get<framework::SelectedRows>().rows();
// split rows index into output sparse vars
for (size_t i = 0; i < rows.size(); ++i) {
auto ep_idx = rows[i] % splited_var_nums;
ids_table->at(table_name)[ep_idx].insert(rows[i]);
auto var_tensor = var->Get<framework::LoDTensor>();
int element_number = var_tensor.numel();
int *var_mutable_data = var_tensor.mutable_data<int>(var_tensor.place());
auto p3 = GetCurrentUS();
// insert ids which has not been record
for (int j = 0; j < element_number; j++) {
auto ep_idx = var_mutable_data[j] % splited_var_nums;
ids_table->at(table_name)[ep_idx].insert(var_mutable_data[j]);
}
auto p4 = GetCurrentUS();
VLOG(1) << "table_name: " << table_name << "; p1-2: " << (p2 - p1)
<< "; p2-3: " << (p3 - p2) << "; p3-4: " << (p4 - p3);
}
}
auto before_push = GetCurrentUS();
need_push_queue_->Push(ids_table);
auto after_send = GetCurrentUS();
VLOG(0) << "run send_op finish. using " << (after_send - before_send);
VLOG(1) << "run send_op finish. using " << (before_push - before_send) << "; "
<< (after_send - before_push);
}
void GeoCommunicator::MainThread() {
......@@ -532,15 +544,15 @@ void GeoCommunicator::MainThread() {
if (ids_send_vec_.size() >= static_cast<size_t>(max_merge_var_num_)) {
auto before_send_global_step = GetCurrentUS();
VLOG(0) << "finish ins_send_vec using time "
VLOG(1) << "finish ins_send_vec using time "
<< (before_send_global_step - before_send_by_communicator)
<< "; send_var_nums_ = " << send_var_nums_;
SendGlobalStep(max_merge_var_num_);
auto after_send_global_step = GetCurrentUS();
VLOG(0) << "finish send global_step using "
VLOG(1) << "finish send global_step using "
<< (after_send_global_step - before_send_global_step);
for (auto &iter : send_varname_to_ctx_) {
VLOG(1) << "debug " << iter.first;
VLOG(2) << "debug " << iter.first;
auto &var_name = iter.first;
auto &send_ctx = iter.second;
int pserver_num = static_cast<int>(send_ctx.epmap.size());
......@@ -556,11 +568,11 @@ void GeoCommunicator::MainThread() {
if (var_name == STEP_COUNTER) {
return;
}
SendSparse(var_name, ep_idx);
// SendSparse(var_name, ep_idx, ids_send_vec_);
auto after_send_sparse = GetCurrentUS();
RecvSparse(var_name, ep_idx);
auto after_recv_sparse = GetCurrentUS();
VLOG(0)
VLOG(1)
<< "send recv "
<< send_varname_to_ctx_.at(var_name).splited_varnames[ep_idx]
<< " finish, using "
......@@ -596,57 +608,60 @@ void GeoCommunicator::MainThread() {
ids_send_vec_.clear();
auto finish_one_comm = GetCurrentUS();
VLOG(0) << "Finish SendByCommunicator "
VLOG(1) << "Finish SendByCommunicator "
<< (finish_one_comm - after_send_global_step);
}
}
}
void GeoCommunicator::SendSparse(const std::string &varname, int ep_idx) {
std::vector<int64_t> ids;
void GeoCommunicator::SendSparse(
const std::string &varname, int ep_idx,
const std::vector<SparseIdsMap> &ids_send_vec) {
std::unordered_set<int64_t> ids_set;
auto debug1 = GetCurrentUS();
auto &rpc_ctx = send_varname_to_ctx_.at(varname);
VLOG(1) << rpc_ctx.print();
VLOG(2) << rpc_ctx.print();
auto send_varname = rpc_ctx.splited_varnames[ep_idx];
auto trainer_id = rpc_ctx.trainer_id;
auto endpoint = rpc_ctx.epmap[ep_idx];
auto pserver_num = rpc_ctx.epmap.size();
for (auto ids_map : ids_send_vec_) {
std::copy(ids_map[varname][ep_idx].begin(), ids_map[varname][ep_idx].end(),
back_inserter(ids));
int64_t vector_size = 0;
for (auto ids_map : ids_send_vec) {
for (auto id : ids_map[varname][ep_idx]) {
ids_set.insert(id);
vector_size += 1;
if (vector_size > 10) {
break;
}
}
if (vector_size > 10) {
break;
}
}
VLOG(1) << "ids_vector_size: " << ids.size();
auto size = ids.size();
auto debug2 = GetCurrentUS();
VLOG(1) << "vector_size: " << vector_size
<< "; ids_set_size: " << ids_set.size() << "; using time "
<< (debug2 - debug1);
std::set<int64_t> st(ids.begin(), ids.end());
ids.assign(st.begin(), st.end());
std::stringstream list_str;
for (uint64_t i = 0; i < ids.size(); i++) {
list_str << ids[i] << ",";
}
VLOG(1) << "SendSparse receive var: " << send_varname << " unset: " << size
<< " set: " << ids.size() << ": " << list_str.str();
auto size = ids_set.size();
if (ids.empty()) {
if (size == 0) {
LOG(WARNING) << "WARNING: GEO has nothing to send, return directly ";
return;
}
std::vector<size_t> outs_rows_idx;
std::vector<int64_t> new_rows;
new_rows.insert(new_rows.begin(), ids_set.begin(), ids_set.end());
if (!rpc_ctx.is_distributed) {
for (size_t i = 0; i < ids.size(); ++i) {
auto id = ids[i] / pserver_num;
outs_rows_idx.push_back(id);
}
} else {
for (size_t i = 0; i < ids.size(); ++i) {
outs_rows_idx.push_back(ids[i]);
}
}
// std::stringstream list_str;
// for (uint64_t i = 0; i < ids.size(); i++) {
// list_str << ids[i] << ",";
// }
auto debug3 = GetCurrentUS();
VLOG(1) << "SendSparse receive var: " << send_varname
<< " set: " << ids_set.size() << ", using time " << (debug3 - debug1);
auto *var_latest = recv_scope_->FindVar(varname);
......@@ -661,30 +676,35 @@ void GeoCommunicator::SendSparse(const std::string &varname, int ep_idx) {
auto *var_delta = delta_scope_->Var(send_varname);
auto *t_delta = var_delta->GetMutable<framework::SelectedRows>();
t_delta->set_height(rpc_ctx.height_sections[ep_idx]);
t_delta->mutable_rows()->assign(outs_rows_idx.begin(), outs_rows_idx.end());
auto *t_value = t_delta->mutable_value();
t_value->mutable_data<float>(
framework::make_ddim({static_cast<int64_t>(ids.size()), dims1}),
framework::make_ddim({static_cast<int64_t>(new_rows.size()), dims1}),
cpu_ctx.GetPlace());
std::vector<std::vector<std::vector<float> *>> values;
auto *ins = distributed::LargeScaleKV::GetInstance();
ins->Get(varname)->Get(ids, {"Param"}, &values);
ins->Get(varname)->Get(new_rows, {"Param"}, &values);
auto blas = math::GetBlas<platform::CPUDeviceContext, float>(cpu_ctx);
float coefficient = 1.0 / static_cast<float>(trainers_);
for (auto j = 0; j < static_cast<int>(ids.size()); ++j) {
blas.VSUB(dims1, t_latest.data<float>() + ids[j] * dims1,
for (auto j = 0; j < static_cast<int>(new_rows.size()); ++j) {
blas.VSUB(dims1, t_latest.data<float>() + new_rows[j] * dims1,
values[j][0]->data(), t_value->data<float>() + j * dims1);
blas.SCAL(dims1, coefficient, t_value->data<float>() + j * dims1);
blas.VADD(dims1, values[j][0]->data(), t_value->data<float>() + j * dims1,
values[j][0]->data());
}
VLOG(1) << "begin to real send " << send_varname;
std::vector<int64_t> send_rows;
send_rows.reserve(new_rows.size());
for (auto idx : new_rows) {
send_rows.push_back(idx / pserver_num);
}
t_delta->set_height(rpc_ctx.height_sections[ep_idx]);
t_delta->set_rows(send_rows);
VLOG(2) << "begin to real send " << send_varname;
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &cpu_ctx_send = *pool.Get(platform::CPUPlace());
distributed::RPCClient *rpc_client =
......@@ -692,9 +712,9 @@ void GeoCommunicator::SendSparse(const std::string &varname, int ep_idx) {
auto ret = rpc_client->AsyncSendVar(endpoint, cpu_ctx_send,
*delta_scope_.get(), send_varname);
VLOG(1) << "need to wait for send " << send_varname;
VLOG(2) << "need to wait for send " << send_varname;
ret->Wait();
VLOG(1) << "finish to send " << send_varname;
VLOG(2) << "finish to send " << send_varname;
}
void GeoCommunicator::SendDense(const std::string &varname) {
......@@ -740,7 +760,7 @@ void GeoCommunicator::RecvSparse(const std::string &varname, int ep_idx) {
recv_varname_to_ctx_.at(varname).splited_varnames[ep_idx];
auto pserver_num = recv_varname_to_ctx_.at(varname).epmap.size();
VLOG(1) << "Begin to RecvSparse receive var: " << splited_var_name;
VLOG(2) << "Begin to RecvSparse receive var: " << splited_var_name;
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &cpu_ctx_recv = *pool.Get(platform::CPUPlace());
......@@ -753,7 +773,7 @@ void GeoCommunicator::RecvSparse(const std::string &varname, int ep_idx) {
splited_var_name, splited_var_name);
handle->Wait();
VLOG(1) << "Finish to RecvSparse receive var: " << splited_var_name;
VLOG(2) << "Finish to RecvSparse receive var: " << splited_var_name;
auto *var_latest = recv_scope_->FindVar(varname);
......@@ -766,7 +786,7 @@ void GeoCommunicator::RecvSparse(const std::string &varname, int ep_idx) {
ids.assign(var_psrever->Get<framework::SelectedRows>().rows().begin(),
var_psrever->Get<framework::SelectedRows>().rows().end());
VLOG(1) << "RecvSparse receive var: " << splited_var_name
VLOG(2) << "RecvSparse receive var: " << splited_var_name
<< " ids Size: " << ids.size();
auto t_psrever = var_psrever->Get<framework::SelectedRows>().value();
......@@ -796,7 +816,7 @@ void GeoCommunicator::RecvSparse(const std::string &varname, int ep_idx) {
blas.VCOPY(dims1, t_psrever.data<float>() + j * dims1,
old_values[j][0]->data());
}
VLOG(1) << "receive finish";
VLOG(2) << "receive finish";
}
void GeoCommunicator::RecvDense(const std::string &varname) {
......
......@@ -426,7 +426,8 @@ class GeoCommunicator : public AsyncCommunicator {
// void SendByCommunicator(int batches) override;
void SendSparse(const std::string &varname, int ep_idx);
void SendSparse(const std::string &varname, int ep_idx,
const std::vector<SparseIdsMap> &ids_send_vec);
void SendDense(const std::string &varname);
......
......@@ -169,16 +169,24 @@ def append_send_ops_pass(program, config, merge=False):
trainer_id = config.get_role_id()
pserver_endpoints = config.get_ps_endpoints()
def _append_send_op(union_vars, queue):
def _append_send_op():
send_input_vars = []
assert (len(queue) == len(union_vars))
for i in range(len(queue)):
if queue[i] == STEP_COUNTER:
send_input_vars.append("")
else:
send_input_vars.append(program.global_block().vars[union_vars[
i]])
sparse_var = []
sparse_tables = []
unique_sparse_var = {}
for op in program.global_block().ops:
if "is_sparse" in op.all_attrs():
if op.type == "lookup_table":
op._set_attr('remote_prefetch', False)
for input_var_name, sparse_var_name in zip(
op.input("Ids"), op.input("W")):
if input_var_name in unique_sparse_var:
if unique_sparse_var[input_var_name] == sparse_var_name:
continue
input_var = program.global_block().var(input_var_name)
sparse_var.append(input_var)
sparse_tables.append(sparse_var_name)
unique_sparse_var[input_var_name] = sparse_var_name
dummy_output = []
if mode in [DistributedMode.SYNC, DistributedMode.HALF_ASYNC]:
......@@ -187,10 +195,10 @@ def append_send_ops_pass(program, config, merge=False):
program.global_block().append_op(
type="send",
inputs={"X": send_input_vars},
inputs={"X": sparse_var},
outputs={"Out": dummy_output},
attrs={
"send_varnames": queue,
"send_varnames": sparse_tables,
"merge_add": True,
"use_send_handler": False,
"endpoints": pserver_endpoints,
......@@ -216,17 +224,10 @@ def append_send_ops_pass(program, config, merge=False):
sends = config.get_trainer_send_context()
if merge:
origin_varnames = []
merged_names = []
for merged_name, send in sends.items():
for var in send.origin_varnames():
origin_varnames.append(var)
merged_names.append(merged_name)
if len(origin_varnames) > 0:
dummys.append(_append_send_op(origin_varnames, merged_names))
dummys.append(_append_send_op())
else:
for merged_name, send in sends.items():
dummys.append(_append_send_op(send.origin_varnames(), merged_name))
dummys.append(_append_send_op())
if mode in [DistributedMode.SYNC, DistributedMode.HALF_ASYNC]:
_append_barrier_op(dummys)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册