diff --git a/paddle/fluid/operators/distributed/parameter_recv.cc b/paddle/fluid/operators/distributed/parameter_recv.cc index ae6516b2464326eaa6e6731ae446090c8ab36042..f40f25c7573db10b195dcd6c9bbf1b6f30a3851b 100644 --- a/paddle/fluid/operators/distributed/parameter_recv.cc +++ b/paddle/fluid/operators/distributed/parameter_recv.cc @@ -52,16 +52,12 @@ void ParameterRecv::operator()(const RpcContext &rpc_ctx, auto *recv_var = scope.FindVar(rpc_ctx.var_name); - std::vector recved_tensors; - // recv all vars to local scope if (recv_var->IsType()) { std::vector rets; for (size_t i = 0; i < rpc_ctx.splited_var_names.size(); i++) { auto &recv_var_name = rpc_ctx.splited_var_names[i]; - framework::Tensor *t = - local_scope->Var(recv_var_name)->GetMutable(); - recved_tensors.push_back(t); + local_scope->Var(recv_var_name); VLOG(3) << "recv " << recv_var_name << " from " << rpc_ctx.epmap[i]; rets.push_back(rpc_client->AsyncGetVar(rpc_ctx.epmap[i], cpu_ctx, *local_scope, recv_var_name, @@ -81,14 +77,34 @@ void ParameterRecv::operator()(const RpcContext &rpc_ctx, recv_var->GetMutable(); auto dev_ctx = paddle::platform::CPUDeviceContext(); int64_t recv_numel = 0; - for (auto *in : recved_tensors) { - recv_numel += in->numel(); - auto in_stride = framework::stride_numel(in->dims()); - auto out_stride = framework::stride_numel(recv_tensor->dims()); - StridedNumelCopyWithAxis( - dev_ctx, 0, recv_tensor->data() + output_offset, out_stride, - in->data(), in_stride, in_stride[0]); - output_offset += in_stride[0]; + for (auto &recv_var_name : rpc_ctx.splited_var_names) { + auto *recv_var = local_scope->FindVar(recv_var_name); + if (recv_var->IsType()) { + auto &in = recv_var->Get(); + recv_numel += in.numel(); + auto in_stride = framework::stride_numel(in.dims()); + auto out_stride = framework::stride_numel(recv_tensor->dims()); + StridedNumelCopyWithAxis( + dev_ctx, 0, recv_tensor->data() + output_offset, out_stride, + in.data(), in_stride, in_stride[0]); + output_offset += in_stride[0]; + } else if (recv_var->IsType()) { + auto &recv_slr = recv_var->Get(); + auto &recv_dims = recv_tensor->dims(); + int64_t width = recv_dims[1]; + PADDLE_ENFORCE_EQ(recv_slr.height(), recv_dims[0]); + PADDLE_ENFORCE_EQ(recv_slr.value().dims()[1], width); + PADDLE_ENFORCE_EQ(recv_slr.value().dims()[0], recv_slr.rows().size()); + VLOG(3) << "recv slr " << recv_var_name << " dims " + << recv_slr.value().dims(); + for (auto i = 0; i < recv_slr.rows().size(); ++i) { + auto row_id = recv_slr.rows()[i]; + memcpy(recv_tensor->data() + row_id * width, + recv_slr.value().data() + i * width, sizeof(T) * width); + } + } else { + PADDLE_THROW("unsupported recieved var type"); + } } PADDLE_ENFORCE_EQ(recv_numel, recv_tensor->numel()); }