variable_response.cc 7.2 KB
Newer Older
1
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/fluid/operators/distributed/variable_response.h"
Y
Yi Wang 已提交
16
#include <vector>
17
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
18 19 20

namespace paddle {
namespace operators {
21
namespace distributed {
22

23 24 25 26
bool VariableResponse::ReadRaw(::google::protobuf::io::CodedInputStream* input,
                               const platform::DeviceContext& dev_ctx,
                               platform::Place place, void* dest,
                               int64_t size) {
27 28
  const void* data = NULL;
  int size_to_write = 0;
29
  int64_t length = size;
Y
yi.wu 已提交
30
  int total_written = 0;
31 32 33 34 35 36 37 38

  if (platform::is_gpu_place(place)) {
#ifdef PADDLE_WITH_CUDA
    auto& gpu_dev_ctx =
        static_cast<const platform::CUDADeviceContext&>(dev_ctx);
    platform::CPUPlace cpu;

    char* p = reinterpret_cast<char*>(dest);
Y
yi.wu 已提交
39
    while (total_written < length) {
40 41 42
      if (!input->GetDirectBufferPointer(&data, &size_to_write)) {
        return false;
      }
Y
yi.wu 已提交
43 44 45 46 47 48
      // NOTE: if raw buffer is large and have two neighbor fields of raw
      // buffers GetDirectBufferPointer can get all of them, use length to
      // truncate it.
      if (total_written + size_to_write > length) {
        size_to_write = length - total_written;
      }
G
gongweibao 已提交
49 50
      // This log is useful to see how long a internal block size is of rpc.
      VLOG(7) << "copy " << size_to_write << " data to CUDAPlace";
51 52 53 54
      memory::Copy(boost::get<platform::CUDAPlace>(place),
                   reinterpret_cast<void*>(p), cpu, data, size_to_write,
                   gpu_dev_ctx.stream());
      p += size_to_write;
Y
yi.wu 已提交
55
      total_written += size_to_write;
56 57 58 59 60 61 62 63 64 65 66

      input->Skip(size_to_write);
    }
    gpu_dev_ctx.Wait();
#else
    PADDLE_THROW("Unexpected branch");
#endif
    return true;
  }

  char* p = reinterpret_cast<char*>(dest);
Y
yi.wu 已提交
67
  while (total_written < length) {
68 69 70
    if (!input->GetDirectBufferPointer(&data, &size_to_write)) {
      return false;
    }
Y
yi.wu 已提交
71 72 73 74 75
    // NOTE: if raw buffer is large and have two neighbor fields of raw buffers
    // GetDirectBufferPointer can get all of them, use length to truncate it.
    if (total_written + size_to_write > length) {
      size_to_write = length - total_written;
    }
76 77
    // TODO(gongwb): can we avoid copy?
    platform::CPUPlace cpu;
G
gongweibao 已提交
78 79
    // This log is useful to see how long a internal block size is of rpc.
    VLOG(7) << "copy " << size_to_write << " data to CPUPlace";
80 81 82
    memory::Copy(cpu, reinterpret_cast<void*>(p), cpu, data, size_to_write);

    p += size_to_write;
Y
yi.wu 已提交
83
    total_written += size_to_write;
84 85 86 87 88 89 90 91 92

    input->Skip(size_to_write);
  }

  return true;
}

bool VariableResponse::CopyLodTensorData(
    ::google::protobuf::io::CodedInputStream* input,
Y
Yancey1989 已提交
93 94
    const platform::DeviceContext& ctx, const framework::DDim& dims,
    int length) {
95 96 97 98 99 100
  auto server_var = GetVar();
  if (!server_var) {
    LOG(ERROR) << "recved var should not on current server: "
               << meta_.varname();
    return false;
  }
101
  auto* tensor = GetVar()->GetMutable<framework::LoDTensor>();
102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132
  tensor->Resize(dims);
  framework::LoD lod;
  for (int i = 0; i < meta_.lod_level(); ++i) {
    framework::Vector<size_t> v;
    for (int j = 0; j < meta_.lod(i).lod_data_size(); ++j) {
      v.push_back(meta_.lod(i).lod_data(j));
    }
    lod.push_back(v);
  }
  tensor->set_lod(lod);

  void* tensor_data =
      tensor->mutable_data(ctx.GetPlace(), ToTypeIndex(meta_.data_type()));
  if (!ReadRaw(input, ctx, tensor->place(), tensor_data, length)) {
    return false;
  }

  return true;
}

inline framework::DDim GetDims(
    const ::google::protobuf::RepeatedField<::google::protobuf::int64>& dims) {
  std::vector<int> vecdims;
  for (auto& d : dims) {
    vecdims.push_back(d);
  }
  return framework::make_ddim(vecdims);
}

bool VariableResponse::CopySelectRowsTensorData(
    ::google::protobuf::io::CodedInputStream* input,
Y
Yancey1989 已提交
133 134
    const platform::DeviceContext& ctx, const framework::DDim& dims,
    int length) {
135
  auto* slr = GetVar()->GetMutable<framework::SelectedRows>();
136
  slr->set_height(meta_.slr_height());
137 138
  auto* tensor = slr->mutable_value();
  tensor->Resize(dims);
139 140 141 142
  PADDLE_ENFORCE_EQ(static_cast<size_t>(tensor->numel()),
                    length / framework::SizeOfType(
                                 paddle::operators::distributed::ToTypeIndex(
                                     meta_.data_type())));
143 144
  void* tensor_data = tensor->mutable_data(
      ctx.GetPlace(),
145
      paddle::operators::distributed::ToTypeIndex(meta_.data_type()));
146 147 148 149 150 151 152 153 154 155 156

  if (!ReadRaw(input, ctx, tensor->place(), tensor_data, length)) {
    return false;
  }

  return true;
}

bool VariableResponse::CopySelectRowsData(
    ::google::protobuf::io::CodedInputStream* input,
    const platform::DeviceContext& ctx, int length) {
157
  auto* slr = GetVar()->GetMutable<framework::SelectedRows>();
Q
qiaolongfei 已提交
158
  slr->mutable_rows()->clear();
T
typhoonzero 已提交
159 160
  slr->mutable_rows()->resize(length /
                              framework::SizeOfType(typeid(int64_t)));  // int64
161 162 163 164 165 166 167 168 169 170 171
  int64_t* rows_data = slr->mutable_rows()->data();

  // copy rows CPU data, GPU data will be copied lazily.
  platform::CPUPlace cpu;
  if (!ReadRaw(input, ctx, cpu, rows_data, length)) {
    return false;
  }

  return true;
}

172 173 174 175 176 177 178 179
bool VariableResponse::ProcSerializedField(
    int tag, ::google::protobuf::io::CodedInputStream* input,
    int64_t num_bytes) {
  PADDLE_ENFORCE((meta_.type() == sendrecv::SELECTED_ROWS ||
                  meta_.type() == sendrecv::LOD_TENSOR ||
                  meta_.type() == sendrecv::NCCL_ID) &&
                     meta_.varname() != "",
                 "meta info should be got first!");
180

181 182 183 184 185 186 187
  if (meta_.type() == sendrecv::NCCL_ID) {
#ifdef PADDLE_WITH_CUDA
    auto* var = scope_->FindVar(meta_.varname());
    if (var != nullptr) {
      ncclUniqueId* id = var->GetMutable<ncclUniqueId>();
      if (!ReadRaw(input, *dev_ctx_, platform::CPUPlace(), id->internal,
                   num_bytes)) {
188 189 190
        return false;
      }
    }
191
    return true;
T
typhoonzero 已提交
192
#else
193 194
    PADDLE_THROW("Not compiled with CUDA!");
    return false;
T
typhoonzero 已提交
195
#endif
196
  }
Y
Yancey1989 已提交
197

G
gongweibao 已提交
198 199
  VLOG(7) << "ProcSerializedField:" << meta_.varname()
          << ", type:" << meta_.type() << std::endl;
200 201 202 203 204 205
  framework::DDim dims = GetDims(meta_.dims());
  if (meta_.type() == sendrecv::LOD_TENSOR) {
    PADDLE_ENFORCE(meta_.lod_size() >= 0, "lod info should be got first!");
    if (!CopyLodTensorData(input, *dev_ctx_, dims, num_bytes)) {
      return false;
    }
G
gongweibao 已提交
206

207 208
    return true;
  }
Y
Yancey1989 已提交
209

210 211 212
  if (meta_.type() == sendrecv::SELECTED_ROWS) {
    if (!CopySelectRowsTensorData(input, *dev_ctx_, dims, num_bytes)) {
      return false;
213
    }
214
    return true;
215 216
  }

G
gongweibao 已提交
217 218 219
  PADDLE_ENFORCE("not supported var types:", meta_.varname(), meta_.type());

  return false;
220 221
}

222
};  // namespace distributed
223 224
};  // namespace operators
};  // namespace paddle