sendrecvop_utils.cc 8.6 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
G
gongweibao 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
Y
Yi Wang 已提交
16

T
typhoonzero 已提交
17
#ifdef PADDLE_WITH_CUDA
T
fix ci  
typhoonzero 已提交
18
#include <nccl.h>
T
typhoonzero 已提交
19
#endif
20
#include <sys/time.h>
Y
Yi Wang 已提交
21 22
#include <thread>  // NOLINT

23 24 25 26 27
#include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/zero_copy_stream.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/operators/detail/bytebuffer_stream.h"
#include "paddle/fluid/operators/detail/proto_encoder_helper.h"
28
#include "paddle/fluid/operators/detail/variable_response.h"
29
#include "paddle/fluid/platform/profiler.h"
G
gongweibao 已提交
30 31 32 33 34

namespace paddle {
namespace operators {
namespace detail {

T
typhoonzero 已提交
35 36 37 38 39 40
using VarMsg = sendrecv::VariableMessage;

void GetTensorPayload(framework::Variable* var,
                      const platform::DeviceContext& ctx, VarMsg* request,
                      void** payload, size_t* payload_size) {
  auto tensor = var->Get<framework::LoDTensor>();
T
typhoonzero 已提交
41
  // FIXME(wuyi): data types in send_recv.proto is copied from
T
typhoonzero 已提交
42
  // framework.proto
T
typhoonzero 已提交
43 44
  request->set_data_type(
      static_cast<VarMsg::Type>(framework::ToDataType(tensor.type())));
T
typhoonzero 已提交
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
  for (auto& dim : framework::vectorize(tensor.dims())) {
    request->add_dims(dim);
  }
  const framework::LoD lod = tensor.lod();
  if (lod.size() > 0) {
    request->set_lod_level(lod.size());
    for (auto& each : lod) {
      VarMsg::LodData* lod_inner = request->add_lod();
      for (auto& d : each) {
        lod_inner->add_lod_data(d);
      }
    }
  }
  if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef PADDLE_WITH_CUDA
    PADDLE_ENFORCE(platform::is_gpu_place(tensor.place()));
Y
yi.wu 已提交
61
    platform::CUDAPinnedPlace cuda_pinned;
T
typhoonzero 已提交
62 63
    auto& gpu_dev_ctx = static_cast<const platform::CUDADeviceContext&>(ctx);
    auto copy_size = tensor.numel() * framework::SizeOfType(tensor.type());
Y
yi.wu 已提交
64
    *payload = memory::Alloc(cuda_pinned, copy_size);
T
typhoonzero 已提交
65

Y
yi.wu 已提交
66 67
    memory::Copy(cuda_pinned, *payload,
                 boost::get<platform::CUDAPlace>(tensor.place()),
T
typhoonzero 已提交
68 69 70 71 72 73 74 75 76 77 78 79 80 81
                 reinterpret_cast<const void*>(tensor.data<void>()), copy_size,
                 gpu_dev_ctx.stream());
    ctx.Wait();
#endif
  } else {
    *payload = tensor.data<void>();
  }
  *payload_size = tensor.numel() * framework::SizeOfType(tensor.type());
}

void GetSelectedRowsPayload(framework::Variable* var,
                            const platform::DeviceContext& ctx, VarMsg* request,
                            void** payload, size_t* payload_size) {
  auto* slr = var->GetMutable<framework::SelectedRows>();
T
typhoonzero 已提交
82 83
  request->set_data_type(
      static_cast<VarMsg::Type>(framework::ToDataType(slr->value().type())));
T
typhoonzero 已提交
84 85 86 87 88 89 90 91 92 93
  request->set_lod_level(0);
  request->set_slr_height(slr->height());

  for (auto& dim : framework::vectorize(slr->value().dims())) {
    request->add_dims(dim);
  }

  auto* tensor = slr->mutable_value();
  if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef PADDLE_WITH_CUDA
Y
yi.wu 已提交
94
    platform::CUDAPinnedPlace cuda_pinned;
T
typhoonzero 已提交
95 96
    auto& gpu_dev_ctx = static_cast<const platform::CUDADeviceContext&>(ctx);
    auto copy_size = tensor->numel() * framework::SizeOfType(tensor->type());
Y
yi.wu 已提交
97 98
    *payload = memory::Alloc(cuda_pinned, copy_size);
    memory::Copy(cuda_pinned, *payload,
T
typhoonzero 已提交
99 100 101 102 103 104 105 106 107 108 109
                 boost::get<platform::CUDAPlace>(tensor->place()),
                 reinterpret_cast<const void*>(tensor->data<void>()), copy_size,
                 gpu_dev_ctx.stream());
    ctx.Wait();
#endif
  } else {
    *payload = slr->mutable_value()->data<void>();
  }
  *payload_size = tensor->numel() * framework::SizeOfType(tensor->type());
}

110 111
void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
                           const platform::DeviceContext& ctx,
Y
Yancey1989 已提交
112 113
                           ::grpc::ByteBuffer* msg,
                           const std::string& out_name) {
T
typhoonzero 已提交
114 115
  // Default DestroyCallback does nothing, When using GPU
  // the CPU buffer need to be freed.
116
  DestroyCallback destroy_callback = [](void* backing) {};
T
typhoonzero 已提交
117
  VarMsg request;
Y
Yancey 已提交
118
  void* payload = nullptr;
119
  size_t payload_size;
T
typhoonzero 已提交
120 121

  request.set_varname(name);
122 123 124 125
  // Note: normally the profiler is enabled in 1 trainer, hence only
  // 1 trainer returns true for ShouldSendProfileState(). It tells PS
  // servers the trainer's profiling state so that PS can follow the
  // trainer.
X
Xin Pan 已提交
126 127
  if (platform::ShouldSendProfileState()) {
    if (platform::IsProfileEnabled()) {
X
Xin Pan 已提交
128
      request.set_profile(platform::kEnableProfiler);
X
Xin Pan 已提交
129
    } else {
X
Xin Pan 已提交
130
      request.set_profile(platform::kDisableProfiler);
X
Xin Pan 已提交
131 132
    }
  }
T
typhoonzero 已提交
133 134
  if (!out_name.empty()) {
    request.set_out_varname(out_name);
135
  }
136
  if (var->IsType<framework::LoDTensor>()) {
T
typhoonzero 已提交
137 138
    request.set_type(::sendrecv::LOD_TENSOR);
    GetTensorPayload(var, ctx, &request, &payload, &payload_size);
139
  } else if (var->IsType<framework::SelectedRows>()) {
T
typhoonzero 已提交
140 141
    request.set_type(::sendrecv::SELECTED_ROWS);
    GetSelectedRowsPayload(var, ctx, &request, &payload, &payload_size);
T
typhoonzero 已提交
142
#ifdef PADDLE_WITH_CUDA
T
typhoonzero 已提交
143
  } else if (var->IsType<ncclUniqueId>()) {
144
    request.set_type(::sendrecv::NCCL_ID);
T
typhoonzero 已提交
145
#endif
T
typhoonzero 已提交
146 147 148
  } else {
    PADDLE_THROW("Serialize does not support type: %s",
                 typeid(var->Type()).name());
149 150
  }

T
typhoonzero 已提交
151
  if (platform::is_gpu_place(ctx.GetPlace())) {
152
#ifdef PADDLE_WITH_CUDA
T
typhoonzero 已提交
153 154 155
    // GPU data is copied to CPU buffer when sending,
    // free the buffer when possible.
    destroy_callback = [](void* backing) {
Y
yi.wu 已提交
156 157
      platform::CUDAPinnedPlace cuda_pinned;
      memory::Free(cuda_pinned, backing);
T
typhoonzero 已提交
158
    };
159
#endif
Y
Yancey1989 已提交
160
  }
161

T
typhoonzero 已提交
162 163 164 165 166 167
  std::string header;
  request.AppendToString(&header);
  auto buffer = std::unique_ptr<char[]>(new char[1024]);
  void* buf = buffer.get();
  ProtoEncodeHelper e(static_cast<char*>(buf), 1024);
  e.WriteRawBytes(std::string(header.data(), header.size()));
168 169
// NCCLID is copied directly to the message, return bytebuffer
// with only one slice if serializing NCCLID.
T
typhoonzero 已提交
170
#ifdef PADDLE_WITH_CUDA
171
  if (var->IsType<ncclUniqueId>()) {
T
typhoonzero 已提交
172 173
    e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber,
                              NCCL_UNIQUE_ID_BYTES);
Y
update  
yi.wu 已提交
174
    const ncclUniqueId& uid = var->Get<ncclUniqueId>();
T
typhoonzero 已提交
175
    e.WriteRawBytes(std::string(uid.internal, NCCL_UNIQUE_ID_BYTES));
176

T
typhoonzero 已提交
177 178 179 180 181 182 183
    // for serialize NCCL_ID
    ::grpc::Slice slices(e.size());
    memcpy(const_cast<uint8_t*>(slices.begin()), e.data(), e.size());
    ::grpc::ByteBuffer tmp(&slices, 1);
    msg->Swap(&tmp);
    return;
  }
T
typhoonzero 已提交
184
#endif
185

T
typhoonzero 已提交
186
  e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size);
187 188 189 190 191 192 193 194 195 196
  // steal reference of tensor data
  ::grpc::Slice slices[4];  // metadata, tensor, rows meta, rows
  int num_slices = 2;       // only SelectedRows have rows buffer
  slices[0] = ::grpc::Slice(e.size());
  memcpy(const_cast<uint8_t*>(slices[0].begin()), e.data(), e.size());
  slices[1] = ::grpc::Slice(
      grpc_slice_new_with_user_data(payload, payload_size, destroy_callback,
                                    static_cast<char*>(payload)),
      ::grpc::Slice::STEAL_REF);

T
typhoonzero 已提交
197
  if (var->IsType<framework::SelectedRows>()) {
198
    auto* slr = var->GetMutable<framework::SelectedRows>();
Y
Yi Wang 已提交
199
    ProtoEncodeHelper e2(static_cast<char*>(buf), 128);
200
    size_t rows_memory_size =
T
typhoonzero 已提交
201
        slr->rows().size() * framework::SizeOfType(typeid(int64_t));
202 203 204 205 206 207 208 209
    e2.WriteVarlengthBeginning(VarMsg::kRowsFieldNumber, rows_memory_size);
    slices[2] = ::grpc::Slice(e2.size());
    memcpy(const_cast<uint8_t*>(slices[2].begin()), e2.data(), e2.size());

    slices[3] = ::grpc::Slice(
        grpc_slice_new_with_user_data(
            const_cast<void*>(
                reinterpret_cast<const void*>(slr->rows().data())),
T
typhoonzero 已提交
210
            rows_memory_size, [](void* backing) {},
211 212 213 214 215 216 217 218 219 220 221 222
            const_cast<char*>(
                reinterpret_cast<const char*>(slr->rows().data()))),
        ::grpc::Slice::STEAL_REF);
    num_slices = 4;
  }

  ::grpc::ByteBuffer tmp(&slices[0], num_slices);
  msg->Swap(&tmp);
}

void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg,
                               const platform::DeviceContext& ctx,
223
                               const framework::Scope* scope,
Y
Yi Wang 已提交
224
                               framework::Variable** var) {
225
  operators::detail::VariableResponse resp(scope, &ctx);
226
  PADDLE_ENFORCE(resp.Parse(msg) == 0, "parse bytebuffer to tensor error!");
Y
Yi Wang 已提交
227
  *var = resp.GetVar();
228 229
}

G
gongweibao 已提交
230 231
}  // namespace detail
}  // namespace operators
Y
Yancey 已提交
232
}  // namespace paddle