sendrecvop_utils.cc 7.5 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
G
gongweibao 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
Y
Yi Wang 已提交
16

17
#include <sys/time.h>
Y
Yi Wang 已提交
18 19
#include <thread>  // NOLINT

20 21 22 23 24
#include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/zero_copy_stream.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/operators/detail/bytebuffer_stream.h"
#include "paddle/fluid/operators/detail/proto_encoder_helper.h"
25
#include "paddle/fluid/operators/detail/variable_response.h"
26
#include "paddle/fluid/platform/profiler.h"
G
gongweibao 已提交
27 28 29 30 31

namespace paddle {
namespace operators {
namespace detail {

T
typhoonzero 已提交
32 33 34 35 36 37
using VarMsg = sendrecv::VariableMessage;

void GetTensorPayload(framework::Variable* var,
                      const platform::DeviceContext& ctx, VarMsg* request,
                      void** payload, size_t* payload_size) {
  auto tensor = var->Get<framework::LoDTensor>();
T
typhoonzero 已提交
38
  // FIXME(wuyi): data types in send_recv.proto is copied from
T
typhoonzero 已提交
39
  // framework.proto
T
typhoonzero 已提交
40 41
  request->set_data_type(
      static_cast<VarMsg::Type>(framework::ToDataType(tensor.type())));
T
typhoonzero 已提交
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
  for (auto& dim : framework::vectorize(tensor.dims())) {
    request->add_dims(dim);
  }
  const framework::LoD lod = tensor.lod();
  if (lod.size() > 0) {
    request->set_lod_level(lod.size());
    for (auto& each : lod) {
      VarMsg::LodData* lod_inner = request->add_lod();
      for (auto& d : each) {
        lod_inner->add_lod_data(d);
      }
    }
  }
  if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef PADDLE_WITH_CUDA
    PADDLE_ENFORCE(platform::is_gpu_place(tensor.place()));
    platform::CPUPlace cpu;
    auto& gpu_dev_ctx = static_cast<const platform::CUDADeviceContext&>(ctx);
    auto copy_size = tensor.numel() * framework::SizeOfType(tensor.type());
    *payload = memory::Alloc(cpu, copy_size);

    memory::Copy(cpu, *payload, boost::get<platform::CUDAPlace>(tensor.place()),
                 reinterpret_cast<const void*>(tensor.data<void>()), copy_size,
                 gpu_dev_ctx.stream());
    ctx.Wait();
#endif
  } else {
    *payload = tensor.data<void>();
  }
  *payload_size = tensor.numel() * framework::SizeOfType(tensor.type());
}

void GetSelectedRowsPayload(framework::Variable* var,
                            const platform::DeviceContext& ctx, VarMsg* request,
                            void** payload, size_t* payload_size) {
  auto* slr = var->GetMutable<framework::SelectedRows>();
T
typhoonzero 已提交
78 79
  request->set_data_type(
      static_cast<VarMsg::Type>(framework::ToDataType(slr->value().type())));
T
typhoonzero 已提交
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
  request->set_lod_level(0);
  request->set_slr_height(slr->height());

  for (auto& dim : framework::vectorize(slr->value().dims())) {
    request->add_dims(dim);
  }

  auto* tensor = slr->mutable_value();
  if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef PADDLE_WITH_CUDA
    platform::CPUPlace cpu;
    auto& gpu_dev_ctx = static_cast<const platform::CUDADeviceContext&>(ctx);
    auto copy_size = tensor->numel() * framework::SizeOfType(tensor->type());
    *payload = memory::Alloc(cpu, copy_size);
    memory::Copy(cpu, *payload,
                 boost::get<platform::CUDAPlace>(tensor->place()),
                 reinterpret_cast<const void*>(tensor->data<void>()), copy_size,
                 gpu_dev_ctx.stream());
    ctx.Wait();
#endif
  } else {
    *payload = slr->mutable_value()->data<void>();
  }
  *payload_size = tensor->numel() * framework::SizeOfType(tensor->type());
}

106 107
void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
                           const platform::DeviceContext& ctx,
Y
Yancey1989 已提交
108 109
                           ::grpc::ByteBuffer* msg,
                           const std::string& out_name) {
T
typhoonzero 已提交
110 111
  // Default DestroyCallback does nothing, When using GPU
  // the CPU buffer need to be freed.
112
  DestroyCallback destroy_callback = [](void* backing) {};
T
typhoonzero 已提交
113
  VarMsg request;
Y
Yancey 已提交
114
  void* payload = nullptr;
115
  size_t payload_size;
T
typhoonzero 已提交
116 117

  request.set_varname(name);
118 119 120 121
  // Note: normally the profiler is enabled in 1 trainer, hence only
  // 1 trainer returns true for ShouldSendProfileState(). It tells PS
  // servers the trainer's profiling state so that PS can follow the
  // trainer.
T
typhoonzero 已提交
122 123 124
  request.set_profile(platform::IsProfileEnabled());
  if (!out_name.empty()) {
    request.set_out_varname(out_name);
125
  }
126
  if (var->IsType<framework::LoDTensor>()) {
T
typhoonzero 已提交
127 128
    request.set_type(::sendrecv::LOD_TENSOR);
    GetTensorPayload(var, ctx, &request, &payload, &payload_size);
129
  } else if (var->IsType<framework::SelectedRows>()) {
T
typhoonzero 已提交
130 131 132 133 134
    request.set_type(::sendrecv::SELECTED_ROWS);
    GetSelectedRowsPayload(var, ctx, &request, &payload, &payload_size);
  } else {
    PADDLE_THROW("Serialize does not support type: %s",
                 typeid(var->Type()).name());
135 136
  }

T
typhoonzero 已提交
137 138 139 140 141 142 143
  if (platform::is_gpu_place(ctx.GetPlace())) {
    // GPU data is copied to CPU buffer when sending,
    // free the buffer when possible.
    destroy_callback = [](void* backing) {
      platform::CPUPlace cpu;
      memory::Free(cpu, backing);
    };
Y
Yancey1989 已提交
144
  }
145

T
typhoonzero 已提交
146 147 148 149 150 151 152
  std::string header;
  request.AppendToString(&header);
  auto buffer = std::unique_ptr<char[]>(new char[1024]);
  void* buf = buffer.get();
  ProtoEncodeHelper e(static_cast<char*>(buf), 1024);
  e.WriteRawBytes(std::string(header.data(), header.size()));
  e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size);
153 154 155 156 157 158 159 160 161 162
  // steal reference of tensor data
  ::grpc::Slice slices[4];  // metadata, tensor, rows meta, rows
  int num_slices = 2;       // only SelectedRows have rows buffer
  slices[0] = ::grpc::Slice(e.size());
  memcpy(const_cast<uint8_t*>(slices[0].begin()), e.data(), e.size());
  slices[1] = ::grpc::Slice(
      grpc_slice_new_with_user_data(payload, payload_size, destroy_callback,
                                    static_cast<char*>(payload)),
      ::grpc::Slice::STEAL_REF);

T
typhoonzero 已提交
163
  if (var->IsType<framework::SelectedRows>()) {
164
    auto* slr = var->GetMutable<framework::SelectedRows>();
Y
Yi Wang 已提交
165
    ProtoEncodeHelper e2(static_cast<char*>(buf), 128);
166
    size_t rows_memory_size =
T
typhoonzero 已提交
167
        slr->rows().size() * framework::SizeOfType(typeid(int64_t));
168 169 170 171 172 173 174 175
    e2.WriteVarlengthBeginning(VarMsg::kRowsFieldNumber, rows_memory_size);
    slices[2] = ::grpc::Slice(e2.size());
    memcpy(const_cast<uint8_t*>(slices[2].begin()), e2.data(), e2.size());

    slices[3] = ::grpc::Slice(
        grpc_slice_new_with_user_data(
            const_cast<void*>(
                reinterpret_cast<const void*>(slr->rows().data())),
T
typhoonzero 已提交
176
            rows_memory_size, [](void* backing) {},
177 178 179 180 181 182 183 184 185 186 187 188
            const_cast<char*>(
                reinterpret_cast<const char*>(slr->rows().data()))),
        ::grpc::Slice::STEAL_REF);
    num_slices = 4;
  }

  ::grpc::ByteBuffer tmp(&slices[0], num_slices);
  msg->Swap(&tmp);
}

void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg,
                               const platform::DeviceContext& ctx,
189
                               const framework::Scope* scope,
Y
Yi Wang 已提交
190
                               framework::Variable** var) {
191
  operators::detail::VariableResponse resp(scope, &ctx);
192
  PADDLE_ENFORCE(resp.Parse(msg) == 0, "parse bytebuffer to tensor error!");
Y
Yi Wang 已提交
193
  *var = resp.GetVar();
194 195
}

G
gongweibao 已提交
196 197
}  // namespace detail
}  // namespace operators
Y
Yancey 已提交
198
}  // namespace paddle