variable_response.cc 7.3 KB
Newer Older
1
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/fluid/operators/distributed/variable_response.h"
Y
Yi Wang 已提交
16
#include <vector>
17
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
18

Q
Qiao Longfei 已提交
19 20 21
DEFINE_string(rpc_server_profile_path, "./profile_ps",
              "the profile log file path");

22 23
namespace paddle {
namespace operators {
24
namespace distributed {
25

26 27 28 29
bool VariableResponse::ReadRaw(::google::protobuf::io::CodedInputStream* input,
                               const platform::DeviceContext& dev_ctx,
                               platform::Place place, void* dest,
                               int64_t size) {
30 31
  const void* data = NULL;
  int size_to_write = 0;
32
  int64_t length = size;
Y
yi.wu 已提交
33
  int total_written = 0;
34 35 36 37 38 39 40 41

  if (platform::is_gpu_place(place)) {
#ifdef PADDLE_WITH_CUDA
    auto& gpu_dev_ctx =
        static_cast<const platform::CUDADeviceContext&>(dev_ctx);
    platform::CPUPlace cpu;

    char* p = reinterpret_cast<char*>(dest);
Y
yi.wu 已提交
42
    while (total_written < length) {
43 44 45
      if (!input->GetDirectBufferPointer(&data, &size_to_write)) {
        return false;
      }
Y
yi.wu 已提交
46 47 48 49 50 51
      // NOTE: if raw buffer is large and have two neighbor fields of raw
      // buffers GetDirectBufferPointer can get all of them, use length to
      // truncate it.
      if (total_written + size_to_write > length) {
        size_to_write = length - total_written;
      }
G
gongweibao 已提交
52
      // This log is useful to see how long a internal block size is of rpc.
M
minqiyang 已提交
53
      VLOG(7) << "copy " << size_to_write << " data to CUDAPlace";
54 55 56 57
      memory::Copy(boost::get<platform::CUDAPlace>(place),
                   reinterpret_cast<void*>(p), cpu, data, size_to_write,
                   gpu_dev_ctx.stream());
      p += size_to_write;
Y
yi.wu 已提交
58
      total_written += size_to_write;
59 60 61 62 63 64 65 66 67 68 69

      input->Skip(size_to_write);
    }
    gpu_dev_ctx.Wait();
#else
    PADDLE_THROW("Unexpected branch");
#endif
    return true;
  }

  char* p = reinterpret_cast<char*>(dest);
Y
yi.wu 已提交
70
  while (total_written < length) {
71 72 73
    if (!input->GetDirectBufferPointer(&data, &size_to_write)) {
      return false;
    }
Y
yi.wu 已提交
74 75 76 77 78
    // NOTE: if raw buffer is large and have two neighbor fields of raw buffers
    // GetDirectBufferPointer can get all of them, use length to truncate it.
    if (total_written + size_to_write > length) {
      size_to_write = length - total_written;
    }
79 80
    // TODO(gongwb): can we avoid copy?
    platform::CPUPlace cpu;
G
gongweibao 已提交
81
    // This log is useful to see how long a internal block size is of rpc.
M
minqiyang 已提交
82
    VLOG(7) << "copy " << size_to_write << " data to CPUPlace";
83 84 85
    memory::Copy(cpu, reinterpret_cast<void*>(p), cpu, data, size_to_write);

    p += size_to_write;
Y
yi.wu 已提交
86
    total_written += size_to_write;
87 88 89 90 91 92 93 94 95

    input->Skip(size_to_write);
  }

  return true;
}

bool VariableResponse::CopyLodTensorData(
    ::google::protobuf::io::CodedInputStream* input,
Y
Yancey1989 已提交
96 97
    const platform::DeviceContext& ctx, const framework::DDim& dims,
    int length) {
98 99 100 101 102 103
  auto server_var = GetVar();
  if (!server_var) {
    LOG(ERROR) << "recved var should not on current server: "
               << meta_.varname();
    return false;
  }
104
  auto* tensor = GetVar()->GetMutable<framework::LoDTensor>();
105 106 107 108 109 110 111 112 113 114 115 116
  tensor->Resize(dims);
  framework::LoD lod;
  for (int i = 0; i < meta_.lod_level(); ++i) {
    framework::Vector<size_t> v;
    for (int j = 0; j < meta_.lod(i).lod_data_size(); ++j) {
      v.push_back(meta_.lod(i).lod_data(j));
    }
    lod.push_back(v);
  }
  tensor->set_lod(lod);

  void* tensor_data =
Y
Yu Yang 已提交
117
      tensor->mutable_data(ctx.GetPlace(), ToVarType(meta_.data_type()));
118

Y
Yu Yang 已提交
119 120 121 122
  VLOG(6) << "Tensor.memory_size = " << tensor->memory_size()
          << ", Buffer Size = " << length;
  PADDLE_ENFORCE_EQ(tensor->memory_size(), length);
  return ReadRaw(input, ctx, tensor->place(), tensor_data, length);
123 124 125 126 127 128 129 130 131 132 133 134 135
}

inline framework::DDim GetDims(
    const ::google::protobuf::RepeatedField<::google::protobuf::int64>& dims) {
  std::vector<int> vecdims;
  for (auto& d : dims) {
    vecdims.push_back(d);
  }
  return framework::make_ddim(vecdims);
}

bool VariableResponse::CopySelectRowsTensorData(
    ::google::protobuf::io::CodedInputStream* input,
Y
Yancey1989 已提交
136 137
    const platform::DeviceContext& ctx, const framework::DDim& dims,
    int length) {
138
  auto* slr = GetVar()->GetMutable<framework::SelectedRows>();
139
  slr->set_height(meta_.slr_height());
140 141
  auto* tensor = slr->mutable_value();
  tensor->Resize(dims);
Y
Yu Yang 已提交
142 143 144 145
  PADDLE_ENFORCE_EQ(
      static_cast<size_t>(tensor->numel()),
      length / framework::SizeOfType(paddle::operators::distributed::ToVarType(
                   meta_.data_type())));
146 147
  void* tensor_data = tensor->mutable_data(
      ctx.GetPlace(),
Y
Yu Yang 已提交
148
      paddle::operators::distributed::ToVarType(meta_.data_type()));
149 150 151 152 153 154 155 156 157 158 159

  if (!ReadRaw(input, ctx, tensor->place(), tensor_data, length)) {
    return false;
  }

  return true;
}

bool VariableResponse::CopySelectRowsData(
    ::google::protobuf::io::CodedInputStream* input,
    const platform::DeviceContext& ctx, int length) {
160
  auto* slr = GetVar()->GetMutable<framework::SelectedRows>();
Q
qiaolongfei 已提交
161
  slr->mutable_rows()->clear();
Y
Yu Yang 已提交
162
  slr->mutable_rows()->resize(length / sizeof(int64_t));  // int64
163 164 165 166 167 168 169 170 171 172 173
  int64_t* rows_data = slr->mutable_rows()->data();

  // copy rows CPU data, GPU data will be copied lazily.
  platform::CPUPlace cpu;
  if (!ReadRaw(input, ctx, cpu, rows_data, length)) {
    return false;
  }

  return true;
}

174 175 176 177 178 179 180 181
bool VariableResponse::ProcSerializedField(
    int tag, ::google::protobuf::io::CodedInputStream* input,
    int64_t num_bytes) {
  PADDLE_ENFORCE((meta_.type() == sendrecv::SELECTED_ROWS ||
                  meta_.type() == sendrecv::LOD_TENSOR ||
                  meta_.type() == sendrecv::NCCL_ID) &&
                     meta_.varname() != "",
                 "meta info should be got first!");
182

183 184 185 186 187 188 189
  if (meta_.type() == sendrecv::NCCL_ID) {
#ifdef PADDLE_WITH_CUDA
    auto* var = scope_->FindVar(meta_.varname());
    if (var != nullptr) {
      ncclUniqueId* id = var->GetMutable<ncclUniqueId>();
      if (!ReadRaw(input, *dev_ctx_, platform::CPUPlace(), id->internal,
                   num_bytes)) {
190 191 192
        return false;
      }
    }
193
    return true;
T
typhoonzero 已提交
194
#else
195 196
    PADDLE_THROW("Not compiled with CUDA!");
    return false;
T
typhoonzero 已提交
197
#endif
198
  }
Y
Yancey1989 已提交
199

M
minqiyang 已提交
200 201
  VLOG(7) << "ProcSerializedField:" << meta_.varname()
          << ", type:" << meta_.type() << std::endl;
202 203 204 205 206 207
  framework::DDim dims = GetDims(meta_.dims());
  if (meta_.type() == sendrecv::LOD_TENSOR) {
    PADDLE_ENFORCE(meta_.lod_size() >= 0, "lod info should be got first!");
    if (!CopyLodTensorData(input, *dev_ctx_, dims, num_bytes)) {
      return false;
    }
G
gongweibao 已提交
208

209 210
    return true;
  }
Y
Yancey1989 已提交
211

212 213 214
  if (meta_.type() == sendrecv::SELECTED_ROWS) {
    if (!CopySelectRowsTensorData(input, *dev_ctx_, dims, num_bytes)) {
      return false;
215
    }
216
    return true;
217 218
  }

G
gongweibao 已提交
219 220 221
  PADDLE_ENFORCE("not supported var types:", meta_.varname(), meta_.type());

  return false;
222 223
}

224
};  // namespace distributed
225 226
};  // namespace operators
};  // namespace paddle