提交 e8fe5186 编写于 作者: Q Qiao Longfei

complete parameter_recv

上级 d5c78982
...@@ -52,16 +52,12 @@ void ParameterRecv<T>::operator()(const RpcContext &rpc_ctx, ...@@ -52,16 +52,12 @@ void ParameterRecv<T>::operator()(const RpcContext &rpc_ctx,
auto *recv_var = scope.FindVar(rpc_ctx.var_name); auto *recv_var = scope.FindVar(rpc_ctx.var_name);
std::vector<framework::Tensor *> recved_tensors;
// recv all vars to local scope // recv all vars to local scope
if (recv_var->IsType<framework::LoDTensor>()) { if (recv_var->IsType<framework::LoDTensor>()) {
std::vector<distributed::VarHandlePtr> rets; std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < rpc_ctx.splited_var_names.size(); i++) { for (size_t i = 0; i < rpc_ctx.splited_var_names.size(); i++) {
auto &recv_var_name = rpc_ctx.splited_var_names[i]; auto &recv_var_name = rpc_ctx.splited_var_names[i];
framework::Tensor *t = local_scope->Var(recv_var_name);
local_scope->Var(recv_var_name)->GetMutable<framework::LoDTensor>();
recved_tensors.push_back(t);
VLOG(3) << "recv " << recv_var_name << " from " << rpc_ctx.epmap[i]; VLOG(3) << "recv " << recv_var_name << " from " << rpc_ctx.epmap[i];
rets.push_back(rpc_client->AsyncGetVar(rpc_ctx.epmap[i], cpu_ctx, rets.push_back(rpc_client->AsyncGetVar(rpc_ctx.epmap[i], cpu_ctx,
*local_scope, recv_var_name, *local_scope, recv_var_name,
...@@ -81,14 +77,34 @@ void ParameterRecv<T>::operator()(const RpcContext &rpc_ctx, ...@@ -81,14 +77,34 @@ void ParameterRecv<T>::operator()(const RpcContext &rpc_ctx,
recv_var->GetMutable<framework::LoDTensor>(); recv_var->GetMutable<framework::LoDTensor>();
auto dev_ctx = paddle::platform::CPUDeviceContext(); auto dev_ctx = paddle::platform::CPUDeviceContext();
int64_t recv_numel = 0; int64_t recv_numel = 0;
for (auto *in : recved_tensors) { for (auto &recv_var_name : rpc_ctx.splited_var_names) {
recv_numel += in->numel(); auto *recv_var = local_scope->FindVar(recv_var_name);
auto in_stride = framework::stride_numel(in->dims()); if (recv_var->IsType<framework::LoDTensor>()) {
auto out_stride = framework::stride_numel(recv_tensor->dims()); auto &in = recv_var->Get<framework::LoDTensor>();
StridedNumelCopyWithAxis<T>( recv_numel += in.numel();
dev_ctx, 0, recv_tensor->data<T>() + output_offset, out_stride, auto in_stride = framework::stride_numel(in.dims());
in->data<T>(), in_stride, in_stride[0]); auto out_stride = framework::stride_numel(recv_tensor->dims());
output_offset += in_stride[0]; StridedNumelCopyWithAxis<T>(
dev_ctx, 0, recv_tensor->data<T>() + output_offset, out_stride,
in.data<T>(), in_stride, in_stride[0]);
output_offset += in_stride[0];
} else if (recv_var->IsType<framework::SelectedRows>()) {
auto &recv_slr = recv_var->Get<framework::SelectedRows>();
auto &recv_dims = recv_tensor->dims();
int64_t width = recv_dims[1];
PADDLE_ENFORCE_EQ(recv_slr.height(), recv_dims[0]);
PADDLE_ENFORCE_EQ(recv_slr.value().dims()[1], width);
PADDLE_ENFORCE_EQ(recv_slr.value().dims()[0], recv_slr.rows().size());
VLOG(3) << "recv slr " << recv_var_name << " dims "
<< recv_slr.value().dims();
for (auto i = 0; i < recv_slr.rows().size(); ++i) {
auto row_id = recv_slr.rows()[i];
memcpy(recv_tensor->data<T>() + row_id * width,
recv_slr.value().data<T>() + i * width, sizeof(T) * width);
}
} else {
PADDLE_THROW("unsupported recieved var type");
}
} }
PADDLE_ENFORCE_EQ(recv_numel, recv_tensor->numel()); PADDLE_ENFORCE_EQ(recv_numel, recv_tensor->numel());
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册