variable_response.cc 7.3 KB
Newer Older
1
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/fluid/operators/distributed/variable_response.h"
Y
Yi Wang 已提交
16
#include <vector>
17
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
18

Q
Qiao Longfei 已提交
19 20 21
DEFINE_string(rpc_server_profile_path, "./profile_ps",
              "the profile log file path");

22 23
namespace paddle {
namespace operators {
24
namespace distributed {
25

26 27 28 29
bool VariableResponse::ReadRaw(::google::protobuf::io::CodedInputStream* input,
                               const platform::DeviceContext& dev_ctx,
                               platform::Place place, void* dest,
                               int64_t size) {
30 31
  const void* data = NULL;
  int size_to_write = 0;
32
  int64_t length = size;
Y
yi.wu 已提交
33
  int total_written = 0;
34 35 36 37 38 39 40 41

  if (platform::is_gpu_place(place)) {
#ifdef PADDLE_WITH_CUDA
    auto& gpu_dev_ctx =
        static_cast<const platform::CUDADeviceContext&>(dev_ctx);
    platform::CPUPlace cpu;

    char* p = reinterpret_cast<char*>(dest);
Y
yi.wu 已提交
42
    while (total_written < length) {
43 44 45
      if (!input->GetDirectBufferPointer(&data, &size_to_write)) {
        return false;
      }
Y
yi.wu 已提交
46 47 48 49 50 51
      // NOTE: if raw buffer is large and have two neighbor fields of raw
      // buffers GetDirectBufferPointer can get all of them, use length to
      // truncate it.
      if (total_written + size_to_write > length) {
        size_to_write = length - total_written;
      }
G
gongweibao 已提交
52
      // This log is useful to see how long a internal block size is of rpc.
53
      VLOG(70) << "copy " << size_to_write << " data to CUDAPlace";
54 55 56 57
      memory::Copy(boost::get<platform::CUDAPlace>(place),
                   reinterpret_cast<void*>(p), cpu, data, size_to_write,
                   gpu_dev_ctx.stream());
      p += size_to_write;
Y
yi.wu 已提交
58
      total_written += size_to_write;
59 60 61 62 63 64 65 66 67 68 69

      input->Skip(size_to_write);
    }
    gpu_dev_ctx.Wait();
#else
    PADDLE_THROW("Unexpected branch");
#endif
    return true;
  }

  char* p = reinterpret_cast<char*>(dest);
Y
yi.wu 已提交
70
  while (total_written < length) {
71 72 73
    if (!input->GetDirectBufferPointer(&data, &size_to_write)) {
      return false;
    }
Y
yi.wu 已提交
74 75 76 77 78
    // NOTE: if raw buffer is large and have two neighbor fields of raw buffers
    // GetDirectBufferPointer can get all of them, use length to truncate it.
    if (total_written + size_to_write > length) {
      size_to_write = length - total_written;
    }
79 80
    // TODO(gongwb): can we avoid copy?
    platform::CPUPlace cpu;
G
gongweibao 已提交
81
    // This log is useful to see how long a internal block size is of rpc.
82
    VLOG(70) << "copy " << size_to_write << " data to CPUPlace";
83 84 85
    memory::Copy(cpu, reinterpret_cast<void*>(p), cpu, data, size_to_write);

    p += size_to_write;
Y
yi.wu 已提交
86
    total_written += size_to_write;
87 88 89 90 91 92 93 94 95

    input->Skip(size_to_write);
  }

  return true;
}

bool VariableResponse::CopyLodTensorData(
    ::google::protobuf::io::CodedInputStream* input,
Y
Yancey1989 已提交
96 97
    const platform::DeviceContext& ctx, const framework::DDim& dims,
    int length) {
98 99 100 101 102 103
  auto server_var = GetVar();
  if (!server_var) {
    LOG(ERROR) << "recved var should not on current server: "
               << meta_.varname();
    return false;
  }
104
  auto* tensor = GetVar()->GetMutable<framework::LoDTensor>();
105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135
  tensor->Resize(dims);
  framework::LoD lod;
  for (int i = 0; i < meta_.lod_level(); ++i) {
    framework::Vector<size_t> v;
    for (int j = 0; j < meta_.lod(i).lod_data_size(); ++j) {
      v.push_back(meta_.lod(i).lod_data(j));
    }
    lod.push_back(v);
  }
  tensor->set_lod(lod);

  void* tensor_data =
      tensor->mutable_data(ctx.GetPlace(), ToTypeIndex(meta_.data_type()));
  if (!ReadRaw(input, ctx, tensor->place(), tensor_data, length)) {
    return false;
  }

  return true;
}

inline framework::DDim GetDims(
    const ::google::protobuf::RepeatedField<::google::protobuf::int64>& dims) {
  std::vector<int> vecdims;
  for (auto& d : dims) {
    vecdims.push_back(d);
  }
  return framework::make_ddim(vecdims);
}

bool VariableResponse::CopySelectRowsTensorData(
    ::google::protobuf::io::CodedInputStream* input,
Y
Yancey1989 已提交
136 137
    const platform::DeviceContext& ctx, const framework::DDim& dims,
    int length) {
138
  auto* slr = GetVar()->GetMutable<framework::SelectedRows>();
139
  slr->set_height(meta_.slr_height());
140 141
  auto* tensor = slr->mutable_value();
  tensor->Resize(dims);
142 143 144 145
  PADDLE_ENFORCE_EQ(static_cast<size_t>(tensor->numel()),
                    length / framework::SizeOfType(
                                 paddle::operators::distributed::ToTypeIndex(
                                     meta_.data_type())));
146 147
  void* tensor_data = tensor->mutable_data(
      ctx.GetPlace(),
148
      paddle::operators::distributed::ToTypeIndex(meta_.data_type()));
149 150 151 152 153 154 155 156 157 158 159

  if (!ReadRaw(input, ctx, tensor->place(), tensor_data, length)) {
    return false;
  }

  return true;
}

bool VariableResponse::CopySelectRowsData(
    ::google::protobuf::io::CodedInputStream* input,
    const platform::DeviceContext& ctx, int length) {
160
  auto* slr = GetVar()->GetMutable<framework::SelectedRows>();
Q
qiaolongfei 已提交
161
  slr->mutable_rows()->clear();
T
typhoonzero 已提交
162 163
  slr->mutable_rows()->resize(length /
                              framework::SizeOfType(typeid(int64_t)));  // int64
164 165 166 167 168 169 170 171 172 173 174
  int64_t* rows_data = slr->mutable_rows()->data();

  // copy rows CPU data, GPU data will be copied lazily.
  platform::CPUPlace cpu;
  if (!ReadRaw(input, ctx, cpu, rows_data, length)) {
    return false;
  }

  return true;
}

175 176 177 178 179 180 181 182
bool VariableResponse::ProcSerializedField(
    int tag, ::google::protobuf::io::CodedInputStream* input,
    int64_t num_bytes) {
  PADDLE_ENFORCE((meta_.type() == sendrecv::SELECTED_ROWS ||
                  meta_.type() == sendrecv::LOD_TENSOR ||
                  meta_.type() == sendrecv::NCCL_ID) &&
                     meta_.varname() != "",
                 "meta info should be got first!");
183

184 185 186 187 188 189 190
  if (meta_.type() == sendrecv::NCCL_ID) {
#ifdef PADDLE_WITH_CUDA
    auto* var = scope_->FindVar(meta_.varname());
    if (var != nullptr) {
      ncclUniqueId* id = var->GetMutable<ncclUniqueId>();
      if (!ReadRaw(input, *dev_ctx_, platform::CPUPlace(), id->internal,
                   num_bytes)) {
191 192 193
        return false;
      }
    }
194
    return true;
T
typhoonzero 已提交
195
#else
196 197
    PADDLE_THROW("Not compiled with CUDA!");
    return false;
T
typhoonzero 已提交
198
#endif
199
  }
Y
Yancey1989 已提交
200

201 202
  VLOG(70) << "ProcSerializedField:" << meta_.varname()
           << ", type:" << meta_.type() << std::endl;
203 204 205 206 207 208
  framework::DDim dims = GetDims(meta_.dims());
  if (meta_.type() == sendrecv::LOD_TENSOR) {
    PADDLE_ENFORCE(meta_.lod_size() >= 0, "lod info should be got first!");
    if (!CopyLodTensorData(input, *dev_ctx_, dims, num_bytes)) {
      return false;
    }
G
gongweibao 已提交
209

210 211
    return true;
  }
Y
Yancey1989 已提交
212

213 214 215
  if (meta_.type() == sendrecv::SELECTED_ROWS) {
    if (!CopySelectRowsTensorData(input, *dev_ctx_, dims, num_bytes)) {
      return false;
216
    }
217
    return true;
218 219
  }

G
gongweibao 已提交
220 221 222
  PADDLE_ENFORCE("not supported var types:", meta_.varname(), meta_.type());

  return false;
223 224
}

225
};  // namespace distributed
226 227
};  // namespace operators
};  // namespace paddle