parameter_recv.cc 4.6 KB
Newer Older
Q
Qiao Longfei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
//   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <set>
#include <string>
#include <vector>

#include "paddle/fluid/operators/distributed/parameter_recv.h"

#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor.h"

#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/rpc_client.h"
#include "paddle/fluid/operators/distributed/variable_response.h"
#include "paddle/fluid/operators/distributed_ops/send_recv_util.h"
Q
Qiao Longfei 已提交
30
#include "paddle/fluid/operators/strided_memcpy.h"
Q
Qiao Longfei 已提交
31 32 33 34 35 36 37 38 39 40 41

namespace paddle {
namespace operators {
namespace distributed {

using LoDTensor = framework::LoDTensor;
using LoDTensor = framework::LoDTensor;
using SelectedRows = framework::SelectedRows;
using DDim = framework::DDim;

template <typename T>
42
void ParameterRecv<T>::operator()(const RpcContext &rpc_ctx,
Q
Qiao Longfei 已提交
43
                                  const framework::Scope &scope) {
Q
Qiao Longfei 已提交
44
  VLOG(3) << "ParameterRecv in";
Q
Qiao Longfei 已提交
45 46 47 48 49 50
  framework::Scope *local_scope = scope.NewTmpScope();

  platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
  auto &cpu_ctx = *pool.Get(platform::CPUPlace());

  distributed::RPCClient *rpc_client =
Q
Qiao Longfei 已提交
51
      distributed::RPCClient::GetInstance<RPCCLIENT_T>(0);
Q
Qiao Longfei 已提交
52

53
  auto *recv_var = scope.FindVar(rpc_ctx.var_name);
Q
Qiao Longfei 已提交
54 55 56 57

  // recv all vars to local scope
  if (recv_var->IsType<framework::LoDTensor>()) {
    std::vector<distributed::VarHandlePtr> rets;
58 59
    for (size_t i = 0; i < rpc_ctx.splited_var_names.size(); i++) {
      auto &recv_var_name = rpc_ctx.splited_var_names[i];
Q
Qiao Longfei 已提交
60
      local_scope->Var(recv_var_name);
61 62 63 64
      VLOG(3) << "recv " << recv_var_name << " from " << rpc_ctx.epmap[i];
      rets.push_back(rpc_client->AsyncGetVar(rpc_ctx.epmap[i], cpu_ctx,
                                             *local_scope, recv_var_name,
                                             recv_var_name));
Q
Qiao Longfei 已提交
65
    }
Q
Qiao Longfei 已提交
66 67
    for (size_t i = 0; i < rets.size(); i++) {
      PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient");
Q
Qiao Longfei 已提交
68 69
    }
  } else {
Q
Qiao Longfei 已提交
70
    PADDLE_THROW("unsupported var type to recv!");
Q
Qiao Longfei 已提交
71 72
  }

Q
Qiao Longfei 已提交
73 74 75 76 77
  // concat recved tensor into one var
  {
    size_t output_offset = 0;
    framework::Tensor *recv_tensor =
        recv_var->GetMutable<framework::LoDTensor>();
Q
Qiao Longfei 已提交
78
    auto dev_ctx = paddle::platform::CPUDeviceContext();
Q
Qiao Longfei 已提交
79
    int64_t recv_numel = 0;
Q
Qiao Longfei 已提交
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
    for (auto &recv_var_name : rpc_ctx.splited_var_names) {
      auto *recv_var = local_scope->FindVar(recv_var_name);
      if (recv_var->IsType<framework::LoDTensor>()) {
        auto &in = recv_var->Get<framework::LoDTensor>();
        recv_numel += in.numel();
        auto in_stride = framework::stride_numel(in.dims());
        auto out_stride = framework::stride_numel(recv_tensor->dims());
        StridedNumelCopyWithAxis<T>(
            dev_ctx, 0, recv_tensor->data<T>() + output_offset, out_stride,
            in.data<T>(), in_stride, in_stride[0]);
        output_offset += in_stride[0];
      } else if (recv_var->IsType<framework::SelectedRows>()) {
        auto &recv_slr = recv_var->Get<framework::SelectedRows>();
        auto &recv_dims = recv_tensor->dims();
        int64_t width = recv_dims[1];
        PADDLE_ENFORCE_EQ(recv_slr.height(), recv_dims[0]);
        PADDLE_ENFORCE_EQ(recv_slr.value().dims()[1], width);
        PADDLE_ENFORCE_EQ(recv_slr.value().dims()[0], recv_slr.rows().size());
        VLOG(3) << "recv slr " << recv_var_name << " dims "
                << recv_slr.value().dims();
        for (auto i = 0; i < recv_slr.rows().size(); ++i) {
          auto row_id = recv_slr.rows()[i];
          memcpy(recv_tensor->data<T>() + row_id * width,
                 recv_slr.value().data<T>() + i * width, sizeof(T) * width);
        }
      } else {
        PADDLE_THROW("unsupported recieved var type");
      }
Q
Qiao Longfei 已提交
108
    }
Q
Qiao Longfei 已提交
109
    PADDLE_ENFORCE_EQ(recv_numel, recv_tensor->numel());
Q
Qiao Longfei 已提交
110 111 112
  }

  delete local_scope;
Q
Qiao Longfei 已提交
113
  VLOG(3) << "ParameterRecv out";
Q
Qiao Longfei 已提交
114 115 116 117 118 119 120
}

template struct ParameterRecv<float>;

};  // namespace distributed
};  // namespace operators
};  // namespace paddle