diff --git a/paddle/fluid/operators/distributed/parameter_send.cc b/paddle/fluid/operators/distributed/parameter_send.cc index 01e7341f15f9868b9a049dc1418f557b47ac95f5..d79ea8cdb98718543d20ddaa03baa6cfc482ba30 100644 --- a/paddle/fluid/operators/distributed/parameter_send.cc +++ b/paddle/fluid/operators/distributed/parameter_send.cc @@ -47,6 +47,15 @@ static size_t GetSectionIndex(int64_t id, return abs_sections.size() - 1; } +static int FindOutIdx(int row, const std::vector& abs_sections) { + for (size_t i = 1; i < abs_sections.size(); ++i) { + if (row < abs_sections[i]) { + return i - 1; + } + } + return abs_sections.size() - 1; +} + static std::vector ToAbsoluteSection( const std::vector& height_sections) { std::vector abs_sections; @@ -97,21 +106,22 @@ static void SplitIdsIntoMultipleVarsBySection( } } +template void send(const std::string& var_name, const std::vector& send_varnames, const std::vector& epmap, const std::vector& height_sections, - const framework::ExecutionContext& context, - const framework::Scope& scope, bool sync) { + const framework::ExecutionContext& ctx, const framework::Scope& scope, + bool sync) { framework::Scope* local_scope = scope.NewTmpScope(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto& cpu_ctx = *pool.Get(platform::CPUPlace()); - auto& actual_ctx = *pool.Get(context.GetPlace()); + auto& actual_ctx = *pool.Get(ctx.GetPlace()); distributed::RPCClient* rpc_client = distributed::RPCClient::GetInstance( - context.Attr("trainer_id")); + ctx.Attr("trainer_id")); auto* send_var = scope.FindVar(var_name); size_t out_num = send_varnames.size(); @@ -122,7 +132,7 @@ void send(const std::string& var_name, outs_dims.reserve(out_num); // infer output shape - int num = context.Attr("num"); + int num = ctx.Attr("num"); if (num > 0) { int64_t in_axis_dim = send_tensor_dims[0]; PADDLE_ENFORCE_EQ(in_axis_dim % num, 0, @@ -153,13 +163,71 @@ void send(const std::string& var_name, *out = send_tensor.Slice(row_offset, row_offset + outs_dims[i][0]); row_offset += outs_dims[i][0]; } - } else if (send_var->IsType()) { + } else if (send_var->IsType()) { + auto& send_slr = send_var->Get(); + auto abs_sections = ToAbsoluteSection(height_sections); + + auto send_rows = send_slr.rows(); + std::vector> outs_rows_idx; + std::vector> outs_dense_idx; + + outs_rows_idx.resize(out_num); + outs_dense_idx.resize(out_num); + + auto row_numel = send_slr.value().numel() / send_slr.value().dims()[0]; + auto src = send_slr.value().data(); + // create output var in local scope + std::vector outs; for (auto& name : send_varnames) { - local_scope->Var(name)->GetMutable(); + auto* out = local_scope->Var(name)->GetMutable(); + outs.push_back(out); + } + + // split rows index into output sparse vars + for (size_t i = 0; i < send_rows.size(); ++i) { + int out_idx = FindOutIdx(send_rows[i], abs_sections); + outs_rows_idx[out_idx].push_back(send_rows[i]); + outs_dense_idx[out_idx].push_back(i); } + auto place = ctx.GetPlace(); + + for (size_t i = 0; i < outs_rows_idx.size(); ++i) { + auto rows_idx = outs_rows_idx[i]; + outs[i]->set_height(height_sections[i]); + auto dims = send_slr.GetCompleteDims(); + dims[0] = rows_idx.size(); + outs[i]->mutable_value()->mutable_data(dims, send_slr.place()); + outs[i]->mutable_rows()->clear(); + if (rows_idx.size() > 0) { + for (auto idx : rows_idx) { + outs[i]->mutable_rows()->push_back(idx - abs_sections[i]); + } + auto dst = outs[i]->mutable_value()->mutable_data(ctx.GetPlace()); + for (size_t j = 0; j < rows_idx.size(); j++) { + if (platform::is_cpu_place(place)) { + memory::Copy( + platform::CPUPlace(), dst + j * row_numel, platform::CPUPlace(), + src + outs_dense_idx[i][j] * row_numel, sizeof(T) * row_numel); + } else { +#ifdef PADDLE_WITH_CUDA + auto stream = ctx.cuda_device_context().stream(); + memory::Copy(platform::CUDAPlace(), dst + j * row_numel, + platform::CUDAPlace(), + src + outs_dense_idx[i][j] * row_numel, + sizeof(T) * row_numel, stream); +#else + PADDLE_THROW("Paddle is not compiled with GPU"); +#endif + } + } + } + PADDLE_ENFORCE_EQ(rows_idx.size(), outs[i]->rows().size(), + "rows should has the same size with tensor dim 0"); + } + } else { - PADDLE_THROW("unsupported var type"); + PADDLE_THROW("unsupported var type to send!"); } std::vector rets; diff --git a/paddle/fluid/operators/distributed/parameter_send.h b/paddle/fluid/operators/distributed/parameter_send.h index ee4da997b73c9c1bc734c43bb9fe5fa6fdc12624..e337649cf23805841dc0c174940e1dd22bff06f9 100644 --- a/paddle/fluid/operators/distributed/parameter_send.h +++ b/paddle/fluid/operators/distributed/parameter_send.h @@ -23,6 +23,7 @@ namespace paddle { namespace operators { namespace distributed { +template void send(const std::string& var_name, const std::vector& send_varnames, const std::vector& epmap,