提交 2ee51619 编写于 作者: S seiriosPlus

optimize init from pserver

上级 7e70802b
...@@ -449,7 +449,7 @@ void GeoCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx, ...@@ -449,7 +449,7 @@ void GeoCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx,
old_scope_.reset(new Scope()); old_scope_.reset(new Scope());
pserver_scope_.reset(new Scope()); pserver_scope_.reset(new Scope());
Init(); InitParams();
} }
void GeoCommunicator::Send(const std::vector<std::string> &var_names, void GeoCommunicator::Send(const std::vector<std::string> &var_names,
...@@ -822,7 +822,7 @@ void GeoCommunicator::InitSparse() { ...@@ -822,7 +822,7 @@ void GeoCommunicator::InitSparse() {
paddle::platform::CPUDeviceContext()); paddle::platform::CPUDeviceContext());
for (auto &id : ids) { for (auto &id : ids) {
blas.VCOPY(dim1, global_value.data<float>() + k * width, blas.VCOPY(dim1, global_value.data<float>() + id * dim1,
values[id][0]->data()); values[id][0]->data());
} }
} }
......
...@@ -60,7 +60,7 @@ void RecvSparseLodTensor(const CommContext &rpc_ctx, ...@@ -60,7 +60,7 @@ void RecvSparseLodTensor(const CommContext &rpc_ctx,
// sparse param in recv_scope is LoDTensor // sparse param in recv_scope is LoDTensor
rets.push_back(rpc_client->AsyncGetVarNoBarrier( rets.push_back(rpc_client->AsyncGetVarNoBarrier(
rpc_ctx.epmap[i], cpu_ctx, *local_scope.get(), recv_var_name, rpc_ctx.epmap[i], cpu_ctx, *local_scope.get(), recv_var_name,
recv_var_name, recv_var_name)); recv_var_name));
const auto *value = local_var->Get<framework::LoDTensor>().data<float>(); const auto *value = local_var->Get<framework::LoDTensor>().data<float>();
tensors.push_back(value); tensors.push_back(value);
...@@ -88,7 +88,7 @@ void RecvSparseLodTensor(const CommContext &rpc_ctx, ...@@ -88,7 +88,7 @@ void RecvSparseLodTensor(const CommContext &rpc_ctx,
"recved var must has same dims with local var"); "recved var must has same dims with local var");
auto *merged_t = merged_var->GetMutable<framework::LoDTensor>(); auto *merged_t = merged_var->GetMutable<framework::LoDTensor>();
auto *merged_d = merged_t->mutable_data<float>(place)(); auto *merged_d = merged_t->mutable_data<float>(cpu_place);
auto pserver_num = rpc_ctx.splited_varnames.size(); auto pserver_num = rpc_ctx.splited_varnames.size();
for (int x = 0; x < height; ++x) { for (int x = 0; x < height; ++x) {
...@@ -220,9 +220,9 @@ void ParameterRecv<T>::operator()(const CommContext &rpc_ctx, ...@@ -220,9 +220,9 @@ void ParameterRecv<T>::operator()(const CommContext &rpc_ctx,
if (rpc_ctx.is_sparse) { if (rpc_ctx.is_sparse) {
if (geo_records) { if (geo_records) {
RecvGeoSparseRecords()<T>(rpc_ctx, scope); RecvGeoSparseRecords<T>(rpc_ctx, scope);
} else { } else {
RecvSparseLodTensor()<T>(rpc_ctx, scope); RecvSparseLodTensor<T>(rpc_ctx, scope);
} }
} else { } else {
RecvLodTensor<T>(rpc_ctx, scope); RecvLodTensor<T>(rpc_ctx, scope);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册