ps_local_client.cc 11.0 KB
Newer Older
T
Thunderbrook 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/fluid/distributed/ps/service/ps_local_client.h"
16

17
#include "paddle/fluid/distributed/ps/table/table.h"
T
Thunderbrook 已提交
18 19 20 21 22

//#define pslib_debug_dense_compress

namespace paddle {
namespace distributed {
Z
zhaocaibei123 已提交
23
int32_t PsLocalClient::Initialize() {
T
Thunderbrook 已提交
24
  const auto& downpour_param = _config.server_param().downpour_server_param();
Z
zhaocaibei123 已提交
25
  TableManager::Instance().Initialize();
Z
zhangchunle 已提交
26
  for (int i = 0; i < downpour_param.downpour_table_param_size(); ++i) {
T
Thunderbrook 已提交
27 28
    auto* table = CREATE_PSCORE_CLASS(
        Table, downpour_param.downpour_table_param(i).table_class());
Z
zhaocaibei123 已提交
29 30
    table->SetShard(0, 1);
    table->Initialize(downpour_param.downpour_table_param(i),
T
Thunderbrook 已提交
31 32 33 34 35 36
                      _config.fs_client_param());
    _table_map[downpour_param.downpour_table_param(i).table_id()].reset(table);
  }
  return 0;
}

Z
zhaocaibei123 已提交
37
::std::future<int32_t> PsLocalClient::Shrink(uint32_t table_id,
T
Thunderbrook 已提交
38 39 40 41 42
                                             const std::string threshold) {
  // TODO
  return done();
}

Z
zhaocaibei123 已提交
43
::std::future<int32_t> PsLocalClient::Load(const std::string& epoch,
T
Thunderbrook 已提交
44 45
                                           const std::string& mode) {
  // TODO
T
Thunderbrook 已提交
46
  for (auto& it : _table_map) {
Z
zhaocaibei123 已提交
47
    Load(it.first, epoch, mode);
T
Thunderbrook 已提交
48
  }
T
Thunderbrook 已提交
49 50
  return done();
}
Z
zhaocaibei123 已提交
51
::std::future<int32_t> PsLocalClient::Load(uint32_t table_id,
T
Thunderbrook 已提交
52 53 54
                                           const std::string& epoch,
                                           const std::string& mode) {
  // TODO
Z
zhaocaibei123 已提交
55 56
  auto* table_ptr = GetTable(table_id);
  table_ptr->Load(epoch, mode);
T
Thunderbrook 已提交
57 58 59
  return done();
}

Z
zhaocaibei123 已提交
60
::std::future<int32_t> PsLocalClient::Save(const std::string& epoch,
T
Thunderbrook 已提交
61 62 63
                                           const std::string& mode) {
  // TODO
  for (auto& it : _table_map) {
Z
zhaocaibei123 已提交
64
    Save(it.first, epoch, mode);
T
Thunderbrook 已提交
65 66 67
  }
  return done();
}
Z
zhaocaibei123 已提交
68
::std::future<int32_t> PsLocalClient::Save(uint32_t table_id,
T
Thunderbrook 已提交
69 70 71
                                           const std::string& epoch,
                                           const std::string& mode) {
  // TODO
Z
zhaocaibei123 已提交
72 73 74
  auto* table_ptr = GetTable(table_id);
  table_ptr->Flush();
  table_ptr->Save(epoch, mode);
T
Thunderbrook 已提交
75 76 77
  return done();
}

Z
zhaocaibei123 已提交
78
::std::future<int32_t> PsLocalClient::Clear() {
T
Thunderbrook 已提交
79 80 81
  // TODO
  return done();
}
Z
zhaocaibei123 已提交
82
::std::future<int32_t> PsLocalClient::Clear(uint32_t table_id) {
T
Thunderbrook 已提交
83 84 85 86
  // TODO
  return done();
}

Z
zhaocaibei123 已提交
87
::std::future<int32_t> PsLocalClient::Flush() {
T
Thunderbrook 已提交
88 89 90 91
  // no need
  return done();
}

Z
zhaocaibei123 已提交
92
::std::future<int32_t> PsLocalClient::StopServer() {
T
Thunderbrook 已提交
93 94 95 96
  // no need
  return done();
}

Z
zhaocaibei123 已提交
97 98 99 100 101
::std::future<int32_t> PsLocalClient::PullDense(Region* regions,
                                                size_t region_num,
                                                size_t table_id) {
  auto* accessor = GetTableAccessor(table_id);
  auto* table_ptr = GetTable(table_id);
Y
yaoxuefeng 已提交
102

103 104
  uint32_t num_per_shard =
      DenseDimPerShard(accessor->GetAccessorInfo().fea_dim, 1);
T
Thunderbrook 已提交
105 106 107

  std::vector<float> region_buffer;
  region_buffer.resize(num_per_shard);
108 109 110 111 112 113 114

  TableContext table_context;
  table_context.value_type = Dense;
  table_context.pull_context.values = region_buffer.data();
  table_context.num = region_buffer.size();
  table_ptr->Pull(table_context);
  //  table_ptr->PullDense(region_buffer.data(), region_buffer.size());
T
Thunderbrook 已提交
115 116 117 118 119 120

  size_t region_idx = 0;
  size_t region_data_idx = 0;
  size_t shard_data_size = num_per_shard;
  size_t shard_buffer_remain = shard_data_size * sizeof(float);
  PADDLE_ENFORCE_EQ(
121 122
      shard_buffer_remain,
      region_buffer.size() * sizeof(float),
T
Thunderbrook 已提交
123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149
      platform::errors::PreconditionNotMet("pull dense size error."));
  size_t index = 0;
  while (shard_buffer_remain > 0 && region_idx < region_num) {
    auto& region = regions[region_idx];
    if (region.size - region_data_idx >= shard_buffer_remain) {
      memcpy((void*)(region.data + region_data_idx),
             (uint8_t*)(void*)(region_buffer.data()) + index,
             shard_buffer_remain);
      region_data_idx += shard_buffer_remain;
      shard_buffer_remain = 0;
    } else if (region.size - region_data_idx == 0) {
      ++region_idx;
      region_data_idx = 0;
    } else {
      memcpy((void*)(region.data + region_data_idx),
             (uint8_t*)(void*)(region_buffer.data()) + index,
             region.size - region_data_idx);
      shard_buffer_remain -= (region.size - region_data_idx);
      index += (region.size - region_data_idx);
      ++region_idx;
      region_data_idx = 0;
    }
  }

  return done();
}

Z
zhaocaibei123 已提交
150 151 152 153 154
::std::future<int32_t> PsLocalClient::PushDenseParam(const Region* regions,
                                                     size_t region_num,
                                                     size_t table_id) {
  auto* accessor = GetTableAccessor(table_id);
  auto* table_ptr = GetTable(table_id);
T
Thunderbrook 已提交
155 156

  std::vector<float> region_buffer;
157 158
  region_buffer.resize(DenseDimPerShard(accessor->GetAccessorInfo().fea_dim, 1),
                       0);
T
Thunderbrook 已提交
159 160 161 162 163 164
  for (size_t i = 0, offset = 0; i < region_num; ++i) {
    uint32_t data_num = regions[i].size / sizeof(float);
    memcpy(region_buffer.data() + offset, regions[i].data, regions[i].size);
    offset += data_num;
  }

165 166 167 168 169 170 171
  TableContext table_context;
  table_context.value_type = Dense;
  table_context.push_context.values = region_buffer.data();
  table_context.push_context.is_param = true;
  table_context.num = region_buffer.size();

  table_ptr->Push(table_context);
Z
zhaocaibei123 已提交
172
  // table_ptr->PushDenseParam(region_buffer.data(), region_buffer.size());
T
Thunderbrook 已提交
173 174 175 176

  return done();
}

Z
zhaocaibei123 已提交
177
::std::future<int32_t> PsLocalClient::PushDenseRawGradient(
178 179 180
    int table_id,
    float* total_send_data,
    size_t total_send_data_size,
T
Thunderbrook 已提交
181 182 183 184 185
    void* callback) {
  VLOG(1) << "wxx push_dense_raw_gradient";

  PSClientClosure* closure = reinterpret_cast<PSClientClosure*>(callback);

Z
zhaocaibei123 已提交
186
  auto* table_ptr = GetTable(table_id);
T
Thunderbrook 已提交
187

188 189 190 191 192 193 194
  TableContext table_context;
  table_context.value_type = Dense;
  table_context.push_context.values = total_send_data;
  table_context.num = total_send_data_size;
  //  table_ptr->PushDense(total_send_data, total_send_data_size);
  table_ptr->Push(table_context);

T
Thunderbrook 已提交
195 196 197 198
  delete closure;
  return done();
}

Z
zhaocaibei123 已提交
199 200 201 202 203
::std::future<int32_t> PsLocalClient::PushDense(const Region* regions,
                                                size_t region_num,
                                                size_t table_id) {
  auto* accessor = GetTableAccessor(table_id);
  auto* table_ptr = GetTable(table_id);
T
Thunderbrook 已提交
204 205

  std::vector<float> region_buffer;
206 207
  region_buffer.resize(
      DenseDimPerShard(accessor->GetAccessorInfo().fea_dim, 1));
T
Thunderbrook 已提交
208 209 210 211
  size_t data_size = region_buffer.size();
  for (size_t i = 0, offset = 0; i < region_num; ++i) {
    uint32_t data_num = regions[i].size / sizeof(float);
    PADDLE_ENFORCE_LE(
212 213
        offset + data_num,
        data_size,
T
Thunderbrook 已提交
214
        platform::errors::PreconditionNotMet(
215 216 217 218
            "invalid dense size, cur pos[%d] data_num[%d] size[%d]",
            offset,
            data_num,
            data_size));
T
Thunderbrook 已提交
219 220 221 222
    memcpy(region_buffer.data() + offset, regions[i].data, regions[i].size);
    offset += data_num;
  }

223 224 225 226 227 228
  TableContext table_context;
  table_context.value_type = Dense;
  table_context.push_context.values = region_buffer.data();
  table_context.num = region_buffer.size();
  //  table_ptr->PushDense(total_send_data, total_send_data_size);
  table_ptr->Push(table_context);
T
Thunderbrook 已提交
229 230 231 232

  return done();
}

Z
zhaocaibei123 已提交
233
//::std::future<int32_t> PsLocalClient::PullSparse(float** select_values,
T
Thunderbrook 已提交
234 235 236 237 238 239 240 241 242
//                                                  size_t table_id,
//                                                  const uint64_t* keys,
//                                                  size_t num) {
//  // FIXME
//  // auto timer =
//  // std::make_shared<CostTimer>("pslib_downpour_client_pull_sparse");
//  // auto local_timer =
//  // std::make_shared<CostTimer>("pslib_downpour_client_pull_sparse_local");
//  //将key拆分到各shard请求,并记录原始对应value指针
Z
zhaocaibei123 已提交
243 244
//  auto* accessor = GetTableAccessor(table_id);
//  auto* table_ptr = GetTable(table_id);
T
Thunderbrook 已提交
245 246
//  size_t value_size = accessor->select_size();
//
Z
zhaocaibei123 已提交
247
//  // table_ptr->PullSparse(keys, num);
T
Thunderbrook 已提交
248 249
//  std::vector<float> res_data;
//  res_data.resize(num * value_size / sizeof(float));
Z
zhaocaibei123 已提交
250
//  table_ptr->PullSparse(res_data.data(), keys, num);
T
Thunderbrook 已提交
251 252 253 254 255 256 257 258 259 260 261 262
//  // memcpy(select_values[0], res_data->data(), res_data->size() *
//  // sizeof(float));
//  size_t offset = 0;
//  for (int i = 0; i < num; ++i) {
//    memcpy(select_values[i], (char*)res_data.data() + offset, value_size);
//    offset += value_size;
//  }
//
//  // return fut;
//  return done();
//}

Z
zhaocaibei123 已提交
263 264 265 266
::std::future<int32_t> PsLocalClient::PullSparsePtr(char** select_values,
                                                    size_t table_id,
                                                    const uint64_t* keys,
                                                    size_t num) {
T
Thunderbrook 已提交
267 268 269 270 271 272
  // FIXME
  // auto timer =
  // std::make_shared<CostTimer>("pslib_downpour_client_pull_sparse");
  // auto local_timer =
  // std::make_shared<CostTimer>("pslib_downpour_client_pull_sparse_local");
  //将key拆分到各shard请求,并记录原始对应value指针
Z
zhaocaibei123 已提交
273
  auto* table_ptr = GetTable(table_id);
T
Thunderbrook 已提交
274

275 276 277 278 279 280 281 282 283
  TableContext table_context;
  table_context.value_type = Sparse;
  table_context.pull_context.keys = keys;
  table_context.pull_context.ptr_values = select_values;
  table_context.use_ptr = true;
  table_context.num = num;

  //  table_ptr->PullSparsePtr(select_values, keys, num);
  table_ptr->Pull(table_context);
T
Thunderbrook 已提交
284 285 286 287

  return done();
}

Z
zhaocaibei123 已提交
288
::std::future<int32_t> PsLocalClient::PushSparseRawGradient(
289 290 291 292 293
    size_t table_id,
    const uint64_t* keys,
    const float** update_values,
    size_t num,
    void* callback) {
T
Thunderbrook 已提交
294
  PSClientClosure* closure = reinterpret_cast<PSClientClosure*>(callback);
Z
zhaocaibei123 已提交
295
  auto* table_ptr = GetTable(table_id);
T
Thunderbrook 已提交
296

297 298 299 300 301 302 303 304 305
  TableContext table_context;
  table_context.value_type = Sparse;
  table_context.push_context.keys = keys;
  table_context.push_context.ptr_values = update_values;
  table_context.num = num;
  table_context.use_ptr = true;

  // table_ptr->PushSparse(keys, update_values, num);
  table_ptr->Push(table_context);
T
Thunderbrook 已提交
306 307 308 309
  delete closure;
  return done();
}

Z
zhaocaibei123 已提交
310 311 312 313 314
::std::future<int32_t> PsLocalClient::PushSparse(size_t table_id,
                                                 const uint64_t* keys,
                                                 const float** update_values,
                                                 size_t num) {
  auto* table_ptr = GetTable(table_id);
T
Thunderbrook 已提交
315

316 317 318 319 320 321 322 323 324
  TableContext table_context;
  table_context.value_type = Sparse;
  table_context.push_context.keys = keys;
  table_context.push_context.ptr_values = update_values;
  table_context.num = num;
  table_context.use_ptr = true;

  //  table_ptr->PushSparse(keys, update_values, num);
  table_ptr->Push(table_context);
T
Thunderbrook 已提交
325 326
  return done();
}
327 328
}  // namespace distributed
}  // namespace paddle