variable_response.cc 7.0 KB
Newer Older
1
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/fluid/operators/distributed/variable_response.h"
Y
Yi Wang 已提交
16
#include <vector>
17
#include "paddle/fluid/operators/distributed/sendrecvop_utils.h"
18 19 20

namespace paddle {
namespace operators {
21
namespace distributed {
22

23 24 25 26
bool VariableResponse::ReadRaw(::google::protobuf::io::CodedInputStream* input,
                               const platform::DeviceContext& dev_ctx,
                               platform::Place place, void* dest,
                               int64_t size) {
27 28
  const void* data = NULL;
  int size_to_write = 0;
29
  int64_t length = size;
Y
yi.wu 已提交
30
  int total_written = 0;
31 32 33 34 35 36 37 38

  if (platform::is_gpu_place(place)) {
#ifdef PADDLE_WITH_CUDA
    auto& gpu_dev_ctx =
        static_cast<const platform::CUDADeviceContext&>(dev_ctx);
    platform::CPUPlace cpu;

    char* p = reinterpret_cast<char*>(dest);
Y
yi.wu 已提交
39
    while (total_written < length) {
40 41 42
      if (!input->GetDirectBufferPointer(&data, &size_to_write)) {
        return false;
      }
Y
yi.wu 已提交
43 44 45 46 47 48
      // NOTE: if raw buffer is large and have two neighbor fields of raw
      // buffers GetDirectBufferPointer can get all of them, use length to
      // truncate it.
      if (total_written + size_to_write > length) {
        size_to_write = length - total_written;
      }
G
gongweibao 已提交
49 50
      // This log is useful to see how long a internal block size is of rpc.
      VLOG(7) << "copy " << size_to_write << " data to CUDAPlace";
51 52 53 54
      memory::Copy(boost::get<platform::CUDAPlace>(place),
                   reinterpret_cast<void*>(p), cpu, data, size_to_write,
                   gpu_dev_ctx.stream());
      p += size_to_write;
Y
yi.wu 已提交
55
      total_written += size_to_write;
56 57 58 59 60 61 62 63 64 65 66

      input->Skip(size_to_write);
    }
    gpu_dev_ctx.Wait();
#else
    PADDLE_THROW("Unexpected branch");
#endif
    return true;
  }

  char* p = reinterpret_cast<char*>(dest);
Y
yi.wu 已提交
67
  while (total_written < length) {
68 69 70
    if (!input->GetDirectBufferPointer(&data, &size_to_write)) {
      return false;
    }
Y
yi.wu 已提交
71 72 73 74 75
    // NOTE: if raw buffer is large and have two neighbor fields of raw buffers
    // GetDirectBufferPointer can get all of them, use length to truncate it.
    if (total_written + size_to_write > length) {
      size_to_write = length - total_written;
    }
76 77
    // TODO(gongwb): can we avoid copy?
    platform::CPUPlace cpu;
G
gongweibao 已提交
78 79
    // This log is useful to see how long a internal block size is of rpc.
    VLOG(7) << "copy " << size_to_write << " data to CPUPlace";
80 81 82
    memory::Copy(cpu, reinterpret_cast<void*>(p), cpu, data, size_to_write);

    p += size_to_write;
Y
yi.wu 已提交
83
    total_written += size_to_write;
84 85 86 87 88 89 90 91 92

    input->Skip(size_to_write);
  }

  return true;
}

bool VariableResponse::CopyLodTensorData(
    ::google::protobuf::io::CodedInputStream* input,
Y
Yancey1989 已提交
93 94
    const platform::DeviceContext& ctx, const framework::DDim& dims,
    int length) {
95
  auto* tensor = GetVar()->GetMutable<framework::LoDTensor>();
96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128
  tensor->Resize(dims);

  framework::LoD lod;
  for (int i = 0; i < meta_.lod_level(); ++i) {
    framework::Vector<size_t> v;
    for (int j = 0; j < meta_.lod(i).lod_data_size(); ++j) {
      v.push_back(meta_.lod(i).lod_data(j));
    }
    lod.push_back(v);
  }
  tensor->set_lod(lod);

  void* tensor_data =
      tensor->mutable_data(ctx.GetPlace(), ToTypeIndex(meta_.data_type()));

  if (!ReadRaw(input, ctx, tensor->place(), tensor_data, length)) {
    return false;
  }

  return true;
}

inline framework::DDim GetDims(
    const ::google::protobuf::RepeatedField<::google::protobuf::int64>& dims) {
  std::vector<int> vecdims;
  for (auto& d : dims) {
    vecdims.push_back(d);
  }
  return framework::make_ddim(vecdims);
}

bool VariableResponse::CopySelectRowsTensorData(
    ::google::protobuf::io::CodedInputStream* input,
Y
Yancey1989 已提交
129 130
    const platform::DeviceContext& ctx, const framework::DDim& dims,
    int length) {
131
  auto* slr = GetVar()->GetMutable<framework::SelectedRows>();
132
  slr->set_height(meta_.slr_height());
133 134
  auto* tensor = slr->mutable_value();
  tensor->Resize(dims);
135 136 137 138
  PADDLE_ENFORCE_EQ(static_cast<size_t>(tensor->numel()),
                    length / framework::SizeOfType(
                                 paddle::operators::distributed::ToTypeIndex(
                                     meta_.data_type())));
139 140
  void* tensor_data = tensor->mutable_data(
      ctx.GetPlace(),
141
      paddle::operators::distributed::ToTypeIndex(meta_.data_type()));
142 143 144 145 146 147 148 149 150 151 152

  if (!ReadRaw(input, ctx, tensor->place(), tensor_data, length)) {
    return false;
  }

  return true;
}

bool VariableResponse::CopySelectRowsData(
    ::google::protobuf::io::CodedInputStream* input,
    const platform::DeviceContext& ctx, int length) {
153
  auto* slr = GetVar()->GetMutable<framework::SelectedRows>();
T
typhoonzero 已提交
154 155
  slr->mutable_rows()->resize(length /
                              framework::SizeOfType(typeid(int64_t)));  // int64
156 157 158 159 160 161 162 163 164 165 166
  int64_t* rows_data = slr->mutable_rows()->data();

  // copy rows CPU data, GPU data will be copied lazily.
  platform::CPUPlace cpu;
  if (!ReadRaw(input, ctx, cpu, rows_data, length)) {
    return false;
  }

  return true;
}

167 168 169 170 171 172 173 174
bool VariableResponse::ProcSerializedField(
    int tag, ::google::protobuf::io::CodedInputStream* input,
    int64_t num_bytes) {
  PADDLE_ENFORCE((meta_.type() == sendrecv::SELECTED_ROWS ||
                  meta_.type() == sendrecv::LOD_TENSOR ||
                  meta_.type() == sendrecv::NCCL_ID) &&
                     meta_.varname() != "",
                 "meta info should be got first!");
175

176 177 178 179 180 181 182
  if (meta_.type() == sendrecv::NCCL_ID) {
#ifdef PADDLE_WITH_CUDA
    auto* var = scope_->FindVar(meta_.varname());
    if (var != nullptr) {
      ncclUniqueId* id = var->GetMutable<ncclUniqueId>();
      if (!ReadRaw(input, *dev_ctx_, platform::CPUPlace(), id->internal,
                   num_bytes)) {
183 184 185
        return false;
      }
    }
186
    return true;
T
typhoonzero 已提交
187
#else
188 189
    PADDLE_THROW("Not compiled with CUDA!");
    return false;
T
typhoonzero 已提交
190
#endif
191
  }
Y
Yancey1989 已提交
192

G
gongweibao 已提交
193 194
  VLOG(7) << "ProcSerializedField:" << meta_.varname()
          << ", type:" << meta_.type() << std::endl;
195 196 197 198 199 200
  framework::DDim dims = GetDims(meta_.dims());
  if (meta_.type() == sendrecv::LOD_TENSOR) {
    PADDLE_ENFORCE(meta_.lod_size() >= 0, "lod info should be got first!");
    if (!CopyLodTensorData(input, *dev_ctx_, dims, num_bytes)) {
      return false;
    }
G
gongweibao 已提交
201

202 203
    return true;
  }
Y
Yancey1989 已提交
204

205 206 207
  if (meta_.type() == sendrecv::SELECTED_ROWS) {
    if (!CopySelectRowsTensorData(input, *dev_ctx_, dims, num_bytes)) {
      return false;
208
    }
209
    return true;
210 211
  }

G
gongweibao 已提交
212 213 214
  PADDLE_ENFORCE("not supported var types:", meta_.varname(), meta_.type());

  return false;
215 216
}

217
};  // namespace distributed
218 219
};  // namespace operators
};  // namespace paddle