parameter_recv.cc 5.3 KB
Newer Older
Q
Qiao Longfei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
//   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <set>
#include <string>
#include <vector>

#include "paddle/fluid/operators/distributed/parameter_recv.h"

#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor.h"

#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/rpc_client.h"
#include "paddle/fluid/operators/distributed/variable_response.h"
#include "paddle/fluid/operators/distributed_ops/send_recv_util.h"
Q
Qiao Longfei 已提交
30
#include "paddle/fluid/operators/strided_memcpy.h"
Q
Qiao Longfei 已提交
31 32 33 34 35 36 37 38 39 40 41

namespace paddle {
namespace operators {
namespace distributed {

using LoDTensor = framework::LoDTensor;
using LoDTensor = framework::LoDTensor;
using SelectedRows = framework::SelectedRows;
using DDim = framework::DDim;

template <typename T>
42
void ParameterRecv<T>::operator()(const RpcContext &rpc_ctx,
Q
Qiao Longfei 已提交
43
                                  const framework::Scope &scope) {
Q
Qiao Longfei 已提交
44
  VLOG(3) << "ParameterRecv in " << rpc_ctx.var_name;
Q
Qiao Longfei 已提交
45 46 47 48 49 50
  framework::Scope *local_scope = scope.NewTmpScope();

  platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
  auto &cpu_ctx = *pool.Get(platform::CPUPlace());

  distributed::RPCClient *rpc_client =
Q
Qiao Longfei 已提交
51
      distributed::RPCClient::GetInstance<RPCCLIENT_T>(rpc_ctx.trainer_id);
Q
Qiao Longfei 已提交
52

53
  auto *recv_var = scope.FindVar(rpc_ctx.var_name);
Q
Qiao Longfei 已提交
54 55 56 57

  // recv all vars to local scope
  if (recv_var->IsType<framework::LoDTensor>()) {
    std::vector<distributed::VarHandlePtr> rets;
58 59
    for (size_t i = 0; i < rpc_ctx.splited_var_names.size(); i++) {
      auto &recv_var_name = rpc_ctx.splited_var_names[i];
Q
Qiao Longfei 已提交
60
      local_scope->Var(recv_var_name);
61 62 63
      VLOG(3) << "recv " << recv_var_name << " from " << rpc_ctx.epmap[i];
      rets.push_back(rpc_client->AsyncGetVar(rpc_ctx.epmap[i], cpu_ctx,
                                             *local_scope, recv_var_name,
Q
Qiao Longfei 已提交
64
                                             recv_var_name, recv_var_name));
Q
Qiao Longfei 已提交
65
    }
Q
Qiao Longfei 已提交
66 67
    for (size_t i = 0; i < rets.size(); i++) {
      PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient");
Q
Qiao Longfei 已提交
68 69
    }
  } else {
Q
Qiao Longfei 已提交
70
    PADDLE_THROW("unsupported var type to recv!");
Q
Qiao Longfei 已提交
71 72
  }

Q
Qiao Longfei 已提交
73 74 75
  // concat recved tensor into one var
  {
    size_t output_offset = 0;
Q
Qiao Longfei 已提交
76
    size_t row_offset = 0;
Q
Qiao Longfei 已提交
77 78
    framework::Tensor *recv_tensor =
        recv_var->GetMutable<framework::LoDTensor>();
Q
Qiao Longfei 已提交
79
    auto dev_ctx = paddle::platform::CPUDeviceContext();
Q
Qiao Longfei 已提交
80
    int64_t recv_numel = 0;
Q
Qiao Longfei 已提交
81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
    for (auto &recv_var_name : rpc_ctx.splited_var_names) {
      auto *recv_var = local_scope->FindVar(recv_var_name);
      if (recv_var->IsType<framework::LoDTensor>()) {
        auto &in = recv_var->Get<framework::LoDTensor>();
        recv_numel += in.numel();
        auto in_stride = framework::stride_numel(in.dims());
        auto out_stride = framework::stride_numel(recv_tensor->dims());
        StridedNumelCopyWithAxis<T>(
            dev_ctx, 0, recv_tensor->data<T>() + output_offset, out_stride,
            in.data<T>(), in_stride, in_stride[0]);
        output_offset += in_stride[0];
      } else if (recv_var->IsType<framework::SelectedRows>()) {
        auto &recv_slr = recv_var->Get<framework::SelectedRows>();
        auto &recv_dims = recv_tensor->dims();
        int64_t width = recv_dims[1];
Q
Qiao Longfei 已提交
96
        recv_numel += recv_slr.height() * width;
Q
Qiao Longfei 已提交
97 98 99 100
        PADDLE_ENFORCE_EQ(recv_slr.value().dims()[1], width);
        PADDLE_ENFORCE_EQ(recv_slr.value().dims()[0], recv_slr.rows().size());
        VLOG(3) << "recv slr " << recv_var_name << " dims "
                << recv_slr.value().dims();
Q
Qiao Longfei 已提交
101 102 103 104 105 106 107 108 109 110
        if (VLOG_IS_ON(3)) {
          std::ostringstream sstream;
          sstream << "[";
          for (auto &row_id : recv_slr.rows()) {
            sstream << row_id << ", ";
          }
          sstream << "]";
          VLOG(3) << "recv_slr size: " << recv_slr.rows().size() << " "
                  << sstream.str();
        }
Q
Qiao Longfei 已提交
111 112 113

        // FIXME(qiao): use a trick to avoid the bug of recv an selected rows
        for (auto i = 1; i < recv_slr.rows().size(); ++i) {
Q
Qiao Longfei 已提交
114
          auto row_id = recv_slr.rows()[i] + row_offset;
Q
Qiao Longfei 已提交
115
          PADDLE_ENFORCE_LT(row_id, recv_dims[0]);
Q
Qiao Longfei 已提交
116 117 118
          memcpy(recv_tensor->data<T>() + row_id * width,
                 recv_slr.value().data<T>() + i * width, sizeof(T) * width);
        }
Q
Qiao Longfei 已提交
119
        row_offset += recv_slr.height();
Q
Qiao Longfei 已提交
120 121 122
      } else {
        PADDLE_THROW("unsupported recieved var type");
      }
Q
Qiao Longfei 已提交
123
    }
Q
Qiao Longfei 已提交
124 125 126 127 128
    auto numel = recv_tensor->numel();
    if (recv_numel != numel) {
      LOG(FATAL) << "recv_numel: " << recv_numel << " acture numel: " << numel;
    }
    PADDLE_ENFORCE_EQ(recv_numel, numel);
Q
Qiao Longfei 已提交
129 130 131
  }

  delete local_scope;
Q
Qiao Longfei 已提交
132
  VLOG(3) << "ParameterRecv out " << rpc_ctx.var_name;
Q
Qiao Longfei 已提交
133 134 135 136 137 138 139
}

template struct ParameterRecv<float>;

};  // namespace distributed
};  // namespace operators
};  // namespace paddle