parameter_send.cc 8.9 KB
Newer Older
Q
Qiao Longfei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
//   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/fluid/operators/distributed/parameter_send.h"
Q
Qiao Longfei 已提交
16
#include <memory>
Q
Qiao Longfei 已提交
17 18
#include <set>
#include <string>
19
#include <utility>
Q
Qiao Longfei 已提交
20 21 22 23 24 25 26 27 28 29 30
#include <vector>

#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor.h"

#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/rpc_client.h"
#include "paddle/fluid/operators/distributed/variable_response.h"
#include "paddle/fluid/operators/distributed_ops/send_recv_util.h"
31
#include "paddle/fluid/string/printf.h"
Q
Qiao Longfei 已提交
32 33 34 35 36 37 38 39 40 41

namespace paddle {
namespace operators {
namespace distributed {

using LoDTensor = framework::LoDTensor;
using LoDTensor = framework::LoDTensor;
using SelectedRows = framework::SelectedRows;
using DDim = framework::DDim;

42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
typedef std::vector<std::pair<std::string, std::string>> EP_SPLIT_TABLE_PAIRS;

inline EP_SPLIT_TABLE_PAIRS GetMultiFieldRpcContext(
    const RpcContext &rpc_ctx, const framework::Scope &scope, int multi_parts) {
  EP_SPLIT_TABLE_PAIRS table_pairs;

  auto *send_var = scope.FindVar(rpc_ctx.var_name);
  if (send_var->IsType<framework::SelectedRows>()) {
    PADDLE_ENFORCE_GT(multi_parts, 0, "multi_parts must >=1");

    if (multi_parts == 1) {
      for (int i = 0; i < rpc_ctx.splited_var_names.size(); i++) {
        table_pairs.push_back(
            std::make_pair(rpc_ctx.epmap[i], rpc_ctx.splited_var_names[i]));
      }
    } else {
      for (int i = 0; i < rpc_ctx.splited_var_names.size(); i++) {
        for (int x = 0; x < multi_parts; x++) {
          auto table =
              string::Sprintf("%s@%d@PIECE", rpc_ctx.splited_var_names[i], x);
          table_pairs.push_back(std::make_pair(rpc_ctx.epmap[i], table));
        }
      }
    }

  } else if (send_var->IsType<framework::LoDTensor>()) {
    PADDLE_THROW("GetMultiFieldRpcContext can not support LoDTensor current!");
  } else {
    PADDLE_THROW("GetMultiFieldRpcContext unsupported var type!");
  }

  return table_pairs;
}  // namespace distributed

76
template <typename T>
77
void ParameterSend<T>::operator()(const RpcContext &rpc_ctx,
78 79
                                  const framework::Scope &scope, bool sync,
                                  int multi_parts) {
80
  std::unique_ptr<framework::Scope> local_scope = scope.NewTmpScope();
Q
Qiao Longfei 已提交
81 82 83 84 85

  platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
  auto &cpu_ctx = *pool.Get(platform::CPUPlace());

  distributed::RPCClient *rpc_client =
Q
Qiao Longfei 已提交
86
      distributed::RPCClient::GetInstance<RPCCLIENT_T>(rpc_ctx.trainer_id);
Q
Qiao Longfei 已提交
87

88 89
  std::vector<distributed::VarHandlePtr> rets;

90
  auto *send_var = scope.FindVar(rpc_ctx.var_name);
91

Q
Qiao Longfei 已提交
92
  if (send_var->IsType<framework::LoDTensor>()) {
93
    size_t out_num = rpc_ctx.splited_var_names.size();
Q
Qiao Longfei 已提交
94 95 96 97 98
    if (out_num > 1) {
      auto &send_tensor = send_var->Get<framework::LoDTensor>();
      auto &send_tensor_dims = send_tensor.dims();
      std::vector<framework::DDim> outs_dims;
      outs_dims.reserve(out_num);
Q
Qiao Longfei 已提交
99

Q
Qiao Longfei 已提交
100
      // infer output shape
101
      PADDLE_ENFORCE_EQ(rpc_ctx.height_sections.size(), out_num,
Q
Qiao Longfei 已提交
102 103 104 105
                        "tensor split sections size"
                        "should be equal to output size.");
      for (size_t i = 0; i < out_num; ++i) {
        auto dim = send_tensor_dims;
106
        dim[0] = rpc_ctx.height_sections[i];
Q
Qiao Longfei 已提交
107 108 109
        outs_dims.push_back(dim);
      }

Q
Qiao Longfei 已提交
110 111 112
      // create output var in local scope
      size_t row_offset = 0;
      for (auto i = 0; i < out_num; ++i) {
113
        framework::Tensor *out = local_scope->Var(rpc_ctx.splited_var_names[i])
Q
Qiao Longfei 已提交
114
                                     ->GetMutable<framework::LoDTensor>();
Q
Qiao Longfei 已提交
115 116 117
        *out = send_tensor.Slice(row_offset, row_offset + outs_dims[i][0]);
        row_offset += outs_dims[i][0];
      }
Q
Qiao Longfei 已提交
118
    }
119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136

    for (size_t i = 0; i < rpc_ctx.splited_var_names.size(); i++) {
      auto &send_var_name = rpc_ctx.splited_var_names[i];
      VLOG(4) << "send var name: " << send_var_name;
      auto &endpoint = rpc_ctx.epmap[i];
      VLOG(4) << "send var endpoint: " << endpoint;
      VLOG(4) << "need send: " << NeedSend(*local_scope.get(), send_var_name);
      if (NeedSend(*local_scope.get(), send_var_name)) {
        VLOG(3) << "sending " << send_var_name << " to " << endpoint;
        rets.push_back(rpc_client->AsyncSendVar(
            endpoint, cpu_ctx, *local_scope.get(), send_var_name));
        VLOG(4) << "send var " << send_var_name << " async handle done";
      } else {
        VLOG(3) << "don't send non-initialized variable: "
                << rpc_ctx.splited_var_names[i];
      }
    }

137
  } else if (send_var->IsType<framework::SelectedRows>()) {
Q
Qiao Longfei 已提交
138
    auto &send_slr = send_var->Get<framework::SelectedRows>();
139
    auto abs_sections = ToAbsoluteSection(rpc_ctx.height_sections);
140

Q
Qiao Longfei 已提交
141
    auto &send_rows = send_slr.rows();
142 143
    std::vector<std::vector<size_t>> outs_rows_idx;
    std::vector<std::vector<size_t>> outs_dense_idx;
144

145 146 147 148
    auto table_pairs = GetMultiFieldRpcContext(rpc_ctx, scope, multi_parts);

    outs_rows_idx.resize(table_pairs.size());
    outs_dense_idx.resize(table_pairs.size());
149 150

    auto row_numel = send_slr.value().numel() / send_slr.value().dims()[0];
Q
Qiao Longfei 已提交
151
    auto *src = send_slr.value().data<T>();
152

Q
Qiao Longfei 已提交
153
    // create output var in local scope
Q
Qiao Longfei 已提交
154
    std::vector<framework::SelectedRows *> outs;
155 156 157
    for (auto &table : table_pairs) {
      auto *out =
          local_scope->Var(table.second)->GetMutable<framework::SelectedRows>();
158 159 160 161 162
      outs.push_back(out);
    }

    // split rows index into output sparse vars
    for (size_t i = 0; i < send_rows.size(); ++i) {
163 164 165
      auto ep_idx = GetSectionIndex(send_rows[i], abs_sections);
      auto table_idx = send_rows[i] % multi_parts;
      auto out_idx = ep_idx * multi_parts + table_idx;
166 167
      outs_rows_idx[out_idx].push_back(send_rows[i]);
      outs_dense_idx[out_idx].push_back(i);
Q
Qiao Longfei 已提交
168
    }
169

Q
Qiao Longfei 已提交
170
    auto place = platform::CPUPlace();
171

172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197
    for (int ctx = 0; ctx < rpc_ctx.splited_var_names.size(); ctx++) {
      for (int part = 0; part < multi_parts; part++) {
        auto out_idx = ctx * multi_parts + part;
        auto rows_idx = outs_rows_idx[out_idx];

        auto dims = send_slr.GetCompleteDims();
        dims[0] = rows_idx.size();

        outs[out_idx]->set_height(rpc_ctx.height_sections[ctx]);
        outs[out_idx]->mutable_rows()->clear();
        outs[out_idx]->mutable_value()->mutable_data<T>(dims, send_slr.place());

        if (rows_idx.size() > 0) {
          for (auto idx : rows_idx) {
            outs[out_idx]->mutable_rows()->push_back(idx - abs_sections[ctx]);
          }
          auto dst = outs[out_idx]->mutable_value()->mutable_data<T>(place);
          for (size_t j = 0; j < rows_idx.size(); j++) {
            if (platform::is_cpu_place(place)) {
              memory::Copy(platform::CPUPlace(), dst + j * row_numel,
                           platform::CPUPlace(),
                           src + outs_dense_idx[out_idx][j] * row_numel,
                           sizeof(T) * row_numel);
            } else {
              PADDLE_THROW("do not support GPU now");
            }
198 199
          }
        }
200 201
        PADDLE_ENFORCE_EQ(rows_idx.size(), outs[out_idx]->rows().size(),
                          "rows should has the same size with tensor dim 0");
202 203 204
      }
    }

205 206 207 208
    for (size_t i = 0; i < table_pairs.size(); i++) {
      auto &send_var_name = table_pairs[i].second;
      auto &endpoint = table_pairs[i].first;
      auto need_send = NeedSend(*local_scope.get(), send_var_name);
Q
Qiao Longfei 已提交
209

210 211 212 213 214 215 216 217 218 219 220 221 222 223
      VLOG(4) << "send var name: " << send_var_name
              << "send var endpoint: " << endpoint
              << "need send: " << need_send;

      if (need_send) {
        VLOG(4) << "sending " << send_var_name << " to " << endpoint;

        rets.push_back(rpc_client->AsyncSendVar(
            endpoint, cpu_ctx, *local_scope.get(), send_var_name));
        VLOG(4) << "send var " << send_var_name << " async handle done";
      } else {
        VLOG(4) << "don't send non-initialized variable: "
                << rpc_ctx.splited_var_names[i];
      }
Q
Qiao Longfei 已提交
224
    }
225 226
  } else {
    PADDLE_THROW("unsupported var type to send!");
Q
Qiao Longfei 已提交
227 228
  }

229
  VLOG(4) << "Prepare to send var " << rpc_ctx.var_name;
230 231
  if (sync) {
    for (auto &handle : rets) {
232
      VLOG(4) << "Wait send var to pserver handle: " << handle;
233
      PADDLE_ENFORCE(handle->Wait(), "internal error in RPCClient");
Q
Qiao Longfei 已提交
234 235 236 237
    }
  }
}

Q
Qiao Longfei 已提交
238 239
template struct ParameterSend<float>;

Q
Qiao Longfei 已提交
240 241 242
};  // namespace distributed
};  // namespace operators
};  // namespace paddle