提交 b20f528b 编写于 作者: M malin10

test=develop, bug fix

上级 b3526fb4
...@@ -452,6 +452,7 @@ void GeoCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx, ...@@ -452,6 +452,7 @@ void GeoCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx,
void GeoCommunicator::Send(const std::vector<std::string> &var_names, void GeoCommunicator::Send(const std::vector<std::string> &var_names,
const std::vector<std::string> &var_tables, const std::vector<std::string> &var_tables,
const framework::Scope &scope) { const framework::Scope &scope) {
return;
waiting_ = false; waiting_ = false;
// PADDLE_ENFORCE_EQ( // PADDLE_ENFORCE_EQ(
...@@ -475,6 +476,7 @@ void GeoCommunicator::Send(const std::vector<std::string> &var_names, ...@@ -475,6 +476,7 @@ void GeoCommunicator::Send(const std::vector<std::string> &var_names,
// << queue->Size(); // << queue->Size();
// queue->Push(tmp_var); // queue->Push(tmp_var);
} else { } else {
auto p1 = GetCurrentUS();
auto splited_var_nums = auto splited_var_nums =
recv_varname_to_ctx_[table_name].splited_varnames.size(); recv_varname_to_ctx_[table_name].splited_varnames.size();
if (ids_table->find(table_name) == ids_table->end()) { if (ids_table->find(table_name) == ids_table->end()) {
...@@ -484,18 +486,28 @@ void GeoCommunicator::Send(const std::vector<std::string> &var_names, ...@@ -484,18 +486,28 @@ void GeoCommunicator::Send(const std::vector<std::string> &var_names,
table_name, table_name,
std::vector<std::unordered_set<int64_t>>{splited_var_nums})); std::vector<std::unordered_set<int64_t>>{splited_var_nums}));
} }
auto p2 = GetCurrentUS();
auto *var = scope.FindVar(var_names[i]); auto *var = scope.FindVar(var_names[i]);
auto &rows = var->Get<framework::SelectedRows>().rows(); auto var_tensor = var->Get<framework::LoDTensor>();
// split rows index into output sparse vars int element_number = var_tensor.numel();
for (size_t i = 0; i < rows.size(); ++i) { int *var_mutable_data = var_tensor.mutable_data<int>(var_tensor.place());
auto ep_idx = rows[i] % splited_var_nums; auto p3 = GetCurrentUS();
ids_table->at(table_name)[ep_idx].insert(rows[i]); // insert ids which has not been record
for (int j = 0; j < element_number; j++) {
auto ep_idx = var_mutable_data[j] % splited_var_nums;
ids_table->at(table_name)[ep_idx].insert(var_mutable_data[j]);
} }
auto p4 = GetCurrentUS();
VLOG(1) << "table_name: " << table_name << "; p1-2: " << (p2 - p1)
<< "; p2-3: " << (p3 - p2) << "; p3-4: " << (p4 - p3);
} }
} }
auto before_push = GetCurrentUS();
need_push_queue_->Push(ids_table); need_push_queue_->Push(ids_table);
auto after_send = GetCurrentUS(); auto after_send = GetCurrentUS();
VLOG(0) << "run send_op finish. using " << (after_send - before_send); VLOG(1) << "run send_op finish. using " << (before_push - before_send) << "; "
<< (after_send - before_push);
} }
void GeoCommunicator::MainThread() { void GeoCommunicator::MainThread() {
...@@ -532,15 +544,15 @@ void GeoCommunicator::MainThread() { ...@@ -532,15 +544,15 @@ void GeoCommunicator::MainThread() {
if (ids_send_vec_.size() >= static_cast<size_t>(max_merge_var_num_)) { if (ids_send_vec_.size() >= static_cast<size_t>(max_merge_var_num_)) {
auto before_send_global_step = GetCurrentUS(); auto before_send_global_step = GetCurrentUS();
VLOG(0) << "finish ins_send_vec using time " VLOG(1) << "finish ins_send_vec using time "
<< (before_send_global_step - before_send_by_communicator) << (before_send_global_step - before_send_by_communicator)
<< "; send_var_nums_ = " << send_var_nums_; << "; send_var_nums_ = " << send_var_nums_;
SendGlobalStep(max_merge_var_num_); SendGlobalStep(max_merge_var_num_);
auto after_send_global_step = GetCurrentUS(); auto after_send_global_step = GetCurrentUS();
VLOG(0) << "finish send global_step using " VLOG(1) << "finish send global_step using "
<< (after_send_global_step - before_send_global_step); << (after_send_global_step - before_send_global_step);
for (auto &iter : send_varname_to_ctx_) { for (auto &iter : send_varname_to_ctx_) {
VLOG(1) << "debug " << iter.first; VLOG(2) << "debug " << iter.first;
auto &var_name = iter.first; auto &var_name = iter.first;
auto &send_ctx = iter.second; auto &send_ctx = iter.second;
int pserver_num = static_cast<int>(send_ctx.epmap.size()); int pserver_num = static_cast<int>(send_ctx.epmap.size());
...@@ -556,11 +568,11 @@ void GeoCommunicator::MainThread() { ...@@ -556,11 +568,11 @@ void GeoCommunicator::MainThread() {
if (var_name == STEP_COUNTER) { if (var_name == STEP_COUNTER) {
return; return;
} }
SendSparse(var_name, ep_idx); // SendSparse(var_name, ep_idx, ids_send_vec_);
auto after_send_sparse = GetCurrentUS(); auto after_send_sparse = GetCurrentUS();
RecvSparse(var_name, ep_idx); RecvSparse(var_name, ep_idx);
auto after_recv_sparse = GetCurrentUS(); auto after_recv_sparse = GetCurrentUS();
VLOG(0) VLOG(1)
<< "send recv " << "send recv "
<< send_varname_to_ctx_.at(var_name).splited_varnames[ep_idx] << send_varname_to_ctx_.at(var_name).splited_varnames[ep_idx]
<< " finish, using " << " finish, using "
...@@ -596,57 +608,60 @@ void GeoCommunicator::MainThread() { ...@@ -596,57 +608,60 @@ void GeoCommunicator::MainThread() {
ids_send_vec_.clear(); ids_send_vec_.clear();
auto finish_one_comm = GetCurrentUS(); auto finish_one_comm = GetCurrentUS();
VLOG(0) << "Finish SendByCommunicator " VLOG(1) << "Finish SendByCommunicator "
<< (finish_one_comm - after_send_global_step); << (finish_one_comm - after_send_global_step);
} }
} }
} }
void GeoCommunicator::SendSparse(const std::string &varname, int ep_idx) { void GeoCommunicator::SendSparse(
std::vector<int64_t> ids; const std::string &varname, int ep_idx,
const std::vector<SparseIdsMap> &ids_send_vec) {
std::unordered_set<int64_t> ids_set;
auto debug1 = GetCurrentUS();
auto &rpc_ctx = send_varname_to_ctx_.at(varname); auto &rpc_ctx = send_varname_to_ctx_.at(varname);
VLOG(1) << rpc_ctx.print(); VLOG(2) << rpc_ctx.print();
auto send_varname = rpc_ctx.splited_varnames[ep_idx]; auto send_varname = rpc_ctx.splited_varnames[ep_idx];
auto trainer_id = rpc_ctx.trainer_id; auto trainer_id = rpc_ctx.trainer_id;
auto endpoint = rpc_ctx.epmap[ep_idx]; auto endpoint = rpc_ctx.epmap[ep_idx];
auto pserver_num = rpc_ctx.epmap.size(); auto pserver_num = rpc_ctx.epmap.size();
for (auto ids_map : ids_send_vec_) { int64_t vector_size = 0;
std::copy(ids_map[varname][ep_idx].begin(), ids_map[varname][ep_idx].end(), for (auto ids_map : ids_send_vec) {
back_inserter(ids)); for (auto id : ids_map[varname][ep_idx]) {
ids_set.insert(id);
vector_size += 1;
if (vector_size > 10) {
break;
}
}
if (vector_size > 10) {
break;
}
} }
VLOG(1) << "ids_vector_size: " << ids.size();
auto size = ids.size();
std::set<int64_t> st(ids.begin(), ids.end()); auto debug2 = GetCurrentUS();
ids.assign(st.begin(), st.end()); VLOG(1) << "vector_size: " << vector_size
<< "; ids_set_size: " << ids_set.size() << "; using time "
<< (debug2 - debug1);
std::stringstream list_str; auto size = ids_set.size();
for (uint64_t i = 0; i < ids.size(); i++) {
list_str << ids[i] << ",";
}
VLOG(1) << "SendSparse receive var: " << send_varname << " unset: " << size
<< " set: " << ids.size() << ": " << list_str.str();
if (ids.empty()) { if (size == 0) {
LOG(WARNING) << "WARNING: GEO has nothing to send, return directly "; LOG(WARNING) << "WARNING: GEO has nothing to send, return directly ";
return; return;
} }
std::vector<size_t> outs_rows_idx; std::vector<int64_t> new_rows;
new_rows.insert(new_rows.begin(), ids_set.begin(), ids_set.end());
if (!rpc_ctx.is_distributed) { // std::stringstream list_str;
for (size_t i = 0; i < ids.size(); ++i) { // for (uint64_t i = 0; i < ids.size(); i++) {
auto id = ids[i] / pserver_num; // list_str << ids[i] << ",";
outs_rows_idx.push_back(id); // }
} auto debug3 = GetCurrentUS();
} else { VLOG(1) << "SendSparse receive var: " << send_varname
for (size_t i = 0; i < ids.size(); ++i) { << " set: " << ids_set.size() << ", using time " << (debug3 - debug1);
outs_rows_idx.push_back(ids[i]);
}
}
auto *var_latest = recv_scope_->FindVar(varname); auto *var_latest = recv_scope_->FindVar(varname);
...@@ -661,30 +676,35 @@ void GeoCommunicator::SendSparse(const std::string &varname, int ep_idx) { ...@@ -661,30 +676,35 @@ void GeoCommunicator::SendSparse(const std::string &varname, int ep_idx) {
auto *var_delta = delta_scope_->Var(send_varname); auto *var_delta = delta_scope_->Var(send_varname);
auto *t_delta = var_delta->GetMutable<framework::SelectedRows>(); auto *t_delta = var_delta->GetMutable<framework::SelectedRows>();
t_delta->set_height(rpc_ctx.height_sections[ep_idx]);
t_delta->mutable_rows()->assign(outs_rows_idx.begin(), outs_rows_idx.end());
auto *t_value = t_delta->mutable_value(); auto *t_value = t_delta->mutable_value();
t_value->mutable_data<float>( t_value->mutable_data<float>(
framework::make_ddim({static_cast<int64_t>(ids.size()), dims1}), framework::make_ddim({static_cast<int64_t>(new_rows.size()), dims1}),
cpu_ctx.GetPlace()); cpu_ctx.GetPlace());
std::vector<std::vector<std::vector<float> *>> values; std::vector<std::vector<std::vector<float> *>> values;
auto *ins = distributed::LargeScaleKV::GetInstance(); auto *ins = distributed::LargeScaleKV::GetInstance();
ins->Get(varname)->Get(ids, {"Param"}, &values); ins->Get(varname)->Get(new_rows, {"Param"}, &values);
auto blas = math::GetBlas<platform::CPUDeviceContext, float>(cpu_ctx); auto blas = math::GetBlas<platform::CPUDeviceContext, float>(cpu_ctx);
float coefficient = 1.0 / static_cast<float>(trainers_); float coefficient = 1.0 / static_cast<float>(trainers_);
for (auto j = 0; j < static_cast<int>(ids.size()); ++j) { for (auto j = 0; j < static_cast<int>(new_rows.size()); ++j) {
blas.VSUB(dims1, t_latest.data<float>() + ids[j] * dims1, blas.VSUB(dims1, t_latest.data<float>() + new_rows[j] * dims1,
values[j][0]->data(), t_value->data<float>() + j * dims1); values[j][0]->data(), t_value->data<float>() + j * dims1);
blas.SCAL(dims1, coefficient, t_value->data<float>() + j * dims1); blas.SCAL(dims1, coefficient, t_value->data<float>() + j * dims1);
blas.VADD(dims1, values[j][0]->data(), t_value->data<float>() + j * dims1, blas.VADD(dims1, values[j][0]->data(), t_value->data<float>() + j * dims1,
values[j][0]->data()); values[j][0]->data());
} }
VLOG(1) << "begin to real send " << send_varname; std::vector<int64_t> send_rows;
send_rows.reserve(new_rows.size());
for (auto idx : new_rows) {
send_rows.push_back(idx / pserver_num);
}
t_delta->set_height(rpc_ctx.height_sections[ep_idx]);
t_delta->set_rows(send_rows);
VLOG(2) << "begin to real send " << send_varname;
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &cpu_ctx_send = *pool.Get(platform::CPUPlace()); auto &cpu_ctx_send = *pool.Get(platform::CPUPlace());
distributed::RPCClient *rpc_client = distributed::RPCClient *rpc_client =
...@@ -692,9 +712,9 @@ void GeoCommunicator::SendSparse(const std::string &varname, int ep_idx) { ...@@ -692,9 +712,9 @@ void GeoCommunicator::SendSparse(const std::string &varname, int ep_idx) {
auto ret = rpc_client->AsyncSendVar(endpoint, cpu_ctx_send, auto ret = rpc_client->AsyncSendVar(endpoint, cpu_ctx_send,
*delta_scope_.get(), send_varname); *delta_scope_.get(), send_varname);
VLOG(1) << "need to wait for send " << send_varname; VLOG(2) << "need to wait for send " << send_varname;
ret->Wait(); ret->Wait();
VLOG(1) << "finish to send " << send_varname; VLOG(2) << "finish to send " << send_varname;
} }
void GeoCommunicator::SendDense(const std::string &varname) { void GeoCommunicator::SendDense(const std::string &varname) {
...@@ -740,7 +760,7 @@ void GeoCommunicator::RecvSparse(const std::string &varname, int ep_idx) { ...@@ -740,7 +760,7 @@ void GeoCommunicator::RecvSparse(const std::string &varname, int ep_idx) {
recv_varname_to_ctx_.at(varname).splited_varnames[ep_idx]; recv_varname_to_ctx_.at(varname).splited_varnames[ep_idx];
auto pserver_num = recv_varname_to_ctx_.at(varname).epmap.size(); auto pserver_num = recv_varname_to_ctx_.at(varname).epmap.size();
VLOG(1) << "Begin to RecvSparse receive var: " << splited_var_name; VLOG(2) << "Begin to RecvSparse receive var: " << splited_var_name;
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &cpu_ctx_recv = *pool.Get(platform::CPUPlace()); auto &cpu_ctx_recv = *pool.Get(platform::CPUPlace());
...@@ -753,7 +773,7 @@ void GeoCommunicator::RecvSparse(const std::string &varname, int ep_idx) { ...@@ -753,7 +773,7 @@ void GeoCommunicator::RecvSparse(const std::string &varname, int ep_idx) {
splited_var_name, splited_var_name); splited_var_name, splited_var_name);
handle->Wait(); handle->Wait();
VLOG(1) << "Finish to RecvSparse receive var: " << splited_var_name; VLOG(2) << "Finish to RecvSparse receive var: " << splited_var_name;
auto *var_latest = recv_scope_->FindVar(varname); auto *var_latest = recv_scope_->FindVar(varname);
...@@ -766,7 +786,7 @@ void GeoCommunicator::RecvSparse(const std::string &varname, int ep_idx) { ...@@ -766,7 +786,7 @@ void GeoCommunicator::RecvSparse(const std::string &varname, int ep_idx) {
ids.assign(var_psrever->Get<framework::SelectedRows>().rows().begin(), ids.assign(var_psrever->Get<framework::SelectedRows>().rows().begin(),
var_psrever->Get<framework::SelectedRows>().rows().end()); var_psrever->Get<framework::SelectedRows>().rows().end());
VLOG(1) << "RecvSparse receive var: " << splited_var_name VLOG(2) << "RecvSparse receive var: " << splited_var_name
<< " ids Size: " << ids.size(); << " ids Size: " << ids.size();
auto t_psrever = var_psrever->Get<framework::SelectedRows>().value(); auto t_psrever = var_psrever->Get<framework::SelectedRows>().value();
...@@ -796,7 +816,7 @@ void GeoCommunicator::RecvSparse(const std::string &varname, int ep_idx) { ...@@ -796,7 +816,7 @@ void GeoCommunicator::RecvSparse(const std::string &varname, int ep_idx) {
blas.VCOPY(dims1, t_psrever.data<float>() + j * dims1, blas.VCOPY(dims1, t_psrever.data<float>() + j * dims1,
old_values[j][0]->data()); old_values[j][0]->data());
} }
VLOG(1) << "receive finish"; VLOG(2) << "receive finish";
} }
void GeoCommunicator::RecvDense(const std::string &varname) { void GeoCommunicator::RecvDense(const std::string &varname) {
......
...@@ -426,7 +426,8 @@ class GeoCommunicator : public AsyncCommunicator { ...@@ -426,7 +426,8 @@ class GeoCommunicator : public AsyncCommunicator {
// void SendByCommunicator(int batches) override; // void SendByCommunicator(int batches) override;
void SendSparse(const std::string &varname, int ep_idx); void SendSparse(const std::string &varname, int ep_idx,
const std::vector<SparseIdsMap> &ids_send_vec);
void SendDense(const std::string &varname); void SendDense(const std::string &varname);
......
...@@ -169,16 +169,24 @@ def append_send_ops_pass(program, config, merge=False): ...@@ -169,16 +169,24 @@ def append_send_ops_pass(program, config, merge=False):
trainer_id = config.get_role_id() trainer_id = config.get_role_id()
pserver_endpoints = config.get_ps_endpoints() pserver_endpoints = config.get_ps_endpoints()
def _append_send_op(union_vars, queue): def _append_send_op():
send_input_vars = [] sparse_var = []
assert (len(queue) == len(union_vars)) sparse_tables = []
for i in range(len(queue)): unique_sparse_var = {}
if queue[i] == STEP_COUNTER: for op in program.global_block().ops:
send_input_vars.append("") if "is_sparse" in op.all_attrs():
else: if op.type == "lookup_table":
send_input_vars.append(program.global_block().vars[union_vars[ op._set_attr('remote_prefetch', False)
i]]) for input_var_name, sparse_var_name in zip(
op.input("Ids"), op.input("W")):
if input_var_name in unique_sparse_var:
if unique_sparse_var[input_var_name] == sparse_var_name:
continue
input_var = program.global_block().var(input_var_name)
sparse_var.append(input_var)
sparse_tables.append(sparse_var_name)
unique_sparse_var[input_var_name] = sparse_var_name
dummy_output = [] dummy_output = []
if mode in [DistributedMode.SYNC, DistributedMode.HALF_ASYNC]: if mode in [DistributedMode.SYNC, DistributedMode.HALF_ASYNC]:
...@@ -187,10 +195,10 @@ def append_send_ops_pass(program, config, merge=False): ...@@ -187,10 +195,10 @@ def append_send_ops_pass(program, config, merge=False):
program.global_block().append_op( program.global_block().append_op(
type="send", type="send",
inputs={"X": send_input_vars}, inputs={"X": sparse_var},
outputs={"Out": dummy_output}, outputs={"Out": dummy_output},
attrs={ attrs={
"send_varnames": queue, "send_varnames": sparse_tables,
"merge_add": True, "merge_add": True,
"use_send_handler": False, "use_send_handler": False,
"endpoints": pserver_endpoints, "endpoints": pserver_endpoints,
...@@ -216,17 +224,10 @@ def append_send_ops_pass(program, config, merge=False): ...@@ -216,17 +224,10 @@ def append_send_ops_pass(program, config, merge=False):
sends = config.get_trainer_send_context() sends = config.get_trainer_send_context()
if merge: if merge:
origin_varnames = [] dummys.append(_append_send_op())
merged_names = []
for merged_name, send in sends.items():
for var in send.origin_varnames():
origin_varnames.append(var)
merged_names.append(merged_name)
if len(origin_varnames) > 0:
dummys.append(_append_send_op(origin_varnames, merged_names))
else: else:
for merged_name, send in sends.items(): for merged_name, send in sends.items():
dummys.append(_append_send_op(send.origin_varnames(), merged_name)) dummys.append(_append_send_op())
if mode in [DistributedMode.SYNC, DistributedMode.HALF_ASYNC]: if mode in [DistributedMode.SYNC, DistributedMode.HALF_ASYNC]:
_append_barrier_op(dummys) _append_barrier_op(dummys)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册