parameter_send.cc 10.3 KB
Newer Older
Q
Qiao Longfei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
//   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/fluid/operators/distributed/parameter_send.h"
Q
Qiao Longfei 已提交
16
#include <memory>
Q
Qiao Longfei 已提交
17 18
#include <set>
#include <string>
19
#include <utility>
Q
Qiao Longfei 已提交
20 21 22 23 24 25 26 27 28 29 30
#include <vector>

#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor.h"

#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/rpc_client.h"
#include "paddle/fluid/operators/distributed/variable_response.h"
#include "paddle/fluid/operators/distributed_ops/send_recv_util.h"
31
#include "paddle/fluid/string/printf.h"
Q
Qiao Longfei 已提交
32 33 34 35 36 37 38 39 40 41

namespace paddle {
namespace operators {
namespace distributed {

using LoDTensor = framework::LoDTensor;
using LoDTensor = framework::LoDTensor;
using SelectedRows = framework::SelectedRows;
using DDim = framework::DDim;

42 43 44 45 46 47 48 49 50 51 52
typedef std::vector<std::pair<std::string, std::string>> EP_SPLIT_TABLE_PAIRS;

inline EP_SPLIT_TABLE_PAIRS GetMultiFieldRpcContext(
    const RpcContext &rpc_ctx, const framework::Scope &scope, int multi_parts) {
  EP_SPLIT_TABLE_PAIRS table_pairs;

  auto *send_var = scope.FindVar(rpc_ctx.var_name);
  if (send_var->IsType<framework::SelectedRows>()) {
    PADDLE_ENFORCE_GT(multi_parts, 0, "multi_parts must >=1");

    if (multi_parts == 1) {
53
      for (size_t i = 0; i < rpc_ctx.splited_var_names.size(); i++) {
54 55 56 57
        table_pairs.push_back(
            std::make_pair(rpc_ctx.epmap[i], rpc_ctx.splited_var_names[i]));
      }
    } else {
58
      for (size_t i = 0; i < rpc_ctx.splited_var_names.size(); i++) {
59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
        for (int x = 0; x < multi_parts; x++) {
          auto table =
              string::Sprintf("%s@%d@PIECE", rpc_ctx.splited_var_names[i], x);
          table_pairs.push_back(std::make_pair(rpc_ctx.epmap[i], table));
        }
      }
    }

  } else if (send_var->IsType<framework::LoDTensor>()) {
    PADDLE_THROW("GetMultiFieldRpcContext can not support LoDTensor current!");
  } else {
    PADDLE_THROW("GetMultiFieldRpcContext unsupported var type!");
  }

  return table_pairs;
}  // namespace distributed

76
template <typename T>
77
void ParameterSend<T>::operator()(const RpcContext &rpc_ctx,
78 79
                                  const framework::Scope &scope, bool sync,
                                  int multi_parts) {
80
  std::unique_ptr<framework::Scope> local_scope = scope.NewTmpScope();
Q
Qiao Longfei 已提交
81 82 83 84 85

  platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
  auto &cpu_ctx = *pool.Get(platform::CPUPlace());

  distributed::RPCClient *rpc_client =
Q
Qiao Longfei 已提交
86
      distributed::RPCClient::GetInstance<RPCCLIENT_T>(rpc_ctx.trainer_id);
Q
Qiao Longfei 已提交
87

88 89
  std::vector<distributed::VarHandlePtr> rets;

90
  auto *send_var = scope.FindVar(rpc_ctx.var_name);
91

Q
Qiao Longfei 已提交
92
  if (send_var->IsType<framework::LoDTensor>()) {
93
    size_t out_num = rpc_ctx.splited_var_names.size();
Q
Qiao Longfei 已提交
94 95 96 97 98
    if (out_num > 1) {
      auto &send_tensor = send_var->Get<framework::LoDTensor>();
      auto &send_tensor_dims = send_tensor.dims();
      std::vector<framework::DDim> outs_dims;
      outs_dims.reserve(out_num);
Q
Qiao Longfei 已提交
99

Q
Qiao Longfei 已提交
100
      // infer output shape
101
      PADDLE_ENFORCE_EQ(rpc_ctx.height_sections.size(), out_num,
Q
Qiao Longfei 已提交
102 103 104 105
                        "tensor split sections size"
                        "should be equal to output size.");
      for (size_t i = 0; i < out_num; ++i) {
        auto dim = send_tensor_dims;
106
        dim[0] = rpc_ctx.height_sections[i];
Q
Qiao Longfei 已提交
107 108 109
        outs_dims.push_back(dim);
      }

Q
Qiao Longfei 已提交
110 111
      // create output var in local scope
      size_t row_offset = 0;
112
      for (size_t i = 0; i < out_num; ++i) {
113
        framework::Tensor *out = local_scope->Var(rpc_ctx.splited_var_names[i])
Q
Qiao Longfei 已提交
114
                                     ->GetMutable<framework::LoDTensor>();
Q
Qiao Longfei 已提交
115 116 117
        *out = send_tensor.Slice(row_offset, row_offset + outs_dims[i][0]);
        row_offset += outs_dims[i][0];
      }
Q
Qiao Longfei 已提交
118
    }
1
123malin 已提交
119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154
    if (rpc_ctx.use_send_handler) {
      for (size_t i = 0; i < rpc_ctx.splited_var_names.size(); i++) {
        auto &send_var_name = rpc_ctx.splited_var_names[i];
        VLOG(4) << "send var name: " << send_var_name;
        auto &endpoint = rpc_ctx.epmap[i];
        VLOG(4) << "send var endpoint: " << endpoint;
        VLOG(4) << "need send: " << NeedSend(*local_scope.get(), send_var_name);
        if (NeedSend(*local_scope.get(), send_var_name)) {
          VLOG(3) << "sending " << send_var_name << " to " << endpoint;
          rets.push_back(rpc_client->AsyncSendVar(
              endpoint, cpu_ctx, *local_scope.get(), send_var_name));
          VLOG(4) << "send var " << send_var_name << " async handle done";
        } else {
          VLOG(3) << "don't send non-initialized variable: "
                  << rpc_ctx.splited_var_names[i];
        }
      }
    } else {
      for (size_t i = 0; i < rpc_ctx.splited_var_names.size(); i++) {
        for (size_t j = 0; j < rpc_ctx.epmap.size(); j++) {
          auto &send_var_name = rpc_ctx.splited_var_names[i];
          VLOG(4) << "send var name: " << send_var_name;
          auto &endpoint = rpc_ctx.epmap[j];
          VLOG(4) << "send var endpoint: " << endpoint;
          VLOG(4) << "need send: "
                  << NeedSend(*local_scope.get(), send_var_name);
          if (NeedSend(*local_scope.get(), send_var_name)) {
            VLOG(3) << "sending " << send_var_name << " to " << endpoint;
            rets.push_back(rpc_client->AsyncDistributeNotify(
                endpoint, cpu_ctx, *local_scope.get(), send_var_name));
            VLOG(4) << "send var " << send_var_name << " async handle done";
          } else {
            VLOG(3) << "don't send non-initialized variable: "
                    << rpc_ctx.splited_var_names[i];
          }
        }
155 156
      }
    }
157
  } else if (send_var->IsType<framework::SelectedRows>()) {
Q
Qiao Longfei 已提交
158
    auto &send_slr = send_var->Get<framework::SelectedRows>();
159
    auto abs_sections = ToAbsoluteSection(rpc_ctx.height_sections);
160

Q
Qiao Longfei 已提交
161
    auto &send_rows = send_slr.rows();
162 163 164 165 166 167 168
    if (send_rows.size() == 0) {
      LOG(WARNING) << "WARNING: The variable sent to pserver is empty, which "
                      "may cause an unknown error. Please check the state of "
                      "use_double_buffer in pyreader async mode, you need to "
                      "turn it false.";
    }

169 170
    std::vector<std::vector<size_t>> outs_rows_idx;
    std::vector<std::vector<size_t>> outs_dense_idx;
171

172 173 174 175
    auto table_pairs = GetMultiFieldRpcContext(rpc_ctx, scope, multi_parts);

    outs_rows_idx.resize(table_pairs.size());
    outs_dense_idx.resize(table_pairs.size());
176 177

    auto row_numel = send_slr.value().numel() / send_slr.value().dims()[0];
Q
Qiao Longfei 已提交
178
    auto *src = send_slr.value().data<T>();
179

Q
Qiao Longfei 已提交
180
    // create output var in local scope
Q
Qiao Longfei 已提交
181
    std::vector<framework::SelectedRows *> outs;
182 183 184
    for (auto &table : table_pairs) {
      auto *out =
          local_scope->Var(table.second)->GetMutable<framework::SelectedRows>();
185 186 187 188 189
      outs.push_back(out);
    }

    // split rows index into output sparse vars
    for (size_t i = 0; i < send_rows.size(); ++i) {
190 191 192
      auto ep_idx = GetSectionIndex(send_rows[i], abs_sections);
      auto table_idx = send_rows[i] % multi_parts;
      auto out_idx = ep_idx * multi_parts + table_idx;
193 194
      outs_rows_idx[out_idx].push_back(send_rows[i]);
      outs_dense_idx[out_idx].push_back(i);
Q
Qiao Longfei 已提交
195
    }
196

Q
Qiao Longfei 已提交
197
    auto place = platform::CPUPlace();
198

199
    for (size_t ctx = 0; ctx < rpc_ctx.splited_var_names.size(); ctx++) {
200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224
      for (int part = 0; part < multi_parts; part++) {
        auto out_idx = ctx * multi_parts + part;
        auto rows_idx = outs_rows_idx[out_idx];

        auto dims = send_slr.GetCompleteDims();
        dims[0] = rows_idx.size();

        outs[out_idx]->set_height(rpc_ctx.height_sections[ctx]);
        outs[out_idx]->mutable_rows()->clear();
        outs[out_idx]->mutable_value()->mutable_data<T>(dims, send_slr.place());

        if (rows_idx.size() > 0) {
          for (auto idx : rows_idx) {
            outs[out_idx]->mutable_rows()->push_back(idx - abs_sections[ctx]);
          }
          auto dst = outs[out_idx]->mutable_value()->mutable_data<T>(place);
          for (size_t j = 0; j < rows_idx.size(); j++) {
            if (platform::is_cpu_place(place)) {
              memory::Copy(platform::CPUPlace(), dst + j * row_numel,
                           platform::CPUPlace(),
                           src + outs_dense_idx[out_idx][j] * row_numel,
                           sizeof(T) * row_numel);
            } else {
              PADDLE_THROW("do not support GPU now");
            }
225 226
          }
        }
227 228
        PADDLE_ENFORCE_EQ(rows_idx.size(), outs[out_idx]->rows().size(),
                          "rows should has the same size with tensor dim 0");
229 230 231
      }
    }

232 233 234 235
    for (size_t i = 0; i < table_pairs.size(); i++) {
      auto &send_var_name = table_pairs[i].second;
      auto &endpoint = table_pairs[i].first;
      auto need_send = NeedSend(*local_scope.get(), send_var_name);
Q
Qiao Longfei 已提交
236

237 238 239 240 241 242 243 244 245 246 247 248 249 250
      VLOG(4) << "send var name: " << send_var_name
              << "send var endpoint: " << endpoint
              << "need send: " << need_send;

      if (need_send) {
        VLOG(4) << "sending " << send_var_name << " to " << endpoint;

        rets.push_back(rpc_client->AsyncSendVar(
            endpoint, cpu_ctx, *local_scope.get(), send_var_name));
        VLOG(4) << "send var " << send_var_name << " async handle done";
      } else {
        VLOG(4) << "don't send non-initialized variable: "
                << rpc_ctx.splited_var_names[i];
      }
Q
Qiao Longfei 已提交
251
    }
252 253
  } else {
    PADDLE_THROW("unsupported var type to send!");
Q
Qiao Longfei 已提交
254 255
  }

256
  VLOG(4) << "Prepare to send var " << rpc_ctx.var_name;
257 258
  if (sync) {
    for (auto &handle : rets) {
259
      VLOG(4) << "Wait send var to pserver handle: " << handle;
260
      PADDLE_ENFORCE(handle->Wait(), "internal error in RPCClient");
Q
Qiao Longfei 已提交
261 262 263 264
    }
  }
}

Q
Qiao Longfei 已提交
265 266
template struct ParameterSend<float>;

Q
Qiao Longfei 已提交
267 268 269
};  // namespace distributed
};  // namespace operators
};  // namespace paddle