提交 fb6cc3a1 编写于 作者: Q Qiao Longfei

follow commnet, optimize code and add comment test=develop

上级 adf272bc
......@@ -132,8 +132,11 @@ class AsyncSSAGraphBuilder : public MultiDevSSAGraphBuilderBase {
node->Op()->Flush();
} else if (node->Name() == "lookup_table" || node->Name() == "nce" ||
node->Name() == "hierarchical_sigmoid") {
// in async_mode, we do not need remote prefetch, because communicator
// will do async parameter recv.
VLOG(1) << "set " << node->Name() << " op remote_prefetch to false";
node->Op()->SetAttr("remote_prefetch", false);
node->Op()->Flush();
}
return false;
}
......
......@@ -52,6 +52,10 @@ class Scope {
/// Mark it to const because that new kid scope cannot change parent scope.
Scope& NewScope() const;
/// Create a sub-scope for current scope but do not record it in the kids to
/// avoid performance problems.
/// Note!!! You should delete the result pointer yourself to avoid memory
/// leak!
Scope* NewTmpScope() const;
/// Create a variable with given name if it doesn't exist.
......
......@@ -81,8 +81,8 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx,
auto abs_sections = ToAbsoluteSection(rpc_ctx.height_sections);
auto &send_rows = send_slr.rows();
std::vector<std::vector<int>> outs_rows_idx;
std::vector<std::vector<int>> outs_dense_idx;
std::vector<std::vector<size_t>> outs_rows_idx;
std::vector<std::vector<size_t>> outs_dense_idx;
outs_rows_idx.resize(out_num);
outs_dense_idx.resize(out_num);
......@@ -99,7 +99,7 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx,
// split rows index into output sparse vars
for (size_t i = 0; i < send_rows.size(); ++i) {
int out_idx = GetSectionIndex(send_rows[i], abs_sections);
size_t out_idx = GetSectionIndex(send_rows[i], abs_sections);
outs_rows_idx[out_idx].push_back(send_rows[i]);
outs_dense_idx[out_idx].push_back(i);
}
......@@ -160,10 +160,9 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx,
}
}
// note!! only support sync send now
if (true || sync) {
for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient");
if (sync) {
for (auto &handle : rets) {
PADDLE_ENFORCE(handle->Wait(), "internal error in RPCClient");
}
}
......
......@@ -52,7 +52,7 @@ class SendOp : public framework::OperatorBase {
auto send_functor = distributed::ParameterSend<float>();
auto rpc_ctx = distributed::RpcContext(ins[0], send_varnames, epmap,
height_sections);
send_functor(rpc_ctx, scope, static_cast<bool>(sync_send));
send_functor(rpc_ctx, scope, true);
} else {
distributed::Communicator::GetInstance()->Send(ins[0], scope);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册